diff --git a/README.md b/README.md index 17ca34a2..733fef5a 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,8 @@ func main() { // Decode the binary as WebAssembly module. mod, _ := wazero.DecodeModuleBinary(source) - // Initialize the execution environment called "store" with Interpreter-based engine. - store := wazero.NewStore() - - // Instantiate the module, which returns its exported functions - functions, _ := store.Instantiate(mod) + // Instantiate the module with a Wasm Interpreter, to return its exported functions + functions, _ := wazero.NewStore().Instantiate(mod) // Get the factorial function fac, _ := functions.GetFunctionI64Return("fac") diff --git a/examples/simple_test.go b/examples/simple_test.go index 45636aea..72452252 100644 --- a/examples/simple_test.go +++ b/examples/simple_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/require" "github.com/tetratelabs/wazero" - "github.com/tetratelabs/wazero/wasm" ) // Test_Simple implements a basic function in go: hello. This is imported as the Wasm name "$hello" and run on start. @@ -20,7 +19,7 @@ func Test_Simple(t *testing.T) { require.NoError(t, err) stdout := new(bytes.Buffer) - goFunc := func(wasm.HostFunctionCallContext) { + goFunc := func() { _, _ = fmt.Fprintln(stdout, "hello!") } diff --git a/internal/wasi/wasi.go b/internal/wasi/wasi.go index 453b2e27..51122e57 100644 --- a/internal/wasi/wasi.go +++ b/internal/wasi/wasi.go @@ -113,7 +113,7 @@ const ( FunctionPathUnlinkFile = "path_unlink_file" FunctionPollOneoff = "poll_oneoff" - // ProcExit terminates the execution of the module with an exit code. + // FunctionProcExit terminates the execution of the module with an exit code. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit FunctionProcExit = "proc_exit" @@ -339,7 +339,7 @@ type SnapshotPreview1 interface { // // Note: ImportProcExit shows this signature in the WebAssembly 1.0 (MVP) Text Format. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit - ProcExit(ctx wasm.HostFunctionCallContext, rval uint32) + ProcExit(rval uint32) // TODO: ProcRaise // TODO: SchedYield @@ -505,8 +505,8 @@ func (a *wasiAPI) ClockTimeGet(ctx wasm.HostFunctionCallContext, id uint32, prec return wasi.ErrnoSuccess } -// ProcExit implements API.ProcExit -func (a *wasiAPI) ProcExit(ctx wasm.HostFunctionCallContext, exitCode uint32) { +// ProcExit implements SnapshotPreview1.ProcExit +func (a *wasiAPI) ProcExit(exitCode uint32) { // Panic in a host function is caught by the engines, and the value of the panic is returned as the error of the CallFunction. // See the document of API.ProcExit. panic(wasi.ExitCode(exitCode)) diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 1a01b0d4..34b160f6 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -10,42 +10,131 @@ import ( publicwasm "github.com/tetratelabs/wazero/wasm" ) +// FunctionKind identifies the type of function that can be called. +type FunctionKind byte + +const ( + // FunctionKindWasm is not a host function: it is implemented in Wasm. + FunctionKindWasm FunctionKind = iota + // FunctionKindHostNoContext is a function implemented in Go, with a signature matching FunctionType. + FunctionKindHostNoContext + // FunctionKindHostGoContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is + // a context.Context. + FunctionKindHostGoContext + // FunctionKindHostFunctionCallContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is + // a HostFunctionCallContext. + FunctionKindHostFunctionCallContext +) + type HostFunction struct { - name string + name string + // functionKind is never FunctionKindWasm + functionKind FunctionKind functionType *FunctionType goFunc *reflect.Value } +func NewHostFunction(funcName string, goFunc interface{}) (hf *HostFunction, err error) { + hf = &HostFunction{name: funcName} + fn := reflect.ValueOf(goFunc) + hf.goFunc = &fn + hf.functionKind, hf.functionType, err = GetFunctionType(hf.name, hf.goFunc) + return +} + +// Below are reflection code to get the interface type used to parse functions and set values. + +var hostFunctionCallContextType = reflect.TypeOf((*publicwasm.HostFunctionCallContext)(nil)).Elem() +var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem() +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +// GetHostFunctionCallContextValue returns a reflect.Value for a context param[0], or nil if there isn't one. +func GetHostFunctionCallContextValue(fk FunctionKind, ctx *HostFunctionCallContext) *reflect.Value { + switch fk { + case FunctionKindHostNoContext: // no special param zero + case FunctionKindHostGoContext: + val := reflect.New(goContextType).Elem() + val.Set(reflect.ValueOf(ctx.Context())) + return &val + case FunctionKindHostFunctionCallContext: + val := reflect.New(hostFunctionCallContextType).Elem() + val.Set(reflect.ValueOf(ctx)) + return &val + } + return nil +} + // GetFunctionType returns the function type corresponding to the function signature or errs if invalid. -func GetFunctionType(name string, fn *reflect.Value) (*FunctionType, error) { +func GetFunctionType(name string, fn *reflect.Value) (fk FunctionKind, ft *FunctionType, err error) { if fn.Kind() != reflect.Func { - return nil, fmt.Errorf("%s value is not a reflect.Func: %s", name, fn.String()) + err = fmt.Errorf("%s is a %s, but should be a Func", name, fn.Kind().String()) + return } p := fn.Type() - if p.NumIn() == 0 { // TODO: actually check the type - return nil, fmt.Errorf("%s must accept wasm.HostFunctionCallContext as the first param", name) - } - paramTypes := make([]ValueType, p.NumIn()-1) - for i := range paramTypes { - kind := p.In(i + 1).Kind() - if t, ok := getTypeOf(kind); !ok { - return nil, fmt.Errorf("%s param[%d] is unsupported: %s", name, i, kind.String()) - } else { - paramTypes[i] = t + pOffset := 0 + pCount := p.NumIn() + fk = FunctionKindHostNoContext + if pCount > 0 && p.In(0).Kind() == reflect.Interface { + p0 := p.In(0) + if p0.Implements(hostFunctionCallContextType) { + fk = FunctionKindHostFunctionCallContext + pOffset = 1 + pCount-- + } else if p0.Implements(goContextType) { + fk = FunctionKindHostGoContext + pOffset = 1 + pCount-- } } - - resultTypes := make([]ValueType, p.NumOut()) - for i := range resultTypes { - kind := p.Out(i).Kind() - if t, ok := getTypeOf(kind); !ok { - return nil, fmt.Errorf("%s result[%d] is unsupported: %s", name, i, kind.String()) - } else { - resultTypes[i] = t - } + rCount := p.NumOut() + switch rCount { + case 0, 1: // ok + default: + err = fmt.Errorf("%s has more than one result", name) + return } - return &FunctionType{Params: paramTypes, Results: resultTypes}, nil + + ft = &FunctionType{Params: make([]ValueType, pCount), Results: make([]ValueType, rCount)} + + for i := 0; i < len(ft.Params); i++ { + pI := p.In(i + pOffset) + if t, ok := getTypeOf(pI.Kind()); ok { + ft.Params[i] = t + continue + } + + // Now, we will definitely err, decide which message is best + var arg0Type reflect.Type + if hc := pI.Implements(hostFunctionCallContextType); hc { + arg0Type = hostFunctionCallContextType + } else if gc := pI.Implements(goContextType); gc { + arg0Type = goContextType + } + + if arg0Type != nil { + err = fmt.Errorf("%s param[%d] is a %s, which may be defined only once as param[0]", name, i+pOffset, arg0Type) + } else { + err = fmt.Errorf("%s param[%d] is unsupported: %s", name, i+pOffset, pI.Kind()) + } + return + } + + if rCount == 0 { + return + } + result := p.Out(0) + if t, ok := getTypeOf(result.Kind()); ok { + ft.Results[0] = t + return + } + + if e := result.Implements(errorType); e { + err = fmt.Errorf("%s result[0] is an error, which is unsupported", name) + } else { + err = fmt.Errorf("%s result[0] is unsupported: %s", name, result.Kind()) + } + return } func getTypeOf(kind reflect.Kind) (ValueType, bool) { diff --git a/internal/wasm/host_test.go b/internal/wasm/host_test.go index 1ffb03e9..0512877f 100644 --- a/internal/wasm/host_test.go +++ b/internal/wasm/host_test.go @@ -1,10 +1,14 @@ package internalwasm import ( + "context" "math" + "reflect" "testing" "github.com/stretchr/testify/require" + + publicwasm "github.com/tetratelabs/wazero/wasm" ) func TestMemoryInstance_HasLen(t *testing.T) { @@ -448,3 +452,122 @@ func TestMemoryInstance_WriteFloat64Le(t *testing.T) { }) } } + +func TestGetFunctionType(t *testing.T) { + i32, i64, f32, f64 := ValueTypeI32, ValueTypeI64, ValueTypeF32, ValueTypeF64 + + tests := []struct { + name string + inputFunc interface{} + expectedKind FunctionKind + expectedType *FunctionType + }{ + { + name: "nullary", + inputFunc: func() {}, + expectedKind: FunctionKindHostNoContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "wasm.HostFunctionCallContext void return", + inputFunc: func(publicwasm.HostFunctionCallContext) {}, + expectedKind: FunctionKindHostFunctionCallContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "context.Context void return", + inputFunc: func(context.Context) {}, + expectedKind: FunctionKindHostGoContext, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, + { + name: "all supported params and i32 result", + inputFunc: func(uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostNoContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + { + name: "all supported params and i32 result - wasm.HostFunctionCallContext", + inputFunc: func(publicwasm.HostFunctionCallContext, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostFunctionCallContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + { + name: "all supported params and i32 result - context.Context", + inputFunc: func(context.Context, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindHostGoContext, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + rVal := reflect.ValueOf(tc.inputFunc) + fk, ft, err := GetFunctionType("fn", &rVal) + require.NoError(t, err) + require.Equal(t, tc.expectedKind, fk) + require.Equal(t, tc.expectedType, ft) + }) + } +} + +func TestGetFunctionTypeErrors(t *testing.T) { + tests := []struct { + name string + input interface{} + expectedErr string + }{ + { + name: "not a func", + input: struct{}{}, + expectedErr: "fn is a struct, but should be a Func", + }, + { + name: "unsupported param", + input: func(uint32, string) {}, + expectedErr: "fn param[1] is unsupported: string", + }, + { + name: "unsupported result", + input: func() string { return "" }, + expectedErr: "fn result[0] is unsupported: string", + }, + { + name: "error result", + input: func() error { return nil }, + expectedErr: "fn result[0] is an error, which is unsupported", + }, + { + name: "multiple results", + input: func() (uint64, uint32) { return 0, 0 }, + expectedErr: "fn has more than one result", + }, + { + name: "multiple context types", + input: func(publicwasm.HostFunctionCallContext, context.Context) error { return nil }, + expectedErr: "fn param[1] is a context.Context, which may be defined only once as param[0]", + }, + { + name: "multiple context.Context", + input: func(context.Context, uint64, context.Context) error { return nil }, + expectedErr: "fn param[2] is a context.Context, which may be defined only once as param[0]", + }, + { + name: "multiple wasm.HostFunctionCallContext", + input: func(publicwasm.HostFunctionCallContext, uint64, publicwasm.HostFunctionCallContext) error { return nil }, + expectedErr: "fn param[2] is a wasm.HostFunctionCallContext, which may be defined only once as param[0]", + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + rVal := reflect.ValueOf(tc.input) + _, _, err := GetFunctionType("fn", &rVal) + require.EqualError(t, err, tc.expectedErr) + }) + } +} diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 805a7802..0f9aedbe 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -117,13 +117,7 @@ type interpreterOp struct { func (it *interpreter) Compile(f *wasm.FunctionInstance) error { funcaddr := f.Address - if f.IsHostFunction() { - ret := &interpreterFunction{ - hostFn: f.HostFunction, funcInstance: f, - } - it.functions[funcaddr] = ret - return nil - } else { + if f.FunctionKind == wasm.FunctionKindWasm { ir, err := wazeroir.Compile(f) if err != nil { return fmt.Errorf("failed to compile Wasm to wazeroir: %w", err) @@ -138,6 +132,12 @@ func (it *interpreter) Compile(f *wasm.FunctionInstance) error { cb(fn) } delete(it.onCompilationDoneCallbacks, funcaddr) + } else { + ret := &interpreterFunction{ + hostFn: f.HostFunction, funcInstance: f, + } + it.functions[funcaddr] = ret + return nil } return nil } @@ -495,7 +495,12 @@ func (it *interpreter) Call(ctx *wasm.HostFunctionCallContext, f *wasm.FunctionI func (it *interpreter) callHostFunc(ctx *wasm.HostFunctionCallContext, f *interpreterFunction) { tp := f.hostFn.Type() in := make([]reflect.Value, tp.NumIn()) - for i := len(in) - 1; i >= 1; i-- { + + wasmParamOffset := 0 + if f.funcInstance.FunctionKind != wasm.FunctionKindHostNoContext { + wasmParamOffset = 1 + } + for i := len(in) - 1; i >= wasmParamOffset; i-- { val := reflect.New(tp.In(i)).Elem() raw := it.pop() kind := tp.In(i).Kind() @@ -512,12 +517,14 @@ func (it *interpreter) callHostFunc(ctx *wasm.HostFunctionCallContext, f *interp in[i] = val } - val := reflect.New(tp.In(0)).Elem() if len(it.frames) > 0 { ctx = ctx.WithMemory(it.frames[len(it.frames)-1].f.funcInstance.ModuleInstance.Memory) } - val.Set(reflect.ValueOf(ctx)) - in[0] = val + + // Handle any special parameter zero + if val := wasm.GetHostFunctionCallContextValue(f.funcInstance.FunctionKind, ctx); val != nil { + in[0] = *val + } frame := &interpreterFrame{f: f} it.pushFrame(frame) diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index dc9743a6..59e4fce1 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -52,6 +52,7 @@ func TestInterpreter_CallHostFunc(t *testing.T) { module := &wasm.ModuleInstance{Memory: memory} it := interpreter{functions: map[wasm.FunctionAddress]*interpreterFunction{ 0: {hostFn: &hostFn, funcInstance: &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindHostFunctionCallContext, FunctionType: &wasm.TypeInstance{ Type: &wasm.FunctionType{ Params: []wasm.ValueType{}, diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index c631c56f..36137162 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -66,7 +66,7 @@ type ( // when making function calls or returning from them. callFrameStackPointer uint64 // previousCallFrameStackPointer is to support re-entrant execution. - // This is updated whenever exntering engine.execFunction. + // This is updated whenever exntering engine.execWasmFunction. // If this is the initial call into Wasm, the value equals zero, // but if this is the recursive function call from the host function, the value becomes non-zero. previousCallFrameStackPointer uint64 @@ -138,7 +138,7 @@ type ( // That is, callFrameTop().returnAddress or returnStackBasePointer are not set // until it makes a function call. callFrame struct { - // Set when making function call from this function frame, or for the initial function frame to call from engine.execFunction. + // Set when making function call from this function frame, or for the initial function frame to call from engine.execWasmFunction. returnAddress uintptr // Set when making function call from this function frame. returnStackBasePointer uint64 @@ -388,10 +388,10 @@ func (e *engine) Call(ctx *wasm.HostFunctionCallContext, f *wasm.FunctionInstanc return } - if compiled.source.IsHostFunction() { - e.execHostFunction(compiled.source.HostFunction, ctx) + if f.FunctionKind == wasm.FunctionKindWasm { + e.execWasmFunction(ctx, compiled) } else { - e.execFunction(ctx, compiled) + e.execHostFunction(f.FunctionKind, compiled.source.HostFunction, ctx) } // Note the top value is the tail of the results, @@ -450,8 +450,7 @@ const ( // execHostFunction executes the given host function represented as *reflect.Value. // -// The arguments to the function are popped from the stack stack following the convension of -// Wasm stack machine. +// The arguments to the function are popped from the stack following the convention of the Wasm stack machine. // For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and // the stack is [..., A, B], then the function is called as F(A, B) where A and B are interpreted // as uint32 and float32 respectively. @@ -459,13 +458,20 @@ const ( // After the execution, the result of host function is pushed onto the stack. // // ctx parameter is passed to the host function as a first argument. -func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallContext) { +func (e *engine) execHostFunction(fk wasm.FunctionKind, f *reflect.Value, ctx *wasm.HostFunctionCallContext) { + // TODO: the signature won't ever change for a host function once instantiated. For this reason, we should be able + // to optimize below based on known possible outcomes. This includes knowledge about if it has a context param[0] + // and which type (if any) it returns. tp := f.Type() in := make([]reflect.Value, tp.NumIn()) // We pop the value and pass them as arguments in a reverse order according to the // stack machine convention. - for i := len(in) - 1; i >= 1; i-- { + wasmParamOffset := 0 + if fk != wasm.FunctionKindHostNoContext { + wasmParamOffset = 1 + } + for i := len(in) - 1; i >= wasmParamOffset; i-- { val := reflect.New(tp.In(i)).Elem() raw := e.popValue() kind := tp.In(i).Kind() @@ -482,12 +488,12 @@ func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallCo in[i] = val } - // Host function must receive wasm.HostFunctionCallContext as a first argument. - val := reflect.New(tp.In(0)).Elem() - val.Set(reflect.ValueOf(ctx)) - in[0] = val + // Handle any special parameter zero + if val := wasm.GetHostFunctionCallContextValue(fk, ctx); val != nil { + in[0] = *val + } - // Excute the host function and push back the call result onto the stack. + // Execute the host function and push back the call result onto the stack. for _, ret := range f.Call(in) { switch ret.Kind() { case reflect.Float64, reflect.Float32: @@ -502,7 +508,7 @@ func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallCo } } -func (e *engine) execFunction(ctx *wasm.HostFunctionCallContext, f *compiledFunction) { +func (e *engine) execWasmFunction(ctx *wasm.HostFunctionCallContext, f *compiledFunction) { // We continuously execute functions until we reach the previous top frame // to support recursive Wasm function executions. e.globalContext.previousCallFrameStackPointer = e.globalContext.callFrameStackPointer @@ -533,12 +539,12 @@ jitentry: fn := e.compiledFunctions[e.exitContext.functionCallAddress] callerCompiledFunction := e.callFrameAt(1).compiledFunction if buildoptions.IsDebugMode { - if !fn.source.IsHostFunction() { + if fn.source.FunctionKind == wasm.FunctionKindWasm { panic("jitCallStatusCodeCallHostFunction is only for host functions") } } saved := e.globalContext.previousCallFrameStackPointer - e.execHostFunction(fn.source.HostFunction, + e.execHostFunction(fn.source.FunctionKind, fn.source.HostFunction, ctx.WithMemory(callerCompiledFunction.source.ModuleInstance.Memory), ) e.globalContext.previousCallFrameStackPointer = saved @@ -636,10 +642,10 @@ func (e *engine) builtinFunctionMemoryGrow(mem *wasm.MemoryInstance) { func (e *engine) Compile(f *wasm.FunctionInstance) (err error) { var compiled *compiledFunction - if f.IsHostFunction() { - compiled, err = compileHostFunction(f) - } else { + if f.FunctionKind == wasm.FunctionKindWasm { compiled, err = compileWasmFunction(f) + } else { + compiled, err = compileHostFunction(f) } if err != nil { return fmt.Errorf("failed to compile function: %w", err) diff --git a/internal/wasm/jit/jit_amd64.go b/internal/wasm/jit/jit_amd64.go index 58449841..c4805825 100644 --- a/internal/wasm/jit/jit_amd64.go +++ b/internal/wasm/jit/jit_amd64.go @@ -4829,7 +4829,7 @@ func (c *amd64Compiler) callFunction(addr wasm.FunctionAddress, addrReg int16, f } // returnFunction adds instructions to return from the current callframe back to the caller's frame. -// If this is the current one is the origin, we return back to the engine.execFunction with the Returned status. +// If this is the current one is the origin, we return back to the engine.execWasmFunction with the Returned status. // Otherwise, we jump into the callers' return address stored in callFrame.returnAddress while setting // up all the necessary change on the engine's state. // diff --git a/internal/wasm/jit/jit_amd64_test.go b/internal/wasm/jit/jit_amd64_test.go index 0e68d8fa..26ae2942 100644 --- a/internal/wasm/jit/jit_amd64_test.go +++ b/internal/wasm/jit/jit_amd64_test.go @@ -29,7 +29,7 @@ func (j *jitEnv) requireNewCompiler(t *testing.T) *amd64Compiler { return &amd64Compiler{builder: b, locationStack: newValueLocationStack(), labels: map[string]*labelInfo{}, - f: &wasm.FunctionInstance{ModuleInstance: j.moduleInstance}, + f: &wasm.FunctionInstance{ModuleInstance: j.moduleInstance, FunctionKind: wasm.FunctionKindWasm}, } } @@ -506,9 +506,9 @@ func TestAmd64Compiler_compileBrTable(t *testing.T) { } func TestAmd64Compiler_pushFunctionInputs(t *testing.T) { - f := &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{ - Params: []wasm.ValueType{wasm.ValueTypeF64, wasm.ValueTypeI32}, - }}} + f := &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{Params: []wasm.ValueType{wasm.ValueTypeF64, wasm.ValueTypeI32}}}} compiler := &amd64Compiler{locationStack: newValueLocationStack(), f: f} compiler.pushFunctionParams() require.Equal(t, uint64(len(f.FunctionType.Type.Params)), compiler.locationStack.sp) @@ -5835,7 +5835,10 @@ func TestAmd64Compiler_compileGlobalGet(t *testing.T) { globals := []*wasm.GlobalInstance{nil, {Val: globalValue, Type: &wasm.GlobalType{ValType: tp}}, nil} env.addGlobals(globals...) // Compiler needs global type information at compilation time. - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Globals: globals}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{Globals: globals}, + FunctionKind: wasm.FunctionKindWasm, + } // Emit the code. err := compiler.compilePreamble() @@ -5966,7 +5969,7 @@ func TestAmd64Compiler_callFunction(t *testing.T) { expectedValue := uint32(0) moduleInstanceToExpectedValueInMemory := map[*wasm.ModuleInstance]uint32{} for i := 0; i < numCalls; i++ { - // Each function takes one arguments, adds the value with 100 + i and returns the result. + // Each function takes one argument, adds the value with 100 + i and returns the result. addTargetValue := uint32(100 + i) moduleInstance := &wasm.ModuleInstance{ Memory: &wasm.MemoryInstance{Buffer: make([]byte, 1024)}, @@ -5974,13 +5977,17 @@ func TestAmd64Compiler_callFunction(t *testing.T) { moduleInstanceToExpectedValueInMemory[moduleInstance] = addTargetValue compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: moduleInstance} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + ModuleInstance: moduleInstance, + } err := compiler.compilePreamble() require.NoError(t, err) expectedValue += addTargetValue - err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: uint32(addTargetValue)}) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: addTargetValue}) require.NoError(t, err) err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32}) @@ -6079,7 +6086,11 @@ func TestAmd64Compiler_compileCall(t *testing.T) { { // Call target function takes three i32 arguments and does ADD 2 times. compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: &wasm.ModuleInstance{}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{}, + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + } err := compiler.compilePreamble() require.NoError(t, err) for i := 0; i < 2; i++ { @@ -6102,7 +6113,11 @@ func TestAmd64Compiler_compileCall(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{ Functions: []*wasm.FunctionInstance{ - {FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, Address: targetFunctionAddress}, + { + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + Address: targetFunctionAddress, + }, }, }} @@ -6149,7 +6164,10 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { targetOperation := &wazeroir.OperationCallIndirect{} // Ensure that the module instance has the type information for targetOperation.TypeIndex. - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}}}} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}}}, + } // Place the offfset value. loc := compiler.locationStack.pushValueLocationOnStack() @@ -6180,8 +6198,10 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { targetOperation := &wazeroir.OperationCallIndirect{} targetOffset := &wazeroir.OperationConstI32{Value: uint32(0)} // Ensure that the module instance has the type information for targetOperation.TypeIndex, - compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{ - Type: &wasm.FunctionType{}, TypeID: 1000}}}} + compiler.f = &wasm.FunctionInstance{ + ModuleInstance: &wasm.ModuleInstance{Types: []*wasm.TypeInstance{{Type: &wasm.FunctionType{}, TypeID: 1000}}}, + FunctionKind: wasm.FunctionKindWasm, + } // and the typeID doesn't match the table[targetOffset]'s type ID. table[0] = wasm.TableElement{FunctionTypeID: wasm.UninitializedTableElementTypeID} @@ -6285,7 +6305,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { require.NoError(t, err) // Ensure that the module instance has the type information for targetOperation.TypeIndex, - compiler.f = &wasm.FunctionInstance{ModuleInstance: moduleInstance} + compiler.f = &wasm.FunctionInstance{ModuleInstance: moduleInstance, FunctionKind: wasm.FunctionKindWasm} // and the typeID matches the table[targetOffset]'s type ID. // Place the offfset value. Here we try calling a function of functionaddr == table[i].FunctionAddress. diff --git a/internal/wasm/jit/jit_arm64_test.go b/internal/wasm/jit/jit_arm64_test.go index a9189f3d..bd85d42b 100644 --- a/internal/wasm/jit/jit_arm64_test.go +++ b/internal/wasm/jit/jit_arm64_test.go @@ -43,7 +43,10 @@ func requirePushTwoFloat32Consts(t *testing.T, x1, x2 float32, compiler *arm64Co } func (j *jitEnv) requireNewCompiler(t *testing.T) *arm64Compiler { - cmp, err := newCompiler(&wasm.FunctionInstance{ModuleInstance: j.moduleInstance}, nil) + cmp, err := newCompiler(&wasm.FunctionInstance{ + ModuleInstance: j.moduleInstance, + FunctionKind: wasm.FunctionKindWasm, + }, nil) require.NoError(t, err) ret, ok := cmp.(*arm64Compiler) require.True(t, ok) @@ -1692,8 +1695,11 @@ func TestArm64Compiler_compileCall(t *testing.T) { expectedValue += addTargetValue compiler := env.requireNewCompiler(t) - compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, - ModuleInstance: &wasm.ModuleInstance{}} + compiler.f = &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, + FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, + ModuleInstance: &wasm.ModuleInstance{}, + } err := compiler.compilePreamble() require.NoError(t, err) diff --git a/internal/wasm/jit/jit_test.go b/internal/wasm/jit/jit_test.go index 6f0b834f..562fb47c 100644 --- a/internal/wasm/jit/jit_test.go +++ b/internal/wasm/jit/jit_test.go @@ -115,6 +115,7 @@ func (j *jitEnv) exec(code []byte) { codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), source: &wasm.FunctionInstance{ + FunctionKind: wasm.FunctionKindWasm, FunctionType: &wasm.TypeInstance{Type: &wasm.FunctionType{}}, }, } diff --git a/internal/wasm/store.go b/internal/wasm/store.go index e3855500..a2fc99f1 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -109,10 +109,13 @@ type ( FunctionType *TypeInstance // LocalTypes holds types of locals. LocalTypes []ValueType + // FunctionKind describes how this function should be called. + FunctionKind FunctionKind // HostFunction holds the runtime representation of host functions. - // If this is not nil, all the above fields are ignored as they are specific to non-host functions. + // This is nil when FunctionKind == FunctionKindWasm. Otherwise, all the above fields are ignored as they are + // specific to Wasm functions. HostFunction *reflect.Value - // Address is the funcaddr(https://www.w3.org/TR/wasm-core-1/#syntax-funcaddr) of this function insntance. + // Address is the funcaddr(https://www.w3.org/TR/wasm-core-1/#syntax-funcaddr) of this function instance. // More precisely, this equals the index of this function instance in store.FunctionInstances. // All function calls are made via funcaddr at runtime, not the index (scoped to a module). // @@ -190,7 +193,7 @@ type ( // and the index to Store.Functions. FunctionAddress uint64 - // FunctionTypeID is an uniquely assigned integer for a function type. + // FunctionTypeID is a uniquely assigned integer for a function type. // This is wazero specific runtime object and specific to a store, // and used at runtime to do type-checks on indirect function calls. FunctionTypeID uint32 @@ -249,10 +252,6 @@ func (m *ModuleInstance) GetFunction(name string, rv ValueType) (*FunctionInstan return exp.Function, nil } -func (f *FunctionInstance) IsHostFunction() bool { - return f.HostFunction != nil -} - func NewStore(engine Engine) *Store { return &Store{ ModuleInstances: map[string]*ModuleInstance{}, @@ -652,6 +651,7 @@ func (s *Store) buildFunctionInstances(module *Module, target *ModuleInstance) ( f := &FunctionInstance{ Name: name, + FunctionKind: FunctionKindWasm, FunctionType: typeInstance, Body: module.CodeSection[codeIndex].Body, LocalTypes: module.CodeSection[codeIndex].LocalTypes, @@ -885,6 +885,7 @@ func (s *Store) AddHostFunction(moduleName string, hf *HostFunction) error { f := &FunctionInstance{ Name: fmt.Sprintf("%s.%s", moduleName, hf.name), HostFunction: hf.goFunc, + FunctionKind: hf.functionKind, FunctionType: typeInstance, ModuleInstance: m, } @@ -904,18 +905,6 @@ func (s *Store) AddHostFunction(moduleName string, hf *HostFunction) error { return nil } -func NewHostFunction(funcName string, goFunc interface{}) (*HostFunction, error) { - hf := &HostFunction{name: funcName} - fn := reflect.ValueOf(goFunc) - hf.goFunc = &fn - ft, err := GetFunctionType(hf.name, hf.goFunc) - if err != nil { - return nil, err - } - hf.functionType = ft - return hf, nil -} - func (s *Store) AddGlobal(moduleName, name string, value uint64, valueType ValueType, mutable bool) error { g := &GlobalInstance{ Val: value, diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index c3e2556b..e3d0998c 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "math" "os" - "reflect" "strconv" "testing" @@ -40,6 +39,7 @@ func TestStore_CallFunction(t *testing.T) { fn: { Kind: ExportKindFunc, Function: &FunctionInstance{ + FunctionKind: FunctionKindWasm, FunctionType: &TypeInstance{ Type: &FunctionType{ Params: []ValueType{}, @@ -89,11 +89,11 @@ func TestStore_CallFunction(t *testing.T) { func TestStore_AddHostFunction(t *testing.T) { s := NewStore(nopEngineInstance) - hostFunction := func(wasm.HostFunctionCallContext) { - } - hf := newHostFunction(t, "fn", hostFunction) - err := s.AddHostFunction("test", hf) + hf, err := NewHostFunction("fn", func(wasm.HostFunctionCallContext) { + }) + require.NoError(t, err) + err = s.AddHostFunction("test", hf) require.NoError(t, err) // The function was added to the store, prefixed by the owning module name @@ -116,21 +116,13 @@ func TestStore_AddHostFunction(t *testing.T) { require.Equal(t, map[string]*ExportInstance{"fn": exp}, m.Exports) } -func newHostFunction(t *testing.T, name string, hostFunction interface{}) *HostFunction { - hf := &HostFunction{name: name} - goFn := reflect.ValueOf(hostFunction) - hf.goFunc = &goFn - ft, err := GetFunctionType(hf.name, hf.goFunc) - require.NoError(t, err) - hf.functionType = ft - return hf -} - func TestStore_ExportImportedHostFunction(t *testing.T) { s := NewStore(nopEngineInstance) - hostFunction := func(wasm.HostFunctionCallContext) { - } - err := s.AddHostFunction("", newHostFunction(t, "host_fn", hostFunction)) + + hf, err := NewHostFunction("host_fn", func(wasm.HostFunctionCallContext) { + }) + require.NoError(t, err) + err = s.AddHostFunction("", hf) require.NoError(t, err) t.Run("ModuleInstance is the importing module", func(t *testing.T) { @@ -212,7 +204,7 @@ func TestStore_addHostFunction(t *testing.T) { t.Run("ok", func(t *testing.T) { s := NewStore(nopEngineInstance) for i := 0; i < 10; i++ { - f := &FunctionInstance{} + f := &FunctionInstance{FunctionKind: FunctionKindHostNoContext} require.Len(t, s.Functions, i) err := s.addFunctionInstance(f) diff --git a/store.go b/store.go index 9cd20d8f..1558203d 100644 --- a/store.go +++ b/store.go @@ -20,10 +20,44 @@ func NewEngineJIT() *Engine { // TODO: compiler? } // HostFunctions are functions written in Go, which a WebAssembly Module can import. +// +// Noting a context exception described later, all parameters or result types must match WebAssembly 1.0 (MVP) value +// types. This means uint32, uint64, float32 or float64. Up to one result can be returned. +// +// Ex. This is a valid host function: +// +// addInts := func(x uint32, uint32) uint32 { +// return x + y +// } +// +// Host functions may also have an initial parameter (param[0]) of type context.Context or wasm.HostFunctionCallContext. +// +// Ex. This uses a Go Context: +// +// addInts := func(ctx context.Context, x uint32, uint32) uint32 { +// // add a little extra if we put some in the context! +// return x + y + ctx.Value(extraKey).(uint32) +// } +// +// The most sophisticated context is wasm.HostFunctionCallContext, which allows access to the Go context, but also +// allows writing to memory. This is important because there are only numeric types in Wasm. The only way to share other +// data is via writing memory and sharing offsets. +// +// Ex. This reads the parameters from! +// +// addInts := func(ctx wasm.HostFunctionCallContext, offset uint32) uint32 { +// x, _ := ctx.Memory().ReadUint32Le(offset) +// y, _ := ctx.Memory().ReadUint32Le(offset + 4) // 32 bits == 4 bytes! +// return x + y +// } +// +// See https://www.w3.org/TR/wasm-core-1/#value-types%E2%91%A0 type HostFunctions struct { nameToHostFunction map[string]*internalwasm.HostFunction } +// NewHostFunctions returns host functions to export. The map key is the name to export and the value is the function. +// See HostFunctions documentation for notes on writing host functions. func NewHostFunctions(nameToGoFunc map[string]interface{}) (ret *HostFunctions, err error) { ret = &HostFunctions{make(map[string]*internalwasm.HostFunction, len(nameToGoFunc))} for name, goFunc := range nameToGoFunc { diff --git a/tests/engine/adhoc_test.go b/tests/engine/adhoc_test.go index 345fb590..b6644774 100644 --- a/tests/engine/adhoc_test.go +++ b/tests/engine/adhoc_test.go @@ -3,6 +3,7 @@ package adhoc import ( "context" _ "embed" + "fmt" "math" "runtime" "sync" @@ -59,7 +60,7 @@ func runTests(t *testing.T, newEngine func() *wazero.Engine) { importedAndExportedFunc(t, newEngine) }) t.Run("host function with float32 type", func(t *testing.T) { - hostFuncWithFloatParam(t, newEngine) + hostFunctions(t, newEngine) }) } @@ -273,8 +274,8 @@ func importedAndExportedFunc(t *testing.T, newEngine func() *wazero.Engine) { require.Equal(t, []byte{0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0}, memory.Buffer[0:10]) } -// hostFuncWithFloatParam fails if a float parameter corrupts a host function value -func hostFuncWithFloatParam(t *testing.T, newEngine func() *wazero.Engine) { +// hostFunctions ensures arg0 is optionally a context, and fails if a float parameter corrupts a host function value +func hostFunctions(t *testing.T, newEngine func() *wazero.Engine) { ctx := context.Background() mod, err := wazero.DecodeModuleText([]byte(`(module $test ;; these imports return the input param @@ -295,41 +296,69 @@ func hostFuncWithFloatParam(t *testing.T, newEngine func() *wazero.Engine) { )`)) require.NoError(t, err) - hostFuncs, err := wazero.NewHostFunctions(map[string]interface{}{ - "identity_f32": func(ctx publicwasm.HostFunctionCallContext, value float32) float32 { + floatFuncs, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(value float32) float32 { return value }, - "identity_f64": func(ctx publicwasm.HostFunctionCallContext, value float64) float64 { + "identity_f64": func(value float64) float64 { return value }}) require.NoError(t, err) - store, err := wazero.NewStoreWithConfig(&wazero.StoreConfig{ - Engine: newEngine(), - ModuleToHostFunctions: map[string]*wazero.HostFunctions{"host": hostFuncs}, - }) + floatFuncsGoContext, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(funcCtx context.Context, value float32) float32 { + require.Equal(t, ctx, funcCtx) + return value + }, + "identity_f64": func(funcCtx context.Context, value float64) float64 { + require.Equal(t, ctx, funcCtx) + return value + }}) require.NoError(t, err) - m, err := store.Instantiate(mod) + floatFuncsHostFunctionCallContext, err := wazero.NewHostFunctions(map[string]interface{}{ + "identity_f32": func(funcCtx publicwasm.HostFunctionCallContext, value float32) float32 { + require.Equal(t, ctx, funcCtx.Context()) + return value + }, + "identity_f64": func(funcCtx publicwasm.HostFunctionCallContext, value float64) float64 { + require.Equal(t, ctx, funcCtx.Context()) + return value + }}) require.NoError(t, err) - t.Run("host function with f32 param", func(t *testing.T) { - fn, ok := m.GetFunctionF32Return("call->test.identity_f32") - require.True(t, ok) - - f32 := float32(math.MaxFloat32) - result, err := fn(ctx, uint64(math.Float32bits(f32))) // float bits are a uint32 value, call requires uint64 + for k, v := range map[string]*wazero.HostFunctions{ + "": floatFuncs, + " - context.Context": floatFuncsGoContext, + " - wasm.HostFunctionCallContext": floatFuncsHostFunctionCallContext, + } { + store, err := wazero.NewStoreWithConfig(&wazero.StoreConfig{ + Engine: newEngine(), + ModuleToHostFunctions: map[string]*wazero.HostFunctions{"host": v}, + }) require.NoError(t, err) - require.Equal(t, f32, result) - }) - t.Run("host function with f64 param", func(t *testing.T) { - fn, ok := m.GetFunctionF64Return("call->test.identity_f64") - require.True(t, ok) - - f64 := math.MaxFloat64 - result, err := fn(ctx, math.Float64bits(f64)) + m, err := store.Instantiate(mod) require.NoError(t, err) - require.Equal(t, f64, result) - }) + + t.Run(fmt.Sprintf("host function with f32 param%s", k), func(t *testing.T) { + fn, ok := m.GetFunctionF32Return("call->test.identity_f32") + require.True(t, ok) + + f32 := float32(math.MaxFloat32) + result, err := fn(ctx, uint64(math.Float32bits(f32))) // float bits are a uint32 value, call requires uint64 + require.NoError(t, err) + require.Equal(t, f32, result) + }) + + t.Run(fmt.Sprintf("host function with f64 param%s", k), func(t *testing.T) { + fn, ok := m.GetFunctionF64Return("call->test.identity_f64") + require.True(t, ok) + + f64 := math.MaxFloat64 + result, err := fn(ctx, math.Float64bits(f64)) + require.NoError(t, err) + require.Equal(t, f64, result) + }) + } }