Ensures no race condition adding a module of the same name (#344)

Signed-off-by: Adrian Cole <adrian@tetrate.io>
This commit is contained in:
Crypt Keeper
2022-03-08 08:21:09 +08:00
committed by GitHub
parent 2846939df2
commit e7afd75e25
2 changed files with 70 additions and 50 deletions

View File

@@ -94,10 +94,6 @@ func (f *exportedFunction) Call(ctx context.Context, params ...uint64) ([]uint64
// //
// TOOD(adrian): make this goroutine-safe like store.Instantiate. // TOOD(adrian): make this goroutine-safe like store.Instantiate.
func (s *Store) NewHostModule(moduleName string, nameToGoFunc map[string]interface{}) (*HostModule, error) { func (s *Store) NewHostModule(moduleName string, nameToGoFunc map[string]interface{}) (*HostModule, error) {
if err := s.requireModuleUnused(moduleName); err != nil {
return nil, err
}
exportCount := len(nameToGoFunc) exportCount := len(nameToGoFunc)
ret := &HostModule{name: moduleName, NameToFunction: make(map[string]*FunctionInstance, exportCount)} ret := &HostModule{name: moduleName, NameToFunction: make(map[string]*FunctionInstance, exportCount)}
hostModule := &ModuleInstance{ hostModule := &ModuleInstance{
@@ -106,9 +102,14 @@ func (s *Store) NewHostModule(moduleName string, nameToGoFunc map[string]interfa
hostModule: ret, hostModule: ret,
} }
if err := s.requireModuleName(moduleName); err != nil {
return nil, err
}
for name, goFunc := range nameToGoFunc { for name, goFunc := range nameToGoFunc {
hf, err := newGoFunc(goFunc) hf, err := newGoFunc(goFunc)
if err != nil { if err != nil {
s.deleteModule(moduleName)
return nil, fmt.Errorf("func[%s] %w", name, err) return nil, fmt.Errorf("func[%s] %w", name, err)
} }
@@ -124,11 +125,13 @@ func (s *Store) NewHostModule(moduleName string, nameToGoFunc map[string]interfa
ret.NameToFunction[name] = f ret.NameToFunction[name] = f
if err = s.compileHostFunction(f); err != nil { if err = s.compileHostFunction(f); err != nil {
s.deleteModule(moduleName)
return nil, err return nil, err
} }
} }
s.modules[moduleName] = hostModule // Now that the instantiation is complete without error, add it. This makes it visible for import.
s.addModule(hostModule)
return ret, nil return ret, nil
} }
@@ -149,15 +152,6 @@ func (s *Store) compileHostFunction(f *FunctionInstance) (err error) {
return nil return nil
} }
func (s *Store) requireModuleUnused(moduleName string) error {
s.mux.RLock()
defer s.mux.RUnlock()
if _, ok := s.modules[moduleName]; ok {
return fmt.Errorf("module %s has already been instantiated", moduleName)
}
return nil
}
// HostModule implements wasm.HostModule // HostModule implements wasm.HostModule
type HostModule struct { type HostModule struct {
// name is for String and Store.ReleaseModule // name is for String and Store.ReleaseModule

View File

@@ -37,8 +37,11 @@ type (
// EnabledFeatures are read-only to allow optimizations. // EnabledFeatures are read-only to allow optimizations.
EnabledFeatures Features EnabledFeatures Features
// moduleNames ensures no race conditions instantiating two modules of the same name
moduleNames map[string]struct{} // guarded by mux
// modules holds the instantiated Wasm modules by module name from Instantiate. // modules holds the instantiated Wasm modules by module name from Instantiate.
modules map[string]*ModuleInstance modules map[string]*ModuleInstance // guarded by mux
// typeIDs maps each FunctionType.String() to a unique FunctionTypeID. This is used at runtime to // typeIDs maps each FunctionType.String() to a unique FunctionTypeID. This is used at runtime to
// do type-checks on indirect function calls. // do type-checks on indirect function calls.
@@ -236,41 +239,41 @@ const (
maximumFunctionTypes = 1 << 27 maximumFunctionTypes = 1 << 27
) )
// newModuleInstance bundles all the instances for a module and creates a new module instance. // addSections adds section elements to the ModuleInstance
func newModuleInstance(name string, module *Module, importedFunctions, functions []*FunctionInstance, func (m *ModuleInstance) addSections(module *Module, importedFunctions, functions []*FunctionInstance,
importedGlobals, globals []*GlobalInstance, importedTable, table *TableInstance, importedGlobals, globals []*GlobalInstance, importedTable, table *TableInstance,
memory, importedMemory *MemoryInstance, typeInstances []*TypeInstance, moduleImports map[*ModuleInstance]struct{}) *ModuleInstance { memory, importedMemory *MemoryInstance, typeInstances []*TypeInstance, moduleImports map[*ModuleInstance]struct{}) {
instance := &ModuleInstance{Name: name, Types: typeInstances, dependencies: moduleImports} m.Types = typeInstances
m.dependencies = moduleImports
instance.Functions = append(instance.Functions, importedFunctions...) m.Functions = append(m.Functions, importedFunctions...)
for i, f := range functions { for i, f := range functions {
// Associate each function with the type instance and the module instance's pointer. // Associate each function with the type instance and the module instance's pointer.
f.Module = instance f.Module = m
f.TypeID = typeInstances[module.FunctionSection[i]].TypeID f.TypeID = typeInstances[module.FunctionSection[i]].TypeID
instance.Functions = append(instance.Functions, f) m.Functions = append(m.Functions, f)
} }
instance.Globals = append(instance.Globals, importedGlobals...) m.Globals = append(m.Globals, importedGlobals...)
instance.Globals = append(instance.Globals, globals...) m.Globals = append(m.Globals, globals...)
if importedTable != nil { if importedTable != nil {
instance.Table = importedTable m.Table = importedTable
} else { } else {
instance.Table = table m.Table = table
} }
if importedMemory != nil { if importedMemory != nil {
instance.Memory = importedMemory m.Memory = importedMemory
} else { } else {
instance.Memory = memory m.Memory = memory
} }
instance.buildExportInstances(module.ExportSection) m.buildExports(module.ExportSection)
return instance
} }
func (m *ModuleInstance) buildExportInstances(exports map[string]*Export) { func (m *ModuleInstance) buildExports(exports map[string]*Export) {
m.Exports = make(map[string]*ExportInstance, len(exports)) m.Exports = make(map[string]*ExportInstance, len(exports))
for _, exp := range exports { for _, exp := range exports {
index := exp.Index index := exp.Index
@@ -365,6 +368,7 @@ func NewStore(ctx context.Context, engine Engine, enabledFeatures Features) *Sto
ctx: ctx, ctx: ctx,
engine: engine, engine: engine,
EnabledFeatures: enabledFeatures, EnabledFeatures: enabledFeatures,
moduleNames: map[string]struct{}{},
modules: map[string]*ModuleInstance{}, modules: map[string]*ModuleInstance{},
typeIDs: map[string]FunctionTypeID{}, typeIDs: map[string]FunctionTypeID{},
maximumFunctionIndex: maximumFunctionIndex, maximumFunctionIndex: maximumFunctionIndex,
@@ -383,20 +387,23 @@ func (s *Store) checkFunctionIndexOverflow(newInstanceNum int) error {
return nil return nil
} }
// Instantiate uses name instead of the Module.NameSection ModuleName as it allows instantiating the same module under
// different names safely and concurrently.
func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error) { func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error) {
// Note: we do not take lock here in order to enable concurrent instantiation and compilation if err := s.requireModuleName(name); err != nil {
// of multiuple modules. When necessary, we take read or write locks in each method of store used here.
if err := s.requireModuleUnused(name); err != nil {
return nil, err return nil, err
} }
// Note: we do not take lock here in order to enable concurrent instantiation and compilation
// of multiple modules. When necessary, we take read or write locks in each method of store used here.
if err := s.checkFunctionIndexOverflow(len(module.FunctionSection)); err != nil { if err := s.checkFunctionIndexOverflow(len(module.FunctionSection)); err != nil {
s.deleteModule(name)
return nil, err return nil, err
} }
types, err := s.getTypes(module.TypeSection) types, err := s.getTypes(module.TypeSection)
if err != nil { if err != nil {
s.deleteModule(name)
return nil, err return nil, err
} }
@@ -410,6 +417,7 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error)
} }
}() }()
if err != nil { if err != nil {
s.deleteModule(name)
return nil, err return nil, err
} }
@@ -417,14 +425,17 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error)
module.buildFunctions(), module.buildGlobals(importedGlobals), module.buildTable(), module.buildMemory() module.buildFunctions(), module.buildGlobals(importedGlobals), module.buildTable(), module.buildMemory()
// Now we have all instances from imports and local ones, so ready to create a new ModuleInstance. // Now we have all instances from imports and local ones, so ready to create a new ModuleInstance.
instance := newModuleInstance(name, module, importedFunctions, functions, importedGlobals, instance := &ModuleInstance{Name: name}
instance.addSections(module, importedFunctions, functions, importedGlobals,
globals, importedTable, table, importedMemory, memory, types, moduleImports) globals, importedTable, table, importedMemory, memory, types, moduleImports)
if err = instance.validateElements(module.ElementSection); err != nil { if err = instance.validateElements(module.ElementSection); err != nil {
s.deleteModule(name)
return nil, err return nil, err
} }
if err := instance.validateData(module.DataSection); err != nil { if err = instance.validateData(module.DataSection); err != nil {
s.deleteModule(name)
return nil, err return nil, err
} }
@@ -432,11 +443,10 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error)
s.addFunctions(functions...) // Need to assign funcaddr to each instance before compilation. s.addFunctions(functions...) // Need to assign funcaddr to each instance before compilation.
for i, f := range functions { for i, f := range functions {
// TODO: maybe better consider spawning multiple goroutines for compilations to accelerate. // TODO: maybe better consider spawning multiple goroutines for compilations to accelerate.
if err := s.engine.Compile(f); err != nil { if err = s.engine.Compile(f); err != nil {
s.deleteModule(name)
// On the failure, release the assigned funcaddr and already compiled functions. // On the failure, release the assigned funcaddr and already compiled functions.
if err := s.releaseFunctions(functions[:i]...); err != nil { _ = s.releaseFunctions(functions[:i]...) // ignore any release error so we can report the original one.
return nil, err
}
idx := module.SectionElementCount(SectionIDFunction) - 1 idx := module.SectionElementCount(SectionIDFunction) - 1
return nil, fmt.Errorf("compilation failed at index %d/%d: %w", i, idx, err) return nil, fmt.Errorf("compilation failed at index %d/%d: %w", i, idx, err)
} }
@@ -446,8 +456,8 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error)
instance.applyElements(module.ElementSection) instance.applyElements(module.ElementSection)
instance.applyData(module.DataSection) instance.applyData(module.DataSection)
// Persist the module instance. // Build the default context for calls to this module.
s.addModule(instance) instance.Ctx = NewModuleContext(s.ctx, s.engine, instance)
// Plus, we can finalize the module import reference count. // Plus, we can finalize the module import reference count.
moduleImportsFinalized = true moduleImportsFinalized = true
@@ -455,10 +465,14 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error)
// Execute the start function. // Execute the start function.
if module.StartSection != nil { if module.StartSection != nil {
funcIdx := *module.StartSection funcIdx := *module.StartSection
if _, err := s.engine.Call(instance.Ctx, instance.Functions[funcIdx]); err != nil { if _, err = s.engine.Call(instance.Ctx, instance.Functions[funcIdx]); err != nil {
s.deleteModule(name)
return nil, fmt.Errorf("module[%s] start function failed: %w", name, err) return nil, fmt.Errorf("module[%s] start function failed: %w", name, err)
} }
} }
// Now that the instantiation is complete without error, add it. This makes it visible for import.
s.addModule(instance)
return &PublicModule{s: s, instance: instance}, nil return &PublicModule{s: s, instance: instance}, nil
} }
@@ -487,7 +501,7 @@ func (s *Store) ReleaseModule(moduleName string) error {
return fmt.Errorf("unable to release function instance: %w", err) return fmt.Errorf("unable to release function instance: %w", err)
} }
s.deleteModule(m) s.deleteModule(moduleName)
return nil return nil
} }
@@ -541,16 +555,28 @@ func (s *Store) addFunctions(fs ...*FunctionInstance) {
} }
} }
func (s *Store) deleteModule(m *ModuleInstance) { // deleteModule makes the moduleName available for instantiation again.
func (s *Store) deleteModule(moduleName string) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
delete(s.modules, m.Name) delete(s.modules, moduleName)
delete(s.moduleNames, moduleName)
} }
func (s *Store) addModule(m *ModuleInstance) { // requireModuleName is a pre-flight check to reserve a module.
// Build the default context for calls to this module. // This must be reverted on error with deleteModule if initialization fails.
m.Ctx = NewModuleContext(s.ctx, s.engine, m) func (s *Store) requireModuleName(moduleName string) error {
s.mux.Lock()
defer s.mux.Unlock()
if _, ok := s.moduleNames[moduleName]; ok {
return fmt.Errorf("module %s has already been instantiated", moduleName)
}
s.moduleNames[moduleName] = struct{}{}
return nil
}
// addModule makes the module visible for import
func (s *Store) addModule(m *ModuleInstance) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.modules[m.Name] = m s.modules[m.Name] = m