diff --git a/internal/wasm/module_instance.go b/internal/wasm/module_instance.go index a12f5b6d..61fb31e8 100644 --- a/internal/wasm/module_instance.go +++ b/internal/wasm/module_instance.go @@ -102,7 +102,7 @@ func (m *ModuleInstance) CloseWithExitCode(ctx context.Context, exitCode uint32) if !m.setExitCode(exitCode, exitCodeFlagResourceClosed) { return nil // not an error to have already closed } - _ = m.s.deleteModule(m.moduleListNode) + _ = m.s.deleteModule(m) return m.ensureResourcesClosed(ctx) } @@ -110,7 +110,7 @@ func (m *ModuleInstance) closeWithExitCodeWithoutClosingResource(exitCode uint32 if !m.setExitCode(exitCode, exitCodeFlagResourceNotClosed) { return nil // not an error to have already closed } - _ = m.s.deleteModule(m.moduleListNode) + _ = m.s.deleteModule(m) return nil } diff --git a/internal/wasm/store.go b/internal/wasm/store.go index c56c134e..df6fa04e 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -26,11 +26,11 @@ type ( // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#store%E2%91%A0 Store struct { // moduleList ensures modules are closed in reverse initialization order. - moduleList *moduleListNode // guarded by mux + moduleList *ModuleInstance // guarded by mux - // nameToNode holds the instantiated Wasm modules by module name from Instantiate. + // nameToModule holds the instantiated Wasm modules by module name from Instantiate. // It ensures no race conditions instantiating two modules of the same name. - nameToNode map[string]*moduleListNode // guarded by mux + nameToModule map[string]*ModuleInstance // guarded by mux // EnabledFeatures are read-only to allow optimizations. EnabledFeatures api.CoreFeatures @@ -92,8 +92,6 @@ type ( // or external objects (unimplemented). ElementInstances []ElementInstance - moduleListNode *moduleListNode - // Sys is exposed for use in special imports such as WASI, assemblyscript // and gojs. // @@ -110,6 +108,8 @@ type ( // s is the Store on which this module is instantiated. s *Store + // prev and next hold the nodes in the linked list of ModuleInstance held by Store. + prev, next *ModuleInstance // Definitions is derived from *Module, and is constructed during compilation phrase. Definitions []FunctionDefinition } @@ -258,7 +258,7 @@ func NewStore(enabledFeatures api.CoreFeatures, engine Engine) *Store { typeIDs[k] = v } return &Store{ - nameToNode: map[string]*moduleListNode{}, + nameToModule: map[string]*ModuleInstance{}, EnabledFeatures: enabledFeatures, Engine: engine, typeIDs: typeIDs, @@ -294,33 +294,16 @@ func (s *Store) Instantiate( return nil, err } - var listNode *moduleListNode - if name == "" { - listNode = s.registerAnonymous() - } else { - // Write-Lock the store and claim the name of the current module. - listNode, err = s.requireModuleName(name) - if err != nil { - return nil, err - } - } - // Instantiate the module and add it to the store so that other modules can import it. m, err := s.instantiate(ctx, module, name, sys, importedModules, typeIDs) if err != nil { - _ = s.deleteModule(listNode) return nil, err } - m.moduleListNode = listNode - - if name != "" { - // Now that the instantiation is complete without error, add it. - // This makes the module visible for import, and ensures it is closed when the store is. - if err := s.setModule(m); err != nil { - m.Close(ctx) - return nil, err - } + // Now that the instantiation is complete without error, add it. + if err = s.registerModule(m); err != nil { + _ = m.Close(ctx) + return nil, err } return m, nil } @@ -638,17 +621,15 @@ func (s *Store) CloseWithExitCode(ctx context.Context, exitCode uint32) (err err s.mux.Lock() defer s.mux.Unlock() // Close modules in reverse initialization order. - for node := s.moduleList; node != nil; node = node.next { + for m := s.moduleList; m != nil; m = m.next { // If closing this module errs, proceed anyway to close the others. - if m := node.module; m != nil { - if e := m.closeWithExitCode(ctx, exitCode); e != nil && err == nil { - // TODO: use multiple errors handling in Go 1.20. - err = e // first error - } + if e := m.closeWithExitCode(ctx, exitCode); e != nil && err == nil { + // TODO: use multiple errors handling in Go 1.20. + err = e // first error } } s.moduleList = nil - s.nameToNode = nil + s.nameToModule = nil s.typeIDs = nil return } diff --git a/internal/wasm/store_module_list.go b/internal/wasm/store_module_list.go index 3291198a..7f39a708 100644 --- a/internal/wasm/store_module_list.go +++ b/internal/wasm/store_module_list.go @@ -1,58 +1,34 @@ package wasm import ( + "errors" "fmt" "github.com/tetratelabs/wazero/api" ) -// moduleListNode is a node in a doubly linked list of names. -type moduleListNode struct { - name string - module *ModuleInstance - next, prev *moduleListNode -} - -// setModule makes the module visible for import. -func (s *Store) setModule(m *ModuleInstance) error { - s.mux.Lock() - defer s.mux.Unlock() - - node, ok := s.nameToNode[m.ModuleName] - if !ok { - return fmt.Errorf("module[%s] name has not been required", m.ModuleName) - } - - node.module = m - return nil -} - // deleteModule makes the moduleName available for instantiation again. -func (s *Store) deleteModule(node *moduleListNode) error { - if node == nil { - return nil - } - +func (s *Store) deleteModule(m *ModuleInstance) error { s.mux.Lock() defer s.mux.Unlock() - // remove this module name - if node.prev != nil { - node.prev.next = node.next + // Remove this module name. + if m.prev != nil { + m.prev.next = m.next } - if node.next != nil { - node.next.prev = node.prev + if m.next != nil { + m.next.prev = m.prev } - if s.moduleList == node { - s.moduleList = node.next + if s.moduleList == m { + s.moduleList = m.next } - // clear the node state so it does not enter any other branch - // on subsequent calls to deleteModule - node.prev = nil - node.next = nil + // Clear the m state so it does not enter any other branch + // on subsequent calls to deleteModule. + m.prev = nil + m.next = nil - if node.name != "" { - delete(s.nameToNode, node.name) + if m.ModuleName != "" { + delete(s.nameToModule, m.ModuleName) } return nil } @@ -61,16 +37,15 @@ func (s *Store) deleteModule(node *moduleListNode) error { func (s *Store) module(moduleName string) (*ModuleInstance, error) { s.mux.RLock() defer s.mux.RUnlock() - node, ok := s.nameToNode[moduleName] + m, ok := s.nameToModule[moduleName] if !ok { return nil, fmt.Errorf("module[%s] not in store", moduleName) } - if node.module == nil { + if m == nil { return nil, fmt.Errorf("module[%s] not set in store", moduleName) } - - return node.module, nil + return m, nil } // requireModules returns all instantiated modules whose names equal the keys in the input, or errs if any are missing. @@ -81,50 +56,39 @@ func (s *Store) requireModules(moduleNames map[string]struct{}) (map[string]*Mod defer s.mux.RUnlock() for n := range moduleNames { - node, ok := s.nameToNode[n] + module, ok := s.nameToModule[n] if !ok { return nil, fmt.Errorf("module[%s] not instantiated", n) } - ret[n] = node.module + ret[n] = module } return ret, nil } -// requireModuleName is a pre-flight check to reserve a module. -// This must be reverted on error with deleteModule if initialization fails. -func (s *Store) requireModuleName(moduleName string) (*moduleListNode, error) { - node := &moduleListNode{name: moduleName} - - s.mux.Lock() - defer s.mux.Unlock() - if _, ok := s.nameToNode[moduleName]; ok { - return nil, fmt.Errorf("module[%s] has already been instantiated", moduleName) - } - - // add the newest node to the moduleNamesList as the head. - node.next = s.moduleList - if node.next != nil { - node.next.prev = node - } - s.moduleList = node - s.nameToNode[moduleName] = node - return node, nil -} - -func (s *Store) registerAnonymous() *moduleListNode { - node := &moduleListNode{name: ""} - +// registerModule registers a ModuleInstance into the store. +// This makes the ModuleInstance visible for import if it's not anonymous, and ensures it is closed when the store is. +func (s *Store) registerModule(m *ModuleInstance) error { s.mux.Lock() defer s.mux.Unlock() - // add the newest node to the moduleNamesList as the head. - node.next = s.moduleList - if node.next != nil { - node.next.prev = node + if s.nameToModule == nil { + return errors.New("already closed") } - s.moduleList = node - return node + if m.ModuleName != "" { + if _, ok := s.nameToModule[m.ModuleName]; ok { + return fmt.Errorf("module[%s] has already been instantiated", m.ModuleName) + } + s.nameToModule[m.ModuleName] = m + } + + // Add the newest node to the moduleNamesList as the head. + m.next = s.moduleList + if m.next != nil { + m.next.prev = m + } + s.moduleList = m + return nil } // AliasModule aliases the instantiated module named `src` as `dst`. @@ -133,7 +97,7 @@ func (s *Store) registerAnonymous() *moduleListNode { func (s *Store) AliasModule(src, dst string) error { s.mux.Lock() defer s.mux.Unlock() - s.nameToNode[dst] = s.nameToNode[src] + s.nameToModule[dst] = s.nameToModule[src] return nil } diff --git a/internal/wasm/store_module_list_test.go b/internal/wasm/store_module_list_test.go index e2858cdc..b5ca500c 100644 --- a/internal/wasm/store_module_list_test.go +++ b/internal/wasm/store_module_list_test.go @@ -7,44 +7,31 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -func TestStore_setModule(t *testing.T) { +func TestStore_registerModule(t *testing.T) { s := newStore() m1 := &ModuleInstance{ModuleName: "m1"} - t.Run("errors if not required", func(t *testing.T) { - require.Error(t, s.setModule(m1)) - }) - t.Run("adds module", func(t *testing.T) { - s.nameToNode[m1.ModuleName] = &moduleListNode{name: m1.ModuleName} - require.NoError(t, s.setModule(m1)) - require.Equal(t, map[string]*moduleListNode{m1.ModuleName: {name: m1.ModuleName, module: m1}}, s.nameToNode) - - // Doesn't affect module names - require.Nil(t, s.moduleList) - }) - - t.Run("redundant ok", func(t *testing.T) { - require.NoError(t, s.setModule(m1)) - require.Equal(t, map[string]*moduleListNode{m1.ModuleName: {name: m1.ModuleName, module: m1}}, s.nameToNode) - - // Doesn't affect module names - require.Nil(t, s.moduleList) + require.NoError(t, s.registerModule(m1)) + require.Equal(t, map[string]*ModuleInstance{m1.ModuleName: m1}, s.nameToModule) + require.Equal(t, m1, s.moduleList) }) t.Run("adds second module", func(t *testing.T) { m2 := &ModuleInstance{ModuleName: "m2"} - s.nameToNode[m2.ModuleName] = &moduleListNode{name: m2.ModuleName} - require.NoError(t, s.setModule(m2)) - require.Equal(t, map[string]*moduleListNode{m1.ModuleName: {name: m1.ModuleName, module: m1}, m2.ModuleName: {name: m2.ModuleName, module: m2}}, s.nameToNode) + require.NoError(t, s.registerModule(m2)) + require.Equal(t, map[string]*ModuleInstance{m1.ModuleName: m1, m2.ModuleName: m2}, s.nameToModule) + require.Equal(t, m2, s.moduleList) + }) - // Doesn't affect module names - require.Nil(t, s.moduleList) + t.Run("error on duplicated non anonymous", func(t *testing.T) { + m1Second := &ModuleInstance{ModuleName: "m1"} + require.EqualError(t, s.registerModule(m1Second), "module[m1] has already been instantiated") }) t.Run("error on closed", func(t *testing.T) { require.NoError(t, s.CloseWithExitCode(context.Background(), 0)) - require.Error(t, s.setModule(m1)) + require.Error(t, s.registerModule(m1)) }) } @@ -52,24 +39,44 @@ func TestStore_deleteModule(t *testing.T) { s, m1, m2 := newTestStore() t.Run("delete one module", func(t *testing.T) { - require.NoError(t, s.deleteModule(m2.moduleListNode)) + require.NoError(t, s.deleteModule(m2)) - // Leaves the other module alone - m1Node := &moduleListNode{name: m1.ModuleName, module: m1} - require.Equal(t, map[string]*moduleListNode{m1.ModuleName: m1Node}, s.nameToNode) - require.Equal(t, m1Node, s.moduleList) + // Leaves the other module alone. + require.Equal(t, map[string]*ModuleInstance{m1.ModuleName: m1}, s.nameToModule) + require.Equal(t, m1, s.moduleList) }) t.Run("ok if missing", func(t *testing.T) { - require.NoError(t, s.deleteModule(m2.moduleListNode)) + require.NoError(t, s.deleteModule(m2)) }) t.Run("delete last module", func(t *testing.T) { - require.NoError(t, s.deleteModule(m1.moduleListNode)) + require.NoError(t, s.deleteModule(m1)) - require.Zero(t, len(s.nameToNode)) + require.Zero(t, len(s.nameToModule)) require.Nil(t, s.moduleList) }) + + t.Run("delete middle", func(t *testing.T) { + s := newStore() + one, two, three := &ModuleInstance{ModuleName: "1"}, &ModuleInstance{ModuleName: "2"}, &ModuleInstance{ModuleName: "3"} + require.NoError(t, s.registerModule(one)) + require.NoError(t, s.registerModule(two)) + require.NoError(t, s.registerModule(three)) + require.Equal(t, three, s.moduleList) + require.Nil(t, three.prev) + require.Equal(t, two, three.next) + require.Equal(t, two.prev, three) + require.Equal(t, one, two.next) + require.Equal(t, one.prev, two) + require.Nil(t, one.next) + require.NoError(t, s.deleteModule(two)) + require.Equal(t, three, s.moduleList) + require.Nil(t, three.prev) + require.Equal(t, one, three.next) + require.Equal(t, one.prev, three) + require.Nil(t, one.next) + }) } func TestStore_module(t *testing.T) { @@ -87,13 +94,6 @@ func TestStore_module(t *testing.T) { require.Nil(t, got) }) - t.Run("not set", func(t *testing.T) { - s.nameToNode["not set"] = &moduleListNode{name: "not set"} - got, err := s.module("not set") - require.Error(t, err) - require.Nil(t, got) - }) - t.Run("store closed", func(t *testing.T) { require.NoError(t, s.CloseWithExitCode(context.Background(), 0)) got, err := s.module(m1.ModuleName) @@ -125,44 +125,14 @@ func TestStore_requireModules(t *testing.T) { }) } -func TestStore_requireModuleName(t *testing.T) { - s := newStore() - - t.Run("first", func(t *testing.T) { - _, err := s.requireModuleName("m1") - require.NoError(t, err) - - // Ensure it adds the module name, and doesn't impact the module list. - require.Equal(t, &moduleListNode{name: "m1"}, s.moduleList) - require.Equal(t, map[string]*moduleListNode{"m1": {name: "m1"}}, s.nameToNode) - }) - t.Run("second", func(t *testing.T) { - _, err := s.requireModuleName("m2") - require.NoError(t, err) - m2Node := &moduleListNode{name: "m2"} - m1Node := &moduleListNode{name: "m1", prev: m2Node} - m2Node.next = m1Node - - // Appends in order. - require.Equal(t, m2Node, s.moduleList) - require.Equal(t, map[string]*moduleListNode{"m1": m1Node, "m2": m2Node}, s.nameToNode) - }) - t.Run("existing", func(t *testing.T) { - _, err := s.requireModuleName("m2") - require.EqualError(t, err, "module[m2] has already been instantiated") - }) -} - func TestStore_AliasModule(t *testing.T) { s := newStore() - m1 := &ModuleInstance{ModuleName: "m1"} - s.nameToNode[m1.ModuleName] = &moduleListNode{name: m1.ModuleName, module: m1} + s.nameToModule[m1.ModuleName] = m1 t.Run("alias module", func(t *testing.T) { require.NoError(t, s.AliasModule("m1", "m2")) - m1node := &moduleListNode{name: "m1", module: m1} - require.Equal(t, map[string]*moduleListNode{"m1": m1node, "m2": m1node}, s.nameToNode) + require.Equal(t, map[string]*ModuleInstance{"m1": m1, "m2": m1}, s.nameToModule) // Doesn't affect module names require.Nil(t, s.moduleList) }) @@ -186,13 +156,9 @@ func newTestStore() (*Store, *ModuleInstance, *ModuleInstance) { m1 := &ModuleInstance{ModuleName: "m1"} m2 := &ModuleInstance{ModuleName: "m2"} - node1 := &moduleListNode{name: m1.ModuleName, module: m1} - node2 := &moduleListNode{name: m2.ModuleName, module: m2, next: node1} - node1.prev = node2 - s.nameToNode = map[string]*moduleListNode{m1.ModuleName: node1, m2.ModuleName: node2} - s.moduleList = node2 - - m1.moduleListNode = node1 - m2.moduleListNode = node2 + m1.prev = m2 + m2.next = m1 + s.nameToModule = map[string]*ModuleInstance{m1.ModuleName: m1, m2.ModuleName: m2} + s.moduleList = m2 return s, m1, m2 } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index eb8646a8..bc282733 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -114,8 +114,8 @@ func TestStore_Instantiate(t *testing.T) { defer mod.Close(testCtx) t.Run("ModuleInstance defaults", func(t *testing.T) { - require.Equal(t, s.nameToNode["bar"].module, mod) - require.Equal(t, s.nameToNode["bar"].module.MemoryInstance, mod.MemoryInstance) + require.Equal(t, s.nameToModule["bar"], mod) + require.Equal(t, s.nameToModule["bar"].MemoryInstance, mod.MemoryInstance) require.Equal(t, s, mod.s) require.Equal(t, sysCtx, mod.Sys) }) @@ -191,7 +191,7 @@ func TestStore_hammer(t *testing.T) { imported, err := s.Instantiate(testCtx, m, importedModuleName, nil, []FunctionTypeID{0}) require.NoError(t, err) - _, ok := s.nameToNode[imported.Name()] + _, ok := s.nameToModule[imported.Name()] require.True(t, ok) importingModule := &Module{ @@ -246,7 +246,7 @@ func TestStore_hammer_close(t *testing.T) { imported, err := s.Instantiate(testCtx, m, importedModuleName, nil, []FunctionTypeID{0}) require.NoError(t, err) - _, ok := s.nameToNode[imported.Name()] + _, ok := s.nameToModule[imported.Name()] require.True(t, ok) importingModule := &Module{ @@ -317,7 +317,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, m, importedModuleName, nil, []FunctionTypeID{0}) require.NoError(t, err) - hm := s.nameToNode[importedModuleName] + hm := s.nameToModule[importedModuleName] require.NotNil(t, hm) _, err = s.Instantiate(testCtx, &Module{ @@ -338,7 +338,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, m, importedModuleName, nil, []FunctionTypeID{0}) require.NoError(t, err) - hm := s.nameToNode[importedModuleName] + hm := s.nameToModule[importedModuleName] require.NotNil(t, hm) engine := s.Engine.(*mockEngine) @@ -370,7 +370,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(testCtx, m, importedModuleName, nil, []FunctionTypeID{0}) require.NoError(t, err) - hm := s.nameToNode[importedModuleName] + hm := s.nameToModule[importedModuleName] require.NotNil(t, hm) startFuncIndex := uint32(1)