diff --git a/internal/engine/compiler/compiler_controlflow_test.go b/internal/engine/compiler/compiler_controlflow_test.go index 278cb68f..132a25ec 100644 --- a/internal/engine/compiler/compiler_controlflow_test.go +++ b/internal/engine/compiler/compiler_controlflow_test.go @@ -822,7 +822,7 @@ func TestCompiler_callIndirect_largeTypeIndex(t *testing.T) { makeExecutable(code1.Bytes()) f := function{ - parent: &compiledFunction{parent: &compiledModule{executable: code1}}, + parent: &compiledFunction{parent: &compiledCode{executable: code1}}, codeInitialAddress: uintptr(unsafe.Pointer(&code1.Bytes()[0])), moduleInstance: env.moduleInstance, } @@ -896,7 +896,7 @@ func TestCompiler_compileCall(t *testing.T) { makeExecutable(code.Bytes()) me.functions = append(me.functions, function{ - parent: &compiledFunction{parent: &compiledModule{executable: code}}, + parent: &compiledFunction{parent: &compiledCode{executable: code}}, codeInitialAddress: uintptr(unsafe.Pointer(&code.Bytes()[0])), moduleInstance: env.moduleInstance, }) diff --git a/internal/engine/compiler/compiler_test.go b/internal/engine/compiler/compiler_test.go index 995b41cf..285300d7 100644 --- a/internal/engine/compiler/compiler_test.go +++ b/internal/engine/compiler/compiler_test.go @@ -202,7 +202,7 @@ func (j *compilerEnv) callEngine() *callEngine { } func (j *compilerEnv) exec(machineCode []byte) { - cm := new(compiledModule) + cm := &compiledModule{compiledCode: &compiledCode{}} if err := cm.executable.Map(len(machineCode)); err != nil { panic(err) } @@ -211,7 +211,7 @@ func (j *compilerEnv) exec(machineCode []byte) { makeExecutable(executable) f := &function{ - parent: &compiledFunction{parent: cm}, + parent: &compiledFunction{parent: cm.compiledCode}, codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])), moduleInstance: j.moduleInstance, } @@ -268,7 +268,7 @@ func newCompilerEnvironment() *compilerEnv { Globals: []*wasm.GlobalInstance{}, Engine: me, }, - ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledModule{}}}), + ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}), } } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index ca07a716..86abce4a 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -48,6 +48,10 @@ type ( // as the underlying memory region is accessed by assembly directly by using // codesElement0Address. functions []function + + // Keep a reference to the compiled module to prevent the GC from reclaiming + // it while the code may still be needed. + module *compiledModule } // callEngine holds context per moduleEngine.Call, and shared across all the @@ -130,11 +134,13 @@ type ( // initialFn is the initial function for this call engine. initialFn *function + // Keep a reference to the compiled module to prevent the GC from reclaiming + // it while the code may still be needed. + module *compiledModule + // stackIterator provides a way to iterate over the stack for Listeners. // It is setup and valid only during a call to a Listener hook. stackIterator stackIterator - - ensureTermination bool } // moduleContext holds the per-function call specific module information. @@ -264,12 +270,27 @@ type ( } compiledModule struct { - executable asm.CodeSegment - functions []compiledFunction - source *wasm.Module + // The data that need to be accessed by compiledFunction.parent are + // separated in an embedded field because we use finalizers to manage + // the lifecycle of compiledModule instances and having cyclic pointers + // prevents the Go runtime from calling them, which results in memory + // leaks since the memory mapped code segments cannot be released. + // + // The indirection guarantees that the finalizer set on compiledModule + // instances can run when all references are gone, and the Go GC can + // manage to reclaim the compiledCode when all compiledFunction objects + // referencing it have been freed. + *compiledCode + functions []compiledFunction + ensureTermination bool } + compiledCode struct { + source *wasm.Module + executable asm.CodeSegment + } + // compiledFunction corresponds to a function in a module (not instantiated one). This holds the machine code // compiled by wazero compiler. compiledFunction struct { @@ -282,7 +303,7 @@ type ( index wasm.Index goFunc interface{} listener experimental.FunctionListener - parent *compiledModule + parent *compiledCode sourceOffsetMap sourceOffsetMap } @@ -496,13 +517,6 @@ func (e *engine) Close() (err error) { e.mux.Lock() defer e.mux.Unlock() // Releasing the references to compiled codes including the memory-mapped machine codes. - - for i := range e.codes { - for j := range e.codes[i].functions { - e.codes[i].functions[j].parent = nil - } - } - e.codes = nil return } @@ -523,9 +537,11 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners var withGoFunc bool localFuncs, importedFuncs := len(module.FunctionSection), module.ImportFunctionCount cm := &compiledModule{ + compiledCode: &compiledCode{ + source: module, + }, functions: make([]compiledFunction, localFuncs), ensureTermination: ensureTermination, - source: module, } if localFuncs == 0 { @@ -559,7 +575,7 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners funcIndex := wasm.Index(i) compiledFn := &cm.functions[i] compiledFn.executableOffset = executable.Size() - compiledFn.parent = cm + compiledFn.parent = cm.compiledCode compiledFn.index = importedFuncs + funcIndex if i < ln { compiledFn.listener = listeners[i] @@ -628,6 +644,8 @@ func (e *engine) NewModuleEngine(module *wasm.Module, instance *wasm.ModuleInsta parent: c, } } + + me.module = cm return me, nil } @@ -720,7 +738,7 @@ func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error { func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) { m := ce.initialFn.moduleInstance - if ce.ensureTermination { + if ce.module.ensureTermination { select { case <-ctx.Done(): // If the provided context is already done, close the call context @@ -741,12 +759,14 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u // If the module closed during the call, and the call didn't err for another reason, set an ExitError. err = m.FailIfClosed() } + // Ensure that the compiled module will never be GC'd before this method returns. + runtime.KeepAlive(ce.module) }() ft := ce.initialFn.funcType ce.initializeStack(ft, params) - if ce.ensureTermination { + if ce.module.ensureTermination { done := m.CloseModuleOnCanceledOrTimeout(ctx) defer done() } @@ -959,11 +979,11 @@ var initialStackSize uint64 = 512 func (e *moduleEngine) newCallEngine(stackSize uint64, fn *function) *callEngine { ce := &callEngine{ - stack: make([]uint64, stackSize), - archContext: newArchContext(), - initialFn: fn, - moduleContext: moduleContext{fn: fn}, - ensureTermination: fn.parent.parent.ensureTermination, + stack: make([]uint64, stackSize), + archContext: newArchContext(), + initialFn: fn, + moduleContext: moduleContext{fn: fn}, + module: e.module, } stackHeader := (*reflect.SliceHeader)(unsafe.Pointer(&ce.stack)) diff --git a/internal/engine/compiler/engine_bench_test.go b/internal/engine/compiler/engine_bench_test.go index 991a98ac..39bac99e 100644 --- a/internal/engine/compiler/engine_bench_test.go +++ b/internal/engine/compiler/engine_bench_test.go @@ -20,7 +20,7 @@ func BenchmarkCallEngine_builtinFunctionFunctionListener(b *testing.B) { }, }, index: 0, - parent: &compiledModule{ + parent: &compiledCode{ source: &wasm.Module{ TypeSection: []wasm.FunctionType{{}}, FunctionSection: []wasm.Index{0}, diff --git a/internal/engine/compiler/engine_cache.go b/internal/engine/compiler/engine_cache.go index e6b3b0e9..37e481bd 100644 --- a/internal/engine/compiler/engine_cache.go +++ b/internal/engine/compiler/engine_cache.go @@ -17,6 +17,7 @@ import ( func (e *engine) deleteCompiledModule(module *wasm.Module) { e.mux.Lock() defer e.mux.Unlock() + delete(e.codes, module.ID) // Note: we do not call e.Cache.Delete, as the lifetime of @@ -158,14 +159,18 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul ensureTermination := header[cachedVersionEnd] != 0 functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:]) - cm = &compiledModule{functions: make([]compiledFunction, functionsNum), ensureTermination: ensureTermination} + cm = &compiledModule{ + compiledCode: new(compiledCode), + functions: make([]compiledFunction, functionsNum), + ensureTermination: ensureTermination, + } imported := module.ImportFunctionCount var eightBytes [8]byte for i := uint32(0); i < functionsNum; i++ { f := &cm.functions[i] - f.parent = cm + f.parent = cm.compiledCode // Read the stack pointer ceil. if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil { diff --git a/internal/engine/compiler/engine_cache_test.go b/internal/engine/compiler/engine_cache_test.go index a2ad6056..c26433a3 100644 --- a/internal/engine/compiler/engine_cache_test.go +++ b/internal/engine/compiler/engine_cache_test.go @@ -38,7 +38,9 @@ func TestSerializeCompiledModule(t *testing.T) { }{ { in: &compiledModule{ - executable: makeCodeSegment(1, 2, 3, 4, 5), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5), + }, functions: []compiledFunction{ {executableOffset: 0, stackPointerCeil: 12345}, }, @@ -57,11 +59,13 @@ func TestSerializeCompiledModule(t *testing.T) { }, { in: &compiledModule{ - ensureTermination: true, - executable: makeCodeSegment(1, 2, 3, 4, 5), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5), + }, functions: []compiledFunction{ {executableOffset: 0, stackPointerCeil: 12345}, }, + ensureTermination: true, }, exp: concat( []byte(wazeroMagic), @@ -77,12 +81,14 @@ func TestSerializeCompiledModule(t *testing.T) { }, { in: &compiledModule{ - ensureTermination: true, - executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3), + }, functions: []compiledFunction{ {executableOffset: 0, stackPointerCeil: 12345}, {executableOffset: 5, stackPointerCeil: 0xffffffff}, }, + ensureTermination: true, }, exp: concat( []byte(wazeroMagic), @@ -159,7 +165,9 @@ func TestDeserializeCompiledModule(t *testing.T) { []byte{1, 2, 3, 4, 5}, // machine code. ), expCompiledModule: &compiledModule{ - executable: makeCodeSegment(1, 2, 3, 4, 5), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5), + }, functions: []compiledFunction{ {executableOffset: 0, stackPointerCeil: 12345, index: 0}, }, @@ -181,9 +189,11 @@ func TestDeserializeCompiledModule(t *testing.T) { []byte{1, 2, 3, 4, 5}, // code. ), expCompiledModule: &compiledModule{ - ensureTermination: true, - executable: makeCodeSegment(1, 2, 3, 4, 5), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5), + }, functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}}, + ensureTermination: true, }, expStaleCache: false, expErr: "", @@ -208,7 +218,9 @@ func TestDeserializeCompiledModule(t *testing.T) { ), importedFunctionCount: 1, expCompiledModule: &compiledModule{ - executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + }, functions: []compiledFunction{ {executableOffset: 0, stackPointerCeil: 12345, index: 1}, {executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2}, @@ -279,8 +291,8 @@ func TestDeserializeCompiledModule(t *testing.T) { if tc.expCompiledModule != nil { require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions)) for i := 0; i < len(cm.functions); i++ { - require.Equal(t, cm, cm.functions[i].parent) - tc.expCompiledModule.functions[i].parent = cm + require.Equal(t, cm.compiledCode, cm.functions[i].parent) + tc.expCompiledModule.functions[i].parent = cm.compiledCode } } @@ -361,13 +373,13 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) { }, expHit: true, expCompiledModule: &compiledModule{ - executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + }, functions: []compiledFunction{ {stackPointerCeil: 12345, executableOffset: 0, index: 0}, {stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1}, }, - source: nil, - ensureTermination: false, }, }, } @@ -379,7 +391,7 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) { if exp := tc.expCompiledModule; exp != nil { exp.source = m for i := range tc.expCompiledModule.functions { - tc.expCompiledModule.functions[i].parent = exp + tc.expCompiledModule.functions[i].parent = exp.compiledCode } } @@ -422,8 +434,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) { tc := filecache.New(t.TempDir()) e := engine{fileCache: tc} cm := &compiledModule{ - executable: makeCodeSegment(1, 2, 3), - functions: []compiledFunction{{stackPointerCeil: 123}}, + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3), + }, + functions: []compiledFunction{{stackPointerCeil: 123}}, } m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module! err := e.addCompiledModuleToCache(m, cm) @@ -438,8 +452,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) { e := engine{fileCache: tc} m := &wasm.Module{} cm := &compiledModule{ - executable: makeCodeSegment(1, 2, 3), - functions: []compiledFunction{{stackPointerCeil: 123}}, + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2, 3), + }, + functions: []compiledFunction{{stackPointerCeil: 123}}, } err := e.addCompiledModuleToCache(m, cm) require.NoError(t, err) diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index 520b35f9..e18ace62 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -234,7 +234,9 @@ func TestCompiler_CompileModule(t *testing.T) { func TestCompiler_Releasecode_Panic(t *testing.T) { captured := require.CapturePanic(func() { releaseCompiledModule(&compiledModule{ - executable: makeCodeSegment(1, 2), + compiledCode: &compiledCode{ + executable: makeCodeSegment(1, 2), + }, }) }) require.Contains(t, captured.Error(), "compiler: failed to munmap code segment") @@ -392,15 +394,15 @@ func TestCallEngine_deferredOnCall(t *testing.T) { } f1 := &function{ funcType: &wasm.FunctionType{ParamNumInUint64: 2}, - parent: &compiledFunction{parent: &compiledModule{source: s}, index: 0}, + parent: &compiledFunction{parent: &compiledCode{source: s}, index: 0}, } f2 := &function{ funcType: &wasm.FunctionType{ParamNumInUint64: 2, ResultNumInUint64: 3}, - parent: &compiledFunction{parent: &compiledModule{source: s}, index: 1}, + parent: &compiledFunction{parent: &compiledCode{source: s}, index: 1}, } f3 := &function{ funcType: &wasm.FunctionType{ResultNumInUint64: 1}, - parent: &compiledFunction{parent: &compiledModule{source: s}, index: 2}, + parent: &compiledFunction{parent: &compiledCode{source: s}, index: 2}, } ce := &callEngine{ @@ -598,7 +600,7 @@ func TestCallEngine_builtinFunctionFunctionListenerBefore(t *testing.T) { }, }, index: 0, - parent: &compiledModule{source: &wasm.Module{ + parent: &compiledCode{source: &wasm.Module{ FunctionSection: []wasm.Index{0}, CodeSection: []wasm.Code{{}}, TypeSection: []wasm.FunctionType{{}}, @@ -624,7 +626,7 @@ func TestCallEngine_builtinFunctionFunctionListenerAfter(t *testing.T) { }, }, index: 0, - parent: &compiledModule{source: &wasm.Module{ + parent: &compiledCode{source: &wasm.Module{ FunctionSection: []wasm.Index{0}, CodeSection: []wasm.Code{{}}, TypeSection: []wasm.FunctionType{{}}, diff --git a/internal/engine/compiler/impl_amd64_test.go b/internal/engine/compiler/impl_amd64_test.go index ae7d08fd..316bfceb 100644 --- a/internal/engine/compiler/impl_amd64_test.go +++ b/internal/engine/compiler/impl_amd64_test.go @@ -44,7 +44,7 @@ func TestAmd64Compiler_indirectCallWithTargetOnCallingConvReg(t *testing.T) { makeExecutable(executable) f := function{ - parent: &compiledFunction{parent: &compiledModule{executable: code}}, + parent: &compiledFunction{parent: &compiledCode{executable: code}}, codeInitialAddress: code.Addr(), moduleInstance: env.moduleInstance, typeID: 0, diff --git a/internal/engine/compiler/impl_arm64_test.go b/internal/engine/compiler/impl_arm64_test.go index 315f0c65..da59b58f 100644 --- a/internal/engine/compiler/impl_arm64_test.go +++ b/internal/engine/compiler/impl_arm64_test.go @@ -42,7 +42,7 @@ func TestArm64Compiler_indirectCallWithTargetOnCallingConvReg(t *testing.T) { makeExecutable(executable) f := function{ - parent: &compiledFunction{parent: &compiledModule{executable: code}}, + parent: &compiledFunction{parent: &compiledCode{executable: code}}, codeInitialAddress: code.Addr(), moduleInstance: env.moduleInstance, } diff --git a/internal/integration_test/engine/memleak_test.go b/internal/integration_test/engine/memleak_test.go new file mode 100644 index 00000000..4cb0179a --- /dev/null +++ b/internal/integration_test/engine/memleak_test.go @@ -0,0 +1,51 @@ +package adhoc + +import ( + "context" + "log" + "runtime" + "testing" + "time" + + "github.com/tetratelabs/wazero" +) + +func TestMemoryLeak(t *testing.T) { + if testing.Short() { + t.Skip("skipping memory leak test in short mode.") + } + + duration := 5 * time.Second + t.Logf("running memory leak test for %s", duration) + + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + for ctx.Err() == nil { + if err := testMemoryLeakInstantiateRuntimeAndModule(); err != nil { + log.Panicln(err) + } + } + + var stats runtime.MemStats + runtime.GC() + runtime.ReadMemStats(&stats) + + if stats.Alloc > (100 * 1024 * 1024) { + t.Errorf("wazero used more than 100 MiB after running the test for %s (alloc=%d)", duration, stats.Alloc) + } +} + +func testMemoryLeakInstantiateRuntimeAndModule() error { + ctx := context.Background() + + runtime := wazero.NewRuntime(ctx) + defer runtime.Close(ctx) + + mod, err := runtime.InstantiateWithConfig(ctx, memoryWasm, + wazero.NewModuleConfig().WithStartFunctions()) + if err != nil { + return err + } + return mod.Close(ctx) +}