From 3d25f48b4a066ef9def55ec435dbcca8a9a01b6f Mon Sep 17 00:00:00 2001 From: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com> Date: Fri, 18 Feb 2022 16:54:52 +0800 Subject: [PATCH] Removes requirement to pass a HostFunctionCallContext (#260) This allows users to decouple from wazero code when authoring host functions. Notably, this allows them to opt out of using a context, or only using a Go context instead of HostFunctionCallContext. This backfills docs on how to write host functions (in simple terms). Finally, this does not optimize engines to avoid propagating context or looking up memory if it would never be used. That could be done later. Signed-off-by: Adrian Cole --- README.md | 7 +- examples/simple_test.go | 3 +- internal/wasi/wasi.go | 8 +- internal/wasm/host.go | 135 +++++++++++++++--- internal/wasm/host_test.go | 123 ++++++++++++++++ internal/wasm/interpreter/interpreter.go | 29 ++-- internal/wasm/interpreter/interpreter_test.go | 1 + internal/wasm/jit/engine.go | 46 +++--- internal/wasm/jit/jit_amd64.go | 2 +- internal/wasm/jit/jit_amd64_test.go | 48 +++++-- internal/wasm/jit/jit_arm64_test.go | 12 +- internal/wasm/jit/jit_test.go | 1 + internal/wasm/store.go | 27 ++-- internal/wasm/store_test.go | 30 ++-- store.go | 34 +++++ tests/engine/adhoc_test.go | 83 +++++++---- 16 files changed, 441 insertions(+), 148 deletions(-) diff --git a/README.md b/README.md index 17ca34a2..733fef5a 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,8 @@ func main() { // Decode the binary as WebAssembly module. mod, _ := wazero.DecodeModuleBinary(source) - // Initialize the execution environment called "store" with Interpreter-based engine. - store := wazero.NewStore() - - // Instantiate the module, which returns its exported functions - functions, _ := store.Instantiate(mod) + // Instantiate the module with a Wasm Interpreter, to return its exported functions + functions, _ := wazero.NewStore().Instantiate(mod) // Get the factorial function fac, _ := functions.GetFunctionI64Return("fac") diff --git a/examples/simple_test.go b/examples/simple_test.go index 45636aea..72452252 100644 --- a/examples/simple_test.go +++ b/examples/simple_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/require" "github.com/tetratelabs/wazero" - "github.com/tetratelabs/wazero/wasm" ) // Test_Simple implements a basic function in go: hello. This is imported as the Wasm name "$hello" and run on start. @@ -20,7 +19,7 @@ func Test_Simple(t *testing.T) { require.NoError(t, err) stdout := new(bytes.Buffer) - goFunc := func(wasm.HostFunctionCallContext) { + goFunc := func() { _, _ = fmt.Fprintln(stdout, "hello!") } diff --git a/internal/wasi/wasi.go b/internal/wasi/wasi.go index 453b2e27..51122e57 100644 --- a/internal/wasi/wasi.go +++ b/internal/wasi/wasi.go @@ -113,7 +113,7 @@ const ( FunctionPathUnlinkFile = "path_unlink_file" FunctionPollOneoff = "poll_oneoff" - // ProcExit terminates the execution of the module with an exit code. + // FunctionProcExit terminates the execution of the module with an exit code. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit FunctionProcExit = "proc_exit" @@ -339,7 +339,7 @@ type SnapshotPreview1 interface { // // Note: ImportProcExit shows this signature in the WebAssembly 1.0 (MVP) Text Format. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit - ProcExit(ctx wasm.HostFunctionCallContext, rval uint32) + ProcExit(rval uint32) // TODO: ProcRaise // TODO: SchedYield @@ -505,8 +505,8 @@ func (a *wasiAPI) ClockTimeGet(ctx wasm.HostFunctionCallContext, id uint32, prec return wasi.ErrnoSuccess } -// ProcExit implements API.ProcExit -func (a *wasiAPI) ProcExit(ctx wasm.HostFunctionCallContext, exitCode uint32) { +// ProcExit implements SnapshotPreview1.ProcExit +func (a *wasiAPI) ProcExit(exitCode uint32) { // Panic in a host function is caught by the engines, and the value of the panic is returned as the error of the CallFunction. // See the document of API.ProcExit. panic(wasi.ExitCode(exitCode)) diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 1a01b0d4..34b160f6 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -10,42 +10,131 @@ import ( publicwasm "github.com/tetratelabs/wazero/wasm" ) +// FunctionKind identifies the type of function that can be called. +type FunctionKind byte + +const ( + // FunctionKindWasm is not a host function: it is implemented in Wasm. + FunctionKindWasm FunctionKind = iota + // FunctionKindHostNoContext is a function implemented in Go, with a signature matching FunctionType. + FunctionKindHostNoContext + // FunctionKindHostGoContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is + // a context.Context. + FunctionKindHostGoContext + // FunctionKindHostFunctionCallContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is + // a HostFunctionCallContext. + FunctionKindHostFunctionCallContext +) + type HostFunction struct { - name string + name string + // functionKind is never FunctionKindWasm + functionKind FunctionKind functionType *FunctionType goFunc *reflect.Value } +func NewHostFunction(funcName string, goFunc interface{}) (hf *HostFunction, err error) { + hf = &HostFunction{name: funcName} + fn := reflect.ValueOf(goFunc) + hf.goFunc = &fn + hf.functionKind, hf.functionType, err = GetFunctionType(hf.name, hf.goFunc) + return +} + +// Below are reflection code to get the interface type used to parse functions and set values. + +var hostFunctionCallContextType = reflect.TypeOf((*publicwasm.HostFunctionCallContext)(nil)).Elem() +var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem() +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +// GetHostFunctionCallContextValue returns a reflect.Value for a context param[0], or nil if there isn't one. +func GetHostFunctionCallContextValue(fk FunctionKind, ctx *HostFunctionCallContext) *reflect.Value { + switch fk { + case FunctionKindHostNoContext: // no special param zero + case FunctionKindHostGoContext: + val := reflect.New(goContextType).Elem() + val.Set(reflect.ValueOf(ctx.Context())) + return &val + case FunctionKindHostFunctionCallContext: + val := reflect.New(hostFunctionCallContextType).Elem() + val.Set(reflect.ValueOf(ctx)) + return &val + } + return nil +} + // GetFunctionType returns the function type corresponding to the function signature or errs if invalid. -func GetFunctionType(name string, fn *reflect.Value) (*FunctionType, error) { +func GetFunctionType(name string, fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err error) { if fn.Kind() != reflect.Func { - return nil, fmt.Errorf("%s value is not a reflect.Func: %s", name, fn.String()) + err = fmt.Errorf("%s is a %s, but should be a Func", name, fn.Kind().String()) + return } p := fn.Type() - if p.NumIn() == 0 { // TODO: actually check the type - return nil, fmt.Errorf("%s must accept wasm.HostFunctionCallContext as the first param", name) - } - paramTypes := make([]ValueType, p.NumIn()-1) - for i := range paramTypes { - kind := p.In(i + 1).Kind() - if t, ok := getTypeOf(kind); !ok { - return nil, fmt.Errorf("%s param[%d] is unsupported: %s", name, i, kind.String()) - } else { - paramTypes[i] = t + pOffset := 0 + pCount := p.NumIn() + fk = FunctionKindHostNoContext + if pCount > 0 && p.In(0).Kind() == reflect.Interface { + p0 := p.In(0) + if p0.Implements(hostFunctionCallContextType) { + fk = FunctionKindHostFunctionCallContext + pOffset = 1 + pCount-- + } else if p0.Implements(goContextType) { + fk = FunctionKindHostGoContext + pOffset = 1 + pCount-- } } - - resultTypes := make([]ValueType, p.NumOut()) - for i := range resultTypes { - kind := p.Out(i).Kind() - if t, ok := getTypeOf(kind); !ok { - return nil, fmt.Errorf("%s result[%d] is unsupported: %s", name, i, kind.String()) - } else { - resultTypes[i] = t - } + rCount := p.NumOut() + switch rCount { + case 0, 1: // ok + default: + err = fmt.Errorf("%s has more than one result", name) + return } - return &FunctionType{Params: paramTypes, Results: resultTypes}, nil + + ft = &FunctionType{Params: make([]ValueType, pCount), Results: make([]ValueType, rCount)} + + for i := 0; i < len(ft.Params); i++ { + pI := p.In(i + pOffset) + if t, ok := getTypeOf(pI.Kind()); ok { + ft.Params[i] = t + continue + } + + // Now, we will definitely err, decide which message is best + var arg0Type reflect.Type + if hc := pI.Implements(hostFunctionCallContextType); hc { + arg0Type = hostFunctionCallContextType + } else if gc := pI.Implements(goContextType); gc { + arg0Type = goContextType + } + + if arg0Type != nil { + err = fmt.Errorf("%s param[%d] is a %s, which may be defined only once as param[0]", name, i+pOffset, arg0Type) + } else { + err = fmt.Errorf("%s param[%d] is unsupported: %s", name, i+pOffset, pI.Kind()) + } + return + } + + if rCount == 0 { + return + } + result := p.Out(0) + if t, ok := getTypeOf(result.Kind()); ok { + ft.Results[0] = t + return + } + + if e := result.Implements(errorType); e { + err = fmt.Errorf("%s result[0] is an error, which is unsupported", name) + } else { + err = fmt.Errorf("%s result[0] is unsupported: %s", name, result.Kind()) + } + return } func getTypeOf(kind reflect.Kind) (ValueType, bool) { diff --git a/internal/wasm/host_test.go b/internal/wasm/host_test.go index 1ffb03e9..0512877f 100644 --- a/internal/wasm/host_test.go +++ b/internal/wasm/host_test.go @@ -1,10 +1,14 @@ package internalwasm import ( + "context" "math" + "reflect" "testing" "github.com/stretchr/testify/require" + + publicwasm "github.com/tetratelabs/wazero/wasm" ) func TestMemoryInstance_HasLen(t *testing.T) { @@ -448,3 +452,122 @@ func TestMemoryInstance_WriteFloat64Le(t *testing.T) { }) } } + +func TestGetFunctionType(t *testing.T) { + i32, i64, f32, f64 := ValueTypeI32, ValueTypeI64, ValueTypeF32, ValueTypeF64 + + tests := []struct { + name string + inputFunc interface{} + expectedKind FunctionKind + expectedType *FunctionType + }{ + { + name: "nullary", + inputFunc: func() {}, + expectedKind: FunctionKindHostNoContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "wasm.HostFunctionCallContext void return", + inputFunc: func(publicwasm.HostFunctionCallContext) {}, + expectedKind: FunctionKindHostFunctionCallContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "context.Context void return", + inputFunc: func(context.Context) {}, + expectedKind: FunctionKindHostGoContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "all supported params and i32 result", + inputFunc: func(uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostNoContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + { + name: "all supported params and i32 result - wasm.HostFunctionCallContext", + inputFunc: func(publicwasm.HostFunctionCallContext, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostFunctionCallContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + { + name: "all supported params and i32 result - context.Context", + inputFunc: func(context.Context, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostGoContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + rVal := reflect.ValueOf(tc.inputFunc) + fk, ft, err := GetFunctionType("fn", &rVal) + require.NoError(t, err) + require.Equal(t, tc.expectedKind, fk) + require.Equal(t, tc.expectedType, ft) + }) + } +} + +func TestGetFunctionTypeErrors(t *testing.T) { + tests := []struct { + name string + input interface{} + expectedErr string + }{ + { + name: "not a func", + input: struct{}{}, + expectedErr: "fn is a struct, but should be a Func", + }, + { + name: "unsupported param", + input: func(uint32, string) {}, + expectedErr: "fn param[1] is unsupported: string", + }, + { + name: "unsupported result", + input: func() string { return "" }, + expectedErr: "fn result[0] is unsupported: string", + }, + { + name: "error result", + input: func() error { return nil }, + expectedErr: "fn result[0] is an error, which is unsupported", + }, + { + name: "multiple results", + input: func() (uint64, uint32) { return 0, 0 }, + expectedErr: "fn has more than one result", + }, + { + name: "multiple context types", + input: func(publicwasm.HostFunctionCallContext, context.Context) error { return nil }, + expectedErr: "fn param[1] is a context.Context, which may be defined only once as param[0]", + }, + { + name: "multiple context.Context", + input: func(context.Context, uint64, context.Context) error { return nil }, + expectedErr: "fn param[2] is a context.Context, which may be defined only once as param[0]", + }, + { + name: "multiple wasm.HostFunctionCallContext", + input: func(publicwasm.HostFunctionCallContext, uint64, publicwasm.HostFunctionCallContext) error { return nil }, + expectedErr: "fn param[2] is a wasm.HostFunctionCallContext, which may be defined only once as param[0]", + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + rVal := reflect.ValueOf(tc.input) + _, _, err := GetFunctionType("fn", &rVal) + require.EqualError(t, err, tc.expectedErr) + }) + } +} diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 805a7802..0f9aedbe 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -117,13 +117,7 @@ type interpreterOp struct { func (it *interpreter) Compile(f *wasm.FunctionInstance) error { funcaddr := f.Address - if f.IsHostFunction() { - ret := &interpreterFunction{ - hostFn: f.HostFunction, funcInstance: f, - } - it.functions[funcaddr] = ret - return nil - } else { + if f.FunctionKind == wasm.FunctionKindWasm { ir, err := wazeroir.Compile(f) if err != nil { return fmt.Errorf("failed to compile Wasm to wazeroir: %w", err) @@ -138,6 +132,12 @@ func (it *interpreter) Compile(f *wasm.FunctionInstance) error { cb(fn) } delete(it.onCompilationDoneCallbacks, funcaddr) + } else { + ret := &interpreterFunction{ + hostFn: f.HostFunction, funcInstance: f, + } + it.functions[funcaddr] = ret + return nil } return nil } @@ -495,7 +495,12 @@ func (it *interpreter) Call(ctx *wasm.HostFunctionCallContext, f *wasm.FunctionI func (it *interpreter) callHostFunc(ctx *wasm.HostFunctionCallContext, f *interpreterFunction) { tp := f.hostFn.Type() in := make([]reflect.Value, tp.NumIn()) - for i := len(in) - 1; i >= 1; i-- { + + wasmParamOffset := 0 + if f.funcInstance.FunctionKind != wasm.FunctionKindHostNoContext { + wasmParamOffset = 1 + } + for i := len(in) - 1; i >= wasmParamOffset; i-- { val := reflect.New(tp.In(i)).Elem() raw := it.pop() kind := tp.In(i).Kind() @@ -512,12 +517,14 @@ func (it *interpreter) callHostFunc(ctx *wasm.HostFunctionCallContext, f *interp in[i] = val } - val := reflect.New(tp.In(0)).Elem() if len(it.frames) > 0 { ctx = ctx.WithMemory(it.frames[len(it.frames)-1].f.funcInstance.ModuleInstance.Memory) } - val.Set(reflect.ValueOf(ctx)) - in[0] = val + + // Handle any special parameter zero + if val := wasm.GetHostFunctionCallContextValue(f.funcInstance.FunctionKind, ctx); val != nil { + in[0] = *val + } frame := &interpreterFrame{f: f} it.pushFrame(frame) diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index dc9743a6..59e4fce1 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -52,6 +52,7 @@ func TestInterpreter_CallHostFunc(t *testing.T) { module := &wasm.ModuleInstance{Memory: memory} it := interpreter{functions: map[wasm.FunctionAddress]*interpreterFunction{ 0: {hostFn: &hostFn, funcInstance: &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindHostFunctionCallContext, FunctionType: &wasm.TypeInstance{ Type: &wasm.FunctionType{ Params: []wasm.ValueType{}, diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index c631c56f..36137162 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -66,7 +66,7 @@ type ( // when making function calls or returning from them. callFrameStackPointer uint64 // previousCallFrameStackPointer is to support re-entrant execution. - // This is updated whenever exntering engine.execFunction. + // This is updated whenever exntering engine.execWasmFunction. // If this is the initial call into Wasm, the value equals zero, // but if this is the recursive function call from the host function, the value becomes non-zero. previousCallFrameStackPointer uint64 @@ -138,7 +138,7 @@ type ( // That is, callFrameTop().returnAddress or returnStackBasePointer are not set // until it makes a function call. callFrame struct { - // Set when making function call from this function frame, or for the initial function frame to call from engine.execFunction. + // Set when making function call from this function frame, or for the initial function frame to call from engine.execWasmFunction. returnAddress uintptr // Set when making function call from this function frame. returnStackBasePointer uint64 @@ -388,10 +388,10 @@ func (e *engine) Call(ctx *wasm.HostFunctionCallContext, f *wasm.FunctionInstanc return } - if compiled.source.IsHostFunction() { - e.execHostFunction(compiled.source.HostFunction, ctx) + if f.FunctionKind == wasm.FunctionKindWasm { + e.execWasmFunction(ctx, compiled) } else { - e.execFunction(ctx, compiled) + e.execHostFunction(f.FunctionKind, compiled.source.HostFunction, ctx) } // Note the top value is the tail of the results, @@ -450,8 +450,7 @@ const ( // execHostFunction executes the given host function represented as *reflect.Value. // -// The arguments to the function are popped from the stack stack following the convension of -// Wasm stack machine. +// The arguments to the function are popped from the stack following the convention of the Wasm stack machine. // For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and // the stack is [..., A, B], then the function is called as F(A, B) where A and B are interpreted // as uint32 and float32 respectively. @@ -459,13 +458,20 @@ const ( // After the execution, the result of host function is pushed onto the stack. // // ctx parameter is passed to the host function as a first argument. -func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallContext) { +func (e *engine) execHostFunction(fk wasm.FunctionKind, f *reflect.Value, ctx *wasm.HostFunctionCallContext) { + // TODO: the signature won't ever change for a host function once instantiated. For this reason, we should be able + // to optimize below based on known possible outcomes. This includes knowledge about if it has a context param[0] + // and which type (if any) it returns. tp := f.Type() in := make([]reflect.Value, tp.NumIn()) // We pop the value and pass them as arguments in a reverse order according to the // stack machine convention. - for i := len(in) - 1; i >= 1; i-- { + wasmParamOffset := 0 + if fk != wasm.FunctionKindHostNoContext { + wasmParamOffset = 1 + } + for i := len(in) - 1; i >= wasmParamOffset; i-- { val := reflect.New(tp.In(i)).Elem() raw := e.popValue() kind := tp.In(i).Kind() @@ -482,12 +488,12 @@ func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallCo in[i] = val } - // Host function must receive wasm.HostFunctionCallContext as a first argument. - val := reflect.New(tp.In(0)).Elem() - val.Set(reflect.ValueOf(ctx)) - in[0] = val + // Handle any special parameter zero + if val := wasm.GetHostFunctionCallContextValue(fk, ctx); val != nil { + in[0] = *val + } - // Excute the host function and push back the call result onto the stack. + // Execute the host function and push back the call result onto the stack. for _, ret := range f.Call(in) { switch ret.Kind() { case reflect.Float64, reflect.Float32: @@ -502,7 +508,7 @@ func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallCo } } -func (e *engine) execFunction(ctx *wasm.HostFunctionCallContext, f *compiledFunction) { +func (e *engine) execWasmFunction(ctx *wasm.HostFunctionCallContext, f *compiledFunction) { // We continuously execute functions until we reach the previous top frame // to support recursive Wasm function executions. e.globalContext.previousCallFrameStackPointer = e.globalContext.callFrameStackPointer @@ -533,12 +539,12 @@ jitentry: fn := e.compiledFunctions[e.exitContext.functionCallAddress] callerCompiledFunction := e.callFrameAt(1).compiledFunction if buildoptions.IsDebugMode { - if !fn.source.IsHostFunction() { + if fn.source.FunctionKind == wasm.FunctionKindWasm { panic("jitCallStatusCodeCallHostFunction is only for host functions") } } saved := e.globalContext.previousCallFrameStackPointer - e.execHostFunction(fn.source.HostFunction, + e.execHostFunction(fn.source.FunctionKind, fn.source.HostFunction, ctx.WithMemory(callerCompiledFunction.source.ModuleInstance.Memory), ) e.globalContext.previousCallFrameStackPointer = saved @@ -636,10 +642,10 @@ func (e *engine) builtinFunctionMemoryGrow(mem *wasm.MemoryInstance) { func (e *engine) Compile(f *wasm.FunctionInstance) (err error) { var compiled *compiledFunction - if f.IsHostFunction() { - compiled, err = compileHostFunction(f) - } else { + if f.FunctionKind == wasm.FunctionKindWasm { compiled, err = compileWasmFunction(f) + } else { + compiled, err = compileHostFunction(f) } if err != nil { return fmt.Errorf("failed to compile function: %w", err) diff --git a/internal/wasm/jit/jit_amd64.go b/internal/wasm/jit/jit_amd64.go index 58449841..c4805825 100644 --- a/internal/wasm/jit/jit_amd64.go +++ b/internal/wasm/jit/jit_amd64.go @@ -4829,7 +4829,7 @@ func (c *amd64Compiler) callFunction(addr wasm.FunctionAddress, addrReg int16, f } // returnFunction adds instructions to return from the current callframe back to the caller's frame. -// If this is the current one is the origin, we return back to the engine.execFunction with the Returned status. +// If this is the current one is the origin, we return back to the engine.execWasmFunction with the Returned status. // Otherwise, we jump into the callers' return address stored in callFrame.returnAddress while setting // up all the necessary change on the engine's state. // diff --git a/internal/wasm/jit/jit_amd64_test.go b/internal/wasm/jit/jit_amd64_test.go index 0e68d8fa..26ae2942 100644 --- a/internal/wasm/jit/jit_amd64_test.go +++ b/internal/wasm/jit/jit_amd64_test.go @@ -29,7 +29,7 @@ func (j *jitEnv) requireNewCompiler(t *testing.T) *amd64Compiler { return &amd64Compiler{builder: b, locationStack: newValueLocationStack(), labels: map[string]*labelInfo{}, - f: &wasm.FunctionInstance{ModuleInstance: j.moduleInstance}, + f: &wasm.FunctionInstance{ModuleInstance: j.moduleInstance, FunctionKind: wasm.FunctionKindWasm}, } } @@ -506,9 +506,9 @@ func TestAmd64Compiler_compileBrTable(t *testing.T) { } func TestAmd64Compiler_pushFunctionInputs(t *testing.T) { - f := &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{ - Params: []wasm.ValueType{wasm.ValueTypeF64, wasm.ValueTypeI32}, - }}} + f := &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{Params: []wasm.ValueType{wasm.ValueTypeF64, wasm.ValueTypeI32}}}} compiler := &amd64Compiler{locationStack: newValueLocationStack(), f: f} compiler.pushFunctionParams() require.Equal(t, uint64(len(f.FunctionType.Type.Params)), compiler.locationStack.sp) @@ -5835,7 +5835,10 @@ func TestAmd64Compiler_compileGlobalGet(t *testing.T) { globals := []*wasm.GlobalInstance{nil, {Val: globalValue, Type: &wasm.GlobalType{ValType: tp}}, nil} env.addGlobals(globals...) // Compiler needs global type information at compilation time. - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Globals: globals}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{Globals: globals}, + FunctionKind: wasm.FunctionKindWasm, + } // Emit the code. err := compiler.compilePreamble() @@ -5966,7 +5969,7 @@ func TestAmd64Compiler_callFunction(t *testing.T) { expectedValue := uint32(0) moduleInstanceToExpectedValueInMemory := map[*wasm.ModuleInstance]uint32{} for i := 0; i < numCalls; i++ { - // Each function takes one arguments, adds the value with 100 + i and returns the result. + // Each function takes one argument, adds the value with 100 + i and returns the result. addTargetValue := uint32(100 + i) moduleInstance := &wasm.ModuleInstance{ Memory: &wasm.MemoryInstance{Buffer: make([]byte, 1024)}, @@ -5974,13 +5977,17 @@ func TestAmd64Compiler_callFunction(t *testing.T) { moduleInstanceToExpectedValueInMemory[moduleInstance] = addTargetValue compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: moduleInstance} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + ModuleInstance: moduleInstance, + } err := compiler.compilePreamble() require.NoError(t, err) expectedValue += addTargetValue - err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: uint32(addTargetValue)}) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: addTargetValue}) require.NoError(t, err) err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32}) @@ -6079,7 +6086,11 @@ func TestAmd64Compiler_compileCall(t *testing.T) { { // Call target function takes three i32 arguments and does ADD 2 times. compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: &wasm.ModuleInstance{}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{}, + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + } err := compiler.compilePreamble() require.NoError(t, err) for i := 0; i < 2; i++ { @@ -6102,7 +6113,11 @@ func TestAmd64Compiler_compileCall(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{ Functions: []*wasm.FunctionInstance{ - {FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, Address: targetFunctionAddress}, + { + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + Address: targetFunctionAddress, + }, }, }} @@ -6149,7 +6164,10 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { targetOperation := &wazeroir.OperationCallIndirect{} // Ensure that the module instance has the type information for targetOperation.TypeIndex. - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}}}} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}}}, + } // Place the offfset value. loc := compiler.locationStack.pushValueLocationOnStack() @@ -6180,8 +6198,10 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { targetOperation := &wazeroir.OperationCallIndirect{} targetOffset := &wazeroir.OperationConstI32{Value: uint32(0)} // Ensure that the module instance has the type information for targetOperation.TypeIndex, - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{ - Type: &wasm.FunctionType{}, TypeID: 1000}}}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}, TypeID: 1000}}}, + FunctionKind: wasm.FunctionKindWasm, + } // and the typeID doesn't match the table[targetOffset]'s type ID. table[0] = wasm.TableElement{FunctionTypeID: wasm.UninitializedTableElementTypeID} @@ -6285,7 +6305,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { require.NoError(t, err) // Ensure that the module instance has the type information for targetOperation.TypeIndex, - compiler.f = &wasm.FunctionInstance{ModuleInstance: moduleInstance} + compiler.f = &wasm.FunctionInstance{ModuleInstance: moduleInstance, FunctionKind: wasm.FunctionKindWasm} // and the typeID matches the table[targetOffset]'s type ID. // Place the offfset value. Here we try calling a function of functionaddr == table[i].FunctionAddress. diff --git a/internal/wasm/jit/jit_arm64_test.go b/internal/wasm/jit/jit_arm64_test.go index a9189f3d..bd85d42b 100644 --- a/internal/wasm/jit/jit_arm64_test.go +++ b/internal/wasm/jit/jit_arm64_test.go @@ -43,7 +43,10 @@ func requirePushTwoFloat32Consts(t *testing.T, x1, x2 float32, compiler *arm64Co } func (j *jitEnv) requireNewCompiler(t *testing.T) *arm64Compiler { - cmp, err := newCompiler(&wasm.FunctionInstance{ModuleInstance: j.moduleInstance}, nil) + cmp, err := newCompiler(&wasm.FunctionInstance{ + ModuleInstance: j.moduleInstance, + FunctionKind: wasm.FunctionKindWasm, + }, nil) require.NoError(t, err) ret, ok := cmp.(*arm64Compiler) require.True(t, ok) @@ -1692,8 +1695,11 @@ func TestArm64Compiler_compileCall(t *testing.T) { expectedValue += addTargetValue compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, - ModuleInstance: &wasm.ModuleInstance{}} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + ModuleInstance: &wasm.ModuleInstance{}, + } err := compiler.compilePreamble() require.NoError(t, err) diff --git a/internal/wasm/jit/jit_test.go b/internal/wasm/jit/jit_test.go index 6f0b834f..562fb47c 100644 --- a/internal/wasm/jit/jit_test.go +++ b/internal/wasm/jit/jit_test.go @@ -115,6 +115,7 @@ func (j *jitEnv) exec(code []byte) { codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), source: &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{}}, }, } diff --git a/internal/wasm/store.go b/internal/wasm/store.go index e3855500..a2fc99f1 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -109,10 +109,13 @@ type ( FunctionType *TypeInstance // LocalTypes holds types of locals. LocalTypes []ValueType + // FunctionKind describes how this function should be called. + FunctionKind FunctionKind // HostFunction holds the runtime representation of host functions. - // If this is not nil, all the above fields are ignored as they are specific to non-host functions. + // This is nil when FunctionKind == FunctionKindWasm. Otherwise, all the above fields are ignored as they are + // specific to Wasm functions. HostFunction *reflect.Value - // Address is the funcaddr(https://www.w3.org/TR/wasm-core-1/#syntax-funcaddr) of this function insntance. + // Address is the funcaddr(https://www.w3.org/TR/wasm-core-1/#syntax-funcaddr) of this function instance. // More precisely, this equals the index of this function instance in store.FunctionInstances. // All function calls are made via funcaddr at runtime, not the index (scoped to a module). // @@ -190,7 +193,7 @@ type ( // and the index to Store.Functions. FunctionAddress uint64 - // FunctionTypeID is an uniquely assigned integer for a function type. + // FunctionTypeID is a uniquely assigned integer for a function type. // This is wazero specific runtime object and specific to a store, // and used at runtime to do type-checks on indirect function calls. FunctionTypeID uint32 @@ -249,10 +252,6 @@ func (m *ModuleInstance) GetFunction(name string, rv ValueType) (*FunctionInstan return exp.Function, nil } -func (f *FunctionInstance) IsHostFunction() bool { - return f.HostFunction != nil -} - func NewStore(engine Engine) *Store { return &Store{ ModuleInstances: map[string]*ModuleInstance{}, @@ -652,6 +651,7 @@ func (s *Store) buildFunctionInstances(module *Module, target *ModuleInstance) ( f := &FunctionInstance{ Name: name, + FunctionKind: FunctionKindWasm, FunctionType: typeInstance, Body: module.CodeSection[codeIndex].Body, LocalTypes: module.CodeSection[codeIndex].LocalTypes, @@ -885,6 +885,7 @@ func (s *Store) AddHostFunction(moduleName string, hf *HostFunction) error { f := &FunctionInstance{ Name: fmt.Sprintf("%s.%s", moduleName, hf.name), HostFunction: hf.goFunc, + FunctionKind: hf.functionKind, FunctionType: typeInstance, ModuleInstance: m, } @@ -904,18 +905,6 @@ func (s *Store) AddHostFunction(moduleName string, hf *HostFunction) error { return nil } -func NewHostFunction(funcName string, goFunc interface{}) (*HostFunction, error) { - hf := &HostFunction{name: funcName} - fn := reflect.ValueOf(goFunc) - hf.goFunc = &fn - ft, err := GetFunctionType(hf.name, hf.goFunc) - if err != nil { - return nil, err - } - hf.functionType = ft - return hf, nil -} - func (s *Store) AddGlobal(moduleName, name string, value uint64, valueType ValueType, mutable bool) error { g := &GlobalInstance{ Val: value, diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index c3e2556b..e3d0998c 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "math" "os" - "reflect" "strconv" "testing" @@ -40,6 +39,7 @@ func TestStore_CallFunction(t *testing.T) { fn: { Kind: ExportKindFunc, Function: &FunctionInstance{ + FunctionKind: FunctionKindWasm, FunctionType: &TypeInstance{ Type: &FunctionType{ Params: []ValueType{}, @@ -89,11 +89,11 @@ func TestStore_CallFunction(t *testing.T) { func TestStore_AddHostFunction(t *testing.T) { s := NewStore(nopEngineInstance) - hostFunction := func(wasm.HostFunctionCallContext) { - } - hf := newHostFunction(t, "fn", hostFunction) - err := s.AddHostFunction("test", hf) + hf, err := NewHostFunction("fn", func(wasm.HostFunctionCallContext) { + }) + require.NoError(t, err) + err = s.AddHostFunction("test", hf) require.NoError(t, err) // The function was added to the store, prefixed by the owning module name @@ -116,21 +116,13 @@ func TestStore_AddHostFunction(t *testing.T) { require.Equal(t, map[string]*ExportInstance{"fn": exp}, m.Exports) } -func newHostFunction(t *testing.T, name string, hostFunction interface{}) *HostFunction { - hf := &HostFunction{name: name} - goFn := reflect.ValueOf(hostFunction) - hf.goFunc = &goFn - ft, err := GetFunctionType(hf.name, hf.goFunc) - require.NoError(t, err) - hf.functionType = ft - return hf -} - func TestStore_ExportImportedHostFunction(t *testing.T) { s := NewStore(nopEngineInstance) - hostFunction := func(wasm.HostFunctionCallContext) { - } - err := s.AddHostFunction("", newHostFunction(t, "host_fn", hostFunction)) + + hf, err := NewHostFunction("host_fn", func(wasm.HostFunctionCallContext) { + }) + require.NoError(t, err) + err = s.AddHostFunction("", hf) require.NoError(t, err) t.Run("ModuleInstance is the importing module", func(t *testing.T) { @@ -212,7 +204,7 @@ func TestStore_addHostFunction(t *testing.T) { t.Run("ok", func(t *testing.T) { s := NewStore(nopEngineInstance) for i := 0; i < 10; i++ { - f := &FunctionInstance{} + f := &FunctionInstance{FunctionKind: FunctionKindHostNoContext} require.Len(t, s.Functions, i) err := s.addFunctionInstance(f) diff --git a/store.go b/store.go index 9cd20d8f..1558203d 100644 --- a/store.go +++ b/store.go @@ -20,10 +20,44 @@ func NewEngineJIT() *Engine { // TODO: compiler? } // HostFunctions are functions written in Go, which a WebAssembly Module can import. +// +// Noting a context exception described later, all parameters or result types must match WebAssembly 1.0 (MVP) value +// types. This means uint32, uint64, float32 or float64. Up to one result can be returned. +// +// Ex. This is a valid host function: +// +// addInts := func(x uint32, uint32) uint32 { +// return x + y +// } +// +// Host functions may also have an initial parameter (param[0]) of type context.Context or wasm.HostFunctionCallContext. +// +// Ex. This uses a Go Context: +// +// addInts := func(ctx context.Context, x uint32, uint32) uint32 { +// // add a little extra if we put some in the context! +// return x + y + ctx.Value(extraKey).(uint32) +// } +// +// The most sophisticated context is wasm.HostFunctionCallContext, which allows access to the Go context, but also +// allows writing to memory. This is important because there are only numeric types in Wasm. The only way to share other +// data is via writing memory and sharing offsets. +// +// Ex. This reads the parameters from! +// +// addInts := func(ctx wasm.HostFunctionCallContext, offset uint32) uint32 { +// x, _ := ctx.Memory().ReadUint32Le(offset) +// y, _ := ctx.Memory().ReadUint32Le(offset + 4) // 32 bits == 4 bytes! +// return x + y +// } +// +// See https://www.w3.org/TR/wasm-core-1/#value-types%E2%91%A0 type HostFunctions struct { nameToHostFunction map[string]*internalwasm.HostFunction } +// NewHostFunctions returns host functions to export. The map key is the name to export and the value is the function. +// See HostFunctions documentation for notes on writing host functions. func NewHostFunctions(nameToGoFunc map[string]interface{}) (ret *HostFunctions, err error) { ret = &HostFunctions{make(map[string]*internalwasm.HostFunction, len(nameToGoFunc))} for name, goFunc := range nameToGoFunc { diff --git a/tests/engine/adhoc_test.go b/tests/engine/adhoc_test.go index 345fb590..b6644774 100644 --- a/tests/engine/adhoc_test.go +++ b/tests/engine/adhoc_test.go @@ -3,6 +3,7 @@ package adhoc import ( "context" _ "embed" + "fmt" "math" "runtime" "sync" @@ -59,7 +60,7 @@ func runTests(t *testing.T, newEngine func() *wazero.Engine) { importedAndExportedFunc(t, newEngine) }) t.Run("host function with float32 type", func(t *testing.T) { - hostFuncWithFloatParam(t, newEngine) + hostFunctions(t, newEngine) }) } @@ -273,8 +274,8 @@ func importedAndExportedFunc(t *testing.T, newEngine func() *wazero.Engine) { require.Equal(t, []byte{0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0}, memory.Buffer[0:10]) } -// hostFuncWithFloatParam fails if a float parameter corrupts a host function value -func hostFuncWithFloatParam(t *testing.T, newEngine func() *wazero.Engine) { +// hostFunctions ensures arg0 is optionally a context, and fails if a float parameter corrupts a host function value +func hostFunctions(t *testing.T, newEngine func() *wazero.Engine) { ctx := context.Background() mod, err := wazero.DecodeModuleText([]byte(`(module $test ;; these imports return the input param @@ -295,41 +296,69 @@ func hostFuncWithFloatParam(t *testing.T, newEngine func() *wazero.Engine) { )`)) require.NoError(t, err) - hostFuncs, err := wazero.NewHostFunctions(map[string]interface{}{ - "identity_f32": func(ctx publicwasm.HostFunctionCallContext, value float32) float32 { + floatFuncs, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(value float32) float32 { return value }, - "identity_f64": func(ctx publicwasm.HostFunctionCallContext, value float64) float64 { + "identity_f64": func(value float64) float64 { return value }}) require.NoError(t, err) - store, err := wazero.NewStoreWithConfig(&wazero.StoreConfig{ - Engine: newEngine(), - ModuleToHostFunctions: map[string]*wazero.HostFunctions{"host": hostFuncs}, - }) + floatFuncsGoContext, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(funcCtx context.Context, value float32) float32 { + require.Equal(t, ctx, funcCtx) + return value + }, + "identity_f64": func(funcCtx context.Context, value float64) float64 { + require.Equal(t, ctx, funcCtx) + return value + }}) require.NoError(t, err) - m, err := store.Instantiate(mod) + floatFuncsHostFunctionCallContext, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(funcCtx publicwasm.HostFunctionCallContext, value float32) float32 { + require.Equal(t, ctx, funcCtx.Context()) + return value + }, + "identity_f64": func(funcCtx publicwasm.HostFunctionCallContext, value float64) float64 { + require.Equal(t, ctx, funcCtx.Context()) + return value + }}) require.NoError(t, err) - t.Run("host function with f32 param", func(t *testing.T) { - fn, ok := m.GetFunctionF32Return("call->test.identity_f32") - require.True(t, ok) - - f32 := float32(math.MaxFloat32) - result, err := fn(ctx, uint64(math.Float32bits(f32))) // float bits are a uint32 value, call requires uint64 + for k, v := range map[string]*wazero.HostFunctions{ + "": floatFuncs, + " - context.Context": floatFuncsGoContext, + " - wasm.HostFunctionCallContext": floatFuncsHostFunctionCallContext, + } { + store, err := wazero.NewStoreWithConfig(&wazero.StoreConfig{ + Engine: newEngine(), + ModuleToHostFunctions: map[string]*wazero.HostFunctions{"host": v}, + }) require.NoError(t, err) - require.Equal(t, f32, result) - }) - t.Run("host function with f64 param", func(t *testing.T) { - fn, ok := m.GetFunctionF64Return("call->test.identity_f64") - require.True(t, ok) - - f64 := math.MaxFloat64 - result, err := fn(ctx, math.Float64bits(f64)) + m, err := store.Instantiate(mod) require.NoError(t, err) - require.Equal(t, f64, result) - }) + + t.Run(fmt.Sprintf("host function with f32 param%s", k), func(t *testing.T) { + fn, ok := m.GetFunctionF32Return("call->test.identity_f32") + require.True(t, ok) + + f32 := float32(math.MaxFloat32) + result, err := fn(ctx, uint64(math.Float32bits(f32))) // float bits are a uint32 value, call requires uint64 + require.NoError(t, err) + require.Equal(t, f32, result) + }) + + t.Run(fmt.Sprintf("host function with f64 param%s", k), func(t *testing.T) { + fn, ok := m.GetFunctionF64Return("call->test.identity_f64") + require.True(t, ok) + + f64 := math.MaxFloat64 + result, err := fn(ctx, math.Float64bits(f64)) + require.NoError(t, err) + require.Equal(t, f64, result) + }) + } }