From 3ec5928a83d78558a01f407efdad946ebf081e52 Mon Sep 17 00:00:00 2001 From: Clifton Kaznocha Date: Thu, 8 Dec 2022 16:50:48 -0800 Subject: [PATCH] Simplify namespace (#906) Signed-off-by: Clifton Kaznocha Signed-off-by: Adrian Cole --- .../integration_test/spectest/spectest.go | 2 +- internal/wasm/call_context.go | 2 +- internal/wasm/namespace.go | 134 +++++++++----- internal/wasm/namespace_test.go | 169 +++++++++++++----- internal/wasm/store.go | 7 +- internal/wasm/store_test.go | 16 +- 6 files changed, 225 insertions(+), 105 deletions(-) diff --git a/internal/integration_test/spectest/spectest.go b/internal/integration_test/spectest/spectest.go index 3fa17f0e..5747377b 100644 --- a/internal/integration_test/spectest/spectest.go +++ b/internal/integration_test/spectest/spectest.go @@ -444,7 +444,7 @@ func Run(t *testing.T, testDataFS embed.FS, ctx context.Context, newEngine func( if src == "" { src = lastInstantiatedModuleName } - ns.AliasModule(src, c.As) + require.NoError(t, ns.AliasModule(src, c.As)) lastInstantiatedModuleName = c.As case "assert_return", "action": moduleName := lastInstantiatedModuleName diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index 2c012c11..d48a2502 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -95,7 +95,7 @@ func (m *CallContext) CloseWithExitCode(ctx context.Context, exitCode uint32) er if !closed { return nil } - m.ns.deleteModule(m.Name()) + _ = m.ns.deleteModule(m.Name()) if m.CodeCloser == nil { return err } diff --git a/internal/wasm/namespace.go b/internal/wasm/namespace.go index 029ddaa5..747410c5 100644 --- a/internal/wasm/namespace.go +++ b/internal/wasm/namespace.go @@ -2,92 +2,125 @@ package wasm import ( "context" + "errors" "fmt" "sync" + "sync/atomic" "github.com/tetratelabs/wazero/api" ) -// nameListNode is a node in a doubly linked list of names. -type nameListNode struct { +// moduleListNode is a node in a doubly linked list of names. +type moduleListNode struct { name string - next, prev *nameListNode + module *ModuleInstance + next, prev *moduleListNode } // Namespace is a collection of instantiated modules which cannot conflict on name. type Namespace struct { - // moduleNamesList ensures modules are closed in reverse initialization order. - moduleNamesList *nameListNode // guarded by mux + // moduleList ensures modules are closed in reverse initialization order. + moduleList *moduleListNode // guarded by mux - // moduleNamesSet ensures no race conditions instantiating two modules of the same name - moduleNamesSet map[string]*nameListNode // guarded by mux - - // modules holds the instantiated Wasm modules by module name from Instantiate. - modules map[string]*ModuleInstance // guarded by mux + // nameToNode holds the instantiated Wasm modules by module name from Instantiate. + // It ensures no race conditions instantiating two modules of the same name. + nameToNode map[string]*moduleListNode // guarded by mux // mux is used to guard the fields from concurrent access. mux sync.RWMutex + + // closed is the pointer used both to guard Namespace.CloseWithExitCode. + // + // Note: Exclusively reading and updating this with atomics guarantees cross-goroutine observations. + // See /RATIONALE.md + closed *uint32 } // newNamespace returns an empty namespace. func newNamespace() *Namespace { return &Namespace{ - moduleNamesList: nil, - moduleNamesSet: map[string]*nameListNode{}, - modules: map[string]*ModuleInstance{}, + moduleList: nil, + nameToNode: map[string]*moduleListNode{}, + closed: new(uint32), } } -// addModule makes the module visible for import. -func (ns *Namespace) addModule(m *ModuleInstance) { +// setModule makes the module visible for import. +func (ns *Namespace) setModule(m *ModuleInstance) error { + if atomic.LoadUint32(ns.closed) != 0 { + return errors.New("module set on closed namespace") + } ns.mux.Lock() defer ns.mux.Unlock() - ns.modules[m.Name] = m + node, ok := ns.nameToNode[m.Name] + if !ok { + return fmt.Errorf("module[%s] name has not been required", m.Name) + } + + node.module = m + return nil } // deleteModule makes the moduleName available for instantiation again. -func (ns *Namespace) deleteModule(moduleName string) { +func (ns *Namespace) deleteModule(moduleName string) error { + if atomic.LoadUint32(ns.closed) != 0 { + return fmt.Errorf("module[%s] deleted from closed namespace", moduleName) + } ns.mux.Lock() defer ns.mux.Unlock() - node, ok := ns.moduleNamesSet[moduleName] + node, ok := ns.nameToNode[moduleName] if !ok { - return + return nil } // remove this module name - if node.prev == nil { - ns.moduleNamesList = node.next - } else { + if node.prev != nil { node.prev.next = node.next + } else { + ns.moduleList = node.next } if node.next != nil { node.next.prev = node.prev } - - delete(ns.modules, moduleName) - delete(ns.moduleNamesSet, moduleName) + delete(ns.nameToNode, moduleName) + return nil } -// module returns the module of the given name or nil if not in this namespace -func (ns *Namespace) module(moduleName string) *ModuleInstance { +// module returns the module of the given name or error if not in this namespace +func (ns *Namespace) module(moduleName string) (*ModuleInstance, error) { + if atomic.LoadUint32(ns.closed) != 0 { + return nil, fmt.Errorf("module[%s] requested from closed namespace", moduleName) + } ns.mux.RLock() defer ns.mux.RUnlock() - return ns.modules[moduleName] + node, ok := ns.nameToNode[moduleName] + if !ok { + return nil, fmt.Errorf("module[%s] not in namespace", moduleName) + } + + if node.module == nil { + return nil, fmt.Errorf("module[%s] not set in namespace", moduleName) + } + + return node.module, nil } // requireModules returns all instantiated modules whose names equal the keys in the input, or errs if any are missing. func (ns *Namespace) requireModules(moduleNames map[string]struct{}) (map[string]*ModuleInstance, error) { + if atomic.LoadUint32(ns.closed) != 0 { + return nil, errors.New("modules required from closed namespace") + } ret := make(map[string]*ModuleInstance, len(moduleNames)) ns.mux.RLock() defer ns.mux.RUnlock() for n := range moduleNames { - m, ok := ns.modules[n] + node, ok := ns.nameToNode[n] if !ok { return nil, fmt.Errorf("module[%s] not instantiated", n) } - ret[n] = m + ret[n] = node.module } return ret, nil } @@ -95,56 +128,67 @@ func (ns *Namespace) requireModules(moduleNames map[string]struct{}) (map[string // requireModuleName is a pre-flight check to reserve a module. // This must be reverted on error with deleteModule if initialization fails. func (ns *Namespace) requireModuleName(moduleName string) error { + if atomic.LoadUint32(ns.closed) != 0 { + return fmt.Errorf("module[%s] name required on closed namespace", moduleName) + } ns.mux.Lock() defer ns.mux.Unlock() - if _, ok := ns.moduleNamesSet[moduleName]; ok { + if _, ok := ns.nameToNode[moduleName]; ok { return fmt.Errorf("module[%s] has already been instantiated", moduleName) } // add the newest node to the moduleNamesList as the head. - node := &nameListNode{ + node := &moduleListNode{ name: moduleName, - next: ns.moduleNamesList, + next: ns.moduleList, } if node.next != nil { node.next.prev = node } - ns.moduleNamesList = node - ns.moduleNamesSet[moduleName] = node + ns.moduleList = node + ns.nameToNode[moduleName] = node return nil } // AliasModule aliases the instantiated module named `src` as `dst`. // // Note: This is only used for spectests. -func (ns *Namespace) AliasModule(src, dst string) { - ns.modules[dst] = ns.modules[src] +func (ns *Namespace) AliasModule(src, dst string) error { + if atomic.LoadUint32(ns.closed) != 0 { + return fmt.Errorf("module[%s] alias created on closed namespace", src) + } + ns.mux.Lock() + defer ns.mux.Unlock() + ns.nameToNode[dst] = ns.nameToNode[src] + return nil } // CloseWithExitCode implements the same method as documented on wazero.Namespace. func (ns *Namespace) CloseWithExitCode(ctx context.Context, exitCode uint32) (err error) { + if !atomic.CompareAndSwapUint32(ns.closed, 0, 1) { + return nil + } ns.mux.Lock() defer ns.mux.Unlock() // Close modules in reverse initialization order. - for node := ns.moduleNamesList; node != nil; node = node.next { + for node := ns.moduleList; node != nil; node = node.next { // If closing this module errs, proceed anyway to close the others. - if m, ok := ns.modules[node.name]; ok { + if m := node.module; m != nil { if _, e := m.CallCtx.close(ctx, exitCode); e != nil && err == nil { err = e // first error } } } - ns.moduleNamesList = nil - ns.moduleNamesSet = map[string]*nameListNode{} - ns.modules = map[string]*ModuleInstance{} + ns.moduleList = nil + ns.nameToNode = nil return } // Module implements wazero.Namespace Module func (ns *Namespace) Module(moduleName string) api.Module { - if m := ns.module(moduleName); m != nil { - return m.CallCtx - } else { + m, err := ns.module(moduleName) + if err != nil { return nil } + return m.CallCtx } diff --git a/internal/wasm/namespace_test.go b/internal/wasm/namespace_test.go index 551100c2..d0611f62 100644 --- a/internal/wasm/namespace_test.go +++ b/internal/wasm/namespace_test.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "errors" "testing" @@ -11,31 +12,47 @@ import ( func Test_newNamespace(t *testing.T) { ns := newNamespace() - require.NotNil(t, ns.modules) + require.NotNil(t, ns.nameToNode) } -func TestNamespace_addModule(t *testing.T) { +func TestNamespace_setModule(t *testing.T) { ns := newNamespace() m1 := &ModuleInstance{Name: "m1"} - t.Run("adds module", func(t *testing.T) { - ns.addModule(m1) + t.Run("errors if not required", func(t *testing.T) { + require.Error(t, ns.setModule(m1)) + }) + + t.Run("adds module", func(t *testing.T) { + ns.nameToNode[m1.Name] = &moduleListNode{name: m1.Name} + require.NoError(t, ns.setModule(m1)) + require.Equal(t, map[string]*moduleListNode{m1.Name: {name: m1.Name, module: m1}}, ns.nameToNode) - require.Equal(t, map[string]*ModuleInstance{m1.Name: m1}, ns.modules) // Doesn't affect module names - require.Zero(t, len(ns.moduleNamesSet)) - require.Nil(t, ns.moduleNamesList) + require.Nil(t, ns.moduleList) }) t.Run("redundant ok", func(t *testing.T) { - ns.addModule(m1) - require.Equal(t, map[string]*ModuleInstance{m1.Name: m1}, ns.modules) + require.NoError(t, ns.setModule(m1)) + require.Equal(t, map[string]*moduleListNode{m1.Name: {name: m1.Name, module: m1}}, ns.nameToNode) + + // Doesn't affect module names + require.Nil(t, ns.moduleList) }) t.Run("adds second module", func(t *testing.T) { m2 := &ModuleInstance{Name: "m2"} - ns.addModule(m2) - require.Equal(t, map[string]*ModuleInstance{m1.Name: m1, m2.Name: m2}, ns.modules) + ns.nameToNode[m2.Name] = &moduleListNode{name: m2.Name} + require.NoError(t, ns.setModule(m2)) + require.Equal(t, map[string]*moduleListNode{m1.Name: {name: m1.Name, module: m1}, m2.Name: {name: m2.Name, module: m2}}, ns.nameToNode) + + // Doesn't affect module names + require.Nil(t, ns.moduleList) + }) + + t.Run("error on closed", func(t *testing.T) { + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + require.Error(t, ns.setModule(m1)) }) } @@ -43,24 +60,31 @@ func TestNamespace_deleteModule(t *testing.T) { ns, m1, m2 := newTestNamespace() t.Run("delete one module", func(t *testing.T) { - ns.deleteModule(m2.Name) + require.NoError(t, ns.deleteModule(m2.Name)) // Leaves the other module alone - require.Equal(t, map[string]*ModuleInstance{m1.Name: m1}, ns.modules) - require.Equal(t, map[string]*nameListNode{m1.Name: {name: m1.Name}}, ns.moduleNamesSet) - require.Equal(t, &nameListNode{name: m1.Name}, ns.moduleNamesList) + m1Node := &moduleListNode{name: m1.Name, module: m1} + require.Equal(t, map[string]*moduleListNode{m1.Name: m1Node}, ns.nameToNode) + require.Equal(t, m1Node, ns.moduleList) }) t.Run("ok if missing", func(t *testing.T) { - ns.deleteModule(m2.Name) + require.NoError(t, ns.deleteModule(m2.Name)) }) t.Run("delete last module", func(t *testing.T) { - ns.deleteModule(m1.Name) + require.NoError(t, ns.deleteModule(m1.Name)) - require.Zero(t, len(ns.modules)) - require.Zero(t, len(ns.moduleNamesSet)) - require.Nil(t, ns.moduleNamesList) + require.Zero(t, len(ns.nameToNode)) + require.Nil(t, ns.moduleList) + }) + + t.Run("error on closed", func(t *testing.T) { + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + require.Error(t, ns.deleteModule(m1.Name)) + + require.Zero(t, len(ns.nameToNode)) + require.Nil(t, ns.moduleList) }) } @@ -68,11 +92,29 @@ func TestNamespace_module(t *testing.T) { ns, m1, _ := newTestNamespace() t.Run("ok", func(t *testing.T) { - require.Equal(t, m1, ns.module(m1.Name)) + got, err := ns.module(m1.Name) + require.NoError(t, err) + require.Equal(t, m1, got) }) t.Run("unknown", func(t *testing.T) { - require.Nil(t, ns.module("unknown")) + got, err := ns.module("unknown") + require.Error(t, err) + require.Nil(t, got) + }) + + t.Run("not set", func(t *testing.T) { + ns.nameToNode["not set"] = &moduleListNode{name: "not set"} + got, err := ns.module("not set") + require.Error(t, err) + require.Nil(t, got) + }) + + t.Run("namespace closed", func(t *testing.T) { + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + got, err := ns.module(m1.Name) + require.Error(t, err) + require.Nil(t, got) }) } @@ -90,47 +132,65 @@ func TestNamespace_requireModules(t *testing.T) { _, err := ns.requireModules(map[string]struct{}{"unknown": {}}) require.EqualError(t, err, "module[unknown] not instantiated") }) + t.Run("namespace closed", func(t *testing.T) { + ns, _, _ := newTestNamespace() + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + + _, err := ns.requireModules(map[string]struct{}{"unknown": {}}) + require.Error(t, err) + }) } func TestNamespace_requireModuleName(t *testing.T) { - ns := &Namespace{moduleNamesSet: map[string]*nameListNode{}} + ns := &Namespace{nameToNode: map[string]*moduleListNode{}, closed: new(uint32)} t.Run("first", func(t *testing.T) { err := ns.requireModuleName("m1") require.NoError(t, err) // Ensure it adds the module name, and doesn't impact the module list. - require.Equal(t, &nameListNode{name: "m1"}, ns.moduleNamesList) - require.Equal(t, map[string]*nameListNode{"m1": {name: "m1"}}, ns.moduleNamesSet) - require.Zero(t, len(ns.modules)) + require.Equal(t, &moduleListNode{name: "m1"}, ns.moduleList) + require.Equal(t, map[string]*moduleListNode{"m1": {name: "m1"}}, ns.nameToNode) }) t.Run("second", func(t *testing.T) { err := ns.requireModuleName("m2") require.NoError(t, err) - m2Node := &nameListNode{name: "m2"} - m1Node := &nameListNode{name: "m1", prev: m2Node} + m2Node := &moduleListNode{name: "m2"} + m1Node := &moduleListNode{name: "m1", prev: m2Node} m2Node.next = m1Node // Appends in order. - require.Equal(t, m2Node, ns.moduleNamesList) - require.Equal(t, map[string]*nameListNode{"m1": m1Node, "m2": m2Node}, ns.moduleNamesSet) + require.Equal(t, m2Node, ns.moduleList) + require.Equal(t, map[string]*moduleListNode{"m1": m1Node, "m2": m2Node}, ns.nameToNode) }) t.Run("existing", func(t *testing.T) { err := ns.requireModuleName("m2") require.EqualError(t, err, "module[m2] has already been instantiated") }) + t.Run("namespace closed", func(t *testing.T) { + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + require.Error(t, ns.requireModuleName("m3")) + }) } func TestNamespace_AliasModule(t *testing.T) { ns := newNamespace() m1 := &ModuleInstance{Name: "m1"} - ns.addModule(m1) + ns.nameToNode[m1.Name] = &moduleListNode{name: m1.Name, module: m1} - ns.AliasModule("m1", "m2") - require.Equal(t, map[string]*ModuleInstance{"m1": m1, "m2": m1}, ns.modules) - // Doesn't affect module names - require.Zero(t, len(ns.moduleNamesSet)) - require.Nil(t, ns.moduleNamesList) + t.Run("alias module", func(t *testing.T) { + require.NoError(t, ns.AliasModule("m1", "m2")) + m1node := &moduleListNode{name: "m1", module: m1} + require.Equal(t, map[string]*moduleListNode{"m1": m1node, "m2": m1node}, ns.nameToNode) + // Doesn't affect module names + require.Nil(t, ns.moduleList) + }) + t.Run("namespace closed", func(t *testing.T) { + require.NoError(t, ns.CloseWithExitCode(context.Background(), 0)) + require.Error(t, ns.AliasModule("m3", "m4")) + require.Nil(t, ns.nameToNode) + require.Nil(t, ns.moduleList) + }) } func TestNamespace_CloseWithExitCode(t *testing.T) { @@ -166,9 +226,8 @@ func TestNamespace_CloseWithExitCode(t *testing.T) { require.Equal(t, uint64(1)+uint64(2)<<32, *m2.CallCtx.closed) // Namespace state zeroed - require.Zero(t, len(ns.modules)) - require.Zero(t, len(ns.moduleNamesSet)) - require.Nil(t, ns.moduleNamesList) + require.Zero(t, len(ns.nameToNode)) + require.Nil(t, ns.moduleList) }) } @@ -192,9 +251,23 @@ func TestNamespace_CloseWithExitCode(t *testing.T) { require.Equal(t, uint64(1)+uint64(2)<<32, *m2.CallCtx.closed) // Namespace state zeroed - require.Zero(t, len(ns.modules)) - require.Zero(t, len(ns.moduleNamesSet)) - require.Nil(t, ns.moduleNamesList) + require.Zero(t, len(ns.nameToNode)) + require.Nil(t, ns.moduleList) + }) + t.Run("multiple closes", func(t *testing.T) { + ns, m1, m2 := newTestNamespace() + + require.NoError(t, ns.CloseWithExitCode(testCtx, 2)) + + // Both modules were closed + require.Equal(t, uint64(1)+uint64(2)<<32, *m1.CallCtx.closed) + require.Equal(t, uint64(1)+uint64(2)<<32, *m2.CallCtx.closed) + + // Namespace state zeroed + require.Zero(t, len(ns.nameToNode)) + require.Nil(t, ns.moduleList) + + require.NoError(t, ns.CloseWithExitCode(testCtx, 2)) }) } @@ -212,17 +285,17 @@ func TestNamespace_Module(t *testing.T) { // newTestNamespace sets up a new Namespace without adding test coverage its functions. func newTestNamespace() (*Namespace, *ModuleInstance, *ModuleInstance) { - ns := &Namespace{} + ns := &Namespace{closed: new(uint32)} m1 := &ModuleInstance{Name: "m1"} m1.CallCtx = NewCallContext(ns, m1, nil) m2 := &ModuleInstance{Name: "m2"} m2.CallCtx = NewCallContext(ns, m2, nil) - ns.modules = map[string]*ModuleInstance{m1.Name: m1, m2.Name: m2} - node1 := &nameListNode{name: m1.Name} - node2 := &nameListNode{name: m2.Name, next: node1} - ns.moduleNamesSet = map[string]*nameListNode{m1.Name: node1, m2.Name: node2} - ns.moduleNamesList = node2 + node1 := &moduleListNode{name: m1.Name, module: m1} + node2 := &moduleListNode{name: m2.Name, module: m2, next: node1} + node1.prev = node2 + ns.nameToNode = map[string]*moduleListNode{m1.Name: node1, m2.Name: node2} + ns.moduleList = node2 return ns, m1, m2 } diff --git a/internal/wasm/store.go b/internal/wasm/store.go index c779ec20..a2a79b50 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -322,12 +322,15 @@ func (s *Store) Instantiate( // Instantiate the module and add it to the namespace so that other modules can import it. if callCtx, err := s.instantiate(ctx, ns, module, name, sys, importedModules); err != nil { - ns.deleteModule(name) + _ = ns.deleteModule(name) return nil, err } else { // Now that the instantiation is complete without error, add it. // This makes the module visible for import, and ensures it is closed when the namespace is. - ns.addModule(callCtx.module) + if err := ns.setModule(callCtx.module); err != nil { + callCtx.Close(ctx) + return nil, err + } return callCtx, nil } } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index c6849da9..aedef24d 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -111,8 +111,8 @@ func TestStore_Instantiate(t *testing.T) { defer mod.Close(testCtx) t.Run("CallContext defaults", func(t *testing.T) { - require.Equal(t, ns.modules[""], mod.module) - require.Equal(t, ns.modules[""].Memory, mod.memory) + require.Equal(t, ns.nameToNode[""].module, mod.module) + require.Equal(t, ns.nameToNode[""].module.Memory, mod.memory) require.Equal(t, ns, mod.ns) require.Equal(t, sysCtx, mod.Sys) }) @@ -168,7 +168,7 @@ func TestStore_CloseWithExitCode(t *testing.T) { require.NoError(t, err) // If Namespace.CloseWithExitCode was dispatched properly, modules should be empty - require.Zero(t, len(ns.modules)) + require.Nil(t, ns.moduleList) // Store state zeroed require.Zero(t, len(s.namespaces)) @@ -187,7 +187,7 @@ func TestStore_hammer(t *testing.T) { imported, err := s.Instantiate(testCtx, ns, m, importedModuleName, nil) require.NoError(t, err) - _, ok := ns.modules[imported.Name()] + _, ok := ns.nameToNode[imported.Name()] require.True(t, ok) importingModule := &Module{ @@ -227,7 +227,7 @@ func TestStore_hammer(t *testing.T) { require.NoError(t, imported.Close(testCtx)) // All instances are freed. - require.Zero(t, len(ns.modules)) + require.Nil(t, ns.moduleList) } func TestStore_Instantiate_Errors(t *testing.T) { @@ -252,7 +252,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, ns, m, importedModuleName, nil) require.NoError(t, err) - hm := ns.modules[importedModuleName] + hm := ns.nameToNode[importedModuleName] require.NotNil(t, hm) _, err = s.Instantiate(testCtx, ns, &Module{ @@ -273,7 +273,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, ns, m, importedModuleName, nil) require.NoError(t, err) - hm := ns.modules[importedModuleName] + hm := ns.nameToNode[importedModuleName] require.NotNil(t, hm) engine := s.Engine.(*mockEngine) @@ -304,7 +304,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, ns, m, importedModuleName, nil) require.NoError(t, err) - hm := ns.modules[importedModuleName] + hm := ns.nameToNode[importedModuleName] require.NotNil(t, hm) startFuncIndex := uint32(1)