diff --git a/builder.go b/builder.go index c8ac5e02..8e7e4abe 100644 --- a/builder.go +++ b/builder.go @@ -309,11 +309,11 @@ func (b *hostModuleBuilder) Compile(ctx context.Context) (CompiledModule, error) } c := &compiledModule{module: module, compiledEngine: b.r.store.Engine} - if c.listeners, err = buildListeners(ctx, b.r, module); err != nil { + if c.listeners, err = buildListeners(ctx, module); err != nil { return nil, err } - if err = b.r.store.Engine.CompileModule(ctx, module); err != nil { + if err = b.r.store.Engine.CompileModule(ctx, module, c.listeners); err != nil { return nil, err } diff --git a/config_test.go b/config_test.go index 3cc59622..358a5769 100644 --- a/config_test.go +++ b/config_test.go @@ -607,7 +607,7 @@ func Test_compiledModule_Close(t *testing.T) { var cs []*compiledModule for i := 0; i < 10; i++ { m := &wasm.Module{} - err := e.CompileModule(ctx, m) + err := e.CompileModule(ctx, m, nil) require.NoError(t, err) cs = append(cs, &compiledModule{module: m, compiledEngine: e}) } diff --git a/experimental/experimental_test.go b/experimental/experimental_test.go deleted file mode 100644 index 2b89d07e..00000000 --- a/experimental/experimental_test.go +++ /dev/null @@ -1,8 +0,0 @@ -// Package experimental_test includes examples for experimental features. When these complete, they'll end up as real -// examples in the /examples directory. -package experimental_test - -import "context" - -// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. -var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") diff --git a/experimental/listener.go b/experimental/listener.go index 9ea0e854..c7e5c571 100644 --- a/experimental/listener.go +++ b/experimental/listener.go @@ -8,8 +8,6 @@ import ( // FunctionListenerFactoryKey is a context.Context Value key. Its associated value should be a FunctionListenerFactory. // -// Note: This is interpreter-only for now! -// // See https://github.com/tetratelabs/wazero/issues/451 type FunctionListenerFactoryKey struct{} diff --git a/experimental/listener_example_test.go b/experimental/listener_example_test.go index 0e3b4276..ee1fbe0b 100644 --- a/experimental/listener_example_test.go +++ b/experimental/listener_example_test.go @@ -58,7 +58,7 @@ func Example_customListenerFactory() { // Set context to one that has an experimental listener ctx := context.WithValue(context.Background(), FunctionListenerFactoryKey{}, u) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) // This closes everything this Runtime created. wasi_snapshot_preview1.MustInstantiate(ctx, r) diff --git a/experimental/listener_test.go b/experimental/listener_test.go index e034522c..1d557929 100644 --- a/experimental/listener_test.go +++ b/experimental/listener_test.go @@ -8,7 +8,6 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" . "github.com/tetratelabs/wazero/experimental" - "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/internal/wasm/binary" @@ -18,17 +17,27 @@ import ( var _ FunctionListenerFactory = &recorder{} type recorder struct { - m map[string]struct{} + m map[string]struct{} + beforeNames, afterNames []string +} + +func (r *recorder) Before(ctx context.Context, def api.FunctionDefinition, _ []uint64) context.Context { + r.beforeNames = append(r.beforeNames, def.DebugName()) + return ctx +} + +func (r *recorder) After(_ context.Context, def api.FunctionDefinition, _ error, _ []uint64) { + r.afterNames = append(r.afterNames, def.DebugName()) } func (r *recorder) NewListener(definition api.FunctionDefinition) FunctionListener { r.m[definition.Name()] = struct{}{} - return nil + return r } func TestFunctionListenerFactory(t *testing.T) { // Set context to one that has an experimental listener - factory := &recorder{map[string]struct{}{}} + factory := &recorder{m: map[string]struct{}{}} ctx := context.WithValue(context.Background(), FunctionListenerFactoryKey{}, factory) // Define a module with two functions @@ -37,37 +46,57 @@ func TestFunctionListenerFactory(t *testing.T) { ImportSection: []*wasm.Import{{}}, FunctionSection: []wasm.Index{0, 0}, CodeSection: []*wasm.Code{ - {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + // fn1 + {Body: []byte{ + // call fn2 twice + wasm.OpcodeCall, 2, + wasm.OpcodeCall, 2, + wasm.OpcodeEnd, + }}, + // fn2 + {Body: []byte{wasm.OpcodeEnd}}, }, + ExportSection: []*wasm.Export{{Name: "fn1", Type: wasm.ExternTypeFunc, Index: 1}}, NameSection: &wasm.NameSection{ ModuleName: "test", FunctionNames: wasm.NameMap{ - {Index: 0, Name: "import"}, // should skip + {Index: 0, Name: "import"}, // should skip for building listeners. {Index: 1, Name: "fn1"}, {Index: 2, Name: "fn2"}, }, }, }) - if platform.CompilerSupported() { - t.Run("fails on compile if compiler", func(t *testing.T) { - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler()) - defer r.Close(testCtx) // This closes everything this Runtime created. - _, err := r.CompileModule(ctx, bin) - require.EqualError(t, err, - "context includes a FunctionListenerFactoryKey, which is only supported in the interpreter") - }) - } - - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) // This closes everything this Runtime created. - _, err := r.CompileModule(ctx, bin) + _, err := r.NewHostModuleBuilder("").NewFunctionBuilder().WithFunc(func() {}).Export("").Instantiate(ctx, r) + require.NoError(t, err) + + // Ensure the imported function was converted to a listener. + require.Equal(t, map[string]struct{}{"": {}}, factory.m) + + compiled, err := r.CompileModule(ctx, bin) require.NoError(t, err) // Ensure each function was converted to a listener eagerly require.Equal(t, map[string]struct{}{ + "": {}, "fn1": {}, "fn2": {}, }, factory.m) + + // Ensures that FunctionListener is a compile-time option, so passing context.Background here + // is ok to use listeners at runtime. + m, err := r.InstantiateModule(context.Background(), compiled, wazero.NewModuleConfig()) + require.NoError(t, err) + + fn1 := m.ExportedFunction("fn1") + require.NotNil(t, fn1) + + _, err = fn1.Call(context.Background()) + require.NoError(t, err) + + require.Equal(t, []string{"test.fn1", "test.fn2", "test.fn2"}, factory.beforeNames) + require.Equal(t, []string{"test.fn2", "test.fn2", "test.fn1"}, factory.afterNames) // after is in the reverse order. } diff --git a/experimental/logging/log_listener_example_test.go b/experimental/logging/log_listener_example_test.go index 481f8cb0..d21d31de 100644 --- a/experimental/logging/log_listener_example_test.go +++ b/experimental/logging/log_listener_example_test.go @@ -25,7 +25,7 @@ func Example_newHostLoggingListenerFactory() { // Set context to one that has an experimental listener ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewHostLoggingListenerFactory(os.Stdout)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) // This closes everything this Runtime created. wasi_snapshot_preview1.MustInstantiate(ctx, r) @@ -63,7 +63,7 @@ func Example_newLoggingListenerFactory() { // Set context to one that has an experimental listener ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(os.Stdout)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) // This closes everything this Runtime created. wasi_snapshot_preview1.MustInstantiate(ctx, r) diff --git a/imports/assemblyscript/assemblyscript_test.go b/imports/assemblyscript/assemblyscript_test.go index bf07c13d..0a19133d 100644 --- a/imports/assemblyscript/assemblyscript_test.go +++ b/imports/assemblyscript/assemblyscript_test.go @@ -422,7 +422,7 @@ func requireProxyModule(t *testing.T, fns FunctionExporter, config wazero.Module // Set context to one that has an experimental listener ctx := context.WithValue(testCtx, FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&log)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) builder := r.NewHostModuleBuilder("env") fns.ExportFunctions(builder) diff --git a/imports/emscripten/emscripten_test.go b/imports/emscripten/emscripten_test.go index 11b42ed0..487a2aad 100644 --- a/imports/emscripten/emscripten_test.go +++ b/imports/emscripten/emscripten_test.go @@ -36,7 +36,7 @@ func TestGrow(t *testing.T) { // Set context to one that has an experimental listener ctx := context.WithValue(testCtx, FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&log)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) wasi_snapshot_preview1.MustInstantiate(ctx, r) @@ -59,7 +59,7 @@ func TestInvoke(t *testing.T) { // Set context to one that has an experimental listener ctx := context.WithValue(testCtx, FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&log)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) _, err := Instantiate(ctx, r) diff --git a/imports/wasi_snapshot_preview1/wasi_test.go b/imports/wasi_snapshot_preview1/wasi_test.go index e8be54ae..7d759275 100644 --- a/imports/wasi_snapshot_preview1/wasi_test.go +++ b/imports/wasi_snapshot_preview1/wasi_test.go @@ -87,7 +87,7 @@ func requireProxyModule(t *testing.T, config wazero.ModuleConfig) (api.Module, a // Set context to one that has an experimental listener ctx := context.WithValue(testCtx, FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&log)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) wasiModuleCompiled, err := (&builder{r}).hostModuleBuilder().Compile(ctx) require.NoError(t, err) @@ -115,7 +115,7 @@ func requireErrnoNosys(t *testing.T, funcName string, params ...uint64) string { // Set context to one that has an experimental listener ctx := context.WithValue(testCtx, FunctionListenerFactoryKey{}, logging.NewHostLoggingListenerFactory(&log)) - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter()) + r := wazero.NewRuntime(ctx) defer r.Close(ctx) // Instantiate the wasi module. diff --git a/internal/engine/compiler/arch_amd64.go b/internal/engine/compiler/arch_amd64.go index c33f80c6..39a53532 100644 --- a/internal/engine/compiler/arch_amd64.go +++ b/internal/engine/compiler/arch_amd64.go @@ -22,6 +22,6 @@ func newArchContextImpl() (ret archContext) { return } // newCompiler returns a new compiler interface which can be used to compile the given function instance. // Note: ir param can be nil for host functions. -func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { - return newAmd64Compiler(ir) +func newCompiler(ir *wazeroir.CompilationResult, withListener bool) (compiler, error) { + return newAmd64Compiler(ir, withListener) } diff --git a/internal/engine/compiler/arch_arm64.go b/internal/engine/compiler/arch_arm64.go index d975d9bd..e15c08f9 100644 --- a/internal/engine/compiler/arch_arm64.go +++ b/internal/engine/compiler/arch_arm64.go @@ -47,6 +47,6 @@ func newArchContextImpl() archContext { // newCompiler returns a new compiler interface which can be used to compile the given function instance. // Note: ir param can be nil for host functions. -func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { - return newArm64Compiler(ir) +func newCompiler(ir *wazeroir.CompilationResult, withListener bool) (compiler, error) { + return newArm64Compiler(ir, withListener) } diff --git a/internal/engine/compiler/arch_other.go b/internal/engine/compiler/arch_other.go index 962ee3d9..9fc62359 100644 --- a/internal/engine/compiler/arch_other.go +++ b/internal/engine/compiler/arch_other.go @@ -13,6 +13,6 @@ import ( type archContext struct{} // newCompiler returns an unsupported error. -func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { +func newCompiler(ir *wazeroir.CompilationResult, _ bool) (compiler, error) { return nil, fmt.Errorf("unsupported GOARCH %s", runtime.GOARCH) } diff --git a/internal/engine/compiler/compiler_bench_test.go b/internal/engine/compiler/compiler_bench_test.go index 0cc88135..1d0ef906 100644 --- a/internal/engine/compiler/compiler_bench_test.go +++ b/internal/engine/compiler/compiler_bench_test.go @@ -25,7 +25,7 @@ func BenchmarkCompiler_compileMemoryCopy(b *testing.B) { testMem[i] = byte(i) } - compiler, _ := newCompiler(&wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}) + compiler, _ := newCompiler(&wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}, false) err := compiler.compilePreamble() requireNoError(b, err) @@ -77,7 +77,7 @@ func BenchmarkCompiler_compileMemoryFill(b *testing.B) { testMem[i] = byte(i) } - compiler, _ := newCompiler(&wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}) + compiler, _ := newCompiler(&wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}, false) err := compiler.compilePreamble() requireNoError(b, err) diff --git a/internal/engine/compiler/compiler_drop_test.go b/internal/engine/compiler/compiler_drop_test.go index 5c39041e..028f157b 100644 --- a/internal/engine/compiler/compiler_drop_test.go +++ b/internal/engine/compiler/compiler_drop_test.go @@ -11,7 +11,7 @@ import ( func Test_compileDropRange(t *testing.T) { t.Run("nil range", func(t *testing.T) { - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) err = compileDropRange(c, nil) @@ -19,7 +19,7 @@ func Test_compileDropRange(t *testing.T) { }) t.Run("start at the top", func(t *testing.T) { - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) // Use up all unreserved registers. @@ -94,7 +94,7 @@ func TestRuntimeValueLocationStack_dropsLivesForInclusiveRange(t *testing.T) { func Test_getTemporariesForStackedLiveValues(t *testing.T) { t.Run("no stacked values", func(t *testing.T) { liveValues := []*runtimeValueLocation{{register: 1}, {register: 2}} - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) gpTmp, vecTmp, err := getTemporariesForStackedLiveValues(c, liveValues) @@ -113,7 +113,7 @@ func Test_getTemporariesForStackedLiveValues(t *testing.T) { {valueType: runtimeValueTypeI32}, {valueType: runtimeValueTypeI64}, } - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) if !freeRegisterExists { @@ -154,7 +154,7 @@ func Test_getTemporariesForStackedLiveValues(t *testing.T) { {valueType: runtimeValueTypeV128Lo}, {valueType: runtimeValueTypeV128Hi}, } - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) if !freeRegisterExists { @@ -189,7 +189,7 @@ func Test_migrateLiveValue(t *testing.T) { }) t.Run("already on register", func(t *testing.T) { // This case, we don't use tmp registers. - c, err := newCompiler(nil) // we don't use ir in compileDropRange, so passing nil is fine. + c, err := newCompiler(nil, false) // we don't use ir in compileDropRange, so passing nil is fine. require.NoError(t, err) // Push the dummy values. diff --git a/internal/engine/compiler/compiler_test.go b/internal/engine/compiler/compiler_test.go index 5b62773f..c9f906b5 100644 --- a/internal/engine/compiler/compiler_test.go +++ b/internal/engine/compiler/compiler_test.go @@ -226,7 +226,7 @@ func (j *compilerEnv) exec(codeSegment []byte) { } // newTestCompiler allows us to test a different architecture than the current one. -type newTestCompiler func(ir *wazeroir.CompilationResult) (compiler, error) +type newTestCompiler func(ir *wazeroir.CompilationResult, _ bool) (compiler, error) func (j *compilerEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, ir *wazeroir.CompilationResult) compilerImpl { requireSupportedOSArch(t) @@ -237,7 +237,7 @@ func (j *compilerEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, ir *w Signature: &wasm.FunctionType{}, } } - c, err := fn(ir) + c, err := fn(ir, false) require.NoError(t, err) diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index 92670a5a..570bff09 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -10,6 +10,7 @@ import ( "unsafe" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/compilationcache" "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/version" @@ -118,6 +119,20 @@ type ( // initialFn is the initial function for this call engine. initialFn *function + + // ctx is the context.Context passed to all the host function calls. + // This is modified when there's a function listener call, otherwise it's always the context.Context + // passed to the Call API. + ctx context.Context + // contextStack is a stack of contexts which is pushed and popped by function listeners. + // This is used and modified when there are function listeners. + contextStack *contextStack + } + + // contextStack is a stack of context.Context. + contextStack struct { + self context.Context + prev *contextStack } // moduleContext holds the per-function call specific module information. @@ -441,7 +456,7 @@ func (e *engine) DeleteCompiledModule(module *wasm.Module) { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module, listeners []experimental.FunctionListener) error { if _, ok, err := e.getCodes(module); ok { // cache hit! return nil } else if err != nil { @@ -455,15 +470,18 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { importedFuncs := module.ImportFuncCount() funcs := make([]*code, len(module.FunctionSection)) + ln := len(listeners) for i, ir := range irs { + withListener := i < ln && listeners[i] != nil + funcIndex := wasm.Index(i) var compiled *code if ir.GoFunc != nil { - if compiled, err = compileGoDefinedHostFunction(ir); err != nil { + if compiled, err = compileGoDefinedHostFunction(ir, withListener); err != nil { def := module.FunctionDefinitionSection[funcIndex+importedFuncs] return fmt.Errorf("error compiling host go func[%s]: %w", def.DebugName(), err) } - } else if compiled, err = compileWasmFunction(e.enabledFeatures, ir); err != nil { + } else if compiled, err = compileWasmFunction(e.enabledFeatures, ir, withListener); err != nil { def := module.FunctionDefinitionSection[funcIndex+importedFuncs] return fmt.Errorf("error compiling wasm func[%s]: %w", def.DebugName(), err) } @@ -805,6 +823,8 @@ const ( builtinFunctionIndexMemoryGrow wasm.Index = iota builtinFunctionIndexGrowStack builtinFunctionIndexTableGrow + builtinFunctionIndexFunctionListenerBefore + builtinFunctionIndexFunctionListenerAfter // builtinFunctionIndexBreakPoint is internal (only for wazero developers). Disabled by default. builtinFunctionIndexBreakPoint ) @@ -812,6 +832,7 @@ const ( func (ce *callEngine) execWasmFunction(ctx context.Context, callCtx *wasm.CallContext) { codeAddr := ce.initialFn.codeInitialAddress modAddr := ce.initialFn.moduleInstanceAddress + ce.ctx = ctx entry: { @@ -837,9 +858,9 @@ entry: fn := calleeHostFunction.source.GoFunc switch fn := fn.(type) { case api.GoModuleFunction: - fn.Call(ctx, callCtx.WithMemory(ce.memoryInstance), stack) + fn.Call(ce.ctx, callCtx.WithMemory(ce.memoryInstance), stack) case api.GoFunction: - fn.Call(ctx, stack) + fn.Call(ce.ctx, stack) } codeAddr, modAddr = ce.returnAddress, ce.moduleInstanceAddress @@ -848,11 +869,15 @@ entry: caller := ce.moduleContext.fn switch ce.exitContext.builtinFunctionCallIndex { case builtinFunctionIndexMemoryGrow: - ce.builtinFunctionMemoryGrow(ctx, caller.source.Module.Memory) + ce.builtinFunctionMemoryGrow(ce.ctx, caller.source.Module.Memory) case builtinFunctionIndexGrowStack: ce.builtinFunctionGrowStack(caller.stackPointerCeil) case builtinFunctionIndexTableGrow: - ce.builtinFunctionTableGrow(ctx, caller.source.Module.Tables) + ce.builtinFunctionTableGrow(ce.ctx, caller.source.Module.Tables) + case builtinFunctionIndexFunctionListenerBefore: + ce.builtinFunctionFunctionListenerBefore(ce.ctx, caller.source) + case builtinFunctionIndexFunctionListenerAfter: + ce.builtinFunctionFunctionListenerAfter(ce.ctx, caller.source) } if false { if ce.exitContext.builtinFunctionCallIndex == builtinFunctionIndexBreakPoint { @@ -917,8 +942,23 @@ func (ce *callEngine) builtinFunctionTableGrow(ctx context.Context, tables []*wa ce.pushValue(uint64(res)) } -func compileGoDefinedHostFunction(ir *wazeroir.CompilationResult) (*code, error) { - compiler, err := newCompiler(ir) +func (ce *callEngine) builtinFunctionFunctionListenerBefore(ctx context.Context, fn *wasm.FunctionInstance) { + base := int(ce.stackBasePointerInBytes >> 3) + listerCtx := fn.Listener.Before(ctx, fn.Definition, ce.stack[base:base+fn.Type.ParamNumInUint64]) + prevStackTop := ce.contextStack + ce.contextStack = &contextStack{self: ctx, prev: prevStackTop} + ce.ctx = listerCtx +} + +func (ce *callEngine) builtinFunctionFunctionListenerAfter(ctx context.Context, fn *wasm.FunctionInstance) { + base := int(ce.stackBasePointerInBytes >> 3) + fn.Listener.After(ctx, fn.Definition, nil, ce.stack[base:base+fn.Type.ResultNumInUint64]) + ce.ctx = ce.contextStack.self + ce.contextStack = ce.contextStack.prev +} + +func compileGoDefinedHostFunction(ir *wazeroir.CompilationResult, withListener bool) (*code, error) { + compiler, err := newCompiler(ir, withListener) if err != nil { return nil, err } @@ -935,8 +975,8 @@ func compileGoDefinedHostFunction(ir *wazeroir.CompilationResult) (*code, error) return &code{codeSegment: c}, nil } -func compileWasmFunction(_ api.CoreFeatures, ir *wazeroir.CompilationResult) (*code, error) { - compiler, err := newCompiler(ir) +func compileWasmFunction(_ api.CoreFeatures, ir *wazeroir.CompilationResult, withListener bool) (*code, error) { + compiler, err := newCompiler(ir, withListener) if err != nil { return nil, fmt.Errorf("failed to initialize assembly builder: %w", err) } diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index ee0a398d..a45b5713 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -1,6 +1,7 @@ package compiler import ( + "bytes" "context" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/experimental/logging" "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/testing/enginetest" "github.com/tetratelabs/wazero/internal/testing/require" @@ -19,8 +21,12 @@ import ( // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") -// et is used for tests defined in the enginetest package. -var et = &engineTester{} +var ( + // et is used for tests defined in the enginetest package. + et = &engineTester{} + functionLog bytes.Buffer + listenerFactory = logging.NewLoggingListenerFactory(&functionLog) +) // engineTester implements enginetest.EngineTester. type engineTester struct{} @@ -32,7 +38,7 @@ func (e *engineTester) IsCompiler() bool { // ListenerFactory implements the same method as documented on enginetest.EngineTester. func (e *engineTester) ListenerFactory() experimental.FunctionListenerFactory { - return nil + return listenerFactory } // NewEngine implements the same method as documented on enginetest.EngineTester. @@ -82,18 +88,74 @@ func TestCompiler_ModuleEngine_LookupFunction(t *testing.T) { } func TestCompiler_ModuleEngine_Call(t *testing.T) { + defer functionLog.Reset() requireSupportedOSArch(t) enginetest.RunTestModuleEngine_Call(t, et) + require.Equal(t, ` +--> .$0(1,2) +<-- (1,2) +`, "\n"+functionLog.String()) } func TestCompiler_ModuleEngine_Call_HostFn(t *testing.T) { + defer functionLog.Reset() requireSupportedOSArch(t) enginetest.RunTestModuleEngine_Call_HostFn(t, et) } func TestCompiler_ModuleEngine_Call_Errors(t *testing.T) { + defer functionLog.Reset() requireSupportedOSArch(t) enginetest.RunTestModuleEngine_Call_Errors(t, et) + + // TODO: Currently, the listener doesn't get notified on errors as they are + // implemented with panic. This means the end hooks aren't make resulting + // in dangling logs like this: + // ==> host.host_div_by(4294967295) + // instead of seeing a return like + // <== DivByZero + require.Equal(t, ` +--> imported.div_by.wasm(1) +<-- (1) +--> imported.div_by.wasm(1) +<-- (1) +--> imported.div_by.wasm(0) +--> imported.div_by.wasm(1) +<-- (1) +--> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> imported.call->div_by.go(1) + ==> host.div_by.go(1) + <== (1) +<-- (1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) + <== (1) + <-- (1) +<-- (1) +--> importing.call_import->call->div_by.go(4294967295) + --> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) + <== (1) + <-- (1) +<-- (1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) + <== (1) + <-- (1) +<-- (1) +`, "\n"+functionLog.String()) } func TestCompiler_ModuleEngine_Memory(t *testing.T) { @@ -136,11 +198,11 @@ func TestCompiler_CompileModule(t *testing.T) { ID: wasm.ModuleID{}, } - err := e.CompileModule(testCtx, okModule) + err := e.CompileModule(testCtx, okModule, nil) require.NoError(t, err) // Compiling same module shouldn't be compiled again, but instead should be cached. - err = e.CompileModule(testCtx, okModule) + err = e.CompileModule(testCtx, okModule, nil) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] @@ -167,7 +229,7 @@ func TestCompiler_CompileModule(t *testing.T) { errModule.BuildFunctionDefinitions() e := et.NewEngine(api.CoreFeaturesV1).(*engine) - err := e.CompileModule(testCtx, errModule) + err := e.CompileModule(testCtx, errModule, nil) require.EqualError(t, err, "failed to lower func[.$2] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, the compiled functions must not be cached. @@ -217,7 +279,7 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) { }}, nil, enabledFeatures) require.NoError(t, err) - err = s.Engine.CompileModule(testCtx, hm) + err = s.Engine.CompileModule(testCtx, hm, nil) require.NoError(t, err) _, err = s.Instantiate(testCtx, ns, hm, hostModuleName, nil, nil) @@ -273,7 +335,7 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) { } m.BuildFunctionDefinitions() - err = s.Engine.CompileModule(testCtx, m) + err = s.Engine.CompileModule(testCtx, m, nil) require.NoError(t, err) mi, err := s.Instantiate(testCtx, ns, m, t.Name(), nil, nil) @@ -518,3 +580,65 @@ func Test_callFrameOffset(t *testing.T) { require.Equal(t, 10, callFrameOffset(&wasm.FunctionType{ParamNumInUint64: 10, ResultNumInUint64: 5})) require.Equal(t, 100, callFrameOffset(&wasm.FunctionType{ParamNumInUint64: 100, ResultNumInUint64: 50})) } + +func TestCallEngine_builtinFunctionFunctionListenerBefore(t *testing.T) { + nextContext, currentContext, prevContext := context.Background(), context.Background(), context.Background() + f := &wasm.FunctionInstance{ + Definition: newMockFunctionDefinition("1"), + Type: &wasm.FunctionType{ParamNumInUint64: 3}, + Listener: mockListener{ + before: func(ctx context.Context, def api.FunctionDefinition, paramValues []uint64) context.Context { + require.Equal(t, currentContext, ctx) + require.Equal(t, []uint64{2, 3, 4}, paramValues) + return nextContext + }, + }, + } + ce := &callEngine{ + ctx: currentContext, stack: []uint64{0, 1, 2, 3, 4, 5}, + stackContext: stackContext{stackBasePointerInBytes: 16}, + contextStack: &contextStack{self: prevContext}, + } + ce.builtinFunctionFunctionListenerBefore(ce.ctx, f) + + // Contexts must be stacked. + require.Equal(t, currentContext, ce.contextStack.self) + require.Equal(t, prevContext, ce.contextStack.prev.self) +} + +func TestCallEngine_builtinFunctionFunctionListenerAfter(t *testing.T) { + currentContext, prevContext := context.Background(), context.Background() + f := &wasm.FunctionInstance{ + Definition: newMockFunctionDefinition("1"), + Type: &wasm.FunctionType{ResultNumInUint64: 1}, + Listener: mockListener{ + after: func(ctx context.Context, def api.FunctionDefinition, err error, resultValues []uint64) { + require.Equal(t, currentContext, ctx) + require.Equal(t, []uint64{5}, resultValues) + }, + }, + } + ce := &callEngine{ + ctx: currentContext, stack: []uint64{0, 1, 2, 3, 4, 5}, + stackContext: stackContext{stackBasePointerInBytes: 40}, + contextStack: &contextStack{self: prevContext}, + } + ce.builtinFunctionFunctionListenerAfter(ce.ctx, f) + + // Contexts must be popped. + require.Nil(t, ce.contextStack) + require.Equal(t, prevContext, ce.ctx) +} + +type mockListener struct { + before func(ctx context.Context, def api.FunctionDefinition, paramValues []uint64) context.Context + after func(ctx context.Context, def api.FunctionDefinition, err error, resultValues []uint64) +} + +func (m mockListener) Before(ctx context.Context, def api.FunctionDefinition, paramValues []uint64) context.Context { + return m.before(ctx, def, paramValues) +} + +func (m mockListener) After(ctx context.Context, def api.FunctionDefinition, err error, resultValues []uint64) { + m.after(ctx, def, err, resultValues) +} diff --git a/internal/engine/compiler/impl_amd64.go b/internal/engine/compiler/impl_amd64.go index 281d6df8..03c01e2b 100644 --- a/internal/engine/compiler/impl_amd64.go +++ b/internal/engine/compiler/impl_amd64.go @@ -91,15 +91,17 @@ type amd64Compiler struct { currentLabel string // onStackPointerCeilDeterminedCallBack hold a callback which are called when the max stack pointer is determined BEFORE generating native code. onStackPointerCeilDeterminedCallBack func(stackPointerCeil uint64) + withListener bool } -func newAmd64Compiler(ir *wazeroir.CompilationResult) (compiler, error) { +func newAmd64Compiler(ir *wazeroir.CompilationResult, withListener bool) (compiler, error) { c := &amd64Compiler{ assembler: amd64.NewAssembler(), locationStack: newRuntimeValueLocationStack(), currentLabel: wazeroir.EntrypointLabel, ir: ir, labels: map[string]*amd64LabelInfo{}, + withListener: withListener, } return c, nil } @@ -158,6 +160,12 @@ func (c *amd64Compiler) compileGoDefinedHostFunction() error { // First we must update the location stack to reflect the number of host function inputs. c.locationStack.init(c.ir.Signature) + if c.withListener { + if err := c.compileCallBuiltinFunction(builtinFunctionIndexFunctionListenerBefore); err != nil { + return err + } + } + if err := c.compileCallGoHostFunction(); err != nil { return err } @@ -4572,6 +4580,14 @@ func (c *amd64Compiler) compileReturnFunction() error { return err } + if c.withListener { + if err := c.compileCallBuiltinFunction(builtinFunctionIndexFunctionListenerAfter); err != nil { + return err + } + // After return, we re-initialize the stack base pointer as that is used to return to the caller below. + c.compileReservedStackBasePointerInitialization() + } + // amd64CallingConventionDestinationFunctionModuleInstanceAddressRegister holds the module instance's address // so mark it used so that it won't be used as a free register. c.locationStack.markRegisterUsed(amd64CallingConventionDestinationFunctionModuleInstanceAddressRegister) @@ -4733,6 +4749,12 @@ func (c *amd64Compiler) compilePreamble() (err error) { return err } + if c.withListener { + if err = c.compileCallBuiltinFunction(builtinFunctionIndexFunctionListenerBefore); err != nil { + return err + } + } + c.compileReservedStackBasePointerInitialization() // Finally, we initialize the reserved memory register based on the module context. diff --git a/internal/engine/compiler/impl_arm64.go b/internal/engine/compiler/impl_arm64.go index 09cf5205..c58232b5 100644 --- a/internal/engine/compiler/impl_arm64.go +++ b/internal/engine/compiler/impl_arm64.go @@ -28,14 +28,16 @@ type arm64Compiler struct { stackPointerCeil uint64 // onStackPointerCeilDeterminedCallBack hold a callback which are called when the ceil of stack pointer is determined before generating native code. onStackPointerCeilDeterminedCallBack func(stackPointerCeil uint64) + withListener bool } -func newArm64Compiler(ir *wazeroir.CompilationResult) (compiler, error) { +func newArm64Compiler(ir *wazeroir.CompilationResult, withListener bool) (compiler, error) { return &arm64Compiler{ assembler: arm64.NewAssembler(arm64ReservedRegisterForTemporary), locationStack: newRuntimeValueLocationStack(), ir: ir, labels: map[string]*arm64LabelInfo{}, + withListener: withListener, }, nil } @@ -188,10 +190,17 @@ func (c *arm64Compiler) compilePreamble() error { return err } + if c.withListener { + if err := c.compileCallGoFunction(nativeCallStatusCodeCallBuiltInFunction, builtinFunctionIndexFunctionListenerBefore); err != nil { + return err + } + } + // We must initialize the stack base pointer register so that we can manipulate the stack properly. c.compileReservedStackBasePointerRegisterInitialization() c.compileReservedMemoryRegisterInitialization() + return nil } @@ -263,6 +272,14 @@ func (c *arm64Compiler) compileReturnFunction() error { return err } + if c.withListener { + if err := c.compileCallGoFunction(nativeCallStatusCodeCallBuiltInFunction, builtinFunctionIndexFunctionListenerAfter); err != nil { + return err + } + // After return, we re-initialize the stack base pointer as that is used to return to the caller below. + c.compileReservedStackBasePointerRegisterInitialization() + } + // arm64CallingConventionModuleInstanceAddressRegister holds the module intstance's address // so mark it used so that it won't be used as a free register. c.locationStack.markRegisterUsed(arm64CallingConventionModuleInstanceAddressRegister) @@ -339,6 +356,13 @@ func (c *arm64Compiler) compileGoDefinedHostFunction() error { // First we must update the location stack to reflect the number of host function inputs. c.locationStack.init(c.ir.Signature) + if c.withListener { + if err := c.compileCallGoFunction(nativeCallStatusCodeCallBuiltInFunction, + builtinFunctionIndexFunctionListenerBefore); err != nil { + return err + } + } + if err := c.compileCallGoFunction(nativeCallStatusCodeCallGoHostFunction, 0); err != nil { return err } diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index 5a0b0361..a479388b 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -220,7 +220,7 @@ type interpreterOp struct { const callFrameStackSize = 0 // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module, _ []experimental.FunctionListener) error { if _, ok := e.getCodes(module); ok { // cache hit! return nil } diff --git a/internal/engine/interpreter/interpreter_test.go b/internal/engine/interpreter/interpreter_test.go index 09b8cc80..9ae5890b 100644 --- a/internal/engine/interpreter/interpreter_test.go +++ b/internal/engine/interpreter/interpreter_test.go @@ -542,7 +542,7 @@ func TestInterpreter_Compile(t *testing.T) { } errModule.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, errModule) + err := e.CompileModule(testCtx, errModule, nil) require.EqualError(t, err, "failed to lower func[.$2] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. @@ -563,7 +563,7 @@ func TestInterpreter_Compile(t *testing.T) { }, ID: wasm.ModuleID{}, } - err := e.CompileModule(testCtx, okModule) + err := e.CompileModule(testCtx, okModule, nil) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] diff --git a/internal/integration_test/bench/hostfunc_bench_test.go b/internal/integration_test/bench/hostfunc_bench_test.go index b4b3bba4..cb36f36c 100644 --- a/internal/integration_test/bench/hostfunc_bench_test.go +++ b/internal/integration_test/bench/hostfunc_bench_test.go @@ -191,7 +191,7 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { goReflectFn := host.Exports["go-reflect"].Function wasnFn := host.Exports["wasm"].Function - err := eng.CompileModule(testCtx, hostModule) + err := eng.CompileModule(testCtx, hostModule, nil) requireNoError(err) hostME, err := eng.NewModuleEngine(host.Name, hostModule, nil, host.Functions, nil, nil) @@ -224,7 +224,7 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { } importingModule.BuildFunctionDefinitions() - err = eng.CompileModule(testCtx, importingModule) + err = eng.CompileModule(testCtx, importingModule, nil) requireNoError(err) importing := &wasm.ModuleInstance{TypeIDs: []wasm.FunctionTypeID{0}} diff --git a/internal/integration_test/spectest/spectest.go b/internal/integration_test/spectest/spectest.go index 7c342dd4..d150c203 100644 --- a/internal/integration_test/spectest/spectest.go +++ b/internal/integration_test/spectest/spectest.go @@ -367,7 +367,7 @@ func addSpectestModule(t *testing.T, ctx context.Context, s *wasm.Store, ns *was err = mod.Validate(enabledFeatures) require.NoError(t, err) - err = s.Engine.CompileModule(ctx, mod) + err = s.Engine.CompileModule(ctx, mod, nil) require.NoError(t, err) _, err = s.Instantiate(ctx, ns, mod, mod.NameSection.ModuleName, sys.DefaultContext(nil), nil) @@ -433,7 +433,7 @@ func Run(t *testing.T, testDataFS embed.FS, ctx context.Context, newEngine func( maybeSetMemoryCap(mod) mod.BuildFunctionDefinitions() - err = s.Engine.CompileModule(ctx, mod) + err = s.Engine.CompileModule(ctx, mod, nil) require.NoError(t, err, msg) _, err = s.Instantiate(ctx, ns, mod, moduleName, nil, nil) @@ -571,7 +571,7 @@ func Run(t *testing.T, testDataFS embed.FS, ctx context.Context, newEngine func( maybeSetMemoryCap(mod) mod.BuildFunctionDefinitions() - err = s.Engine.CompileModule(ctx, mod) + err = s.Engine.CompileModule(ctx, mod, nil) require.NoError(t, err, msg) _, err = s.Instantiate(ctx, ns, mod, t.Name(), nil, nil) @@ -604,7 +604,7 @@ func requireInstantiationError(t *testing.T, ctx context.Context, s *wasm.Store, maybeSetMemoryCap(mod) mod.BuildFunctionDefinitions() - err = s.Engine.CompileModule(ctx, mod) + err = s.Engine.CompileModule(ctx, mod, nil) if err != nil { return } diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 6c1fcefd..10c3045d 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -70,7 +70,7 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { t.Run("sets module name", func(t *testing.T) { m := &wasm.Module{} - err := e.CompileModule(testCtx, m) + err := e.CompileModule(testCtx, m, nil) require.NoError(t, err) me, err := e.NewModuleEngine(t.Name(), m, nil, nil, nil, nil) require.NoError(t, err) @@ -92,7 +92,7 @@ func RunTestEngine_InitializeFuncrefGlobals(t *testing.T, et EngineTester) { }, } m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) + err := e.CompileModule(testCtx, m, nil) require.NoError(t, err) // To use the function, we first need to add it to a module. @@ -140,13 +140,15 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeLocalGet, 1, wasm.OpcodeEnd}}, }, } + m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) + listeners := buildListeners(et.ListenerFactory(), m) + err := e.CompileModule(testCtx, m, listeners) require.NoError(t, err) // To use the function, we first need to add it to a module. module := &wasm.ModuleInstance{Name: t.Name(), TypeIDs: []wasm.FunctionTypeID{0}} - module.Functions = module.BuildFunctions(m, buildListeners(et.ListenerFactory(), m)) + module.Functions = module.BuildFunctions(m, listeners) // Compile the module me, err := e.NewModuleEngine(module.Name, m, nil, module.Functions, nil, nil) @@ -210,7 +212,7 @@ func requireNewModuleEngine_emptyTable(t *testing.T, e wasm.Engine, et EngineTes ID: wasm.ModuleID{0}, } m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) + err := e.CompileModule(testCtx, m, nil) require.NoError(t, err) module = &wasm.ModuleInstance{Name: t.Name(), Tables: tables, TypeIDs: []wasm.FunctionTypeID{0}} @@ -245,7 +247,7 @@ func requireNewModuleEngine_multiTable(t *testing.T, e wasm.Engine, et EngineTes ID: wasm.ModuleID{1}, } m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) + err := e.CompileModule(testCtx, m, nil) require.NoError(t, err) module = &wasm.ModuleInstance{Name: t.Name(), Tables: tables, TypeIDs: []wasm.FunctionTypeID{0}} @@ -284,7 +286,7 @@ func requireNewModuleEngine_tableWithImportedFunction(t *testing.T, e wasm.Engin ID: wasm.ModuleID{2}, } importedModule.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, importedModule) + err := e.CompileModule(testCtx, importedModule, nil) require.NoError(t, err) imported := &wasm.ModuleInstance{Name: t.Name(), Tables: tables, TypeIDs: []wasm.FunctionTypeID{0}} @@ -303,7 +305,7 @@ func requireNewModuleEngine_tableWithImportedFunction(t *testing.T, e wasm.Engin ID: wasm.ModuleID{3}, } importingModule.BuildFunctionDefinitions() - err = e.CompileModule(testCtx, importingModule) + err = e.CompileModule(testCtx, importingModule, nil) require.NoError(t, err) tableInits := []wasm.TableInitEntry{ @@ -336,7 +338,7 @@ func requireNewModuleEngine_tableWithMixedFunctions(t *testing.T, e wasm.Engine, ID: wasm.ModuleID{4}, } importedModule.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, importedModule) + err := e.CompileModule(testCtx, importedModule, nil) require.NoError(t, err) imported := &wasm.ModuleInstance{Name: t.Name(), TypeIDs: []wasm.FunctionTypeID{0}} importedFunctions := imported.BuildFunctions(importedModule, buildListeners(et.ListenerFactory(), importedModule)) @@ -358,7 +360,7 @@ func requireNewModuleEngine_tableWithMixedFunctions(t *testing.T, e wasm.Engine, ID: wasm.ModuleID{5}, } importingModule.BuildFunctionDefinitions() - err = e.CompileModule(testCtx, importingModule) + err = e.CompileModule(testCtx, importingModule, nil) require.NoError(t, err) importing = &wasm.ModuleInstance{Name: t.Name(), Tables: tables, TypeIDs: []wasm.FunctionTypeID{0}} @@ -678,8 +680,9 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { }, } m.BuildFunctionDefinitions() + listeners := buildListeners(et.ListenerFactory(), m) - err := e.CompileModule(testCtx, m) + err := e.CompileModule(testCtx, m, listeners) require.NoError(t, err) // Assign memory to the module instance @@ -692,7 +695,7 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { var memory api.Memory = module.Memory // To use functions, we need to instantiate them (associate them with a ModuleInstance). - module.Functions = module.BuildFunctions(m, buildListeners(et.ListenerFactory(), m)) + module.Functions = module.BuildFunctions(m, listeners) module.BuildExports(m.ExportSection) grow, init := module.Functions[0], module.Functions[1] @@ -820,10 +823,11 @@ func setupCallTests(t *testing.T, e wasm.Engine, divBy *wasm.Code, fnlf experime ID: wasm.ModuleID{0}, } hostModule.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, hostModule) + lns := buildListeners(fnlf, hostModule) + err := e.CompileModule(testCtx, hostModule, lns) require.NoError(t, err) host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} - host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) + host.Functions = host.BuildFunctions(hostModule, lns) host.BuildExports(hostModule.ExportSection) hostFn := host.Exports[divByGoName].Function @@ -854,11 +858,12 @@ func setupCallTests(t *testing.T, e wasm.Engine, divBy *wasm.Code, fnlf experime ID: wasm.ModuleID{1}, } importedModule.BuildFunctionDefinitions() - err = e.CompileModule(testCtx, importedModule) + lns = buildListeners(fnlf, importedModule) + err = e.CompileModule(testCtx, importedModule, lns) require.NoError(t, err) imported := &wasm.ModuleInstance{Name: importedModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} - importedFunctions := imported.BuildFunctions(importedModule, buildListeners(fnlf, importedModule)) + importedFunctions := imported.BuildFunctions(importedModule, lns) imported.Functions = append([]*wasm.FunctionInstance{hostFn}, importedFunctions...) imported.BuildExports(importedModule.ExportSection) callHostFn := imported.Exports[callDivByGoName].Function @@ -886,12 +891,13 @@ func setupCallTests(t *testing.T, e wasm.Engine, divBy *wasm.Code, fnlf experime ID: wasm.ModuleID{2}, } importingModule.BuildFunctionDefinitions() - err = e.CompileModule(testCtx, importingModule) + lns = buildListeners(fnlf, importingModule) + err = e.CompileModule(testCtx, importingModule, lns) require.NoError(t, err) // Add the exported function. importing := &wasm.ModuleInstance{Name: importingModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} - importingFunctions := importing.BuildFunctions(importingModule, buildListeners(fnlf, importingModule)) + importingFunctions := importing.BuildFunctions(importingModule, lns) importing.Functions = append([]*wasm.FunctionInstance{callHostFn}, importingFunctions...) importing.BuildExports(importingModule.ExportSection) @@ -936,7 +942,7 @@ func setupCallMemTests(t *testing.T, e wasm.Engine, readMem *wasm.Code, fnlf exp ID: wasm.ModuleID{0}, } hostModule.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, hostModule) + err := e.CompileModule(testCtx, hostModule, nil) require.NoError(t, err) host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) @@ -976,7 +982,7 @@ func setupCallMemTests(t *testing.T, e wasm.Engine, readMem *wasm.Code, fnlf exp ID: wasm.ModuleID{1}, } importingModule.BuildFunctionDefinitions() - err = e.CompileModule(testCtx, importingModule) + err = e.CompileModule(testCtx, importingModule, nil) require.NoError(t, err) // Add the exported function. diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index ebafe321..80c3eef2 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -3,13 +3,15 @@ package wasm import ( "context" "errors" + + "github.com/tetratelabs/wazero/experimental" ) // Engine is a Store-scoped mechanism to compile functions declared or imported by a module. // This is a top-level type implemented by an interpreter or compiler. type Engine interface { // CompileModule implements the same method as documented on wasm.Engine. - CompileModule(ctx context.Context, module *Module) error + CompileModule(ctx context.Context, module *Module, listeners []experimental.FunctionListener) error // CompiledModuleCount is exported for testing, to track the size of the compilation cache. CompiledModuleCount() uint32 diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index cefa68df..8d36bc68 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/sys" "github.com/tetratelabs/wazero/internal/testing/hammer" @@ -362,7 +363,9 @@ func newStore() (*Store, *Namespace) { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *mockEngine) CompileModule(context.Context, *Module) error { return nil } +func (e *mockEngine) CompileModule(context.Context, *Module, []experimental.FunctionListener) error { + return nil +} // LookupFunction implements the same method as documented on wasm.Engine. func (e *mockModuleEngine) LookupFunction(*TableInstance, FunctionTypeID, Index) (Index, error) { diff --git a/runtime.go b/runtime.go index 7ec3755f..730d419d 100644 --- a/runtime.go +++ b/runtime.go @@ -189,11 +189,11 @@ func (r *runtime) CompileModule(ctx context.Context, binary []byte) (CompiledMod c := &compiledModule{module: internal, compiledEngine: r.store.Engine} - if c.listeners, err = buildListeners(ctx, r, internal); err != nil { + if c.listeners, err = buildListeners(ctx, internal); err != nil { return nil, err } - if err = r.store.Engine.CompileModule(ctx, internal); err != nil { + if err = r.store.Engine.CompileModule(ctx, internal, c.listeners); err != nil { return nil, err } @@ -201,15 +201,12 @@ func (r *runtime) CompileModule(ctx context.Context, binary []byte) (CompiledMod return c, nil } -func buildListeners(ctx context.Context, r *runtime, internal *wasm.Module) ([]experimentalapi.FunctionListener, error) { +func buildListeners(ctx context.Context, internal *wasm.Module) ([]experimentalapi.FunctionListener, error) { // Test to see if internal code are using an experimental feature. fnlf := ctx.Value(experimentalapi.FunctionListenerFactoryKey{}) if fnlf == nil { return nil, nil } - if !r.isInterpreter { - return nil, errors.New("context includes a FunctionListenerFactoryKey, which is only supported in the interpreter") - } factory := fnlf.(experimentalapi.FunctionListenerFactory) importCount := internal.ImportFuncCount() listeners := make([]experimentalapi.FunctionListener, len(internal.FunctionSection)) diff --git a/runtime_test.go b/runtime_test.go index 1ee7e4bd..720effdc 100644 --- a/runtime_test.go +++ b/runtime_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/version" @@ -297,7 +298,7 @@ func TestModule_Global(t *testing.T) { code := &compiledModule{module: tc.module} - err := r.store.Engine.CompileModule(testCtx, code.module) + err := r.store.Engine.CompileModule(testCtx, code.module, nil) require.NoError(t, err) // Instantiate the module and get the export of the above global @@ -647,7 +648,7 @@ type mockEngine struct { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *mockEngine) CompileModule(_ context.Context, module *wasm.Module) error { +func (e *mockEngine) CompileModule(_ context.Context, module *wasm.Module, _ []experimental.FunctionListener) error { e.cachedModules[module] = struct{}{} return nil }