diff --git a/api/wasm.go b/api/wasm.go index 2d812fac..b96da949 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "math" - "reflect" ) // ExternType classifies imports and exports with their respective types. @@ -95,9 +94,10 @@ const ( // (func (import "env" "f") (param externref) (result externref)) // // This can be defined in Go as: - // r.NewHostModuleBuilder("env").ExportFunctions(map[string]interface{}{ - // "f": func(externref uintptr) (resultExternRef uintptr) { return }, - // }) + // r.NewHostModuleBuilder("env"). + // NewFunctionBuilder(). + // WithFunc(func(context.Context, _ uintptr) (_ uintptr) { return }). + // Export("f") // // Note: The usage of this type is toggled with api.CoreFeatureBulkMemoryOperations. ValueTypeExternref ValueType = 0x6f @@ -251,16 +251,11 @@ type FunctionDefinition interface { // and not the imported function name. DebugName() string - // GoFunc is present when the function was implemented by the embedder - // (ex via wazero.HostModuleBuilder) instead of a wasm binary. + // GoFunction is non-nil when implemented by the embedder instead of a wasm + // binary, e.g. via wazero.HostModuleBuilder // - // This function can be non-deterministic or cause side effects. It also - // has special properties not defined in the WebAssembly Core - // specification. Notably, it uses the caller's memory, which might be - // different from its defining module. - // - // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 - GoFunc() *reflect.Value + // The expected results are nil, GoFunction or GoModuleFunction. + GoFunction() any // ParamTypes are the possibly empty sequence of value types accepted by a // function with this signature. @@ -301,6 +296,66 @@ type Function interface { Call(ctx context.Context, params ...uint64) ([]uint64, error) } +// GoModuleFunction is a Function implemented in Go instead of a wasm binary. +// The Module parameter is the calling module, used to access memory or +// exported functions. See GoModuleFunc for an example. +// +// This function can be non-deterministic or cause side effects. It also +// has special properties not defined in the WebAssembly Core specification. +// Notably, this uses the caller's memory (via Module.Memory). See +// https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 +// +// Most end users will not define functions directly with this, as they will +// use reflection or code generators instead. These approaches are more +// idiomatic as they can map go types to ValueType. This type is exposed for +// those willing to trade usability and safety for performance. +type GoModuleFunction interface { + Call(ctx context.Context, mod Module, params []uint64) []uint64 +} + +// GoModuleFunc is a convenience for defining an inlined function. +// +// For example, the following returns a uint32 value read from parameter zero: +// +// api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { +// ret, ok := mod.Memory().ReadUint32Le(ctx, uint32(params[0])) +// if !ok { +// panic("out of memory") +// } +// return []uint64{uint64(ret)} +// } +type GoModuleFunc func(ctx context.Context, mod Module, params []uint64) []uint64 + +// Call implements GoModuleFunction.Call. +func (f GoModuleFunc) Call(ctx context.Context, mod Module, params []uint64) []uint64 { + return f(ctx, mod, params) +} + +// GoFunction is an optimized form of GoModuleFunction which doesn't require +// the Module parameter. See GoFunc for an example. +// +// For example, this function does not need to use the importing module's +// memory or exported functions. +type GoFunction interface { + Call(ctx context.Context, params []uint64) []uint64 +} + +// GoFunc is a convenience for defining an inlined function. +// +// For example, the following returns the sum of two uint32 parameters: +// +// api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { +// x, y := uint32(params[0]), uint32(params[1]) +// sum := x + y +// return []uint64{sum} +// } +type GoFunc func(ctx context.Context, params []uint64) []uint64 + +// Call implements GoFunction.Call. +func (f GoFunc) Call(ctx context.Context, params []uint64) []uint64 { + return f(ctx, params) +} + // Global is a WebAssembly 1.0 (20191205) global exported from an instantiated module (wazero.Runtime InstantiateModule). // // For example, if the value is not mutable, you can read it once: diff --git a/builder.go b/builder.go index 3aadc49a..c8ac5e02 100644 --- a/builder.go +++ b/builder.go @@ -7,6 +7,125 @@ import ( "github.com/tetratelabs/wazero/internal/wasm" ) +// HostFunctionBuilder defines a host function (in Go), so that a +// WebAssembly binary (e.g. %.wasm file) can import and use it. +// +// Here's an example of an addition function: +// +// hostModuleBuilder.NewFunctionBuilder(). +// WithFunc(func(cxt context.Context, x, y uint32) uint32 { +// return x + y +// }). +// Export("add") +// +// # Memory +// +// All host functions act on the importing api.Module, including any memory +// exported in its binary (%.wasm file). If you are reading or writing memory, +// it is sand-boxed Wasm memory defined by the guest. +// +// Below, `m` is the importing module, defined in Wasm. `fn` is a host function +// added via Export. This means that `x` was read from memory defined in Wasm, +// not arbitrary memory in the process. +// +// fn := func(ctx context.Context, m api.Module, offset uint32) uint32 { +// x, _ := m.Memory().ReadUint32Le(ctx, offset) +// return x +// } +type HostFunctionBuilder interface { + // WithGoFunction is an advanced feature for those who need higher + // performance than WithFunc at the cost of more complexity. + // + // Here's an example addition function: + // + // builder.WithGoFunction(api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { + // x, y := uint32(params[0]), uint32(params[1]) + // sum := x + y + // return []uint64{sum} + // }, []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}) + // + // As you can see above, defining in this way implies knowledge of which + // WebAssembly api.ValueType is appropriate for each parameter and result. + // + // See WithGoModuleFunction if you also need to access the calling module. + WithGoFunction(fn api.GoFunction, params, results []api.ValueType) HostFunctionBuilder + + // WithGoModuleFunction is an advanced feature for those who need higher + // performance than WithFunc at the cost of more complexity. + // + // Here's an example addition function that loads operands from memory: + // + // builder.WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { + // mem := m.Memory() + // offset := uint32(params[0]) + // + // x, _ := mem.ReadUint32Le(ctx, offset) + // y, _ := mem.ReadUint32Le(ctx, offset + 4) // 32 bits == 4 bytes! + // sum := x + y + // + // return []uint64{sum} + // }, []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}) + // + // As you can see above, defining in this way implies knowledge of which + // WebAssembly api.ValueType is appropriate for each parameter and result. + // + // See WithGoFunction if you don't need access to the calling module. + WithGoModuleFunction(fn api.GoModuleFunction, params, results []api.ValueType) HostFunctionBuilder + + // WithFunc uses reflect.Value to map a go `func` to a WebAssembly + // compatible Signature. An input that isn't a `func` will fail to + // instantiate. + // + // Here's an example of an addition function: + // + // builder.WithFunc(func(cxt context.Context, x, y uint32) uint32 { + // return x + y + // }) + // + // # Defining a function + // + // Except for the context.Context and optional api.Module, all parameters + // or result types must map to WebAssembly numeric value types. This means + // uint32, int32, uint64, int32 float32 or float64. + // + // api.Module may be specified as the second parameter, usually to access + // memory. This is important because there are only numeric types in Wasm. + // The only way to share other data is via writing memory and sharing + // offsets. + // + // builder.WithFunc(func(ctx context.Context, m api.Module, offset uint32) uint32 { + // mem := m.Memory() + // x, _ := mem.ReadUint32Le(ctx, offset) + // y, _ := mem.ReadUint32Le(ctx, offset + 4) // 32 bits == 4 bytes! + // return x + y + // }) + // + // This example propagates context properly when calling other functions + // exported in the api.Module: + // + // builder.WithFunc(func(ctx context.Context, m api.Module, offset, byteCount uint32) uint32 { + // fn = m.ExportedFunction("__read") + // results, err := fn(ctx, offset, byteCount) + // --snip-- + WithFunc(interface{}) HostFunctionBuilder + + // WithName defines the optional module-local name of this function, e.g. + // "random_get" + // + // Note: This is not required to match the Export name. + WithName(name string) HostFunctionBuilder + + // WithParameterNames defines optional parameter names of the function + // signature, e.x. "buf", "buf_len" + // + // Note: When defined, names must be provided for all parameters. + WithParameterNames(names ...string) HostFunctionBuilder + + // Export exports this to the HostModuleBuilder as the given name, e.g. + // "random_get" + Export(name string) HostModuleBuilder +} + // HostModuleBuilder is a way to define host functions (in Go), so that a // WebAssembly binary (e.g. %.wasm file) can import and use them. // @@ -24,36 +143,20 @@ import ( // fmt.Fprintln(stdout, "hello!") // } // env, _ := r.NewHostModuleBuilder("env"). -// ExportFunction("hello", hello). +// NewFunctionBuilder().WithFunc(hello).Export("hello"). // Instantiate(ctx, r) // // If the same module may be instantiated multiple times, it is more efficient // to separate steps. Here's an example: // // compiled, _ := r.NewHostModuleBuilder("env"). -// ExportFunction("get_random_string", getRandomString). +// NewFunctionBuilder().WithFunc(getRandomString).Export("get_random_string"). // Compile(ctx) // // env1, _ := r.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName("env.1")) -// // env2, _ := r.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName("env.2")) // -// # Memory -// -// All host functions act on the importing api.Module, including any memory -// exported in its binary (%.wasm file). If you are reading or writing memory, -// it is sand-boxed Wasm memory defined by the guest. -// -// Below, `m` is the importing module, defined in Wasm. `fn` is a host function -// added via ExportFunction. This means that `x` was read from memory defined -// in Wasm, not arbitrary memory in the process. -// -// fn := func(ctx context.Context, m api.Module, offset uint32) uint32 { -// x, _ := m.Memory().ReadUint32Le(ctx, offset) -// return x -// } -// -// See ExportFunction for valid host function signatures and other details. +// See HostFunctionBuilder for valid host function signatures and other details. // // # Notes // @@ -66,71 +169,8 @@ import ( type HostModuleBuilder interface { // Note: until golang/go#5860, we can't use example tests to embed code in interface godocs. - // ExportFunction adds a function written in Go, which a WebAssembly module can import. - // If a function is already exported with the same name, this overwrites it. - // - // # Parameters - // - // - exportName - The name to export. Ex "random_get" - // - goFunc - The `func` to export. - // - names - If present, the first is the api.FunctionDefinition name. - // If any follow, they must match the count of goFunc's parameters. - // - // Here's an example: - // // Just export the function, and use "abort" in stack traces. - // builder.ExportFunction("abort", env.abort) - // // Ensure "~lib/builtins/abort" is used in stack traces. - // builder.ExportFunction("abort", env.abort, "~lib/builtins/abort") - // // Allow function listeners to know the param names for logging, etc. - // builder.ExportFunction("abort", env.abort, "~lib/builtins/abort", - // "message", "fileName", "lineNumber", "columnNumber") - // - // # Valid Signature - // - // Noting a context exception described later, all parameters or result - // types must match WebAssembly 1.0 (20191205) value types. This means - // uint32, uint64, float32 or float64. Up to one result can be returned. - // - // For example, this is a valid host function: - // - // addInts := func(x, y uint32) uint32 { - // return x + y - // } - // - // Host functions may also have an initial parameter (param[0]) of type - // context.Context or api.Module. - // - // For example, this uses a Go Context: - // - // addInts := func(ctx context.Context, x, y uint32) uint32 { - // // add a little extra if we put some in the context! - // return x + y + ctx.Value(extraKey).(uint32) - // } - // - // The example below uses an api.Module to read parameters from memory. - // This is important because there are only numeric types in Wasm. The - // only way to share other data is via writing memory and sharing offsets. - // - // addInts := func(ctx context.Context, m api.Module, offset uint32) uint32 { - // x, _ := m.Memory().ReadUint32Le(ctx, offset) - // y, _ := m.Memory().ReadUint32Le(ctx, offset + 4) // 32 bits == 4 bytes! - // return x + y - // } - // - // If both parameters exist, they must be in order at positions 0 and 1. - // - // This example propagates context properly when calling other functions - // exported in the api.Module: - // callRead := func(ctx context.Context, m api.Module, offset, byteCount uint32) uint32 { - // fn = m.ExportedFunction("__read") - // results, err := fn(ctx, offset, byteCount) - // --snip-- - // - // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 - ExportFunction(exportName string, goFunc interface{}, names ...string) HostModuleBuilder - - // ExportFunctions is a convenience that calls ExportFunction for each key/value in the provided map. - ExportFunctions(nameToGoFunc map[string]interface{}) HostModuleBuilder + // NewFunctionBuilder begins the definition of a host function. + NewFunctionBuilder() HostFunctionBuilder // Compile returns a CompiledModule that can instantiated in any namespace (Namespace). // @@ -150,7 +190,7 @@ type HostModuleBuilder interface { // fmt.Fprintln(stdout, "hello!") // } // env, _ := r.NewHostModuleBuilder("env"). - // ExportFunction("hello", hello). + // NewFunctionBuilder().WithFunc(hello).Export("hello"). // Instantiate(ctx, r) // // # Notes @@ -179,21 +219,84 @@ func (r *runtime) NewHostModuleBuilder(moduleName string) HostModuleBuilder { } } -// ExportFunction implements HostModuleBuilder.ExportFunction -func (b *hostModuleBuilder) ExportFunction(exportName string, goFunc interface{}, names ...string) HostModuleBuilder { - b.nameToGoFunc[exportName] = goFunc - if len(names) > 0 { - b.funcToNames[exportName] = names - } - return b +// hostFunctionBuilder implements HostFunctionBuilder +type hostFunctionBuilder struct { + b *hostModuleBuilder + fn interface{} + name string + paramNames []string } -// ExportFunctions implements HostModuleBuilder.ExportFunctions -func (b *hostModuleBuilder) ExportFunctions(nameToGoFunc map[string]interface{}) HostModuleBuilder { - for k, v := range nameToGoFunc { - b.ExportFunction(k, v) +// WithGoFunction implements HostFunctionBuilder.WithGoFunction +func (h *hostFunctionBuilder) WithGoFunction(fn api.GoFunction, params, results []api.ValueType) HostFunctionBuilder { + h.fn = &wasm.HostFunc{ + ParamTypes: params, + ResultTypes: results, + Code: &wasm.Code{IsHostFunction: true, GoFunc: fn}, } - return b + return h +} + +// WithGoModuleFunction implements HostFunctionBuilder.WithGoModuleFunction +func (h *hostFunctionBuilder) WithGoModuleFunction(fn api.GoModuleFunction, params, results []api.ValueType) HostFunctionBuilder { + h.fn = &wasm.HostFunc{ + ParamTypes: params, + ResultTypes: results, + Code: &wasm.Code{IsHostFunction: true, GoFunc: fn}, + } + return h +} + +// WithFunc implements HostFunctionBuilder.WithFunc +func (h *hostFunctionBuilder) WithFunc(fn interface{}) HostFunctionBuilder { + h.fn = fn + return h +} + +// WithName implements HostFunctionBuilder.WithName +func (h *hostFunctionBuilder) WithName(name string) HostFunctionBuilder { + h.name = name + return h +} + +// WithParameterNames implements HostFunctionBuilder.WithParameterNames +func (h *hostFunctionBuilder) WithParameterNames(names ...string) HostFunctionBuilder { + h.paramNames = names + return h +} + +// Export implements HostFunctionBuilder.Export +func (h *hostFunctionBuilder) Export(exportName string) HostModuleBuilder { + if h.name == "" { + h.name = exportName + } + if fn, ok := h.fn.(*wasm.HostFunc); ok { + if fn.Name == "" { + fn.Name = h.name + } + fn.ParamNames = h.paramNames + fn.ExportNames = []string{exportName} + } + h.b.nameToGoFunc[exportName] = h.fn + if len(h.paramNames) > 0 { + h.b.funcToNames[exportName] = append([]string{h.name}, h.paramNames...) + } + return h.b +} + +// ExportHostFunc implements wasm.HostFuncExporter +func (b *hostModuleBuilder) ExportHostFunc(fn *wasm.HostFunc) { + b.nameToGoFunc[fn.ExportNames[0]] = fn +} + +// ExportProxyFunc implements wasm.ProxyFuncExporter +func (b *hostModuleBuilder) ExportProxyFunc(fn *wasm.ProxyFunc) { + b.nameToGoFunc[fn.Name()] = fn +} + +// NewFunctionBuilder implements HostModuleBuilder.NewFunctionBuilder +func (b *hostModuleBuilder) NewFunctionBuilder() HostFunctionBuilder { + return &hostFunctionBuilder{b: b} } // Compile implements HostModuleBuilder.Compile diff --git a/builder_test.go b/builder_test.go index e90b66d5..f7daee2c 100644 --- a/builder_test.go +++ b/builder_test.go @@ -1,6 +1,7 @@ package wazero import ( + "context" "testing" "github.com/tetratelabs/wazero/api" @@ -12,13 +13,20 @@ import ( func TestNewHostModuleBuilder_Compile(t *testing.T) { i32, i64 := api.ValueTypeI32, api.ValueTypeI64 - uint32_uint32 := func(uint32) uint32 { + uint32_uint32 := func(context.Context, uint32) uint32 { return 0 } - uint64_uint32 := func(uint64) uint32 { + uint64_uint32 := func(context.Context, uint64) uint32 { return 0 } + gofunc1 := api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { + return []uint64{0} + }) + gofunc2 := api.GoFunc(func(ctx context.Context, params []uint64) []uint64 { + return []uint64{0} + }) + tests := []struct { name string input func(Runtime) HostModuleBuilder @@ -39,16 +47,17 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { expected: &wasm.Module{NameSection: &wasm.NameSection{ModuleName: "env"}}, }, { - name: "ExportFunction", + name: "WithFunc", input: func(r Runtime) HostModuleBuilder { - return r.NewHostModuleBuilder("").ExportFunction("1", uint32_uint32) + return r.NewHostModuleBuilder(""). + NewFunctionBuilder().WithFunc(uint32_uint32).Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -58,16 +67,19 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { }, }, { - name: "ExportFunction with names", + name: "WithFunc WithName WithParameterNames", input: func(r Runtime) HostModuleBuilder { - return r.NewHostModuleBuilder("").ExportFunction("1", uint32_uint32, "get", "x") + return r.NewHostModuleBuilder("").NewFunctionBuilder(). + WithFunc(uint32_uint32). + WithName("get").WithParameterNames("x"). + Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32)}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(uint32_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -78,16 +90,18 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { }, }, { - name: "ExportFunction overwrites existing", + name: "WithFunc overwrites existing", input: func(r Runtime) HostModuleBuilder { - return r.NewHostModuleBuilder("").ExportFunction("1", uint32_uint32).ExportFunction("1", uint64_uint32) + return r.NewHostModuleBuilder(""). + NewFunctionBuilder().WithFunc(uint32_uint32).Export("1"). + NewFunctionBuilder().WithFunc(uint64_uint32).Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint64_uint32)}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, }, @@ -97,10 +111,12 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { }, }, { - name: "ExportFunction twice", + name: "WithFunc twice", input: func(r Runtime) HostModuleBuilder { // Intentionally out of order - return r.NewHostModuleBuilder("").ExportFunction("2", uint64_uint32).ExportFunction("1", uint32_uint32) + return r.NewHostModuleBuilder(""). + NewFunctionBuilder().WithFunc(uint64_uint32).Export("2"). + NewFunctionBuilder().WithFunc(uint32_uint32).Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ @@ -108,7 +124,7 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(uint32_uint32), wasm.MustParseGoReflectFuncCode(uint64_uint32)}, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -119,37 +135,92 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { }, }, { - name: "ExportFunctions", + name: "WithGoFunction", input: func(r Runtime) HostModuleBuilder { - return r.NewHostModuleBuilder("").ExportFunctions(map[string]interface{}{ - "1": uint32_uint32, - "2": uint64_uint32, - }) + return r.NewHostModuleBuilder(""). + NewFunctionBuilder(). + WithGoFunction(gofunc1, []api.ValueType{i32}, []api.ValueType{i32}). + Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, - {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, - FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{ + {IsHostFunction: true, GoFunc: gofunc1}, + }, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, - {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, }, NameSection: &wasm.NameSection{ - FunctionNames: wasm.NameMap{{Index: 0, Name: "1"}, {Index: 1, Name: "2"}}, + FunctionNames: wasm.NameMap{{Index: 0, Name: "1"}}, }, }, }, { - name: "ExportFunctions overwrites", + name: "WithGoFunction WithName WithParameterNames", input: func(r Runtime) HostModuleBuilder { - b := r.NewHostModuleBuilder("").ExportFunction("1", uint64_uint32) - return b.ExportFunctions(map[string]interface{}{ - "1": uint32_uint32, - "2": uint64_uint32, - }) + return r.NewHostModuleBuilder("").NewFunctionBuilder(). + WithGoFunction(gofunc1, []api.ValueType{i32}, []api.ValueType{i32}). + WithName("get").WithParameterNames("x"). + Export("1") + }, + expected: &wasm.Module{ + TypeSection: []*wasm.FunctionType{ + {Params: []api.ValueType{i32}, Results: []api.ValueType{i32}}, + }, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{ + {IsHostFunction: true, GoFunc: gofunc1}, + }, + ExportSection: []*wasm.Export{ + {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, + }, + NameSection: &wasm.NameSection{ + FunctionNames: wasm.NameMap{{Index: 0, Name: "get"}}, + LocalNames: []*wasm.NameMapAssoc{{Index: 0, NameMap: wasm.NameMap{{Index: 0, Name: "x"}}}}, + }, + }, + }, + { + name: "WithGoFunction overwrites existing", + input: func(r Runtime) HostModuleBuilder { + return r.NewHostModuleBuilder(""). + NewFunctionBuilder(). + WithGoFunction(gofunc1, []api.ValueType{i32}, []api.ValueType{i32}). + Export("1"). + NewFunctionBuilder(). + WithGoFunction(gofunc2, []api.ValueType{i64}, []api.ValueType{i32}). + Export("1") + }, + expected: &wasm.Module{ + TypeSection: []*wasm.FunctionType{ + {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, + }, + FunctionSection: []wasm.Index{0}, + CodeSection: []*wasm.Code{ + {IsHostFunction: true, GoFunc: gofunc2}, + }, + ExportSection: []*wasm.Export{ + {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, + }, + NameSection: &wasm.NameSection{ + FunctionNames: wasm.NameMap{{Index: 0, Name: "1"}}, + }, + }, + }, + { + name: "WithGoFunction twice", + input: func(r Runtime) HostModuleBuilder { + // Intentionally out of order + return r.NewHostModuleBuilder(""). + NewFunctionBuilder(). + WithGoFunction(gofunc2, []api.ValueType{i64}, []api.ValueType{i32}). + Export("2"). + NewFunctionBuilder(). + WithGoFunction(gofunc1, []api.ValueType{i32}, []api.ValueType{i32}). + Export("1") }, expected: &wasm.Module{ TypeSection: []*wasm.FunctionType{ @@ -157,7 +228,10 @@ func TestNewHostModuleBuilder_Compile(t *testing.T) { {Params: []api.ValueType{i64}, Results: []api.ValueType{i32}}, }, FunctionSection: []wasm.Index{0, 1}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(uint32_uint32), wasm.MustParseGoFuncCode(uint64_uint32)}, + CodeSection: []*wasm.Code{ + {IsHostFunction: true, GoFunc: gofunc1}, + {IsHostFunction: true, GoFunc: gofunc2}, + }, ExportSection: []*wasm.Export{ {Name: "1", Type: wasm.ExternTypeFunc, Index: 0}, {Name: "2", Type: wasm.ExternTypeFunc, Index: 1}, @@ -204,12 +278,12 @@ func TestNewHostModuleBuilder_Compile_Errors(t *testing.T) { { name: "error compiling", // should fail due to missing result. input: func(rt Runtime) HostModuleBuilder { - return rt.NewHostModuleBuilder(""). - ExportFunction("fn", &wasm.HostFunc{ + return rt.NewHostModuleBuilder("").NewFunctionBuilder(). + WithFunc(&wasm.HostFunc{ ExportNames: []string{"fn"}, ResultTypes: []wasm.ValueType{wasm.ValueTypeI32}, Code: &wasm.Code{IsHostFunction: true, Body: []byte{wasm.OpcodeEnd}}, - }) + }).Export("fn") }, expectedErr: `invalid function[0] export["fn"]: not enough results have () @@ -275,8 +349,7 @@ func requireHostModuleEquals(t *testing.T, expected, actual *wasm.Module) { for i, c := range expected.CodeSection { actualCode := actual.CodeSection[i] require.True(t, actualCode.IsHostFunction) - require.Equal(t, c.Kind, actualCode.Kind) - require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + require.Equal(t, c.GoFunc, actualCode.GoFunc) // Not wasm require.Nil(t, actualCode.Body) diff --git a/examples/allocation/rust/greet.go b/examples/allocation/rust/greet.go index 1001a7ef..f84ab4be 100644 --- a/examples/allocation/rust/greet.go +++ b/examples/allocation/rust/greet.go @@ -31,7 +31,7 @@ func main() { // Instantiate a Go-defined module named "env" that exports a function to // log to the console. _, err := r.NewHostModuleBuilder("env"). - ExportFunction("log", logString). + NewFunctionBuilder().WithFunc(logString).Export("log"). Instantiate(ctx, r) if err != nil { log.Panicln(err) diff --git a/examples/allocation/tinygo/greet.go b/examples/allocation/tinygo/greet.go index 63056bed..c0962c91 100644 --- a/examples/allocation/tinygo/greet.go +++ b/examples/allocation/tinygo/greet.go @@ -32,7 +32,7 @@ func main() { // Instantiate a Go-defined module named "env" that exports a function to // log to the console. _, err := r.NewHostModuleBuilder("env"). - ExportFunction("log", logString). + NewFunctionBuilder().WithFunc(logString).Export("log"). Instantiate(ctx, r) if err != nil { log.Panicln(err) diff --git a/examples/allocation/zig/greet.go b/examples/allocation/zig/greet.go index 4015984e..49ffdbb7 100644 --- a/examples/allocation/zig/greet.go +++ b/examples/allocation/zig/greet.go @@ -37,7 +37,7 @@ func run() error { // Instantiate a Go-defined module named "env" that exports a function to // log to the console. _, err := r.NewHostModuleBuilder("env"). - ExportFunction("log", logString). + NewFunctionBuilder().WithFunc(logString).Export("log"). Instantiate(ctx, r) if err != nil { return err diff --git a/examples/import-go/age-calculator.go b/examples/import-go/age-calculator.go index f93a672c..6efc8690 100644 --- a/examples/import-go/age-calculator.go +++ b/examples/import-go/age-calculator.go @@ -34,20 +34,24 @@ func main() { // Instantiate a Go-defined module named "env" that exports functions to // get the current year and log to the console. // - // Note: As noted on ExportFunction documentation, function signatures are - // constrained to a subset of numeric types. + // Note: As noted on wazero.HostFunctionBuilder documentation, function + // signatures are constrained to a subset of numeric types. // Note: "env" is a module name conventionally used for arbitrary // host-defined functions, but any name would do. _, err := r.NewHostModuleBuilder("env"). - ExportFunction("log_i32", func(v uint32) { + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, v uint32) { fmt.Println("log_i32 >>", v) }). - ExportFunction("current_year", func() uint32 { + Export("log_i32"). + NewFunctionBuilder(). + WithFunc(func(context.Context) uint32 { if envYear, err := strconv.ParseUint(os.Getenv("CURRENT_YEAR"), 10, 64); err == nil { return uint32(envYear) // Allow env-override to prevent annual test maintenance! } return uint32(time.Now().Year()) }). + Export("current_year"). Instantiate(ctx, r) if err != nil { log.Panicln(err) diff --git a/examples/multiple-results/multiple-results.go b/examples/multiple-results/multiple-results.go index 253fbeef..341a6df9 100644 --- a/examples/multiple-results/multiple-results.go +++ b/examples/multiple-results/multiple-results.go @@ -111,11 +111,14 @@ func multiValueFromImportedHostWasmFunctions(ctx context.Context, r wazero.Runti // Instantiate the host module with the exported `get_age` function which returns multiple results. if _, err := r.NewHostModuleBuilder("multi-value/host"). // Define a function that returns two results - ExportFunction("get_age", func() (age uint64, errno uint32) { + NewFunctionBuilder(). + WithFunc(func(context.Context) (age uint64, errno uint32) { age = 37 errno = 0 return - }).Instantiate(ctx, r); err != nil { + }). + Export("get_age"). + Instantiate(ctx, r); err != nil { return nil, err } // Then, creates the module which imports the `get_age` function from the `multi-value/host` module above. diff --git a/examples/namespace/counter.go b/examples/namespace/counter.go index ad42d8f2..0956e17b 100644 --- a/examples/namespace/counter.go +++ b/examples/namespace/counter.go @@ -58,7 +58,7 @@ type counter struct { counter uint32 } -func (e *counter) getAndIncrement() (ret uint32) { +func (e *counter) getAndIncrement(context.Context) (ret uint32) { ret = e.counter e.counter++ return @@ -72,7 +72,7 @@ func instantiateWithEnv(ctx context.Context, r wazero.Runtime, module wazero.Com // Instantiate a new "env" module which exports a stateful function. c := &counter{} _, err := r.NewHostModuleBuilder("env"). - ExportFunction("next_i32", c.getAndIncrement). + NewFunctionBuilder().WithFunc(c.getAndIncrement).Export("next_i32"). Instantiate(ctx, ns) if err != nil { log.Panicln(err) diff --git a/experimental/listener_example_test.go b/experimental/listener_example_test.go index 90c5f654..0e3b4276 100644 --- a/experimental/listener_example_test.go +++ b/experimental/listener_example_test.go @@ -36,7 +36,7 @@ func (u uniqGoFuncs) callees() []string { // NewListener implements FunctionListenerFactory.NewListener func (u uniqGoFuncs) NewListener(def api.FunctionDefinition) FunctionListener { - if def.GoFunc() == nil { + if def.GoFunction() == nil { return nil // only track go funcs } return u diff --git a/experimental/logging/log_listener.go b/experimental/logging/log_listener.go index 831a1df6..7606ca58 100644 --- a/experimental/logging/log_listener.go +++ b/experimental/logging/log_listener.go @@ -64,14 +64,14 @@ func (l *loggingListener) writeIndented(before bool, err error, vals []uint64, i message.WriteByte('\t') } if before { - if l.fnd.GoFunc() != nil { + if l.fnd.GoFunction() != nil { message.WriteString("==> ") } else { message.WriteString("--> ") } l.writeFuncEnter(&message, vals) } else { // after - if l.fnd.GoFunc() != nil { + if l.fnd.GoFunction() != nil { message.WriteString("<== ") } else { message.WriteString("<-- ") diff --git a/experimental/logging/log_listener_test.go b/experimental/logging/log_listener_test.go index 4db93587..7941dda3 100644 --- a/experimental/logging/log_listener_test.go +++ b/experimental/logging/log_listener_test.go @@ -266,7 +266,7 @@ func Test_loggingListener(t *testing.T) { var out bytes.Buffer lf := logging.NewLoggingListenerFactory(&out) - fn := func() {} + fn := func(context.Context) {} for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { @@ -287,7 +287,7 @@ func Test_loggingListener(t *testing.T) { } if tc.isHostFunc { - m.CodeSection = []*wasm.Code{wasm.MustParseGoFuncCode(fn)} + m.CodeSection = []*wasm.Code{wasm.MustParseGoReflectFuncCode(fn)} } else { m.CodeSection = []*wasm.Code{{Body: []byte{wasm.OpcodeEnd}}} } diff --git a/imports/assemblyscript/assemblyscript.go b/imports/assemblyscript/assemblyscript.go index 493e4970..b3c49bfb 100644 --- a/imports/assemblyscript/assemblyscript.go +++ b/imports/assemblyscript/assemblyscript.go @@ -26,6 +26,7 @@ package assemblyscript import ( "context" + "encoding/binary" "fmt" "io" "strconv" @@ -34,12 +35,13 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" - "github.com/tetratelabs/wazero/internal/ieee754" "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/sys" ) const ( + i32, f64 = wasm.ValueTypeI32, wasm.ValueTypeF64 + functionAbort = "abort" functionTrace = "trace" functionSeed = "seed" @@ -119,9 +121,10 @@ func (e *functionExporter) WithTraceToStderr() FunctionExporter { // ExportFunctions implements FunctionExporter.ExportFunctions func (e *functionExporter) ExportFunctions(builder wazero.HostModuleBuilder) { - builder.ExportFunction(functionAbort, e.abortFn) - builder.ExportFunction(functionTrace, e.traceFn) - builder.ExportFunction(functionSeed, seed) + exporter := builder.(wasm.HostFuncExporter) + exporter.ExportHostFunc(e.abortFn) + exporter.ExportHostFunc(e.traceFn) + exporter.ExportHostFunc(seed) } // abort is called on unrecoverable errors. This is typically present in Wasm @@ -137,18 +140,25 @@ func (e *functionExporter) ExportFunctions(builder wazero.HostModuleBuilder) { // (import "env" "abort" (func $~lib/builtins/abort (param i32 i32 i32 i32))) // // See https://github.com/AssemblyScript/assemblyscript/blob/fa14b3b03bd4607efa52aaff3132bea0c03a7989/std/assembly/wasi/index.ts#L18 -var abortMessageEnabled = wasm.NewGoFunc( - "abort", "~lib/builtins/abort", - []string{"message", "fileName", "lineNumber", "columnNumber"}, - abortWithMessage, -) +var abortMessageEnabled = &wasm.HostFunc{ + ExportNames: []string{functionAbort}, + Name: "~lib/builtins/abort", + ParamTypes: []api.ValueType{i32, i32, i32, i32}, + ParamNames: []string{"message", "fileName", "lineNumber", "columnNumber"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(abortWithMessage), + }, +} -var abortMessageDisabled = abortMessageEnabled.MustGoFunc(abort) +var abortMessageDisabled = abortMessageEnabled.WithGoModuleFunc(abort) -// abortWithMessage implements fnAbort -func abortWithMessage( - ctx context.Context, mod api.Module, message, fileName, lineNumber, columnNumber uint32, -) { +// abortWithMessage implements functionAbort +func abortWithMessage(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + message := uint32(params[0]) + fileName := uint32(params[1]) + lineNumber := uint32(params[2]) + columnNumber := uint32(params[3]) sysCtx := mod.(*wasm.CallContext).Sys mem := mod.Memory() // Don't panic if there was a problem reading the message @@ -157,13 +167,12 @@ func abortWithMessage( _, _ = fmt.Fprintf(sysCtx.Stderr(), "%s at %s:%d:%d\n", msg, fn, lineNumber, columnNumber) } } - abort(ctx, mod, message, fileName, lineNumber, columnNumber) + abort(ctx, mod, params) + return } -// abortWithMessage implements fnAbort ignoring the message. -func abort( - ctx context.Context, mod api.Module, message, fileName, lineNumber, columnNumber uint32, -) { +// abortWithMessage implements functionAbort ignoring the message. +func abort(ctx context.Context, mod api.Module, _ []uint64) (_ []uint64) { // AssemblyScript expects the exit code to be 255 // See https://github.com/AssemblyScript/assemblyscript/blob/v0.20.13/tests/compiler/wasi/abort.js#L14 exitCode := uint32(255) @@ -179,20 +188,24 @@ func abort( var traceDisabled = traceStdout.WithWasm([]byte{wasm.OpcodeEnd}) // traceStdout implements trace to the configured Stdout. -var traceStdout = wasm.NewGoFunc(functionTrace, "~lib/builtins/trace", - []string{"message", "nArgs", "arg0", "arg1", "arg2", "arg3", "arg4"}, - func( - ctx context.Context, mod api.Module, message uint32, nArgs uint32, arg0, arg1, arg2, arg3, arg4 float64, - ) { - traceTo(ctx, mod, message, nArgs, arg0, arg1, arg2, arg3, arg4, mod.(*wasm.CallContext).Sys.Stdout()) +var traceStdout = &wasm.HostFunc{ + ExportNames: []string{functionTrace}, + Name: "~lib/builtins/trace", + ParamTypes: []api.ValueType{i32, i32, f64, f64, f64, f64, f64}, + ParamNames: []string{"message", "nArgs", "arg0", "arg1", "arg2", "arg3", "arg4"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + traceTo(ctx, mod, params, mod.(*wasm.CallContext).Sys.Stdout()) + return + }), }, -) +} // traceStderr implements trace to the configured Stderr. -var traceStderr = traceStdout.MustGoFunc(func( - ctx context.Context, mod api.Module, message uint32, nArgs uint32, arg0, arg1, arg2, arg3, arg4 float64, -) { - traceTo(ctx, mod, message, nArgs, arg0, arg1, arg2, arg3, arg4, mod.(*wasm.CallContext).Sys.Stderr()) +var traceStderr = traceStdout.WithGoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + traceTo(ctx, mod, params, mod.(*wasm.CallContext).Sys.Stderr()) + return }) // traceTo implements the function "trace" in AssemblyScript. e.g. @@ -205,10 +218,15 @@ var traceStderr = traceStdout.MustGoFunc(func( // (import "env" "trace" (func $~lib/builtins/trace (param i32 i32 f64 f64 f64 f64 f64))) // // See https://github.com/AssemblyScript/assemblyscript/blob/fa14b3b03bd4607efa52aaff3132bea0c03a7989/std/assembly/wasi/index.ts#L61 -func traceTo( - ctx context.Context, mod api.Module, message uint32, nArgs uint32, arg0, arg1, arg2, arg3, arg4 float64, - writer io.Writer, -) { +func traceTo(ctx context.Context, mod api.Module, params []uint64, writer io.Writer) { + message := uint32(params[0]) + nArgs := uint32(params[1]) + arg0 := api.DecodeF64(params[2]) + arg1 := api.DecodeF64(params[3]) + arg2 := api.DecodeF64(params[4]) + arg3 := api.DecodeF64(params[5]) + arg4 := api.DecodeF64(params[6]) + msg, ok := readAssemblyScriptString(ctx, mod.Memory(), message) if !ok { return // don't panic if unable to trace @@ -253,16 +271,25 @@ func formatFloat(f float64) string { // (import "env" "seed" (func $~lib/builtins/seed (result f64))) // // See https://github.com/AssemblyScript/assemblyscript/blob/fa14b3b03bd4607efa52aaff3132bea0c03a7989/std/assembly/wasi/index.ts#L111 -var seed = wasm.NewGoFunc(functionSeed, "~lib/builtins/seed", []string{}, - func(mod api.Module) float64 { - randSource := mod.(*wasm.CallContext).Sys.RandSource() - v, err := ieee754.DecodeFloat64(randSource) - if err != nil { - panic(fmt.Errorf("error reading random seed: %w", err)) - } - return v +var seed = &wasm.HostFunc{ + ExportNames: []string{functionSeed}, + Name: "~lib/builtins/seed", + ResultTypes: []api.ValueType{f64}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { + r := mod.(*wasm.CallContext).Sys.RandSource() + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + panic(fmt.Errorf("error reading random seed: %w", err)) + } + // the caller interprets this as a float64 + raw := binary.LittleEndian.Uint64(buf) + return []uint64{raw} + }), }, -) +} // readAssemblyScriptString reads a UTF-16 string created by AssemblyScript. func readAssemblyScriptString(ctx context.Context, mem api.Memory, offset uint32) (string, bool) { diff --git a/imports/assemblyscript/assemblyscript_example_test.go b/imports/assemblyscript/assemblyscript_example_test.go index 5c046123..4c2178a3 100644 --- a/imports/assemblyscript/assemblyscript_example_test.go +++ b/imports/assemblyscript/assemblyscript_example_test.go @@ -32,7 +32,9 @@ func Example_functionExporter() { // First construct your own module builder for "env" envBuilder := r.NewHostModuleBuilder("env"). - ExportFunction("get_int", func() uint32 { return 1 }) + NewFunctionBuilder(). + WithFunc(func(context.Context) uint32 { return 1 }). + Export("get_int") // Now, add AssemblyScript special function imports into it. assemblyscript.NewFunctionExporter(). diff --git a/imports/emscripten/emscripten.go b/imports/emscripten/emscripten.go index 38fb0ea0..0c1e8099 100644 --- a/imports/emscripten/emscripten.go +++ b/imports/emscripten/emscripten.go @@ -62,7 +62,8 @@ type functionExporter struct{} // ExportFunctions implements FunctionExporter.ExportFunctions func (e *functionExporter) ExportFunctions(builder wazero.HostModuleBuilder) { - builder.ExportFunction(notifyMemoryGrowth.Name, notifyMemoryGrowth) + exporter := builder.(wasm.HostFuncExporter) + exporter.ExportHostFunc(notifyMemoryGrowth) } // emscriptenNotifyMemoryGrowth is called when wasm is compiled with diff --git a/imports/emscripten/emscripten_example_test.go b/imports/emscripten/emscripten_example_test.go index 1baf3790..7c09cd4d 100644 --- a/imports/emscripten/emscripten_example_test.go +++ b/imports/emscripten/emscripten_example_test.go @@ -39,7 +39,9 @@ func Example_functionExporter() { // Next, construct your own module builder for "env" with any functions // you need. envBuilder := r.NewHostModuleBuilder("env"). - ExportFunction("get_int", func() uint32 { return 1 }) + NewFunctionBuilder(). + WithFunc(func(context.Context) uint32 { return 1 }). + Export("get_int") // Now, add Emscripten special function imports into it. emscripten.NewFunctionExporter().ExportFunctions(envBuilder) diff --git a/imports/go/gojs.go b/imports/go/gojs.go index 1b99b969..0e09ca52 100644 --- a/imports/go/gojs.go +++ b/imports/go/gojs.go @@ -15,6 +15,7 @@ import ( "github.com/tetratelabs/wazero" . "github.com/tetratelabs/wazero/internal/gojs" + "github.com/tetratelabs/wazero/internal/wasm" ) // WithRoundTripper sets the http.RoundTripper used to Run Wasm. @@ -87,31 +88,35 @@ func Run(ctx context.Context, r wazero.Runtime, compiled wazero.CompiledModule, } // hostModuleBuilder returns a new wazero.HostModuleBuilder -func hostModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder { - return r.NewHostModuleBuilder("go"). - ExportFunction(GetRandomData.Name(), GetRandomData). - ExportFunction(Nanotime1.Name(), Nanotime1). - ExportFunction(WasmExit.Name(), WasmExit). - ExportFunction(CopyBytesToJS.Name(), CopyBytesToJS). - ExportFunction(ValueCall.Name(), ValueCall). - ExportFunction(ValueGet.Name(), ValueGet). - ExportFunction(ValueIndex.Name(), ValueIndex). - ExportFunction(ValueLength.Name(), ValueLength). - ExportFunction(ValueNew.Name(), ValueNew). - ExportFunction(ValueSet.Name(), ValueSet). - ExportFunction(WasmWrite.Name(), WasmWrite). - ExportFunction(ResetMemoryDataView.Name, ResetMemoryDataView). - ExportFunction(Walltime.Name(), Walltime). - ExportFunction(ScheduleTimeoutEvent.Name, ScheduleTimeoutEvent). - ExportFunction(ClearTimeoutEvent.Name, ClearTimeoutEvent). - ExportFunction(FinalizeRef.Name(), FinalizeRef). - ExportFunction(StringVal.Name(), StringVal). - ExportFunction(ValueDelete.Name, ValueDelete). - ExportFunction(ValueSetIndex.Name, ValueSetIndex). - ExportFunction(ValueInvoke.Name, ValueInvoke). - ExportFunction(ValuePrepareString.Name(), ValuePrepareString). - ExportFunction(ValueInstanceOf.Name, ValueInstanceOf). - ExportFunction(ValueLoadString.Name(), ValueLoadString). - ExportFunction(CopyBytesToGo.Name(), CopyBytesToGo). - ExportFunction(Debug.Name, Debug) +func hostModuleBuilder(r wazero.Runtime) (builder wazero.HostModuleBuilder) { + builder = r.NewHostModuleBuilder("go") + hfExporter := builder.(wasm.HostFuncExporter) + pfExporter := builder.(wasm.ProxyFuncExporter) + + pfExporter.ExportProxyFunc(GetRandomData) + pfExporter.ExportProxyFunc(Nanotime1) + pfExporter.ExportProxyFunc(WasmExit) + pfExporter.ExportProxyFunc(CopyBytesToJS) + pfExporter.ExportProxyFunc(ValueCall) + pfExporter.ExportProxyFunc(ValueGet) + pfExporter.ExportProxyFunc(ValueIndex) + pfExporter.ExportProxyFunc(ValueLength) + pfExporter.ExportProxyFunc(ValueNew) + pfExporter.ExportProxyFunc(ValueSet) + pfExporter.ExportProxyFunc(WasmWrite) + hfExporter.ExportHostFunc(ResetMemoryDataView) + pfExporter.ExportProxyFunc(Walltime) + hfExporter.ExportHostFunc(ScheduleTimeoutEvent) + hfExporter.ExportHostFunc(ClearTimeoutEvent) + pfExporter.ExportProxyFunc(FinalizeRef) + pfExporter.ExportProxyFunc(StringVal) + hfExporter.ExportHostFunc(ValueDelete) + hfExporter.ExportHostFunc(ValueSetIndex) + hfExporter.ExportHostFunc(ValueInvoke) + pfExporter.ExportProxyFunc(ValuePrepareString) + hfExporter.ExportHostFunc(ValueInstanceOf) + pfExporter.ExportProxyFunc(ValueLoadString) + pfExporter.ExportProxyFunc(CopyBytesToGo) + hfExporter.ExportHostFunc(Debug) + return } diff --git a/imports/wasi_snapshot_preview1/args.go b/imports/wasi_snapshot_preview1/args.go index 8d80774a..33ecf0d3 100644 --- a/imports/wasi_snapshot_preview1/args.go +++ b/imports/wasi_snapshot_preview1/args.go @@ -21,23 +21,21 @@ const ( // encoding to api.Memory // - argsSizesGet result argc * 4 bytes are written to this offset // - argvBuf: offset to write the null terminated arguments to api.Memory -// - argsSizesGet result argv_buf_size bytes are written to this offset +// - argsSizesGet result argv_len bytes are written to this offset // // Result (Errno) // // The return value is ErrnoSuccess except the following error conditions: // - ErrnoFault: there is not enough memory to write results // -// For example, if argsSizesGet wrote argc=2 and argvBufSize=5 for arguments: +// For example, if argsSizesGet wrote argc=2 and argvLen=5 for arguments: // "a" and "bc" parameters argv=7 and argvBuf=1, this function writes the below // to api.Memory: // -// argvBufSize uint32le uint32le -// +----------------+ +--------+ +--------+ -// | | | | | | -// -// []byte{?, 'a', 0, 'b', 'c', 0, ?, 1, 0, 0, 0, 3, 0, 0, 0, ?} -// +// argvLen uint32le uint32le +// +----------------+ +--------+ +--------+ +// | | | | | | +// []byte{?, 'a', 0, 'b', 'c', 0, ?, 1, 0, 0, 0, 3, 0, 0, 0, ?} // argvBuf --^ ^ ^ // argv --| | // offset that begins "a" --+ | @@ -46,14 +44,23 @@ const ( // See argsSizesGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#args_get // See https://en.wikipedia.org/wiki/Null-terminated_string -var argsGet = wasm.NewGoFunc( - functionArgsGet, functionArgsGet, - []string{"argv", "argv_buf"}, - func(ctx context.Context, mod api.Module, argv, argvBuf uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Args(), argv, argvBuf) +var argsGet = &wasm.HostFunc{ + ExportNames: []string{functionArgsGet}, + Name: functionArgsGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"argv", "argv_buf"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(argsGetFn), }, -) +} + +func argsGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + argv, argvBuf := uint32(params[0]), uint32(params[1]) + return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Args(), argv, argvBuf) +} // argsSizesGet is the WASI function named functionArgsSizesGet that reads // command-line argument sizes. @@ -61,7 +68,7 @@ var argsGet = wasm.NewGoFunc( // # Parameters // // - resultArgc: offset to write the argument count to api.Memory -// - resultArgvBufSize: offset to write the null-terminated argument length to +// - resultArgvLen: offset to write the null-terminated argument length to // api.Memory // // Result (Errno) @@ -70,7 +77,7 @@ var argsGet = wasm.NewGoFunc( // - ErrnoFault: there is not enough memory to write results // // For example, if args are "a", "bc" and parameters resultArgc=1 and -// resultArgvBufSize=6, this function writes the below to api.Memory: +// resultArgvLen=6, this function writes the below to api.Memory: // // uint32le uint32le // +--------+ +--------+ @@ -78,25 +85,34 @@ var argsGet = wasm.NewGoFunc( // []byte{?, 2, 0, 0, 0, ?, 5, 0, 0, 0, ?} // resultArgc --^ ^ // 2 args --+ | -// resultArgvBufSize --| +// resultArgvLen --| // len([]byte{'a',0,'b',c',0}) --+ // // See argsGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#args_sizes_get // See https://en.wikipedia.org/wiki/Null-terminated_string -var argsSizesGet = wasm.NewGoFunc( - functionArgsSizesGet, functionArgsSizesGet, - []string{"result.argc", "result.argv_buf_size"}, - func(ctx context.Context, mod api.Module, resultArgc, resultArgvBufSize uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - mem := mod.Memory() - - if !mem.WriteUint32Le(ctx, resultArgc, uint32(len(sysCtx.Args()))) { - return ErrnoFault - } - if !mem.WriteUint32Le(ctx, resultArgvBufSize, sysCtx.ArgsSize()) { - return ErrnoFault - } - return ErrnoSuccess +var argsSizesGet = &wasm.HostFunc{ + ExportNames: []string{functionArgsSizesGet}, + Name: functionArgsSizesGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"result.argc", "result.argv_len"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(argsSizesGetFn), }, -) +} + +func argsSizesGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + mem := mod.Memory() + resultArgc, resultArgvLen := uint32(params[0]), uint32(params[1]) + + if !mem.WriteUint32Le(ctx, resultArgc, uint32(len(sysCtx.Args()))) { + return errnoFault + } + if !mem.WriteUint32Le(ctx, resultArgvLen, sysCtx.ArgsSize()) { + return errnoFault + } + return errnoSuccess +} diff --git a/imports/wasi_snapshot_preview1/args_test.go b/imports/wasi_snapshot_preview1/args_test.go index d847bf80..33a3bd1f 100644 --- a/imports/wasi_snapshot_preview1/args_test.go +++ b/imports/wasi_snapshot_preview1/args_test.go @@ -114,12 +114,12 @@ func Test_argsSizesGet(t *testing.T) { mod, r, log := requireProxyModule(t, wazero.NewModuleConfig().WithArgs("a", "bc")) defer r.Close(testCtx) - resultArgc := uint32(1) // arbitrary offset - resultArgvBufSize := uint32(6) // arbitrary offset + resultArgc := uint32(1) // arbitrary offset + resultArgvLen := uint32(6) // arbitrary offset expectedMemory := []byte{ '?', // resultArgc is after this 0x2, 0x0, 0x0, 0x0, // little endian-encoded arg count - '?', // resultArgvBufSize is after this + '?', // resultArgvLen is after this 0x5, 0x0, 0x0, 0x0, // little endian-encoded size of null terminated strings '?', // stopped after encoding } @@ -127,10 +127,10 @@ func Test_argsSizesGet(t *testing.T) { maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke argsSizesGet and check the memory side effects. - requireErrno(t, ErrnoSuccess, mod, functionArgsSizesGet, uint64(resultArgc), uint64(resultArgvBufSize)) + requireErrno(t, ErrnoSuccess, mod, functionArgsSizesGet, uint64(resultArgc), uint64(resultArgvLen)) require.Equal(t, ` ---> proxy.args_sizes_get(result.argc=1,result.argv_buf_size=6) - ==> wasi_snapshot_preview1.args_sizes_get(result.argc=1,result.argv_buf_size=6) +--> proxy.args_sizes_get(result.argc=1,result.argv_len=6) + ==> wasi_snapshot_preview1.args_sizes_get(result.argc=1,result.argv_len=6) <== ESUCCESS <-- (0) `, "\n"+log.String()) @@ -148,50 +148,50 @@ func Test_argsSizesGet_Errors(t *testing.T) { validAddress := uint32(0) // arbitrary valid address as arguments to args_sizes_get. We chose 0 here. tests := []struct { - name string - argc, argvBufSize uint32 - expectedLog string + name string + argc, argvLen uint32 + expectedLog string }{ { - name: "out-of-memory argc", - argc: memorySize, - argvBufSize: validAddress, + name: "out-of-memory argc", + argc: memorySize, + argvLen: validAddress, expectedLog: ` ---> proxy.args_sizes_get(result.argc=65536,result.argv_buf_size=0) - ==> wasi_snapshot_preview1.args_sizes_get(result.argc=65536,result.argv_buf_size=0) +--> proxy.args_sizes_get(result.argc=65536,result.argv_len=0) + ==> wasi_snapshot_preview1.args_sizes_get(result.argc=65536,result.argv_len=0) <== EFAULT <-- (21) `, }, { - name: "out-of-memory argvBufSize", - argc: validAddress, - argvBufSize: memorySize, + name: "out-of-memory argvLen", + argc: validAddress, + argvLen: memorySize, expectedLog: ` ---> proxy.args_sizes_get(result.argc=0,result.argv_buf_size=65536) - ==> wasi_snapshot_preview1.args_sizes_get(result.argc=0,result.argv_buf_size=65536) +--> proxy.args_sizes_get(result.argc=0,result.argv_len=65536) + ==> wasi_snapshot_preview1.args_sizes_get(result.argc=0,result.argv_len=65536) <== EFAULT <-- (21) `, }, { - name: "argc exceeds the maximum valid address by 1", - argc: memorySize - 4 + 1, // 4 is the size of uint32, the type of the count of args - argvBufSize: validAddress, + name: "argc exceeds the maximum valid address by 1", + argc: memorySize - 4 + 1, // 4 is the size of uint32, the type of the count of args + argvLen: validAddress, expectedLog: ` ---> proxy.args_sizes_get(result.argc=65533,result.argv_buf_size=0) - ==> wasi_snapshot_preview1.args_sizes_get(result.argc=65533,result.argv_buf_size=0) +--> proxy.args_sizes_get(result.argc=65533,result.argv_len=0) + ==> wasi_snapshot_preview1.args_sizes_get(result.argc=65533,result.argv_len=0) <== EFAULT <-- (21) `, }, { - name: "argvBufSize exceeds the maximum valid size by 1", - argc: validAddress, - argvBufSize: memorySize - 4 + 1, // 4 is count of bytes to encode uint32le + name: "argvLen exceeds the maximum valid size by 1", + argc: validAddress, + argvLen: memorySize - 4 + 1, // 4 is count of bytes to encode uint32le expectedLog: ` ---> proxy.args_sizes_get(result.argc=0,result.argv_buf_size=65533) - ==> wasi_snapshot_preview1.args_sizes_get(result.argc=0,result.argv_buf_size=65533) +--> proxy.args_sizes_get(result.argc=0,result.argv_len=65533) + ==> wasi_snapshot_preview1.args_sizes_get(result.argc=0,result.argv_len=65533) <== EFAULT <-- (21) `, @@ -204,7 +204,7 @@ func Test_argsSizesGet_Errors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { defer log.Reset() - requireErrno(t, ErrnoFault, mod, functionArgsSizesGet, uint64(tc.argc), uint64(tc.argvBufSize)) + requireErrno(t, ErrnoFault, mod, functionArgsSizesGet, uint64(tc.argc), uint64(tc.argvLen)) require.Equal(t, tc.expectedLog, "\n"+log.String()) }) } diff --git a/imports/wasi_snapshot_preview1/clock.go b/imports/wasi_snapshot_preview1/clock.go index 22d5500f..30bea72e 100644 --- a/imports/wasi_snapshot_preview1/clock.go +++ b/imports/wasi_snapshot_preview1/clock.go @@ -42,36 +42,45 @@ const ( // For example, if the resolution is 100ns, this function writes the below to // api.Memory: // -// uint64le -// +-------------------------------------+ -// | | -// []byte{?, 0x64, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ?} +// uint64le +// +-------------------------------------+ +// | | +// []byte{?, 0x64, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ?} // resultResolution --^ // // Note: This is similar to `clock_getres` in POSIX. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-clock_res_getid-clockid---errno-timestamp // See https://linux.die.net/man/3/clock_getres -var clockResGet = wasm.NewGoFunc( - functionClockResGet, functionClockResGet, - []string{"id", "result.resolution"}, - func(ctx context.Context, mod api.Module, id uint32, resultResolution uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - - var resolution uint64 // ns - switch id { - case clockIDRealtime: - resolution = uint64(sysCtx.WalltimeResolution()) - case clockIDMonotonic: - resolution = uint64(sysCtx.NanotimeResolution()) - default: - return ErrnoInval - } - if !mod.Memory().WriteUint64Le(ctx, resultResolution, resolution) { - return ErrnoFault - } - return ErrnoSuccess +var clockResGet = &wasm.HostFunc{ + ExportNames: []string{functionClockResGet}, + Name: functionClockResGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"id", "result.resolution"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(clockResGetFn), }, -) +} + +func clockResGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + id, resultResolution := uint32(params[0]), uint32(params[1]) + + var resolution uint64 // ns + switch id { + case clockIDRealtime: + resolution = uint64(sysCtx.WalltimeResolution()) + case clockIDMonotonic: + resolution = uint64(sysCtx.NanotimeResolution()) + default: + return errnoInval + } + if !mod.Memory().WriteUint64Le(ctx, resultResolution, resolution) { + return errnoFault + } + return errnoSuccess +} // clockTimeGet is the WASI function named functionClockTimeGet that returns // the time value of a name (time.Now). @@ -104,27 +113,38 @@ var clockResGet = wasm.NewGoFunc( // Note: This is similar to `clock_gettime` in POSIX. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-clock_time_getid-clockid-precision-timestamp---errno-timestamp // See https://linux.die.net/man/3/clock_gettime -var clockTimeGet = wasm.NewGoFunc( - functionClockTimeGet, functionClockTimeGet, - []string{"id", "precision", "result.timestamp"}, - func(ctx context.Context, mod api.Module, id uint32, precision uint64, resultTimestamp uint32) Errno { - // TODO: precision is currently ignored. - sysCtx := mod.(*wasm.CallContext).Sys - - var val uint64 - switch id { - case clockIDRealtime: - sec, nsec := sysCtx.Walltime(ctx) - val = (uint64(sec) * uint64(time.Second.Nanoseconds())) + uint64(nsec) - case clockIDMonotonic: - val = uint64(sysCtx.Nanotime(ctx)) - default: - return ErrnoInval - } - - if !mod.Memory().WriteUint64Le(ctx, resultTimestamp, val) { - return ErrnoFault - } - return ErrnoSuccess +var clockTimeGet = &wasm.HostFunc{ + ExportNames: []string{functionClockTimeGet}, + Name: functionClockTimeGet, + ParamTypes: []api.ValueType{i32, i64, i32}, + ParamNames: []string{"id", "precision", "result.timestamp"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(clockTimeGetFn), }, -) +} + +func clockTimeGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + id := uint32(params[0]) + // TODO: precision is currently ignored. + _ = params[1] + resultTimestamp := uint32(params[2]) + + var val uint64 + switch id { + case clockIDRealtime: + sec, nsec := sysCtx.Walltime(ctx) + val = (uint64(sec) * uint64(time.Second.Nanoseconds())) + uint64(nsec) + case clockIDMonotonic: + val = uint64(sysCtx.Nanotime(ctx)) + default: + return errnoInval + } + + if !mod.Memory().WriteUint64Le(ctx, resultTimestamp, val) { + return errnoFault + } + return errnoSuccess +} diff --git a/imports/wasi_snapshot_preview1/clock_test.go b/imports/wasi_snapshot_preview1/clock_test.go index 1021eefb..420c66f3 100644 --- a/imports/wasi_snapshot_preview1/clock_test.go +++ b/imports/wasi_snapshot_preview1/clock_test.go @@ -257,9 +257,9 @@ func Test_clockTimeGet_Errors(t *testing.T) { memorySize := mod.Memory().Size(testCtx) tests := []struct { - name string - resultTimestamp, argvBufSize uint32 - expectedLog string + name string + resultTimestamp, argvLen uint32 + expectedLog string }{ { name: "resultTimestamp out-of-memory", diff --git a/imports/wasi_snapshot_preview1/environ.go b/imports/wasi_snapshot_preview1/environ.go index 52404924..5516757e 100644 --- a/imports/wasi_snapshot_preview1/environ.go +++ b/imports/wasi_snapshot_preview1/environ.go @@ -22,38 +22,45 @@ const ( // - environSizesGet result environc * 4 bytes are written to this offset // - environBuf: offset to write the null-terminated variables to api.Memory // - the format is like os.Environ: null-terminated "key=val" entries -// - environSizesGet result environBufSize bytes are written to this offset +// - environSizesGet result environLen bytes are written to this offset // // Result (Errno) // // The return value is ErrnoSuccess except the following error conditions: // - ErrnoFault: there is not enough memory to write results // -// For example, if environSizesGet wrote environc=2 and environBufSize=9 for +// For example, if environSizesGet wrote environc=2 and environLen=9 for // environment variables: "a=b", "b=cd" and parameters environ=11 and // environBuf=1, this function writes the below to api.Memory: // -// environBufSize uint32le uint32le -// +------------------------------------+ +--------+ +--------+ -// | | | | | | -// []byte{?, 'a', '=', 'b', 0, 'b', '=', 'c', 'd', 0, ?, 1, 0, 0, 0, 5, 0, 0, 0, ?} -// -// environBuf --^ ^ ^ -// -// environ offset for "a=b" --+ | -// environ offset for "b=cd" --+ +// environLen uint32le uint32le +// +------------------------------------+ +--------+ +--------+ +// | | | | | | +// []byte{?, 'a', '=', 'b', 0, 'b', '=', 'c', 'd', 0, ?, 1, 0, 0, 0, 5, 0, 0, 0, ?} +// environBuf --^ ^ ^ +// environ offset for "a=b" --+ | +// environ offset for "b=cd" --+ // // See environSizesGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#environ_get // See https://en.wikipedia.org/wiki/Null-terminated_string -var environGet = wasm.NewGoFunc( - functionEnvironGet, functionEnvironGet, - []string{"environ", "environ_buf"}, - func(ctx context.Context, mod api.Module, environ uint32, environBuf uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Environ(), environ, environBuf) +var environGet = &wasm.HostFunc{ + ExportNames: []string{functionEnvironGet}, + Name: functionEnvironGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"environ", "environ_buf"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(environGetFn), }, -) +} + +func environGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + environ, environBuf := uint32(params[0]), uint32(params[1]) + return writeOffsetsAndNullTerminatedValues(ctx, mod.Memory(), sysCtx.Environ(), environ, environBuf) +} // environSizesGet is the WASI function named functionEnvironSizesGet that // reads environment variable sizes. @@ -62,7 +69,7 @@ var environGet = wasm.NewGoFunc( // // - resultEnvironc: offset to write the count of environment variables to // api.Memory -// - resultEnvironBufSize: offset to write the null-terminated environment +// - resultEnvironvLen: offset to write the null-terminated environment // variable length to api.Memory // // Result (Errno) @@ -71,37 +78,43 @@ var environGet = wasm.NewGoFunc( // - ErrnoFault: there is not enough memory to write results // // For example, if environ are "a=b","b=cd" and parameters resultEnvironc=1 and -// resultEnvironBufSize=6, this function writes the below to api.Memory: +// resultEnvironvLen=6, this function writes the below to api.Memory: // -// uint32le uint32le -// +--------+ +--------+ -// | | | | -// []byte{?, 2, 0, 0, 0, ?, 9, 0, 0, 0, ?} -// -// resultEnvironc --^ ^ -// -// 2 variables --+ | -// resultEnvironBufSize --| -// len([]byte{'a','=','b',0, | -// 'b','=','c','d',0}) --+ +// uint32le uint32le +// +--------+ +--------+ +// | | | | +// []byte{?, 2, 0, 0, 0, ?, 9, 0, 0, 0, ?} +// resultEnvironc --^ ^ +// 2 variables --+ | +// resultEnvironvLen --| +// len([]byte{'a','=','b',0, | +// 'b','=','c','d',0}) --+ // // See environGet // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#environ_sizes_get // and https://en.wikipedia.org/wiki/Null-terminated_string -var environSizesGet = wasm.NewGoFunc( - functionEnvironSizesGet, functionEnvironSizesGet, - []string{"result.environc", "result.environBufSize"}, - func(ctx context.Context, mod api.Module, resultEnvironc uint32, resultEnvironBufSize uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - mem := mod.Memory() - - if !mem.WriteUint32Le(ctx, resultEnvironc, uint32(len(sysCtx.Environ()))) { - return ErrnoFault - } - if !mem.WriteUint32Le(ctx, resultEnvironBufSize, sysCtx.EnvironSize()) { - return ErrnoFault - } - - return ErrnoSuccess +var environSizesGet = &wasm.HostFunc{ + ExportNames: []string{functionEnvironSizesGet}, + Name: functionEnvironSizesGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"result.environc", "result.environv_len"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(environSizesGetFn), }, -) +} + +func environSizesGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + mem := mod.Memory() + resultEnvironc, resultEnvironvLen := uint32(params[0]), uint32(params[1]) + + if !mem.WriteUint32Le(ctx, resultEnvironc, uint32(len(sysCtx.Environ()))) { + return errnoFault + } + if !mem.WriteUint32Le(ctx, resultEnvironvLen, sysCtx.EnvironSize()) { + return errnoFault + } + return errnoSuccess +} diff --git a/imports/wasi_snapshot_preview1/environ_test.go b/imports/wasi_snapshot_preview1/environ_test.go index 1980022f..da7ec1da 100644 --- a/imports/wasi_snapshot_preview1/environ_test.go +++ b/imports/wasi_snapshot_preview1/environ_test.go @@ -54,7 +54,7 @@ func Test_environGet_Errors(t *testing.T) { expectedLog string }{ { - name: "out-of-memory environPtr", + name: "out-of-memory environ", environ: memorySize, environBuf: validAddress, expectedLog: ` @@ -65,7 +65,7 @@ func Test_environGet_Errors(t *testing.T) { `, }, { - name: "out-of-memory environBufPtr", + name: "out-of-memory environBuf", environ: validAddress, environBuf: memorySize, expectedLog: ` @@ -76,8 +76,8 @@ func Test_environGet_Errors(t *testing.T) { `, }, { - name: "environPtr exceeds the maximum valid address by 1", - // 4*envCount is the expected length for environPtr, 4 is the size of uint32 + name: "environ exceeds the maximum valid address by 1", + // 4*envCount is the expected length for environ, 4 is the size of uint32 environ: memorySize - 4*2 + 1, environBuf: validAddress, expectedLog: ` @@ -88,7 +88,7 @@ func Test_environGet_Errors(t *testing.T) { `, }, { - name: "environBufPtr exceeds the maximum valid address by 1", + name: "environBuf exceeds the maximum valid address by 1", environ: validAddress, // "a=bc", "b=cd" size = size of "a=bc0b=cd0" = 10 environBuf: memorySize - 10 + 1, @@ -118,12 +118,12 @@ func Test_environSizesGet(t *testing.T) { WithEnv("a", "b").WithEnv("b", "cd")) defer r.Close(testCtx) - resultEnvironc := uint32(1) // arbitrary offset - resultEnvironBufSize := uint32(6) // arbitrary offset + resultEnvironc := uint32(1) // arbitrary offset + resultEnvironvLen := uint32(6) // arbitrary offset expectedMemory := []byte{ '?', // resultEnvironc is after this 0x2, 0x0, 0x0, 0x0, // little endian-encoded environment variable count - '?', // resultEnvironBufSize is after this + '?', // resultEnvironvLen is after this 0x9, 0x0, 0x0, 0x0, // little endian-encoded size of null terminated strings '?', // stopped after encoding } @@ -131,10 +131,10 @@ func Test_environSizesGet(t *testing.T) { maskMemory(t, testCtx, mod, len(expectedMemory)) // Invoke environSizesGet and check the memory side effects. - requireErrno(t, ErrnoSuccess, mod, functionEnvironSizesGet, uint64(resultEnvironc), uint64(resultEnvironBufSize)) + requireErrno(t, ErrnoSuccess, mod, functionEnvironSizesGet, uint64(resultEnvironc), uint64(resultEnvironvLen)) require.Equal(t, ` ---> proxy.environ_sizes_get(result.environc=1,result.environBufSize=6) - ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=1,result.environBufSize=6) +--> proxy.environ_sizes_get(result.environc=1,result.environv_len=6) + ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=1,result.environv_len=6) <== ESUCCESS <-- (0) `, "\n"+log.String()) @@ -153,50 +153,50 @@ func Test_environSizesGet_Errors(t *testing.T) { validAddress := uint32(0) // arbitrary valid address as arguments to environ_sizes_get. We chose 0 here. tests := []struct { - name string - environc, environBufSize uint32 - expectedLog string + name string + environc, environLen uint32 + expectedLog string }{ { - name: "out-of-memory environCountPtr", - environc: memorySize, - environBufSize: validAddress, + name: "out-of-memory environCount", + environc: memorySize, + environLen: validAddress, expectedLog: ` ---> proxy.environ_sizes_get(result.environc=65536,result.environBufSize=0) - ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=65536,result.environBufSize=0) +--> proxy.environ_sizes_get(result.environc=65536,result.environv_len=0) + ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=65536,result.environv_len=0) <== EFAULT <-- (21) `, }, { - name: "out-of-memory environBufSizePtr", - environc: validAddress, - environBufSize: memorySize, + name: "out-of-memory environLen", + environc: validAddress, + environLen: memorySize, expectedLog: ` ---> proxy.environ_sizes_get(result.environc=0,result.environBufSize=65536) - ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=0,result.environBufSize=65536) +--> proxy.environ_sizes_get(result.environc=0,result.environv_len=65536) + ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=0,result.environv_len=65536) <== EFAULT <-- (21) `, }, { - name: "environCountPtr exceeds the maximum valid address by 1", - environc: memorySize - 4 + 1, // 4 is the size of uint32, the type of the count of environ - environBufSize: validAddress, + name: "environCount exceeds the maximum valid address by 1", + environc: memorySize - 4 + 1, // 4 is the size of uint32, the type of the count of environ + environLen: validAddress, expectedLog: ` ---> proxy.environ_sizes_get(result.environc=65533,result.environBufSize=0) - ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=65533,result.environBufSize=0) +--> proxy.environ_sizes_get(result.environc=65533,result.environv_len=0) + ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=65533,result.environv_len=0) <== EFAULT <-- (21) `, }, { - name: "environBufSizePtr exceeds the maximum valid size by 1", - environc: validAddress, - environBufSize: memorySize - 4 + 1, // 4 is count of bytes to encode uint32le + name: "environLen exceeds the maximum valid size by 1", + environc: validAddress, + environLen: memorySize - 4 + 1, // 4 is count of bytes to encode uint32le expectedLog: ` ---> proxy.environ_sizes_get(result.environc=0,result.environBufSize=65533) - ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=0,result.environBufSize=65533) +--> proxy.environ_sizes_get(result.environc=0,result.environv_len=65533) + ==> wasi_snapshot_preview1.environ_sizes_get(result.environc=0,result.environv_len=65533) <== EFAULT <-- (21) `, @@ -209,7 +209,7 @@ func Test_environSizesGet_Errors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { defer log.Reset() - requireErrno(t, ErrnoFault, mod, functionEnvironSizesGet, uint64(tc.environc), uint64(tc.environBufSize)) + requireErrno(t, ErrnoFault, mod, functionEnvironSizesGet, uint64(tc.environc), uint64(tc.environLen)) require.Equal(t, tc.expectedLog, "\n"+log.String()) }) } diff --git a/imports/wasi_snapshot_preview1/fs.go b/imports/wasi_snapshot_preview1/fs.go index 2e36632e..45e99629 100644 --- a/imports/wasi_snapshot_preview1/fs.go +++ b/imports/wasi_snapshot_preview1/fs.go @@ -81,18 +81,28 @@ var fdAllocate = stubFunction( // Note: This is similar to `close` in POSIX. // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#fd_close // and https://linux.die.net/man/3/close -var fdClose = wasm.NewGoFunc( - functionFdClose, functionFdClose, - []string{"fd"}, - func(ctx context.Context, mod api.Module, fd uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - if ok := sysCtx.FS(ctx).CloseFile(ctx, fd); !ok { - return ErrnoBadf - } - - return ErrnoSuccess +var fdClose = &wasm.HostFunc{ + ExportNames: []string{functionFdClose}, + Name: functionFdClose, + ParamTypes: []api.ValueType{i32}, + ParamNames: []string{"fd"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdCloseFn), }, -) +} + +func fdCloseFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + fd := uint32(params[0]) + + if ok := sysCtx.FS(ctx).CloseFile(ctx, fd); !ok { + return errnoBadf + } + + return errnoSuccess +} // fdDatasync is the WASI function named functionFdDatasync which synchronizes // the data of a file to disk. @@ -141,18 +151,28 @@ var fdDatasync = stubFunction( // well as additional fields. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fdstat // and https://linux.die.net/man/3/fsync -var fdFdstatGet = wasm.NewGoFunc( - functionFdFdstatGet, functionFdFdstatGet, - []string{"fd", "result.stat"}, - func(ctx context.Context, mod api.Module, fd uint32, resultStat uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - if _, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok { - return ErrnoBadf - } - // TODO: actually write the fdstat! - return ErrnoSuccess +var fdFdstatGet = &wasm.HostFunc{ + ExportNames: []string{functionFdFdstatGet}, + Name: functionFdFdstatGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"fd", "result.stat"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdFdstatGetFn), }, -) +} + +func fdFdstatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + // TODO: actually write the fdstat! + fd, _ := uint32(params[0]), uint32(params[1]) + + if _, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok { + return errnoBadf + } + return errnoSuccess +} // fdFdstatSetFlags is the WASI function named functionFdFdstatSetFlags which // adjusts the flags associated with a file descriptor. @@ -244,28 +264,38 @@ var fdPread = stubFunction( // // See fdPrestatDirName and // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#prestat -var fdPrestatGet = wasm.NewGoFunc( - functionFdPrestatGet, functionFdPrestatGet, - []string{"fd", "result.prestat"}, - func(ctx context.Context, mod api.Module, fd uint32, resultPrestat uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - entry, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) - if !ok { - return ErrnoBadf - } - - // Zero-value 8-bit tag, and 3-byte zero-value paddings, which is uint32le(0) in short. - if !mod.Memory().WriteUint32Le(ctx, resultPrestat, uint32(0)) { - return ErrnoFault - } - // Write the length of the directory name at offset 4. - if !mod.Memory().WriteUint32Le(ctx, resultPrestat+4, uint32(len(entry.Path))) { - return ErrnoFault - } - - return ErrnoSuccess +var fdPrestatGet = &wasm.HostFunc{ + ExportNames: []string{functionFdPrestatGet}, + Name: functionFdPrestatGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"fd", "result.prestat"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdPrestatGetFn), }, -) +} + +func fdPrestatGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + fd, resultPrestat := uint32(params[0]), uint32(params[1]) + + entry, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) + if !ok { + return errnoBadf + } + + // Zero-value 8-bit tag, and 3-byte zero-value paddings, which is uint32le(0) in short. + if !mod.Memory().WriteUint32Le(ctx, resultPrestat, uint32(0)) { + return errnoFault + } + // Write the length of the directory name at offset 4. + if !mod.Memory().WriteUint32Le(ctx, resultPrestat+4, uint32(len(entry.Path))) { + return errnoFault + } + + return errnoSuccess +} // fdPrestatDirName is the WASI function named functionFdPrestatDirName which // returns the path of the pre-opened directory of a file descriptor. @@ -297,28 +327,39 @@ var fdPrestatGet = wasm.NewGoFunc( // // See fdPrestatGet // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_prestat_dir_name -var fdPrestatDirName = wasm.NewGoFunc( - functionFdPrestatDirName, functionFdPrestatDirName, - []string{"fd", "path", "path_len"}, - func(ctx context.Context, mod api.Module, fd uint32, pathPtr uint32, pathLen uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) - if !ok { - return ErrnoBadf - } - - // Some runtimes may have another semantics. See /RATIONALE.md - if uint32(len(f.Path)) < pathLen { - return ErrnoNametoolong - } - - // TODO: fdPrestatDirName may have to return ErrnoNotdir if the type of the prestat data of `fd` is not a PrestatDir. - if !mod.Memory().Write(ctx, pathPtr, []byte(f.Path)[:pathLen]) { - return ErrnoFault - } - return ErrnoSuccess +var fdPrestatDirName = &wasm.HostFunc{ + ExportNames: []string{functionFdPrestatDirName}, + Name: functionFdPrestatDirName, + ParamTypes: []api.ValueType{i32, i32, i32}, + ParamNames: []string{"fd", "path", "path_len"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdPrestatDirNameFn), }, -) +} + +func fdPrestatDirNameFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + fd, path, pathLen := uint32(params[0]), uint32(params[1]), uint32(params[2]) + + f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd) + if !ok { + return errnoBadf + } + + // Some runtimes may have another semantics. See /RATIONALE.md + if uint32(len(f.Path)) < pathLen { + return errnoNametoolong + } + + // TODO: fdPrestatDirName may have to return ErrnoNotdir if the type of the + // prestat data of `fd` is not a PrestatDir. + if !mod.Memory().Write(ctx, path, []byte(f.Path)[:pathLen]) { + return errnoFault + } + return errnoSuccess +} // fdPwrite is the WASI function named functionFdPwrite which writes to a file // descriptor, without using and updating the file descriptor's offset. @@ -378,65 +419,78 @@ var fdPwrite = stubFunction(functionFdPwrite, // // See fdWrite // and https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_read -var fdRead = wasm.NewGoFunc( - functionFdRead, functionFdRead, - []string{"fd", "iovs", "iovs_len", "result.size"}, - func(ctx context.Context, mod api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - mem := mod.Memory() - reader := internalsys.FdReader(ctx, sysCtx, fd) - if reader == nil { - return ErrnoBadf - } - - var nread uint32 - for i := uint32(0); i < iovsCount; i++ { - iovPtr := iovs + i*8 - offset, ok := mem.ReadUint32Le(ctx, iovPtr) - if !ok { - return ErrnoFault - } - l, ok := mem.ReadUint32Le(ctx, iovPtr+4) - if !ok { - return ErrnoFault - } - b, ok := mem.Read(ctx, offset, l) - if !ok { - return ErrnoFault - } - - n, err := reader.Read(b) - nread += uint32(n) - - shouldContinue, errno := fdRead_shouldContinueRead(uint32(n), l, err) - if errno != 0 { - return errno - } else if !shouldContinue { - break - } - } - if !mem.WriteUint32Le(ctx, resultSize, nread) { - return ErrnoFault - } - return ErrnoSuccess +var fdRead = &wasm.HostFunc{ + ExportNames: []string{functionFdRead}, + Name: functionFdRead, + ParamTypes: []api.ValueType{i32, i32, i32, i32}, + ParamNames: []string{"fd", "iovs", "iovs_len", "result.size"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdReadFn), }, -) +} + +func fdReadFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + mem := mod.Memory() + fd := uint32(params[0]) + iovs := uint32(params[1]) + iovsCount := uint32(params[2]) + resultSize := uint32(params[3]) + + reader := internalsys.FdReader(ctx, sysCtx, fd) + if reader == nil { + return errnoBadf + } + + var nread uint32 + for i := uint32(0); i < iovsCount; i++ { + iov := iovs + i*8 + offset, ok := mem.ReadUint32Le(ctx, iov) + if !ok { + return errnoFault + } + l, ok := mem.ReadUint32Le(ctx, iov+4) + if !ok { + return errnoFault + } + b, ok := mem.Read(ctx, offset, l) + if !ok { + return errnoFault + } + + n, err := reader.Read(b) + nread += uint32(n) + + shouldContinue, errno := fdRead_shouldContinueRead(uint32(n), l, err) + if errno != nil { + return errno + } else if !shouldContinue { + break + } + } + if !mem.WriteUint32Le(ctx, resultSize, nread) { + return errnoFault + } + return errnoSuccess +} // fdRead_shouldContinueRead decides whether to continue reading the next iovec // based on the amount read (n/l) and a possible error returned from io.Reader. // // Note: When there are both bytes read (n) and an error, this continues. // See /RATIONALE.md "Why ignore the error returned by io.Reader when n > 1?" -func fdRead_shouldContinueRead(n, l uint32, err error) (bool, Errno) { +func fdRead_shouldContinueRead(n, l uint32, err error) (bool, []uint64) { if errors.Is(err, io.EOF) { - return false, 0 // EOF isn't an error, and we shouldn't continue. + return false, nil // EOF isn't an error, and we shouldn't continue. } else if err != nil && n == 0 { - return false, ErrnoIo + return false, errnoIo } else if err != nil { - return false, 0 // Allow the caller to process n bytes. + return false, nil // Allow the caller to process n bytes. } // Continue reading, unless there's a partial read or nothing to read. - return n == l && n != 0, 0 + return n == l && n != 0, nil } // fdReaddir is the WASI function named functionFdReaddir which reads directory @@ -496,35 +550,48 @@ var fdRenumber = stubFunction( // // See io.Seeker // and https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_seek -var fdSeek = wasm.NewGoFunc( - functionFdSeek, functionFdSeek, - []string{"fd", "offset", "whence", "result.newoffset"}, - func(ctx context.Context, mod api.Module, fd uint32, offset uint64, whence uint32, resultNewoffset uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - var seeker io.Seeker - // Check to see if the file descriptor is available - if f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok || f.File == nil { - return ErrnoBadf - // fs.FS doesn't declare io.Seeker, but implementations such as os.File implement it. - } else if seeker, ok = f.File.(io.Seeker); !ok { - return ErrnoBadf - } - - if whence > io.SeekEnd /* exceeds the largest valid whence */ { - return ErrnoInval - } - newOffset, err := seeker.Seek(int64(offset), int(whence)) - if err != nil { - return ErrnoIo - } - - if !mod.Memory().WriteUint64Le(ctx, resultNewoffset, uint64(newOffset)) { - return ErrnoFault - } - - return ErrnoSuccess +var fdSeek = &wasm.HostFunc{ + ExportNames: []string{functionFdSeek}, + Name: functionFdSeek, + ParamTypes: []api.ValueType{i32, i64, i32, i32}, + ParamNames: []string{"fd", "offset", "whence", "result.newoffset"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdSeekFn), }, -) +} + +func fdSeekFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + fd := uint32(params[0]) + offset := params[1] + whence := uint32(params[2]) + resultNewoffset := uint32(params[3]) + + var seeker io.Seeker + // Check to see if the file descriptor is available + if f, ok := sysCtx.FS(ctx).OpenedFile(ctx, fd); !ok || f.File == nil { + return errnoBadf + // fs.FS doesn't declare io.Seeker, but implementations such as os.File implement it. + } else if seeker, ok = f.File.(io.Seeker); !ok { + return errnoBadf + } + + if whence > io.SeekEnd /* exceeds the largest valid whence */ { + return errnoInval + } + newOffset, err := seeker.Seek(int64(offset), int(whence)) + if err != nil { + return errnoIo + } + + if !mod.Memory().WriteUint64Le(ctx, resultNewoffset, uint64(newOffset)) { + return errnoFault + } + + return errnoSuccess +} // fdSync is the WASI function named functionFdSync which synchronizes the data // and metadata of a file to disk. @@ -603,52 +670,65 @@ var fdTell = stubFunction( // See fdRead // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#ciovec // and https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#fd_write -var fdWrite = wasm.NewGoFunc( - functionFdWrite, functionFdWrite, - []string{"fd", "iovs", "iovs_len", "result.size"}, - func(ctx context.Context, mod api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - sysCtx := mod.(*wasm.CallContext).Sys - writer := internalsys.FdWriter(ctx, sysCtx, fd) - if writer == nil { - return ErrnoBadf - } - - var err error - var nwritten uint32 - for i := uint32(0); i < iovsCount; i++ { - iovPtr := iovs + i*8 - offset, ok := mod.Memory().ReadUint32Le(ctx, iovPtr) - if !ok { - return ErrnoFault - } - // Note: emscripten has been known to write zero length iovec. However, - // it is not common in other compilers, so we don't optimize for it. - l, ok := mod.Memory().ReadUint32Le(ctx, iovPtr+4) - if !ok { - return ErrnoFault - } - - var n int - if writer == io.Discard { // special-case default - n = int(l) - } else { - b, ok := mod.Memory().Read(ctx, offset, l) - if !ok { - return ErrnoFault - } - n, err = writer.Write(b) - if err != nil { - return ErrnoIo - } - } - nwritten += uint32(n) - } - if !mod.Memory().WriteUint32Le(ctx, resultSize, nwritten) { - return ErrnoFault - } - return ErrnoSuccess +var fdWrite = &wasm.HostFunc{ + ExportNames: []string{functionFdWrite}, + Name: functionFdWrite, + ParamTypes: []api.ValueType{i32, i32, i32, i32}, + ParamNames: []string{"fd", "iovs", "iovs_len", "result.size"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(fdWriteFn), }, -) +} + +func fdWriteFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + fd := uint32(params[0]) + iovs := uint32(params[1]) + iovsCount := uint32(params[2]) + resultSize := uint32(params[3]) + + sysCtx := mod.(*wasm.CallContext).Sys + writer := internalsys.FdWriter(ctx, sysCtx, fd) + if writer == nil { + return errnoBadf + } + + var err error + var nwritten uint32 + for i := uint32(0); i < iovsCount; i++ { + iov := iovs + i*8 + offset, ok := mod.Memory().ReadUint32Le(ctx, iov) + if !ok { + return errnoFault + } + // Note: emscripten has been known to write zero length iovec. However, + // it is not common in other compilers, so we don't optimize for it. + l, ok := mod.Memory().ReadUint32Le(ctx, iov+4) + if !ok { + return errnoFault + } + + var n int + if writer == io.Discard { // special-case default + n = int(l) + } else { + b, ok := mod.Memory().Read(ctx, offset, l) + if !ok { + return errnoFault + } + n, err = writer.Write(b) + if err != nil { + return errnoIo + } + } + nwritten += uint32(n) + } + if !mod.Memory().WriteUint32Le(ctx, resultSize, nwritten) { + return errnoFault + } + return errnoSuccess +} // pathCreateDirectory is the WASI function named functionPathCreateDirectory // which creates a directory. @@ -743,38 +823,56 @@ var pathLink = stubFunction( // - The returned file descriptor is not guaranteed to be the lowest-number // // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#path_open -var pathOpen = wasm.NewGoFunc( - functionPathOpen, functionPathOpen, - []string{"fd", "dirflags", "path", "path_len", "oflags", "fs_rights_base", "fs_rights_inheriting", "fdflags", "result.opened_fd"}, - func(ctx context.Context, mod api.Module, fd, dirflags, pathPtr, pathLen, oflags uint32, _, _ uint64, - fdflags, resultOpenedFd uint32) (errno Errno) { - sysCtx := mod.(*wasm.CallContext).Sys - fsc := sysCtx.FS(ctx) - if _, ok := fsc.OpenedFile(ctx, fd); !ok { - return ErrnoBadf - } - - b, ok := mod.Memory().Read(ctx, pathPtr, pathLen) - if !ok { - return ErrnoFault - } - - if newFD, err := fsc.OpenFile(ctx, string(b)); err != nil { - switch { - case errors.Is(err, fs.ErrNotExist): - return ErrnoNoent - case errors.Is(err, fs.ErrExist): - return ErrnoExist - default: - return ErrnoIo - } - } else if !mod.Memory().WriteUint32Le(ctx, resultOpenedFd, newFD) { - _ = fsc.CloseFile(ctx, newFD) - return ErrnoFault - } - return ErrnoSuccess +var pathOpen = &wasm.HostFunc{ + ExportNames: []string{functionPathOpen}, + Name: functionPathOpen, + ParamTypes: []api.ValueType{i32, i32, i32, i32, i32, i64, i64, i32, i32}, + ParamNames: []string{"fd", "dirflags", "path", "path_len", "oflags", "fs_rights_base", "fs_rights_inheriting", "fdflags", "result.opened_fd"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(pathOpenFn), }, -) +} + +func pathOpenFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + fsc := sysCtx.FS(ctx) + + fd := uint32(params[0]) + _ /* dirflags */ = uint32(params[1]) + path := uint32(params[2]) + pathLen := uint32(params[3]) + _ /* oflags */ = uint32(params[4]) + // rights aren't used + _, _ = params[5], params[6] + _ /* fdflags */ = uint32(params[7]) + resultOpenedFd := uint32(params[8]) + + if _, ok := fsc.OpenedFile(ctx, fd); !ok { + return errnoBadf + } + + b, ok := mod.Memory().Read(ctx, path, pathLen) + if !ok { + return errnoFault + } + + if newFD, err := fsc.OpenFile(ctx, string(b)); err != nil { + switch { + case errors.Is(err, fs.ErrNotExist): + return errnoNoent + case errors.Is(err, fs.ErrExist): + return errnoExist + default: + return errnoIo + } + } else if !mod.Memory().WriteUint32Le(ctx, resultOpenedFd, newFD) { + _ = fsc.CloseFile(ctx, newFD) + return errnoFault + } + return errnoSuccess +} // pathReadlink is the WASI function named functionPathReadlink that reads the // contents of a symbolic link. diff --git a/imports/wasi_snapshot_preview1/fs_test.go b/imports/wasi_snapshot_preview1/fs_test.go index d96d1551..ffdcbdc3 100644 --- a/imports/wasi_snapshot_preview1/fs_test.go +++ b/imports/wasi_snapshot_preview1/fs_test.go @@ -621,7 +621,7 @@ func Test_fdRead_shouldContinueRead(t *testing.T) { n, l uint32 err error expectedOk bool - expectedErrno Errno + expectedErrno []uint64 }{ { name: "break when nothing to read", @@ -668,13 +668,13 @@ func Test_fdRead_shouldContinueRead(t *testing.T) { { name: "return ErrnoIo on error on nothing to read", err: io.ErrClosedPipe, - expectedErrno: ErrnoIo, + expectedErrno: errnoIo, }, { name: "return ErrnoIo on error on nothing read", l: 4, err: io.ErrClosedPipe, - expectedErrno: ErrnoIo, + expectedErrno: errnoIo, }, { // Special case, allows processing data before err name: "break on error on partial read", @@ -1170,7 +1170,7 @@ func Test_pathOpen(t *testing.T) { ) dirflags := uint32(0) - pathPtr := uint32(1) + path := uint32(1) pathLen := uint32(len(pathName)) oflags := uint32(0) fsRightsBase := uint64(1) // ignored: rights were removed from WASI. @@ -1186,7 +1186,7 @@ func Test_pathOpen(t *testing.T) { ok := mod.Memory().Write(testCtx, 0, initialMemory) require.True(t, ok) - requireErrno(t, ErrnoSuccess, mod, functionPathOpen, uint64(rootFD), uint64(dirflags), uint64(pathPtr), + requireErrno(t, ErrnoSuccess, mod, functionPathOpen, uint64(rootFD), uint64(dirflags), uint64(path), uint64(pathLen), uint64(oflags), fsRightsBase, fsRightsInheriting, uint64(fdflags), uint64(resultOpenedFd)) require.Equal(t, ` --> proxy.path_open(fd=3,dirflags=0,path=1,path_len=6,oflags=0,fs_rights_base=1,fs_rights_inheriting=2,fdflags=0,result.opened_fd=8) diff --git a/imports/wasi_snapshot_preview1/poll.go b/imports/wasi_snapshot_preview1/poll.go index b5603fef..d575c83d 100644 --- a/imports/wasi_snapshot_preview1/poll.go +++ b/imports/wasi_snapshot_preview1/poll.go @@ -47,61 +47,74 @@ const ( // // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#poll_oneoff // See https://linux.die.net/man/3/poll -var pollOneoff = wasm.NewGoFunc( - functionPollOneoff, functionPollOneoff, - []string{"in", "out", "nsubscriptions", "result.nevents"}, - func(ctx context.Context, mod api.Module, in, out, nsubscriptions, resultNevents uint32) Errno { - if nsubscriptions == 0 { - return ErrnoInval - } - - mem := mod.Memory() - - // Ensure capacity prior to the read loop to reduce error handling. - inBuf, ok := mem.Read(ctx, in, nsubscriptions*48) - if !ok { - return ErrnoFault - } - outBuf, ok := mem.Read(ctx, out, nsubscriptions*32) - if !ok { - return ErrnoFault - } - - // Eagerly write the number of events which will equal subscriptions unless - // there's a fault in parsing (not processing). - if !mod.Memory().WriteUint32Le(ctx, resultNevents, nsubscriptions) { - return ErrnoFault - } - - // Loop through all subscriptions and write their output. - for sub := uint32(0); sub < nsubscriptions; sub++ { - inOffset := sub * 48 - outOffset := sub * 32 - - var errno Errno - eventType := inBuf[inOffset+8] // +8 past userdata - switch eventType { - case eventTypeClock: // handle later - // +8 past userdata +8 name alignment - errno = processClockEvent(ctx, mod, inBuf[inOffset+8+8:]) - case eventTypeFdRead, eventTypeFdWrite: - // +8 past userdata +4 FD alignment - errno = processFDEvent(ctx, mod, eventType, inBuf[inOffset+8+4:]) - default: - return ErrnoInval - } - - // Write the event corresponding to the processed subscription. - // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-event-struct - copy(outBuf, inBuf[inOffset:inOffset+8]) // userdata - outBuf[outOffset+8] = byte(errno) // uint16, but safe as < 255 - outBuf[outOffset+9] = 0 - binary.LittleEndian.PutUint32(outBuf[outOffset+10:], uint32(eventType)) - // TODO: When FD events are supported, write outOffset+16 - } - return ErrnoSuccess +var pollOneoff = &wasm.HostFunc{ + ExportNames: []string{functionPollOneoff}, + Name: functionPollOneoff, + ParamTypes: []api.ValueType{i32, i32, i32, i32}, + ParamNames: []string{"in", "out", "nsubscriptions", "result.nevents"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(pollOneoffFn), }, -) +} + +func pollOneoffFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + in := uint32(params[0]) + out := uint32(params[1]) + nsubscriptions := uint32(params[2]) + resultNevents := uint32(params[3]) + + if nsubscriptions == 0 { + return errnoInval + } + + mem := mod.Memory() + + // Ensure capacity prior to the read loop to reduce error handling. + inBuf, ok := mem.Read(ctx, in, nsubscriptions*48) + if !ok { + return errnoFault + } + outBuf, ok := mem.Read(ctx, out, nsubscriptions*32) + if !ok { + return errnoFault + } + + // Eagerly write the number of events which will equal subscriptions unless + // there's a fault in parsing (not processing). + if !mod.Memory().WriteUint32Le(ctx, resultNevents, nsubscriptions) { + return errnoFault + } + + // Loop through all subscriptions and write their output. + for sub := uint32(0); sub < nsubscriptions; sub++ { + inOffset := sub * 48 + outOffset := sub * 32 + + var errno Errno + eventType := inBuf[inOffset+8] // +8 past userdata + switch eventType { + case eventTypeClock: // handle later + // +8 past userdata +8 name alignment + errno = processClockEvent(ctx, mod, inBuf[inOffset+8+8:]) + case eventTypeFdRead, eventTypeFdWrite: + // +8 past userdata +4 FD alignment + errno = processFDEvent(ctx, mod, eventType, inBuf[inOffset+8+4:]) + default: + return errnoInval + } + + // Write the event corresponding to the processed subscription. + // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-event-struct + copy(outBuf, inBuf[inOffset:inOffset+8]) // userdata + outBuf[outOffset+8] = byte(errno) // uint16, but safe as < 255 + outBuf[outOffset+9] = 0 + binary.LittleEndian.PutUint32(outBuf[outOffset+10:], uint32(eventType)) + // TODO: When FD events are supported, write outOffset+16 + } + return errnoSuccess +} // processClockEvent supports only relative name events, as that's what's used // to implement sleep in various compilers including Rust, Zig and TinyGo. diff --git a/imports/wasi_snapshot_preview1/proc.go b/imports/wasi_snapshot_preview1/proc.go index 9e656d1b..937390ee 100644 --- a/imports/wasi_snapshot_preview1/proc.go +++ b/imports/wasi_snapshot_preview1/proc.go @@ -22,19 +22,28 @@ const ( // - exitCode: exit code. // // See https://github.com/WebAssembly/WASI/blob/main/phases/snapshot/docs.md#proc_exit -var procExit = wasm.NewGoFunc( - functionProcExit, functionProcExit, - []string{"rval"}, - func(ctx context.Context, mod api.Module, exitCode uint32) { - // Ensure other callers see the exit code. - _ = mod.CloseWithExitCode(ctx, exitCode) - - // Prevent any code from executing after this function. For example, LLVM - // inserts unreachable instructions after calls to exit. - // See: https://github.com/emscripten-core/emscripten/issues/12322 - panic(sys.NewExitError(mod.Name(), exitCode)) +var procExit = &wasm.HostFunc{ + ExportNames: []string{functionProcExit}, + Name: functionProcExit, + ParamTypes: []api.ValueType{i32}, + ParamNames: []string{"rval"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(procExitFn), }, -) +} + +func procExitFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + exitCode := uint32(params[0]) + + // Ensure other callers see the exit code. + _ = mod.CloseWithExitCode(ctx, exitCode) + + // Prevent any code from executing after this function. For example, LLVM + // inserts unreachable instructions after calls to exit. + // See: https://github.com/emscripten-core/emscripten/issues/12322 + panic(sys.NewExitError(mod.Name(), exitCode)) +} // procRaise is stubbed and will never be supported, as it was removed. // diff --git a/imports/wasi_snapshot_preview1/random.go b/imports/wasi_snapshot_preview1/random.go index 885640ac..35ecd04f 100644 --- a/imports/wasi_snapshot_preview1/random.go +++ b/imports/wasi_snapshot_preview1/random.go @@ -34,23 +34,32 @@ const functionRandomGet = "random_get" // buf --^ // // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-random_getbuf-pointeru8-bufLen-size---errno -var randomGet = wasm.NewGoFunc( - functionRandomGet, functionRandomGet, - []string{"buf", "buf_len"}, - func(ctx context.Context, mod api.Module, buf uint32, bufLen uint32) (errno Errno) { - sysCtx := mod.(*wasm.CallContext).Sys - randSource := sysCtx.RandSource() - - randomBytes, ok := mod.Memory().Read(ctx, buf, bufLen) - if !ok { // out-of-range - return ErrnoFault - } - - // We can ignore the returned n as it only != byteCount on error - if _, err := io.ReadAtLeast(randSource, randomBytes, int(bufLen)); err != nil { - return ErrnoIo - } - - return ErrnoSuccess +var randomGet = &wasm.HostFunc{ + ExportNames: []string{functionRandomGet}, + Name: functionRandomGet, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"buf", "buf_len"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(randomGetFn), }, -) +} + +func randomGetFn(ctx context.Context, mod api.Module, params []uint64) []uint64 { + sysCtx := mod.(*wasm.CallContext).Sys + randSource := sysCtx.RandSource() + buf, bufLen := uint32(params[0]), uint32(params[1]) + + randomBytes, ok := mod.Memory().Read(ctx, buf, bufLen) + if !ok { // out-of-range + return errnoFault + } + + // We can ignore the returned n as it only != byteCount on error + if _, err := io.ReadAtLeast(randSource, randomBytes, int(bufLen)); err != nil { + return errnoIo + } + + return errnoSuccess +} diff --git a/imports/wasi_snapshot_preview1/testdata/wasi_arg.wat b/imports/wasi_snapshot_preview1/testdata/wasi_arg.wat index 32734ae8..f7fb72e9 100644 --- a/imports/wasi_snapshot_preview1/testdata/wasi_arg.wat +++ b/imports/wasi_snapshot_preview1/testdata/wasi_arg.wat @@ -10,7 +10,7 @@ ;; ;; See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-args_sizes_get---errno-size-size (import "wasi_snapshot_preview1" "args_sizes_get" - (func $wasi.args_sizes_get (param $result.argc i32) (param $result.argv_buf_size i32) (result (;errno;) i32))) + (func $wasi.args_sizes_get (param $result.argc i32) (param $result.argv_len i32) (result (;errno;) i32))) ;; fd_write write bytes to a file descriptor. ;; @@ -45,7 +45,7 @@ ;; Next, we need to know how many bytes were loaded, as that's how much we'll copy to the file. (call $wasi.args_sizes_get (global.get $ignored) ;; ignore $result.argc as we only read the argv_buf. - (i32.add (global.get $iovs) (i32.const 4)) ;; store $result.argv_buf_size as the length to copy + (i32.add (global.get $iovs) (i32.const 4)) ;; store $result.argv_len as the length to copy ) drop ;; ignore the errno returned diff --git a/imports/wasi_snapshot_preview1/wasi.go b/imports/wasi_snapshot_preview1/wasi.go index d611c02d..e36ff8c5 100644 --- a/imports/wasi_snapshot_preview1/wasi.go +++ b/imports/wasi_snapshot_preview1/wasi.go @@ -121,77 +121,91 @@ func (b *builder) Instantiate(ctx context.Context, ns wazero.Namespace) (api.Clo // exportFunctions adds all go functions that implement wasi. // These should be exported in the module named ModuleName. func exportFunctions(builder wazero.HostModuleBuilder) { + exporter := builder.(wasm.HostFuncExporter) + // Note: these are ordered per spec for consistency even if the resulting // map can't guarantee that. // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#functions - builder.ExportFunction(argsGet.Name, argsGet) - builder.ExportFunction(argsSizesGet.Name, argsSizesGet) - builder.ExportFunction(environGet.Name, environGet) - builder.ExportFunction(environSizesGet.Name, environSizesGet) - builder.ExportFunction(clockResGet.Name, clockResGet) - builder.ExportFunction(clockTimeGet.Name, clockTimeGet) - builder.ExportFunction(fdAdvise.Name, fdAdvise) - builder.ExportFunction(fdAllocate.Name, fdAllocate) - builder.ExportFunction(fdClose.Name, fdClose) - builder.ExportFunction(fdDatasync.Name, fdDatasync) - builder.ExportFunction(fdFdstatGet.Name, fdFdstatGet) - builder.ExportFunction(fdFdstatSetFlags.Name, fdFdstatSetFlags) - builder.ExportFunction(fdFdstatSetRights.Name, fdFdstatSetRights) - builder.ExportFunction(fdFilestatGet.Name, fdFilestatGet) - builder.ExportFunction(fdFilestatSetSize.Name, fdFilestatSetSize) - builder.ExportFunction(fdFilestatSetTimes.Name, fdFilestatSetTimes) - builder.ExportFunction(fdPread.Name, fdPread) - builder.ExportFunction(fdPrestatGet.Name, fdPrestatGet) - builder.ExportFunction(fdPrestatDirName.Name, fdPrestatDirName) - builder.ExportFunction(fdPwrite.Name, fdPwrite) - builder.ExportFunction(fdRead.Name, fdRead) - builder.ExportFunction(fdReaddir.Name, fdReaddir) - builder.ExportFunction(fdRenumber.Name, fdRenumber) - builder.ExportFunction(fdSeek.Name, fdSeek) - builder.ExportFunction(fdSync.Name, fdSync) - builder.ExportFunction(fdTell.Name, fdTell) - builder.ExportFunction(fdWrite.Name, fdWrite) - builder.ExportFunction(pathCreateDirectory.Name, pathCreateDirectory) - builder.ExportFunction(pathFilestatGet.Name, pathFilestatGet) - builder.ExportFunction(pathFilestatSetTimes.Name, pathFilestatSetTimes) - builder.ExportFunction(pathLink.Name, pathLink) - builder.ExportFunction(pathOpen.Name, pathOpen) - builder.ExportFunction(pathReadlink.Name, pathReadlink) - builder.ExportFunction(pathRemoveDirectory.Name, pathRemoveDirectory) - builder.ExportFunction(pathRename.Name, pathRename) - builder.ExportFunction(pathSymlink.Name, pathSymlink) - builder.ExportFunction(pathUnlinkFile.Name, pathUnlinkFile) - builder.ExportFunction(pollOneoff.Name, pollOneoff) - builder.ExportFunction(procExit.Name, procExit) - builder.ExportFunction(procRaise.Name, procRaise) - builder.ExportFunction(schedYield.Name, schedYield) - builder.ExportFunction(randomGet.Name, randomGet) - builder.ExportFunction(sockAccept.Name, sockAccept) - builder.ExportFunction(sockRecv.Name, sockRecv) - builder.ExportFunction(sockSend.Name, sockSend) - builder.ExportFunction(sockShutdown.Name, sockShutdown) + exporter.ExportHostFunc(argsGet) + exporter.ExportHostFunc(argsSizesGet) + exporter.ExportHostFunc(environGet) + exporter.ExportHostFunc(environSizesGet) + exporter.ExportHostFunc(clockResGet) + exporter.ExportHostFunc(clockTimeGet) + exporter.ExportHostFunc(fdAdvise) + exporter.ExportHostFunc(fdAllocate) + exporter.ExportHostFunc(fdClose) + exporter.ExportHostFunc(fdDatasync) + exporter.ExportHostFunc(fdFdstatGet) + exporter.ExportHostFunc(fdFdstatSetFlags) + exporter.ExportHostFunc(fdFdstatSetRights) + exporter.ExportHostFunc(fdFilestatGet) + exporter.ExportHostFunc(fdFilestatSetSize) + exporter.ExportHostFunc(fdFilestatSetTimes) + exporter.ExportHostFunc(fdPread) + exporter.ExportHostFunc(fdPrestatGet) + exporter.ExportHostFunc(fdPrestatDirName) + exporter.ExportHostFunc(fdPwrite) + exporter.ExportHostFunc(fdRead) + exporter.ExportHostFunc(fdReaddir) + exporter.ExportHostFunc(fdRenumber) + exporter.ExportHostFunc(fdSeek) + exporter.ExportHostFunc(fdSync) + exporter.ExportHostFunc(fdTell) + exporter.ExportHostFunc(fdWrite) + exporter.ExportHostFunc(pathCreateDirectory) + exporter.ExportHostFunc(pathFilestatGet) + exporter.ExportHostFunc(pathFilestatSetTimes) + exporter.ExportHostFunc(pathLink) + exporter.ExportHostFunc(pathOpen) + exporter.ExportHostFunc(pathReadlink) + exporter.ExportHostFunc(pathRemoveDirectory) + exporter.ExportHostFunc(pathRename) + exporter.ExportHostFunc(pathSymlink) + exporter.ExportHostFunc(pathUnlinkFile) + exporter.ExportHostFunc(pollOneoff) + exporter.ExportHostFunc(procExit) + exporter.ExportHostFunc(procRaise) + exporter.ExportHostFunc(schedYield) + exporter.ExportHostFunc(randomGet) + exporter.ExportHostFunc(sockAccept) + exporter.ExportHostFunc(sockRecv) + exporter.ExportHostFunc(sockSend) + exporter.ExportHostFunc(sockShutdown) } -func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, values []string, offsets, bytes uint32) Errno { +// Declare constants to avoid slice allocation per call. +var ( + errnoBadf = []uint64{uint64(ErrnoBadf)} + errnoExist = []uint64{uint64(ErrnoExist)} + errnoInval = []uint64{uint64(ErrnoInval)} + errnoIo = []uint64{uint64(ErrnoIo)} + errnoNoent = []uint64{uint64(ErrnoNoent)} + errnoFault = []uint64{uint64(ErrnoFault)} + errnoNametoolong = []uint64{uint64(ErrnoNametoolong)} + errnoSuccess = []uint64{uint64(ErrnoSuccess)} +) + +func writeOffsetsAndNullTerminatedValues(ctx context.Context, mem api.Memory, values []string, offsets, bytes uint32) []uint64 { for _, value := range values { // Write current offset and advance it. if !mem.WriteUint32Le(ctx, offsets, bytes) { - return ErrnoFault + return errnoFault } offsets += 4 // size of uint32 // Write the next value to memory with a NUL terminator if !mem.Write(ctx, bytes, []byte(value)) { - return ErrnoFault + return errnoFault } bytes += uint32(len(value)) if !mem.WriteByte(ctx, bytes, 0) { - return ErrnoFault + return errnoFault } bytes++ } - return ErrnoSuccess + return errnoSuccess } // stubFunction stubs for GrainLang per #271. diff --git a/internal/engine/compiler/compiler_test.go b/internal/engine/compiler/compiler_test.go index 29643565..0dc3c11d 100644 --- a/internal/engine/compiler/compiler_test.go +++ b/internal/engine/compiler/compiler_test.go @@ -206,7 +206,6 @@ func (j *compilerEnv) newFunction(codeSegment []byte) *function { codeInitialAddress: uintptr(unsafe.Pointer(&codeSegment[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(j.moduleInstance)), source: &wasm.FunctionInstance{ - Kind: wasm.FunctionKindWasm, Type: &wasm.FunctionType{}, Module: j.moduleInstance, }, diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index 74814135..41a43ec3 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -309,7 +309,7 @@ const ( tableInstanceTableLenOffset = 8 // Offsets for wasm.FunctionInstance. - functionInstanceTypeIDOffset = 80 + functionInstanceTypeIDOffset = 88 // Offsets for wasm.MemoryInstance. memoryInstanceBufferOffset = 0 @@ -571,7 +571,7 @@ func (e *moduleEngine) NewCallEngine(callCtx *wasm.CallContext, f *wasm.Function } // Call implements the same method as documented on wasm.ModuleEngine. -func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, params ...uint64) (results []uint64, err error) { +func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, params []uint64) (results []uint64, err error) { tp := ce.initialFn.source.Type paramCount := len(params) @@ -784,12 +784,15 @@ entry: calleeHostFunction := ce.moduleContext.fn base := int(ce.stackBasePointerInBytes >> 3) params := ce.stack[base : base+len(calleeHostFunction.source.Type.Params)] - results := wasm.CallGoFunc( - ctx, - callCtx.WithMemory(ce.memoryInstance), - calleeHostFunction.source, - params, - ) + fn := calleeHostFunction.source.GoFunc + var results []uint64 + switch fn := fn.(type) { + case api.GoModuleFunction: + results = fn.Call(ctx, callCtx.WithMemory(ce.memoryInstance), params) + case api.GoFunction: + results = fn.Call(ctx, params) + } + copy(ce.stack[base:], results) codeAddr, modAddr = ce.returnAddress, ce.moduleInstanceAddress goto entry diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index 6085135f..121dc02c 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -194,7 +194,7 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) { const hostModuleName = "env" const hostFnName = "grow_and_shrink_goroutine_stack" - hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func() { + hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func(context.Context) { // This function aggressively grow the goroutine stack by recursively // calling the function many times. var callNum = 1000 diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index 9b6ca5cb..1f13cd6c 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -6,7 +6,6 @@ import ( "fmt" "math" "math/bits" - "reflect" "strings" "sync" "unsafe" @@ -172,13 +171,13 @@ type callFrame struct { type code struct { body []*interpreterOp - hostFn *reflect.Value + hostFn interface{} } type function struct { source *wasm.FunctionInstance body []*interpreterOp - hostFn *reflect.Value + hostFn interface{} } // functionFromUintptr resurrects the original *function from the given uintptr @@ -776,7 +775,7 @@ func (e *moduleEngine) NewCallEngine(callCtx *wasm.CallContext, f *wasm.Function } // Call implements the same method as documented on wasm.ModuleEngine. -func (ce *callEngine) Call(ctx context.Context, m *wasm.CallContext, params ...uint64) (results []uint64, err error) { +func (ce *callEngine) Call(ctx context.Context, m *wasm.CallContext, params []uint64) (results []uint64, err error) { paramSignature := ce.source.Type.ParamNumInUint64 paramCount := len(params) if paramSignature != paramCount { @@ -834,13 +833,20 @@ func (ce *callEngine) callFunction(ctx context.Context, callCtx *wasm.CallContex } func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { - callCtx = callCtx.WithMemory(ce.callerMemory()) if f.source.Listener != nil { ctx = f.source.Listener.Before(ctx, f.source.Definition, params) } frame := &callFrame{f: f} ce.pushFrame(frame) - results = wasm.CallGoFunc(ctx, callCtx, f.source, params) + + fn := f.source.GoFunc + switch fn := fn.(type) { + case api.GoModuleFunction: + results = fn.Call(ctx, callCtx.WithMemory(ce.callerMemory()), params) + case api.GoFunction: + results = fn.Call(ctx, params) + } + ce.popFrame() if f.source.Listener != nil { // TODO: This doesn't get the error due to use of panic to propagate them. @@ -4346,7 +4352,7 @@ func (ce *callEngine) popMemoryOffset(op *interpreterOp) uint32 { } func (ce *callEngine) callGoFuncWithStack(ctx context.Context, callCtx *wasm.CallContext, f *function) { - params := wasm.PopGoFuncParams(f.source, ce.popValue) + params := wasm.PopValues(f.source.Type.ParamNumInUint64, ce.popValue) results := ce.callGoFunc(ctx, callCtx, f, params) for _, v := range results { ce.pushValue(v) diff --git a/internal/gojs/crypto_test.go b/internal/gojs/crypto_test.go index c7f6231a..dfdf03c0 100644 --- a/internal/gojs/crypto_test.go +++ b/internal/gojs/crypto_test.go @@ -12,8 +12,8 @@ func Test_crypto(t *testing.T) { stdout, stderr, err := compileAndRun(testCtx, "crypto", wazero.NewModuleConfig()) - require.EqualError(t, err, `module "" closed with exit_code(0)`) require.Zero(t, stderr) + require.EqualError(t, err, `module "" closed with exit_code(0)`) require.Equal(t, `7a0c9f9f0d `, stdout) } diff --git a/internal/gojs/runtime.go b/internal/gojs/runtime.go index b1476a7a..014cc4cb 100644 --- a/internal/gojs/runtime.go +++ b/internal/gojs/runtime.go @@ -11,6 +11,8 @@ import ( ) const ( + i32, i64 = api.ValueTypeI32, api.ValueTypeI64 + functionWasmExit = "runtime.wasmExit" functionWasmWrite = "runtime.wasmWrite" functionResetMemoryDataView = "runtime.resetMemoryDataView" @@ -24,40 +26,60 @@ const ( // WasmExit implements runtime.wasmExit which supports runtime.exit. // // See https://github.com/golang/go/blob/go1.19/src/runtime/sys_wasm.go#L28 -var WasmExit = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionWasmExit, functionWasmExit, - []string{"code"}, - func(ctx context.Context, mod api.Module, code int32) { - getState(ctx).clear() - _ = mod.CloseWithExitCode(ctx, uint32(code)) // TODO: should ours be signed bit (like -1 == 255)? +var WasmExit = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionWasmExit}, + Name: functionWasmExit, + ParamTypes: []api.ValueType{i32}, + ParamNames: []string{"code"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(wasmExit), }, -)) +}) + +func wasmExit(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + code := uint32(params[0]) + + getState(ctx).clear() + _ = mod.CloseWithExitCode(ctx, code) // TODO: should ours be signed bit (like -1 == 255)? + return +} // WasmWrite implements runtime.wasmWrite which supports runtime.write and // runtime.writeErr. This implements `println`. // // See https://github.com/golang/go/blob/go1.19/src/runtime/os_js.go#L29 -var WasmWrite = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionWasmWrite, functionWasmWrite, - []string{"code"}, - func(ctx context.Context, mod api.Module, fd, p, n uint32) { - var writer io.Writer - - switch fd { - case 1: - writer = mod.(*wasm.CallContext).Sys.Stdout() - case 2: - writer = mod.(*wasm.CallContext).Sys.Stderr() - default: - // Keep things simple by expecting nothing past 2 - panic(fmt.Errorf("unexpected fd %d", fd)) - } - - if _, err := writer.Write(mustRead(ctx, mod.Memory(), "p", p, n)); err != nil { - panic(fmt.Errorf("error writing p: %w", err)) - } +var WasmWrite = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionWasmWrite}, + Name: functionWasmWrite, + ParamTypes: []api.ValueType{i32, i32, i32}, + ParamNames: []string{"fd", "p", "n"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(wasmWrite), }, -)) +}) + +func wasmWrite(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + fd, p, n := uint32(params[0]), uint32(params[1]), uint32(params[2]) + + var writer io.Writer + + switch fd { + case 1: + writer = mod.(*wasm.CallContext).Sys.Stdout() + case 2: + writer = mod.(*wasm.CallContext).Sys.Stderr() + default: + // Keep things simple by expecting nothing past 2 + panic(fmt.Errorf("unexpected fd %d", fd)) + } + + if _, err := writer.Write(mustRead(ctx, mod.Memory(), "p", p, n)); err != nil { + panic(fmt.Errorf("error writing p: %w", err)) + } + return +} // ResetMemoryDataView signals wasm.OpcodeMemoryGrow happened, indicating any // cached view of memory should be reset. @@ -75,22 +97,38 @@ var ResetMemoryDataView = &wasm.HostFunc{ // Nanotime1 implements runtime.nanotime which supports time.Since. // // See https://github.com/golang/go/blob/go1.19/src/runtime/sys_wasm.s#L184 -var Nanotime1 = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionNanotime1, functionNanotime1, - []string{}, - func(ctx context.Context, mod api.Module) int64 { - return mod.(*wasm.CallContext).Sys.Nanotime(ctx) - })) +var Nanotime1 = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionNanotime1}, + Name: functionNanotime1, + ResultTypes: []api.ValueType{i64}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(nanotime1), + }, +}) + +func nanotime1(ctx context.Context, mod api.Module, _ []uint64) []uint64 { + time := mod.(*wasm.CallContext).Sys.Nanotime(ctx) + return []uint64{api.EncodeI64(time)} +} // Walltime implements runtime.walltime which supports time.Now. // // See https://github.com/golang/go/blob/go1.19/src/runtime/sys_wasm.s#L188 -var Walltime = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionWalltime, functionWalltime, - []string{}, - func(ctx context.Context, mod api.Module) (sec int64, nsec int32) { - return mod.(*wasm.CallContext).Sys.Walltime(ctx) - })) +var Walltime = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionWalltime}, + Name: functionWalltime, + ResultTypes: []api.ValueType{i64, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(walltime), + }, +}) + +func walltime(ctx context.Context, mod api.Module, _ []uint64) []uint64 { + sec, nsec := mod.(*wasm.CallContext).Sys.Walltime(ctx) + return []uint64{api.EncodeI64(sec), api.EncodeI32(nsec)} +} // ScheduleTimeoutEvent implements runtime.scheduleTimeoutEvent which supports // runtime.notetsleepg used by runtime.signal_recv. @@ -115,18 +153,27 @@ var ClearTimeoutEvent = stubFunction(functionClearTimeoutEvent) // for runtime.fastrand. // // See https://github.com/golang/go/blob/go1.19/src/runtime/sys_wasm.s#L200 -var GetRandomData = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionGetRandomData, functionGetRandomData, - []string{"buf", "bufLen"}, - func(ctx context.Context, mod api.Module, buf, bufLen uint32) { - randSource := mod.(*wasm.CallContext).Sys.RandSource() - - r := mustRead(ctx, mod.Memory(), "r", buf, bufLen) - - if n, err := randSource.Read(r); err != nil { - panic(fmt.Errorf("RandSource.Read(r /* len=%d */) failed: %w", bufLen, err)) - } else if uint32(n) != bufLen { - panic(fmt.Errorf("RandSource.Read(r /* len=%d */) read %d bytes", bufLen, n)) - } +var GetRandomData = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionGetRandomData}, + Name: functionGetRandomData, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"buf", "bufLen"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(getRandomData), }, -)) +}) + +func getRandomData(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + randSource := mod.(*wasm.CallContext).Sys.RandSource() + buf, bufLen := uint32(params[0]), uint32(params[1]) + + r := mustRead(ctx, mod.Memory(), "r", buf, bufLen) + + if n, err := randSource.Read(r); err != nil { + panic(fmt.Errorf("RandSource.Read(r /* len=%d */) failed: %w", bufLen, err)) + } else if uint32(n) != bufLen { + panic(fmt.Errorf("RandSource.Read(r /* len=%d */) read %d bytes", bufLen, n)) + } + return +} diff --git a/internal/gojs/spfunc/spfunc_test.go b/internal/gojs/spfunc/spfunc_test.go index b2c27619..69846d86 100644 --- a/internal/gojs/spfunc/spfunc_test.go +++ b/internal/gojs/spfunc/spfunc_test.go @@ -12,6 +12,8 @@ import ( "github.com/tetratelabs/wazero/internal/wasm/binary" ) +const i32, i64 = api.ValueTypeI32, api.ValueTypeI64 + // testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") @@ -173,7 +175,13 @@ var spMem = []byte{ 10, 0, 0, 0, 0, 0, 0, 0, } -func i64i32i32i32i32_i64i32_withSP(vRef uint64, mAddr, mLen, argsArray, argsLen uint32) (xRef uint64, ok uint32, sp uint32) { +func i64i32i32i32i32_i64i32_withSP(_ context.Context, params []uint64) []uint64 { + vRef := params[0] + mAddr := uint32(params[1]) + mLen := uint32(params[2]) + argsArray := uint32(params[3]) + argsLen := uint32(params[4]) + if vRef != 1 { panic("vRef") } @@ -189,7 +197,8 @@ func i64i32i32i32i32_i64i32_withSP(vRef uint64, mAddr, mLen, argsArray, argsLen if argsLen != 5 { panic("argsLen") } - return 10, 20, 8 + + return []uint64{10, 20, 8} } func TestMustCallFromSP(t *testing.T) { @@ -197,12 +206,18 @@ func TestMustCallFromSP(t *testing.T) { defer r.Close(testCtx) funcName := "i64i32i32i32i32_i64i32_withSP" - im, err := r.NewHostModuleBuilder("go"). - ExportFunction(funcName, MustCallFromSP(true, wasm.NewGoFunc( - funcName, funcName, - []string{"v", "mAddr", "mLen", "argsArray", "argsLen"}, - i64i32i32i32i32_i64i32_withSP))). - Instantiate(testCtx, r) + builder := r.NewHostModuleBuilder("go") + builder.(wasm.ProxyFuncExporter).ExportProxyFunc(MustCallFromSP(true, &wasm.HostFunc{ + ExportNames: []string{funcName}, + Name: funcName, + ParamTypes: []api.ValueType{i64, i32, i32, i32, i32}, + ParamNames: []string{"v", "mAddr", "mLen", "argsArray", "argsLen"}, + ResultTypes: []api.ValueType{i64, i32, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoFunc(i64i32i32i32i32_i64i32_withSP), + }})) + im, err := builder.Instantiate(testCtx, r) require.NoError(t, err) callDef := im.ExportedFunction(funcName).Definition() diff --git a/internal/gojs/syscall.go b/internal/gojs/syscall.go index 4b78fa00..f182d582 100644 --- a/internal/gojs/syscall.go +++ b/internal/gojs/syscall.go @@ -38,27 +38,47 @@ const ( // runtime.SetFinalizer on the given reference. // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L61 -var FinalizeRef = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionFinalizeRef, functionFinalizeRef, - []string{"r"}, - func(ctx context.Context, mod api.Module, id uint32) { // 32-bits of the ref are the ID - getState(ctx).values.decrement(id) +var FinalizeRef = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionFinalizeRef}, + Name: functionFinalizeRef, + ParamTypes: []api.ValueType{i32}, + ParamNames: []string{"r"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoFunc(finalizeRef), }, -)) +}) + +func finalizeRef(ctx context.Context, params []uint64) (_ []uint64) { + id := uint32(params[0]) // 32-bits of the ref are the ID + + getState(ctx).values.decrement(id) + return +} // StringVal implements js.stringVal, which is used to load the string for // `js.ValueOf(x)`. For example, this is used when setting HTTP headers. // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L212 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L305-L308 -var StringVal = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionStringVal, functionStringVal, - []string{"xAddr", "xLen"}, - func(ctx context.Context, mod api.Module, xAddr, xLen uint32) uint64 { - x := string(mustRead(ctx, mod.Memory(), "x", xAddr, xLen)) - return storeRef(ctx, x) +var StringVal = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionStringVal}, + Name: functionStringVal, + ParamTypes: []api.ValueType{i32, i32}, + ParamNames: []string{"xAddr", "xLen"}, + ResultTypes: []api.ValueType{i64}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(stringVal), }, -)) +}) + +func stringVal(ctx context.Context, mod api.Module, params []uint64) []uint64 { + xAddr, xLen := uint32(params[0]), uint32(params[1]) + + x := string(mustRead(ctx, mod.Memory(), "x", xAddr, xLen)) + return []uint64{storeRef(ctx, x)} +} // ValueGet implements js.valueGet, which is used to load a js.Value property // by name, e.g. `v.Get("address")`. Notably, this is used by js.handleEvent to @@ -66,33 +86,45 @@ var StringVal = spfunc.MustCallFromSP(false, wasm.NewGoFunc( // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L295 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L311-L316 -var ValueGet = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValueGet, functionValueGet, - []string{"v", "pAddr", "pLen"}, - func(ctx context.Context, mod api.Module, vRef uint64, pAddr, pLen uint32) uint64 { - p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) - v := loadValue(ctx, ref(vRef)) +var ValueGet = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValueGet}, + Name: functionValueGet, + ParamTypes: []api.ValueType{i64, i32, i32}, + ParamNames: []string{"v", "pAddr", "pLen"}, + ResultTypes: []api.ValueType{i64}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(valueGet), + }, +}) - var result interface{} - if g, ok := v.(jsGet); ok { - result = g.get(ctx, p) - } else if e, ok := v.(error); ok { - switch p { - case "message": // js (GOOS=js) error, can be anything. - result = e.Error() - case "code": // syscall (GOARCH=wasm) error, must match key in mapJSError in fs_js.go - result = mapJSError(e).Error() - default: - panic(fmt.Errorf("TODO: valueGet(v=%v, p=%s)", v, p)) - } - } else { +func valueGet(ctx context.Context, mod api.Module, params []uint64) []uint64 { + vRef := params[0] + pAddr := uint32(params[1]) + pLen := uint32(params[2]) + + p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) + v := loadValue(ctx, ref(vRef)) + + var result interface{} + if g, ok := v.(jsGet); ok { + result = g.get(ctx, p) + } else if e, ok := v.(error); ok { + switch p { + case "message": // js (GOOS=js) error, can be anything. + result = e.Error() + case "code": // syscall (GOARCH=wasm) error, must match key in mapJSError in fs_js.go + result = mapJSError(e).Error() + default: panic(fmt.Errorf("TODO: valueGet(v=%v, p=%s)", v, p)) } + } else { + panic(fmt.Errorf("TODO: valueGet(v=%v, p=%s)", v, p)) + } - xRef := storeRef(ctx, result) - return xRef - }, -)) + xRef := storeRef(ctx, result) + return []uint64{xRef} +} // ValueSet implements js.valueSet, which is used to store a js.Value property // by name, e.g. `v.Set("address", a)`. Notably, this is used by js.handleEvent @@ -100,34 +132,46 @@ var ValueGet = spfunc.MustCallFromSP(false, wasm.NewGoFunc( // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L309 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L318-L322 -var ValueSet = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValueSet, functionValueSet, - []string{"v", "pAddr", "pLen", "x"}, - func(ctx context.Context, mod api.Module, vRef uint64, pAddr, pLen uint32, xRef uint64) { - v := loadValue(ctx, ref(vRef)) - p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) - x := loadValue(ctx, ref(xRef)) - if v == getState(ctx) { - switch p { - case "_pendingEvent": - if x == nil { // syscall_js.handleEvent - v.(*state)._pendingEvent = nil - return - } - } - } else if e, ok := v.(*event); ok { // syscall_js.handleEvent - switch p { - case "result": - e.result = x +var ValueSet = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValueSet}, + Name: functionValueSet, + ParamTypes: []api.ValueType{i64, i32, i32, i64}, + ParamNames: []string{"v", "pAddr", "pLen", "x"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(valueSet), + }, +}) + +func valueSet(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + vRef := params[0] + pAddr := uint32(params[1]) + pLen := uint32(params[2]) + xRef := params[3] + + v := loadValue(ctx, ref(vRef)) + p := string(mustRead(ctx, mod.Memory(), "p", pAddr, pLen)) + x := loadValue(ctx, ref(xRef)) + if v == getState(ctx) { + switch p { + case "_pendingEvent": + if x == nil { // syscall_js.handleEvent + v.(*state)._pendingEvent = nil return } - } else if m, ok := v.(*object); ok { - m.properties[p] = x // e.g. opt.Set("method", req.Method) + } + } else if e, ok := v.(*event); ok { // syscall_js.handleEvent + switch p { + case "result": + e.result = x return } - panic(fmt.Errorf("TODO: valueSet(v=%v, p=%s, x=%v)", v, p, x)) - }, -)) + } else if m, ok := v.(*object); ok { + m.properties[p] = x // e.g. opt.Set("method", req.Method) + return + } + panic(fmt.Errorf("TODO: valueSet(v=%v, p=%s, x=%v)", v, p, x)) +} // ValueDelete is stubbed as it isn't used in Go's main source tree. // @@ -140,16 +184,28 @@ var ValueDelete = stubFunction(functionValueDelete) // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L334 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L331-L334 -var ValueIndex = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValueIndex, functionValueIndex, - []string{"v", "i"}, - func(ctx context.Context, mod api.Module, vRef uint64, i uint32) (xRef uint64) { - v := loadValue(ctx, ref(vRef)) - result := v.(*objectArray).slice[i] - xRef = storeRef(ctx, result) - return +var ValueIndex = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValueIndex}, + Name: functionValueIndex, + ParamTypes: []api.ValueType{i64, i32}, + ParamNames: []string{"v", "i"}, + ResultTypes: []api.ValueType{i64}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoFunc(valueIndex), }, -)) +}) + +func valueIndex(ctx context.Context, params []uint64) []uint64 { + vRef := params[0] + i := uint32(params[1]) + + v := loadValue(ctx, ref(vRef)) + result := v.(*objectArray).slice[i] + xRef := storeRef(ctx, result) + + return []uint64{xRef} +} // ValueSetIndex is stubbed as it is only used for js.ValueOf when the input is // []interface{}, which doesn't appear to occur in Go's source tree. @@ -163,29 +219,45 @@ var ValueSetIndex = stubFunction(functionValueSetIndex) // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L394 // // https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L343-L358 -var ValueCall = spfunc.MustCallFromSP(true, wasm.NewGoFunc( - functionValueCall, functionValueCall, - []string{"v", "mAddr", "mLen", "argsArray", "argsLen"}, - func(ctx context.Context, mod api.Module, vRef uint64, mAddr, mLen, argsArray, argsLen uint32) (xRef uint64, ok uint32, sp uint32) { - this := ref(vRef) - v := loadValue(ctx, this) - m := string(mustRead(ctx, mod.Memory(), "m", mAddr, mLen)) - args := loadArgs(ctx, mod, argsArray, argsLen) - - if c, isCall := v.(jsCall); !isCall { - panic(fmt.Errorf("TODO: valueCall(v=%v, m=%v, args=%v)", v, m, args)) - } else if result, err := c.call(ctx, mod, this, m, args...); err != nil { - xRef = storeRef(ctx, err) - ok = 0 - } else { - xRef = storeRef(ctx, result) - ok = 1 - } - - sp = refreshSP(mod) - return +var ValueCall = spfunc.MustCallFromSP(true, &wasm.HostFunc{ + ExportNames: []string{functionValueCall}, + Name: functionValueCall, + ParamTypes: []api.ValueType{i64, i32, i32, i32, i32}, + ParamNames: []string{"v", "mAddr", "mLen", "argsArray", "argsLen"}, + ResultTypes: []api.ValueType{i64, i32, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(valueCall), }, -)) +}) + +func valueCall(ctx context.Context, mod api.Module, params []uint64) []uint64 { + vRef := params[0] + mAddr := uint32(params[1]) + mLen := uint32(params[2]) + argsArray := uint32(params[3]) + argsLen := uint32(params[4]) + + this := ref(vRef) + v := loadValue(ctx, this) + m := string(mustRead(ctx, mod.Memory(), "m", mAddr, mLen)) + args := loadArgs(ctx, mod, argsArray, argsLen) + + var xRef uint64 + var ok, sp uint32 + if c, isCall := v.(jsCall); !isCall { + panic(fmt.Errorf("TODO: valueCall(v=%v, m=%v, args=%v)", v, m, args)) + } else if result, err := c.call(ctx, mod, this, m, args...); err != nil { + xRef = storeRef(ctx, err) + ok = 0 + } else { + xRef = storeRef(ctx, result) + ok = 1 + } + + sp = refreshSP(mod) + return []uint64{xRef, uint64(ok), uint64(sp)} +} // ValueInvoke is stubbed as it isn't used in Go's main source tree. // @@ -197,67 +269,93 @@ var ValueInvoke = stubFunction(functionValueInvoke) // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L432 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L380-L391 -var ValueNew = spfunc.MustCallFromSP(true, wasm.NewGoFunc( - functionValueNew, functionValueNew, - []string{"v", "argsArray", "argsLen"}, - func(ctx context.Context, mod api.Module, vRef uint64, argsArray, argsLen uint32) (xRef uint64, ok uint32, sp uint32) { - args := loadArgs(ctx, mod, argsArray, argsLen) - ref := ref(vRef) - v := loadValue(ctx, ref) +var ValueNew = spfunc.MustCallFromSP(true, &wasm.HostFunc{ + ExportNames: []string{functionValueNew}, + Name: functionValueNew, + ParamTypes: []api.ValueType{i64, i32, i32}, + ParamNames: []string{"v", "argsArray", "argsLen"}, + ResultTypes: []api.ValueType{i64, i32, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(valueNew), + }, +}) - switch ref { - case refArrayConstructor: - result := &objectArray{} - xRef = storeRef(ctx, result) - ok = 1 - case refUint8ArrayConstructor: - var result *byteArray - if n, ok := args[0].(float64); ok { - result = &byteArray{make([]byte, uint32(n))} - } else if n, ok := args[0].(uint32); ok { - result = &byteArray{make([]byte, n)} - } else if b, ok := args[0].(*byteArray); ok { - // In case of below, in HTTP, return the same ref - // uint8arrayWrapper := uint8Array.New(args[0]) - result = b - } else { - panic(fmt.Errorf("TODO: valueNew(v=%v, args=%v)", v, args)) - } - xRef = storeRef(ctx, result) - ok = 1 - case refObjectConstructor: - result := &object{properties: map[string]interface{}{}} - xRef = storeRef(ctx, result) - ok = 1 - case refHttpHeadersConstructor: - result := &headers{headers: http.Header{}} - xRef = storeRef(ctx, result) - ok = 1 - case refJsDateConstructor: - xRef = uint64(refJsDate) - ok = 1 - default: +func valueNew(ctx context.Context, mod api.Module, params []uint64) []uint64 { + vRef := params[0] + argsArray := uint32(params[1]) + argsLen := uint32(params[2]) + + args := loadArgs(ctx, mod, argsArray, argsLen) + ref := ref(vRef) + v := loadValue(ctx, ref) + + var xRef uint64 + var ok, sp uint32 + switch ref { + case refArrayConstructor: + result := &objectArray{} + xRef = storeRef(ctx, result) + ok = 1 + case refUint8ArrayConstructor: + var result *byteArray + if n, ok := args[0].(float64); ok { + result = &byteArray{make([]byte, uint32(n))} + } else if n, ok := args[0].(uint32); ok { + result = &byteArray{make([]byte, n)} + } else if b, ok := args[0].(*byteArray); ok { + // In case of below, in HTTP, return the same ref + // uint8arrayWrapper := uint8Array.New(args[0]) + result = b + } else { panic(fmt.Errorf("TODO: valueNew(v=%v, args=%v)", v, args)) } + xRef = storeRef(ctx, result) + ok = 1 + case refObjectConstructor: + result := &object{properties: map[string]interface{}{}} + xRef = storeRef(ctx, result) + ok = 1 + case refHttpHeadersConstructor: + result := &headers{headers: http.Header{}} + xRef = storeRef(ctx, result) + ok = 1 + case refJsDateConstructor: + xRef = uint64(refJsDate) + ok = 1 + default: + panic(fmt.Errorf("TODO: valueNew(v=%v, args=%v)", v, args)) + } - sp = refreshSP(mod) - return - }, -)) + sp = refreshSP(mod) + return []uint64{xRef, uint64(ok), uint64(sp)} +} // ValueLength implements js.valueLength, which is used to load the length // property of a value, e.g. `array.length`. // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L372 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L396-L397 -var ValueLength = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValueLength, functionValueLength, - []string{"v"}, - func(ctx context.Context, mod api.Module, vRef uint64) uint32 { - v := loadValue(ctx, ref(vRef)) - return uint32(len(v.(*objectArray).slice)) +var ValueLength = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValueLength}, + Name: functionValueLength, + ParamTypes: []api.ValueType{i64}, + ParamNames: []string{"v"}, + ResultTypes: []api.ValueType{i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoFunc(valueLength), }, -)) +}) + +func valueLength(ctx context.Context, params []uint64) []uint64 { + vRef := params[0] + + v := loadValue(ctx, ref(vRef)) + l := uint32(len(v.(*objectArray).slice)) + + return []uint64{uint64(l)} +} // ValuePrepareString implements js.valuePrepareString, which is used to load // the string for `o.String()` (via js.jsString) for string, boolean and @@ -266,17 +364,29 @@ var ValueLength = spfunc.MustCallFromSP(false, wasm.NewGoFunc( // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L531 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L402-L405 -var ValuePrepareString = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValuePrepareString, functionValuePrepareString, - []string{"v"}, - func(ctx context.Context, mod api.Module, vRef uint64) (sRef uint64, sLen uint32) { - v := loadValue(ctx, ref(vRef)) - s := valueString(v) - sRef = storeRef(ctx, s) - sLen = uint32(len(s)) - return +var ValuePrepareString = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValuePrepareString}, + Name: functionValuePrepareString, + ParamTypes: []api.ValueType{i64}, + ParamNames: []string{"v"}, + ResultTypes: []api.ValueType{i64, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoFunc(valuePrepareString), }, -)) +}) + +func valuePrepareString(ctx context.Context, params []uint64) []uint64 { + vRef := params[0] + + v := loadValue(ctx, ref(vRef)) + s := valueString(v) + + sRef := storeRef(ctx, s) + sLen := uint32(len(s)) + + return []uint64{sRef, uint64(sLen)} +} // ValueLoadString implements js.valueLoadString, which is used copy a string // value for `o.String()`. @@ -284,16 +394,28 @@ var ValuePrepareString = spfunc.MustCallFromSP(false, wasm.NewGoFunc( // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L533 // // https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L410-L412 -var ValueLoadString = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionValueLoadString, functionValueLoadString, - []string{"v", "bAddr", "bLen"}, - func(ctx context.Context, mod api.Module, vRef uint64, bAddr, bLen uint32) { - v := loadValue(ctx, ref(vRef)) - s := valueString(v) - b := mustRead(ctx, mod.Memory(), "b", bAddr, bLen) - copy(b, s) +var ValueLoadString = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionValueLoadString}, + Name: functionValueLoadString, + ParamTypes: []api.ValueType{i64, i32, i32}, + ParamNames: []string{"v", "bAddr", "bLen"}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(valueLoadString), }, -)) +}) + +func valueLoadString(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + vRef := params[0] + bAddr := uint32(params[1]) + bLen := uint32(params[2]) + + v := loadValue(ctx, ref(vRef)) + s := valueString(v) + b := mustRead(ctx, mod.Memory(), "b", bAddr, bLen) + copy(b, s) + return +} // ValueInstanceOf is stubbed as it isn't used in Go's main source tree. // @@ -310,19 +432,35 @@ var ValueInstanceOf = stubFunction(functionValueInstanceOf) // // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L569 // and https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L424-L433 -var CopyBytesToGo = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionCopyBytesToGo, functionCopyBytesToGo, - []string{"dstAddr", "dstLen", "src"}, - func(ctx context.Context, mod api.Module, dstAddr, dstLen, _ uint32, srcRef uint64) (n, ok uint32) { - dst := mustRead(ctx, mod.Memory(), "dst", dstAddr, dstLen) // nolint - v := loadValue(ctx, ref(srcRef)) - if src, isBuf := v.(*byteArray); isBuf { - n = uint32(copy(dst, src.slice)) - ok = 1 - } - return +var CopyBytesToGo = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionCopyBytesToGo}, + Name: functionCopyBytesToGo, + ParamTypes: []api.ValueType{i32, i32, i32, i64}, + ParamNames: []string{"dstAddr", "dstLen", "_", "src"}, + ResultTypes: []api.ValueType{i32, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(copyBytesToGo), }, -)) +}) + +func copyBytesToGo(ctx context.Context, mod api.Module, params []uint64) []uint64 { + dstAddr := uint32(params[0]) + dstLen := uint32(params[1]) + _ /* unknown */ = uint32(params[2]) + srcRef := params[3] + + dst := mustRead(ctx, mod.Memory(), "dst", dstAddr, dstLen) // nolint + v := loadValue(ctx, ref(srcRef)) + + var n, ok uint32 + if src, isBuf := v.(*byteArray); isBuf { + n = uint32(copy(dst, src.slice)) + ok = 1 + } + + return []uint64{uint64(n), uint64(ok)} +} // CopyBytesToJS copies linear memory to a JavaScript managed byte array. // For example, this is used to read an HTTP request body. @@ -335,21 +473,37 @@ var CopyBytesToGo = spfunc.MustCallFromSP(false, wasm.NewGoFunc( // See https://github.com/golang/go/blob/go1.19/src/syscall/js/js.go#L583 // // https://github.com/golang/go/blob/go1.19/misc/wasm/wasm_exec.js#L438-L448 -var CopyBytesToJS = spfunc.MustCallFromSP(false, wasm.NewGoFunc( - functionCopyBytesToJS, functionCopyBytesToJS, - []string{"dst", "srcAddr", "srcLen"}, - func(ctx context.Context, mod api.Module, dstRef uint64, srcAddr, srcLen, _ uint32) (n, ok uint32) { - src := mustRead(ctx, mod.Memory(), "src", srcAddr, srcLen) // nolint - v := loadValue(ctx, ref(dstRef)) - if dst, isBuf := v.(*byteArray); isBuf { - if dst != nil { // empty is possible on EOF - n = uint32(copy(dst.slice, src)) - } - ok = 1 - } - return +var CopyBytesToJS = spfunc.MustCallFromSP(false, &wasm.HostFunc{ + ExportNames: []string{functionCopyBytesToJS}, + Name: functionCopyBytesToJS, + ParamTypes: []api.ValueType{i64, i32, i32, i32}, + ParamNames: []string{"dst", "srcAddr", "srcLen", "_"}, + ResultTypes: []api.ValueType{i32, i32}, + Code: &wasm.Code{ + IsHostFunction: true, + GoFunc: api.GoModuleFunc(copyBytesToJS), }, -)) +}) + +func copyBytesToJS(ctx context.Context, mod api.Module, params []uint64) []uint64 { + dstRef := params[0] + srcAddr := uint32(params[1]) + srcLen := uint32(params[2]) + _ /* unknown */ = uint32(params[3]) + + src := mustRead(ctx, mod.Memory(), "src", srcAddr, srcLen) // nolint + v := loadValue(ctx, ref(dstRef)) + + var n, ok uint32 + if dst, isBuf := v.(*byteArray); isBuf { + if dst != nil { // empty is possible on EOF + n = uint32(copy(dst.slice, src)) + } + ok = 1 + } + + return []uint64{uint64(n), uint64(ok)} +} // refreshSP refreshes the stack pointer, which is needed prior to storeValue // when in an operation that can trigger a Go event handler. diff --git a/internal/integration_test/bench/bench_test.go b/internal/integration_test/bench/bench_test.go index 7fa8f48c..930af983 100644 --- a/internal/integration_test/bench/bench_test.go +++ b/internal/integration_test/bench/bench_test.go @@ -216,7 +216,7 @@ func createRuntime(b *testing.B, config wazero.RuntimeConfig) wazero.Runtime { r := wazero.NewRuntimeWithConfig(testCtx, config) _, err := r.NewHostModuleBuilder("env"). - ExportFunction("get_random_string", getRandomString). + NewFunctionBuilder().WithFunc(getRandomString).Export("get_random_string"). Instantiate(testCtx, r) if err != nil { b.Fatal(err) diff --git a/internal/integration_test/bench/hostfunc_bench_test.go b/internal/integration_test/bench/hostfunc_bench_test.go index 67d07f7a..cee9d982 100644 --- a/internal/integration_test/bench/hostfunc_bench_test.go +++ b/internal/integration_test/bench/hostfunc_bench_test.go @@ -16,9 +16,14 @@ import ( ) const ( - // callGoHostName is the name of exported function which calls the Go-implemented host function. + // callGoHostName is the name of exported function which calls the + // Go-implemented host function. callGoHostName = "call_go_host" - // callWasmHostName is the name of exported function which calls the Wasm-implemented host function. + // callGoReflectHostName is the name of exported function which calls the + // Go-implemented host function defined in reflection. + callGoReflectHostName = "call_go_reflect_host" + // callWasmHostName is the name of exported function which calls the + // Wasm-implemented host function. callWasmHostName = "call_wasm_host" ) @@ -35,46 +40,32 @@ func BenchmarkHostFunctionCall(b *testing.B) { } }) - const offset = 100 + const offset = uint64(100) const val = float32(1.1234) binary.LittleEndian.PutUint32(m.Memory.Buffer[offset:], math.Float32bits(val)) - b.Run(callGoHostName, func(b *testing.B) { - ce, err := getCallEngine(m, callGoHostName) - if err != nil { - b.Fatal(err) - } + for _, fn := range []string{callGoReflectHostName, callGoHostName, callWasmHostName} { + fn := fn - b.ResetTimer() - for i := 0; i < b.N; i++ { - res, err := ce.Call(testCtx, m.CallCtx, offset) + b.Run(fn, func(b *testing.B) { + ce, err := getCallEngine(m, fn) if err != nil { b.Fatal(err) } - if uint32(res[0]) != math.Float32bits(val) { - b.Fail() - } - } - }) - b.Run(callWasmHostName, func(b *testing.B) { - ce, err := getCallEngine(m, callWasmHostName) - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - res, err := ce.Call(testCtx, m.CallCtx, offset) - if err != nil { - b.Fatal(err) + b.ResetTimer() + for i := 0; i < b.N; i++ { + res, err := ce.Call(testCtx, m.CallCtx, []uint64{offset}) + if err != nil { + b.Fatal(err) + } + if uint32(res[0]) != math.Float32bits(val) { + b.Fail() + } } - if uint32(res[0]) != math.Float32bits(val) { - b.Fail() - } - } - }) + }) + } } func TestBenchmarkFunctionCall(t *testing.T) { @@ -92,8 +83,12 @@ func TestBenchmarkFunctionCall(t *testing.T) { callGoHost, err := getCallEngine(m, callGoHostName) require.NoError(t, err) + callGoReflectHost, err := getCallEngine(m, callGoReflectHostName) + require.NoError(t, err) + require.NotNil(t, callWasmHost) require.NotNil(t, callGoHost) + require.NotNil(t, callGoReflectHost) tests := []struct { offset uint32 @@ -111,13 +106,14 @@ func TestBenchmarkFunctionCall(t *testing.T) { ce wasm.CallEngine }{ {name: "go", ce: callGoHost}, + {name: "go-reflect", ce: callGoReflectHost}, {name: "wasm", ce: callWasmHost}, } { f := f t.Run(f.name, func(t *testing.T) { for _, tc := range tests { binary.LittleEndian.PutUint32(mem[tc.offset:], math.Float32bits(tc.val)) - res, err := f.ce.Call(context.Background(), m.CallCtx, uint64(tc.offset)) + res, err := f.ce.Call(context.Background(), m.CallCtx, []uint64{uint64(tc.offset)}) require.NoError(t, err) require.Equal(t, math.Float32bits(tc.val), uint32(res[0])) } @@ -148,9 +144,19 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { // Build the host module. hostModule := &wasm.Module{ TypeSection: []*wasm.FunctionType{ft}, - FunctionSection: []wasm.Index{0, 0}, + FunctionSection: []wasm.Index{0, 0, 0}, CodeSection: []*wasm.Code{ - wasm.MustParseGoFuncCode( + { + IsHostFunction: true, + GoFunc: api.GoModuleFunc(func(ctx context.Context, mod api.Module, params []uint64) []uint64 { + ret, ok := mod.Memory().ReadUint32Le(ctx, uint32(params[0])) + if !ok { + panic("couldn't read memory") + } + return []uint64{uint64(ret)} + }), + }, + wasm.MustParseGoReflectFuncCode( func(ctx context.Context, m api.Module, pos uint32) float32 { ret, ok := m.Memory().ReadUint32Le(ctx, pos) if !ok { @@ -171,7 +177,8 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { }, ExportSection: []*wasm.Export{ {Name: "go", Type: wasm.ExternTypeFunc, Index: 0}, - {Name: "wasm", Type: wasm.ExternTypeFunc, Index: 1}, + {Name: "go-reflect", Type: wasm.ExternTypeFunc, Index: 1}, + {Name: "wasm", Type: wasm.ExternTypeFunc, Index: 2}, }, ID: wasm.ModuleID{1, 2, 3, 4, 5}, } @@ -180,7 +187,9 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { host := &wasm.ModuleInstance{Name: "host", TypeIDs: []wasm.FunctionTypeID{0}} host.Functions = host.BuildFunctions(hostModule, nil) host.BuildExports(hostModule.ExportSection) - goFn, wasnFn := host.Exports["go"].Function, host.Exports["wasm"].Function + goFn := host.Exports["go"].Function + goReflectFn := host.Exports["go-reflect"].Function + wasnFn := host.Exports["wasm"].Function err := eng.CompileModule(testCtx, hostModule) requireNoError(err) @@ -196,15 +205,18 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { // Placeholders for imports from hostModule. {Type: wasm.ExternTypeFunc}, {Type: wasm.ExternTypeFunc}, + {Type: wasm.ExternTypeFunc}, }, - FunctionSection: []wasm.Index{0, 0}, + FunctionSection: []wasm.Index{0, 0, 0}, ExportSection: []*wasm.Export{ {Name: callGoHostName, Type: wasm.ExternTypeFunc, Index: 2}, - {Name: callWasmHostName, Type: wasm.ExternTypeFunc, Index: 3}, + {Name: callGoReflectHostName, Type: wasm.ExternTypeFunc, Index: 3}, + {Name: callWasmHostName, Type: wasm.ExternTypeFunc, Index: 4}, }, CodeSection: []*wasm.Code{ {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 0, wasm.OpcodeEnd}}, // Calling the index 0 = host.go. - {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 1, wasm.OpcodeEnd}}, // Calling the index 1 = host.wasm. + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 1, wasm.OpcodeEnd}}, // Calling the index 1 = host.go-reflect. + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 2, wasm.OpcodeEnd}}, // Calling the index 2 = host.wasm. }, // Indicates that this module has a memory so that compilers are able to assemble memory-related initialization. MemorySection: &wasm.Memory{Min: 1}, @@ -220,7 +232,7 @@ func setupHostCallBench(requireNoError func(error)) *wasm.ModuleInstance { importing.Functions = append([]*wasm.FunctionInstance{goFn, wasnFn}, importingFunctions...) importing.BuildExports(importingModule.ExportSection) - importingMe, err := eng.NewModuleEngine(importing.Name, importingModule, []*wasm.FunctionInstance{goFn, wasnFn}, importingFunctions, nil, nil) + importingMe, err := eng.NewModuleEngine(importing.Name, importingModule, []*wasm.FunctionInstance{goFn, goReflectFn, wasnFn}, importingFunctions, nil, nil) requireNoError(err) linkModuleToEngine(importing, importingMe) diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index 1d5971b3..5cbfcd22 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -92,12 +92,13 @@ func testReftypeImports(t *testing.T, r wazero.Runtime) { hostObj := &dog{name: "hello"} host, err := r.NewHostModuleBuilder("host"). - ExportFunctions(map[string]interface{}{ - "externref": func(externrefFromRefNull uintptr) uintptr { - require.Zero(t, externrefFromRefNull) - return uintptr(unsafe.Pointer(hostObj)) - }, - }).Instantiate(testCtx, r) + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, externrefFromRefNull uintptr) uintptr { + require.Zero(t, externrefFromRefNull) + return uintptr(unsafe.Pointer(hostObj)) + }). + Export("externref"). + Instantiate(testCtx, r) require.NoError(t, err) defer host.Close(testCtx) @@ -164,11 +165,13 @@ func testGlobalExtend(t *testing.T, r wazero.Runtime) { } func testUnreachable(t *testing.T, r wazero.Runtime) { - callUnreachable := func(nil api.Module) { + callUnreachable := func(context.Context) { panic("panic in host function") } - _, err := r.NewHostModuleBuilder("host").ExportFunction("cause_unreachable", callUnreachable).Instantiate(testCtx, r) + _, err := r.NewHostModuleBuilder("host"). + NewFunctionBuilder().WithFunc(callUnreachable).Export("cause_unreachable"). + Instantiate(testCtx, r) require.NoError(t, err) module, err := r.InstantiateModuleFromBinary(testCtx, unreachableWasm) @@ -186,12 +189,14 @@ wasm stack trace: } func testRecursiveEntry(t *testing.T, r wazero.Runtime) { - hostfunc := func(mod api.Module) { + hostfunc := func(ctx context.Context, mod api.Module) { _, err := mod.ExportedFunction("called_by_host_func").Call(testCtx) require.NoError(t, err) } - _, err := r.NewHostModuleBuilder("env").ExportFunction("host_func", hostfunc).Instantiate(testCtx, r) + _, err := r.NewHostModuleBuilder("env"). + NewFunctionBuilder().WithFunc(hostfunc).Export("host_func"). + Instantiate(testCtx, r) require.NoError(t, err) module, err := r.InstantiateModuleFromBinary(testCtx, recursiveWasm) @@ -214,7 +219,9 @@ func testHostFuncMemory(t *testing.T, r wazero.Runtime) { return 0 } - host, err := r.NewHostModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate(testCtx, r) + host, err := r.NewHostModuleBuilder(""). + NewFunctionBuilder().WithFunc(storeInt).Export("store_int"). + Instantiate(testCtx, r) require.NoError(t, err) defer host.Close(testCtx) @@ -240,21 +247,24 @@ func testNestedGoContext(t *testing.T, r wazero.Runtime) { importingName := t.Name() + "-importing" var importing api.Module - fns := map[string]interface{}{ - "inner": func(ctx context.Context, p uint32) uint32 { + + imported, err := r.NewHostModuleBuilder(importedName). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, p uint32) uint32 { // We expect the initial context, testCtx, to be overwritten by "outer" when it called this. require.Equal(t, nestedCtx, ctx) return p + 1 - }, - "outer": func(ctx context.Context, module api.Module, p uint32) uint32 { + }). + Export("inner"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, module api.Module, p uint32) uint32 { require.Equal(t, testCtx, ctx) results, err := module.ExportedFunction("inner").Call(nestedCtx, uint64(p)) require.NoError(t, err) return uint32(results[0]) + 1 - }, - } - - imported, err := r.NewHostModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx, r) + }). + Export("outer"). + Instantiate(testCtx, r) require.NoError(t, err) defer imported.Close(testCtx) @@ -276,14 +286,11 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { var importing api.Module fns := map[string]interface{}{ - "no_context": func(p uint32) uint32 { - return p + 1 - }, - "go_context": func(ctx context.Context, p uint32) uint32 { + "ctx": func(ctx context.Context, p uint32) uint32 { require.Equal(t, testCtx, ctx) return p + 1 }, - "module_context": func(module api.Module, p uint32) uint32 { + "ctx mod": func(ctx context.Context, module api.Module, p uint32) uint32 { require.Equal(t, importing, module) return p + 1 }, @@ -292,7 +299,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { for test := range fns { t.Run(test, func(t *testing.T) { imported, err := r.NewHostModuleBuilder(importedName). - ExportFunction("return_input", fns[test]). + NewFunctionBuilder().WithFunc(fns[test]).Export("return_input"). Instantiate(testCtx, r) require.NoError(t, err) defer imported.Close(testCtx) @@ -316,16 +323,16 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { importingName := t.Name() + "-importing" fns := map[string]interface{}{ - "i32": func(p uint32) uint32 { + "i32": func(ctx context.Context, p uint32) uint32 { return p + 1 }, - "i64": func(p uint64) uint64 { + "i64": func(ctx context.Context, p uint64) uint64 { return p + 1 }, - "f32": func(p float32) float32 { + "f32": func(ctx context.Context, p float32) float32 { return p + 1 }, - "f64": func(p float64) float64 { + "f64": func(ctx context.Context, p float64) float64 { return p + 1 }, } @@ -362,7 +369,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { } { t.Run(test.name, func(t *testing.T) { imported, err := r.NewHostModuleBuilder(importedName). - ExportFunction("return_input", fns[test.name]). + NewFunctionBuilder().WithFunc(fns[test.name]).Export("return_input"). Instantiate(testCtx, r) require.NoError(t, err) defer imported.Close(testCtx) @@ -511,8 +518,9 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { } // Create the host module, which exports the function that closes the importing module. - importedCode, err = r.NewHostModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", closeAndReturn).Compile(testCtx) + importedCode, err = r.NewHostModuleBuilder(t.Name() + "-imported"). + NewFunctionBuilder().WithFunc(closeAndReturn).Export("return_input"). + Compile(testCtx) require.NoError(t, err) imported, err = r.InstantiateModule(testCtx, importedCode, moduleConfig) diff --git a/internal/integration_test/engine/hammer_test.go b/internal/integration_test/engine/hammer_test.go index 9b181bc8..3d6cefbb 100644 --- a/internal/integration_test/engine/hammer_test.go +++ b/internal/integration_test/engine/hammer_test.go @@ -1,6 +1,7 @@ package adhoc import ( + "context" "sync" "testing" @@ -49,9 +50,13 @@ func closeImportedModuleWhileInUse(t *testing.T, r wazero.Runtime) { require.NoError(t, imported.Close(testCtx)) // Redefine the imported module, with a function that no longer blocks. - imported, err := r.NewHostModuleBuilder(imported.Name()).ExportFunction("return_input", func(x uint32) uint32 { - return x - }).Instantiate(testCtx, r) + imported, err := r.NewHostModuleBuilder(imported.Name()). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, x uint32) uint32 { + return x + }). + Export("return_input"). + Instantiate(testCtx, r) require.NoError(t, err) // Redefine the importing module, which should link to the redefined host module. @@ -72,14 +77,15 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported // To know return path works on a closed module, we need to block calls. var calls sync.WaitGroup calls.Add(P) - blockAndReturn := func(x uint32) uint32 { + blockAndReturn := func(ctx context.Context, x uint32) uint32 { calls.Wait() return x } // Create the host module, which exports the blocking function. imported, err := r.NewHostModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", blockAndReturn).Instantiate(testCtx, r) + NewFunctionBuilder().WithFunc(blockAndReturn).Export("return_input"). + Instantiate(testCtx, r) require.NoError(t, err) defer imported.Close(testCtx) diff --git a/internal/integration_test/vs/runtime.go b/internal/integration_test/vs/runtime.go index 124b446b..7ed957f5 100644 --- a/internal/integration_test/vs/runtime.go +++ b/internal/integration_test/vs/runtime.go @@ -79,14 +79,18 @@ func (m *wazeroModule) Memory() []byte { return m.mod.Memory().(*wasm.MemoryInstance).Buffer } -func (r *wazeroRuntime) log(ctx context.Context, m api.Module, offset, byteCount uint32) { - buf, ok := m.Memory().Read(ctx, offset, byteCount) +func (r *wazeroRuntime) log(ctx context.Context, mod api.Module, params []uint64) (_ []uint64) { + offset, byteCount := uint32(params[0]), uint32(params[1]) + + buf, ok := mod.Memory().Read(ctx, offset, byteCount) if !ok { panic("out of memory reading log buffer") } if err := r.logFn(buf); err != nil { panic(err) } + + return } func (r *wazeroRuntime) Compile(ctx context.Context, cfg *RuntimeConfig) (err error) { @@ -94,17 +98,20 @@ func (r *wazeroRuntime) Compile(ctx context.Context, cfg *RuntimeConfig) (err er if cfg.LogFn != nil { r.logFn = cfg.LogFn if r.env, err = r.runtime.NewHostModuleBuilder("env"). - ExportFunction("log", r.log).Compile(ctx); err != nil { + NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc(r.log), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}). + Export("log"). + Compile(ctx); err != nil { return err } } else if cfg.EnvFReturnValue != 0 { if r.env, err = r.runtime.NewHostModuleBuilder("env"). - ExportFunction("f", - // Note: accepting (context.Context, api.Module) is the slowest type of host function with wazero. - func(context.Context, api.Module, uint64) uint64 { - return cfg.EnvFReturnValue - }, - ).Compile(ctx); err != nil { + NewFunctionBuilder(). + WithGoFunction(api.GoFunc(func(context.Context, []uint64) []uint64 { + return []uint64{cfg.EnvFReturnValue} + }), []api.ValueType{api.ValueTypeI64}, []api.ValueType{api.ValueTypeI64}). + Export("f"). + Compile(ctx); err != nil { return err } } diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 48bfa77e..0b84c597 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -158,7 +158,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { ce, err := me.NewCallEngine(module.CallCtx, fn) require.NoError(t, err) - results, err := ce.Call(testCtx, module.CallCtx, 1, 2) + results, err := ce.Call(testCtx, module.CallCtx, []uint64{1, 2}) require.NoError(t, err) require.Equal(t, []uint64{1, 2}, results) @@ -166,7 +166,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { ce, err := me.NewCallEngine(module.CallCtx, fn) require.NoError(t, err) - _, err = ce.Call(testCtx, module.CallCtx) + _, err = ce.Call(testCtx, module.CallCtx, nil) require.EqualError(t, err, "expected 2 params, but passed 0") }) @@ -174,7 +174,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { ce, err := me.NewCallEngine(module.CallCtx, fn) require.NoError(t, err) - _, err = ce.Call(testCtx, module.CallCtx, 1, 2, 3) + _, err = ce.Call(testCtx, module.CallCtx, []uint64{1, 2, 3}) require.EqualError(t, err, "expected 2 params, but passed 3") }) } @@ -377,7 +377,7 @@ func runTestModuleEngine_Call_HostFn_Mem(t *testing.T, et EngineTester, readMem ce, err := tc.fn.Module.Engine.NewCallEngine(tc.fn.Module.CallCtx, tc.fn) require.NoError(t, err) - results, err := ce.Call(testCtx, importing.CallCtx) + results, err := ce.Call(testCtx, importing.CallCtx, nil) require.NoError(t, err) require.Equal(t, tc.expected, results[0]) }) @@ -433,7 +433,7 @@ func runTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester, hostDivBy *w ce, err := f.Module.Engine.NewCallEngine(m, f) require.NoError(t, err) - results, err := ce.Call(testCtx, m, 1) + results, err := ce.Call(testCtx, m, []uint64{1}) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -529,11 +529,11 @@ wasm stack trace: ce, err := f.Module.Engine.NewCallEngine(m, f) require.NoError(t, err) - _, err = ce.Call(testCtx, m, tc.input...) + _, err = ce.Call(testCtx, m, tc.input) require.EqualError(t, err, tc.expectedErr) // Ensure the module still works - results, err := ce.Call(testCtx, m, 1) + results, err := ce.Call(testCtx, m, []uint64{1}) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -617,7 +617,7 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { // Initialize the memory using Wasm. This copies the test phrase. initCallEngine, err := me.NewCallEngine(module.CallCtx, init) require.NoError(t, err) - _, err = initCallEngine.Call(testCtx, module.CallCtx) + _, err = initCallEngine.Call(testCtx, module.CallCtx, nil) require.NoError(t, err) // We expect the same []byte read earlier to now include the phrase in wasm. @@ -649,7 +649,7 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { // Now, we need to prove the other direction, that when Wasm changes the capacity, the host's buffer is unaffected. growCallEngine, err := me.NewCallEngine(module.CallCtx, grow) require.NoError(t, err) - _, err = growCallEngine.Call(testCtx, module.CallCtx, 1) + _, err = growCallEngine.Call(testCtx, module.CallCtx, []uint64{1}) require.NoError(t, err) // The host buffer should still contain the same bytes as before grow @@ -658,7 +658,7 @@ func RunTestModuleEngine_Memory(t *testing.T, et EngineTester) { // Re-initialize the memory in wasm, which overwrites the region. initCallEngine2, err := me.NewCallEngine(module.CallCtx, init) require.NoError(t, err) - _, err = initCallEngine2.Call(testCtx, module.CallCtx) + _, err = initCallEngine2.Call(testCtx, module.CallCtx, nil) require.NoError(t, err) // The host was not affected because it is a different slice due to "memory.grow" affecting the underlying memory. @@ -672,14 +672,14 @@ const ( callImportCallDivByGoName = "call_import->" + callDivByGoName ) -func divByGo(d uint32) uint32 { +func divByGo(_ context.Context, d uint32) uint32 { if d == math.MaxUint32 { panic(errors.New("host-function panic")) } return 1 / d // go panics if d == 0 } -var hostDivByGo = wasm.MustParseGoFuncCode(divByGo) +var hostDivByGo = wasm.MustParseGoReflectFuncCode(divByGo) // (func (export "div_by.wasm") (param i32) (result i32) (i32.div_u (i32.const 1) (local.get 0))) var divByWasm = []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeLocalGet, 0, wasm.OpcodeI32DivU, wasm.OpcodeEnd} @@ -700,7 +700,7 @@ func readMemGo(ctx context.Context, m api.Module) uint64 { return ret } -var hostReadMemGo = wasm.MustParseGoFuncCode(readMemGo) +var hostReadMemGo = wasm.MustParseGoReflectFuncCode(readMemGo) // (func (export "wasm_read_mem") (result i64) i32.const 0 i64.load) var readMemWasm = []byte{wasm.OpcodeI32Const, 0, wasm.OpcodeI64Load, 0x3, 0x0, wasm.OpcodeEnd} diff --git a/internal/testing/require/require.go b/internal/testing/require/require.go index d36fb420..94fd5b1d 100644 --- a/internal/testing/require/require.go +++ b/internal/testing/require/require.go @@ -22,6 +22,10 @@ type TestingT interface { Fatal(args ...interface{}) } +type EqualTo interface { + EqualTo(that interface{}) bool +} + // TODO: implement, test and document each function without using testify // Contains fails if `s` does not contain `substr` using strings.Contains. @@ -69,6 +73,18 @@ func Equal(t TestingT, expected, actual interface{}, formatWithArgs ...interface } else if et.Kind() < reflect.Array { fail(t, fmt.Sprintf("expected %v, but was %v", expected, actual), "", formatWithArgs...) return + } else if et.Kind() == reflect.Func { + // compare funcs by string pointer + expected := fmt.Sprintf("%v", expected) + actual := fmt.Sprintf("%v", actual) + if expected != actual { + fail(t, fmt.Sprintf("expected %s, but was %s", expected, actual), "", formatWithArgs...) + } + return + } else if eq, ok := actual.(EqualTo); ok { + if !eq.EqualTo(expected) { + fail(t, fmt.Sprintf("expected %v, but was %v", expected, actual), "", formatWithArgs...) + } } // If we have the same type, and it isn't a string, but the expected and actual values on a different line. diff --git a/internal/wasm/binary/code.go b/internal/wasm/binary/code.go index c0134c0a..09511024 100644 --- a/internal/wasm/binary/code.go +++ b/internal/wasm/binary/code.go @@ -84,7 +84,7 @@ func decodeCode(r *bytes.Reader) (*wasm.Code, error) { // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-code func encodeCode(c *wasm.Code) []byte { if c.GoFunc != nil { - panic("BUG: GoFunc is not encodable") + panic("BUG: GoFunction is not encodable") } // local blocks compress locals while preserving index order by grouping locals of the same type. diff --git a/internal/wasm/binary/encoder_test.go b/internal/wasm/binary/encoder_test.go index 3f8925e7..544d1c6a 100644 --- a/internal/wasm/binary/encoder_test.go +++ b/internal/wasm/binary/encoder_test.go @@ -1,9 +1,9 @@ package binary import ( + "context" "testing" - "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/wasm" @@ -211,13 +211,13 @@ func TestModule_Encode(t *testing.T) { func TestModule_Encode_HostFunctionSection_Unsupported(t *testing.T) { // We don't currently have an approach to serialize reflect.Value pointers - fn := func(api.Module) {} + fn := func(context.Context) {} captured := require.CapturePanic(func() { EncodeModule(&wasm.Module{ TypeSection: []*wasm.FunctionType{{}}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(fn)}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(fn)}, }) }) - require.EqualError(t, captured, "BUG: GoFunc is not encodable") + require.EqualError(t, captured, "BUG: GoFunction is not encodable") } diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index c6efbfb9..c4b43ce5 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -168,7 +168,7 @@ func (f *function) Definition() api.FunctionDefinition { // Call implements the same method as documented on api.Function. func (f *function) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { - return f.ce.Call(ctx, f.fi.Module.CallCtx, params...) + return f.ce.Call(ctx, f.fi.Module.CallCtx, params) } // importedFn implements api.Function and ensures the call context of an imported function is the importing module. @@ -189,7 +189,7 @@ func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, return nil, fmt.Errorf("directly calling host function is not supported") } mod := f.importingModule - return f.ce.Call(ctx, mod, params...) + return f.ce.Call(ctx, mod, params) } // GlobalVal is an internal hack to get the lower 64 bits of a global. diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index e4bd9fe9..27e349e6 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -59,7 +59,7 @@ type ModuleEngine interface { // internally, and shouldn't be used concurrently. type CallEngine interface { // Call invokes a function instance f with given parameters. - Call(ctx context.Context, m *CallContext, params ...uint64) (results []uint64, err error) + Call(ctx context.Context, m *CallContext, params []uint64) (results []uint64, err error) } // TableInitEntry is normalized element segment used for initializing tables by engines. diff --git a/internal/wasm/function_definition.go b/internal/wasm/function_definition.go index 238deea8..358ca329 100644 --- a/internal/wasm/function_definition.go +++ b/internal/wasm/function_definition.go @@ -1,8 +1,6 @@ package wasm import ( - "reflect" - "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/wasmdebug" ) @@ -110,7 +108,7 @@ type FunctionDefinition struct { index Index name string debugName string - goFunc *reflect.Value + goFunc interface{} funcType *FunctionType importDesc *[2]string exportNames []string @@ -151,7 +149,7 @@ func (f *FunctionDefinition) ExportNames() []string { } // GoFunc implements the same method as documented on api.FunctionDefinition. -func (f *FunctionDefinition) GoFunc() *reflect.Value { +func (f *FunctionDefinition) GoFunction() interface{} { return f.goFunc } diff --git a/internal/wasm/function_definition_test.go b/internal/wasm/function_definition_test.go index 8fe3c1aa..93cf53fc 100644 --- a/internal/wasm/function_definition_test.go +++ b/internal/wasm/function_definition_test.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "testing" "github.com/tetratelabs/wazero/api" @@ -9,7 +10,7 @@ import ( func TestModule_BuildFunctionDefinitions(t *testing.T) { nopCode := &Code{Body: []byte{OpcodeEnd}} - fn := func() {} + fn := func(context.Context) {} tests := []struct { name string m *Module @@ -35,13 +36,13 @@ func TestModule_BuildFunctionDefinitions(t *testing.T) { m: &Module{ TypeSection: []*FunctionType{v_v}, FunctionSection: []Index{0}, - CodeSection: []*Code{MustParseGoFuncCode(fn)}, + CodeSection: []*Code{MustParseGoReflectFuncCode(fn)}, }, expected: []*FunctionDefinition{ { index: 0, debugName: ".$0", - goFunc: MustParseGoFuncCode(fn).GoFunc, + goFunc: MustParseGoReflectFuncCode(fn).GoFunc, funcType: v_v, }, }, diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index 5c31f8da..f527654d 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -1,7 +1,9 @@ package wasm import ( + "bytes" "context" + "errors" "fmt" "math" "reflect" @@ -9,53 +11,68 @@ import ( "github.com/tetratelabs/wazero/api" ) -// FunctionKind identifies the type of function that can be called. -type FunctionKind byte - -const ( - // FunctionKindWasm is not a Go function: it is implemented in Wasm. - FunctionKindWasm FunctionKind = iota - // FunctionKindGoNoContext is a function implemented in Go, with a signature matching FunctionType. - FunctionKindGoNoContext - // FunctionKindGoContext is a function implemented in Go, with a signature matching FunctionType, except arg zero is - // a context.Context. - FunctionKindGoContext - // FunctionKindGoModule is a function implemented in Go, with a signature matching FunctionType, except arg - // zero is an api.Module. - FunctionKindGoModule - // FunctionKindGoContextModule is a function implemented in Go, with a signature matching FunctionType, except arg - // zero is a context.Context and arg one is an api.Module. - FunctionKindGoContextModule -) - // Below are reflection code to get the interface type used to parse functions and set values. var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem() var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -// PopGoFuncParams pops the correct number of parameters off the stack into a parameter slice for use in CallGoFunc -// -// For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and -// the stack is [..., A, B], then the function is called as F(A, B) where A and B are interpreted -// as uint32 and float32 respectively. -func PopGoFuncParams(f *FunctionInstance, popParam func() uint64) []uint64 { - // First, determine how many values we need to pop - paramCount := f.GoFunc.Type().NumIn() - switch f.Kind { - case FunctionKindGoNoContext: - case FunctionKindGoContextModule: - paramCount -= 2 - default: - paramCount-- - } +// compile-time check to ensure reflectGoModuleFunction implements +// api.GoModuleFunction. +var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil) - return PopValues(paramCount, popParam) +type reflectGoModuleFunction struct { + fn *reflect.Value + params, results []ValueType } -// PopValues pops api.ValueType values from the stack and returns them in reverse order. +// Call implements the same method as documented on api.GoModuleFunction. +func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, params []uint64) []uint64 { + return callGoFunc(ctx, mod, f.fn, params) +} + +// EqualTo is exposed for testing. +func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool { + if f2, ok := that.(*reflectGoModuleFunction); !ok { + return false + } else { + // TODO compare reflect pointers + return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results) + } +} + +// compile-time check to ensure reflectGoFunction implements api.GoFunction. +var _ api.GoFunction = (*reflectGoFunction)(nil) + +type reflectGoFunction struct { + fn *reflect.Value + params, results []ValueType +} + +// EqualTo is exposed for testing. +func (f *reflectGoFunction) EqualTo(that interface{}) bool { + if f2, ok := that.(*reflectGoFunction); !ok { + return false + } else { + // TODO compare reflect pointers + return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results) + } +} + +// Call implements the same method as documented on api.GoFunction. +func (f *reflectGoFunction) Call(ctx context.Context, params []uint64) []uint64 { + return callGoFunc(ctx, nil, f.fn, params) +} + +// PopValues pops the specified number of api.ValueType parameters off the +// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction. // -// Note: the popper intentionally doesn't return bool or error because the caller's stack depth is trusted. +// For example, if the host function F requires the (x1 uint32, x2 float32) +// parameters, and the stack is [..., A, B], then the function is called as +// F(A, B) where A and B are interpreted as uint32 and float32 respectively. +// +// Note: the popper intentionally doesn't return bool or error because the +// caller's stack depth is trusted. func PopValues(count int, popper func() uint64) []uint64 { if count == 0 { return nil @@ -67,33 +84,20 @@ func PopValues(count int, popper func() uint64) []uint64 { return params } -type HostFn func(ctx context.Context, mod api.Module, params ...uint64) ([]uint64, error) - -// CallGoFunc executes the FunctionInstance.GoFunc by converting params to Go types. The results of the function call -// are converted back to api.ValueType. -// -// * callCtx is passed to the host function as a first argument. -// -// Note: ctx must use the caller's memory, which might be different from the defining module on an imported function. -func CallGoFunc(ctx context.Context, callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 { - tp := f.GoFunc.Type() +// callGoFunc executes the reflective function by converting params to Go +// types. The results of the function call are converted back to api.ValueType. +func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params []uint64) []uint64 { + tp := fn.Type() var in []reflect.Value if tp.NumIn() != 0 { in = make([]reflect.Value, tp.NumIn()) - i := 0 - switch f.Kind { - case FunctionKindGoContext: - in[0] = newContextVal(ctx) - i = 1 - case FunctionKindGoModule: - in[0] = newModuleVal(callCtx) - i = 1 - case FunctionKindGoContextModule: - in[0] = newContextVal(ctx) - in[1] = newModuleVal(callCtx) - i = 2 + i := 1 + in[0] = newContextVal(ctx) + if mod != nil { + in[1] = newModuleVal(mod) + i++ } for _, raw := range params { @@ -121,7 +125,7 @@ func CallGoFunc(ctx context.Context, callCtx *CallContext, f *FunctionInstance, if tp.NumOut() > 0 { results = make([]uint64, 0, tp.NumOut()) } - for i, ret := range f.GoFunc.Call(in) { + for i, ret := range fn.Call(in) { switch ret.Kind() { case reflect.Float32: results = append(results, uint64(math.Float32bits(float32(ret.Float())))) @@ -150,19 +154,19 @@ func newModuleVal(m api.Module) reflect.Value { return val } -// MustParseGoFuncCode parses Code from the go function or panics. +// MustParseGoReflectFuncCode parses Code from the go function or panics. // // Exposing this simplifies FunctionDefinition of host functions in built-in host // modules and tests. -func MustParseGoFuncCode(fn interface{}) *Code { - _, _, code, err := parseGoFunc(fn) +func MustParseGoReflectFuncCode(fn interface{}) *Code { + _, _, code, err := parseGoReflectFunc(fn) if err != nil { panic(err) } return code } -func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err error) { +func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code, err error) { fnV := reflect.ValueOf(fn) p := fnV.Type() @@ -171,16 +175,15 @@ func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err e return } - fk := kind(p) - code = &Code{IsHostFunction: true, Kind: fk, GoFunc: &fnV} + needsMod, needsErr := needsModule(p) + if needsErr != nil { + err = needsErr + return + } - pOffset := 0 - switch fk { - case FunctionKindGoNoContext: - case FunctionKindGoContextModule: - pOffset = 2 - default: - pOffset = 1 + pOffset := 1 // ctx + if needsMod { + pOffset = 2 // ctx, mod } pCount := p.NumIn() - pOffset @@ -229,23 +232,32 @@ func parseGoFunc(fn interface{}) (params, results []ValueType, code *Code, err e } return } + + code = &Code{IsHostFunction: true} + if needsMod { + code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results} + } else { + code.GoFunc = &reflectGoFunction{fn: &fnV, params: params, results: results} + } return } -func kind(p reflect.Type) FunctionKind { +func needsModule(p reflect.Type) (bool, error) { pCount := p.NumIn() - if pCount > 0 && p.In(0).Kind() == reflect.Interface { + if pCount == 0 { + return false, errors.New("invalid signature: context.Context must be param[0]") + } + if p.In(0).Kind() == reflect.Interface { p0 := p.In(0) if p0.Implements(moduleType) { - return FunctionKindGoModule + return false, errors.New("invalid signature: api.Module parameter must be preceded by context.Context") } else if p0.Implements(goContextType) { if pCount >= 2 && p.In(1).Implements(moduleType) { - return FunctionKindGoContextModule + return true, nil } - return FunctionKindGoContext } } - return FunctionKindGoNoContext + return false, nil } func getTypeOf(kind reflect.Kind) (ValueType, bool) { diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index afe7c3e3..6e3bbb9d 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -15,78 +15,42 @@ var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") func Test_parseGoFunc(t *testing.T) { var tests = []struct { - name string - inputFunc interface{} - expectedKind FunctionKind - expectedType *FunctionType + name string + input interface{} + expectNeedsModule bool + expectedType *FunctionType }{ { - name: "nullary", - inputFunc: func() {}, - expectedKind: FunctionKindGoNoContext, + name: "(ctx) -> ()", + input: func(context.Context) {}, expectedType: &FunctionType{}, }, { - name: "wasm.Module void return", - inputFunc: func(api.Module) {}, - expectedKind: FunctionKindGoModule, - expectedType: &FunctionType{}, - }, - { - name: "context.Context void return", - inputFunc: func(context.Context) {}, - expectedKind: FunctionKindGoContext, - expectedType: &FunctionType{}, - }, - { - name: "context.Context and api.Module void return", - inputFunc: func(context.Context, api.Module) {}, - expectedKind: FunctionKindGoContextModule, - expectedType: &FunctionType{}, + name: "(ctx, mod) -> ()", + input: func(context.Context, api.Module) {}, + expectNeedsModule: true, + expectedType: &FunctionType{}, }, { name: "all supported params and i32 result", - inputFunc: func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, - expectedKind: FunctionKindGoNoContext, + input: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, }, { - name: "all supported params and all supported results", - inputFunc: func(uint32, uint64, float32, float64, uintptr) (uint32, uint64, float32, float64, uintptr) { - return 0, 0, 0, 0, 0 - }, - expectedKind: FunctionKindGoNoContext, - expectedType: &FunctionType{ - Params: []ValueType{i32, i64, f32, f64, externref}, - Results: []ValueType{i32, i64, f32, f64, externref}, - }, - }, - { - name: "all supported params and i32 result - wasm.Module", - inputFunc: func(api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, - expectedKind: FunctionKindGoModule, - expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, - }, - { - name: "all supported params and i32 result - context.Context", - inputFunc: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, - expectedKind: FunctionKindGoContext, - expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, - }, - { - name: "all supported params and i32 result - context.Context and api.Module", - inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, - expectedKind: FunctionKindGoContextModule, - expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, + name: "all supported params and i32 result - context.Context and api.Module", + input: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 }, + expectNeedsModule: true, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}}, }, } for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) + paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input) require.NoError(t, err) - require.Equal(t, tc.expectedKind, code.Kind) + _, isModuleFunc := code.GoFunc.(api.GoModuleFunction) + require.Equal(t, tc.expectNeedsModule, isModuleFunc) require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes}) }) } @@ -94,11 +58,20 @@ func Test_parseGoFunc(t *testing.T) { func Test_parseGoFunc_Errors(t *testing.T) { tests := []struct { - name string - input interface{} - allowErrorResult bool - expectedErr string + name string + input interface{} + expectedErr string }{ + { + name: "no context", + input: func() {}, + expectedErr: "invalid signature: context.Context must be param[0]", + }, + { + name: "module no context", + input: func(api.Module) {}, + expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", + }, { name: "not a func", input: struct{}{}, @@ -106,23 +79,23 @@ func Test_parseGoFunc_Errors(t *testing.T) { }, { name: "unsupported param", - input: func(uint32, string) {}, - expectedErr: "param[1] is unsupported: string", + input: func(context.Context, uint32, string) {}, + expectedErr: "param[2] is unsupported: string", }, { name: "unsupported result", - input: func() string { return "" }, + input: func(context.Context) string { return "" }, expectedErr: "result[0] is unsupported: string", }, { name: "error result", - input: func() error { return nil }, + input: func(context.Context) error { return nil }, expectedErr: "result[0] is an error, which is unsupported", }, { - name: "multiple context types", + name: "incorrect order", input: func(api.Module, context.Context) error { return nil }, - expectedErr: "param[1] is a context.Context, which may be defined only once as param[0]", + expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context", }, { name: "multiple context.Context", @@ -131,8 +104,8 @@ func Test_parseGoFunc_Errors(t *testing.T) { }, { name: "multiple wasm.Module", - input: func(api.Module, uint64, api.Module) error { return nil }, - expectedErr: "param[2] is a api.Module, which may be defined only once as param[0]", + input: func(context.Context, api.Module, uint64, api.Module) error { return nil }, + expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]", }, } @@ -140,7 +113,7 @@ func Test_parseGoFunc_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - _, _, _, err := parseGoFunc(tc.input) + _, _, _, err := parseGoReflectFunc(tc.input) require.EqualError(t, err, tc.expectedErr) }) } @@ -195,164 +168,31 @@ func TestPopValues(t *testing.T) { } } -func TestPopGoFuncParams(t *testing.T) { - stackVals := []uint64{1, 2, 3, 4, 5, 6, 7} - var tests = []struct { - name string - inputFunc interface{} - expected []uint64 - }{ - { - name: "nullary", - inputFunc: func() {}, - }, - { - name: "wasm.Module", - inputFunc: func(api.Module) {}, - }, - { - name: "context.Context", - inputFunc: func(context.Context) {}, - }, - { - name: "context.Context and api.Module", - inputFunc: func(context.Context, api.Module) {}, - }, - { - name: "all supported params", - inputFunc: func(uint32, uint64, float32, float64, uintptr) {}, - expected: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "all supported params - wasm.Module", - inputFunc: func(api.Module, uint32, uint64, float32, float64, uintptr) {}, - expected: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "all supported params - context.Context", - inputFunc: func(context.Context, uint32, uint64, float32, float64, uintptr) {}, - expected: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "all supported params - context.Context and api.Module", - inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) {}, - expected: []uint64{3, 4, 5, 6, 7}, - }, - } - - for _, tt := range tests { - tc := tt - - t.Run(tc.name, func(t *testing.T) { - _, _, code, err := parseGoFunc(tc.inputFunc) - require.NoError(t, err) - - vals := PopGoFuncParams(&FunctionInstance{Kind: code.Kind, GoFunc: code.GoFunc}, (&stack{stackVals}).pop) - require.Equal(t, tc.expected, vals) - }) - } -} - -func TestCallGoFunc(t *testing.T) { +func Test_callGoFunc(t *testing.T) { tPtr := uintptr(unsafe.Pointer(t)) callCtx := &CallContext{} - callCtxPtr := uintptr(unsafe.Pointer(callCtx)) var tests = []struct { name string - inputFunc interface{} + input interface{} inputParams, expectedResults []uint64 }{ - { - name: "nullary", - inputFunc: func() {}, - }, - { - name: "wasm.Module void return", - inputFunc: func(m api.Module) { - require.Equal(t, callCtx, m) - }, - }, { name: "context.Context void return", - inputFunc: func(ctx context.Context) { + input: func(ctx context.Context) { require.Equal(t, testCtx, ctx) }, }, { name: "context.Context and api.Module void return", - inputFunc: func(ctx context.Context, m api.Module) { + input: func(ctx context.Context, m api.Module) { require.Equal(t, testCtx, ctx) require.Equal(t, callCtx, m) }, }, - { - name: "all supported params and i32 result", - inputFunc: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 { - require.Equal(t, tPtr, v) - require.Equal(t, uint32(math.MaxUint32), w) - require.Equal(t, uint64(math.MaxUint64), x) - require.Equal(t, float32(math.MaxFloat32), y) - require.Equal(t, math.MaxFloat64, z) - return 100 - }, - inputParams: []uint64{ - api.EncodeExternref(tPtr), - math.MaxUint32, - math.MaxUint64, - api.EncodeF32(math.MaxFloat32), - api.EncodeF64(math.MaxFloat64), - }, - expectedResults: []uint64{100}, - }, - { - name: "all supported params and all supported results", - inputFunc: func(v uintptr, w uint32, x uint64, y float32, z float64) (uintptr, uint32, uint64, float32, float64) { - require.Equal(t, tPtr, v) - require.Equal(t, uint32(math.MaxUint32), w) - require.Equal(t, uint64(math.MaxUint64), x) - require.Equal(t, float32(math.MaxFloat32), y) - require.Equal(t, math.MaxFloat64, z) - return uintptr(unsafe.Pointer(callCtx)), 100, 200, 300, 400 - }, - inputParams: []uint64{ - api.EncodeExternref(tPtr), - math.MaxUint32, - math.MaxUint64, - api.EncodeF32(math.MaxFloat32), - api.EncodeF64(math.MaxFloat64), - }, - expectedResults: []uint64{ - api.EncodeExternref(callCtxPtr), - api.EncodeI32(100), - 200, - api.EncodeF32(300), - api.EncodeF64(400), - }, - }, - { - name: "all supported params and i32 result - wasm.Module", - inputFunc: func(m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { - require.Equal(t, callCtx, m) - require.Equal(t, tPtr, v) - require.Equal(t, uint32(math.MaxUint32), w) - require.Equal(t, uint64(math.MaxUint64), x) - require.Equal(t, float32(math.MaxFloat32), y) - require.Equal(t, math.MaxFloat64, z) - return 100 - }, - inputParams: []uint64{ - api.EncodeExternref(tPtr), - math.MaxUint32, - math.MaxUint64, - api.EncodeF32(math.MaxFloat32), - api.EncodeF64(math.MaxFloat64), - }, - expectedResults: []uint64{100}, - }, { name: "all supported params and i32 result - context.Context", - inputFunc: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { + input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { require.Equal(t, testCtx, ctx) require.Equal(t, tPtr, v) require.Equal(t, uint32(math.MaxUint32), w) @@ -372,7 +212,7 @@ func TestCallGoFunc(t *testing.T) { }, { name: "all supported params and i32 result - context.Context and api.Module", - inputFunc: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { + input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 { require.Equal(t, testCtx, ctx) require.Equal(t, callCtx, m) require.Equal(t, tPtr, v) @@ -396,20 +236,18 @@ func TestCallGoFunc(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - paramTypes, resultTypes, code, err := parseGoFunc(tc.inputFunc) + _, _, code, err := parseGoReflectFunc(tc.input) require.NoError(t, err) - results := CallGoFunc( - testCtx, - callCtx, - &FunctionInstance{ - IsHostFunction: code.IsHostFunction, - Kind: code.Kind, - Type: &FunctionType{Params: paramTypes, Results: resultTypes}, - GoFunc: code.GoFunc, - }, - tc.inputParams, - ) + var results []uint64 + switch code.GoFunc.(type) { + case api.GoFunction: + results = code.GoFunc.(api.GoFunction).Call(testCtx, tc.inputParams) + case api.GoModuleFunction: + results = code.GoFunc.(api.GoModuleFunction).Call(testCtx, callCtx, tc.inputParams) + default: + t.Fatal("unexpected type.") + } require.Equal(t, tc.expectedResults, results) }) } diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 1995f781..3faa6320 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -9,6 +9,10 @@ import ( "github.com/tetratelabs/wazero/internal/wasmdebug" ) +type ProxyFuncExporter interface { + ExportProxyFunc(*ProxyFunc) +} + // ProxyFunc is a function defined both in wasm and go. This is used to // optimize the Go signature or obviate calls based on what can be done // mechanically in wasm. @@ -27,6 +31,10 @@ func (p *ProxyFunc) Name() string { return p.Proxied.Name } +type HostFuncExporter interface { + ExportHostFunc(*HostFunc) +} + // HostFunc is a function with an inlined type, used for NewHostModule. // Any corresponding FunctionType will be reused or added to the Module. type HostFunc struct { @@ -49,18 +57,9 @@ type HostFunc struct { Code *Code } -// NewGoFunc returns a HostFunc for the given parameters or panics. -func NewGoFunc(exportName string, name string, paramNames []string, fn interface{}) *HostFunc { - return (&HostFunc{ - ExportNames: []string{exportName}, - Name: name, - ParamNames: paramNames, - }).MustGoFunc(fn) -} - -// MustGoFunc calls WithGoFunc or panics on error. -func (f *HostFunc) MustGoFunc(fn interface{}) *HostFunc { - if ret, err := f.WithGoFunc(fn); err != nil { +// MustGoReflectFunc calls WithGoReflectFunc or panics on error. +func (f *HostFunc) MustGoReflectFunc(fn interface{}) *HostFunc { + if ret, err := f.WithGoReflectFunc(fn); err != nil { panic(err) } else { return ret @@ -68,10 +67,24 @@ func (f *HostFunc) MustGoFunc(fn interface{}) *HostFunc { } // WithGoFunc returns a copy of the function, replacing its Code.GoFunc. -func (f *HostFunc) WithGoFunc(fn interface{}) (*HostFunc, error) { +func (f *HostFunc) WithGoFunc(fn api.GoFunc) *HostFunc { + ret := *f + ret.Code = &Code{IsHostFunction: true, GoFunc: fn} + return &ret +} + +// WithGoModuleFunc returns a copy of the function, replacing its Code.GoFunc. +func (f *HostFunc) WithGoModuleFunc(fn api.GoModuleFunc) *HostFunc { + ret := *f + ret.Code = &Code{IsHostFunction: true, GoFunc: fn} + return &ret +} + +// WithGoReflectFunc returns a copy of the function, replacing its Code.GoFunc. +func (f *HostFunc) WithGoReflectFunc(fn interface{}) (*HostFunc, error) { ret := *f var err error - ret.ParamTypes, ret.ResultTypes, ret.Code, err = parseGoFunc(fn) + ret.ParamTypes, ret.ResultTypes, ret.Code, err = parseGoReflectFunc(fn) return &ret, err } @@ -167,8 +180,8 @@ func addFuncs( nameToFunc[proxy.Name] = proxy funcNames = append(funcNames, proxy.Name) - } else { - params, results, code, ftErr := parseGoFunc(v) + } else { // reflection + params, results, code, ftErr := parseGoReflectFunc(v) if ftErr != nil { return fmt.Errorf("func[%s.%s] %w", moduleName, k, ftErr) } diff --git a/internal/wasm/host_test.go b/internal/wasm/host_test.go index 0a892913..dfdfb237 100644 --- a/internal/wasm/host_test.go +++ b/internal/wasm/host_test.go @@ -1,30 +1,26 @@ package wasm import ( + "context" "testing" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/testing/require" ) -// wasiAPI simulates the real WASI api -type wasiAPI struct { -} - -func (a *wasiAPI) ArgsSizesGet(ctx api.Module, resultArgc, resultArgvBufSize uint32) uint32 { +func argsSizesGet(ctx context.Context, mod api.Module, resultArgc, resultArgvBufSize uint32) uint32 { return 0 } -func (a *wasiAPI) FdWrite(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) uint32 { +func fdWrite(ctx context.Context, mod api.Module, fd, iovs, iovsCount, resultSize uint32) uint32 { return 0 } -func swap(x, y uint32) (uint32, uint32) { +func swap(ctx context.Context, x, y uint32) (uint32, uint32) { return y, x } func TestNewHostModule(t *testing.T) { - a := wasiAPI{} functionArgsSizesGet := "args_sizes_get" functionFdWrite := "fd_write" functionSwap := "swap" @@ -47,8 +43,8 @@ func TestNewHostModule(t *testing.T) { name: "funcs", moduleName: "wasi_snapshot_preview1", nameToGoFunc: map[string]interface{}{ - functionArgsSizesGet: a.ArgsSizesGet, - functionFdWrite: a.FdWrite, + functionArgsSizesGet: argsSizesGet, + functionFdWrite: fdWrite, }, expected: &Module{ TypeSection: []*FunctionType{ @@ -56,7 +52,7 @@ func TestNewHostModule(t *testing.T) { {Params: []ValueType{i32, i32, i32, i32}, Results: []ValueType{i32}}, }, FunctionSection: []Index{0, 1}, - CodeSection: []*Code{MustParseGoFuncCode(a.ArgsSizesGet), MustParseGoFuncCode(a.FdWrite)}, + CodeSection: []*Code{MustParseGoReflectFuncCode(argsSizesGet), MustParseGoReflectFuncCode(fdWrite)}, ExportSection: []*Export{ {Name: "args_sizes_get", Type: ExternTypeFunc, Index: 0}, {Name: "fd_write", Type: ExternTypeFunc, Index: 1}, @@ -79,7 +75,7 @@ func TestNewHostModule(t *testing.T) { expected: &Module{ TypeSection: []*FunctionType{{Params: []ValueType{i32, i32}, Results: []ValueType{i32, i32}}}, FunctionSection: []Index{0}, - CodeSection: []*Code{MustParseGoFuncCode(swap)}, + CodeSection: []*Code{MustParseGoReflectFuncCode(swap)}, ExportSection: []*Export{{Name: "swap", Type: ExternTypeFunc, Index: 0}}, NameSection: &NameSection{ModuleName: "swapper", FunctionNames: NameMap{{Index: 0, Name: "swap"}}}, }, @@ -118,8 +114,7 @@ func requireHostModuleEquals(t *testing.T, expected, actual *Module) { for i, c := range expected.CodeSection { actualCode := actual.CodeSection[i] require.True(t, actualCode.IsHostFunction) - require.Equal(t, c.Kind, actualCode.Kind) - require.Equal(t, c.GoFunc.Type(), actualCode.GoFunc.Type()) + require.Equal(t, c.GoFunc, actualCode.GoFunc) // Not wasm require.Nil(t, actualCode.Body) @@ -140,7 +135,7 @@ func TestNewHostModule_Errors(t *testing.T) { }, { name: "function has multiple results", - nameToGoFunc: map[string]interface{}{"fn": func() (uint32, uint32) { return 0, 0 }}, + nameToGoFunc: map[string]interface{}{"fn": func(context.Context) (uint32, uint32) { return 0, 0 }}, expectedErr: "func[.fn] multiple result types invalid as feature \"multi-value\" is disabled", }, } diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 33f0c011..60f30b12 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "errors" "fmt" - "reflect" "sort" "strings" @@ -595,7 +594,6 @@ func (m *ModuleInstance) BuildFunctions(mod *Module, listeners []experimental.Fu code := mod.CodeSection[i] fns = append(fns, &FunctionInstance{ IsHostFunction: code.IsHostFunction, - Kind: code.Kind, LocalTypes: code.LocalTypes, Body: code.Body, GoFunc: code.GoFunc, @@ -816,9 +814,6 @@ type Code struct { // See https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0 IsHostFunction bool - // Kind describes how this function should be called. - Kind FunctionKind - // LocalTypes are any function-scoped variables in insertion order. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-local LocalTypes []ValueType @@ -827,13 +822,13 @@ type Code struct { // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-expr Body []byte - // GoFunc is a host function defined in Go. - // - // When present, LocalTypes and Body must be nil. + // GoFunc is non-nil when IsHostFunction and defined in go, either + // api.GoFunction or api.GoModuleFunction. When present, LocalTypes and Body must + // be nil. // // Note: This has no serialization format, so is not encodable. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 - GoFunc *reflect.Value + GoFunc interface{} } type DataSegment struct { diff --git a/internal/wasm/store.go b/internal/wasm/store.go index e3a510cf..05e52220 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "errors" "fmt" - "reflect" "sync" "github.com/tetratelabs/wazero/api" @@ -112,9 +111,6 @@ type ( // wasm.Code. IsHostFunction bool - // Kind describes how this function should be called. - Kind FunctionKind - // Type is the signature of this function. Type *FunctionType @@ -124,10 +120,12 @@ type ( // Body is the function body in WebAssembly Binary Format, set when Kind == FunctionKindWasm Body []byte - // GoFunc holds the runtime representation of host functions. - // This is nil when Kind == FunctionKindWasm. Otherwise, all the above fields are ignored as they are - // specific to Wasm functions. - GoFunc *reflect.Value + // GoFunc is non-nil when IsHostFunction and defined in go, either + // api.GoFunction or api.GoModuleFunction. + // + // Note: This has no serialization format, so is not encodable. + // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 + GoFunc interface{} // Fields above here are settable prior to instantiation. Below are set by the Store during instantiation. @@ -408,7 +406,7 @@ func (s *Store) instantiate( module.funcDesc(SectionIDFunction, funcIdx), err) } - _, err = ce.Call(ctx, callCtx) + _, err = ce.Call(ctx, callCtx, nil) if exitErr, ok := err.(*sys.ExitError); ok { // Don't wrap an exit error! return nil, exitErr } else if err != nil { diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 6bde3c6f..937a8297 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -91,7 +91,7 @@ func TestModuleInstance_Memory(t *testing.T) { func TestStore_Instantiate(t *testing.T) { s, ns := newStore() - m, err := NewHostModule("", map[string]interface{}{"fn": func(api.Module) {}}, nil, api.CoreFeaturesV1) + m, err := NewHostModule("", map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1) require.NoError(t, err) sysCtx := sys.DefaultContext(nil) @@ -169,7 +169,7 @@ func TestStore_CloseWithExitCode(t *testing.T) { func TestStore_hammer(t *testing.T) { const importedModuleName = "imported" - m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(api.Module) {}}, nil, api.CoreFeaturesV1) + m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1) require.NoError(t, err) s, ns := newStore() @@ -223,7 +223,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { const importedModuleName = "imported" const importingModuleName = "test" - m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(api.Module) {}}, nil, api.CoreFeaturesV1) + m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1) require.NoError(t, err) t.Run("Fails if module name already in use", func(t *testing.T) { @@ -314,7 +314,7 @@ func TestStore_Instantiate_Errors(t *testing.T) { } func TestCallContext_ExportedFunction(t *testing.T) { - host, err := NewHostModule("host", map[string]interface{}{"host_fn": func(api.Module) {}}, nil, api.CoreFeaturesV1) + host, err := NewHostModule("host", map[string]interface{}{"host_fn": func(context.Context) {}}, nil, api.CoreFeaturesV1) require.NoError(t, err) s, ns := newStore() @@ -400,7 +400,7 @@ func (e *mockModuleEngine) Close(_ context.Context) { } // Call implements the same method as documented on wasm.ModuleEngine. -func (ce *mockCallEngine) Call(ctx context.Context, callCtx *CallContext, _ ...uint64) (results []uint64, err error) { +func (ce *mockCallEngine) Call(ctx context.Context, callCtx *CallContext, _ []uint64) (results []uint64, err error) { if ce.callFailIndex >= 0 && ce.f.Definition.Index() == Index(ce.callFailIndex) { err = errors.New("call failed") return diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 41698fa0..1363239b 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -7,7 +7,6 @@ import ( "fmt" "math" "os" - "reflect" "strings" "github.com/tetratelabs/wazero/api" @@ -209,7 +208,7 @@ type CompilationResult struct { // GoFunc is the data returned by the same field documented on wasm.Code. // In this case, IsHostFunction is true and other fields can be ignored. - GoFunc *reflect.Value + GoFunc interface{} // Operations holds wazeroir operations compiled from Wasm instructions in a Wasm function. Operations []Operation @@ -270,12 +269,14 @@ func CompileFunctions(_ context.Context, enabledFeatures api.CoreFeatures, callF sig := module.TypeSection[typeID] code := module.CodeSection[funcIndex] if code.GoFunc != nil { + // Assume the function might use memory if it has a parameter for the api.Module + _, usesMemory := code.GoFunc.(api.GoModuleFunction) + ret = append(ret, &CompilationResult{ IsHostFunction: true, - // Assume the function might use memory if it has a parameter for the api.Module - UsesMemory: code.Kind == wasm.FunctionKindGoModule || code.Kind == wasm.FunctionKindGoContextModule, - GoFunc: code.GoFunc, - Signature: sig, + UsesMemory: usesMemory, + GoFunc: code.GoFunc, + Signature: sig, }) if len(sig.Params) > 0 && sig.ParamNumInUint64 == 0 { diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index cb573cb4..fe152fa8 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -82,25 +82,16 @@ func TestCompile(t *testing.T) { module: &wasm.Module{ TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func() {})}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(func(context.Context) {})}, }, expected: &CompilationResult{IsHostFunction: true}, }, - { - name: "host go api.Module uses memory", - module: &wasm.Module{ - TypeSection: []*wasm.FunctionType{v_v}, - FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(api.Module) {})}, - }, - expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, - }, { name: "host go context.Context api.Module uses memory", module: &wasm.Module{ TypeSection: []*wasm.FunctionType{v_v}, FunctionSection: []wasm.Index{0}, - CodeSection: []*wasm.Code{wasm.MustParseGoFuncCode(func(context.Context, api.Module) {})}, + CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(func(context.Context, api.Module) {})}, }, expected: &CompilationResult{IsHostFunction: true, UsesMemory: true}, }, diff --git a/runtime.go b/runtime.go index d09a8391..6e2a29a0 100644 --- a/runtime.go +++ b/runtime.go @@ -27,10 +27,12 @@ type Runtime interface { // Below defines and instantiates a module named "env" with one function: // // ctx := context.Background() - // hello := func() { + // hello := func(context.Context) { // fmt.Fprintln(stdout, "hello!") // } - // _, err := r.NewHostModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx, r) + // _, err := r.NewHostModuleBuilder("env"). + // NewFunctionBuilder().WithFunc(hello).Export("hello"). + // Instantiate(ctx, r) NewHostModuleBuilder(moduleName string) HostModuleBuilder // CompileModule decodes the WebAssembly binary (%.wasm) or errs if invalid. diff --git a/runtime_test.go b/runtime_test.go index 3357db99..b34a9316 100644 --- a/runtime_test.go +++ b/runtime_test.go @@ -333,7 +333,7 @@ func TestRuntime_InstantiateModule_UsesContext(t *testing.T) { } _, err := r.NewHostModuleBuilder("env"). - ExportFunction("start", start). + NewFunctionBuilder().WithFunc(start).Export("start"). Instantiate(testCtx, r) require.NoError(t, err) @@ -410,7 +410,7 @@ func TestRuntime_InstantiateModuleFromBinary_ErrorOnStart(t *testing.T) { } host, err := r.NewHostModuleBuilder(""). - ExportFunction("start", start). + NewFunctionBuilder().WithFunc(start).Export("start"). Instantiate(testCtx, r) require.NoError(t, err) @@ -461,7 +461,9 @@ func TestRuntime_InstantiateModule_ExitError(t *testing.T) { require.NoError(t, m.CloseWithExitCode(ctx, 2)) } - _, err := r.NewHostModuleBuilder("env").ExportFunction("exit", start).Instantiate(testCtx, r) + _, err := r.NewHostModuleBuilder("env"). + NewFunctionBuilder().WithFunc(start).Export("exit"). + Instantiate(testCtx, r) require.NoError(t, err) one := uint32(1) @@ -580,8 +582,8 @@ func TestHostFunctionWithCustomContext(t *testing.T) { } _, err := r.NewHostModuleBuilder("env"). - ExportFunction("host", start). - ExportFunction("host2", callFunc). + NewFunctionBuilder().WithFunc(start).Export("host"). + NewFunctionBuilder().WithFunc(callFunc).Export("host2"). Instantiate(hostCtx, r) require.NoError(t, err)