compiler: fix compiledModule leak (#1608)

Signed-off-by: Nuno Cruces <ncruces@users.noreply.github.com>
Co-authored-by: Achille Roussel <achille.roussel@gmail.com>
This commit is contained in:
Nuno Cruces
2023-08-02 02:14:49 +01:00
committed by GitHub
parent 2f2b6a9d2c
commit 90f58bce75
10 changed files with 151 additions and 57 deletions

View File

@@ -822,7 +822,7 @@ func TestCompiler_callIndirect_largeTypeIndex(t *testing.T) {
makeExecutable(code1.Bytes())
f := function{
parent: &compiledFunction{parent: &compiledModule{executable: code1}},
parent: &compiledFunction{parent: &compiledCode{executable: code1}},
codeInitialAddress: uintptr(unsafe.Pointer(&code1.Bytes()[0])),
moduleInstance: env.moduleInstance,
}
@@ -896,7 +896,7 @@ func TestCompiler_compileCall(t *testing.T) {
makeExecutable(code.Bytes())
me.functions = append(me.functions, function{
parent: &compiledFunction{parent: &compiledModule{executable: code}},
parent: &compiledFunction{parent: &compiledCode{executable: code}},
codeInitialAddress: uintptr(unsafe.Pointer(&code.Bytes()[0])),
moduleInstance: env.moduleInstance,
})

View File

@@ -202,7 +202,7 @@ func (j *compilerEnv) callEngine() *callEngine {
}
func (j *compilerEnv) exec(machineCode []byte) {
cm := new(compiledModule)
cm := &compiledModule{compiledCode: &compiledCode{}}
if err := cm.executable.Map(len(machineCode)); err != nil {
panic(err)
}
@@ -211,7 +211,7 @@ func (j *compilerEnv) exec(machineCode []byte) {
makeExecutable(executable)
f := &function{
parent: &compiledFunction{parent: cm},
parent: &compiledFunction{parent: cm.compiledCode},
codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
moduleInstance: j.moduleInstance,
}
@@ -268,7 +268,7 @@ func newCompilerEnvironment() *compilerEnv {
Globals: []*wasm.GlobalInstance{},
Engine: me,
},
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledModule{}}}),
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
}
}

View File

@@ -48,6 +48,10 @@ type (
// as the underlying memory region is accessed by assembly directly by using
// codesElement0Address.
functions []function
// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule
}
// callEngine holds context per moduleEngine.Call, and shared across all the
@@ -130,11 +134,13 @@ type (
// initialFn is the initial function for this call engine.
initialFn *function
// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule
// stackIterator provides a way to iterate over the stack for Listeners.
// It is setup and valid only during a call to a Listener hook.
stackIterator stackIterator
ensureTermination bool
}
// moduleContext holds the per-function call specific module information.
@@ -264,12 +270,27 @@ type (
}
compiledModule struct {
executable asm.CodeSegment
// The data that need to be accessed by compiledFunction.parent are
// separated in an embedded field because we use finalizers to manage
// the lifecycle of compiledModule instances and having cyclic pointers
// prevents the Go runtime from calling them, which results in memory
// leaks since the memory mapped code segments cannot be released.
//
// The indirection guarantees that the finalizer set on compiledModule
// instances can run when all references are gone, and the Go GC can
// manage to reclaim the compiledCode when all compiledFunction objects
// referencing it have been freed.
*compiledCode
functions []compiledFunction
source *wasm.Module
ensureTermination bool
}
compiledCode struct {
source *wasm.Module
executable asm.CodeSegment
}
// compiledFunction corresponds to a function in a module (not instantiated one). This holds the machine code
// compiled by wazero compiler.
compiledFunction struct {
@@ -282,7 +303,7 @@ type (
index wasm.Index
goFunc interface{}
listener experimental.FunctionListener
parent *compiledModule
parent *compiledCode
sourceOffsetMap sourceOffsetMap
}
@@ -496,13 +517,6 @@ func (e *engine) Close() (err error) {
e.mux.Lock()
defer e.mux.Unlock()
// Releasing the references to compiled codes including the memory-mapped machine codes.
for i := range e.codes {
for j := range e.codes[i].functions {
e.codes[i].functions[j].parent = nil
}
}
e.codes = nil
return
}
@@ -523,9 +537,11 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
var withGoFunc bool
localFuncs, importedFuncs := len(module.FunctionSection), module.ImportFunctionCount
cm := &compiledModule{
compiledCode: &compiledCode{
source: module,
},
functions: make([]compiledFunction, localFuncs),
ensureTermination: ensureTermination,
source: module,
}
if localFuncs == 0 {
@@ -559,7 +575,7 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
funcIndex := wasm.Index(i)
compiledFn := &cm.functions[i]
compiledFn.executableOffset = executable.Size()
compiledFn.parent = cm
compiledFn.parent = cm.compiledCode
compiledFn.index = importedFuncs + funcIndex
if i < ln {
compiledFn.listener = listeners[i]
@@ -628,6 +644,8 @@ func (e *engine) NewModuleEngine(module *wasm.Module, instance *wasm.ModuleInsta
parent: c,
}
}
me.module = cm
return me, nil
}
@@ -720,7 +738,7 @@ func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.initialFn.moduleInstance
if ce.ensureTermination {
if ce.module.ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the call context
@@ -741,12 +759,14 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
err = m.FailIfClosed()
}
// Ensure that the compiled module will never be GC'd before this method returns.
runtime.KeepAlive(ce.module)
}()
ft := ce.initialFn.funcType
ce.initializeStack(ft, params)
if ce.ensureTermination {
if ce.module.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}
@@ -963,7 +983,7 @@ func (e *moduleEngine) newCallEngine(stackSize uint64, fn *function) *callEngine
archContext: newArchContext(),
initialFn: fn,
moduleContext: moduleContext{fn: fn},
ensureTermination: fn.parent.parent.ensureTermination,
module: e.module,
}
stackHeader := (*reflect.SliceHeader)(unsafe.Pointer(&ce.stack))

