From 1689fc1bbfd67f074e0400fc1d68840f53ac6b13 Mon Sep 17 00:00:00 2001 From: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com> Date: Mon, 25 Jul 2022 09:12:44 +0800 Subject: [PATCH] Allows wasm-defined host functions to use memory in interpreter (#713) Before, we allowed stubbed host functions to be defined in wasm instead of Go. This improves performance and reduces a chance of side-effects vs Go. In fact, any pure function was supported in wasm, provided it only called pure functions. This changes internals so that a wasm-defined host function can use memory. Notably, host functions use the caller's memory, so this is simpler to initially support in the interpreter. This is needed to simplify and reduce performance hit of GOARCH=wasm, GOOS=js code, which perform a lot of memory reads and do not have idiomatic signatures. Note: wasm-defined host functions remain internal until we gain experience, at least conclusion of the wasm_exec host module. Signed-off-by: Adrian Cole --- RATIONALE.md | 2 +- api/wasm.go | 13 +- assemblyscript/assemblyscript.go | 6 +- builder_test.go | 27 +- emscripten/emscripten.go | 4 +- experimental/log_listener.go | 4 +- experimental/log_listener_test.go | 5 +- internal/engine/compiler/engine.go | 21 +- internal/engine/compiler/engine_test.go | 5 + internal/engine/interpreter/interpreter.go | 29 +- .../engine/interpreter/interpreter_test.go | 95 +++-- internal/testing/enginetest/enginetest.go | 332 ++++++++++++------ internal/wasm/binary/encoder_test.go | 9 +- internal/wasm/func_validation.go | 21 +- internal/wasm/function_definition.go | 33 +- internal/wasm/function_definition_test.go | 15 +- internal/wasm/gofunc.go | 45 ++- internal/wasm/gofunc_test.go | 43 ++- internal/wasm/host.go | 82 +++-- internal/wasm/host_test.go | 22 +- internal/wasm/module.go | 40 ++- internal/wasm/store.go | 4 + internal/watzero/internal/func_parser_test.go | 4 +- internal/wazeroir/compiler.go | 26 +- internal/wazeroir/compiler_test.go | 144 +++++++- wasi_snapshot_preview1/wasi.go | 9 +- 26 files changed, 688 insertions(+), 352 deletions(-) diff --git a/RATIONALE.md b/RATIONALE.md index fffaa7bb..fd3852dd 100644 --- a/RATIONALE.md +++ b/RATIONALE.md @@ -438,7 +438,7 @@ value (possibly `PWD`). Those unable to control the compiled code should only use absolute paths in configuration. See -* https://github.com/golang/go/blob/go1.19beta1/src/syscall/fs_js.go#L324 +* https://github.com/golang/go/blob/go1.19rc2/src/syscall/fs_js.go#L324 * https://github.com/WebAssembly/wasi-libc/pull/214#issue-673090117 ### FdPrestatDirName diff --git a/api/wasm.go b/api/wasm.go index 08c5a5a8..4db74777 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math" + "reflect" ) // ExternType classifies imports and exports with their respective types. @@ -227,12 +228,16 @@ type FunctionDefinition interface { // is possible. ExportNames() []string - // IsHostFunction returns true if the function was implemented by the - // embedder (ex via wazero.ModuleBuilder) instead of a wasm binary. + // GoFunc is present when the function was implemented by the embedder + // (ex via wazero.ModuleBuilder) instead of a wasm binary. + // + // This function can be non-deterministic or cause side effects. It also + // has special properties not defined in the WebAssembly Core + // specification. Notably, it uses the caller's memory, which might be + // different from its defining module. // - // Note: Host functions can be non-deterministic or cause side effects. // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 - IsHostFunction() bool + GoFunc() *reflect.Value // ParamTypes are the possibly empty sequence of value types accepted by a // function with this signature. diff --git a/assemblyscript/assemblyscript.go b/assemblyscript/assemblyscript.go index d234989b..1e783346 100644 --- a/assemblyscript/assemblyscript.go +++ b/assemblyscript/assemblyscript.go @@ -87,7 +87,7 @@ func NewFunctionExporter() FunctionExporter { } type functionExporter struct { - abortFn, traceFn *wasm.Func + abortFn, traceFn *wasm.HostFunc } // WithAbortMessageDisabled implements FunctionExporter.WithAbortMessageDisabled @@ -130,7 +130,7 @@ var abortMessageEnabled = wasm.NewGoFunc( abortWithMessage, ) -var abortMessageDisabled = abortMessageEnabled.WithGoFunc(abort) +var abortMessageDisabled = abortMessageEnabled.MustGoFunc(abort) // abortWithMessage implements fnAbort func abortWithMessage( @@ -176,7 +176,7 @@ var traceStdout = wasm.NewGoFunc(functionTrace, "~lib/builtins/trace", ) // traceStderr implements trace to the configured Stderr. -var traceStderr = traceStdout.WithGoFunc(func( +var traceStderr = traceStdout.MustGoFunc(func( ctx context.Context, mod api.Module, message uint32, nArgs uint32, arg0, arg1, arg2, arg3, arg4 float64, ) { traceTo(ctx, mod, message, nArgs, arg0, arg1, arg2, arg3, arg4, mod.(*wasm.CallContext).Sys.Stderr()) diff --git a/builder_test.go b/builder_test.go index c1c772ed..5b75bef7 100644 --- a/builder_test.go +++ b/builder_test.go @@ -2,7 +2,6 @@ package wazero import ( "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -19,11 +18,9 @@ func TestNewModuleBuilder_Compile(t *testing.T) { uint32_uint32 := func(uint32) uint32 { return 0 } - fnUint32_uint32 := reflect.ValueOf(uint32_uint32) uint64_uint32 := func(uint64) uint32 { return 0 } - fnUint64_uint32 := reflect.ValueOf(uint64_uint32) tests := []struct { name string @@ -54,7 +51,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -73,7 +70,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -93,7 +90,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -114,7 +111,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -138,7 +135,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -163,7 +160,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -464,8 +461,16 @@ func requireHostModuleEquals(t *testing.T, expected, actual *wasm.Module) { require.Equal(t, expected.NameSection, actual.NameSection) // Special case because reflect.Value can't be compared with Equals + // TODO: This is copy/paste with /internal/wasm/host_test.go require.Equal(t, len(expected.CodeSection), len(actual.CodeSection)) - for _, c := range expected.CodeSection { - require.Equal(t, c.GoFunc.Type(), c.GoFunc.Type()) + for i, c := range expected.CodeSection { + actualCode := actual.CodeSection[i] + require.True(t, actualCode.IsHostFunction) + require.Equal(t, c.Kind, actualCode.Kind) + require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + + // Not wasm + require.Nil(t, actualCode.Body) + require.Nil(t, actualCode.LocalTypes) } } diff --git a/emscripten/emscripten.go b/emscripten/emscripten.go index 78f689ec..ab8d2781 100644 --- a/emscripten/emscripten.go +++ b/emscripten/emscripten.go @@ -71,10 +71,10 @@ func (e *functionExporter) ExportFunctions(builder wazero.ModuleBuilder) { // and https://emscripten.org/docs/api_reference/emscripten.h.html#abi-functions const functionNotifyMemoryGrowth = "emscripten_notify_memory_growth" -var notifyMemoryGrowth = &wasm.Func{ +var notifyMemoryGrowth = &wasm.HostFunc{ ExportNames: []string{functionNotifyMemoryGrowth}, Name: functionNotifyMemoryGrowth, ParamTypes: []wasm.ValueType{wasm.ValueTypeI32}, ParamNames: []string{"memory_index"}, - Code: &wasm.Code{Body: []byte{wasm.OpcodeEnd}}, + Code: &wasm.Code{IsHostFunction: true, Body: []byte{wasm.OpcodeEnd}}, } diff --git a/experimental/log_listener.go b/experimental/log_listener.go index 94c1d43d..93056dee 100644 --- a/experimental/log_listener.go +++ b/experimental/log_listener.go @@ -63,14 +63,14 @@ func (l *loggingListener) writeIndented(before bool, err error, vals []uint64, i message.WriteByte('\t') } if before { - if l.fnd.IsHostFunction() { + if l.fnd.GoFunc() != nil { message.WriteString("==> ") } else { message.WriteString("--> ") } l.writeFuncEnter(&message, vals) } else { // after - if l.fnd.IsHostFunction() { + if l.fnd.GoFunc() != nil { message.WriteString("<== ") } else { message.WriteString("<-- ") diff --git a/experimental/log_listener_test.go b/experimental/log_listener_test.go index c7961b51..f41c8939 100644 --- a/experimental/log_listener_test.go +++ b/experimental/log_listener_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -263,7 +262,7 @@ func Test_loggingListener(t *testing.T) { var out bytes.Buffer lf := experimental.NewLoggingListenerFactory(&out) - fnV := reflect.ValueOf(func() {}) + fn := func() {} for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { @@ -284,7 +283,7 @@ func Test_loggingListener(t *testing.T) { } if tc.isHostFunc { - m.CodeSection = []*wasm.Code{{GoFunc: &fnV}} + m.CodeSection = []*wasm.Code{wasm.MustParseGoFuncCode(fn)} } else { m.CodeSection = []*wasm.Code{{Body: []byte{wasm.OpcodeEnd}}} } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index e08a7301..f958512b 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -424,16 +424,16 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { for funcIndex, ir := range irs { var compiled *code if ir.GoFunc != nil { - sig := module.TypeSection[module.FunctionSection[funcIndex]] - if compiled, err = compileHostFunction(sig); err != nil { + if compiled, err = compileHostFunction(ir); err != nil { def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] - return fmt.Errorf("error compiling host func[%s]: %w", def.DebugName(), err) - } - } else { - if compiled, err = compileWasmFunction(e.enabledFeatures, ir); err != nil { - def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] - return fmt.Errorf("error compiling wasm func[%s]: %w", def.DebugName(), err) + return fmt.Errorf("error compiling host go func[%s]: %w", def.DebugName(), err) } + } else if ir.IsHostFunction && ir.UsesMemory { + def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] + return fmt.Errorf("error compiling host wasm func[%s]: memory access not yet supported", def.DebugName()) + } else if compiled, err = compileWasmFunction(e.enabledFeatures, ir); err != nil { + def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] + return fmt.Errorf("error compiling wasm func[%s]: %w", def.DebugName(), err) } // As this uses mmap, we need to munmap on the compiled machine code when it's GCed. @@ -811,12 +811,13 @@ func (ce *callEngine) builtinFunctionTableGrow(ctx context.Context, tables []*wa ce.pushValue(uint64(res)) } -func compileHostFunction(sig *wasm.FunctionType) (*code, error) { - compiler, err := newCompiler(&wazeroir.CompilationResult{Signature: sig}) +func compileHostFunction(ir *wazeroir.CompilationResult) (*code, error) { + compiler, err := newCompiler(ir) if err != nil { return nil, err } + // TODO: Consider when no memory is used (ex !ir.UsesMemory) if err = compiler.compileHostFunction(); err != nil { return nil, err } diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index 14008132..176a4934 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -23,6 +23,11 @@ var et = &engineTester{} // engineTester implements enginetest.EngineTester. type engineTester struct{} +// IsCompiler implements the same method as documented on enginetest.EngineTester. +func (e *engineTester) IsCompiler() bool { + return true +} + // ListenerFactory implements the same method as documented on enginetest.EngineTester. func (e *engineTester) ListenerFactory() experimental.FunctionListenerFactory { return nil diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index 999dbba5..72224004 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -809,10 +809,7 @@ func (ce *callEngine) callFunction(ctx context.Context, callCtx *wasm.CallContex } func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { - if len(ce.frames) > 0 { - // Use the caller's memory, which might be different from the defining module on an imported function. - callCtx = callCtx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) - } + callCtx = callCtx.WithMemory(ce.callerMemory()) if f.source.FunctionListener != nil { ctx = f.source.FunctionListener.Before(ctx, params) } @@ -830,11 +827,19 @@ func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallContext, f *function) { frame := &callFrame{f: f} moduleInst := f.source.Module - memoryInst := moduleInst.Memory + functions := moduleInst.Engine.(*moduleEngine).functions + var memoryInst *wasm.MemoryInstance + if f.source.IsHostFunction { + if memoryInst = ce.callerMemory(); memoryInst == nil { + // If there was no caller, someone called a host function directly. + memoryInst = callCtx.Memory().(*wasm.MemoryInstance) + } + } else { + memoryInst = moduleInst.Memory + } globals := moduleInst.Globals tables := moduleInst.Tables typeIDs := f.source.Module.TypeIDs - functions := f.source.Module.Engine.(*moduleEngine).functions dataInstances := f.source.Module.DataInstances elementInstances := f.source.Module.ElementInstances ce.pushFrame(frame) @@ -4085,6 +4090,18 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallCont ce.popFrame() } +// callerMemory returns the caller context memory or nil if the host function +// was called directly. +func (ce *callEngine) callerMemory() *wasm.MemoryInstance { + if len(ce.frames) > 0 { + lastFrameSource := ce.frames[len(ce.frames)-1].f.source + if !lastFrameSource.IsHostFunction { + return lastFrameSource.Module.Memory + } + } + return nil +} + func WasmCompatMax32bits(v1, v2 uint32) uint64 { return uint64(math.Float32bits(moremath.WasmCompatMax32( math.Float32frombits(v1), diff --git a/internal/engine/interpreter/interpreter_test.go b/internal/engine/interpreter/interpreter_test.go index 7d531438..c395d7af 100644 --- a/internal/engine/interpreter/interpreter_test.go +++ b/internal/engine/interpreter/interpreter_test.go @@ -70,6 +70,12 @@ var listenerFactory = experimental.NewLoggingListenerFactory(&functionLog) // engineTester implements enginetest.EngineTester. type engineTester struct{} +// IsCompiler implements enginetest.EngineTester NewEngine. +func (e engineTester) IsCompiler() bool { + return false +} + +// ListenerFactory implements enginetest.EngineTester NewEngine. func (e engineTester) ListenerFactory() experimental.FunctionListenerFactory { return listenerFactory } @@ -117,31 +123,15 @@ func TestInterpreter_Engine_NewModuleEngine_InitTable(t *testing.T) { func TestInterpreter_ModuleEngine_Call(t *testing.T) { defer functionLog.Reset() enginetest.RunTestModuleEngine_Call(t, et) - require.Equal(t, `--> .$0(1,2) + require.Equal(t, ` +--> .$0(1,2) <-- (1,2) -`, functionLog.String()) +`, "\n"+functionLog.String()) } func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) { defer functionLog.Reset() enginetest.RunTestModuleEngine_Call_HostFn(t, et) - require.Equal(t, `==> .$0(3) -<== (3) ---> imported.wasm_div_by(1) -<-- (1) -==> host.host_div_by(1) -<== (1) ---> imported.call->host_div_by(1) - ==> host.host_div_by(1) - <== (1) -<-- (1) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) - <== (1) - <-- (1) -<-- (1) -`, functionLog.String()) } func TestInterpreter_ModuleEngine_Call_Errors(t *testing.T) { @@ -154,57 +144,58 @@ func TestInterpreter_ModuleEngine_Call_Errors(t *testing.T) { // ==> host.host_div_by(4294967295) // instead of seeing a return like // <== DivByZero - require.Equal(t, `==> host.host_div_by(1) + require.Equal(t, ` +==> host.div_by.go(1) <== (1) -==> host.host_div_by(1) +==> host.div_by.go(1) <== (1) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(1) <-- (1) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(1) <-- (1) ---> imported.wasm_div_by(0) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(0) +--> imported.div_by.wasm(1) <-- (1) -==> host.host_div_by(4294967295) -==> host.host_div_by(1) +==> host.div_by.go(4294967295) +==> host.div_by.go(1) <== (1) -==> host.host_div_by(0) -==> host.host_div_by(1) +==> host.div_by.go(0) +==> host.div_by.go(1) <== (1) ---> imported.call->host_div_by(4294967295) - ==> host.host_div_by(4294967295) ---> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) ---> importing.call_import->call->host_div_by(0) - --> imported.call->host_div_by(0) - ==> host.host_div_by(0) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) ---> importing.call_import->call->host_div_by(4294967295) - --> imported.call->host_div_by(4294967295) - ==> host.host_div_by(4294967295) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(4294967295) + --> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) ---> importing.call_import->call->host_div_by(0) - --> imported.call->host_div_by(0) - ==> host.host_div_by(0) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) -`, functionLog.String()) +`, "\n"+functionLog.String()) } func TestInterpreter_ModuleEngine_Memory(t *testing.T) { diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 8e9bf062..073159d0 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -20,15 +20,19 @@ import ( "context" "errors" "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/testing/require" + "github.com/tetratelabs/wazero/internal/u64" "github.com/tetratelabs/wazero/internal/wasm" ) +const ( + i32, i64 = wasm.ValueTypeI32, wasm.ValueTypeI64 +) + var ( // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") @@ -37,6 +41,9 @@ var ( ) type EngineTester interface { + // IsCompiler returns true if this engine is a compiler. + IsCompiler() bool + NewEngine(enabledFeatures wasm.Features) wasm.Engine ListenerFactory() experimental.FunctionListenerFactory @@ -70,14 +77,14 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { func RunTestEngine_InitializeFuncrefGlobals(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20220419) - i64 := wasm.ValueTypeI64 + i64 := i64 m := &wasm.Module{ TypeSection: []*wasm.FunctionType{{Params: []wasm.ValueType{i64}, Results: []wasm.ValueType{i64}}}, FunctionSection: []uint32{0, 0, 0}, CodeSection: []*wasm.Code{ - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, }, } m.BuildFunctionDefinitions() @@ -92,7 +99,7 @@ func RunTestEngine_InitializeFuncrefGlobals(t *testing.T, et EngineTester) { nullRefVal := wasm.GlobalInstanceNullFuncRefValue globals := []*wasm.GlobalInstance{ - {Val: 10, Type: &wasm.GlobalType{ValType: wasm.ValueTypeI32}}, + {Val: 10, Type: &wasm.GlobalType{ValType: i32}}, {Val: uint64(nullRefVal), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, {Val: uint64(2), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, {Val: uint64(1), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, @@ -115,7 +122,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { // Define a basic function which defines two parameters and two results. // This is used to test results when incorrect arity is used. - i64 := wasm.ValueTypeI64 + i64 := i64 m := &wasm.Module{ TypeSection: []*wasm.FunctionType{ { @@ -227,7 +234,8 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []uint32{0, 0, 0, 0}, CodeSection: []*wasm.Code{ - {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, }, ID: wasm.ModuleID{2}, } @@ -326,55 +334,75 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { }) } -func runTestModuleEngine_Call_HostFn_ModuleContext(t *testing.T, et EngineTester) { - features := wasm.Features20191205 - e := et.NewEngine(features) +func runTestModuleEngine_Call_HostFn_Mem(t *testing.T, et EngineTester, readMem *wasm.Code) { + e := et.NewEngine(wasm.Features20191205) + host, importing, close := setupCallMemTests(t, e, readMem, et.ListenerFactory()) + defer close() - sig := &wasm.FunctionType{ - Params: []wasm.ValueType{wasm.ValueTypeI64}, - Results: []wasm.ValueType{wasm.ValueTypeI64}, - ParamNumInUint64: 1, ResultNumInUint64: 1, + hostMemoryVal := uint64(3) + host.Memory = &wasm.MemoryInstance{Buffer: u64.LeBytes(hostMemoryVal), Min: 1, Cap: 1, Max: 1} + importingMemoryVal := uint64(6) + importing.Memory = &wasm.MemoryInstance{Buffer: u64.LeBytes(importingMemoryVal), Min: 1, Cap: 1, Max: 1} + + tests := []struct { + name string + module *wasm.CallContext + fn *wasm.FunctionInstance + expected uint64 + }{ + { + name: readMemName, + module: host.CallCtx, + fn: host.Exports[readMemName].Function, + expected: hostMemoryVal, + }, + { + name: callReadMemName, + module: host.CallCtx, + fn: host.Exports[callReadMemName].Function, + expected: hostMemoryVal, + }, + { + name: callImportReadMemName, + module: importing.CallCtx, + fn: importing.Exports[callImportReadMemName].Function, + expected: importingMemoryVal, + }, + { + name: callImportCallReadMemName, + module: importing.CallCtx, + fn: importing.Exports[callImportCallReadMemName].Function, + expected: importingMemoryVal, + }, } + for _, tt := range tests { + tc := tt - memory := &wasm.MemoryInstance{} - var mMemory api.Memory - host := reflect.ValueOf(func(m api.Module, v uint64) uint64 { - mMemory = m.Memory() - return v - }) - - m := &wasm.Module{ - TypeSection: []*wasm.FunctionType{sig}, - FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &host}}, + t.Run(tc.name, func(t *testing.T) { + results, err := tc.fn.Module.Engine.Call(testCtx, tc.module, tc.fn) + require.NoError(t, err) + require.Equal(t, tc.expected, results[0]) + }) } - m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) - require.NoError(t, err) - - module := &wasm.ModuleInstance{Memory: memory, TypeIDs: []wasm.FunctionTypeID{0}} - _, ns := wasm.NewStore(features, e) - modCtx := wasm.NewCallContext(ns, module, nil) - - fns := module.BuildFunctions(m, buildListeners(et.ListenerFactory(), m)) - me, err := e.NewModuleEngine(t.Name(), m, nil, fns, nil, nil) - require.NoError(t, err) - - t.Run("defaults to module memory when call stack empty", func(t *testing.T) { - // When calling a host func directly, there may be no stack. This ensures the module's memory is used. - results, err := me.Call(testCtx, modCtx, fns[0], 3) - require.NoError(t, err) - require.Equal(t, uint64(3), results[0]) - require.Same(t, memory, mMemory) - }) } func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { - runTestModuleEngine_Call_HostFn_ModuleContext(t, et) // TODO: refactor to use the same test interface. + t.Run("wasm", func(t *testing.T) { + runTestModuleEngine_Call_HostFn(t, et, hostDivByWasm) + if !et.IsCompiler() { // TODO: Support host wasm func that uses caller memory in compiler. + runTestModuleEngine_Call_HostFn_Mem(t, et, hostReadMemWasm) + } + }) + t.Run("go", func(t *testing.T) { + runTestModuleEngine_Call_HostFn(t, et, hostDivByGo) + runTestModuleEngine_Call_HostFn_Mem(t, et, hostReadMemGo) + }) +} +func runTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester, hostDivBy *wasm.Code) { e := et.NewEngine(wasm.Features20191205) - host, imported, importing, close := setupCallTests(t, e, et.ListenerFactory()) + host, imported, importing, close := setupCallTests(t, e, hostDivBy, et.ListenerFactory()) defer close() // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. @@ -384,24 +412,24 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { fn *wasm.FunctionInstance }{ { - name: wasmFnName, + name: divByWasmName, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, }, { - name: hostFnName, + name: divByGoName, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, }, { - name: callHostFnName, + name: callDivByGoName, module: imported.CallCtx, - fn: imported.Exports[callHostFnName].Function, + fn: imported.Exports[callDivByGoName].Function, }, { - name: callImportCallHostFnName, + name: callImportCallDivByGoName, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, }, } for _, tt := range tests { @@ -420,7 +448,7 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) - host, imported, importing, close := setupCallTests(t, e, et.ListenerFactory()) + host, imported, importing, close := setupCallTests(t, e, hostDivByGo, et.ListenerFactory()) defer close() tests := []struct { @@ -434,99 +462,99 @@ func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { name: "host function not enough parameters", input: []uint64{}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `expected 1 params, but passed 0`, }, { name: "host function too many parameters", input: []uint64{1, 2}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `expected 1 params, but passed 2`, }, { name: "wasm function not enough parameters", input: []uint64{}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `expected 1 params, but passed 0`, }, { name: "wasm function too many parameters", input: []uint64{1, 2}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `expected 1 params, but passed 2`, }, { name: "wasm function panics with wasmruntime.Error", input: []uint64{0}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `wasm error: integer divide by zero wasm stack trace: - imported.wasm_div_by(i32) i32`, + imported.div_by.wasm(i32) i32`, }, { name: "host function that panics", input: []uint64{math.MaxUint32}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32`, + host.div_by.go(i32) i32`, }, { name: "host function panics with runtime.Error", input: []uint64{0}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32`, + host.div_by.go(i32) i32`, }, { name: "wasm calls host function that panics", input: []uint64{math.MaxUint32}, module: imported.CallCtx, - fn: imported.Exports[callHostFnName].Function, + fn: imported.Exports[callDivByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm that calls host function panics with runtime.Error", input: []uint64{0}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm that calls host function that panics", input: []uint64{math.MaxUint32}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm calls host function panics with runtime.Error", input: []uint64{0}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, } for _, tt := range tests { @@ -665,35 +693,61 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { } const ( - wasmFnName = "wasm_div_by" - hostFnName = "host_div_by" - callHostFnName = "call->" + hostFnName - callImportCallHostFnName = "call_import->" + callHostFnName + divByWasmName = "div_by.wasm" + divByGoName = "div_by.go" + callDivByGoName = "call->" + divByGoName + callImportCallDivByGoName = "call_import->" + callDivByGoName ) -// (func (export "wasm_div_by") (param i32) (result i32) (i32.div_u (i32.const 1) (local.get 0))) -var wasmFnBody = []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeLocalGet, 0, wasm.OpcodeI32DivU, wasm.OpcodeEnd} - -func divBy(d uint32) uint32 { +func divByGo(d uint32) uint32 { if d == math.MaxUint32 { panic(errors.New("host-function panic")) } return 1 / d // go panics if d == 0 } -func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, *wasm.ModuleInstance, func()) { - i32 := wasm.ValueTypeI32 +var hostDivByGo = wasm.MustParseGoFuncCode(divByGo) + +// (func (export "div_by.wasm") (param i32) (result i32) (i32.div_u (i32.const 1) (local.get 0))) +var divByWasm = []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeLocalGet, 0, wasm.OpcodeI32DivU, wasm.OpcodeEnd} +var hostDivByWasm = &wasm.Code{IsHostFunction: true, Body: divByWasm} + +const ( + readMemName = "read_mem" + callReadMemName = "call->read_mem" + callImportReadMemName = "call_import->read_mem" + callImportCallReadMemName = "call_import->call->read_mem" +) + +func readMemGo(ctx context.Context, m api.Module) uint64 { + ret, ok := m.Memory().ReadUint64Le(ctx, 0) + if !ok { + panic("couldn't read memory") + } + return ret +} + +var hostReadMemGo = wasm.MustParseGoFuncCode(readMemGo) + +// (func (export "wasm_read_mem") (result i64) i32.const 0 i64.load) +var readMemWasm = []byte{wasm.OpcodeI32Const, 0, wasm.OpcodeI64Load, 0x3, 0x0, wasm.OpcodeEnd} +var hostReadMemWasm = &wasm.Code{IsHostFunction: true, Body: readMemWasm} + +func setupCallTests(t *testing.T, e wasm.Engine, divBy *wasm.Code, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, *wasm.ModuleInstance, func()) { ft := &wasm.FunctionType{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}, ParamNumInUint64: 1, ResultNumInUint64: 1} - hostFnVal := reflect.ValueOf(divBy) + divByName := divByWasmName + if divBy.GoFunc != nil { + divByName = divByGoName + } hostModule := &wasm.Module{ TypeSection: []*wasm.FunctionType{ft}, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &hostFnVal}}, - ExportSection: []*wasm.Export{{Name: hostFnName, Type: wasm.ExternTypeFunc, Index: 0}}, + CodeSection: []*wasm.Code{divBy}, + ExportSection: []*wasm.Export{{Name: divByGoName, Type: wasm.ExternTypeFunc, Index: 0}}, NameSection: &wasm.NameSection{ ModuleName: "host", - FunctionNames: wasm.NameMap{{Index: wasm.Index(0), Name: hostFnName}}, + FunctionNames: wasm.NameMap{{Index: wasm.Index(0), Name: divByName}}, }, ID: wasm.ModuleID{0}, } @@ -703,7 +757,7 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) host.BuildExports(hostModule.ExportSection) - hostFn := host.Exports[hostFnName].Function + hostFn := host.Exports[divByGoName].Function hostME, err := e.NewModuleEngine(host.Name, hostModule, nil, host.Functions, nil, nil) require.NoError(t, err) @@ -714,19 +768,19 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe TypeSection: []*wasm.FunctionType{ft}, FunctionSection: []uint32{0, 0}, CodeSection: []*wasm.Code{ - {Body: wasmFnBody}, + {Body: divByWasm}, {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, byte(0), // Calling imported host function ^. wasm.OpcodeEnd}}, }, ExportSection: []*wasm.Export{ - {Name: wasmFnName, Type: wasm.ExternTypeFunc, Index: 1}, - {Name: callHostFnName, Type: wasm.ExternTypeFunc, Index: 2}, + {Name: divByWasmName, Type: wasm.ExternTypeFunc, Index: 1}, + {Name: callDivByGoName, Type: wasm.ExternTypeFunc, Index: 2}, }, NameSection: &wasm.NameSection{ ModuleName: "imported", FunctionNames: wasm.NameMap{ - {Index: wasm.Index(1), Name: wasmFnName}, - {Index: wasm.Index(2), Name: callHostFnName}, + {Index: wasm.Index(1), Name: divByWasmName}, + {Index: wasm.Index(2), Name: callDivByGoName}, }, }, ID: wasm.ModuleID{1}, @@ -739,7 +793,7 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe importedFunctions := imported.BuildFunctions(importedModule, buildListeners(fnlf, importedModule)) imported.Functions = append([]*wasm.FunctionInstance{hostFn}, importedFunctions...) imported.BuildExports(importedModule.ExportSection) - callHostFn := imported.Exports[callHostFnName].Function + callHostFn := imported.Exports[callDivByGoName].Function // Compile the imported module importedMe, err := e.NewModuleEngine(imported.Name, importedModule, []*wasm.FunctionInstance{hostFn}, importedFunctions, nil, nil) @@ -755,11 +809,11 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 0 /* only one imported function */, wasm.OpcodeEnd}}, }, ExportSection: []*wasm.Export{ - {Name: callImportCallHostFnName, Type: wasm.ExternTypeFunc, Index: 1}, + {Name: callImportCallDivByGoName, Type: wasm.ExternTypeFunc, Index: 1}, }, NameSection: &wasm.NameSection{ ModuleName: "importing", - FunctionNames: wasm.NameMap{{Index: wasm.Index(1), Name: callImportCallHostFnName}}, + FunctionNames: wasm.NameMap{{Index: wasm.Index(1), Name: callImportCallDivByGoName}}, }, ID: wasm.ModuleID{2}, } @@ -785,6 +839,76 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe } } +func setupCallMemTests(t *testing.T, e wasm.Engine, readMem *wasm.Code, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, func()) { + ft := &wasm.FunctionType{Results: []wasm.ValueType{i64}, ResultNumInUint64: 1} + + callReadMem := &wasm.Code{ // shows indirect calls still use the same memory + IsHostFunction: true, + Body: []byte{wasm.OpcodeCall, 0, wasm.OpcodeEnd}, + } + hostModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{ft}, + FunctionSection: []wasm.Index{0, 0}, + CodeSection: []*wasm.Code{readMem, callReadMem}, + ExportSection: []*wasm.Export{ + {Name: readMemName, Type: wasm.ExternTypeFunc, Index: 0}, + {Name: callReadMemName, Type: wasm.ExternTypeFunc, Index: 1}, + }, + NameSection: &wasm.NameSection{ + ModuleName: "host", + FunctionNames: wasm.NameMap{{Index: 0, Name: readMemName}, {Index: 1, Name: callReadMemName}}, + }, + ID: wasm.ModuleID{0}, + } + hostModule.BuildFunctionDefinitions() + err := e.CompileModule(testCtx, hostModule) + require.NoError(t, err) + host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} + host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) + host.BuildExports(hostModule.ExportSection) + readMemFn := host.Exports[readMemName].Function + callReadMemFn := host.Exports[callReadMemName].Function + + hostME, err := e.NewModuleEngine(host.Name, hostModule, nil, host.Functions, nil, nil) + require.NoError(t, err) + linkModuleToEngine(host, hostME) + + importingModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{ft}, + ExportSection: []*wasm.Export{ + {Name: callImportReadMemName, Type: wasm.ExternTypeFunc, Index: 0}, + {Name: callImportCallReadMemName, Type: wasm.ExternTypeFunc, Index: 1}, + }, + NameSection: &wasm.NameSection{ + ModuleName: "importing", + FunctionNames: wasm.NameMap{ + {Index: 0, Name: callImportReadMemName}, + {Index: 1, Name: callImportCallReadMemName}, + }, + }, + ID: wasm.ModuleID{1}, + } + importingModule.BuildFunctionDefinitions() + err = e.CompileModule(testCtx, importingModule) + require.NoError(t, err) + + // Add the exported function. + importing := &wasm.ModuleInstance{Name: importingModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} + importingFunctions := importing.BuildFunctions(importingModule, buildListeners(fnlf, importingModule)) + importing.Functions = append([]*wasm.FunctionInstance{readMemFn, callReadMemFn}, importingFunctions...) + importing.BuildExports(importingModule.ExportSection) + + // Compile the importing module + importingMe, err := e.NewModuleEngine(importing.Name, importingModule, []*wasm.FunctionInstance{readMemFn}, importingFunctions, nil, nil) + require.NoError(t, err) + linkModuleToEngine(importing, importingMe) + + return host, importing, func() { + e.DeleteCompiledModule(hostModule) + e.DeleteCompiledModule(importingModule) + } +} + // linkModuleToEngine assigns fields that wasm.Store would on instantiation. These includes fields both interpreter and // Compiler needs as well as fields only needed by Compiler. // diff --git a/internal/wasm/binary/encoder_test.go b/internal/wasm/binary/encoder_test.go index 12e7cec3..3f8925e7 100644 --- a/internal/wasm/binary/encoder_test.go +++ b/internal/wasm/binary/encoder_test.go @@ -1,9 +1,9 @@ package binary import ( - "reflect" "testing" + "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/wasm" @@ -211,10 +211,13 @@ func TestModule_Encode(t *testing.T) { func TestModule_Encode_HostFunctionSection_Unsupported(t *testing.T) { // We don't currently have an approach to serialize reflect.Value pointers - fn := reflect.ValueOf(func(wasm.Module) {}) + fn := func(api.Module) {} captured := require.CapturePanic(func() { - EncodeModule(&wasm.Module{CodeSection: []*wasm.Code{{GoFunc: &fn}}}) + EncodeModule(&wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(fn)}, + }) }) require.EqualError(t, captured, "BUG: GoFunc is not encodable") } diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 97823bf7..58dd236a 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -64,8 +64,9 @@ func (m *Module) validateFunctionWithMaxStackValues( declaredFunctionIndexes map[Index]struct{}, ) error { functionType := m.TypeSection[m.FunctionSection[idx]] - body := m.CodeSection[idx].Body - localTypes := m.CodeSection[idx].LocalTypes + code := m.CodeSection[idx] + body := code.Body + localTypes := code.LocalTypes types := m.TypeSection // We start with the outermost control block which is for function return if the code branches into it. @@ -90,8 +91,8 @@ func (m *Module) validateFunctionWithMaxStackValues( } if OpcodeI32Load <= op && op <= OpcodeI64Store32 { - if memory == nil { - return fmt.Errorf("unknown memory access") + if memory == nil && !code.IsHostFunction { + return fmt.Errorf("memory must exist for %s", InstructionName(op)) } pc++ align, _, read, err := readMemArg(pc, body) @@ -272,8 +273,8 @@ func (m *Module) validateFunctionWithMaxStackValues( } } } else if OpcodeMemorySize <= op && op <= OpcodeMemoryGrow { - if memory == nil { - return fmt.Errorf("unknown memory access") + if memory == nil && !code.IsHostFunction { + return fmt.Errorf("memory must exist for %s", InstructionName(op)) } pc++ val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) @@ -1100,7 +1101,7 @@ func (m *Module) validateFunctionWithMaxStackValues( OpcodeVecV128Load32x2s, OpcodeVecV128Load32x2u, OpcodeVecV128Load8Splat, OpcodeVecV128Load16Splat, OpcodeVecV128Load32Splat, OpcodeVecV128Load64Splat, OpcodeVecV128Load32zero, OpcodeVecV128Load64zero: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } pc++ @@ -1138,7 +1139,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } valueTypeStack.push(ValueTypeV128) case OpcodeVecV128Store: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } pc++ @@ -1157,7 +1158,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("cannot pop the operand for %s: %v", OpcodeVecV128StoreName, err) } case OpcodeVecV128Load8Lane, OpcodeVecV128Load16Lane, OpcodeVecV128Load32Lane, OpcodeVecV128Load64Lane: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } attr := vecLoadLanes[vecOpcode] @@ -1185,7 +1186,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } valueTypeStack.push(ValueTypeV128) case OpcodeVecV128Store8Lane, OpcodeVecV128Store16Lane, OpcodeVecV128Store32Lane, OpcodeVecV128Store64Lane: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } attr := vecStoreLanes[vecOpcode] diff --git a/internal/wasm/function_definition.go b/internal/wasm/function_definition.go index 81a108b0..238deea8 100644 --- a/internal/wasm/function_definition.go +++ b/internal/wasm/function_definition.go @@ -1,6 +1,8 @@ package wasm import ( + "reflect" + "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/wasmdebug" ) @@ -65,10 +67,11 @@ func (m *Module) BuildFunctionDefinitions() { } for codeIndex, typeIndex := range m.FunctionSection { + code := m.CodeSection[codeIndex] m.FunctionDefinitionSection = append(m.FunctionDefinitionSection, &FunctionDefinition{ - index: Index(codeIndex) + importCount, - funcType: m.TypeSection[typeIndex], - isHostFunction: m.CodeSection[codeIndex].GoFunc != nil, + index: Index(codeIndex) + importCount, + funcType: m.TypeSection[typeIndex], + goFunc: code.GoFunc, }) } @@ -103,15 +106,15 @@ func (m *Module) BuildFunctionDefinitions() { // FunctionDefinition implements api.FunctionDefinition type FunctionDefinition struct { - moduleName string - index Index - name string - debugName string - isHostFunction bool - funcType *FunctionType - importDesc *[2]string - exportNames []string - paramNames []string + moduleName string + index Index + name string + debugName string + goFunc *reflect.Value + funcType *FunctionType + importDesc *[2]string + exportNames []string + paramNames []string } // ModuleName implements the same method as documented on api.FunctionDefinition. @@ -147,9 +150,9 @@ func (f *FunctionDefinition) ExportNames() []string { return f.exportNames } -// IsHostFunction implements the same method as documented on api.FunctionDefinition. -func (f *FunctionDefinition) IsHostFunction() bool { - return f.isHostFunction +// GoFunc implements the same method as documented on api.FunctionDefinition. +func (f *FunctionDefinition) GoFunc() *reflect.Value { + return f.goFunc } // ParamNames implements the same method as documented on api.FunctionDefinition. diff --git a/internal/wasm/function_definition_test.go b/internal/wasm/function_definition_test.go index 1c4bfc3e..8fe3c1aa 100644 --- a/internal/wasm/function_definition_test.go +++ b/internal/wasm/function_definition_test.go @@ -1,7 +1,6 @@ package wasm import ( - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -10,7 +9,7 @@ import ( func TestModule_BuildFunctionDefinitions(t *testing.T) { nopCode := &Code{Body: []byte{OpcodeEnd}} - fnV := reflect.ValueOf(func() {}) + fn := func() {} tests := []struct { name string m *Module @@ -32,18 +31,18 @@ func TestModule_BuildFunctionDefinitions(t *testing.T) { expectedExports: map[string]api.FunctionDefinition{}, }, { - name: "host func", + name: "host func go", m: &Module{ TypeSection: []*FunctionType{v_v}, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnV}}, + CodeSection: []*Code{MustParseGoFuncCode(fn)}, }, expected: []*FunctionDefinition{ { - index: 0, - debugName: ".$0", - isHostFunction: true, - funcType: v_v, + index: 0, + debugName: ".$0", + goFunc: MustParseGoFuncCode(fn).GoFunc, + funcType: v_v, }, }, expectedExports: map[string]api.FunctionDefinition{}, diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index 2f1ab806..fe575220 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -148,16 +148,30 @@ func newModuleVal(m api.Module) reflect.Value { return val } -// getFunctionType returns the function type corresponding to the function signature or errs if invalid. -func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err error) { - p := fn.Type() +// MustParseGoFuncCode parses Code from the go function or panics. +// +// Exposing this simplifies definition of host functions in built-in host +// modules and tests. +func MustParseGoFuncCode(fn interface{}) *Code { + _, _, code, err := parseGoFunc(fn) + if err != nil { + panic(err) + } + return code +} - if fn.Kind() != reflect.Func { - err = fmt.Errorf("kind != func: %s", fn.Kind().String()) +func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err error) { + fnV := reflect.ValueOf(fn) + p := fnV.Type() + + if fnV.Kind() != reflect.Func { + err = fmt.Errorf("kind != func: %s", fnV.Kind().String()) return } - fk = kind(p) + fk := kind(p) + code = &Code{IsHostFunction: true, Kind: fk, GoFunc: &fnV} + pOffset := 0 switch fk { case FunctionKindGoNoContext: @@ -167,13 +181,14 @@ func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err pOffset = 1 } - rCount := p.NumOut() - - ft = &FunctionType{Params: make([]ValueType, p.NumIn()-pOffset), Results: make([]ValueType, rCount)} - for i := 0; i < len(ft.Params); i++ { + pCount := p.NumIn() - pOffset + if pCount > 0 { + params = make([]ValueType, pCount) + } + for i := 0; i < len(params); i++ { pI := p.In(i + pOffset) if t, ok := getTypeOf(pI.Kind()); ok { - ft.Params[i] = t + params[i] = t continue } @@ -193,10 +208,14 @@ func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err return } - for i := 0; i < len(ft.Results); i++ { + rCount := p.NumOut() + if rCount > 0 { + results = make([]ValueType, rCount) + } + for i := 0; i < len(results); i++ { rI := p.Out(i) if t, ok := getTypeOf(rI.Kind()); ok { - ft.Results[i] = t + results[i] = t continue } diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index d9c696ad..afe7c3e3 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -3,7 +3,6 @@ package wasm import ( "context" "math" - "reflect" "testing" "unsafe" @@ -14,7 +13,7 @@ import ( // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") -func TestGetFunctionType(t *testing.T) { +func Test_parseGoFunc(t *testing.T) { var tests = []struct { name string inputFunc interface{} @@ -25,25 +24,25 @@ func TestGetFunctionType(t *testing.T) { name: "nullary", inputFunc: func() {}, expectedKind: FunctionKindGoNoContext, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "wasm.Module void return", inputFunc: func(api.Module) {}, expectedKind: FunctionKindGoModule, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "context.Context void return", inputFunc: func(context.Context) {}, expectedKind: FunctionKindGoContext, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "context.Context and api.Module void return", inputFunc: func(context.Context, api.Module) {}, expectedKind: FunctionKindGoContextModule, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "all supported params and i32 result", @@ -85,16 +84,15 @@ func TestGetFunctionType(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - rVal := reflect.ValueOf(tc.inputFunc) - fk, ft, err := getFunctionType(&rVal) + paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - require.Equal(t, tc.expectedKind, fk) - require.Equal(t, tc.expectedType, ft) + require.Equal(t, tc.expectedKind, code.Kind) + require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes}) }) } } -func TestGetFunctionTypeErrors(t *testing.T) { +func Test_parseGoFunc_Errors(t *testing.T) { tests := []struct { name string input interface{} @@ -142,8 +140,7 @@ func TestGetFunctionTypeErrors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - rVal := reflect.ValueOf(tc.input) - _, _, err := getFunctionType(&rVal) + _, _, _, err := parseGoFunc(tc.input) require.EqualError(t, err, tc.expectedErr) }) } @@ -247,11 +244,10 @@ func TestPopGoFuncParams(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - goFunc := reflect.ValueOf(tc.inputFunc) - fk, _, err := getFunctionType(&goFunc) + _, _, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - vals := PopGoFuncParams(&FunctionInstance{Kind: fk, GoFunc: &goFunc}, (&stack{stackVals}).pop) + vals := PopGoFuncParams(&FunctionInstance{Kind: code.Kind, GoFunc: code.GoFunc}, (&stack{stackVals}).pop) require.Equal(t, tc.expected, vals) }) } @@ -400,11 +396,20 @@ func TestCallGoFunc(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - goFunc := reflect.ValueOf(tc.inputFunc) - fk, _, err := getFunctionType(&goFunc) + paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - results := CallGoFunc(testCtx, callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) + results := CallGoFunc( + testCtx, + callCtx, + &FunctionInstance{ + IsHostFunction: code.IsHostFunction, + Kind: code.Kind, + Type: &FunctionType{Params: paramTypes, Results: resultTypes}, + GoFunc: code.GoFunc, + }, + tc.inputParams, + ) require.Equal(t, tc.expectedResults, results) }) } diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 7cf4eec5..4a5ecb70 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -2,16 +2,15 @@ package wasm import ( "fmt" - "reflect" "sort" "strings" "github.com/tetratelabs/wazero/internal/wasmdebug" ) -// Func is a function with an inlined type, typically used for NewHostModule. +// HostFunc is a function with an inlined type, used for NewHostModule. // Any corresponding FunctionType will be reused or added to the Module. -type Func struct { +type HostFunc struct { // ExportNames is equivalent to the same method on api.FunctionDefinition. ExportNames []string @@ -31,35 +30,39 @@ type Func struct { Code *Code } -// NewGoFunc returns a Func for the given parameters or panics. -func NewGoFunc(exportName string, name string, paramNames []string, fn interface{}) *Func { - fnV := reflect.ValueOf(fn) - _, ft, err := getFunctionType(&fnV) - if err != nil { - panic(err) - } - return &Func{ +// NewGoFunc returns a HostFunc for the given parameters or panics. +func NewGoFunc(exportName string, name string, paramNames []string, fn interface{}) *HostFunc { + return (&HostFunc{ ExportNames: []string{exportName}, Name: name, - ParamTypes: ft.Params, - ResultTypes: ft.Results, ParamNames: paramNames, - Code: &Code{GoFunc: &fnV}, + }).MustGoFunc(fn) +} + +// MustGoFunc calls WithGoFunc or panics on error. +func (f *HostFunc) MustGoFunc(fn interface{}) *HostFunc { + if ret, err := f.WithGoFunc(fn); err != nil { + panic(err) + } else { + return ret } } // WithGoFunc returns a copy of the function, replacing its Code.GoFunc. -func (f *Func) WithGoFunc(fn interface{}) *Func { +func (f *HostFunc) WithGoFunc(fn interface{}) (*HostFunc, error) { ret := *f - fnV := reflect.ValueOf(fn) - ret.Code = &Code{GoFunc: &fnV} - return &ret + var err error + ret.ParamTypes, ret.ResultTypes, ret.Code, err = parseGoFunc(fn) + return &ret, err } // WithWasm returns a copy of the function, replacing its Code.Body. -func (f *Func) WithWasm(body []byte) *Func { +func (f *HostFunc) WithWasm(body []byte) *HostFunc { ret := *f - ret.Code = &Code{Body: body} + ret.Code = &Code{IsHostFunction: true, Body: body} + if f.Code != nil { + ret.Code.LocalTypes = f.Code.LocalTypes + } return &ret } @@ -138,41 +141,46 @@ func addFuncs( m.NameSection = &NameSection{} } moduleName := m.NameSection.ModuleName - nameToFunc := make(map[string]*Func, len(nameToGoFunc)) + nameToFunc := make(map[string]*HostFunc, len(nameToGoFunc)) + sortedExportNames := make([]string, len(nameToFunc)) + for k := range nameToGoFunc { + sortedExportNames = append(sortedExportNames, k) + } + + // Sort names for consistent iteration + sort.Strings(sortedExportNames) + funcNames := make([]string, len(nameToFunc)) - for k, v := range nameToGoFunc { - if hf, ok := v.(*Func); !ok { - fn := reflect.ValueOf(v) - _, ft, ftErr := getFunctionType(&fn) + for _, k := range sortedExportNames { + v := nameToGoFunc[k] + if hf, ok := v.(*HostFunc); ok { + nameToFunc[hf.Name] = hf + funcNames = append(funcNames, hf.Name) + } else { + params, results, code, ftErr := parseGoFunc(v) if ftErr != nil { return fmt.Errorf("func[%s.%s] %w", moduleName, k, ftErr) } - hf = &Func{ + hf = &HostFunc{ ExportNames: []string{k}, Name: k, - ParamTypes: ft.Params, - ResultTypes: ft.Results, - Code: &Code{GoFunc: &fn}, + ParamTypes: params, + ResultTypes: results, + Code: code, } if names := funcToNames[k]; names != nil { namesLen := len(names) - if namesLen > 1 && namesLen-1 != len(ft.Params) { - return fmt.Errorf("func[%s.%s] has %d params, but %d param names", moduleName, k, namesLen-1, len(ft.Params)) + if namesLen > 1 && namesLen-1 != len(params) { + return fmt.Errorf("func[%s.%s] has %d params, but %d param names", moduleName, k, namesLen-1, len(params)) } hf.Name = names[0] hf.ParamNames = names[1:] } nameToFunc[k] = hf funcNames = append(funcNames, k) - } else { - nameToFunc[hf.Name] = hf - funcNames = append(funcNames, hf.Name) } } - // Sort names for consistent iteration - sort.Strings(funcNames) - funcCount := uint32(len(nameToFunc)) m.NameSection.FunctionNames = make([]*NameAssoc, 0, funcCount) m.FunctionSection = make([]Index, 0, funcCount) diff --git a/internal/wasm/host_test.go b/internal/wasm/host_test.go index 03981b03..43a63afb 100644 --- a/internal/wasm/host_test.go +++ b/internal/wasm/host_test.go @@ -1,7 +1,6 @@ package wasm import ( - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -32,11 +31,8 @@ func swap(x, y uint32) (uint32, uint32) { func TestNewHostModule(t *testing.T) { a := wasiAPI{} functionArgsSizesGet := "args_sizes_get" - fnArgsSizesGet := reflect.ValueOf(a.ArgsSizesGet) functionFdWrite := "fd_write" - fnFdWrite := reflect.ValueOf(a.FdWrite) functionSwap := "swap" - fnSwap := reflect.ValueOf(swap) tests := []struct { name, moduleName string @@ -67,7 +63,7 @@ func TestNewHostModule(t *testing.T) { {Params: []ValueType{i32, i32, i32, i32}, Results: []ValueType{i32}}, }, FunctionSection: []Index{0, 1}, - CodeSection: []*Code{{GoFunc: &fnArgsSizesGet}, {GoFunc: &fnFdWrite}}, + CodeSection: []*Code{MustParseGoFuncCode(a.ArgsSizesGet), MustParseGoFuncCode(a.FdWrite)}, ExportSection: []*Export{ {Name: "args_sizes_get", Type: ExternTypeFunc, Index: 0}, {Name: "fd_write", Type: ExternTypeFunc, Index: 1}, @@ -90,7 +86,7 @@ func TestNewHostModule(t *testing.T) { expected: &Module{ TypeSection: []*FunctionType{{Params: []ValueType{i32, i32}, Results: []ValueType{i32, i32}}}, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnSwap}}, + CodeSection: []*Code{MustParseGoFuncCode(swap)}, ExportSection: []*Export{{Name: "swap", Type: ExternTypeFunc, Index: 0}}, NameSection: &NameSection{ModuleName: "swapper", FunctionNames: NameMap{{Index: 0, Name: "swap"}}}, }, @@ -152,7 +148,7 @@ func TestNewHostModule(t *testing.T) { {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, }, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnArgsSizesGet}}, + CodeSection: []*Code{MustParseGoFuncCode(a.ArgsSizesGet)}, GlobalSection: []*Global{ { Type: &GlobalType{ValType: i32}, @@ -201,9 +197,17 @@ func requireHostModuleEquals(t *testing.T, expected, actual *Module) { require.Equal(t, expected.NameSection, actual.NameSection) // Special case because reflect.Value can't be compared with Equals + // TODO: This is copy/paste with /builder_test.go require.Equal(t, len(expected.CodeSection), len(actual.CodeSection)) - for _, c := range expected.CodeSection { - require.Equal(t, c.GoFunc.Type(), c.GoFunc.Type()) + for i, c := range expected.CodeSection { + actualCode := actual.CodeSection[i] + require.True(t, actualCode.IsHostFunction) + require.Equal(t, c.Kind, actualCode.Kind) + require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + + // Not wasm + require.Nil(t, actualCode.Body) + require.Nil(t, actualCode.LocalTypes) } } diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 446cd438..e9a5c3c6 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -587,16 +587,13 @@ func (m *ModuleInstance) BuildFunctions(mod *Module, listeners []experimental.Fu fns = make([]*FunctionInstance, 0, len(mod.FunctionDefinitionSection)) for i := range mod.FunctionSection { code := mod.CodeSection[i] - fnKind := FunctionKindWasm - if fn := code.GoFunc; fn != nil { - fnKind = kind(fn.Type()) - } fns = append(fns, &FunctionInstance{ - Kind: fnKind, - Body: code.Body, - GoFunc: code.GoFunc, - TypeID: m.TypeIDs[mod.FunctionSection[i]], - LocalTypes: code.LocalTypes, + IsHostFunction: code.IsHostFunction, + Kind: code.Kind, + LocalTypes: code.LocalTypes, + Body: code.Body, + GoFunc: code.GoFunc, + TypeID: m.TypeIDs[mod.FunctionSection[i]], }) } @@ -804,14 +801,17 @@ type Export struct { // Code is an entry in the Module.CodeSection containing the locals and body of the function. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-code type Code struct { + // IsHostFunction returns true if the function was implemented by the + // embedder (ex via wazero.ModuleBuilder) instead of a wasm binary. + // + // Notably, host functions can use the caller's memory, which might be + // different from its defining module. + // + // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 + IsHostFunction bool - // GoFunc is a host function defined in Go. - // - // When present, LocalTypes and Body must be nil. - // - // Note: This has no serialization format, so is not encodable. - // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 - GoFunc *reflect.Value + // Kind describes how this function should be called. + Kind FunctionKind // LocalTypes are any function-scoped variables in insertion order. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-local @@ -820,6 +820,14 @@ type Code struct { // Body is a sequence of expressions ending in OpcodeEnd // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-expr Body []byte + + // GoFunc is a host function defined in Go. + // + // When present, LocalTypes and Body must be nil. + // + // Note: This has no serialization format, so is not encodable. + // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 + GoFunc *reflect.Value } type DataSegment struct { diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 391375ab..545a4a5b 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -108,6 +108,10 @@ type ( // FunctionInstance represents a function instance in a Store. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#function-instances%E2%91%A0 FunctionInstance struct { + // IsHostFunction is the data returned by the same field documented on + // wasm.Code. + IsHostFunction bool + // Kind describes how this function should be called. Kind FunctionKind diff --git a/internal/watzero/internal/func_parser_test.go b/internal/watzero/internal/func_parser_test.go index 9d73d7a0..5e581dcb 100644 --- a/internal/watzero/internal/func_parser_test.go +++ b/internal/watzero/internal/func_parser_test.go @@ -76,7 +76,7 @@ func TestFuncParser(t *testing.T) { name: "i32.load", source: "(func i32.const 8 i32.load)", expected: &wasm.Code{Body: []byte{ - wasm.OpcodeI32Const, 8, // dynamic memory offset to load + wasm.OpcodeI32Const, 8, // memory offset to load wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 wasm.OpcodeEnd, }}, @@ -100,7 +100,7 @@ func TestFuncParser(t *testing.T) { name: "i64.load", source: "(func i32.const 8 i64.load)", expected: &wasm.Code{Body: []byte{ - wasm.OpcodeI32Const, 8, // dynamic memory offset to load + wasm.OpcodeI32Const, 8, // memory offset to load wasm.OpcodeI64Load, 0x3, 0x0, // load alignment=3 (natural alignment) staticOffset=0 wasm.OpcodeEnd, }}, diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 3129a016..a6538d0c 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -177,12 +177,17 @@ func (c *compiler) resetUnreachable() { } type CompilationResult struct { - // GoFunc is present when the result is a host function. - // In this case, other fields can be ignored. + // IsHostFunction is the data returned by the same field documented on + // wasm.Code. + IsHostFunction bool + + // GoFunc is the data returned by the same field documented on wasm.Code. + // In this case, IsHostFunction is true and other fields can be ignored. GoFunc *reflect.Value // Operations holds wazeroir operations compiled from Wasm instructions in a Wasm function. Operations []Operation + // LabelCallers maps Label.String() to the number of callers to that label. // Here "callers" means that the call-sites which jumps to the label with br, br_if or br_table // instructions. @@ -209,6 +214,8 @@ type CompilationResult struct { TableTypes []wasm.ValueType // HasMemory is true if the module from which this function is compiled has memory declaration. HasMemory bool + // UsesMemory is true if this function might use memory. + UsesMemory bool // HasTable is true if the module from which this function is compiled has table declaration. HasTable bool // HasDataInstances is true if the module has data instances which might be used by memory.init or data.drop instructions. @@ -239,7 +246,13 @@ func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module * sig := module.TypeSection[typeID] code := module.CodeSection[funcIndex] if code.GoFunc != nil { - ret = append(ret, &CompilationResult{GoFunc: code.GoFunc}) + ret = append(ret, &CompilationResult{ + IsHostFunction: true, + // Assume the function might use memory if it has a parameter for the api.Module + UsesMemory: code.Kind == wasm.FunctionKindGoModule || code.Kind == wasm.FunctionKindGoContextModule, + GoFunc: code.GoFunc, + Signature: sig, + }) continue } r, err := compile(enabledFeatures, sig, code.Body, code.LocalTypes, module.TypeSection, functions, globals) @@ -247,6 +260,7 @@ func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module * def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] return nil, fmt.Errorf("failed to lower func[%s] to wazeroir: %w", def.DebugName(), err) } + r.IsHostFunction = code.IsHostFunction r.Globals = globals r.Functions = functions r.Types = module.TypeSection @@ -1045,11 +1059,13 @@ operatorSwitch: &OperationStore32{Arg: imm}, ) case wasm.OpcodeMemorySize: + c.result.UsesMemory = true c.pc++ // Skip the reserved one byte. c.emit( &OperationMemorySize{}, ) case wasm.OpcodeMemoryGrow: + c.result.UsesMemory = true c.pc++ // Skip the reserved one byte. c.emit( &OperationMemoryGrow{}, @@ -1672,6 +1688,7 @@ operatorSwitch: &OperationITruncFromF{InputType: Float64, OutputType: SignedUint64, NonTrapping: true}, ) case wasm.OpcodeMiscMemoryInit: + c.result.UsesMemory = true dataIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) @@ -1690,11 +1707,13 @@ operatorSwitch: &OperationDataDrop{DataIndex: dataIndex}, ) case wasm.OpcodeMiscMemoryCopy: + c.result.UsesMemory = true c.pc += 2 // +2 to skip two memory indexes which are fixed to zero. c.emit( &OperationMemoryCopy{}, ) case wasm.OpcodeMiscMemoryFill: + c.result.UsesMemory = true c.pc += 1 // +1 to skip the memory index which is fixed to zero. c.emit( &OperationMemoryFill{}, @@ -3072,6 +3091,7 @@ func (c *compiler) stackLenInUint64(ceil int) (ret int) { } func (c *compiler) readMemoryArg(tag string) (*MemoryArg, error) { + c.result.UsesMemory = true r := bytes.NewReader(c.body[c.pc+1:]) alignment, num, err := leb128.DecodeUint32(r) if err != nil { diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 450f2476..bb2d2781 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -51,11 +51,57 @@ func TestCompile(t *testing.T) { }, LabelCallers: map[string]uint32{}, Functions: []uint32{0}, - Types: []*wasm.FunctionType{{}}, - Signature: &wasm.FunctionType{}, + Types: []*wasm.FunctionType{v_v}, + Signature: v_v, TableTypes: []wasm.RefType{}, }, }, + { + name: "host wasm nullary", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{IsHostFunction: true, Body: []byte{wasm.OpcodeEnd}}}, + }, + expected: &CompilationResult{ + IsHostFunction: true, + Operations: []Operation{ // begin with params: [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Functions: []uint32{0}, + Types: []*wasm.FunctionType{v_v}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + }, + }, + { + name: "host go nullary", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func() {})}, + }, + expected: &CompilationResult{IsHostFunction: true}, + }, + { + name: "host go api.Module uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(api.Module) {})}, + }, + expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, + }, + { + name: "host go context.Context api.Module uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(context.Context, api.Module) {})}, + }, + expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, + }, { name: "identity", module: requireModuleText(t, `(module @@ -82,6 +128,61 @@ func TestCompile(t *testing.T) { TableTypes: []wasm.RefType{}, }, }, + { + name: "uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{Body: []byte{ + wasm.OpcodeI32Const, 8, // memory offset to load + wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 + wasm.OpcodeDrop, + wasm.OpcodeEnd, + }}}, + }, + expected: &CompilationResult{ + Operations: []Operation{ // begin with params: [] + &OperationConstI32{Value: 8}, // [8] + &OperationLoad{Type: UnsignedTypeI32, Arg: &MemoryArg{Alignment: 2, Offset: 0}}, // [x] + &OperationDrop{Depth: &InclusiveRange{}}, // [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Types: []*wasm.FunctionType{v_v}, + Functions: []uint32{0}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + UsesMemory: true, + }, + }, + { + name: "host uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{IsHostFunction: true, Body: []byte{ + wasm.OpcodeI32Const, 8, // memory offset to load + wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 + wasm.OpcodeDrop, + wasm.OpcodeEnd, + }}}, + }, + expected: &CompilationResult{ + IsHostFunction: true, + Operations: []Operation{ // begin with params: [] + &OperationConstI32{Value: 8}, // [8] + &OperationLoad{Type: UnsignedTypeI32, Arg: &MemoryArg{Alignment: 2, Offset: 0}}, // [x] + &OperationDrop{Depth: &InclusiveRange{}}, // [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Types: []*wasm.FunctionType{v_v}, + Functions: []uint32{0}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + UsesMemory: true, + }, + }, { name: "memory.grow", // Ex to expose ops to grow memory module: requireModuleText(t, `(module @@ -107,6 +208,7 @@ func TestCompile(t *testing.T) { ResultNumInUint64: 1, }, TableTypes: []wasm.RefType{}, + UsesMemory: true, }, }, } @@ -124,7 +226,16 @@ func TestCompile(t *testing.T) { } res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) - require.Equal(t, tc.expected, res[0]) + + fn := res[0] + if fn.GoFunc != nil { // can't compare functions + // Special case because reflect.Value can't be compared with Equals + require.True(t, fn.IsHostFunction) + require.Equal(t, tc.expected.UsesMemory, fn.UsesMemory) + require.Equal(t, &tc.module.CodeSection[0].GoFunc, &fn.GoFunc) + } else { + require.Equal(t, tc.expected, fn) + } }) } } @@ -244,6 +355,7 @@ func TestCompile_BulkMemoryOperations(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, HasMemory: true, + UsesMemory: true, HasDataInstances: true, LabelCallers: map[string]uint32{}, Signature: v_v, @@ -742,7 +854,7 @@ func TestCompile_Refs(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, } @@ -810,7 +922,7 @@ func TestCompile_TableGetOrSet(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, TableSection: []*wasm.Table{{}}, @@ -879,7 +991,7 @@ func TestCompile_TableGrowFillSize(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, TableSection: []*wasm.Table{{}}, @@ -933,7 +1045,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.get - non func param - v128", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -996,7 +1108,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.set - non func param - v128", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -1070,7 +1182,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.tee - non func param", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -2443,7 +2555,7 @@ func TestCompile_Vec(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, MemorySection: &wasm.Memory{}, CodeSection: []*wasm.Code{{Body: tc.body}}, @@ -2477,7 +2589,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2492,7 +2604,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br_if", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2508,7 +2620,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br_table", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2543,7 +2655,7 @@ func TestCompile_drop_vectors(t *testing.T) { { name: "basic", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, @@ -2581,7 +2693,7 @@ func TestCompile_select_vectors(t *testing.T) { { name: "non typed", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, @@ -2607,7 +2719,7 @@ func TestCompile_select_vectors(t *testing.T) { { name: "typed", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, diff --git a/wasi_snapshot_preview1/wasi.go b/wasi_snapshot_preview1/wasi.go index 45747cbf..f1007038 100644 --- a/wasi_snapshot_preview1/wasi.go +++ b/wasi_snapshot_preview1/wasi.go @@ -181,13 +181,16 @@ func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, va } // stubFunction stubs for GrainLang per #271. -func stubFunction(name string, paramTypes []wasm.ValueType, paramNames []string) *wasm.Func { - return &wasm.Func{ +func stubFunction(name string, paramTypes []wasm.ValueType, paramNames []string) *wasm.HostFunc { + return &wasm.HostFunc{ Name: name, ExportNames: []string{name}, ParamTypes: paramTypes, ParamNames: paramNames, ResultTypes: []wasm.ValueType{i32}, - Code: &wasm.Code{Body: []byte{wasm.OpcodeI32Const, byte(ErrnoNosys), wasm.OpcodeEnd}}, + Code: &wasm.Code{ + IsHostFunction: true, + Body: []byte{wasm.OpcodeI32Const, byte(ErrnoNosys), wasm.OpcodeEnd}, + }, } }