diff --git a/builder.go b/builder.go index 09062018..87e65a08 100644 --- a/builder.go +++ b/builder.go @@ -265,6 +265,8 @@ func (b *moduleBuilder) Instantiate() (api.Module, error) { if module, err := b.Build(); err != nil { return nil, err } else { + // *wasm.ModuleInstance cannot be tracked, so we release the cache inside of this function. + defer module.Close() return b.r.InstantiateModuleWithConfig(module, NewModuleConfig().WithName(b.moduleName)) } } diff --git a/config.go b/config.go index 7addb6fc..ea37b8aa 100644 --- a/config.go +++ b/config.go @@ -148,6 +148,33 @@ func (c *RuntimeConfig) WithFeatureMultiValue(enabled bool) *RuntimeConfig { // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#semantic-phases%E2%91%A0 type CompiledCode struct { module *wasm.Module + // cachedEngines maps wasm.Engine to []*wasm.Module which originate from this .module. + // This is necessary to track which engine caches *Module where latter might be different + // from .module via import replacement config (ModuleConfig.WithImport). + cachedEngines map[wasm.Engine]map[*wasm.Module]struct{} +} + +// Close releases all the allocated resources for this CompiledCode. +func (c *CompiledCode) Close() { + for engine, modules := range c.cachedEngines { + for module := range modules { + engine.ReleaseCompilationCache(module) + } + } +} + +func (c *CompiledCode) addCacheEntry(module *wasm.Module, engine wasm.Engine) { + if c.cachedEngines == nil { + c.cachedEngines = map[wasm.Engine]map[*wasm.Module]struct{}{} + } + + cache, ok := c.cachedEngines[engine] + if !ok { + cache = map[*wasm.Module]struct{}{} + c.cachedEngines[engine] = cache + } + + cache[module] = struct{}{} } // ModuleConfig configures resources needed by functions that have low-level interactions with the host operating system. diff --git a/examples/replace-import/replace-import.go b/examples/replace-import/replace-import.go index 3793e32f..2c299b3f 100644 --- a/examples/replace-import/replace-import.go +++ b/examples/replace-import/replace-import.go @@ -32,6 +32,7 @@ func main() { if err != nil { log.Fatal(err) } + defer code.Close() // Instantiate the module, replacing the import "env.abort" with "assemblyscript.abort". mod, err := r.InstantiateModuleWithConfig(code, wazero.NewModuleConfig(). diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index ebafe0db..6b828055 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -38,7 +38,7 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) t.Run("sets module name", func(t *testing.T) { - me, err := e.NewModuleEngine(t.Name(), nil, nil, nil, nil) + me, err := e.NewModuleEngine(t.Name(), nil, nil, nil, nil, nil) require.NoError(t, err) defer me.Close() require.Equal(t, t.Name(), me.Name()) @@ -61,7 +61,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { addFunction(module, "fn", fn) // Compile the module - me, err := e.NewModuleEngine(module.Name, nil, module.Functions, nil, nil) + me, err := e.NewModuleEngine(module.Name, nil, nil, module.Functions, nil, nil) require.NoError(t, err) defer me.Close() linkModuleToEngine(module, me) @@ -92,7 +92,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { var tableInit map[wasm.Index]wasm.Index // Instantiate the module, which has nothing but an empty table. - me, err := e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, table, tableInit) + me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) defer me.Close() @@ -112,7 +112,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { tableInit := map[wasm.Index]wasm.Index{0: 2} // Instantiate the module whose table points to its own functions. - me, err := e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, table, tableInit) + me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) defer me.Close() @@ -132,12 +132,12 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { tableInit := map[wasm.Index]wasm.Index{0: 2} // Imported functions are compiled before the importing module is instantiated. - imported, err := e.NewModuleEngine(t.Name(), nil, importedFunctions, nil, nil) + imported, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) require.NoError(t, err) defer imported.Close() // Instantiate the importing module, which is whose table is initialized. - importing, err := e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, table, tableInit) + importing, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) defer importing.Close() @@ -162,12 +162,12 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { tableInit := map[wasm.Index]wasm.Index{0: 0, 1: 4} // Imported functions are compiled before the importing module is instantiated. - imported, err := e.NewModuleEngine(t.Name(), nil, importedFunctions, nil, nil) + imported, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) require.NoError(t, err) defer imported.Close() // Instantiate the importing module, which is whose table is initialized. - importing, err := e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, table, tableInit) + importing, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) defer importing.Close() @@ -202,7 +202,7 @@ func runTestModuleEngine_Call_HostFn_ModuleContext(t *testing.T, et EngineTester module.Types = []*wasm.TypeInstance{{Type: f.Type}} module.Functions = []*wasm.FunctionInstance{f} - me, err := e.NewModuleEngine(t.Name(), nil, module.Functions, nil, nil) + me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, module.Functions, nil, nil) require.NoError(t, err) defer me.Close() @@ -440,7 +440,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, wasm.Mod addFunction(imported, callHostFnName, callHostFn) // Compile the imported module - importedMe, err := e.NewModuleEngine(imported.Name, nil, imported.Functions, nil, nil) + importedMe, err := e.NewModuleEngine(imported.Name, &wasm.Module{}, nil, imported.Functions, nil, nil) require.NoError(t, err) linkModuleToEngine(imported, importedMe) @@ -460,7 +460,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, wasm.Mod addFunction(importing, callImportCallHostFnName, callImportedHostFn) // Compile the importing module - importingMe, err := e.NewModuleEngine(importing.Name, importedFunctions, importing.Functions, nil, nil) + importingMe, err := e.NewModuleEngine(importing.Name, &wasm.Module{}, importedFunctions, importing.Functions, nil, nil) require.NoError(t, err) linkModuleToEngine(importing, importingMe) diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index ea1f0be6..46138fc5 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -6,6 +6,7 @@ type Engine interface { // NewModuleEngine compiles down the function instances in a module, and returns ModuleEngine for the module. // // * name is the name the module was instantiated with used for error handling. + // * module is the source module from which moduleFunctions are instantiated. This is used for caching. // * importedFunctions: functions this module imports, already compiled in this engine. // * moduleFunctions: functions declared in this module that must be compiled. // * table: a possibly shared table used by this module. When nil tableInit will be nil. @@ -13,7 +14,10 @@ type Engine interface { // // Note: Input parameters must be pre-validated with wasm.Module Validate, to ensure no fields are invalid // due to reasons such as out-of-bounds. - NewModuleEngine(name string, importedFunctions, moduleFunctions []*FunctionInstance, table *TableInstance, tableInit map[Index]Index) (ModuleEngine, error) + NewModuleEngine(name string, module *Module, importedFunctions, moduleFunctions []*FunctionInstance, table *TableInstance, tableInit map[Index]Index) (ModuleEngine, error) + + // ReleaseCompilationCache releases compilation caches for the given module (source). + ReleaseCompilationCache(module *Module) } // ModuleEngine implements function calls for a given module. diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index adcf0eee..2abafea4 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -21,18 +21,25 @@ var callStackCeiling = buildoptions.CallStackCeiling // engine is an interpreter implementation of wasm.Engine type engine struct { - enabledFeatures wasm.Features - compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. - mux sync.RWMutex + enabledFeatures wasm.Features + compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. + cachedCompiledFunctionsPerModule map[*wasm.Module][]*compiledFunction // guarded by mutex. + mux sync.RWMutex } func NewEngine(enabledFeatures wasm.Features) wasm.Engine { return &engine{ - enabledFeatures: enabledFeatures, - compiledFunctions: make(map[*wasm.FunctionInstance]*compiledFunction), + enabledFeatures: enabledFeatures, + compiledFunctions: make(map[*wasm.FunctionInstance]*compiledFunction), + cachedCompiledFunctionsPerModule: map[*wasm.Module][]*compiledFunction{}, } } +// ReleaseCompilationCache implements the same method as documented on wasm.Engine. +func (e *engine) ReleaseCompilationCache(m *wasm.Module) { + e.deleteCachedCompiledFunctions(m) +} + func (e *engine) deleteCompiledFunction(f *wasm.FunctionInstance) { e.mux.Lock() defer e.mux.Unlock() @@ -52,6 +59,25 @@ func (e *engine) addCompiledFunction(f *wasm.FunctionInstance, cf *compiledFunct e.compiledFunctions[f] = cf } +func (e *engine) deleteCachedCompiledFunctions(module *wasm.Module) { + e.mux.Lock() + defer e.mux.Unlock() + delete(e.cachedCompiledFunctionsPerModule, module) +} + +func (e *engine) addCachedCompiledFunctions(module *wasm.Module, fs []*compiledFunction) { + e.mux.Lock() + defer e.mux.Unlock() + e.cachedCompiledFunctionsPerModule[module] = fs +} + +func (e *engine) getCachedCompiledFunctions(module *wasm.Module) (fs []*compiledFunction, ok bool) { + e.mux.RLock() + defer e.mux.RUnlock() + fs, ok = e.cachedCompiledFunctionsPerModule[module] + return +} + // moduleEngine implements wasm.ModuleEngine type moduleEngine struct { // name is the name the module was instantiated with used for error handling. @@ -145,6 +171,15 @@ type compiledFunction struct { hostFn *reflect.Value } +func (c *compiledFunction) clone(me *moduleEngine, newSourceInstance *wasm.FunctionInstance) *compiledFunction { + return &compiledFunction{ + moduleEngine: me, + source: newSourceInstance, + body: c.body, + hostFn: c.hostFn, + } +} + // Non-interface union of all the wazeroir operations. type interpreterOp struct { kind wazeroir.OperationKind @@ -154,7 +189,7 @@ type interpreterOp struct { } // NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *engine) NewModuleEngine(name string, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { +func (e *engine) NewModuleEngine(name string, source *wasm.Module, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { imported := uint32(len(importedFunctions)) me := &moduleEngine{ name: name, @@ -171,30 +206,39 @@ func (e *engine) NewModuleEngine(name string, importedFunctions, moduleFunctions me.compiledFunctions = append(me.compiledFunctions, cf) } - for i, f := range moduleFunctions { - var compiled *compiledFunction - if f.Kind == wasm.FunctionKindWasm { - ir, err := wazeroir.Compile(e.enabledFeatures, f) - if err != nil { - me.Close() - // TODO(Adrian): extract Module.funcDesc so that errors here have more context - return nil, fmt.Errorf("function[%d/%d] failed to lower to wazeroir: %w", i, len(moduleFunctions)-1, err) - } - - compiled, err = e.lowerIROps(f, ir.Operations) - if err != nil { - me.Close() - return nil, fmt.Errorf("function[%d/%d] failed to convert wazeroir operations: %w", i, len(moduleFunctions)-1, err) - } - } else { - compiled = &compiledFunction{hostFn: f.GoFunc, source: f} + if cached, ok := e.getCachedCompiledFunctions(source); ok { // cache hit + for i, c := range cached[len(importedFunctions):] { + cloned := c.clone(me, moduleFunctions[i]) + me.compiledFunctions = append(me.compiledFunctions, cloned) } - compiled.moduleEngine = me - me.compiledFunctions = append(me.compiledFunctions, compiled) + } else { // cache miss + for i, f := range moduleFunctions { + var compiled *compiledFunction + if f.Kind == wasm.FunctionKindWasm { + ir, err := wazeroir.Compile(e.enabledFeatures, f) + if err != nil { + me.Close() + // TODO(Adrian): extract Module.funcDesc so that errors here have more context + return nil, fmt.Errorf("function[%d/%d] failed to lower to wazeroir: %w", i, len(moduleFunctions)-1, err) + } - // Add the compiled function to the store-wide engine as well so that - // the future importing module can refer the function instance. - e.addCompiledFunction(f, compiled) + compiled, err = e.lowerIROps(f, ir.Operations) + if err != nil { + me.Close() + return nil, fmt.Errorf("function[%d/%d] failed to convert wazeroir operations: %w", i, len(moduleFunctions)-1, err) + } + } else { + compiled = &compiledFunction{hostFn: f.GoFunc, source: f} + } + compiled.moduleEngine = me + me.compiledFunctions = append(me.compiledFunctions, compiled) + + // Add the compiled function to the store-wide engine as well so that + // the future importing module can refer the function instance. + e.addCompiledFunction(f, compiled) + } + + e.addCachedCompiledFunctions(source, me.compiledFunctions) } for elemIdx, funcidx := range tableInit { // Initialize any elements with compiled functions diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index 5ec626bc..77a824db 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -5,6 +5,7 @@ import ( "math" "strconv" "testing" + "unsafe" "github.com/tetratelabs/wazero/internal/buildoptions" "github.com/tetratelabs/wazero/internal/testing/enginetest" @@ -201,6 +202,7 @@ func TestInterpreter_EngineCompile_Errors(t *testing.T) { t.Run("invalid import", func(t *testing.T) { e := et.NewEngine(wasm.Features20191205).(*engine) _, err := e.NewModuleEngine(t.Name(), + &wasm.Module{}, []*wasm.FunctionInstance{{Module: &wasm.ModuleInstance{Name: "uncompiled"}, DebugName: "uncompiled.fn"}}, nil, // moduleFunctions nil, // table @@ -220,7 +222,7 @@ func TestInterpreter_EngineCompile_Errors(t *testing.T) { } // initialize the module-engine containing imported functions - _, err := e.NewModuleEngine(t.Name(), nil, importedFunctions, nil, nil) + _, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) require.NoError(t, err) require.Len(t, e.compiledFunctions, len(importedFunctions)) @@ -233,7 +235,7 @@ func TestInterpreter_EngineCompile_Errors(t *testing.T) { }, Module: &wasm.ModuleInstance{}}, } - _, err = e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, nil, nil) + _, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) require.EqualError(t, err, "function[2/2] failed to lower to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. @@ -245,11 +247,6 @@ func TestInterpreter_EngineCompile_Errors(t *testing.T) { } func TestInterpreter_Close(t *testing.T) { - newFunctionInstance := func(id int) *wasm.FunctionInstance { - return &wasm.FunctionInstance{ - DebugName: strconv.Itoa(id), Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}} - } - for _, tc := range []struct { name string importedFunctions, moduleFunctions []*wasm.FunctionInstance @@ -273,12 +270,12 @@ func TestInterpreter_Close(t *testing.T) { e := et.NewEngine(wasm.Features20191205).(*engine) if len(tc.importedFunctions) > 0 { // initialize the module-engine containing imported functions - me, err := e.NewModuleEngine(t.Name(), nil, tc.importedFunctions, nil, nil) + me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, tc.importedFunctions, nil, nil) require.NoError(t, err) require.Len(t, me.(*moduleEngine).compiledFunctions, len(tc.importedFunctions)) } - me, err := e.NewModuleEngine(t.Name(), tc.importedFunctions, tc.moduleFunctions, nil, nil) + me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, tc.importedFunctions, tc.moduleFunctions, nil, nil) require.NoError(t, err) require.Len(t, me.(*moduleEngine).compiledFunctions, len(tc.importedFunctions)+len(tc.moduleFunctions)) @@ -302,3 +299,114 @@ func TestInterpreter_Close(t *testing.T) { }) } } + +func TestEngine_CachedCompiledFunctionsPerModule(t *testing.T) { + e := et.NewEngine(wasm.Features20191205).(*engine) + exp := []*compiledFunction{ + {source: &wasm.FunctionInstance{DebugName: "1"}}, + {source: &wasm.FunctionInstance{DebugName: "2"}}, + } + m := &wasm.Module{} + + e.addCachedCompiledFunctions(m, exp) + + actual, ok := e.getCachedCompiledFunctions(m) + require.True(t, ok) + require.Len(t, actual, len(exp)) + for i := range actual { + require.Equal(t, exp[i], actual[i]) + } + + e.deleteCachedCompiledFunctions(m) + _, ok = e.getCachedCompiledFunctions(m) + require.False(t, ok) +} + +func TestEngine_NewModuleEngine_cache(t *testing.T) { + e := et.NewEngine(wasm.Features20191205).(*engine) + importedModuleSource := &wasm.Module{} + + // No cache. + importedME, err := e.NewModuleEngine("1", importedModuleSource, nil, []*wasm.FunctionInstance{ + newFunctionInstance(1), + newFunctionInstance(2), + }, nil, nil) + require.NoError(t, err) + + // Cached. + importedMEFromCache, err := e.NewModuleEngine("2", importedModuleSource, nil, []*wasm.FunctionInstance{ + newFunctionInstance(1), + newFunctionInstance(2), + }, nil, nil) + require.NoError(t, err) + + require.NotEqual(t, importedME, importedMEFromCache) + require.NotEqual(t, importedME.Name(), importedMEFromCache.Name()) + + // Check compiled functions. + ime, imeCache := importedME.(*moduleEngine), importedMEFromCache.(*moduleEngine) + require.Equal(t, len(ime.compiledFunctions), len(imeCache.compiledFunctions)) + + for i, fn := range ime.compiledFunctions { + // Compiled functions must be cloend. + fnCached := imeCache.compiledFunctions[i] + require.NotEqual(t, fn, fnCached) + require.NotEqual(t, fn.moduleEngine, fnCached.moduleEngine) + require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. + // But the body stays the same. + require.Equal(t, fn.body, fnCached.body) + } + + // Next is to veirfy the caching works for modules with imports. + importedFunc := ime.compiledFunctions[0].source + moduleSource := &wasm.Module{} + + // No cache. + modEng, err := e.NewModuleEngine("3", moduleSource, + []*wasm.FunctionInstance{importedFunc}, // Import one function. + []*wasm.FunctionInstance{ + newFunctionInstance(10), + newFunctionInstance(20), + }, nil, nil) + require.NoError(t, err) + + // Cached. + modEngCache, err := e.NewModuleEngine("4", moduleSource, + []*wasm.FunctionInstance{importedFunc}, // Import one function. + []*wasm.FunctionInstance{ + newFunctionInstance(10), + newFunctionInstance(20), + }, nil, nil) + require.NoError(t, err) + + require.NotEqual(t, modEng, modEngCache) + require.NotEqual(t, modEng.Name(), modEngCache.Name()) + + me, meCache := modEng.(*moduleEngine), modEngCache.(*moduleEngine) + require.Equal(t, len(me.compiledFunctions), len(meCache.compiledFunctions)) + + for i, fn := range me.compiledFunctions { + fnCached := meCache.compiledFunctions[i] + if i == 0 { + // This case the function is imported, so it must be the same for both module engines. + require.Equal(t, fn, fnCached) + require.Equal(t, importedFunc, fn.source) + } else { + // Compiled functions must be cloend. + require.NotEqual(t, fn, fnCached) + require.NotEqual(t, fn.moduleEngine, fnCached.moduleEngine) + require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. + // But the code segment stays the same. + require.Equal(t, fn.body, fnCached.body) + } + } +} + +func newFunctionInstance(id int) *wasm.FunctionInstance { + return &wasm.FunctionInstance{ + DebugName: strconv.Itoa(id), + Type: &wasm.FunctionType{}, + Body: []byte{wasm.OpcodeEnd}, + Module: &wasm.ModuleInstance{}, + } +} diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index 5075bbc9..bb37bc50 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -18,9 +18,10 @@ import ( type ( // engine is an JIT implementation of wasm.Engine engine struct { - enabledFeatures wasm.Features - compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. - mux sync.RWMutex + enabledFeatures wasm.Features + compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. + cachedCompiledFunctionsPerModule map[*wasm.Module][]*compiledFunction // guarded by mutex. + mux sync.RWMutex // setFinalizer defaults to runtime.SetFinalizer, but overridable for tests. setFinalizer func(obj interface{}, finalizer interface{}) } @@ -193,6 +194,19 @@ type ( compiledFunctionStaticData = [][]byte ) +func (c *compiledFunction) clone(newSourceInstance *wasm.FunctionInstance) *compiledFunction { + // Note: we don't need to set finalizer to munmap the code segment since it is + // already a target of munmap by the finalizer set on the original compiledFunction `c`. + return &compiledFunction{ + codeInitialAddress: c.codeInitialAddress, + stackPointerCeil: c.stackPointerCeil, + source: newSourceInstance, + moduleInstanceAddress: uintptr(unsafe.Pointer(newSourceInstance.Module)), + codeSegment: c.codeSegment, + staticData: c.staticData, + } +} + // Native code reads/writes Go's structs with the following constants. // See TestVerifyOffsetValue for how to derive these values. const ( @@ -349,8 +363,13 @@ func releaseCompiledFunction(compiledFn *compiledFunction) { } } +// ReleaseCompilationCache implements the same method as documented on wasm.Engine. +func (e *engine) ReleaseCompilationCache(module *wasm.Module) { + e.deleteCachedCompiledFunctions(module) +} + // NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *engine) NewModuleEngine(name string, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { +func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { imported := uint32(len(importedFunctions)) me := &moduleEngine{ name: name, @@ -367,28 +386,37 @@ func (e *engine) NewModuleEngine(name string, importedFunctions, moduleFunctions me.compiledFunctions = append(me.compiledFunctions, cf) } - for i, f := range moduleFunctions { - var compiled *compiledFunction - var err error - if f.Kind == wasm.FunctionKindWasm { - compiled, err = compileWasmFunction(e.enabledFeatures, f) - } else { - compiled, err = compileHostFunction(f) + if cached, ok := e.getCachedCompiledFunctions(module); ok { // cache hit. + for i, c := range cached[len(importedFunctions):] { + cloned := c.clone(moduleFunctions[i]) + me.compiledFunctions = append(me.compiledFunctions, cloned) } - if err != nil { - me.Close() // safe because the reference to me was never leaked. - return nil, fmt.Errorf("function[%s(%d/%d)] %w", f.DebugName, i, len(moduleFunctions)-1, err) + } else { // cache miss. + for i, f := range moduleFunctions { + var compiled *compiledFunction + var err error + if f.Kind == wasm.FunctionKindWasm { + compiled, err = compileWasmFunction(e.enabledFeatures, f) + } else { + compiled, err = compileHostFunction(f) + } + if err != nil { + me.Close() // safe because the reference to me was never leaked. + return nil, fmt.Errorf("function[%s(%d/%d)] %w", f.DebugName, i, len(moduleFunctions)-1, err) + } + + // As this uses mmap, we need a finalizer in case moduleEngine.Close was never called. Regardless, we need a + // finalizer due to how moduleEngine.Close is implemented. + e.setFinalizer(compiled, releaseCompiledFunction) + + me.compiledFunctions = append(me.compiledFunctions, compiled) + + // Add the compiled function to the store-wide engine as well so that + // the future importing module can refer the function instance. + e.addCompiledFunction(f, compiled) } - // As this uses mmap, we need a finalizer in case moduleEngine.Close was never called. Regardless, we need a - // finalizer due to how moduleEngine.Close is implemented. - e.setFinalizer(compiled, releaseCompiledFunction) - - me.compiledFunctions = append(me.compiledFunctions, compiled) - - // Add the compiled function to the store-wide engine as well so that - // the future importing module can refer the function instance. - e.addCompiledFunction(f, compiled) + e.addCachedCompiledFunctions(module, me.compiledFunctions) } for elemIdx, funcidx := range tableInit { // Initialize any elements with compiled functions @@ -446,6 +474,24 @@ func (e *engine) getCompiledFunction(f *wasm.FunctionInstance) (cf *compiledFunc cf, ok = e.compiledFunctions[f] return } +func (e *engine) deleteCachedCompiledFunctions(module *wasm.Module) { + e.mux.Lock() + defer e.mux.Unlock() + delete(e.cachedCompiledFunctionsPerModule, module) +} + +func (e *engine) addCachedCompiledFunctions(module *wasm.Module, fs []*compiledFunction) { + e.mux.Lock() + defer e.mux.Unlock() + e.cachedCompiledFunctionsPerModule[module] = fs +} + +func (e *engine) getCachedCompiledFunctions(module *wasm.Module) (fs []*compiledFunction, ok bool) { + e.mux.RLock() + defer e.mux.RUnlock() + fs, ok = e.cachedCompiledFunctionsPerModule[module] + return +} // Name implements the same method as documented on wasm.ModuleEngine. func (me *moduleEngine) Name() string { @@ -523,9 +569,10 @@ func NewEngine(enabledFeatures wasm.Features) wasm.Engine { func newEngine(enabledFeatures wasm.Features) *engine { return &engine{ - enabledFeatures: enabledFeatures, - compiledFunctions: map[*wasm.FunctionInstance]*compiledFunction{}, - setFinalizer: runtime.SetFinalizer, + enabledFeatures: enabledFeatures, + compiledFunctions: map[*wasm.FunctionInstance]*compiledFunction{}, + cachedCompiledFunctionsPerModule: map[*wasm.Module][]*compiledFunction{}, + setFinalizer: runtime.SetFinalizer, } } @@ -785,17 +832,11 @@ func compileHostFunction(f *wasm.FunctionInstance) (*compiledFunction, error) { return nil, err } - stackPointerCeil := uint64(len(f.Type.Params)) - if res := uint64(len(f.Type.Results)); stackPointerCeil < res { - stackPointerCeil = res - } - return &compiledFunction{ source: f, codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(f.Module)), - stackPointerCeil: stackPointerCeil, }, nil } diff --git a/internal/wasm/jit/engine_test.go b/internal/wasm/jit/engine_test.go index 2f6fa73b..daa9dd65 100644 --- a/internal/wasm/jit/engine_test.go +++ b/internal/wasm/jit/engine_test.go @@ -152,6 +152,7 @@ func TestJIT_EngineCompile_Errors(t *testing.T) { e := et.NewEngine(wasm.Features20191205) _, err := e.NewModuleEngine( t.Name(), + &wasm.Module{}, []*wasm.FunctionInstance{{Module: &wasm.ModuleInstance{Name: "uncompiled"}, DebugName: "uncompiled.fn"}}, nil, // moduleFunctions nil, // table @@ -169,7 +170,7 @@ func TestJIT_EngineCompile_Errors(t *testing.T) { {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, } - _, err := e.NewModuleEngine(t.Name(), nil, importedFunctions, nil, nil) + _, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) require.NoError(t, err) require.Len(t, e.compiledFunctions, len(importedFunctions)) @@ -182,7 +183,7 @@ func TestJIT_EngineCompile_Errors(t *testing.T) { }, Module: &wasm.ModuleInstance{}}, } - _, err = e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, nil, nil) + _, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) require.EqualError(t, err, "function[invalid code(2/2)] failed to lower to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. @@ -204,15 +205,6 @@ func (f fakeFinalizer) setFinalizer(obj interface{}, finalizer interface{}) { } func TestJIT_NewModuleEngine_CompiledFunctions(t *testing.T) { - newFunctionInstance := func(id int) *wasm.FunctionInstance { - return &wasm.FunctionInstance{ - DebugName: strconv.Itoa(id), - Type: &wasm.FunctionType{}, - Body: []byte{wasm.OpcodeEnd}, - Module: &wasm.ModuleInstance{}, - } - } - e := et.NewEngine(wasm.Features20191205).(*engine) importedFinalizer := fakeFinalizer{} @@ -222,7 +214,7 @@ func TestJIT_NewModuleEngine_CompiledFunctions(t *testing.T) { newFunctionInstance(10), newFunctionInstance(20), } - modE, err := e.NewModuleEngine(t.Name(), nil, importedFunctions, nil, nil) + modE, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) require.NoError(t, err) defer modE.Close() imported := modE.(*moduleEngine) @@ -236,7 +228,7 @@ func TestJIT_NewModuleEngine_CompiledFunctions(t *testing.T) { newFunctionInstance(300), } - modE, err = e.NewModuleEngine(t.Name(), importedFunctions, moduleFunctions, nil, nil) + modE, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) require.NoError(t, err) defer modE.Close() importing := modE.(*moduleEngine) @@ -316,6 +308,7 @@ func TestJIT_ModuleEngine_Close(t *testing.T) { // Instantiate the imported module modEngine, err := e.NewModuleEngine( fmt.Sprintf("%s - imported functions", t.Name()), + &wasm.Module{}, nil, // moduleFunctions tc.importedFunctions, nil, // table @@ -328,6 +321,7 @@ func TestJIT_ModuleEngine_Close(t *testing.T) { importing, err := e.NewModuleEngine( fmt.Sprintf("%s - module-defined functions", t.Name()), + &wasm.Module{}, tc.importedFunctions, tc.moduleFunctions, nil, // table @@ -461,3 +455,116 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { }) } } + +// TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and jit +func TestEngine_CachedCompiledFunctionsPerModule(t *testing.T) { + e := newEngine(wasm.Features20191205) + exp := []*compiledFunction{ + {source: &wasm.FunctionInstance{DebugName: "1"}}, + {source: &wasm.FunctionInstance{DebugName: "2"}}, + } + m := &wasm.Module{} + + e.addCachedCompiledFunctions(m, exp) + + actual, ok := e.getCachedCompiledFunctions(m) + require.True(t, ok) + require.Len(t, actual, len(exp)) + for i := range actual { + require.Equal(t, exp[i], actual[i]) + } + + e.deleteCachedCompiledFunctions(m) + _, ok = e.getCachedCompiledFunctions(m) + require.False(t, ok) +} + +// TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and jit +func TestEngine_NewModuleEngine_cache(t *testing.T) { + e := newEngine(wasm.Features20191205) + importedModuleSource := &wasm.Module{} + + // No cache. + importedME, err := e.NewModuleEngine("1", importedModuleSource, nil, []*wasm.FunctionInstance{ + newFunctionInstance(1), + newFunctionInstance(2), + }, nil, nil) + require.NoError(t, err) + + // Cached. + importedMEFromCache, err := e.NewModuleEngine("2", importedModuleSource, nil, []*wasm.FunctionInstance{ + newFunctionInstance(1), + newFunctionInstance(2), + }, nil, nil) + require.NoError(t, err) + + require.NotEqual(t, importedME, importedMEFromCache) + require.NotEqual(t, importedME.Name(), importedMEFromCache.Name()) + + // Check compiled functions. + ime, imeCache := importedME.(*moduleEngine), importedMEFromCache.(*moduleEngine) + require.Equal(t, len(ime.compiledFunctions), len(imeCache.compiledFunctions)) + + for i, fn := range ime.compiledFunctions { + // Compiled functions must be cloend. + fnCached := imeCache.compiledFunctions[i] + require.NotEqual(t, fn, fnCached) + require.NotEqual(t, fn.moduleInstanceAddress, fnCached.moduleInstanceAddress) + require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. + // But the code segment stays the same. + require.Equal(t, fn.codeSegment, fnCached.codeSegment) + } + + // Next is to veirfy the caching works for modules with imports. + importedFunc := ime.compiledFunctions[0].source + moduleSource := &wasm.Module{} + + // No cache. + modEng, err := e.NewModuleEngine("3", moduleSource, + []*wasm.FunctionInstance{importedFunc}, // Import one function. + []*wasm.FunctionInstance{ + newFunctionInstance(10), + newFunctionInstance(20), + }, nil, nil) + require.NoError(t, err) + + // Cached. + modEngCache, err := e.NewModuleEngine("4", moduleSource, + []*wasm.FunctionInstance{importedFunc}, // Import one function. + []*wasm.FunctionInstance{ + newFunctionInstance(10), + newFunctionInstance(20), + }, nil, nil) + require.NoError(t, err) + + require.NotEqual(t, modEng, modEngCache) + require.NotEqual(t, modEng.Name(), modEngCache.Name()) + + me, meCache := modEng.(*moduleEngine), modEngCache.(*moduleEngine) + require.Equal(t, len(me.compiledFunctions), len(meCache.compiledFunctions)) + + for i, fn := range me.compiledFunctions { + fnCached := meCache.compiledFunctions[i] + if i == 0 { + // This case the function is imported, so it must be the same for both module engines. + require.Equal(t, fn, fnCached) + require.Equal(t, importedFunc, fn.source) + } else { + // Compiled functions must be cloend. + require.NotEqual(t, fn, fnCached) + require.NotEqual(t, fn.moduleInstanceAddress, fnCached.moduleInstanceAddress) + require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. + // But the code segment stays the same. + require.Equal(t, fn.codeSegment, fnCached.codeSegment) + } + } +} + +func newFunctionInstance(id int) *wasm.FunctionInstance { + return &wasm.FunctionInstance{ + DebugName: strconv.Itoa(id), + Type: &wasm.FunctionType{}, + Body: []byte{wasm.OpcodeEnd}, + Module: &wasm.ModuleInstance{}, + } +} diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 7e046b32..2d120b0c 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -29,7 +29,7 @@ type ( EnabledFeatures Features // Engine is a global context for a Store which is in responsible for compilation and execution of Wasm modules. - engine Engine + Engine Engine // moduleNames ensures no race conditions instantiating two modules of the same name moduleNames map[string]struct{} // guarded by mux @@ -231,7 +231,7 @@ func (m *ModuleInstance) getExport(name string, et ExternType) (*ExportInstance, func NewStore(enabledFeatures Features, engine Engine) *Store { return &Store{ EnabledFeatures: enabledFeatures, - engine: engine, + Engine: engine, moduleNames: map[string]struct{}{}, modules: map[string]*ModuleInstance{}, typeIDs: map[string]FunctionTypeID{}, @@ -292,7 +292,7 @@ func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sy } // Plus, we are ready to compile functions. - m.Engine, err = s.engine.NewModuleEngine(name, importedFunctions, functions, table, tableInit) + m.Engine, err = s.Engine.NewModuleEngine(name, module, importedFunctions, functions, table, tableInit) if err != nil { return nil, fmt.Errorf("compilation failed: %w", err) } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 567fb285..6f1a027c 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -293,7 +293,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { hm := s.modules[importedModuleName] require.NotNil(t, hm) - engine := s.engine.(*mockEngine) + engine := s.Engine.(*mockEngine) engine.shouldCompileFail = true _, err = s.Instantiate(context.Background(), &Module{ @@ -312,7 +312,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { t.Run("start func failed", func(t *testing.T) { s := newStore() - engine := s.engine.(*mockEngine) + engine := s.Engine.(*mockEngine) engine.callFailIndex = 1 _, err = s.Instantiate(context.Background(), m, importedModuleName, nil) @@ -438,13 +438,16 @@ func newStore() *Store { } // NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *mockEngine) NewModuleEngine(_ string, _, _ []*FunctionInstance, _ *TableInstance, _ map[Index]Index) (ModuleEngine, error) { +func (e *mockEngine) NewModuleEngine(_ string, _ *Module, _, _ []*FunctionInstance, _ *TableInstance, _ map[Index]Index) (ModuleEngine, error) { if e.shouldCompileFail { return nil, fmt.Errorf("some compilation error") } return &mockModuleEngine{callFailIndex: e.callFailIndex}, nil } +// ReleaseCompilationCache implements the same method as documented on wasm.Engine. +func (e *mockEngine) ReleaseCompilationCache(*Module) {} + // Name implements the same method as documented on wasm.ModuleEngine. func (e *mockModuleEngine) Name() string { return e.name diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 414a2bdb..41d8128f 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -407,7 +407,7 @@ type catchFunctions struct { } // NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *catchFunctions) NewModuleEngine(_ string, _, functions []*wasm.FunctionInstance, _ *wasm.TableInstance, _ map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { +func (e *catchFunctions) NewModuleEngine(_ string, _ *wasm.Module, _, functions []*wasm.FunctionInstance, _ *wasm.TableInstance, _ map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { e.functions = functions return e, nil } @@ -423,5 +423,7 @@ func (e *catchFunctions) Call(_ *wasm.ModuleContext, _ *wasm.FunctionInstance, _ } // Close implements the same method as documented on wasm.ModuleEngine. -func (e *catchFunctions) Close() { -} +func (e *catchFunctions) Close() {} + +// ReleaseCompilationCache implements the same method as documented on wasm.Engine. +func (e *catchFunctions) ReleaseCompilationCache(*wasm.Module) {} diff --git a/tests/bench/bench_test.go b/tests/bench/bench_test.go index f67da6a6..9574461e 100644 --- a/tests/bench/bench_test.go +++ b/tests/bench/bench_test.go @@ -16,22 +16,52 @@ import ( //go:embed testdata/case.wasm var caseWasm []byte -func BenchmarkEngines(b *testing.B) { +func BenchmarkInvocation(b *testing.B) { b.Run("interpreter", func(b *testing.B) { m := instantiateHostFunctionModuleWithEngine(b, wazero.NewRuntimeConfigInterpreter()) defer m.Close() - runAllBenches(b, m) + runAllInvocationBenches(b, m) }) if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { b.Run("jit", func(b *testing.B) { m := instantiateHostFunctionModuleWithEngine(b, wazero.NewRuntimeConfigJIT()) defer m.Close() - runAllBenches(b, m) + runAllInvocationBenches(b, m) }) } } -func runAllBenches(b *testing.B, m api.Module) { +func BenchmarkInitialization(b *testing.B) { + b.Run("interpreter", func(b *testing.B) { + r := createRuntime(b, wazero.NewRuntimeConfigInterpreter()) + runInitializationBench(b, r) + }) + + if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { + b.Run("jit", func(b *testing.B) { + r := createRuntime(b, wazero.NewRuntimeConfigJIT()) + runInitializationBench(b, r) + }) + } +} + +func runInitializationBench(b *testing.B, r wazero.Runtime) { + compiled, err := r.CompileModule(caseWasm) + if err != nil { + b.Fatal(err) + } + defer compiled.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + mod, err := r.InstantiateModule(compiled) + if err != nil { + b.Fatal(err) + } + mod.Close() + } +} + +func runAllInvocationBenches(b *testing.B, m api.Module) { runBase64Benches(b, m) runFibBenches(b, m) runStringManipulationBenches(b, m) @@ -120,6 +150,17 @@ func runRandomMatMul(b *testing.B, m api.Module) { } func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.RuntimeConfig) api.Module { + r := createRuntime(b, engine) + + // InstantiateModuleFromCode runs the "_start" function which is what TinyGo compiles "main" to. + m, err := r.InstantiateModuleFromCode(caseWasm) + if err != nil { + b.Fatal(err) + } + return m +} + +func createRuntime(b *testing.B, engine *wazero.RuntimeConfig) wazero.Runtime { getRandomString := func(ctx api.Module, retBufPtr uint32, retBufSize uint32) { results, err := ctx.ExportedFunction("allocate_buffer").Call(ctx, 10) if err != nil { @@ -147,11 +188,5 @@ func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.Runtim if err != nil { b.Fatal(err) } - - // InstantiateModuleFromCode runs the "_start" function which is what TinyGo compiles "main" to. - m, err := r.InstantiateModuleFromCode(caseWasm) - if err != nil { - b.Fatal(err) - } - return m + return r } diff --git a/tests/engine/adhoc_test.go b/tests/engine/adhoc_test.go index d5000b6c..749ad94f 100644 --- a/tests/engine/adhoc_test.go +++ b/tests/engine/adhoc_test.go @@ -5,6 +5,7 @@ import ( _ "embed" "fmt" "math" + "strconv" "testing" "github.com/tetratelabs/wazero" @@ -15,13 +16,14 @@ import ( ) var tests = map[string]func(t *testing.T, r wazero.Runtime){ - "huge stack": testHugeStack, - "unreachable": testUnreachable, - "recursive entry": testRecursiveEntry, - "imported-and-exported func": testImportedAndExportedFunc, - "host function with context parameter": testHostFunctionContextParameter, - "host function with numeric parameter": testHostFunctionNumericParameter, - "close module with in-flight calls": testCloseInFlight, + "huge stack": testHugeStack, + "unreachable": testUnreachable, + "recursive entry": testRecursiveEntry, + "imported-and-exported func": testImportedAndExportedFunc, + "host function with context parameter": testHostFunctionContextParameter, + "host function with numeric parameter": testHostFunctionNumericParameter, + "close module with in-flight calls": testCloseInFlight, + "multiple instantiation from same source": testMultipleInstantiation, } func TestEngineJIT(t *testing.T) { @@ -337,3 +339,41 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { }) } } + +func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { + compiled, err := r.CompileModule([]byte(`(module $test + (memory 1) + (func $store + i32.const 1 ;; memory offset + i64.const 1000 ;; expected value + i64.store + ) + (export "store" (func $store)) + )`)) + require.NoError(t, err) + defer compiled.Close() + + // Instantiate multiple modules with the same source (*CompiledCode). + for i := 0; i < 100; i++ { + module, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) + require.NoError(t, err) + defer module.Close() + + // Ensure that compilation cache doesn't cause race on memory instance. + before, ok := module.Memory().ReadUint64Le(1) + require.True(t, ok) + // Value must be zero as the memory must not be affected by the previously instantiated modules. + require.Zero(t, before) + + f := module.ExportedFunction("store") + require.NotNil(t, f) + + _, err = f.Call(nil) + require.NoError(t, err) + + // After the call, the value must be set properly. + after, ok := module.Memory().ReadUint64Le(1) + require.True(t, ok) + require.Equal(t, uint64(1000), after) + } +} diff --git a/wasi/usage_test.go b/wasi/usage_test.go index 2f2ae015..e77f36cd 100644 --- a/wasi/usage_test.go +++ b/wasi/usage_test.go @@ -26,6 +26,7 @@ func TestInstantiateModuleWithConfig(t *testing.T) { code, err := r.CompileModule(wasiArg) require.NoError(t, err) + defer code.Close() // Re-use the same module many times. for _, tc := range []string{"a", "b", "c"} { diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index 1c3a7545..d5ecfbfe 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -2014,15 +2014,16 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa _, err := r.NewModuleBuilder("wasi_snapshot_preview1").ExportFunctions(fns).Instantiate() require.NoError(t, err) - m, err := r.CompileModule([]byte(fmt.Sprintf(`(module + compiled, err := r.CompileModule([]byte(fmt.Sprintf(`(module %[2]s (memory 1 1) ;; just an arbitrary size big enough for tests (export "memory" (memory 0)) (export "%[1]s" (func $wasi.%[1]s)) )`, wasifunction, wasiimport))) require.NoError(t, err) + defer compiled.Close() - mod, err := r.InstantiateModuleWithConfig(m, wazero.NewModuleConfig().WithName(t.Name())) + mod, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(t.Name())) require.NoError(t, err) if sysCtx != nil { diff --git a/wasm.go b/wasm.go index f4a4ccb7..c348189d 100644 --- a/wasm.go +++ b/wasm.go @@ -72,6 +72,7 @@ type Runtime interface { // Ex. // r := wazero.NewRuntime() // code, _ := r.CompileModule(source) + // defer code.Close() // module, _ := r.InstantiateModule(code) // defer module.Close() // @@ -170,6 +171,8 @@ func (r *runtime) InstantiateModuleFromCode(source []byte) (api.Module, error) { if code, err := r.CompileModule(source); err != nil { return nil, err } else { + // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. + defer code.Close() return r.InstantiateModule(code) } } @@ -179,6 +182,8 @@ func (r *runtime) InstantiateModuleFromCodeWithConfig(source []byte, config *Mod if code, err := r.CompileModule(source); err != nil { return nil, err } else { + // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. + defer code.Close() return r.InstantiateModuleWithConfig(code, config) } } @@ -217,5 +222,6 @@ func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *Module return } } + code.addCacheEntry(module, r.store.Engine) return } diff --git a/wasm_test.go b/wasm_test.go index 5e8239b7..12d4a9c2 100644 --- a/wasm_test.go +++ b/wasm_test.go @@ -5,6 +5,7 @@ import ( _ "embed" "fmt" "math" + "strconv" "testing" "github.com/tetratelabs/wazero/api" @@ -55,6 +56,7 @@ func TestRuntime_DecodeModule(t *testing.T) { t.Run(tc.name, func(t *testing.T) { code, err := r.CompileModule(tc.source) require.NoError(t, err) + defer code.Close() if tc.expectedName != "" { require.Equal(t, tc.expectedName, code.module.NameSection.ModuleName) } @@ -321,6 +323,7 @@ func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { (start $start) )`)) require.NoError(t, err) + defer code.Close() // Instantiate the module, which calls the start function. This will fail if the context wasn't as intended. m, err := r.InstantiateModule(code) @@ -377,6 +380,7 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { r := NewRuntime() base, err := r.CompileModule([]byte(`(module $0 (memory 1))`)) require.NoError(t, err) + defer base.Close() require.Equal(t, "0", base.module.NameSection.ModuleName) @@ -406,3 +410,72 @@ func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx api `(module (import "host" "%[1]s" (func (result i64))) (export "%[1]s" (func 0)))`, functionName, )), mod.Close } + +func TestCompiledCode_addCacheEntry(t *testing.T) { + c := &CompiledCode{} + + m1, e1 := &wasm.Module{}, &mockEngine{name: "1"} + for i := 0; i < 5; i++ { + c.addCacheEntry(m1, e1) + } + + require.Contains(t, c.cachedEngines, e1) + require.Contains(t, c.cachedEngines[e1], m1) + require.Len(t, c.cachedEngines[e1], 1) + + m2, e2 := &wasm.Module{}, &mockEngine{name: "2"} + for i := 0; i < 5; i++ { + c.addCacheEntry(m2, e2) + } + + require.Contains(t, c.cachedEngines, e1) + require.Contains(t, c.cachedEngines, e2) + require.Contains(t, c.cachedEngines[e1], m1) + require.Contains(t, c.cachedEngines[e2], m2) + require.Len(t, c.cachedEngines[e1], 1) + require.Len(t, c.cachedEngines[e2], 1) +} + +func TestCompiledCode_Close(t *testing.T) { + e1, e2 := &mockEngine{name: "1", cachedModules: map[*wasm.Module]struct{}{}}, + &mockEngine{name: "2", cachedModules: map[*wasm.Module]struct{}{}} + + c := &CompiledCode{} + for _, e := range []wasm.Engine{e1, e2} { + for i := 0; i < 10; i++ { + m := &wasm.Module{} + _, _ = e.NewModuleEngine(strconv.Itoa(i), m, nil, nil, nil, nil) + c.addCacheEntry(m, e) + } + } + + // Before Close. + require.Len(t, e1.cachedModules, 10) + require.Len(t, e2.cachedModules, 10) + require.Len(t, c.cachedEngines, 2) + for _, modules := range c.cachedEngines { + require.Len(t, modules, 10) + } + + c.Close() + + // After Close. + require.Len(t, e1.cachedModules, 0) + require.Len(t, e2.cachedModules, 0) +} + +type mockEngine struct { + name string + cachedModules map[*wasm.Module]struct{} +} + +// NewModuleEngine implements the same method as documented on wasm.Engine. +func (e *mockEngine) NewModuleEngine(_ string, module *wasm.Module, _, _ []*wasm.FunctionInstance, _ *wasm.TableInstance, _ map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { + e.cachedModules[module] = struct{}{} + return nil, nil +} + +// ReleaseCompilationCache implements the same method as documented on wasm.Engine. +func (e *mockEngine) ReleaseCompilationCache(module *wasm.Module) { + delete(e.cachedModules, module) +}