diff --git a/api/wasm.go b/api/wasm.go index c427a8eb..1bf33b09 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -272,6 +272,9 @@ type Function interface { // If Module.Close or Module.CloseWithExitCode were invoked during this call, the error returned may be a // sys.ExitError. Interpreting this is specific to the module. For example, some "main" functions always call a // function that exits. + // + // Call is not goroutine-safe, therefore it is recommended to create another Function if you want to invoke + // the same function concurrently. On the other hand, sequential invocations of Call is allowed. Call(ctx context.Context, params ...uint64) ([]uint64, error) } diff --git a/internal/engine/compiler/compiler_bench_test.go b/internal/engine/compiler/compiler_bench_test.go index d9c623a0..b736adfc 100644 --- a/internal/engine/compiler/compiler_bench_test.go +++ b/internal/engine/compiler/compiler_bench_test.go @@ -110,7 +110,7 @@ func BenchmarkCompiler_compileMemoryFill(b *testing.B) { } func (j *compilerEnv) execBench(b *testing.B, codeSegment []byte) { - f := j.newFunctionFrame(codeSegment) + f := j.newFunction(codeSegment) b.StartTimer() for i := 0; i < b.N; i++ { diff --git a/internal/engine/compiler/compiler_test.go b/internal/engine/compiler/compiler_test.go index 5a78f6ec..c2bbc51c 100644 --- a/internal/engine/compiler/compiler_test.go +++ b/internal/engine/compiler/compiler_test.go @@ -220,7 +220,7 @@ func (j *compilerEnv) callEngine() *callEngine { return j.ce } -func (j *compilerEnv) newFunctionFrame(codeSegment []byte) *function { +func (j *compilerEnv) newFunction(codeSegment []byte) *function { return &function{ parent: &code{codeSegment: codeSegment}, codeInitialAddress: uintptr(unsafe.Pointer(&codeSegment[0])), @@ -234,8 +234,10 @@ func (j *compilerEnv) newFunctionFrame(codeSegment []byte) *function { } func (j *compilerEnv) exec(codeSegment []byte) { - j.ce.callFrameStack[j.ce.globalContext.callFrameStackPointer] = callFrame{function: j.newFunctionFrame(codeSegment)} + f := j.newFunction(codeSegment) + j.ce.callFrameStack[j.ce.globalContext.callFrameStackPointer] = callFrame{function: f} j.ce.globalContext.callFrameStackPointer++ + j.ce.compiled = f nativecall( uintptr(unsafe.Pointer(&codeSegment[0])), @@ -292,6 +294,6 @@ func newCompilerEnvironment() *compilerEnv { Globals: []*wasm.GlobalInstance{}, Engine: me, }, - ce: me.newCallEngine(), + ce: me.newCallEngine(nil, nil), } } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index e911b6e5..61cc4a77 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -67,6 +67,11 @@ type ( // The currently executed function call frame lives at callFrameStack[callFrameStackPointer-1] // and that is equivalent to engine.callFrameTop(). callFrameStack []callFrame + + // compiled is the initial function for this call engine. + compiled *function + // source is the FunctionInstance from which compiled is created from. + source *wasm.FunctionInstance } // globalContext holds the data which is constant across multiple function calls. @@ -380,7 +385,7 @@ func (s nativeCallStatusCode) String() (ret string) { func (c *callFrame) String() string { return fmt.Sprintf( "[%s: return address=0x%x, return stack base pointer=%d]", - c.function.source.Definition().DebugName(), c.returnAddress, c.returnStackBasePointer, + c.function.source.FunctionDefinition.DebugName(), c.returnAddress, c.returnStackBasePointer, ) } @@ -524,8 +529,7 @@ func (e *moduleEngine) InitializeFuncrefGlobals(globals []*wasm.GlobalInstance) } } -// Call implements the same method as documented on wasm.ModuleEngine. -func (e *moduleEngine) Call(ctx context.Context, callCtx *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (e *moduleEngine) NewCallEngine(callCtx *wasm.CallContext, f *wasm.FunctionInstance) (ce wasm.CallEngine, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := e.functions[f.Idx] @@ -535,14 +539,16 @@ func (e *moduleEngine) Call(ctx context.Context, callCtx *wasm.CallContext, f *w } return } + return e.newCallEngine(f, compiled), nil +} +// Call implements the same method as documented on wasm.ModuleEngine. +func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, params ...uint64) (results []uint64, err error) { paramCount := len(params) - if f.Type.ParamNumInUint64 != paramCount { - return nil, fmt.Errorf("expected %d params, but passed %d", f.Type.ParamNumInUint64, paramCount) + if ce.source.Type.ParamNumInUint64 != paramCount { + return nil, fmt.Errorf("expected %d params, but passed %d", ce.source.Type.ParamNumInUint64, paramCount) } - ce := e.newCallEngine() - // We ensure that this Call method never panics as // this Call method is indirectly invoked by embedders via store.CallFunction, // and we have to make sure that all the runtime errors, including the one happening inside @@ -555,25 +561,31 @@ func (e *moduleEngine) Call(ctx context.Context, callCtx *wasm.CallContext, f *w // TODO: ^^ Will not fail if the function was imported from a closed module. if v := recover(); v != nil { - builder := wasmdebug.NewErrorBuilder() - // Handle edge-case where the host function is called directly by Go. - if ce.globalContext.callFrameStackPointer == 0 { - def := compiled.source.Definition() - builder.AddFrame(def.DebugName(), def.ParamTypes(), def.ResultTypes()) - } - for i := uint64(0); i < ce.globalContext.callFrameStackPointer; i++ { - def := ce.callFrameStack[ce.globalContext.callFrameStackPointer-1-i].function.source.Definition() - builder.AddFrame(def.DebugName(), def.ParamTypes(), def.ResultTypes()) - } - err = builder.FromRecovered(v) + err = ce.recoverOnCall(v) } }() for _, v := range params { ce.pushValue(v) } - ce.execWasmFunction(ctx, callCtx, compiled) - results = wasm.PopValues(f.Type.ResultNumInUint64, ce.popValue) + ce.execWasmFunction(ctx, callCtx) + results = wasm.PopValues(ce.source.Type.ResultNumInUint64, ce.popValue) + return +} + +// recoverOnCall takes the recovered value `recoverOnCall`, and wraps it +// with the call frame stack traces. Also, reset the state of callEngine +// so that it can be used for the subsequent calls. +func (ce *callEngine) recoverOnCall(v interface{}) (err error) { + builder := wasmdebug.NewErrorBuilder() + for i := uint64(0); i < ce.callFrameStackPointer; i++ { + def := ce.callFrameStack[ce.callFrameStackPointer-1-i].function.source.FunctionDefinition + builder.AddFrame(def.DebugName(), def.ParamTypes(), def.ResultTypes()) + } + err = builder.FromRecovered(v) + + // Allows the reuse of CallEngine. + ce.stackPointer, ce.callFrameStackPointer = 0, 0 return } @@ -627,11 +639,13 @@ var ( initialCallFrameStackSize = 16 ) -func (e *moduleEngine) newCallEngine() *callEngine { +func (e *moduleEngine) newCallEngine(source *wasm.FunctionInstance, compiled *function) *callEngine { ce := &callEngine{ valueStack: make([]uint64, initialValueStackSize), callFrameStack: make([]callFrame, initialCallFrameStackSize), archContext: newArchContext(), + source: source, + compiled: compiled, } valueStackHeader := (*reflect.SliceHeader)(unsafe.Pointer(&ce.valueStack)) @@ -679,9 +693,9 @@ const ( builtinFunctionIndexBreakPoint ) -func (ce *callEngine) execWasmFunction(ctx context.Context, callCtx *wasm.CallContext, f *function) { +func (ce *callEngine) execWasmFunction(ctx context.Context, callCtx *wasm.CallContext) { // Push the initial callframe. - ce.callFrameStack[0] = callFrame{returnAddress: f.codeInitialAddress, function: f} + ce.callFrameStack[0] = callFrame{returnAddress: ce.compiled.codeInitialAddress, function: ce.compiled} ce.globalContext.callFrameStackPointer++ entry: @@ -693,7 +707,7 @@ entry: } // Call into the native code. - nativecall(frame.returnAddress, uintptr(unsafe.Pointer(ce)), f.moduleInstanceAddress) + nativecall(frame.returnAddress, uintptr(unsafe.Pointer(ce)), frame.function.moduleInstanceAddress) // Check the status code from Compiler code. switch status := ce.exitContext.statusCode; status { diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index 1be1c314..ec4b92e0 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -2,11 +2,13 @@ package compiler import ( "context" + "errors" "fmt" "runtime" "testing" "unsafe" + "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/testing/enginetest" @@ -302,3 +304,60 @@ func TestCallEngine_builtinFunctionTableGrow(t *testing.T) { require.Equal(t, 1, len(table.References)) require.Equal(t, uintptr(0xff), table.References[0]) } + +func TestCallEngine_recoverOnCall(t *testing.T) { + ce := &callEngine{ + valueStack: make([]uint64, 100), + valueStackContext: valueStackContext{stackPointer: 3}, + globalContext: globalContext{callFrameStackPointer: 5}, + callFrameStack: []callFrame{ + {function: &function{source: &wasm.FunctionInstance{FunctionDefinition: newMockFunctionDefinition("1")}}}, + {function: &function{source: &wasm.FunctionInstance{FunctionDefinition: newMockFunctionDefinition("2")}}}, + {function: &function{source: &wasm.FunctionInstance{FunctionDefinition: newMockFunctionDefinition("3")}}}, + {function: &function{source: &wasm.FunctionInstance{FunctionDefinition: newMockFunctionDefinition("4")}}}, + {function: &function{source: &wasm.FunctionInstance{FunctionDefinition: newMockFunctionDefinition("5")}}}, + }, + } + + beforeRecoverValueStack, beforeRecoverCallFrameStack := ce.valueStack, ce.callFrameStack + + err := ce.recoverOnCall(errors.New("some error")) + require.EqualError(t, err, `some error (recovered by wazero) +wasm stack trace: + 5() + 4() + 3() + 2() + 1()`) + + // After recover, the stack pointers must be reset, but the underlying slices must be intact + // for the subsequent calls. + require.Equal(t, uint64(0), ce.stackPointer) + require.Equal(t, uint64(0), ce.callFrameStackPointer) + require.Equal(t, beforeRecoverValueStack, ce.valueStack) + require.Equal(t, beforeRecoverCallFrameStack, ce.callFrameStack) +} + +func newMockFunctionDefinition(name string) api.FunctionDefinition { + return &mockFunctionDefinition{debugName: name, FunctionDefinition: &wasm.FunctionDefinition{}} +} + +type mockFunctionDefinition struct { + debugName string + *wasm.FunctionDefinition +} + +// DebugName implements the same method as documented on api.FunctionDefinition. +func (f *mockFunctionDefinition) DebugName() string { + return f.debugName +} + +// ParamTypes implements api.FunctionDefinition ParamTypes. +func (f *mockFunctionDefinition) ParamTypes() []wasm.ValueType { + return []wasm.ValueType{} +} + +// ResultTypes implements api.FunctionDefinition ResultTypes. +func (f *mockFunctionDefinition) ResultTypes() []wasm.ValueType { + return []wasm.ValueType{} +} diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index da2e98b2..bd2c7e74 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -88,10 +88,15 @@ type callEngine struct { // frames are the function call stack. frames []*callFrame + + // compiled is the initial function for this call engine. + compiled *function + // source is the FunctionInstance from which compiled is created from. + source *wasm.FunctionInstance } -func (e *moduleEngine) newCallEngine() *callEngine { - return &callEngine{} +func (e *moduleEngine) newCallEngine(source *wasm.FunctionInstance, compiled *function) *callEngine { + return &callEngine{source: source, compiled: compiled} } func (ce *callEngine) pushValue(v uint64) { @@ -750,25 +755,27 @@ func (e *moduleEngine) InitializeFuncrefGlobals(globals []*wasm.GlobalInstance) } } -// Call implements the same method as documented on wasm.ModuleEngine. -func (e *moduleEngine) Call(ctx context.Context, m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (e *moduleEngine) NewCallEngine(callCtx *wasm.CallContext, f *wasm.FunctionInstance) (ce wasm.CallEngine, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := e.functions[f.Idx] if compiled == nil { // Lazy check the cause as it could be because the module was already closed. - if err = m.FailIfClosed(); err == nil { - panic(fmt.Errorf("BUG: %s.codes[%d] was nil before close", e.name, f.Idx)) + if err = callCtx.FailIfClosed(); err == nil { + panic(fmt.Errorf("BUG: %s.func[%d] was nil before close", e.name, f.Idx)) } return } + return e.newCallEngine(f, compiled), nil +} - paramSignature := f.Type.ParamNumInUint64 +// Call implements the same method as documented on wasm.ModuleEngine. +func (ce *callEngine) Call(ctx context.Context, m *wasm.CallContext, params ...uint64) (results []uint64, err error) { + paramSignature := ce.source.Type.ParamNumInUint64 paramCount := len(params) if paramSignature != paramCount { return nil, fmt.Errorf("expected %d params, but passed %d", paramSignature, paramCount) } - ce := e.newCallEngine() defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { @@ -777,14 +784,7 @@ func (e *moduleEngine) Call(ctx context.Context, m *wasm.CallContext, f *wasm.Fu // TODO: ^^ Will not fail if the function was imported from a closed module. if v := recover(); v != nil { - builder := wasmdebug.NewErrorBuilder() - frameCount := len(ce.frames) - for i := 0; i < frameCount; i++ { - frame := ce.popFrame() - def := frame.f.source.Definition() - builder.AddFrame(def.DebugName(), def.ParamTypes(), def.ResultTypes()) - } - err = builder.FromRecovered(v) + err = ce.recoverOnCall(v) } }() @@ -792,9 +792,27 @@ func (e *moduleEngine) Call(ctx context.Context, m *wasm.CallContext, f *wasm.Fu ce.pushValue(param) } - ce.callFunction(ctx, m, compiled) + ce.callFunction(ctx, m, ce.compiled) - results = wasm.PopValues(f.Type.ResultNumInUint64, ce.popValue) + results = wasm.PopValues(ce.source.Type.ResultNumInUint64, ce.popValue) + return +} + +// recoverOnCall takes the recovered value `recoverOnCall`, and wraps it +// with the call frame stack traces. Also, reset the state of callEngine +// so that it can be used for the subsequent calls. +func (ce *callEngine) recoverOnCall(v interface{}) (err error) { + builder := wasmdebug.NewErrorBuilder() + frameCount := len(ce.frames) + for i := 0; i < frameCount; i++ { + frame := ce.popFrame() + def := frame.f.source.FunctionDefinition + builder.AddFrame(def.DebugName(), def.ParamTypes(), def.ResultTypes()) + } + err = builder.FromRecovered(v) + + // Allows the reuse of CallEngine. + ce.stack, ce.frames = ce.stack[:0], ce.frames[:0] return } @@ -811,7 +829,7 @@ func (ce *callEngine) callFunction(ctx context.Context, callCtx *wasm.CallContex func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { callCtx = callCtx.WithMemory(ce.callerMemory()) if f.source.FunctionListener != nil { - ctx = f.source.FunctionListener.Before(ctx, f.source.Definition(), params) + ctx = f.source.FunctionListener.Before(ctx, f.source.FunctionDefinition, params) } frame := &callFrame{f: f} ce.pushFrame(frame) @@ -819,7 +837,7 @@ func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, ce.popFrame() if f.source.FunctionListener != nil { // TODO: This doesn't get the error due to use of panic to propagate them. - f.source.FunctionListener.After(ctx, f.source.Definition(), nil, results) + f.source.FunctionListener.After(ctx, f.source.FunctionDefinition, nil, results) } return } @@ -4301,10 +4319,10 @@ func i32Abs(v uint32) uint32 { } func (ce *callEngine) callNativeFuncWithListener(ctx context.Context, callCtx *wasm.CallContext, f *function, fnl experimental.FunctionListener) context.Context { - ctx = fnl.Before(ctx, f.source.Definition(), ce.peekValues(len(f.source.Type.Params))) + ctx = fnl.Before(ctx, f.source.FunctionDefinition, ce.peekValues(len(f.source.Type.Params))) ce.callNativeFunc(ctx, callCtx, f) // TODO: This doesn't get the error due to use of panic to propagate them. - fnl.After(ctx, f.source.Definition(), nil, ce.peekValues(len(f.source.Type.Results))) + fnl.After(ctx, f.source.FunctionDefinition, nil, ce.peekValues(len(f.source.Type.Results))) return ctx } diff --git a/internal/integration_test/bench/hostfunc_bench_test.go b/internal/integration_test/bench/hostfunc_bench_test.go index df02aa7b..95f525f1 100644 --- a/internal/integration_test/bench/hostfunc_bench_test.go +++ b/internal/integration_test/bench/hostfunc_bench_test.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "encoding/binary" + "fmt" "math" "testing" @@ -40,14 +41,14 @@ func BenchmarkHostFunctionCall(b *testing.B) { binary.LittleEndian.PutUint32(m.Memory.Buffer[offset:], math.Float32bits(val)) b.Run(callGoHostName, func(b *testing.B) { - callGoHost := m.Exports[callGoHostName].Function - if callGoHost == nil { - b.Fatal() + ce, err := getCallEngine(m, callGoHostName) + if err != nil { + b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - res, err := callGoHost.Call(testCtx, offset) + res, err := ce.Call(testCtx, m.CallCtx, offset) if err != nil { b.Fatal(err) } @@ -58,14 +59,14 @@ func BenchmarkHostFunctionCall(b *testing.B) { }) b.Run(callWasmHostName, func(b *testing.B) { - callWasmHost := m.Exports[callWasmHostName].Function - if callWasmHost == nil { - b.Fatal() + ce, err := getCallEngine(m, callWasmHostName) + if err != nil { + b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - res, err := callWasmHost.Call(testCtx, offset) + res, err := ce.Call(testCtx, m.CallCtx, offset) if err != nil { b.Fatal(err) } @@ -85,8 +86,11 @@ func TestBenchmarkFunctionCall(t *testing.T) { require.NoError(t, err) }) - callWasmHost := m.Exports[callWasmHostName].Function - callGoHost := m.Exports[callGoHostName].Function + callWasmHost, err := getCallEngine(m, callWasmHostName) + require.NoError(t, err) + + callGoHost, err := getCallEngine(m, callGoHostName) + require.NoError(t, err) require.NotNil(t, callWasmHost) require.NotNil(t, callGoHost) @@ -104,16 +108,16 @@ func TestBenchmarkFunctionCall(t *testing.T) { for _, f := range []struct { name string - f *wasm.FunctionInstance + ce wasm.CallEngine }{ - {name: "go", f: callGoHost}, - {name: "wasm", f: callWasmHost}, + {name: "go", ce: callGoHost}, + {name: "wasm", ce: callWasmHost}, } { f := f t.Run(f.name, func(t *testing.T) { for _, tc := range tests { binary.LittleEndian.PutUint32(mem[tc.offset:], math.Float32bits(tc.val)) - res, err := f.f.Call(context.Background(), uint64(tc.offset)) + res, err := f.ce.Call(context.Background(), m.CallCtx, uint64(tc.offset)) require.NoError(t, err) require.Equal(t, math.Float32bits(tc.val), uint32(res[0])) } @@ -121,6 +125,17 @@ func TestBenchmarkFunctionCall(t *testing.T) { } } +func getCallEngine(m *wasm.ModuleInstance, name string) (ce wasm.CallEngine, err error) { + f := m.Exports[name].Function + if f == nil { + err = fmt.Errorf("%s not found", name) + return + } + + ce, err = m.Engine.NewCallEngine(m.CallCtx, f) + return +} + func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { eng := compiler.NewEngine(context.Background(), wasm.Features20220419) diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index 119b9eb9..085bceb8 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -555,7 +555,7 @@ func testMemOps(t *testing.T, r wazero.Runtime) { require.NoError(t, err) require.Zero(t, results[0]) // should succeed and return the old size in pages. - // Any offset larger than the current size should be out of of bounds error even when it is less than memory capacity. + // Any offset larger than the current size should be out of bounds error even when it is less than memory capacity. _, err = memory.ExportedFunction("store").Call(testCtx, wasm.MemoryPagesToBytesNum(memoryCapacityPages)-8) require.Error(t, err) // Out of bounds error. diff --git a/internal/integration_test/filecache/filecache_test.go b/internal/integration_test/filecache/filecache_test.go index 9c31f968..8f87493f 100644 --- a/internal/integration_test/filecache/filecache_test.go +++ b/internal/integration_test/filecache/filecache_test.go @@ -47,7 +47,7 @@ func TestSpecTestCompilerCache(t *testing.T) { cmd.Stdout = buf cmd.Stderr = buf err = cmd.Run() - require.NoError(t, err) + require.NoError(t, err, buf.String()) exp = append(exp, "PASS\n") } diff --git a/internal/integration_test/vs/jit/jit_test.go b/internal/integration_test/vs/jit/jit_test.go index f0c4dd27..accd3802 100644 --- a/internal/integration_test/vs/jit/jit_test.go +++ b/internal/integration_test/vs/jit/jit_test.go @@ -31,7 +31,3 @@ func TestHostCall(t *testing.T) { func BenchmarkHostCall(b *testing.B) { vs.RunBenchmarkHostCall(b, runtime) } - -func TestBenchmarkHostCall_CompilerFastest(t *testing.T) { - vs.RunTestBenchmarkHostCall_CompilerFastest(t, runtime()) -} diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index f5777983..de243850 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -154,17 +154,27 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. fn := module.Functions[0] - results, err := me.Call(testCtx, module.CallCtx, fn, 1, 2) + + ce, err := me.NewCallEngine(module.CallCtx, fn) + require.NoError(t, err) + + results, err := ce.Call(testCtx, module.CallCtx, 1, 2) require.NoError(t, err) require.Equal(t, []uint64{1, 2}, results) t.Run("errs when not enough parameters", func(t *testing.T) { - _, err := me.Call(testCtx, module.CallCtx, fn) + ce, err := me.NewCallEngine(module.CallCtx, fn) + require.NoError(t, err) + + _, err = ce.Call(testCtx, module.CallCtx) require.EqualError(t, err, "expected 2 params, but passed 0") }) t.Run("errs when too many parameters", func(t *testing.T) { - _, err := me.Call(testCtx, module.CallCtx, fn, 1, 2, 3) + ce, err := me.NewCallEngine(module.CallCtx, fn) + require.NoError(t, err) + + _, err = ce.Call(testCtx, module.CallCtx, 1, 2, 3) require.EqualError(t, err, "expected 2 params, but passed 3") }) } @@ -364,7 +374,10 @@ func runTestModuleEngine_Call_HostFn_Mem(t *testing.T, et EngineTester, readMem tc := tt t.Run(tc.name, func(t *testing.T) { - results, err := tc.fn.Module.Engine.Call(testCtx, importing.CallCtx, tc.fn) + ce, err := tc.fn.Module.Engine.NewCallEngine(tc.fn.Module.CallCtx, tc.fn) + require.NoError(t, err) + + results, err := ce.Call(testCtx, importing.CallCtx) require.NoError(t, err) require.Equal(t, tc.expected, results[0]) }) @@ -416,7 +429,11 @@ func runTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester, hostDivBy *w t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - results, err := f.Module.Engine.Call(testCtx, m, f, 1) + + ce, err := f.Module.Engine.NewCallEngine(m, f) + require.NoError(t, err) + + results, err := ce.Call(testCtx, m, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -508,11 +525,15 @@ wasm stack trace: t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - _, err := f.Module.Engine.Call(testCtx, m, f, tc.input...) + + ce, err := f.Module.Engine.NewCallEngine(m, f) + require.NoError(t, err) + + _, err = ce.Call(testCtx, m, tc.input...) require.EqualError(t, err, tc.expectedErr) // Ensure the module still works - results, err := f.Module.Engine.Call(testCtx, m, f, 1) + results, err := ce.Call(testCtx, m, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -594,7 +615,9 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { require.Equal(t, make([]byte, wasmPhraseSize), buf) // Initialize the memory using Wasm. This copies the test phrase. - _, err = me.Call(testCtx, module.CallCtx, init) + initCallEngine, err := me.NewCallEngine(module.CallCtx, init) + require.NoError(t, err) + _, err = initCallEngine.Call(testCtx, module.CallCtx) require.NoError(t, err) // We expect the same []byte read earlier to now include the phrase in wasm. @@ -624,14 +647,18 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { require.Equal(t, hostPhraseTruncated, string(buf2)) // Now, we need to prove the other direction, that when Wasm changes the capacity, the host's buffer is unaffected. - _, err = me.Call(testCtx, module.CallCtx, grow, 1) + growCallEngine, err := me.NewCallEngine(module.CallCtx, grow) + require.NoError(t, err) + _, err = growCallEngine.Call(testCtx, module.CallCtx, 1) require.NoError(t, err) // The host buffer should still contain the same bytes as before grow require.Equal(t, hostPhraseTruncated, string(buf2)) // Re-initialize the memory in wasm, which overwrites the region. - _, err = me.Call(testCtx, module.CallCtx, init) + initCallEngine2, err := me.NewCallEngine(module.CallCtx, init) + require.NoError(t, err) + _, err = initCallEngine2.Call(testCtx, module.CallCtx) require.NoError(t, err) // The host was not affected because it is a different slice due to "memory.grow" affecting the underlying memory. diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index b6440716..982bb79a 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -140,22 +140,47 @@ func (m *CallContext) ExportedFunction(name string) api.Function { if err != nil { return nil } - if exp.Function.Module == m.module { - return exp.Function - } else { - return &importedFn{importingModule: m, importedFn: exp.Function} + + fi := exp.Function + ce, err := exp.Function.Module.Engine.NewCallEngine(m, fi) + if err != nil { + return nil } + + if exp.Function.Module == m.module { + return &function{fi: fi, ce: ce} + } else { + return &importedFn{importingModule: m, importedFn: fi, ce: ce} + } +} + +// function implements api.Function. This couples FunctionInstance with CallEngine so that +// it can be used to make function calls originating from the FunctionInstance. +type function struct { + fi *FunctionInstance + ce CallEngine +} + +// Definition implements the same method as documented on api.FunctionDefinition. +func (f *function) Definition() api.FunctionDefinition { + return f.fi.FunctionDefinition +} + +// Call implements the same method as documented on api.Function. +func (f *function) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { + return f.ce.Call(ctx, f.fi.Module.CallCtx, params...) } // importedFn implements api.Function and ensures the call context of an imported function is the importing module. type importedFn struct { + ce CallEngine importingModule *CallContext importedFn *FunctionInstance } // Definition implements the same method as documented on api.Function. func (f *importedFn) Definition() api.FunctionDefinition { - return f.importedFn.definition + return f.importedFn.FunctionDefinition } // Call implements the same method as documented on api.Function. @@ -164,17 +189,7 @@ func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, return nil, fmt.Errorf("directly calling host function is not supported") } mod := f.importingModule - return f.importedFn.Module.Engine.Call(ctx, mod, f.importedFn, params...) -} - -// Call implements the same method as documented on api.Function. -func (f *FunctionInstance) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { - if f.IsHostFunction { - return nil, fmt.Errorf("directly calling host function is not supported") - } - mod := f.Module - ret, err = mod.Engine.Call(ctx, mod.CallCtx, f, params...) - return + return f.ce.Call(ctx, mod, params...) } // ExportedGlobal implements the same method as documented on api.Module. diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index 28da20fb..5a9a3c6c 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -44,8 +44,8 @@ type ModuleEngine interface { // Name returns the name of the module this engine was compiled for. Name() string - // Call invokes a function instance f with given parameters. - Call(ctx context.Context, m *CallContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) + // NewCallEngine returns a CallEngine for the given FunctionInstance. + NewCallEngine(callCtx *CallContext, f *FunctionInstance) (CallEngine, error) // CreateFuncElementInstance creates an ElementInstance whose references are engine-specific function pointers // corresponding to the given `indexes`. @@ -55,6 +55,13 @@ type ModuleEngine interface { InitializeFuncrefGlobals(globals []*GlobalInstance) } +// CallEngine implements function calls for a FunctionInstance. It manages its own call frame stack and value stack, +// internally, and shouldn't be used concurrently. +type CallEngine interface { + // Call invokes a function instance f with given parameters. + Call(ctx context.Context, m *CallContext, params ...uint64) (results []uint64, err error) +} + // TableInitEntry is normalized element segment used for initializing tables by engines. type TableInitEntry struct { TableIndex Index diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index fe575220..3b10945c 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -150,7 +150,7 @@ func newModuleVal(m api.Module) reflect.Value { // MustParseGoFuncCode parses Code from the go function or panics. // -// Exposing this simplifies definition of host functions in built-in host +// Exposing this simplifies FunctionDefinition of host functions in built-in host // modules and tests. func MustParseGoFuncCode(fn interface{}) *Code { _, _, code, err := parseGoFunc(fn) diff --git a/internal/wasm/module.go b/internal/wasm/module.go index a4e35020..00a0185c 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -603,7 +603,7 @@ func (m *ModuleInstance) BuildFunctions(mod *Module, listeners []experimental.Fu f.Module = m f.Idx = d.index f.Type = d.funcType - f.definition = d + f.FunctionDefinition = d if listeners != nil { f.FunctionListener = listeners[i] } diff --git a/internal/wasm/module_test.go b/internal/wasm/module_test.go index 8984388e..ef10495a 100644 --- a/internal/wasm/module_test.go +++ b/internal/wasm/module_test.go @@ -787,7 +787,7 @@ func TestModule_buildFunctions(t *testing.T) { instance := &ModuleInstance{Name: "counter", TypeIDs: []FunctionTypeID{0}} instance.BuildFunctions(m, nil) for i, f := range instance.Functions { - require.Equal(t, i, f.Definition().Index()) + require.Equal(t, i, f.FunctionDefinition.Index()) require.Equal(t, nopCode.Body, f.Body) } } diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 644bfd97..3f596a57 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -140,8 +140,8 @@ type ( // Idx holds the index of this function instance in the function index namespace (beginning with imports). Idx Index - // definition is known at compile time. - definition api.FunctionDefinition + // FunctionDefinition is known at compile time. + FunctionDefinition api.FunctionDefinition // FunctionListener holds a listener to notify when this function is called. FunctionListener experimentalapi.FunctionListener @@ -164,11 +164,6 @@ type ( FunctionTypeID uint32 ) -// Definition implements the same method as documented on api.FunctionDefinition. -func (f *FunctionInstance) Definition() api.FunctionDefinition { - return f.definition -} - // The wazero specific limitations described at RATIONALE.md. const maximumFunctionTypes = 1 << 27 @@ -399,14 +394,21 @@ func (s *Store) instantiate( } // Compile the default context for calls to this module. - m.CallCtx = NewCallContext(ns, m, sysCtx) + callCtx := NewCallContext(ns, m, sysCtx) + m.CallCtx = callCtx // Execute the start function. if module.StartSection != nil { funcIdx := *module.StartSection f := m.Functions[funcIdx] - _, err = f.Module.Engine.Call(ctx, m.CallCtx, f) + ce, err := f.Module.Engine.NewCallEngine(callCtx, f) + if err != nil { + return nil, fmt.Errorf("create call engine for start function[%s]: %v", + module.funcDesc(SectionIDFunction, funcIdx), err) + } + + _, err = ce.Call(ctx, callCtx) if exitErr, ok := err.(*sys.ExitError); ok { // Don't wrap an exit error! return nil, exitErr } else if err != nil { @@ -448,7 +450,7 @@ func resolveImports(module *Module, modules map[string]*ModuleInstance) ( expectedType := module.TypeSection[i.DescFunc] importedFunction := imported.Function - d := importedFunction.Definition() + d := importedFunction.FunctionDefinition if !expectedType.EqualsSignature(d.ParamTypes(), d.ResultTypes()) { actualType := &FunctionType{Params: d.ParamTypes(), Results: d.ResultTypes()} err = errorInvalidImport(i, idx, fmt.Errorf("signature mismatch: %s != %s", expectedType, actualType)) diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 8d73f4ee..24fd1d37 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -337,7 +337,7 @@ func TestCallContext_ExportedFunction(t *testing.T) { fn := importing.ExportedFunction("host.fn") require.NotNil(t, fn) - require.Equal(t, fn.(*importedFn).importedFn, imported.ExportedFunction("host_fn")) + require.Equal(t, fn.(*importedFn).importedFn, imported.ExportedFunction("host_fn").(*function).fi) require.Equal(t, fn.(*importedFn).importingModule, importing) }) } @@ -352,6 +352,11 @@ type mockModuleEngine struct { callFailIndex int } +type mockCallEngine struct { + f *FunctionInstance + callFailIndex int +} + func newStore() (*Store, *Namespace) { return NewStore(Features20191205, &mockEngine{shouldCompileFail: false, callFailIndex: -1}) } @@ -373,8 +378,12 @@ func (e *mockEngine) NewModuleEngine(_ string, _ *Module, _, _ []*FunctionInstan return &mockModuleEngine{callFailIndex: e.callFailIndex}, nil } +func (e *mockModuleEngine) NewCallEngine(callCtx *CallContext, f *FunctionInstance) (CallEngine, error) { + return &mockCallEngine{f: f, callFailIndex: e.callFailIndex}, nil +} + // CreateFuncElementInstance implements the same method as documented on wasm.ModuleEngine. -func (me *mockModuleEngine) CreateFuncElementInstance([]*Index) *ElementInstance { +func (e *mockModuleEngine) CreateFuncElementInstance([]*Index) *ElementInstance { return nil } @@ -386,19 +395,19 @@ func (e *mockModuleEngine) Name() string { return e.name } +// Close implements the same method as documented on wasm.ModuleEngine. +func (e *mockModuleEngine) Close(_ context.Context) { +} + // Call implements the same method as documented on wasm.ModuleEngine. -func (e *mockModuleEngine) Call(ctx context.Context, callCtx *CallContext, f *FunctionInstance, _ ...uint64) (results []uint64, err error) { - if e.callFailIndex >= 0 && f.Definition().Index() == Index(e.callFailIndex) { +func (ce *mockCallEngine) Call(ctx context.Context, callCtx *CallContext, _ ...uint64) (results []uint64, err error) { + if ce.callFailIndex >= 0 && ce.f.FunctionDefinition.Index() == Index(ce.callFailIndex) { err = errors.New("call failed") return } return } -// Close implements the same method as documented on wasm.ModuleEngine. -func (e *mockModuleEngine) Close(_ context.Context) { -} - func TestStore_getFunctionTypeID(t *testing.T) { t.Run("too many functions", func(t *testing.T) { s, _ := newStore() @@ -608,9 +617,9 @@ func Test_resolveImports(t *testing.T) { t.Run("func", func(t *testing.T) { t.Run("ok", func(t *testing.T) { f := &FunctionInstance{ - definition: &FunctionDefinition{funcType: &FunctionType{Results: []ValueType{ValueTypeF32}}}} + FunctionDefinition: &FunctionDefinition{funcType: &FunctionType{Results: []ValueType{ValueTypeF32}}}} g := &FunctionInstance{ - definition: &FunctionDefinition{funcType: &FunctionType{Results: []ValueType{ValueTypeI32}}}} + FunctionDefinition: &FunctionDefinition{funcType: &FunctionType{Results: []ValueType{ValueTypeI32}}}} modules := map[string]*ModuleInstance{ moduleName: { Exports: map[string]*ExportInstance{ @@ -642,7 +651,7 @@ func Test_resolveImports(t *testing.T) { t.Run("signature mismatch", func(t *testing.T) { modules := map[string]*ModuleInstance{ moduleName: {Exports: map[string]*ExportInstance{name: { - Function: &FunctionInstance{definition: &FunctionDefinition{funcType: &FunctionType{}}}, + Function: &FunctionInstance{FunctionDefinition: &FunctionDefinition{funcType: &FunctionType{}}}, }}, Name: moduleName}, } m := &Module{