diff --git a/api/wasm.go b/api/wasm.go index a7b8e8cc..c8cd093a 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -307,6 +307,19 @@ type Function interface { // The Module parameter is the calling module, used to access memory or // exported functions. See GoModuleFunc for an example. // +// The stack is includes any parameters encoded according to their ValueType. +// Its length is the max of parameter or result length. When there are results, +// write them in order beginning at index zero. Do not use the stack after the +// function returns. +// +// Here's a typical way to read three parameters and write back one. +// +// // read parameters off the stack in index order +// argv, argvBuf := uint32(stack[0]), uint32(stack[1]) +// +// // write results back to the stack in index order +// stack[0] = uint64(ErrnoSuccess) +// // This function can be non-deterministic or cause side effects. It also // has special properties not defined in the WebAssembly Core specification. // Notably, this uses the caller's memory (via Module.Memory). See @@ -317,25 +330,28 @@ type Function interface { // idiomatic as they can map go types to ValueType. This type is exposed for // those willing to trade usability and safety for performance. type GoModuleFunction interface { - Call(ctx context.Context, mod Module, params []uint64) []uint64 + Call(ctx context.Context, mod Module, stack []uint64) } // GoModuleFunc is a convenience for defining an inlined function. // // For example, the following returns a uint32 value read from parameter zero: // -// api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { -// ret, ok := mod.Memory().ReadUint32Le(ctx, uint32(params[0])) +// api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) { +// offset := uint32(params[0]) // read the parameter from the stack +// +// ret, ok := mod.Memory().ReadUint32Le(ctx, offset) // if !ok { // panic("out of memory") // } -// return []uint64{uint64(ret)} -// } -type GoModuleFunc func(ctx context.Context, mod Module, params []uint64) []uint64 +// +// results[0] = uint64(ret) // add the result back to the stack. +// }) +type GoModuleFunc func(ctx context.Context, mod Module, stack []uint64) // Call implements GoModuleFunction.Call. -func (f GoModuleFunc) Call(ctx context.Context, mod Module, params []uint64) []uint64 { - return f(ctx, mod, params) +func (f GoModuleFunc) Call(ctx context.Context, mod Module, stack []uint64) { + f(ctx, mod, stack) } // GoFunction is an optimized form of GoModuleFunction which doesn't require @@ -344,23 +360,22 @@ func (f GoModuleFunc) Call(ctx context.Context, mod Module, params []uint64) []u // For example, this function does not need to use the importing module's // memory or exported functions. type GoFunction interface { - Call(ctx context.Context, params []uint64) []uint64 + Call(ctx context.Context, stack []uint64) } // GoFunc is a convenience for defining an inlined function. // // For example, the following returns the sum of two uint32 parameters: // -// api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { +// api.GoFunc(func(ctx context.Context, stack []uint64) { // x, y := uint32(params[0]), uint32(params[1]) -// sum := x + y -// return []uint64{sum} -// } -type GoFunc func(ctx context.Context, params []uint64) []uint64 +// results[0] = uint64(x + y) +// }) +type GoFunc func(ctx context.Context, stack []uint64) // Call implements GoFunction.Call. -func (f GoFunc) Call(ctx context.Context, params []uint64) []uint64 { - return f(ctx, params) +func (f GoFunc) Call(ctx context.Context, stack []uint64) { + f(ctx, stack) } // Global is a WebAssembly 1.0 (20191205) global exported from an instantiated module (wazero.Runtime InstantiateModule). diff --git a/builder_test.go b/builder_test.go index f7daee2c..47e4b83e 100644 --- a/builder_test.go +++ b/builder_test.go @@ -20,11 +20,11 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { return 0 } - gofunc1 := api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { - return []uint64{0} + gofunc1 := api.GoFunc(func(ctx context.Context, stack []uint64) { + stack[0] = 0 }) - gofunc2 := api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { - return []uint64{0} + gofunc2 := api.GoFunc(func(ctx context.Context, stack []uint64) { + stack[0] = 0 }) tests := []struct { diff --git a/imports/assemblyscript/assemblyscript.go b/imports/assemblyscript/assemblyscript.go index 4b1033ca..3492b6e4 100644 --- a/imports/assemblyscript/assemblyscript.go +++ b/imports/assemblyscript/assemblyscript.go @@ -154,11 +154,11 @@ var abortMessageEnabled = &wasm.HostFunc{ var abortMessageDisabled = abortMessageEnabled.WithGoModuleFunc(abort) // abortWithMessage implements functionAbort -func abortWithMessage(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - message := uint32(params[0]) - fileName := uint32(params[1]) - lineNumber := uint32(params[2]) - columnNumber := uint32(params[3]) +func abortWithMessage(ctx context.Context, mod api.Module, stack []uint64) { + message := uint32(stack[0]) + fileName := uint32(stack[1]) + lineNumber := uint32(stack[2]) + columnNumber := uint32(stack[3]) sysCtx := mod.(*wasm.CallContext).Sys mem := mod.Memory() // Don't panic if there was a problem reading the message @@ -167,12 +167,11 @@ func abortWithMessage(ctx context.Context, mod api.Module, params []uint64) (_ [ _, _ = fmt.Fprintf(sysCtx.Stderr(), "%s at %s:%d:%d\n", msg, fn, lineNumber, columnNumber) } } - abort(ctx, mod, params) - return + abort(ctx, mod, stack) } // abortWithMessage implements functionAbort ignoring the message. -func abort(ctx context.Context, mod api.Module, _ []uint64) (_ []uint64) { +func abort(ctx context.Context, mod api.Module, _ []uint64) { // AssemblyScript expects the exit code to be 255 // See https://github.com/AssemblyScript/assemblyscript/blob/v0.20.13/tests/compiler/wasi/abort.js#L14 exitCode := uint32(255) @@ -195,17 +194,15 @@ var traceStdout = &wasm.HostFunc{ ParamNames: []string{"message", "nArgs", "arg0", "arg1", "arg2", "arg3", "arg4"}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - traceTo(ctx, mod, params, mod.(*wasm.CallContext).Sys.Stdout()) - return + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) { + traceTo(ctx, mod, stack, mod.(*wasm.CallContext).Sys.Stdout()) }), }, } // traceStderr implements trace to the configured Stderr. -var traceStderr = traceStdout.WithGoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - traceTo(ctx, mod, params, mod.(*wasm.CallContext).Sys.Stderr()) - return +var traceStderr = traceStdout.WithGoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) { + traceTo(ctx, mod, stack, mod.(*wasm.CallContext).Sys.Stderr()) }) // traceTo implements the function "trace" in AssemblyScript. e.g. @@ -277,16 +274,16 @@ var seed = &wasm.HostFunc{ ResultTypes: []api.ValueType{f64}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) { r := mod.(*wasm.CallContext).Sys.RandSource() buf := make([]byte, 8) _, err := io.ReadFull(r, buf) if err != nil { panic(fmt.Errorf("error reading random seed: %w", err)) } - // the caller interprets this as a float64 - raw := binary.LittleEndian.Uint64(buf) - return []uint64{raw} + + // the caller interprets the result as a float64 + stack[0] = binary.LittleEndian.Uint64(buf) }), }, } diff --git a/imports/emscripten/emscripten.go b/imports/emscripten/emscripten.go index 48f55d90..46980b37 100644 --- a/imports/emscripten/emscripten.go +++ b/imports/emscripten/emscripten.go @@ -147,12 +147,12 @@ var invokeI = &wasm.HostFunc{ }, } -func invokeIFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "v_i32", wasm.Index(params[0]), nil) +func invokeIFn(ctx context.Context, mod api.Module, stack []uint64) { + ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "v_i32", wasm.Index(stack[0]), nil) if err != nil { panic(err) } - return ret + stack[0] = ret[0] } var invokeIi = &wasm.HostFunc{ @@ -167,12 +167,12 @@ var invokeIi = &wasm.HostFunc{ }, } -func invokeIiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32_i32", wasm.Index(params[0]), params[1:]) +func invokeIiFn(ctx context.Context, mod api.Module, stack []uint64) { + ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32_i32", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret + stack[0] = ret[0] } var invokeIii = &wasm.HostFunc{ @@ -187,12 +187,12 @@ var invokeIii = &wasm.HostFunc{ }, } -func invokeIiiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32_i32", wasm.Index(params[0]), params[1:]) +func invokeIiiFn(ctx context.Context, mod api.Module, stack []uint64) { + ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32_i32", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret + stack[0] = ret[0] } var invokeIiii = &wasm.HostFunc{ @@ -207,12 +207,12 @@ var invokeIiii = &wasm.HostFunc{ }, } -func invokeIiiiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32_i32", wasm.Index(params[0]), params[1:]) +func invokeIiiiFn(ctx context.Context, mod api.Module, stack []uint64) { + ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32_i32", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret + stack[0] = ret[0] } var invokeIiiii = &wasm.HostFunc{ @@ -227,12 +227,12 @@ var invokeIiiii = &wasm.HostFunc{ }, } -func invokeIiiiiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32i32_i32", wasm.Index(params[0]), params[1:]) +func invokeIiiiiFn(ctx context.Context, mod api.Module, stack []uint64) { + ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32i32_i32", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret + stack[0] = ret[0] } var invokeV = &wasm.HostFunc{ @@ -247,12 +247,11 @@ var invokeV = &wasm.HostFunc{ }, } -func invokeVFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "v_v", wasm.Index(params[0]), nil) +func invokeVFn(ctx context.Context, mod api.Module, stack []uint64) { + _, err := callDynamic(ctx, mod.(*wasm.CallContext), "v_v", wasm.Index(stack[0]), nil) if err != nil { panic(err) } - return ret } var invokeVi = &wasm.HostFunc{ @@ -267,12 +266,11 @@ var invokeVi = &wasm.HostFunc{ }, } -func invokeViFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32_v", wasm.Index(params[0]), params[1:]) +func invokeViFn(ctx context.Context, mod api.Module, stack []uint64) { + _, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32_v", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret } var invokeVii = &wasm.HostFunc{ @@ -287,12 +285,11 @@ var invokeVii = &wasm.HostFunc{ }, } -func invokeViiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32_v", wasm.Index(params[0]), params[1:]) +func invokeViiFn(ctx context.Context, mod api.Module, stack []uint64) { + _, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32_v", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret } var invokeViii = &wasm.HostFunc{ @@ -307,12 +304,11 @@ var invokeViii = &wasm.HostFunc{ }, } -func invokeViiiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32_v", wasm.Index(params[0]), params[1:]) +func invokeViiiFn(ctx context.Context, mod api.Module, stack []uint64) { + _, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32_v", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret } var invokeViiii = &wasm.HostFunc{ @@ -327,12 +323,11 @@ var invokeViiii = &wasm.HostFunc{ }, } -func invokeViiiiFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32i32_v", wasm.Index(params[0]), params[1:]) +func invokeViiiiFn(ctx context.Context, mod api.Module, stack []uint64) { + _, err := callDynamic(ctx, mod.(*wasm.CallContext), "i32i32i32i32_v", wasm.Index(stack[0]), stack[1:]) if err != nil { panic(err) } - return ret } // callDynamic special cases dynamic calls needed for emscripten `invoke_` diff --git a/imports/wasi_snapshot_preview1/args.go b/imports/wasi_snapshot_preview1/args.go index 33ecf0d3..441c9dfc 100644 --- a/imports/wasi_snapshot_preview1/args.go +++ b/imports/wasi_snapshot_preview1/args.go @@ -52,11 +52,11 @@ var argsGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(argsGetFn), + GoFunc: wasiFunc(argsGetFn), }, } -func argsGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func argsGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys argv, argvBuf := uint32(params[0]), uint32(params[1]) return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Args(), argv, argvBuf) @@ -99,20 +99,21 @@ var argsSizesGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(argsSizesGetFn), + GoFunc: wasiFunc(argsSizesGetFn), }, } -func argsSizesGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func argsSizesGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys mem := mod.Memory() resultArgc, resultArgvLen := uint32(params[0]), uint32(params[1]) + // Write the Errno back to the stack if !mem.WriteUint32Le(ctx, resultArgc, uint32(len(sysCtx.Args()))) { - return errnoFault + return ErrnoFault } if !mem.WriteUint32Le(ctx, resultArgvLen, sysCtx.ArgsSize()) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } diff --git a/imports/wasi_snapshot_preview1/clock.go b/imports/wasi_snapshot_preview1/clock.go index 30bea72e..16a091fb 100644 --- a/imports/wasi_snapshot_preview1/clock.go +++ b/imports/wasi_snapshot_preview1/clock.go @@ -59,11 +59,11 @@ var clockResGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(clockResGetFn), + GoFunc: wasiFunc(clockResGetFn), }, } -func clockResGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func clockResGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys id, resultResolution := uint32(params[0]), uint32(params[1]) @@ -74,12 +74,13 @@ func clockResGetFn(ctx context.Context, mod api.Module, params []uint64) []uint6 case clockIDMonotonic: resolution = uint64(sysCtx.NanotimeResolution()) default: - return errnoInval + return ErrnoInval } + if !mod.Memory().WriteUint64Le(ctx, resultResolution, resolution) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } // clockTimeGet is the WASI function named functionClockTimeGet that returns @@ -121,15 +122,15 @@ var clockTimeGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(clockTimeGetFn), + GoFunc: wasiFunc(clockTimeGetFn), }, } -func clockTimeGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func clockTimeGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys id := uint32(params[0]) // TODO: precision is currently ignored. - _ = params[1] + // precision = params[1] resultTimestamp := uint32(params[2]) var val uint64 @@ -140,11 +141,11 @@ func clockTimeGetFn(ctx context.Context, mod api.Module, params []uint64) []uint case clockIDMonotonic: val = uint64(sysCtx.Nanotime(ctx)) default: - return errnoInval + return ErrnoInval } if !mod.Memory().WriteUint64Le(ctx, resultTimestamp, val) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } diff --git a/imports/wasi_snapshot_preview1/environ.go b/imports/wasi_snapshot_preview1/environ.go index 5516757e..5d9c0e4d 100644 --- a/imports/wasi_snapshot_preview1/environ.go +++ b/imports/wasi_snapshot_preview1/environ.go @@ -52,13 +52,14 @@ var environGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(environGetFn), + GoFunc: wasiFunc(environGetFn), }, } -func environGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func environGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys environ, environBuf := uint32(params[0]), uint32(params[1]) + return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Environ(), environ, environBuf) } @@ -101,20 +102,20 @@ var environSizesGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(environSizesGetFn), + GoFunc: wasiFunc(environSizesGetFn), }, } -func environSizesGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func environSizesGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys mem := mod.Memory() resultEnvironc, resultEnvironvLen := uint32(params[0]), uint32(params[1]) if !mem.WriteUint32Le(ctx, resultEnvironc, uint32(len(sysCtx.Environ()))) { - return errnoFault + return ErrnoFault } if !mem.WriteUint32Le(ctx, resultEnvironvLen, sysCtx.EnvironSize()) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } diff --git a/imports/wasi_snapshot_preview1/fs.go b/imports/wasi_snapshot_preview1/fs.go index 5ab65cfc..fe0238ab 100644 --- a/imports/wasi_snapshot_preview1/fs.go +++ b/imports/wasi_snapshot_preview1/fs.go @@ -92,19 +92,18 @@ var fdClose = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdCloseFn), + GoFunc: wasiFunc(fdCloseFn), }, } -func fdCloseFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdCloseFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fd := uint32(params[0]) if ok := sysCtx.FS(ctx).CloseFile(ctx, fd); !ok { - return errnoBadf + return ErrnoBadf } - - return errnoSuccess + return ErrnoSuccess } // fdDatasync is the WASI function named functionFdDatasync which synchronizes @@ -162,19 +161,19 @@ var fdFdstatGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdFdstatGetFn), + GoFunc: wasiFunc(fdFdstatGetFn), }, } -func fdFdstatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdFdstatGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys // TODO: actually write the fdstat! fd, _ := uint32(params[0]), uint32(params[1]) if _, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok { - return errnoBadf + return ErrnoBadf } - return errnoSuccess + return ErrnoSuccess } // fdFdstatSetFlags is the WASI function named functionFdFdstatSetFlags which @@ -251,7 +250,7 @@ var fdFilestatGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdFilestatGetFn), + GoFunc: wasiFunc(fdFilestatGetFn), }, } @@ -268,20 +267,20 @@ const ( wasiFiletypeSymbolicLink ) -func fdFilestatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdFilestatGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { return fdFilestatGetFunc(ctx, mod, uint32(params[0]), uint32(params[1])) } -func fdFilestatGetFunc(ctx context.Context, mod api.Module, fd, resultBuf uint32) []uint64 { +func fdFilestatGetFunc(ctx context.Context, mod api.Module, fd, resultBuf uint32) Errno { sysCtx := mod.(*wasm.CallContext).Sys file, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) if !ok { - return errnoBadf + return ErrnoBadf } fileStat, err := file.File.Stat() if err != nil { - return errnoIo + return ErrnoIo } fileMode := fileStat.Mode() @@ -301,7 +300,7 @@ func fdFilestatGetFunc(ctx context.Context, mod api.Module, fd, resultBuf uint32 buf, ok := mod.Memory().Read(ctx, resultBuf, 64) if !ok { - return errnoFault + return ErrnoFault } buf[16] = uint8(wasiFileMode) @@ -312,7 +311,7 @@ func fdFilestatGetFunc(ctx context.Context, mod api.Module, fd, resultBuf uint32 binary.LittleEndian.PutUint64(buf[48:], mtim) binary.LittleEndian.PutUint64(buf[56:], mtim) - return errnoSuccess + return ErrnoSuccess } // fdFilestatSetSize is the WASI function named functionFdFilestatSetSize which @@ -349,11 +348,11 @@ var fdPread = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdPreadFn), + GoFunc: wasiFunc(fdPreadFn), }, } -func fdPreadFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdPreadFn(ctx context.Context, mod api.Module, params []uint64) Errno { return fdReadOrPread(ctx, mod, params, true) } @@ -396,29 +395,29 @@ var fdPrestatGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdPrestatGetFn), + GoFunc: wasiFunc(fdPrestatGetFn), }, } -func fdPrestatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdPrestatGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fd, resultPrestat := uint32(params[0]), uint32(params[1]) entry, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) if !ok { - return errnoBadf + return ErrnoBadf } // Zero-value 8-bit tag, and 3-byte zero-value paddings, which is uint32le(0) in short. if !mod.Memory().WriteUint32Le(ctx, resultPrestat, uint32(0)) { - return errnoFault - } - // Write the length of the directory name at offset 4. - if !mod.Memory().WriteUint32Le(ctx, resultPrestat+4, uint32(len(entry.Path))) { - return errnoFault + return ErrnoFault } - return errnoSuccess + // Write the length of the directory name at offset 4. + if !mod.Memory().WriteUint32Le(ctx, resultPrestat+4, uint32(len(entry.Path))) { + return ErrnoFault + } + return ErrnoSuccess } // fdPrestatDirName is the WASI function named functionFdPrestatDirName which @@ -459,30 +458,30 @@ var fdPrestatDirName = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdPrestatDirNameFn), + GoFunc: wasiFunc(fdPrestatDirNameFn), }, } -func fdPrestatDirNameFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdPrestatDirNameFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fd, path, pathLen := uint32(params[0]), uint32(params[1]), uint32(params[2]) f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) if !ok { - return errnoBadf + return ErrnoBadf } // Some runtimes may have another semantics. See /RATIONALE.md if uint32(len(f.Path)) < pathLen { - return errnoNametoolong + return ErrnoNametoolong } // TODO: fdPrestatDirName may have to return ErrnoNotdir if the type of the // prestat data of `fd` is not a PrestatDir. if !mod.Memory().Write(ctx, path, []byte(f.Path)[:pathLen]) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } // fdPwrite is the WASI function named functionFdPwrite which writes to a file @@ -551,15 +550,15 @@ var fdRead = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdReadFn), + GoFunc: wasiFunc(fdReadFn), }, } -func fdReadFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdReadFn(ctx context.Context, mod api.Module, params []uint64) Errno { return fdReadOrPread(ctx, mod, params, false) } -func fdReadOrPread(ctx context.Context, mod api.Module, params []uint64, isPread bool) []uint64 { +func fdReadOrPread(ctx context.Context, mod api.Module, params []uint64, isPread bool) Errno { sysCtx := mod.(*wasm.CallContext).Sys mem := mod.Memory() fd := uint32(params[0]) @@ -577,16 +576,16 @@ func fdReadOrPread(ctx context.Context, mod api.Module, params []uint64, isPread r := internalsys.FdReader(ctx, sysCtx, fd) if r == nil { - return errnoBadf + return ErrnoBadf } if isPread { if s, ok := r.(io.Seeker); ok { if _, err := s.Seek(offset, io.SeekStart); err != nil { - return errnoFault + return ErrnoFault } } else { - return errnoInval + return ErrnoInval } } @@ -595,31 +594,32 @@ func fdReadOrPread(ctx context.Context, mod api.Module, params []uint64, isPread iov := iovs + i*8 offset, ok := mem.ReadUint32Le(ctx, iov) if !ok { - return errnoFault + return ErrnoFault } l, ok := mem.ReadUint32Le(ctx, iov+4) if !ok { - return errnoFault + return ErrnoFault } b, ok := mem.Read(ctx, offset, l) if !ok { - return errnoFault + return ErrnoFault } n, err := r.Read(b) nread += uint32(n) shouldContinue, errno := fdRead_shouldContinueRead(uint32(n), l, err) - if errno != nil { + if errno != ErrnoSuccess { return errno } else if !shouldContinue { break } } if !mem.WriteUint32Le(ctx, resultSize, nread) { - return errnoFault + return ErrnoFault + } else { + return ErrnoSuccess } - return errnoSuccess } // fdRead_shouldContinueRead decides whether to continue reading the next iovec @@ -627,16 +627,16 @@ func fdReadOrPread(ctx context.Context, mod api.Module, params []uint64, isPread // // Note: When there are both bytes read (n) and an error, this continues. // See /RATIONALE.md "Why ignore the error returned by io.Reader when n > 1?" -func fdRead_shouldContinueRead(n, l uint32, err error) (bool, []uint64) { +func fdRead_shouldContinueRead(n, l uint32, err error) (bool, Errno) { if errors.Is(err, io.EOF) { - return false, nil // EOF isn't an error, and we shouldn't continue. + return false, ErrnoSuccess // EOF isn't an error, and we shouldn't continue. } else if err != nil && n == 0 { - return false, errnoIo + return false, ErrnoIo } else if err != nil { - return false, nil // Allow the caller to process n bytes. + return false, ErrnoSuccess // Allow the caller to process n bytes. } // Continue reading, unless there's a partial read or nothing to read. - return n == l && n != 0, nil + return n == l && n != 0, ErrnoSuccess } // fdReaddir is the WASI function named functionFdReaddir which reads directory @@ -704,11 +704,11 @@ var fdSeek = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdSeekFn), + GoFunc: wasiFunc(fdSeekFn), }, } -func fdSeekFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdSeekFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fd := uint32(params[0]) offset := params[1] @@ -718,25 +718,24 @@ func fdSeekFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { var seeker io.Seeker // Check to see if the file descriptor is available if f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok || f.File == nil { - return errnoBadf + return ErrnoBadf // fs.FS doesn't declare io.Seeker, but implementations such as os.File implement it. } else if seeker, ok = f.File.(io.Seeker); !ok { - return errnoBadf + return ErrnoBadf } if whence > io.SeekEnd /* exceeds the largest valid whence */ { - return errnoInval + return ErrnoInval } + newOffset, err := seeker.Seek(int64(offset), int(whence)) if err != nil { - return errnoIo + return ErrnoIo } - if !mod.Memory().WriteUint64Le(ctx, resultNewoffset, uint64(newOffset)) { - return errnoFault + return ErrnoFault } - - return errnoSuccess + return ErrnoSuccess } // fdSync is the WASI function named functionFdSync which synchronizes the data @@ -824,11 +823,11 @@ var fdWrite = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(fdWriteFn), + GoFunc: wasiFunc(fdWriteFn), }, } -func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) Errno { fd := uint32(params[0]) iovs := uint32(params[1]) iovsCount := uint32(params[2]) @@ -837,7 +836,7 @@ func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { sysCtx := mod.(*wasm.CallContext).Sys writer := internalsys.FdWriter(ctx, sysCtx, fd) if writer == nil { - return errnoBadf + return ErrnoBadf } var err error @@ -846,13 +845,13 @@ func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { iov := iovs + i*8 offset, ok := mod.Memory().ReadUint32Le(ctx, iov) if !ok { - return errnoFault + return ErrnoFault } // Note: emscripten has been known to write zero length iovec. However, // it is not common in other compilers, so we don't optimize for it. l, ok := mod.Memory().ReadUint32Le(ctx, iov+4) if !ok { - return errnoFault + return ErrnoFault } var n int @@ -861,19 +860,19 @@ func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { } else { b, ok := mod.Memory().Read(ctx, offset, l) if !ok { - return errnoFault + return ErrnoFault } n, err = writer.Write(b) if err != nil { - return errnoIo + return ErrnoIo } } nwritten += uint32(n) } if !mod.Memory().WriteUint32Le(ctx, resultSize, nwritten) { - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } // pathCreateDirectory is the WASI function named functionPathCreateDirectory @@ -922,11 +921,11 @@ var pathFilestatGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(pathFilestatGetFn), + GoFunc: wasiFunc(pathFilestatGetFn), }, } -func pathFilestatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func pathFilestatGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fsc := sysCtx.FS(ctx) @@ -943,22 +942,22 @@ func pathFilestatGetFn(ctx context.Context, mod api.Module, params []uint64) []u // then join it with its parent b, ok := mod.Memory().Read(ctx, pathOffset, pathLen) if !ok { - return errnoNametoolong + return ErrnoNametoolong } pathName := string(b) if dir, ok := fsc.OpenedFile(ctx, fd); !ok { - return errnoBadf + return ErrnoBadf } else if dir.File == nil { // root } else if _, ok := dir.File.(fs.ReadDirFile); !ok { - return errnoNotdir + return ErrnoNotdir } else { pathName = path.Join(dir.Path, pathName) } // Sadly, we need to open the file to stat it. pathFd, errnoResult := openFile(ctx, fsc, pathName) - if errnoResult != nil { + if errnoResult != ErrnoSuccess { return errnoResult } @@ -1048,11 +1047,11 @@ var pathOpen = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(pathOpenFn), + GoFunc: wasiFunc(pathOpenFn), }, } -func pathOpenFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func pathOpenFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys fsc := sysCtx.FS(ctx) @@ -1067,24 +1066,24 @@ func pathOpenFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { resultOpenedFd := uint32(params[8]) if _, ok := fsc.OpenedFile(ctx, fd); !ok { - return errnoBadf + return ErrnoBadf } b, ok := mod.Memory().Read(ctx, path, pathLen) if !ok { - return errnoFault + return ErrnoFault } newFD, errnoResult := openFile(ctx, fsc, string(b)) - if errnoResult != nil { + if errnoResult != ErrnoSuccess { return errnoResult } if !mod.Memory().WriteUint32Le(ctx, resultOpenedFd, newFD) { _ = fsc.CloseFile(ctx, newFD) - return errnoFault + return ErrnoFault } - return errnoSuccess + return ErrnoSuccess } // pathReadlink is the WASI function named functionPathReadlink that reads the @@ -1141,26 +1140,27 @@ var pathUnlinkFile = stubFunction( // Note: Coercion isn't centralized in internalsys.FSContext because ABI use // different error codes. For example, wasi-filesystem and GOOS=js don't map to // these Errno. -func openFile(ctx context.Context, fsc *internalsys.FSContext, name string) (fd uint32, errnoResult []uint64) { +func openFile(ctx context.Context, fsc *internalsys.FSContext, name string) (fd uint32, errno Errno) { newFD, err := fsc.OpenFile(ctx, name) if err == nil { fd = newFD + errno = ErrnoSuccess return } // handle all the cases of FS.Open or internal to FSContext.OpenFile switch { case errors.Is(err, fs.ErrInvalid): - errnoResult = errnoInval + errno = ErrnoInval case errors.Is(err, fs.ErrNotExist): // fs.FS is allowed to return this instead of ErrInvalid on an invalid path - errnoResult = errnoNoent + errno = ErrnoNoent case errors.Is(err, fs.ErrExist): - errnoResult = errnoExist + errno = ErrnoExist case errors.Is(err, syscall.EBADF): // fsc.OpenFile currently returns this on out of file descriptors - errnoResult = errnoBadf + errno = ErrnoBadf default: - errnoResult = errnoIo + errno = ErrnoIo } return } diff --git a/imports/wasi_snapshot_preview1/fs_test.go b/imports/wasi_snapshot_preview1/fs_test.go index 932a254f..cb6a686a 100644 --- a/imports/wasi_snapshot_preview1/fs_test.go +++ b/imports/wasi_snapshot_preview1/fs_test.go @@ -925,7 +925,7 @@ func Test_fdRead_shouldContinueRead(t *testing.T) { n, l uint32 err error expectedOk bool - expectedErrno []uint64 + expectedErrno Errno }{ { name: "break when nothing to read", @@ -972,13 +972,13 @@ func Test_fdRead_shouldContinueRead(t *testing.T) { { name: "return ErrnoIo on error on nothing to read", err: io.ErrClosedPipe, - expectedErrno: errnoIo, + expectedErrno: ErrnoIo, }, { name: "return ErrnoIo on error on nothing read", l: 4, err: io.ErrClosedPipe, - expectedErrno: errnoIo, + expectedErrno: ErrnoIo, }, { // Special case, allows processing data before err name: "break on error on partial read", diff --git a/imports/wasi_snapshot_preview1/poll.go b/imports/wasi_snapshot_preview1/poll.go index d575c83d..7365df6d 100644 --- a/imports/wasi_snapshot_preview1/poll.go +++ b/imports/wasi_snapshot_preview1/poll.go @@ -55,18 +55,18 @@ var pollOneoff = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(pollOneoffFn), + GoFunc: wasiFunc(pollOneoffFn), }, } -func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) Errno { in := uint32(params[0]) out := uint32(params[1]) nsubscriptions := uint32(params[2]) resultNevents := uint32(params[3]) if nsubscriptions == 0 { - return errnoInval + return ErrnoInval } mem := mod.Memory() @@ -74,17 +74,17 @@ func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) []uint64 // Ensure capacity prior to the read loop to reduce error handling. inBuf, ok := mem.Read(ctx, in, nsubscriptions*48) if !ok { - return errnoFault + return ErrnoFault } outBuf, ok := mem.Read(ctx, out, nsubscriptions*32) if !ok { - return errnoFault + return ErrnoFault } // Eagerly write the number of events which will equal subscriptions unless // there's a fault in parsing (not processing). if !mod.Memory().WriteUint32Le(ctx, resultNevents, nsubscriptions) { - return errnoFault + return ErrnoFault } // Loop through all subscriptions and write their output. @@ -102,7 +102,7 @@ func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) []uint64 // +8 past userdata +4 FD alignment errno = processFDEvent(ctx, mod, eventType, inBuf[inOffset+8+4:]) default: - return errnoInval + return ErrnoInval } // Write the event corresponding to the processed subscription. @@ -113,7 +113,7 @@ func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) []uint64 binary.LittleEndian.PutUint32(outBuf[outOffset+10:], uint32(eventType)) // TODO: When FD events are supported, write outOffset+16 } - return errnoSuccess + return ErrnoSuccess } // processClockEvent supports only relative name events, as that's what's used diff --git a/imports/wasi_snapshot_preview1/proc.go b/imports/wasi_snapshot_preview1/proc.go index 937390ee..b0c76763 100644 --- a/imports/wasi_snapshot_preview1/proc.go +++ b/imports/wasi_snapshot_preview1/proc.go @@ -29,11 +29,11 @@ var procExit = &wasm.HostFunc{ ParamNames: []string{"rval"}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(procExitFn), + GoFunc: wasiFunc(procExitFn), }, } -func procExitFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func procExitFn(ctx context.Context, mod api.Module, params []uint64) Errno { exitCode := uint32(params[0]) // Ensure other callers see the exit code. diff --git a/imports/wasi_snapshot_preview1/random.go b/imports/wasi_snapshot_preview1/random.go index 35ecd04f..66936105 100644 --- a/imports/wasi_snapshot_preview1/random.go +++ b/imports/wasi_snapshot_preview1/random.go @@ -42,24 +42,24 @@ var randomGet = &wasm.HostFunc{ ResultTypes: []api.ValueType{i32}, Code: &wasm.Code{ IsHostFunction: true, - GoFunc: api.GoModuleFunc(randomGetFn), + GoFunc: wasiFunc(randomGetFn), }, } -func randomGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { +func randomGetFn(ctx context.Context, mod api.Module, params []uint64) Errno { sysCtx := mod.(*wasm.CallContext).Sys randSource := sysCtx.RandSource() buf, bufLen := uint32(params[0]), uint32(params[1]) randomBytes, ok := mod.Memory().Read(ctx, buf, bufLen) if !ok { // out-of-range - return errnoFault + return ErrnoFault } // We can ignore the returned n as it only != byteCount on error if _, err := io.ReadAtLeast(randSource, randomBytes, int(bufLen)); err != nil { - return errnoIo + return ErrnoIo } - return errnoSuccess + return ErrnoSuccess } diff --git a/imports/wasi_snapshot_preview1/wasi.go b/imports/wasi_snapshot_preview1/wasi.go index 5c1d1e65..baf5a6be 100644 --- a/imports/wasi_snapshot_preview1/wasi.go +++ b/imports/wasi_snapshot_preview1/wasi.go @@ -215,39 +215,36 @@ func exportFunctions(builder wazero.HostModuleBuilder) { exporter.ExportHostFunc(sockShutdown) } -// Declare constants to avoid slice allocation per call. -var ( - errnoBadf = []uint64{uint64(ErrnoBadf)} - errnoExist = []uint64{uint64(ErrnoExist)} - errnoInval = []uint64{uint64(ErrnoInval)} - errnoIo = []uint64{uint64(ErrnoIo)} - errnoNoent = []uint64{uint64(ErrnoNoent)} - errnoNotdir = []uint64{uint64(ErrnoNotdir)} - errnoFault = []uint64{uint64(ErrnoFault)} - errnoNametoolong = []uint64{uint64(ErrnoNametoolong)} - errnoSuccess = []uint64{uint64(ErrnoSuccess)} -) - -func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, values []string, offsets, bytes uint32) []uint64 { +func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, values []string, offsets, bytes uint32) Errno { for _, value := range values { // Write current offset and advance it. if !mem.WriteUint32Le(ctx, offsets, bytes) { - return errnoFault + return ErrnoFault } offsets += 4 // size of uint32 // Write the next value to memory with a NUL terminator if !mem.Write(ctx, bytes, []byte(value)) { - return errnoFault + return ErrnoFault } bytes += uint32(len(value)) if !mem.WriteByte(ctx, bytes, 0) { - return errnoFault + return ErrnoFault } bytes++ } - return errnoSuccess + return ErrnoSuccess +} + +// wasiFunc special cases that all WASI functions return a single Errno +// result. The returned value will be written back to the stack at index zero. +type wasiFunc func(ctx context.Context, mod api.Module, params []uint64) Errno + +// Call implements the same method as documented on api.GoModuleFunction. +func (f wasiFunc) Call(ctx context.Context, mod api.Module, stack []uint64) { + // Write the result back onto the stack + stack[0] = uint64(f(ctx, mod, stack)) } // stubFunction stubs for GrainLang per #271. diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index d82317f5..a385bdbd 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -629,7 +629,13 @@ func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, param ce.initializeStack(tp, params) ce.execWasmFunction(ctx, callCtx) - results = ce.stack[:tp.ResultNumInUint64] + // This returns a safe copy of the results, instead of a slice view. If we + // returned a re-slice, the caller could accidentally or purposefully + // corrupt the stack of subsequent calls + if resultCount := tp.ResultNumInUint64; resultCount > 0 { + results = make([]uint64, resultCount) + copy(results, ce.stack[:resultCount]) + } return } @@ -817,17 +823,24 @@ entry: case nativeCallStatusCodeCallGoHostFunction: calleeHostFunction := ce.moduleContext.fn base := int(ce.stackBasePointerInBytes >> 3) - params := ce.stack[base : base+len(calleeHostFunction.source.Type.Params)] + + // In the compiler engine, ce.stack has enough capacity for the + // max of param or result length, so we don't need to grow when + // there are more results than parameters. + stackLen := calleeHostFunction.source.Type.ParamNumInUint64 + if resultLen := calleeHostFunction.source.Type.ResultNumInUint64; resultLen > stackLen { + stackLen = resultLen + } + stack := ce.stack[base : base+stackLen] + fn := calleeHostFunction.source.GoFunc - var results []uint64 switch fn := fn.(type) { case api.GoModuleFunction: - results = fn.Call(ctx, callCtx.WithMemory(ce.memoryInstance), params) + fn.Call(ctx, callCtx.WithMemory(ce.memoryInstance), stack) case api.GoFunction: - results = fn.Call(ctx, params) + fn.Call(ctx, stack) } - copy(ce.stack[base:], results) codeAddr, modAddr = ce.returnAddress, ce.moduleInstanceAddress goto entry case nativeCallStatusCodeCallBuiltInFunction: diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index d8ec88c4..5a0b0361 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -827,6 +827,9 @@ func (ce *callEngine) call(ctx context.Context, m *wasm.CallContext, tf *functio ce.callFunction(ctx, m, tf) + // This returns a safe copy of the results, instead of a slice view. If we + // returned a re-slice, the caller could accidentally or purposefully + // corrupt the stack of subsequent calls. results = wasm.PopValues(ft.ResultNumInUint64, ce.popValue) return } @@ -859,8 +862,9 @@ func (ce *callEngine) callFunction(ctx context.Context, callCtx *wasm.CallContex } } -func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { +func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, stack []uint64) { if f.source.Listener != nil { + params := stack[:f.source.Type.ParamNumInUint64] ctx = f.source.Listener.Before(ctx, f.source.Definition, params) } frame := &callFrame{f: f} @@ -869,17 +873,17 @@ func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, fn := f.source.GoFunc switch fn := fn.(type) { case api.GoModuleFunction: - results = fn.Call(ctx, callCtx.WithMemory(ce.callerMemory()), params) + fn.Call(ctx, callCtx.WithMemory(ce.callerMemory()), stack) case api.GoFunction: - results = fn.Call(ctx, params) + fn.Call(ctx, stack) } ce.popFrame() if f.source.Listener != nil { // TODO: This doesn't get the error due to use of panic to propagate them. + results := stack[:f.source.Type.ResultNumInUint64] f.source.Listener.After(ctx, f.source.Definition, nil, results) } - return } func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallContext, f *function) { @@ -4380,9 +4384,25 @@ func (ce *callEngine) popMemoryOffset(op *interpreterOp) uint32 { } func (ce *callEngine) callGoFuncWithStack(ctx context.Context, callCtx *wasm.CallContext, f *function) { - params := wasm.PopValues(f.source.Type.ParamNumInUint64, ce.popValue) - results := ce.callGoFunc(ctx, callCtx, f, params) - for _, v := range results { - ce.pushValue(v) + paramLen := f.source.Type.ParamNumInUint64 + resultLen := f.source.Type.ResultNumInUint64 + stackLen := paramLen + + // In the interpreter engine, ce.stack may only have capacity to store + // parameters. Grow when there are more results than parameters. + if growLen := resultLen - paramLen; growLen > 0 { + for i := 0; i < growLen; i++ { + ce.stack = append(ce.stack, 0) + } + stackLen += growLen + } + + // Pass the stack elements to the go function. + stack := ce.stack[len(ce.stack)-stackLen:] + ce.callGoFunc(ctx, callCtx, f, stack) + + // Shrink the stack when there were more parameters than results. + if shrinkLen := paramLen - resultLen; shrinkLen > 0 { + ce.stack = ce.stack[0 : len(ce.stack)-shrinkLen] } } diff --git a/internal/gojs/runtime.go b/internal/gojs/runtime.go index 014cc4cb..89068fc3 100644 --- a/internal/gojs/runtime.go +++ b/internal/gojs/runtime.go @@ -37,12 +37,11 @@ var WasmExit = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func wasmExit(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - code := uint32(params[0]) +func wasmExit(ctx context.Context, mod api.Module, stack []uint64) { + code := uint32(stack[0]) getState(ctx).clear() _ = mod.CloseWithExitCode(ctx, code) // TODO: should ours be signed bit (like -1 == 255)? - return } // WasmWrite implements runtime.wasmWrite which supports runtime.write and @@ -60,8 +59,8 @@ var WasmWrite = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func wasmWrite(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - fd, p, n := uint32(params[0]), uint32(params[1]), uint32(params[2]) +func wasmWrite(ctx context.Context, mod api.Module, stack []uint64) { + fd, p, n := uint32(stack[0]), uint32(stack[1]), uint32(stack[2]) var writer io.Writer @@ -78,7 +77,6 @@ func wasmWrite(ctx context.Context, mod api.Module, params []uint64) (_ []uint64 if _, err := writer.Write(mustRead(ctx, mod.Memory(), "p", p, n)); err != nil { panic(fmt.Errorf("error writing p: %w", err)) } - return } // ResetMemoryDataView signals wasm.OpcodeMemoryGrow happened, indicating any @@ -107,9 +105,9 @@ var Nanotime1 = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func nanotime1(ctx context.Context, mod api.Module, _ []uint64) []uint64 { +func nanotime1(ctx context.Context, mod api.Module, stack []uint64) { time := mod.(*wasm.CallContext).Sys.Nanotime(ctx) - return []uint64{api.EncodeI64(time)} + stack[0] = api.EncodeI64(time) } // Walltime implements runtime.walltime which supports time.Now. @@ -125,9 +123,10 @@ var Walltime = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func walltime(ctx context.Context, mod api.Module, _ []uint64) []uint64 { +func walltime(ctx context.Context, mod api.Module, stack []uint64) { sec, nsec := mod.(*wasm.CallContext).Sys.Walltime(ctx) - return []uint64{api.EncodeI64(sec), api.EncodeI32(nsec)} + stack[0] = api.EncodeI64(sec) + stack[1] = api.EncodeI32(nsec) } // ScheduleTimeoutEvent implements runtime.scheduleTimeoutEvent which supports @@ -164,9 +163,9 @@ var GetRandomData = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func getRandomData(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { +func getRandomData(ctx context.Context, mod api.Module, stack []uint64) { randSource := mod.(*wasm.CallContext).Sys.RandSource() - buf, bufLen := uint32(params[0]), uint32(params[1]) + buf, bufLen := uint32(stack[0]), uint32(stack[1]) r := mustRead(ctx, mod.Memory(), "r", buf, bufLen) @@ -175,5 +174,4 @@ func getRandomData(ctx context.Context, mod api.Module, params []uint64) (_ []ui } else if uint32(n) != bufLen { panic(fmt.Errorf("RandSource.Read(r /* len=%d */) read %d bytes", bufLen, n)) } - return } diff --git a/internal/gojs/spfunc/spfunc_test.go b/internal/gojs/spfunc/spfunc_test.go index afb98aad..f254fe2c 100644 --- a/internal/gojs/spfunc/spfunc_test.go +++ b/internal/gojs/spfunc/spfunc_test.go @@ -175,12 +175,12 @@ var spMem = []byte{ 10, 0, 0, 0, 0, 0, 0, 0, } -func i64i32i32i32i32_i64i32_withSP(_ context.Context, params []uint64) []uint64 { - vRef := params[0] - mAddr := uint32(params[1]) - mLen := uint32(params[2]) - argsArray := uint32(params[3]) - argsLen := uint32(params[4]) +func i64i32i32i32i32_i64i32_withSP(_ context.Context, stack []uint64) { + vRef := stack[0] + mAddr := uint32(stack[1]) + mLen := uint32(stack[2]) + argsArray := uint32(stack[3]) + argsLen := uint32(stack[4]) if vRef != 1 { panic("vRef") @@ -198,7 +198,10 @@ func i64i32i32i32i32_i64i32_withSP(_ context.Context, params []uint64) []uint64 panic("argsLen") } - return []uint64{10, 20, 8} + // set results + stack[0] = 10 + stack[1] = 20 + stack[2] = 8 } func TestMustCallFromSP(t *testing.T) { diff --git a/internal/gojs/syscall.go b/internal/gojs/syscall.go index f182d582..66ecd321 100644 --- a/internal/gojs/syscall.go +++ b/internal/gojs/syscall.go @@ -49,11 +49,10 @@ var FinalizeRef = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func finalizeRef(ctx context.Context, params []uint64) (_ []uint64) { - id := uint32(params[0]) // 32-bits of the ref are the ID +func finalizeRef(ctx context.Context, stack []uint64) { + id := uint32(stack[0]) // 32-bits of the ref are the ID getState(ctx).values.decrement(id) - return } // StringVal implements js.stringVal, which is used to load the string for @@ -73,11 +72,12 @@ var StringVal = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func stringVal(ctx context.Context, mod api.Module, params []uint64) []uint64 { - xAddr, xLen := uint32(params[0]), uint32(params[1]) +func stringVal(ctx context.Context, mod api.Module, stack []uint64) { + xAddr, xLen := uint32(stack[0]), uint32(stack[1]) x := string(mustRead(ctx, mod.Memory(), "x", xAddr, xLen)) - return []uint64{storeRef(ctx, x)} + + stack[0] = storeRef(ctx, x) } // ValueGet implements js.valueGet, which is used to load a js.Value property @@ -98,10 +98,10 @@ var ValueGet = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valueGet(ctx context.Context, mod api.Module, params []uint64) []uint64 { - vRef := params[0] - pAddr := uint32(params[1]) - pLen := uint32(params[2]) +func valueGet(ctx context.Context, mod api.Module, stack []uint64) { + vRef := stack[0] + pAddr := uint32(stack[1]) + pLen := uint32(stack[2]) p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) v := loadValue(ctx, ref(vRef)) @@ -122,8 +122,7 @@ func valueGet(ctx context.Context, mod api.Module, params []uint64) []uint64 { panic(fmt.Errorf("TODO: valueGet(v=%v, p=%s)", v, p)) } - xRef := storeRef(ctx, result) - return []uint64{xRef} + stack[0] = storeRef(ctx, result) } // ValueSet implements js.valueSet, which is used to store a js.Value property @@ -143,11 +142,11 @@ var ValueSet = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valueSet(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - vRef := params[0] - pAddr := uint32(params[1]) - pLen := uint32(params[2]) - xRef := params[3] +func valueSet(ctx context.Context, mod api.Module, stack []uint64) { + vRef := stack[0] + pAddr := uint32(stack[1]) + pLen := uint32(stack[2]) + xRef := stack[3] v := loadValue(ctx, ref(vRef)) p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) @@ -196,15 +195,14 @@ var ValueIndex = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valueIndex(ctx context.Context, params []uint64) []uint64 { - vRef := params[0] - i := uint32(params[1]) +func valueIndex(ctx context.Context, stack []uint64) { + vRef := stack[0] + i := uint32(stack[1]) v := loadValue(ctx, ref(vRef)) result := v.(*objectArray).slice[i] - xRef := storeRef(ctx, result) - return []uint64{xRef} + stack[0] = storeRef(ctx, result) } // ValueSetIndex is stubbed as it is only used for js.ValueOf when the input is @@ -231,12 +229,12 @@ var ValueCall = spfunc.MustCallFromSP(true, &wasm.HostFunc{ }, }) -func valueCall(ctx context.Context, mod api.Module, params []uint64) []uint64 { - vRef := params[0] - mAddr := uint32(params[1]) - mLen := uint32(params[2]) - argsArray := uint32(params[3]) - argsLen := uint32(params[4]) +func valueCall(ctx context.Context, mod api.Module, stack []uint64) { + vRef := stack[0] + mAddr := uint32(stack[1]) + mLen := uint32(stack[2]) + argsArray := uint32(stack[3]) + argsLen := uint32(stack[4]) this := ref(vRef) v := loadValue(ctx, this) @@ -256,7 +254,7 @@ func valueCall(ctx context.Context, mod api.Module, params []uint64) []uint64 { } sp = refreshSP(mod) - return []uint64{xRef, uint64(ok), uint64(sp)} + stack[0], stack[1], stack[2] = xRef, uint64(ok), uint64(sp) } // ValueInvoke is stubbed as it isn't used in Go's main source tree. @@ -281,10 +279,10 @@ var ValueNew = spfunc.MustCallFromSP(true, &wasm.HostFunc{ }, }) -func valueNew(ctx context.Context, mod api.Module, params []uint64) []uint64 { - vRef := params[0] - argsArray := uint32(params[1]) - argsLen := uint32(params[2]) +func valueNew(ctx context.Context, mod api.Module, stack []uint64) { + vRef := stack[0] + argsArray := uint32(stack[1]) + argsLen := uint32(stack[2]) args := loadArgs(ctx, mod, argsArray, argsLen) ref := ref(vRef) @@ -328,7 +326,7 @@ func valueNew(ctx context.Context, mod api.Module, params []uint64) []uint64 { } sp = refreshSP(mod) - return []uint64{xRef, uint64(ok), uint64(sp)} + stack[0], stack[1], stack[2] = xRef, uint64(ok), uint64(sp) } // ValueLength implements js.valueLength, which is used to load the length @@ -348,13 +346,13 @@ var ValueLength = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valueLength(ctx context.Context, params []uint64) []uint64 { - vRef := params[0] +func valueLength(ctx context.Context, stack []uint64) { + vRef := stack[0] v := loadValue(ctx, ref(vRef)) l := uint32(len(v.(*objectArray).slice)) - return []uint64{uint64(l)} + stack[0] = uint64(l) } // ValuePrepareString implements js.valuePrepareString, which is used to load @@ -376,8 +374,8 @@ var ValuePrepareString = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valuePrepareString(ctx context.Context, params []uint64) []uint64 { - vRef := params[0] +func valuePrepareString(ctx context.Context, stack []uint64) { + vRef := stack[0] v := loadValue(ctx, ref(vRef)) s := valueString(v) @@ -385,7 +383,7 @@ func valuePrepareString(ctx context.Context, params []uint64) []uint64 { sRef := storeRef(ctx, s) sLen := uint32(len(s)) - return []uint64{sRef, uint64(sLen)} + stack[0], stack[1] = sRef, uint64(sLen) } // ValueLoadString implements js.valueLoadString, which is used copy a string @@ -405,16 +403,15 @@ var ValueLoadString = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func valueLoadString(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - vRef := params[0] - bAddr := uint32(params[1]) - bLen := uint32(params[2]) +func valueLoadString(ctx context.Context, mod api.Module, stack []uint64) { + vRef := stack[0] + bAddr := uint32(stack[1]) + bLen := uint32(stack[2]) v := loadValue(ctx, ref(vRef)) s := valueString(v) b := mustRead(ctx, mod.Memory(), "b", bAddr, bLen) copy(b, s) - return } // ValueInstanceOf is stubbed as it isn't used in Go's main source tree. @@ -444,11 +441,11 @@ var CopyBytesToGo = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func copyBytesToGo(ctx context.Context, mod api.Module, params []uint64) []uint64 { - dstAddr := uint32(params[0]) - dstLen := uint32(params[1]) - _ /* unknown */ = uint32(params[2]) - srcRef := params[3] +func copyBytesToGo(ctx context.Context, mod api.Module, stack []uint64) { + dstAddr := uint32(stack[0]) + dstLen := uint32(stack[1]) + _ /* unknown */ = uint32(stack[2]) + srcRef := stack[3] dst := mustRead(ctx, mod.Memory(), "dst", dstAddr, dstLen) // nolint v := loadValue(ctx, ref(srcRef)) @@ -459,7 +456,7 @@ func copyBytesToGo(ctx context.Context, mod api.Module, params []uint64) []uint6 ok = 1 } - return []uint64{uint64(n), uint64(ok)} + stack[0], stack[1] = uint64(n), uint64(ok) } // CopyBytesToJS copies linear memory to a JavaScript managed byte array. @@ -485,11 +482,11 @@ var CopyBytesToJS = spfunc.MustCallFromSP(false, &wasm.HostFunc{ }, }) -func copyBytesToJS(ctx context.Context, mod api.Module, params []uint64) []uint64 { - dstRef := params[0] - srcAddr := uint32(params[1]) - srcLen := uint32(params[2]) - _ /* unknown */ = uint32(params[3]) +func copyBytesToJS(ctx context.Context, mod api.Module, stack []uint64) { + dstRef := stack[0] + srcAddr := uint32(stack[1]) + srcLen := uint32(stack[2]) + _ /* unknown */ = uint32(stack[3]) src := mustRead(ctx, mod.Memory(), "src", srcAddr, srcLen) // nolint v := loadValue(ctx, ref(dstRef)) @@ -502,7 +499,7 @@ func copyBytesToJS(ctx context.Context, mod api.Module, params []uint64) []uint6 ok = 1 } - return []uint64{uint64(n), uint64(ok)} + stack[0], stack[1] = uint64(n), uint64(ok) } // refreshSP refreshes the stack pointer, which is needed prior to storeValue diff --git a/internal/integration_test/bench/hostfunc_bench_test.go b/internal/integration_test/bench/hostfunc_bench_test.go index cee9d982..b4b3bba4 100644 --- a/internal/integration_test/bench/hostfunc_bench_test.go +++ b/internal/integration_test/bench/hostfunc_bench_test.go @@ -148,12 +148,12 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { CodeSection: []*wasm.Code{ { IsHostFunction: true, - GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { - ret, ok := mod.Memory().ReadUint32Le(ctx, uint32(params[0])) + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) { + ret, ok := mod.Memory().ReadUint32Le(ctx, uint32(stack[0])) if !ok { panic("couldn't read memory") } - return []uint64{uint64(ret)} + stack[0] = uint64(ret) }), }, wasm.MustParseGoReflectFuncCode( diff --git a/internal/integration_test/vs/runtime.go b/internal/integration_test/vs/runtime.go index 7ed957f5..04794964 100644 --- a/internal/integration_test/vs/runtime.go +++ b/internal/integration_test/vs/runtime.go @@ -79,8 +79,8 @@ func (m *wazeroModule) Memory() []byte { return m.mod.Memory().(*wasm.MemoryInstance).Buffer } -func (r *wazeroRuntime) log(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { - offset, byteCount := uint32(params[0]), uint32(params[1]) +func (r *wazeroRuntime) log(ctx context.Context, mod api.Module, stack []uint64) { + offset, byteCount := uint32(stack[0]), uint32(stack[1]) buf, ok := mod.Memory().Read(ctx, offset, byteCount) if !ok { @@ -89,8 +89,6 @@ func (r *wazeroRuntime) log(ctx context.Context, mod api.Module, params []uint64 if err := r.logFn(buf); err != nil { panic(err) } - - return } func (r *wazeroRuntime) Compile(ctx context.Context, cfg *RuntimeConfig) (err error) { @@ -107,8 +105,8 @@ func (r *wazeroRuntime) Compile(ctx context.Context, cfg *RuntimeConfig) (err er } else if cfg.EnvFReturnValue != 0 { if r.env, err = r.runtime.NewHostModuleBuilder("env"). NewFunctionBuilder(). - WithGoFunction(api.GoFunc(func(context.Context, []uint64) []uint64 { - return []uint64{cfg.EnvFReturnValue} + WithGoFunction(api.GoFunc(func(ctx context.Context, stack []uint64) { + stack[0] = cfg.EnvFReturnValue }), []api.ValueType{api.ValueTypeI64}, []api.ValueType{api.ValueTypeI64}). Export("f"). Compile(ctx); err != nil { diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 75a1d29d..6c1fcefd 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -519,6 +519,14 @@ func runTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester, hostDivBy *w results, err := ce.Call(testCtx, m, []uint64{1}) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) + + results2, err := ce.Call(testCtx, m, []uint64{1}) + require.NoError(t, err) + require.Equal(t, results, results2) + + // Ensure the result slices are unique + results[0] = 255 + require.Equal(t, uint64(1), results2[0]) }) } } diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index 51fa9f6f..c98ce1a0 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -37,8 +37,8 @@ type reflectGoModuleFunction struct { } // Call implements the same method as documented on api.GoModuleFunction. -func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, params []uint64) []uint64 { - return callGoFunc(ctx, mod, f.fn, params) +func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, stack []uint64) { + callGoFunc(ctx, mod, f.fn, stack) } // EqualTo is exposed for testing. @@ -72,11 +72,11 @@ func (f *reflectGoFunction) EqualTo(that interface{}) bool { } // Call implements the same method as documented on api.GoFunction. -func (f *reflectGoFunction) Call(ctx context.Context, params []uint64) []uint64 { +func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) { if f.pk == paramsKindNoContext { ctx = nil } - return callGoFunc(ctx, nil, f.fn, params) + callGoFunc(ctx, nil, f.fn, stack) } // PopValues pops the specified number of api.ValueType parameters off the @@ -101,12 +101,13 @@ func PopValues(count int, popper func() uint64) []uint64 { // callGoFunc executes the reflective function by converting params to Go // types. The results of the function call are converted back to api.ValueType. -func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params []uint64) []uint64 { +func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) { tp := fn.Type() var in []reflect.Value - if tp.NumIn() != 0 { - in = make([]reflect.Value, tp.NumIn()) + pLen := tp.NumIn() + if pLen != 0 { + in = make([]reflect.Value, pLen) i := 0 if ctx != nil { @@ -118,9 +119,13 @@ func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params [ i++ } - for _, raw := range params { - val := reflect.New(tp.In(i)).Elem() - k := tp.In(i).Kind() + for j := 0; i < pLen; i++ { + next := tp.In(i) + val := reflect.New(next).Elem() + k := next.Kind() + raw := stack[j] + j++ + switch k { case reflect.Float32: val.SetFloat(float64(math.Float32frombits(uint32(raw)))) @@ -134,30 +139,24 @@ func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params [ panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k)) } in[i] = val - i++ } } // Execute the host function and push back the call result onto the stack. - var results []uint64 - if tp.NumOut() > 0 { - results = make([]uint64, 0, tp.NumOut()) - } for i, ret := range fn.Call(in) { switch ret.Kind() { case reflect.Float32: - results = append(results, uint64(math.Float32bits(float32(ret.Float())))) + stack[i] = uint64(math.Float32bits(float32(ret.Float()))) case reflect.Float64: - results = append(results, math.Float64bits(ret.Float())) + stack[i] = math.Float64bits(ret.Float()) case reflect.Uint32, reflect.Uint64, reflect.Uintptr: - results = append(results, ret.Uint()) + stack[i] = ret.Uint() case reflect.Int32, reflect.Int64: - results = append(results, uint64(ret.Int())) + stack[i] = uint64(ret.Int()) default: panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind())) } } - return results } func newContextVal(ctx context.Context) reflect.Value { diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index 1acd60c0..07fc2e58 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -267,15 +267,27 @@ func Test_callGoFunc(t *testing.T) { _, _, code, err := parseGoReflectFunc(tc.input) require.NoError(t, err) - var results []uint64 + resultLen := len(tc.expectedResults) + stackLen := len(tc.inputParams) + if resultLen > stackLen { + stackLen = resultLen + } + stack := make([]uint64, stackLen) + copy(stack, tc.inputParams) + switch code.GoFunc.(type) { case api.GoFunction: - results = code.GoFunc.(api.GoFunction).Call(testCtx, tc.inputParams) + code.GoFunc.(api.GoFunction).Call(testCtx, stack) case api.GoModuleFunction: - results = code.GoFunc.(api.GoModuleFunction).Call(testCtx, callCtx, tc.inputParams) + code.GoFunc.(api.GoModuleFunction).Call(testCtx, callCtx, stack) default: t.Fatal("unexpected type.") } + + var results []uint64 + if resultLen > 0 { + results = stack[:resultLen] + } require.Equal(t, tc.expectedResults, results) }) }