diff --git a/RATIONALE.md b/RATIONALE.md index fffaa7bb..fd3852dd 100644 --- a/RATIONALE.md +++ b/RATIONALE.md @@ -438,7 +438,7 @@ value (possibly `PWD`). Those unable to control the compiled code should only use absolute paths in configuration. See -* https://github.com/golang/go/blob/go1.19beta1/src/syscall/fs_js.go#L324 +* https://github.com/golang/go/blob/go1.19rc2/src/syscall/fs_js.go#L324 * https://github.com/WebAssembly/wasi-libc/pull/214#issue-673090117 ### FdPrestatDirName diff --git a/api/wasm.go b/api/wasm.go index 08c5a5a8..4db74777 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math" + "reflect" ) // ExternType classifies imports and exports with their respective types. @@ -227,12 +228,16 @@ type FunctionDefinition interface { // is possible. ExportNames() []string - // IsHostFunction returns true if the function was implemented by the - // embedder (ex via wazero.ModuleBuilder) instead of a wasm binary. + // GoFunc is present when the function was implemented by the embedder + // (ex via wazero.ModuleBuilder) instead of a wasm binary. + // + // This function can be non-deterministic or cause side effects. It also + // has special properties not defined in the WebAssembly Core + // specification. Notably, it uses the caller's memory, which might be + // different from its defining module. // - // Note: Host functions can be non-deterministic or cause side effects. // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 - IsHostFunction() bool + GoFunc() *reflect.Value // ParamTypes are the possibly empty sequence of value types accepted by a // function with this signature. diff --git a/assemblyscript/assemblyscript.go b/assemblyscript/assemblyscript.go index d234989b..1e783346 100644 --- a/assemblyscript/assemblyscript.go +++ b/assemblyscript/assemblyscript.go @@ -87,7 +87,7 @@ func NewFunctionExporter() FunctionExporter { } type functionExporter struct { - abortFn, traceFn *wasm.Func + abortFn, traceFn *wasm.HostFunc } // WithAbortMessageDisabled implements FunctionExporter.WithAbortMessageDisabled @@ -130,7 +130,7 @@ var abortMessageEnabled = wasm.NewGoFunc( abortWithMessage, ) -var abortMessageDisabled = abortMessageEnabled.WithGoFunc(abort) +var abortMessageDisabled = abortMessageEnabled.MustGoFunc(abort) // abortWithMessage implements fnAbort func abortWithMessage( @@ -176,7 +176,7 @@ var traceStdout = wasm.NewGoFunc(functionTrace, "~lib/builtins/trace", ) // traceStderr implements trace to the configured Stderr. -var traceStderr = traceStdout.WithGoFunc(func( +var traceStderr = traceStdout.MustGoFunc(func( ctx context.Context, mod api.Module, message uint32, nArgs uint32, arg0, arg1, arg2, arg3, arg4 float64, ) { traceTo(ctx, mod, message, nArgs, arg0, arg1, arg2, arg3, arg4, mod.(*wasm.CallContext).Sys.Stderr()) diff --git a/builder_test.go b/builder_test.go index c1c772ed..5b75bef7 100644 --- a/builder_test.go +++ b/builder_test.go @@ -2,7 +2,6 @@ package wazero import ( "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -19,11 +18,9 @@ func TestNewModuleBuilder_Compile(t *testing.T) { uint32_uint32 := func(uint32) uint32 { return 0 } - fnUint32_uint32 := reflect.ValueOf(uint32_uint32) uint64_uint32 := func(uint64) uint32 { return 0 } - fnUint64_uint32 := reflect.ValueOf(uint64_uint32) tests := []struct { name string @@ -54,7 +51,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -73,7 +70,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -93,7 +90,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -114,7 +111,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -138,7 +135,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -163,7 +160,7 @@ func TestNewModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{{GoFunc: &fnUint32_uint32}, {GoFunc: &fnUint64_uint32}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -464,8 +461,16 @@ func requireHostModuleEquals(t *testing.T, expected, actual *wasm.Module) { require.Equal(t, expected.NameSection, actual.NameSection) // Special case because reflect.Value can't be compared with Equals + // TODO: This is copy/paste with /internal/wasm/host_test.go require.Equal(t, len(expected.CodeSection), len(actual.CodeSection)) - for _, c := range expected.CodeSection { - require.Equal(t, c.GoFunc.Type(), c.GoFunc.Type()) + for i, c := range expected.CodeSection { + actualCode := actual.CodeSection[i] + require.True(t, actualCode.IsHostFunction) + require.Equal(t, c.Kind, actualCode.Kind) + require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + + // Not wasm + require.Nil(t, actualCode.Body) + require.Nil(t, actualCode.LocalTypes) } } diff --git a/emscripten/emscripten.go b/emscripten/emscripten.go index 78f689ec..ab8d2781 100644 --- a/emscripten/emscripten.go +++ b/emscripten/emscripten.go @@ -71,10 +71,10 @@ func (e *functionExporter) ExportFunctions(builder wazero.ModuleBuilder) { // and https://emscripten.org/docs/api_reference/emscripten.h.html#abi-functions const functionNotifyMemoryGrowth = "emscripten_notify_memory_growth" -var notifyMemoryGrowth = &wasm.Func{ +var notifyMemoryGrowth = &wasm.HostFunc{ ExportNames: []string{functionNotifyMemoryGrowth}, Name: functionNotifyMemoryGrowth, ParamTypes: []wasm.ValueType{wasm.ValueTypeI32}, ParamNames: []string{"memory_index"}, - Code: &wasm.Code{Body: []byte{wasm.OpcodeEnd}}, + Code: &wasm.Code{IsHostFunction: true, Body: []byte{wasm.OpcodeEnd}}, } diff --git a/experimental/log_listener.go b/experimental/log_listener.go index 94c1d43d..93056dee 100644 --- a/experimental/log_listener.go +++ b/experimental/log_listener.go @@ -63,14 +63,14 @@ func (l *loggingListener) writeIndented(before bool, err error, vals []uint64, i message.WriteByte('\t') } if before { - if l.fnd.IsHostFunction() { + if l.fnd.GoFunc() != nil { message.WriteString("==> ") } else { message.WriteString("--> ") } l.writeFuncEnter(&message, vals) } else { // after - if l.fnd.IsHostFunction() { + if l.fnd.GoFunc() != nil { message.WriteString("<== ") } else { message.WriteString("<-- ") diff --git a/experimental/log_listener_test.go b/experimental/log_listener_test.go index c7961b51..f41c8939 100644 --- a/experimental/log_listener_test.go +++ b/experimental/log_listener_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -263,7 +262,7 @@ func Test_loggingListener(t *testing.T) { var out bytes.Buffer lf := experimental.NewLoggingListenerFactory(&out) - fnV := reflect.ValueOf(func() {}) + fn := func() {} for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { @@ -284,7 +283,7 @@ func Test_loggingListener(t *testing.T) { } if tc.isHostFunc { - m.CodeSection = []*wasm.Code{{GoFunc: &fnV}} + m.CodeSection = []*wasm.Code{wasm.MustParseGoFuncCode(fn)} } else { m.CodeSection = []*wasm.Code{{Body: []byte{wasm.OpcodeEnd}}} } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index e08a7301..f958512b 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -424,16 +424,16 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { for funcIndex, ir := range irs { var compiled *code if ir.GoFunc != nil { - sig := module.TypeSection[module.FunctionSection[funcIndex]] - if compiled, err = compileHostFunction(sig); err != nil { + if compiled, err = compileHostFunction(ir); err != nil { def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] - return fmt.Errorf("error compiling host func[%s]: %w", def.DebugName(), err) - } - } else { - if compiled, err = compileWasmFunction(e.enabledFeatures, ir); err != nil { - def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] - return fmt.Errorf("error compiling wasm func[%s]: %w", def.DebugName(), err) + return fmt.Errorf("error compiling host go func[%s]: %w", def.DebugName(), err) } + } else if ir.IsHostFunction && ir.UsesMemory { + def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] + return fmt.Errorf("error compiling host wasm func[%s]: memory access not yet supported", def.DebugName()) + } else if compiled, err = compileWasmFunction(e.enabledFeatures, ir); err != nil { + def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] + return fmt.Errorf("error compiling wasm func[%s]: %w", def.DebugName(), err) } // As this uses mmap, we need to munmap on the compiled machine code when it's GCed. @@ -811,12 +811,13 @@ func (ce *callEngine) builtinFunctionTableGrow(ctx context.Context, tables []*wa ce.pushValue(uint64(res)) } -func compileHostFunction(sig *wasm.FunctionType) (*code, error) { - compiler, err := newCompiler(&wazeroir.CompilationResult{Signature: sig}) +func compileHostFunction(ir *wazeroir.CompilationResult) (*code, error) { + compiler, err := newCompiler(ir) if err != nil { return nil, err } + // TODO: Consider when no memory is used (ex !ir.UsesMemory) if err = compiler.compileHostFunction(); err != nil { return nil, err } diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index 14008132..176a4934 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -23,6 +23,11 @@ var et = &engineTester{} // engineTester implements enginetest.EngineTester. type engineTester struct{} +// IsCompiler implements the same method as documented on enginetest.EngineTester. +func (e *engineTester) IsCompiler() bool { + return true +} + // ListenerFactory implements the same method as documented on enginetest.EngineTester. func (e *engineTester) ListenerFactory() experimental.FunctionListenerFactory { return nil diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index 999dbba5..72224004 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -809,10 +809,7 @@ func (ce *callEngine) callFunction(ctx context.Context, callCtx *wasm.CallContex } func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { - if len(ce.frames) > 0 { - // Use the caller's memory, which might be different from the defining module on an imported function. - callCtx = callCtx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) - } + callCtx = callCtx.WithMemory(ce.callerMemory()) if f.source.FunctionListener != nil { ctx = f.source.FunctionListener.Before(ctx, params) } @@ -830,11 +827,19 @@ func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallContext, f *function) { frame := &callFrame{f: f} moduleInst := f.source.Module - memoryInst := moduleInst.Memory + functions := moduleInst.Engine.(*moduleEngine).functions + var memoryInst *wasm.MemoryInstance + if f.source.IsHostFunction { + if memoryInst = ce.callerMemory(); memoryInst == nil { + // If there was no caller, someone called a host function directly. + memoryInst = callCtx.Memory().(*wasm.MemoryInstance) + } + } else { + memoryInst = moduleInst.Memory + } globals := moduleInst.Globals tables := moduleInst.Tables typeIDs := f.source.Module.TypeIDs - functions := f.source.Module.Engine.(*moduleEngine).functions dataInstances := f.source.Module.DataInstances elementInstances := f.source.Module.ElementInstances ce.pushFrame(frame) @@ -4085,6 +4090,18 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallCont ce.popFrame() } +// callerMemory returns the caller context memory or nil if the host function +// was called directly. +func (ce *callEngine) callerMemory() *wasm.MemoryInstance { + if len(ce.frames) > 0 { + lastFrameSource := ce.frames[len(ce.frames)-1].f.source + if !lastFrameSource.IsHostFunction { + return lastFrameSource.Module.Memory + } + } + return nil +} + func WasmCompatMax32bits(v1, v2 uint32) uint64 { return uint64(math.Float32bits(moremath.WasmCompatMax32( math.Float32frombits(v1), diff --git a/internal/engine/interpreter/interpreter_test.go b/internal/engine/interpreter/interpreter_test.go index 7d531438..c395d7af 100644 --- a/internal/engine/interpreter/interpreter_test.go +++ b/internal/engine/interpreter/interpreter_test.go @@ -70,6 +70,12 @@ var listenerFactory = experimental.NewLoggingListenerFactory(&functionLog) // engineTester implements enginetest.EngineTester. type engineTester struct{} +// IsCompiler implements enginetest.EngineTester NewEngine. +func (e engineTester) IsCompiler() bool { + return false +} + +// ListenerFactory implements enginetest.EngineTester NewEngine. func (e engineTester) ListenerFactory() experimental.FunctionListenerFactory { return listenerFactory } @@ -117,31 +123,15 @@ func TestInterpreter_Engine_NewModuleEngine_InitTable(t *testing.T) { func TestInterpreter_ModuleEngine_Call(t *testing.T) { defer functionLog.Reset() enginetest.RunTestModuleEngine_Call(t, et) - require.Equal(t, `--> .$0(1,2) + require.Equal(t, ` +--> .$0(1,2) <-- (1,2) -`, functionLog.String()) +`, "\n"+functionLog.String()) } func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) { defer functionLog.Reset() enginetest.RunTestModuleEngine_Call_HostFn(t, et) - require.Equal(t, `==> .$0(3) -<== (3) ---> imported.wasm_div_by(1) -<-- (1) -==> host.host_div_by(1) -<== (1) ---> imported.call->host_div_by(1) - ==> host.host_div_by(1) - <== (1) -<-- (1) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) - <== (1) - <-- (1) -<-- (1) -`, functionLog.String()) } func TestInterpreter_ModuleEngine_Call_Errors(t *testing.T) { @@ -154,57 +144,58 @@ func TestInterpreter_ModuleEngine_Call_Errors(t *testing.T) { // ==> host.host_div_by(4294967295) // instead of seeing a return like // <== DivByZero - require.Equal(t, `==> host.host_div_by(1) + require.Equal(t, ` +==> host.div_by.go(1) <== (1) -==> host.host_div_by(1) +==> host.div_by.go(1) <== (1) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(1) <-- (1) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(1) <-- (1) ---> imported.wasm_div_by(0) ---> imported.wasm_div_by(1) +--> imported.div_by.wasm(0) +--> imported.div_by.wasm(1) <-- (1) -==> host.host_div_by(4294967295) -==> host.host_div_by(1) +==> host.div_by.go(4294967295) +==> host.div_by.go(1) <== (1) -==> host.host_div_by(0) -==> host.host_div_by(1) +==> host.div_by.go(0) +==> host.div_by.go(1) <== (1) ---> imported.call->host_div_by(4294967295) - ==> host.host_div_by(4294967295) ---> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) ---> importing.call_import->call->host_div_by(0) - --> imported.call->host_div_by(0) - ==> host.host_div_by(0) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) ---> importing.call_import->call->host_div_by(4294967295) - --> imported.call->host_div_by(4294967295) - ==> host.host_div_by(4294967295) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(4294967295) + --> imported.call->div_by.go(4294967295) + ==> host.div_by.go(4294967295) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) ---> importing.call_import->call->host_div_by(0) - --> imported.call->host_div_by(0) - ==> host.host_div_by(0) ---> importing.call_import->call->host_div_by(1) - --> imported.call->host_div_by(1) - ==> host.host_div_by(1) +--> importing.call_import->call->div_by.go(0) + --> imported.call->div_by.go(0) + ==> host.div_by.go(0) +--> importing.call_import->call->div_by.go(1) + --> imported.call->div_by.go(1) + ==> host.div_by.go(1) <== (1) <-- (1) <-- (1) -`, functionLog.String()) +`, "\n"+functionLog.String()) } func TestInterpreter_ModuleEngine_Memory(t *testing.T) { diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 8e9bf062..073159d0 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -20,15 +20,19 @@ import ( "context" "errors" "math" - "reflect" "testing" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/testing/require" + "github.com/tetratelabs/wazero/internal/u64" "github.com/tetratelabs/wazero/internal/wasm" ) +const ( + i32, i64 = wasm.ValueTypeI32, wasm.ValueTypeI64 +) + var ( // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") @@ -37,6 +41,9 @@ var ( ) type EngineTester interface { + // IsCompiler returns true if this engine is a compiler. + IsCompiler() bool + NewEngine(enabledFeatures wasm.Features) wasm.Engine ListenerFactory() experimental.FunctionListenerFactory @@ -70,14 +77,14 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { func RunTestEngine_InitializeFuncrefGlobals(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20220419) - i64 := wasm.ValueTypeI64 + i64 := i64 m := &wasm.Module{ TypeSection: []*wasm.FunctionType{{Params: []wasm.ValueType{i64}, Results: []wasm.ValueType{i64}}}, FunctionSection: []uint32{0, 0, 0}, CodeSection: []*wasm.Code{ - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{i64}}, }, } m.BuildFunctionDefinitions() @@ -92,7 +99,7 @@ func RunTestEngine_InitializeFuncrefGlobals(t *testing.T, et EngineTester) { nullRefVal := wasm.GlobalInstanceNullFuncRefValue globals := []*wasm.GlobalInstance{ - {Val: 10, Type: &wasm.GlobalType{ValType: wasm.ValueTypeI32}}, + {Val: 10, Type: &wasm.GlobalType{ValType: i32}}, {Val: uint64(nullRefVal), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, {Val: uint64(2), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, {Val: uint64(1), Type: &wasm.GlobalType{ValType: wasm.ValueTypeFuncref}}, @@ -115,7 +122,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { // Define a basic function which defines two parameters and two results. // This is used to test results when incorrect arity is used. - i64 := wasm.ValueTypeI64 + i64 := i64 m := &wasm.Module{ TypeSection: []*wasm.FunctionType{ { @@ -227,7 +234,8 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []uint32{0, 0, 0, 0}, CodeSection: []*wasm.Code{ - {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, }, ID: wasm.ModuleID{2}, } @@ -326,55 +334,75 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { }) } -func runTestModuleEngine_Call_HostFn_ModuleContext(t *testing.T, et EngineTester) { - features := wasm.Features20191205 - e := et.NewEngine(features) +func runTestModuleEngine_Call_HostFn_Mem(t *testing.T, et EngineTester, readMem *wasm.Code) { + e := et.NewEngine(wasm.Features20191205) + host, importing, close := setupCallMemTests(t, e, readMem, et.ListenerFactory()) + defer close() - sig := &wasm.FunctionType{ - Params: []wasm.ValueType{wasm.ValueTypeI64}, - Results: []wasm.ValueType{wasm.ValueTypeI64}, - ParamNumInUint64: 1, ResultNumInUint64: 1, + hostMemoryVal := uint64(3) + host.Memory = &wasm.MemoryInstance{Buffer: u64.LeBytes(hostMemoryVal), Min: 1, Cap: 1, Max: 1} + importingMemoryVal := uint64(6) + importing.Memory = &wasm.MemoryInstance{Buffer: u64.LeBytes(importingMemoryVal), Min: 1, Cap: 1, Max: 1} + + tests := []struct { + name string + module *wasm.CallContext + fn *wasm.FunctionInstance + expected uint64 + }{ + { + name: readMemName, + module: host.CallCtx, + fn: host.Exports[readMemName].Function, + expected: hostMemoryVal, + }, + { + name: callReadMemName, + module: host.CallCtx, + fn: host.Exports[callReadMemName].Function, + expected: hostMemoryVal, + }, + { + name: callImportReadMemName, + module: importing.CallCtx, + fn: importing.Exports[callImportReadMemName].Function, + expected: importingMemoryVal, + }, + { + name: callImportCallReadMemName, + module: importing.CallCtx, + fn: importing.Exports[callImportCallReadMemName].Function, + expected: importingMemoryVal, + }, } + for _, tt := range tests { + tc := tt - memory := &wasm.MemoryInstance{} - var mMemory api.Memory - host := reflect.ValueOf(func(m api.Module, v uint64) uint64 { - mMemory = m.Memory() - return v - }) - - m := &wasm.Module{ - TypeSection: []*wasm.FunctionType{sig}, - FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &host}}, + t.Run(tc.name, func(t *testing.T) { + results, err := tc.fn.Module.Engine.Call(testCtx, tc.module, tc.fn) + require.NoError(t, err) + require.Equal(t, tc.expected, results[0]) + }) } - m.BuildFunctionDefinitions() - err := e.CompileModule(testCtx, m) - require.NoError(t, err) - - module := &wasm.ModuleInstance{Memory: memory, TypeIDs: []wasm.FunctionTypeID{0}} - _, ns := wasm.NewStore(features, e) - modCtx := wasm.NewCallContext(ns, module, nil) - - fns := module.BuildFunctions(m, buildListeners(et.ListenerFactory(), m)) - me, err := e.NewModuleEngine(t.Name(), m, nil, fns, nil, nil) - require.NoError(t, err) - - t.Run("defaults to module memory when call stack empty", func(t *testing.T) { - // When calling a host func directly, there may be no stack. This ensures the module's memory is used. - results, err := me.Call(testCtx, modCtx, fns[0], 3) - require.NoError(t, err) - require.Equal(t, uint64(3), results[0]) - require.Same(t, memory, mMemory) - }) } func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { - runTestModuleEngine_Call_HostFn_ModuleContext(t, et) // TODO: refactor to use the same test interface. + t.Run("wasm", func(t *testing.T) { + runTestModuleEngine_Call_HostFn(t, et, hostDivByWasm) + if !et.IsCompiler() { // TODO: Support host wasm func that uses caller memory in compiler. + runTestModuleEngine_Call_HostFn_Mem(t, et, hostReadMemWasm) + } + }) + t.Run("go", func(t *testing.T) { + runTestModuleEngine_Call_HostFn(t, et, hostDivByGo) + runTestModuleEngine_Call_HostFn_Mem(t, et, hostReadMemGo) + }) +} +func runTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester, hostDivBy *wasm.Code) { e := et.NewEngine(wasm.Features20191205) - host, imported, importing, close := setupCallTests(t, e, et.ListenerFactory()) + host, imported, importing, close := setupCallTests(t, e, hostDivBy, et.ListenerFactory()) defer close() // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. @@ -384,24 +412,24 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { fn *wasm.FunctionInstance }{ { - name: wasmFnName, + name: divByWasmName, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, }, { - name: hostFnName, + name: divByGoName, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, }, { - name: callHostFnName, + name: callDivByGoName, module: imported.CallCtx, - fn: imported.Exports[callHostFnName].Function, + fn: imported.Exports[callDivByGoName].Function, }, { - name: callImportCallHostFnName, + name: callImportCallDivByGoName, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, }, } for _, tt := range tests { @@ -420,7 +448,7 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) - host, imported, importing, close := setupCallTests(t, e, et.ListenerFactory()) + host, imported, importing, close := setupCallTests(t, e, hostDivByGo, et.ListenerFactory()) defer close() tests := []struct { @@ -434,99 +462,99 @@ func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { name: "host function not enough parameters", input: []uint64{}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `expected 1 params, but passed 0`, }, { name: "host function too many parameters", input: []uint64{1, 2}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `expected 1 params, but passed 2`, }, { name: "wasm function not enough parameters", input: []uint64{}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `expected 1 params, but passed 0`, }, { name: "wasm function too many parameters", input: []uint64{1, 2}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `expected 1 params, but passed 2`, }, { name: "wasm function panics with wasmruntime.Error", input: []uint64{0}, module: imported.CallCtx, - fn: imported.Exports[wasmFnName].Function, + fn: imported.Exports[divByWasmName].Function, expectedErr: `wasm error: integer divide by zero wasm stack trace: - imported.wasm_div_by(i32) i32`, + imported.div_by.wasm(i32) i32`, }, { name: "host function that panics", input: []uint64{math.MaxUint32}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32`, + host.div_by.go(i32) i32`, }, { name: "host function panics with runtime.Error", input: []uint64{0}, module: host.CallCtx, - fn: host.Exports[hostFnName].Function, + fn: host.Exports[divByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32`, + host.div_by.go(i32) i32`, }, { name: "wasm calls host function that panics", input: []uint64{math.MaxUint32}, module: imported.CallCtx, - fn: imported.Exports[callHostFnName].Function, + fn: imported.Exports[callDivByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm that calls host function panics with runtime.Error", input: []uint64{0}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm that calls host function that panics", input: []uint64{math.MaxUint32}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, { name: "wasm calls imported wasm calls host function panics with runtime.Error", input: []uint64{0}, module: importing.CallCtx, - fn: importing.Exports[callImportCallHostFnName].Function, + fn: importing.Exports[callImportCallDivByGoName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - host.host_div_by(i32) i32 - imported.call->host_div_by(i32) i32 - importing.call_import->call->host_div_by(i32) i32`, + host.div_by.go(i32) i32 + imported.call->div_by.go(i32) i32 + importing.call_import->call->div_by.go(i32) i32`, }, } for _, tt := range tests { @@ -665,35 +693,61 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { } const ( - wasmFnName = "wasm_div_by" - hostFnName = "host_div_by" - callHostFnName = "call->" + hostFnName - callImportCallHostFnName = "call_import->" + callHostFnName + divByWasmName = "div_by.wasm" + divByGoName = "div_by.go" + callDivByGoName = "call->" + divByGoName + callImportCallDivByGoName = "call_import->" + callDivByGoName ) -// (func (export "wasm_div_by") (param i32) (result i32) (i32.div_u (i32.const 1) (local.get 0))) -var wasmFnBody = []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeLocalGet, 0, wasm.OpcodeI32DivU, wasm.OpcodeEnd} - -func divBy(d uint32) uint32 { +func divByGo(d uint32) uint32 { if d == math.MaxUint32 { panic(errors.New("host-function panic")) } return 1 / d // go panics if d == 0 } -func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, *wasm.ModuleInstance, func()) { - i32 := wasm.ValueTypeI32 +var hostDivByGo = wasm.MustParseGoFuncCode(divByGo) + +// (func (export "div_by.wasm") (param i32) (result i32) (i32.div_u (i32.const 1) (local.get 0))) +var divByWasm = []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeLocalGet, 0, wasm.OpcodeI32DivU, wasm.OpcodeEnd} +var hostDivByWasm = &wasm.Code{IsHostFunction: true, Body: divByWasm} + +const ( + readMemName = "read_mem" + callReadMemName = "call->read_mem" + callImportReadMemName = "call_import->read_mem" + callImportCallReadMemName = "call_import->call->read_mem" +) + +func readMemGo(ctx context.Context, m api.Module) uint64 { + ret, ok := m.Memory().ReadUint64Le(ctx, 0) + if !ok { + panic("couldn't read memory") + } + return ret +} + +var hostReadMemGo = wasm.MustParseGoFuncCode(readMemGo) + +// (func (export "wasm_read_mem") (result i64) i32.const 0 i64.load) +var readMemWasm = []byte{wasm.OpcodeI32Const, 0, wasm.OpcodeI64Load, 0x3, 0x0, wasm.OpcodeEnd} +var hostReadMemWasm = &wasm.Code{IsHostFunction: true, Body: readMemWasm} + +func setupCallTests(t *testing.T, e wasm.Engine, divBy *wasm.Code, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, *wasm.ModuleInstance, func()) { ft := &wasm.FunctionType{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}, ParamNumInUint64: 1, ResultNumInUint64: 1} - hostFnVal := reflect.ValueOf(divBy) + divByName := divByWasmName + if divBy.GoFunc != nil { + divByName = divByGoName + } hostModule := &wasm.Module{ TypeSection: []*wasm.FunctionType{ft}, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{{GoFunc: &hostFnVal}}, - ExportSection: []*wasm.Export{{Name: hostFnName, Type: wasm.ExternTypeFunc, Index: 0}}, + CodeSection: []*wasm.Code{divBy}, + ExportSection: []*wasm.Export{{Name: divByGoName, Type: wasm.ExternTypeFunc, Index: 0}}, NameSection: &wasm.NameSection{ ModuleName: "host", - FunctionNames: wasm.NameMap{{Index: wasm.Index(0), Name: hostFnName}}, + FunctionNames: wasm.NameMap{{Index: wasm.Index(0), Name: divByName}}, }, ID: wasm.ModuleID{0}, } @@ -703,7 +757,7 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) host.BuildExports(hostModule.ExportSection) - hostFn := host.Exports[hostFnName].Function + hostFn := host.Exports[divByGoName].Function hostME, err := e.NewModuleEngine(host.Name, hostModule, nil, host.Functions, nil, nil) require.NoError(t, err) @@ -714,19 +768,19 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe TypeSection: []*wasm.FunctionType{ft}, FunctionSection: []uint32{0, 0}, CodeSection: []*wasm.Code{ - {Body: wasmFnBody}, + {Body: divByWasm}, {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, byte(0), // Calling imported host function ^. wasm.OpcodeEnd}}, }, ExportSection: []*wasm.Export{ - {Name: wasmFnName, Type: wasm.ExternTypeFunc, Index: 1}, - {Name: callHostFnName, Type: wasm.ExternTypeFunc, Index: 2}, + {Name: divByWasmName, Type: wasm.ExternTypeFunc, Index: 1}, + {Name: callDivByGoName, Type: wasm.ExternTypeFunc, Index: 2}, }, NameSection: &wasm.NameSection{ ModuleName: "imported", FunctionNames: wasm.NameMap{ - {Index: wasm.Index(1), Name: wasmFnName}, - {Index: wasm.Index(2), Name: callHostFnName}, + {Index: wasm.Index(1), Name: divByWasmName}, + {Index: wasm.Index(2), Name: callDivByGoName}, }, }, ID: wasm.ModuleID{1}, @@ -739,7 +793,7 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe importedFunctions := imported.BuildFunctions(importedModule, buildListeners(fnlf, importedModule)) imported.Functions = append([]*wasm.FunctionInstance{hostFn}, importedFunctions...) imported.BuildExports(importedModule.ExportSection) - callHostFn := imported.Exports[callHostFnName].Function + callHostFn := imported.Exports[callDivByGoName].Function // Compile the imported module importedMe, err := e.NewModuleEngine(imported.Name, importedModule, []*wasm.FunctionInstance{hostFn}, importedFunctions, nil, nil) @@ -755,11 +809,11 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 0 /* only one imported function */, wasm.OpcodeEnd}}, }, ExportSection: []*wasm.Export{ - {Name: callImportCallHostFnName, Type: wasm.ExternTypeFunc, Index: 1}, + {Name: callImportCallDivByGoName, Type: wasm.ExternTypeFunc, Index: 1}, }, NameSection: &wasm.NameSection{ ModuleName: "importing", - FunctionNames: wasm.NameMap{{Index: wasm.Index(1), Name: callImportCallHostFnName}}, + FunctionNames: wasm.NameMap{{Index: wasm.Index(1), Name: callImportCallDivByGoName}}, }, ID: wasm.ModuleID{2}, } @@ -785,6 +839,76 @@ func setupCallTests(t *testing.T, e wasm.Engine, fnlf experimental.FunctionListe } } +func setupCallMemTests(t *testing.T, e wasm.Engine, readMem *wasm.Code, fnlf experimental.FunctionListenerFactory) (*wasm.ModuleInstance, *wasm.ModuleInstance, func()) { + ft := &wasm.FunctionType{Results: []wasm.ValueType{i64}, ResultNumInUint64: 1} + + callReadMem := &wasm.Code{ // shows indirect calls still use the same memory + IsHostFunction: true, + Body: []byte{wasm.OpcodeCall, 0, wasm.OpcodeEnd}, + } + hostModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{ft}, + FunctionSection: []wasm.Index{0, 0}, + CodeSection: []*wasm.Code{readMem, callReadMem}, + ExportSection: []*wasm.Export{ + {Name: readMemName, Type: wasm.ExternTypeFunc, Index: 0}, + {Name: callReadMemName, Type: wasm.ExternTypeFunc, Index: 1}, + }, + NameSection: &wasm.NameSection{ + ModuleName: "host", + FunctionNames: wasm.NameMap{{Index: 0, Name: readMemName}, {Index: 1, Name: callReadMemName}}, + }, + ID: wasm.ModuleID{0}, + } + hostModule.BuildFunctionDefinitions() + err := e.CompileModule(testCtx, hostModule) + require.NoError(t, err) + host := &wasm.ModuleInstance{Name: hostModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} + host.Functions = host.BuildFunctions(hostModule, buildListeners(fnlf, hostModule)) + host.BuildExports(hostModule.ExportSection) + readMemFn := host.Exports[readMemName].Function + callReadMemFn := host.Exports[callReadMemName].Function + + hostME, err := e.NewModuleEngine(host.Name, hostModule, nil, host.Functions, nil, nil) + require.NoError(t, err) + linkModuleToEngine(host, hostME) + + importingModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{ft}, + ExportSection: []*wasm.Export{ + {Name: callImportReadMemName, Type: wasm.ExternTypeFunc, Index: 0}, + {Name: callImportCallReadMemName, Type: wasm.ExternTypeFunc, Index: 1}, + }, + NameSection: &wasm.NameSection{ + ModuleName: "importing", + FunctionNames: wasm.NameMap{ + {Index: 0, Name: callImportReadMemName}, + {Index: 1, Name: callImportCallReadMemName}, + }, + }, + ID: wasm.ModuleID{1}, + } + importingModule.BuildFunctionDefinitions() + err = e.CompileModule(testCtx, importingModule) + require.NoError(t, err) + + // Add the exported function. + importing := &wasm.ModuleInstance{Name: importingModule.NameSection.ModuleName, TypeIDs: []wasm.FunctionTypeID{0}} + importingFunctions := importing.BuildFunctions(importingModule, buildListeners(fnlf, importingModule)) + importing.Functions = append([]*wasm.FunctionInstance{readMemFn, callReadMemFn}, importingFunctions...) + importing.BuildExports(importingModule.ExportSection) + + // Compile the importing module + importingMe, err := e.NewModuleEngine(importing.Name, importingModule, []*wasm.FunctionInstance{readMemFn}, importingFunctions, nil, nil) + require.NoError(t, err) + linkModuleToEngine(importing, importingMe) + + return host, importing, func() { + e.DeleteCompiledModule(hostModule) + e.DeleteCompiledModule(importingModule) + } +} + // linkModuleToEngine assigns fields that wasm.Store would on instantiation. These includes fields both interpreter and // Compiler needs as well as fields only needed by Compiler. // diff --git a/internal/wasm/binary/encoder_test.go b/internal/wasm/binary/encoder_test.go index 12e7cec3..3f8925e7 100644 --- a/internal/wasm/binary/encoder_test.go +++ b/internal/wasm/binary/encoder_test.go @@ -1,9 +1,9 @@ package binary import ( - "reflect" "testing" + "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/wasm" @@ -211,10 +211,13 @@ func TestModule_Encode(t *testing.T) { func TestModule_Encode_HostFunctionSection_Unsupported(t *testing.T) { // We don't currently have an approach to serialize reflect.Value pointers - fn := reflect.ValueOf(func(wasm.Module) {}) + fn := func(api.Module) {} captured := require.CapturePanic(func() { - EncodeModule(&wasm.Module{CodeSection: []*wasm.Code{{GoFunc: &fn}}}) + EncodeModule(&wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(fn)}, + }) }) require.EqualError(t, captured, "BUG: GoFunc is not encodable") } diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 97823bf7..58dd236a 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -64,8 +64,9 @@ func (m *Module) validateFunctionWithMaxStackValues( declaredFunctionIndexes map[Index]struct{}, ) error { functionType := m.TypeSection[m.FunctionSection[idx]] - body := m.CodeSection[idx].Body - localTypes := m.CodeSection[idx].LocalTypes + code := m.CodeSection[idx] + body := code.Body + localTypes := code.LocalTypes types := m.TypeSection // We start with the outermost control block which is for function return if the code branches into it. @@ -90,8 +91,8 @@ func (m *Module) validateFunctionWithMaxStackValues( } if OpcodeI32Load <= op && op <= OpcodeI64Store32 { - if memory == nil { - return fmt.Errorf("unknown memory access") + if memory == nil && !code.IsHostFunction { + return fmt.Errorf("memory must exist for %s", InstructionName(op)) } pc++ align, _, read, err := readMemArg(pc, body) @@ -272,8 +273,8 @@ func (m *Module) validateFunctionWithMaxStackValues( } } } else if OpcodeMemorySize <= op && op <= OpcodeMemoryGrow { - if memory == nil { - return fmt.Errorf("unknown memory access") + if memory == nil && !code.IsHostFunction { + return fmt.Errorf("memory must exist for %s", InstructionName(op)) } pc++ val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) @@ -1100,7 +1101,7 @@ func (m *Module) validateFunctionWithMaxStackValues( OpcodeVecV128Load32x2s, OpcodeVecV128Load32x2u, OpcodeVecV128Load8Splat, OpcodeVecV128Load16Splat, OpcodeVecV128Load32Splat, OpcodeVecV128Load64Splat, OpcodeVecV128Load32zero, OpcodeVecV128Load64zero: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } pc++ @@ -1138,7 +1139,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } valueTypeStack.push(ValueTypeV128) case OpcodeVecV128Store: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } pc++ @@ -1157,7 +1158,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("cannot pop the operand for %s: %v", OpcodeVecV128StoreName, err) } case OpcodeVecV128Load8Lane, OpcodeVecV128Load16Lane, OpcodeVecV128Load32Lane, OpcodeVecV128Load64Lane: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } attr := vecLoadLanes[vecOpcode] @@ -1185,7 +1186,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } valueTypeStack.push(ValueTypeV128) case OpcodeVecV128Store8Lane, OpcodeVecV128Store16Lane, OpcodeVecV128Store32Lane, OpcodeVecV128Store64Lane: - if memory == nil { + if memory == nil && !code.IsHostFunction { return fmt.Errorf("memory must exist for %s", VectorInstructionName(vecOpcode)) } attr := vecStoreLanes[vecOpcode] diff --git a/internal/wasm/function_definition.go b/internal/wasm/function_definition.go index 81a108b0..238deea8 100644 --- a/internal/wasm/function_definition.go +++ b/internal/wasm/function_definition.go @@ -1,6 +1,8 @@ package wasm import ( + "reflect" + "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/wasmdebug" ) @@ -65,10 +67,11 @@ func (m *Module) BuildFunctionDefinitions() { } for codeIndex, typeIndex := range m.FunctionSection { + code := m.CodeSection[codeIndex] m.FunctionDefinitionSection = append(m.FunctionDefinitionSection, &FunctionDefinition{ - index: Index(codeIndex) + importCount, - funcType: m.TypeSection[typeIndex], - isHostFunction: m.CodeSection[codeIndex].GoFunc != nil, + index: Index(codeIndex) + importCount, + funcType: m.TypeSection[typeIndex], + goFunc: code.GoFunc, }) } @@ -103,15 +106,15 @@ func (m *Module) BuildFunctionDefinitions() { // FunctionDefinition implements api.FunctionDefinition type FunctionDefinition struct { - moduleName string - index Index - name string - debugName string - isHostFunction bool - funcType *FunctionType - importDesc *[2]string - exportNames []string - paramNames []string + moduleName string + index Index + name string + debugName string + goFunc *reflect.Value + funcType *FunctionType + importDesc *[2]string + exportNames []string + paramNames []string } // ModuleName implements the same method as documented on api.FunctionDefinition. @@ -147,9 +150,9 @@ func (f *FunctionDefinition) ExportNames() []string { return f.exportNames } -// IsHostFunction implements the same method as documented on api.FunctionDefinition. -func (f *FunctionDefinition) IsHostFunction() bool { - return f.isHostFunction +// GoFunc implements the same method as documented on api.FunctionDefinition. +func (f *FunctionDefinition) GoFunc() *reflect.Value { + return f.goFunc } // ParamNames implements the same method as documented on api.FunctionDefinition. diff --git a/internal/wasm/function_definition_test.go b/internal/wasm/function_definition_test.go index 1c4bfc3e..8fe3c1aa 100644 --- a/internal/wasm/function_definition_test.go +++ b/internal/wasm/function_definition_test.go @@ -1,7 +1,6 @@ package wasm import ( - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -10,7 +9,7 @@ import ( func TestModule_BuildFunctionDefinitions(t *testing.T) { nopCode := &Code{Body: []byte{OpcodeEnd}} - fnV := reflect.ValueOf(func() {}) + fn := func() {} tests := []struct { name string m *Module @@ -32,18 +31,18 @@ func TestModule_BuildFunctionDefinitions(t *testing.T) { expectedExports: map[string]api.FunctionDefinition{}, }, { - name: "host func", + name: "host func go", m: &Module{ TypeSection: []*FunctionType{v_v}, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnV}}, + CodeSection: []*Code{MustParseGoFuncCode(fn)}, }, expected: []*FunctionDefinition{ { - index: 0, - debugName: ".$0", - isHostFunction: true, - funcType: v_v, + index: 0, + debugName: ".$0", + goFunc: MustParseGoFuncCode(fn).GoFunc, + funcType: v_v, }, }, expectedExports: map[string]api.FunctionDefinition{}, diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index 2f1ab806..fe575220 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -148,16 +148,30 @@ func newModuleVal(m api.Module) reflect.Value { return val } -// getFunctionType returns the function type corresponding to the function signature or errs if invalid. -func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err error) { - p := fn.Type() +// MustParseGoFuncCode parses Code from the go function or panics. +// +// Exposing this simplifies definition of host functions in built-in host +// modules and tests. +func MustParseGoFuncCode(fn interface{}) *Code { + _, _, code, err := parseGoFunc(fn) + if err != nil { + panic(err) + } + return code +} - if fn.Kind() != reflect.Func { - err = fmt.Errorf("kind != func: %s", fn.Kind().String()) +func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err error) { + fnV := reflect.ValueOf(fn) + p := fnV.Type() + + if fnV.Kind() != reflect.Func { + err = fmt.Errorf("kind != func: %s", fnV.Kind().String()) return } - fk = kind(p) + fk := kind(p) + code = &Code{IsHostFunction: true, Kind: fk, GoFunc: &fnV} + pOffset := 0 switch fk { case FunctionKindGoNoContext: @@ -167,13 +181,14 @@ func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err pOffset = 1 } - rCount := p.NumOut() - - ft = &FunctionType{Params: make([]ValueType, p.NumIn()-pOffset), Results: make([]ValueType, rCount)} - for i := 0; i < len(ft.Params); i++ { + pCount := p.NumIn() - pOffset + if pCount > 0 { + params = make([]ValueType, pCount) + } + for i := 0; i < len(params); i++ { pI := p.In(i + pOffset) if t, ok := getTypeOf(pI.Kind()); ok { - ft.Params[i] = t + params[i] = t continue } @@ -193,10 +208,14 @@ func getFunctionType(fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err return } - for i := 0; i < len(ft.Results); i++ { + rCount := p.NumOut() + if rCount > 0 { + results = make([]ValueType, rCount) + } + for i := 0; i < len(results); i++ { rI := p.Out(i) if t, ok := getTypeOf(rI.Kind()); ok { - ft.Results[i] = t + results[i] = t continue } diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index d9c696ad..afe7c3e3 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -3,7 +3,6 @@ package wasm import ( "context" "math" - "reflect" "testing" "unsafe" @@ -14,7 +13,7 @@ import ( // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") -func TestGetFunctionType(t *testing.T) { +func Test_parseGoFunc(t *testing.T) { var tests = []struct { name string inputFunc interface{} @@ -25,25 +24,25 @@ func TestGetFunctionType(t *testing.T) { name: "nullary", inputFunc: func() {}, expectedKind: FunctionKindGoNoContext, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "wasm.Module void return", inputFunc: func(api.Module) {}, expectedKind: FunctionKindGoModule, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "context.Context void return", inputFunc: func(context.Context) {}, expectedKind: FunctionKindGoContext, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "context.Context and api.Module void return", inputFunc: func(context.Context, api.Module) {}, expectedKind: FunctionKindGoContextModule, - expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + expectedType: &FunctionType{}, }, { name: "all supported params and i32 result", @@ -85,16 +84,15 @@ func TestGetFunctionType(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - rVal := reflect.ValueOf(tc.inputFunc) - fk, ft, err := getFunctionType(&rVal) + paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - require.Equal(t, tc.expectedKind, fk) - require.Equal(t, tc.expectedType, ft) + require.Equal(t, tc.expectedKind, code.Kind) + require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes}) }) } } -func TestGetFunctionTypeErrors(t *testing.T) { +func Test_parseGoFunc_Errors(t *testing.T) { tests := []struct { name string input interface{} @@ -142,8 +140,7 @@ func TestGetFunctionTypeErrors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - rVal := reflect.ValueOf(tc.input) - _, _, err := getFunctionType(&rVal) + _, _, _, err := parseGoFunc(tc.input) require.EqualError(t, err, tc.expectedErr) }) } @@ -247,11 +244,10 @@ func TestPopGoFuncParams(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - goFunc := reflect.ValueOf(tc.inputFunc) - fk, _, err := getFunctionType(&goFunc) + _, _, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - vals := PopGoFuncParams(&FunctionInstance{Kind: fk, GoFunc: &goFunc}, (&stack{stackVals}).pop) + vals := PopGoFuncParams(&FunctionInstance{Kind: code.Kind, GoFunc: code.GoFunc}, (&stack{stackVals}).pop) require.Equal(t, tc.expected, vals) }) } @@ -400,11 +396,20 @@ func TestCallGoFunc(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - goFunc := reflect.ValueOf(tc.inputFunc) - fk, _, err := getFunctionType(&goFunc) + paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) require.NoError(t, err) - results := CallGoFunc(testCtx, callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) + results := CallGoFunc( + testCtx, + callCtx, + &FunctionInstance{ + IsHostFunction: code.IsHostFunction, + Kind: code.Kind, + Type: &FunctionType{Params: paramTypes, Results: resultTypes}, + GoFunc: code.GoFunc, + }, + tc.inputParams, + ) require.Equal(t, tc.expectedResults, results) }) } diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 7cf4eec5..4a5ecb70 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -2,16 +2,15 @@ package wasm import ( "fmt" - "reflect" "sort" "strings" "github.com/tetratelabs/wazero/internal/wasmdebug" ) -// Func is a function with an inlined type, typically used for NewHostModule. +// HostFunc is a function with an inlined type, used for NewHostModule. // Any corresponding FunctionType will be reused or added to the Module. -type Func struct { +type HostFunc struct { // ExportNames is equivalent to the same method on api.FunctionDefinition. ExportNames []string @@ -31,35 +30,39 @@ type Func struct { Code *Code } -// NewGoFunc returns a Func for the given parameters or panics. -func NewGoFunc(exportName string, name string, paramNames []string, fn interface{}) *Func { - fnV := reflect.ValueOf(fn) - _, ft, err := getFunctionType(&fnV) - if err != nil { - panic(err) - } - return &Func{ +// NewGoFunc returns a HostFunc for the given parameters or panics. +func NewGoFunc(exportName string, name string, paramNames []string, fn interface{}) *HostFunc { + return (&HostFunc{ ExportNames: []string{exportName}, Name: name, - ParamTypes: ft.Params, - ResultTypes: ft.Results, ParamNames: paramNames, - Code: &Code{GoFunc: &fnV}, + }).MustGoFunc(fn) +} + +// MustGoFunc calls WithGoFunc or panics on error. +func (f *HostFunc) MustGoFunc(fn interface{}) *HostFunc { + if ret, err := f.WithGoFunc(fn); err != nil { + panic(err) + } else { + return ret } } // WithGoFunc returns a copy of the function, replacing its Code.GoFunc. -func (f *Func) WithGoFunc(fn interface{}) *Func { +func (f *HostFunc) WithGoFunc(fn interface{}) (*HostFunc, error) { ret := *f - fnV := reflect.ValueOf(fn) - ret.Code = &Code{GoFunc: &fnV} - return &ret + var err error + ret.ParamTypes, ret.ResultTypes, ret.Code, err = parseGoFunc(fn) + return &ret, err } // WithWasm returns a copy of the function, replacing its Code.Body. -func (f *Func) WithWasm(body []byte) *Func { +func (f *HostFunc) WithWasm(body []byte) *HostFunc { ret := *f - ret.Code = &Code{Body: body} + ret.Code = &Code{IsHostFunction: true, Body: body} + if f.Code != nil { + ret.Code.LocalTypes = f.Code.LocalTypes + } return &ret } @@ -138,41 +141,46 @@ func addFuncs( m.NameSection = &NameSection{} } moduleName := m.NameSection.ModuleName - nameToFunc := make(map[string]*Func, len(nameToGoFunc)) + nameToFunc := make(map[string]*HostFunc, len(nameToGoFunc)) + sortedExportNames := make([]string, len(nameToFunc)) + for k := range nameToGoFunc { + sortedExportNames = append(sortedExportNames, k) + } + + // Sort names for consistent iteration + sort.Strings(sortedExportNames) + funcNames := make([]string, len(nameToFunc)) - for k, v := range nameToGoFunc { - if hf, ok := v.(*Func); !ok { - fn := reflect.ValueOf(v) - _, ft, ftErr := getFunctionType(&fn) + for _, k := range sortedExportNames { + v := nameToGoFunc[k] + if hf, ok := v.(*HostFunc); ok { + nameToFunc[hf.Name] = hf + funcNames = append(funcNames, hf.Name) + } else { + params, results, code, ftErr := parseGoFunc(v) if ftErr != nil { return fmt.Errorf("func[%s.%s] %w", moduleName, k, ftErr) } - hf = &Func{ + hf = &HostFunc{ ExportNames: []string{k}, Name: k, - ParamTypes: ft.Params, - ResultTypes: ft.Results, - Code: &Code{GoFunc: &fn}, + ParamTypes: params, + ResultTypes: results, + Code: code, } if names := funcToNames[k]; names != nil { namesLen := len(names) - if namesLen > 1 && namesLen-1 != len(ft.Params) { - return fmt.Errorf("func[%s.%s] has %d params, but %d param names", moduleName, k, namesLen-1, len(ft.Params)) + if namesLen > 1 && namesLen-1 != len(params) { + return fmt.Errorf("func[%s.%s] has %d params, but %d param names", moduleName, k, namesLen-1, len(params)) } hf.Name = names[0] hf.ParamNames = names[1:] } nameToFunc[k] = hf funcNames = append(funcNames, k) - } else { - nameToFunc[hf.Name] = hf - funcNames = append(funcNames, hf.Name) } } - // Sort names for consistent iteration - sort.Strings(funcNames) - funcCount := uint32(len(nameToFunc)) m.NameSection.FunctionNames = make([]*NameAssoc, 0, funcCount) m.FunctionSection = make([]Index, 0, funcCount) diff --git a/internal/wasm/host_test.go b/internal/wasm/host_test.go index 03981b03..43a63afb 100644 --- a/internal/wasm/host_test.go +++ b/internal/wasm/host_test.go @@ -1,7 +1,6 @@ package wasm import ( - "reflect" "testing" "github.com/tetratelabs/wazero/api" @@ -32,11 +31,8 @@ func swap(x, y uint32) (uint32, uint32) { func TestNewHostModule(t *testing.T) { a := wasiAPI{} functionArgsSizesGet := "args_sizes_get" - fnArgsSizesGet := reflect.ValueOf(a.ArgsSizesGet) functionFdWrite := "fd_write" - fnFdWrite := reflect.ValueOf(a.FdWrite) functionSwap := "swap" - fnSwap := reflect.ValueOf(swap) tests := []struct { name, moduleName string @@ -67,7 +63,7 @@ func TestNewHostModule(t *testing.T) { {Params: []ValueType{i32, i32, i32, i32}, Results: []ValueType{i32}}, }, FunctionSection: []Index{0, 1}, - CodeSection: []*Code{{GoFunc: &fnArgsSizesGet}, {GoFunc: &fnFdWrite}}, + CodeSection: []*Code{MustParseGoFuncCode(a.ArgsSizesGet), MustParseGoFuncCode(a.FdWrite)}, ExportSection: []*Export{ {Name: "args_sizes_get", Type: ExternTypeFunc, Index: 0}, {Name: "fd_write", Type: ExternTypeFunc, Index: 1}, @@ -90,7 +86,7 @@ func TestNewHostModule(t *testing.T) { expected: &Module{ TypeSection: []*FunctionType{{Params: []ValueType{i32, i32}, Results: []ValueType{i32, i32}}}, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnSwap}}, + CodeSection: []*Code{MustParseGoFuncCode(swap)}, ExportSection: []*Export{{Name: "swap", Type: ExternTypeFunc, Index: 0}}, NameSection: &NameSection{ModuleName: "swapper", FunctionNames: NameMap{{Index: 0, Name: "swap"}}}, }, @@ -152,7 +148,7 @@ func TestNewHostModule(t *testing.T) { {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, }, FunctionSection: []Index{0}, - CodeSection: []*Code{{GoFunc: &fnArgsSizesGet}}, + CodeSection: []*Code{MustParseGoFuncCode(a.ArgsSizesGet)}, GlobalSection: []*Global{ { Type: &GlobalType{ValType: i32}, @@ -201,9 +197,17 @@ func requireHostModuleEquals(t *testing.T, expected, actual *Module) { require.Equal(t, expected.NameSection, actual.NameSection) // Special case because reflect.Value can't be compared with Equals + // TODO: This is copy/paste with /builder_test.go require.Equal(t, len(expected.CodeSection), len(actual.CodeSection)) - for _, c := range expected.CodeSection { - require.Equal(t, c.GoFunc.Type(), c.GoFunc.Type()) + for i, c := range expected.CodeSection { + actualCode := actual.CodeSection[i] + require.True(t, actualCode.IsHostFunction) + require.Equal(t, c.Kind, actualCode.Kind) + require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + + // Not wasm + require.Nil(t, actualCode.Body) + require.Nil(t, actualCode.LocalTypes) } } diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 446cd438..e9a5c3c6 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -587,16 +587,13 @@ func (m *ModuleInstance) BuildFunctions(mod *Module, listeners []experimental.Fu fns = make([]*FunctionInstance, 0, len(mod.FunctionDefinitionSection)) for i := range mod.FunctionSection { code := mod.CodeSection[i] - fnKind := FunctionKindWasm - if fn := code.GoFunc; fn != nil { - fnKind = kind(fn.Type()) - } fns = append(fns, &FunctionInstance{ - Kind: fnKind, - Body: code.Body, - GoFunc: code.GoFunc, - TypeID: m.TypeIDs[mod.FunctionSection[i]], - LocalTypes: code.LocalTypes, + IsHostFunction: code.IsHostFunction, + Kind: code.Kind, + LocalTypes: code.LocalTypes, + Body: code.Body, + GoFunc: code.GoFunc, + TypeID: m.TypeIDs[mod.FunctionSection[i]], }) } @@ -804,14 +801,17 @@ type Export struct { // Code is an entry in the Module.CodeSection containing the locals and body of the function. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-code type Code struct { + // IsHostFunction returns true if the function was implemented by the + // embedder (ex via wazero.ModuleBuilder) instead of a wasm binary. + // + // Notably, host functions can use the caller's memory, which might be + // different from its defining module. + // + // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 + IsHostFunction bool - // GoFunc is a host function defined in Go. - // - // When present, LocalTypes and Body must be nil. - // - // Note: This has no serialization format, so is not encodable. - // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 - GoFunc *reflect.Value + // Kind describes how this function should be called. + Kind FunctionKind // LocalTypes are any function-scoped variables in insertion order. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-local @@ -820,6 +820,14 @@ type Code struct { // Body is a sequence of expressions ending in OpcodeEnd // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-expr Body []byte + + // GoFunc is a host function defined in Go. + // + // When present, LocalTypes and Body must be nil. + // + // Note: This has no serialization format, so is not encodable. + // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 + GoFunc *reflect.Value } type DataSegment struct { diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 391375ab..545a4a5b 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -108,6 +108,10 @@ type ( // FunctionInstance represents a function instance in a Store. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#function-instances%E2%91%A0 FunctionInstance struct { + // IsHostFunction is the data returned by the same field documented on + // wasm.Code. + IsHostFunction bool + // Kind describes how this function should be called. Kind FunctionKind diff --git a/internal/watzero/internal/func_parser_test.go b/internal/watzero/internal/func_parser_test.go index 9d73d7a0..5e581dcb 100644 --- a/internal/watzero/internal/func_parser_test.go +++ b/internal/watzero/internal/func_parser_test.go @@ -76,7 +76,7 @@ func TestFuncParser(t *testing.T) { name: "i32.load", source: "(func i32.const 8 i32.load)", expected: &wasm.Code{Body: []byte{ - wasm.OpcodeI32Const, 8, // dynamic memory offset to load + wasm.OpcodeI32Const, 8, // memory offset to load wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 wasm.OpcodeEnd, }}, @@ -100,7 +100,7 @@ func TestFuncParser(t *testing.T) { name: "i64.load", source: "(func i32.const 8 i64.load)", expected: &wasm.Code{Body: []byte{ - wasm.OpcodeI32Const, 8, // dynamic memory offset to load + wasm.OpcodeI32Const, 8, // memory offset to load wasm.OpcodeI64Load, 0x3, 0x0, // load alignment=3 (natural alignment) staticOffset=0 wasm.OpcodeEnd, }}, diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 3129a016..a6538d0c 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -177,12 +177,17 @@ func (c *compiler) resetUnreachable() { } type CompilationResult struct { - // GoFunc is present when the result is a host function. - // In this case, other fields can be ignored. + // IsHostFunction is the data returned by the same field documented on + // wasm.Code. + IsHostFunction bool + + // GoFunc is the data returned by the same field documented on wasm.Code. + // In this case, IsHostFunction is true and other fields can be ignored. GoFunc *reflect.Value // Operations holds wazeroir operations compiled from Wasm instructions in a Wasm function. Operations []Operation + // LabelCallers maps Label.String() to the number of callers to that label. // Here "callers" means that the call-sites which jumps to the label with br, br_if or br_table // instructions. @@ -209,6 +214,8 @@ type CompilationResult struct { TableTypes []wasm.ValueType // HasMemory is true if the module from which this function is compiled has memory declaration. HasMemory bool + // UsesMemory is true if this function might use memory. + UsesMemory bool // HasTable is true if the module from which this function is compiled has table declaration. HasTable bool // HasDataInstances is true if the module has data instances which might be used by memory.init or data.drop instructions. @@ -239,7 +246,13 @@ func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module * sig := module.TypeSection[typeID] code := module.CodeSection[funcIndex] if code.GoFunc != nil { - ret = append(ret, &CompilationResult{GoFunc: code.GoFunc}) + ret = append(ret, &CompilationResult{ + IsHostFunction: true, + // Assume the function might use memory if it has a parameter for the api.Module + UsesMemory: code.Kind == wasm.FunctionKindGoModule || code.Kind == wasm.FunctionKindGoContextModule, + GoFunc: code.GoFunc, + Signature: sig, + }) continue } r, err := compile(enabledFeatures, sig, code.Body, code.LocalTypes, module.TypeSection, functions, globals) @@ -247,6 +260,7 @@ func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module * def := module.FunctionDefinitionSection[uint32(funcIndex)+module.ImportFuncCount()] return nil, fmt.Errorf("failed to lower func[%s] to wazeroir: %w", def.DebugName(), err) } + r.IsHostFunction = code.IsHostFunction r.Globals = globals r.Functions = functions r.Types = module.TypeSection @@ -1045,11 +1059,13 @@ operatorSwitch: &OperationStore32{Arg: imm}, ) case wasm.OpcodeMemorySize: + c.result.UsesMemory = true c.pc++ // Skip the reserved one byte. c.emit( &OperationMemorySize{}, ) case wasm.OpcodeMemoryGrow: + c.result.UsesMemory = true c.pc++ // Skip the reserved one byte. c.emit( &OperationMemoryGrow{}, @@ -1672,6 +1688,7 @@ operatorSwitch: &OperationITruncFromF{InputType: Float64, OutputType: SignedUint64, NonTrapping: true}, ) case wasm.OpcodeMiscMemoryInit: + c.result.UsesMemory = true dataIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) @@ -1690,11 +1707,13 @@ operatorSwitch: &OperationDataDrop{DataIndex: dataIndex}, ) case wasm.OpcodeMiscMemoryCopy: + c.result.UsesMemory = true c.pc += 2 // +2 to skip two memory indexes which are fixed to zero. c.emit( &OperationMemoryCopy{}, ) case wasm.OpcodeMiscMemoryFill: + c.result.UsesMemory = true c.pc += 1 // +1 to skip the memory index which is fixed to zero. c.emit( &OperationMemoryFill{}, @@ -3072,6 +3091,7 @@ func (c *compiler) stackLenInUint64(ceil int) (ret int) { } func (c *compiler) readMemoryArg(tag string) (*MemoryArg, error) { + c.result.UsesMemory = true r := bytes.NewReader(c.body[c.pc+1:]) alignment, num, err := leb128.DecodeUint32(r) if err != nil { diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 450f2476..bb2d2781 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -51,11 +51,57 @@ func TestCompile(t *testing.T) { }, LabelCallers: map[string]uint32{}, Functions: []uint32{0}, - Types: []*wasm.FunctionType{{}}, - Signature: &wasm.FunctionType{}, + Types: []*wasm.FunctionType{v_v}, + Signature: v_v, TableTypes: []wasm.RefType{}, }, }, + { + name: "host wasm nullary", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{IsHostFunction: true, Body: []byte{wasm.OpcodeEnd}}}, + }, + expected: &CompilationResult{ + IsHostFunction: true, + Operations: []Operation{ // begin with params: [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Functions: []uint32{0}, + Types: []*wasm.FunctionType{v_v}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + }, + }, + { + name: "host go nullary", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func() {})}, + }, + expected: &CompilationResult{IsHostFunction: true}, + }, + { + name: "host go api.Module uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(api.Module) {})}, + }, + expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, + }, + { + name: "host go context.Context api.Module uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(context.Context, api.Module) {})}, + }, + expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, + }, { name: "identity", module: requireModuleText(t, `(module @@ -82,6 +128,61 @@ func TestCompile(t *testing.T) { TableTypes: []wasm.RefType{}, }, }, + { + name: "uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{Body: []byte{ + wasm.OpcodeI32Const, 8, // memory offset to load + wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 + wasm.OpcodeDrop, + wasm.OpcodeEnd, + }}}, + }, + expected: &CompilationResult{ + Operations: []Operation{ // begin with params: [] + &OperationConstI32{Value: 8}, // [8] + &OperationLoad{Type: UnsignedTypeI32, Arg: &MemoryArg{Alignment: 2, Offset: 0}}, // [x] + &OperationDrop{Depth: &InclusiveRange{}}, // [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Types: []*wasm.FunctionType{v_v}, + Functions: []uint32{0}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + UsesMemory: true, + }, + }, + { + name: "host uses memory", + module: &wasm.Module{ + TypeSection: []*wasm.FunctionType{v_v}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{{IsHostFunction: true, Body: []byte{ + wasm.OpcodeI32Const, 8, // memory offset to load + wasm.OpcodeI32Load, 0x2, 0x0, // load alignment=2 (natural alignment) staticOffset=0 + wasm.OpcodeDrop, + wasm.OpcodeEnd, + }}}, + }, + expected: &CompilationResult{ + IsHostFunction: true, + Operations: []Operation{ // begin with params: [] + &OperationConstI32{Value: 8}, // [8] + &OperationLoad{Type: UnsignedTypeI32, Arg: &MemoryArg{Alignment: 2, Offset: 0}}, // [x] + &OperationDrop{Depth: &InclusiveRange{}}, // [] + &OperationBr{Target: &BranchTarget{}}, // return! + }, + LabelCallers: map[string]uint32{}, + Types: []*wasm.FunctionType{v_v}, + Functions: []uint32{0}, + Signature: v_v, + TableTypes: []wasm.RefType{}, + UsesMemory: true, + }, + }, { name: "memory.grow", // Ex to expose ops to grow memory module: requireModuleText(t, `(module @@ -107,6 +208,7 @@ func TestCompile(t *testing.T) { ResultNumInUint64: 1, }, TableTypes: []wasm.RefType{}, + UsesMemory: true, }, }, } @@ -124,7 +226,16 @@ func TestCompile(t *testing.T) { } res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) - require.Equal(t, tc.expected, res[0]) + + fn := res[0] + if fn.GoFunc != nil { // can't compare functions + // Special case because reflect.Value can't be compared with Equals + require.True(t, fn.IsHostFunction) + require.Equal(t, tc.expected.UsesMemory, fn.UsesMemory) + require.Equal(t, &tc.module.CodeSection[0].GoFunc, &fn.GoFunc) + } else { + require.Equal(t, tc.expected, fn) + } }) } } @@ -244,6 +355,7 @@ func TestCompile_BulkMemoryOperations(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, HasMemory: true, + UsesMemory: true, HasDataInstances: true, LabelCallers: map[string]uint32{}, Signature: v_v, @@ -742,7 +854,7 @@ func TestCompile_Refs(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, } @@ -810,7 +922,7 @@ func TestCompile_TableGetOrSet(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, TableSection: []*wasm.Table{{}}, @@ -879,7 +991,7 @@ func TestCompile_TableGrowFillSize(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: tc.body}}, TableSection: []*wasm.Table{{}}, @@ -933,7 +1045,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.get - non func param - v128", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -996,7 +1108,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.set - non func param - v128", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -1070,7 +1182,7 @@ func TestCompile_Locals(t *testing.T) { { name: "local.tee - non func param", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{ Body: []byte{ @@ -2443,7 +2555,7 @@ func TestCompile_Vec(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { module := &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, MemorySection: &wasm.Memory{}, CodeSection: []*wasm.Code{{Body: tc.body}}, @@ -2477,7 +2589,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2492,7 +2604,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br_if", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2508,7 +2620,7 @@ func TestCompile_unreachable_Br_BrIf_BrTable(t *testing.T) { { name: "br_table", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeBr, 0, // Return the function -> the followings are unreachable. @@ -2543,7 +2655,7 @@ func TestCompile_drop_vectors(t *testing.T) { { name: "basic", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, @@ -2581,7 +2693,7 @@ func TestCompile_select_vectors(t *testing.T) { { name: "non typed", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, @@ -2607,7 +2719,7 @@ func TestCompile_select_vectors(t *testing.T) { { name: "typed", mod: &wasm.Module{ - TypeSection: []*wasm.FunctionType{{}}, + TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, CodeSection: []*wasm.Code{{Body: []byte{ wasm.OpcodeVecPrefix, diff --git a/wasi_snapshot_preview1/wasi.go b/wasi_snapshot_preview1/wasi.go index 45747cbf..f1007038 100644 --- a/wasi_snapshot_preview1/wasi.go +++ b/wasi_snapshot_preview1/wasi.go @@ -181,13 +181,16 @@ func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, va } // stubFunction stubs for GrainLang per #271. -func stubFunction(name string, paramTypes []wasm.ValueType, paramNames []string) *wasm.Func { - return &wasm.Func{ +func stubFunction(name string, paramTypes []wasm.ValueType, paramNames []string) *wasm.HostFunc { + return &wasm.HostFunc{ Name: name, ExportNames: []string{name}, ParamTypes: paramTypes, ParamNames: paramNames, ResultTypes: []wasm.ValueType{i32}, - Code: &wasm.Code{Body: []byte{wasm.OpcodeI32Const, byte(ErrnoNosys), wasm.OpcodeEnd}}, + Code: &wasm.Code{ + IsHostFunction: true, + Body: []byte{wasm.OpcodeI32Const, byte(ErrnoNosys), wasm.OpcodeEnd}, + }, } }