View File

@@ -20,7 +20,7 @@ func BenchmarkCallEngine_builtinFunctionFunctionListener(b *testing.B) {
},
},
index: 0,
parent: &compiledModule{
parent: &compiledCode{
source: &wasm.Module{
TypeSection: []wasm.FunctionType{{}},
FunctionSection: []wasm.Index{0},

View File

@@ -17,6 +17,7 @@ import (
func (e *engine) deleteCompiledModule(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()
delete(e.codes, module.ID)
// Note: we do not call e.Cache.Delete, as the lifetime of
@@ -158,14 +159,18 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul
ensureTermination := header[cachedVersionEnd] != 0
functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
cm = &compiledModule{functions: make([]compiledFunction, functionsNum), ensureTermination: ensureTermination}
cm = &compiledModule{
compiledCode: new(compiledCode),
functions: make([]compiledFunction, functionsNum),
ensureTermination: ensureTermination,
}
imported := module.ImportFunctionCount
var eightBytes [8]byte
for i := uint32(0); i < functionsNum; i++ {
f := &cm.functions[i]
f.parent = cm
f.parent = cm.compiledCode
// Read the stack pointer ceil.
if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {

View File

@@ -38,7 +38,9 @@ func TestSerializeCompiledModule(t *testing.T) {
}{
{
in: &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
@@ -57,11 +59,13 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
@@ -77,12 +81,14 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
{executableOffset: 5, stackPointerCeil: 0xffffffff},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
@@ -159,7 +165,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // machine code.
),
expCompiledModule: &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 0},
},
@@ -181,9 +189,11 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // code.
),
expCompiledModule: &compiledModule{
ensureTermination: true,
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
ensureTermination: true,
},
expStaleCache: false,
expErr: "",
@@ -208,7 +218,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
),
importedFunctionCount: 1,
expCompiledModule: &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 1},
{executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
@@ -279,8 +291,8 @@ func TestDeserializeCompiledModule(t *testing.T) {
if tc.expCompiledModule != nil {
require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
for i := 0; i < len(cm.functions); i++ {
require.Equal(t, cm, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm
require.Equal(t, cm.compiledCode, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm.compiledCode
}
}
@@ -361,13 +373,13 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
},
expHit: true,
expCompiledModule: &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{stackPointerCeil: 12345, executableOffset: 0, index: 0},
{stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
},
source: nil,
ensureTermination: false,
},
},
}
@@ -379,7 +391,7 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
if exp := tc.expCompiledModule; exp != nil {
exp.source = m
for i := range tc.expCompiledModule.functions {
tc.expCompiledModule.functions[i].parent = exp
tc.expCompiledModule.functions[i].parent = exp.compiledCode
}
}
@@ -422,7 +434,9 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
tc := filecache.New(t.TempDir())
e := engine{fileCache: tc}
cm := &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module!
@@ -438,7 +452,9 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
e := engine{fileCache: tc}
m := &wasm.Module{}
cm := &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
err := e.addCompiledModuleToCache(m, cm)

View File

@@ -234,7 +234,9 @@ func TestCompiler_CompileModule(t *testing.T) {
func TestCompiler_Releasecode_Panic(t *testing.T) {
captured := require.CapturePanic(func() {
releaseCompiledModule(&compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2),
},
})
})
require.Contains(t, captured.Error(), "compiler: failed to munmap code segment")
@@ -392,15 +394,15 @@ func TestCallEngine_deferredOnCall(t *testing.T) {
}
f1 := &function{
funcType: &wasm.FunctionType{ParamNumInUint64: 2},
parent: &compiledFunction{parent: &compiledModule{source: s}, index: 0},
parent: &compiledFunction{parent: &compiledCode{source: s}, index: 0},
}
f2 := &function{
funcType: &wasm.FunctionType{ParamNumInUint64: 2, ResultNumInUint64: 3},
parent: &compiledFunction{parent: &compiledModule{source: s}, index: 1},
parent: &compiledFunction{parent: &compiledCode{source: s}, index: 1},
}
f3 := &function{
funcType: &wasm.FunctionType{ResultNumInUint64: 1},
parent: &compiledFunction{parent: &compiledModule{source: s}, index: 2},
parent: &compiledFunction{parent: &compiledCode{source: s}, index: 2},
}
ce := &callEngine{
@@ -598,7 +600,7 @@ func TestCallEngine_builtinFunctionFunctionListenerBefore(t *testing.T) {
},
},
index: 0,
parent: &compiledModule{source: &wasm.Module{
parent: &compiledCode{source: &wasm.Module{
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{{}},
TypeSection: []wasm.FunctionType{{}},
@@ -624,7 +626,7 @@ func TestCallEngine_builtinFunctionFunctionListenerAfter(t *testing.T) {
},
},
index: 0,
parent: &compiledModule{source: &wasm.Module{
parent: &compiledCode{source: &wasm.Module{
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{{}},
TypeSection: []wasm.FunctionType{{}},

View File

@@ -44,7 +44,7 @@ func TestAmd64Compiler_indirectCallWithTargetOnCallingConvReg(t *testing.T) {
makeExecutable(executable)
f := function{
parent: &compiledFunction{parent: &compiledModule{executable: code}},
parent: &compiledFunction{parent: &compiledCode{executable: code}},
codeInitialAddress: code.Addr(),
moduleInstance: env.moduleInstance,
typeID: 0,

View File

@@ -42,7 +42,7 @@ func TestArm64Compiler_indirectCallWithTargetOnCallingConvReg(t *testing.T) {
makeExecutable(executable)
f := function{
parent: &compiledFunction{parent: &compiledModule{executable: code}},
parent: &compiledFunction{parent: &compiledCode{executable: code}},
codeInitialAddress: code.Addr(),
moduleInstance: env.moduleInstance,
}

View File

@@ -0,0 +1,51 @@
package adhoc
import (
"context"
"log"
"runtime"
"testing"
"time"
"github.com/tetratelabs/wazero"
)
func TestMemoryLeak(t *testing.T) {
if testing.Short() {
t.Skip("skipping memory leak test in short mode.")
}
duration := 5 * time.Second
t.Logf("running memory leak test for %s", duration)
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
for ctx.Err() == nil {
if err := testMemoryLeakInstantiateRuntimeAndModule(); err != nil {
log.Panicln(err)
}
}
var stats runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&stats)
if stats.Alloc > (100 * 1024 * 1024) {
t.Errorf("wazero used more than 100 MiB after running the test for %s (alloc=%d)", duration, stats.Alloc)
}
}
func testMemoryLeakInstantiateRuntimeAndModule() error {
ctx := context.Background()
runtime := wazero.NewRuntime(ctx)
defer runtime.Close(ctx)
mod, err := runtime.InstantiateWithConfig(ctx, memoryWasm,
wazero.NewModuleConfig().WithStartFunctions())
if err != nil {
return err
}
return mod.Close(ctx)
}