diff --git a/README.md b/README.md index 1074200b..5a9e46b9 100644 --- a/README.md +++ b/README.md @@ -17,16 +17,19 @@ For the impatient, here's how invoking a factorial function looks in wazero: ```golang func main() { + // Choose the context to use for function calls. + ctx := context.Background() + // Read a WebAssembly binary containing an exported "fac" function. // * Ex. (func (export "fac") (param i64) (result i64) ... source, _ := os.ReadFile("./tests/bench/testdata/fac.wasm") // Instantiate the module and return its exported functions - module, _ := wazero.NewRuntime().InstantiateModuleFromCode(source) + module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) defer module.Close() // Discover 7! is 5040 - fmt.Println(module.ExportedFunction("fac").Call(nil, 7)) + fmt.Println(module.ExportedFunction("fac").Call(ctx, 7)) } ``` @@ -56,7 +59,7 @@ env, err := r.NewModuleBuilder("env"). ExportFunction("log_i32", func(v uint32) { fmt.Println("log_i32 >>", v) }). - Instantiate() + Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -73,11 +76,11 @@ bundles an implementation. That way, you don't have to write these functions. For example, here's how you can allow WebAssembly modules to read "/work/home/a.txt" as "/a.txt" or "./a.txt": ```go -wm, err := wasi.InstantiateSnapshotPreview1(r) +wm, err := wasi.InstantiateSnapshotPreview1(ctx, r) defer wm.Close() config := wazero.ModuleConfig().WithFS(os.DirFS("/work/home")) -module, err := r.InstantiateModule(binary, config) +module, err := r.InstantiateModule(ctx, binary, config) defer module.Close() ... ``` diff --git a/api/wasm.go b/api/wasm.go index 6e58d1ba..fb79160f 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -84,14 +84,6 @@ type Module interface { // with the exitCode. CloseWithExitCode(exitCode uint32) error - // Context returns any propagated context from the Runtime or a prior function call. - // - // The returned context is always non-nil; it defaults to context.Background. - Context() context.Context - - // WithContext allows callers to override the propagated context, for example, to add values to it. - WithContext(ctx context.Context) Module - // Memory returns a memory defined in this module or nil if there are none wasn't. Memory() Memory @@ -129,23 +121,11 @@ type Function interface { // encoded according to ResultTypes. An error is returned for any failure looking up or invoking the function // including signature mismatch. // - // If `m` is nil, it defaults to the module the function was defined in. - // - // To override context propagation, use Module.WithContext - // fn = m.ExportedFunction("fib") - // results, err := fn(m.WithContext(ctx), 5) - // --snip-- - // - // To ensure context propagation in a host function body, pass the `ctx` parameter: - // hostFunction := func(m api.Module, offset, byteCount uint32) uint32 { - // fn = m.ExportedFunction("__read") - // results, err := fn(m, offset, byteCount) - // --snip-- - // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: If Module.Close or Module.CloseWithExitCode were invoked during this call, the error returned may be a // sys.ExitError. Interpreting this is specific to the module. For example, some "main" functions always call a // function that exits. - Call(m Module, params ...uint64) ([]uint64, error) + Call(ctx context.Context, params ...uint64) ([]uint64, error) } // Global is a WebAssembly 1.0 (20191205) global exported from an instantiated module (wazero.Runtime InstantiateModule). diff --git a/builder.go b/builder.go index c2213ae1..37441f36 100644 --- a/builder.go +++ b/builder.go @@ -1,6 +1,7 @@ package wazero import ( + "context" "fmt" "github.com/tetratelabs/wazero/api" @@ -13,19 +14,20 @@ import ( // // Ex. Below defines and instantiates a module named "env" with one function: // +// ctx := context.Background() // hello := func() { // fmt.Fprintln(stdout, "hello!") // } -// env, _ := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate() +// env, _ := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx) // // If the same module may be instantiated multiple times, it is more efficient to separate steps. Ex. // -// env, _ := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Build() +// env, _ := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Build(ctx) // -// env1, _ := r.InstantiateModuleWithConfig(env, NewModuleConfig().WithName("env.1")) +// env1, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.1")) // defer env1.Close() // -// env2, _ := r.InstantiateModuleWithConfig(env, NewModuleConfig().WithName("env.2")) +// env2, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.2")) // defer env2.Close() // // Note: Builder methods do not return errors, to allow chaining. Any validation errors are deferred until Build. @@ -56,11 +58,8 @@ type ModuleBuilder interface { // return x + y + m.Value(extraKey).(uint32) // } // - // The most sophisticated context is api.Module, which allows access to the Go context, but also - // allows writing to memory. This is important because there are only numeric types in Wasm. The only way to share other - // data is via writing memory and sharing offsets. - // - // Ex. This reads the parameters from! + // Ex. This uses an api.Module to reads the parameters from memory. This is important because there are only numeric + // types in Wasm. The only way to share other data is via writing memory and sharing offsets. // // addInts := func(m api.Module, offset uint32) uint32 { // x, _ := m.Memory().ReadUint32Le(offset) @@ -68,6 +67,14 @@ type ModuleBuilder interface { // return x + y // } // + // If both parameters exist, they must be in order at positions zero and one. + // + // Ex. This uses propagates context properly when calling other functions exported in the api.Module: + // callRead := func(ctx context.Context, m api.Module, offset, byteCount uint32) uint32 { + // fn = m.ExportedFunction("__read") + // results, err := fn(ctx, offset, byteCount) + // --snip-- + // // Note: If a function is already exported with the same name, this overwrites it. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 ExportFunction(name string, goFunc interface{}) ModuleBuilder @@ -144,12 +151,12 @@ type ModuleBuilder interface { ExportGlobalF64(name string, v float64) ModuleBuilder // Build returns a module to instantiate, or returns an error if any of the configuration is invalid. - Build() (*CompiledCode, error) + Build(ctx context.Context) (*CompiledCode, error) // Instantiate is a convenience that calls Build, then Runtime.InstantiateModule // // Note: Fields in the builder are copied during instantiation: Later changes do not affect the instantiated result. - Instantiate() (api.Module, error) + Instantiate(ctx context.Context) (api.Module, error) } // moduleBuilder implements ModuleBuilder @@ -237,7 +244,7 @@ func (b *moduleBuilder) ExportGlobalF64(name string, v float64) ModuleBuilder { } // Build implements ModuleBuilder.Build -func (b *moduleBuilder) Build() (*CompiledCode, error) { +func (b *moduleBuilder) Build(ctx context.Context) (*CompiledCode, error) { // Verify the maximum limit here, so we don't have to pass it to wasm.NewHostModule maxLimit := b.r.memoryMaxPages for name, mem := range b.nameToMemory { @@ -252,7 +259,7 @@ func (b *moduleBuilder) Build() (*CompiledCode, error) { return nil, err } - if err = b.r.store.Engine.CompileModule(module); err != nil { + if err = b.r.store.Engine.CompileModule(ctx, module); err != nil { return nil, err } @@ -260,15 +267,15 @@ func (b *moduleBuilder) Build() (*CompiledCode, error) { } // Instantiate implements ModuleBuilder.Instantiate -func (b *moduleBuilder) Instantiate() (api.Module, error) { - if module, err := b.Build(); err != nil { +func (b *moduleBuilder) Instantiate(ctx context.Context) (api.Module, error) { + if module, err := b.Build(ctx); err != nil { return nil, err } else { - if err = b.r.store.Engine.CompileModule(module.module); err != nil { + if err = b.r.store.Engine.CompileModule(ctx, module.module); err != nil { return nil, err } // *wasm.ModuleInstance cannot be tracked, so we release the cache inside of this function. defer module.Close() - return b.r.InstantiateModuleWithConfig(module, NewModuleConfig().WithName(b.moduleName)) + return b.r.InstantiateModuleWithConfig(ctx, module, NewModuleConfig().WithName(b.moduleName)) } } diff --git a/builder_test.go b/builder_test.go index 4a18bd7d..877883f5 100644 --- a/builder_test.go +++ b/builder_test.go @@ -344,7 +344,7 @@ func TestNewModuleBuilder_Build(t *testing.T) { t.Run(tc.name, func(t *testing.T) { b := tc.input(NewRuntime()).(*moduleBuilder) - m, err := b.Build() + m, err := b.Build(testCtx) require.NoError(t, err) requireHostModuleEquals(t, tc.expected, m.module) @@ -352,7 +352,7 @@ func TestNewModuleBuilder_Build(t *testing.T) { require.Equal(t, b.r.store.Engine, m.compiledEngine) // Built module must be instantiable by Engine. - _, err = b.r.InstantiateModule(m) + _, err = b.r.InstantiateModule(testCtx, m) require.NoError(t, err) }) } @@ -385,7 +385,7 @@ func TestNewModuleBuilder_Build_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - _, e := tc.input(NewRuntime()).Build() + _, e := tc.input(NewRuntime()).Build(testCtx) require.EqualError(t, e, tc.expectedErr) }) } @@ -394,7 +394,7 @@ func TestNewModuleBuilder_Build_Errors(t *testing.T) { // TestNewModuleBuilder_Instantiate ensures Runtime.InstantiateModule is called on success. func TestNewModuleBuilder_Instantiate(t *testing.T) { r := NewRuntime() - m, err := r.NewModuleBuilder("env").Instantiate() + m, err := r.NewModuleBuilder("env").Instantiate(testCtx) require.NoError(t, err) // If this was instantiated, it would be added to the store under the same name @@ -404,10 +404,10 @@ func TestNewModuleBuilder_Instantiate(t *testing.T) { // TestNewModuleBuilder_Instantiate_Errors ensures errors propagate from Runtime.InstantiateModule func TestNewModuleBuilder_Instantiate_Errors(t *testing.T) { r := NewRuntime() - _, err := r.NewModuleBuilder("env").Instantiate() + _, err := r.NewModuleBuilder("env").Instantiate(testCtx) require.NoError(t, err) - _, err = r.NewModuleBuilder("env").Instantiate() + _, err = r.NewModuleBuilder("env").Instantiate(testCtx) require.EqualError(t, err, "module env has already been instantiated") } diff --git a/config.go b/config.go index 7d73a47a..f6986389 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package wazero import ( - "context" "errors" "fmt" "io" @@ -18,14 +17,12 @@ import ( type RuntimeConfig struct { enabledFeatures wasm.Features newEngine func(wasm.Features) wasm.Engine - ctx context.Context memoryMaxPages uint32 } // engineLessConfig helps avoid copy/pasting the wrong defaults. var engineLessConfig = &RuntimeConfig{ enabledFeatures: wasm.Features20191205, - ctx: context.Background(), memoryMaxPages: wasm.MemoryMaxPages, } @@ -34,7 +31,6 @@ func (c *RuntimeConfig) clone() *RuntimeConfig { return &RuntimeConfig{ enabledFeatures: c.enabledFeatures, newEngine: c.newEngine, - ctx: c.ctx, memoryMaxPages: c.memoryMaxPages, } } @@ -56,23 +52,6 @@ func NewRuntimeConfigInterpreter() *RuntimeConfig { return ret } -// WithContext sets the default context used to initialize the module. Defaults to context.Background if nil. -// -// Notes: -// * If the Module defines a start function, this is used to invoke it. -// * This is the outer-most ancestor of api.Module Context() during api.Function invocations. -// * This is the default context of api.Function when callers pass nil. -// -// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#start-function%E2%91%A0 -func (c *RuntimeConfig) WithContext(ctx context.Context) *RuntimeConfig { - if ctx == nil { - ctx = context.Background() - } - ret := c.clone() - ret.ctx = ctx - return ret -} - // WithMemoryMaxPages reduces the maximum number of pages a module can define from 65536 pages (4GiB) to a lower value. // // Notes: diff --git a/config_test.go b/config_test.go index 517f6ee7..2111dcc6 100644 --- a/config_test.go +++ b/config_test.go @@ -1,7 +1,6 @@ package wazero import ( - "context" "io" "math" "testing" @@ -17,24 +16,6 @@ func TestRuntimeConfig(t *testing.T) { with func(*RuntimeConfig) *RuntimeConfig expected *RuntimeConfig }{ - { - name: "WithContext", - with: func(c *RuntimeConfig) *RuntimeConfig { - return c.WithContext(context.TODO()) - }, - expected: &RuntimeConfig{ - ctx: context.TODO(), - }, - }, - { - name: "WithContext - nil", - with: func(c *RuntimeConfig) *RuntimeConfig { - return c.WithContext(nil) //nolint - }, - expected: &RuntimeConfig{ - ctx: context.Background(), - }, - }, { name: "WithMemoryMaxPages", with: func(c *RuntimeConfig) *RuntimeConfig { diff --git a/example_test.go b/example_test.go index 730ced0a..4d6431c1 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package wazero import ( + "context" _ "embed" "fmt" "log" @@ -10,11 +11,14 @@ import ( // // See https://github.com/tetratelabs/wazero/tree/main/examples for more examples. func Example() { + // Choose the context to use for function calls. + ctx := context.Background() + // Create a new WebAssembly Runtime. r := NewRuntime() // Add a module to the runtime named "wasm/math" which exports one function "add", implemented in WebAssembly. - mod, err := r.InstantiateModuleFromCode([]byte(`(module $wasm/math + mod, err := r.InstantiateModuleFromCode(ctx, []byte(`(module $wasm/math (func $add (param i32 i32) (result i32) local.get 0 local.get 1 @@ -31,7 +35,7 @@ func Example() { add := mod.ExportedFunction("add") x, y := uint64(1), uint64(2) - results, err := add.Call(nil, x, y) + results, err := add.Call(ctx, x, y) if err != nil { log.Fatal(err) } diff --git a/examples/basic/add.go b/examples/basic/add.go index 0a229418..44d5a14f 100644 --- a/examples/basic/add.go +++ b/examples/basic/add.go @@ -1,6 +1,7 @@ package add import ( + "context" _ "embed" "fmt" "log" @@ -13,11 +14,14 @@ import ( // main implements a basic function in both Go and WebAssembly. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Add a module to the runtime named "wasm/math" which exports one function "add", implemented in WebAssembly. - wasm, err := r.InstantiateModuleFromCode([]byte(`(module $wasm/math + wasm, err := r.InstantiateModuleFromCode(ctx, []byte(`(module $wasm/math (func $add (param i32 i32) (result i32) local.get 0 local.get 1 @@ -34,7 +38,7 @@ func main() { host, err := r.NewModuleBuilder("host/math"). ExportFunction("add", func(v1, v2 uint32) uint32 { return v1 + v2 - }).Instantiate() + }).Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -46,7 +50,7 @@ func main() { // Call the same function in both modules and print the results to the console. for _, mod := range []api.Module{wasm, host} { add := mod.ExportedFunction("add") - results, err := add.Call(nil, x, y) + results, err := add.Call(ctx, x, y) if err != nil { log.Fatal(err) } diff --git a/examples/import-go/age-calculator.go b/examples/import-go/age-calculator.go index 0d080cc1..c55b4a63 100644 --- a/examples/import-go/age-calculator.go +++ b/examples/import-go/age-calculator.go @@ -1,6 +1,7 @@ package age_calculator import ( + "context" _ "embed" "fmt" "log" @@ -16,6 +17,10 @@ import ( // // See README.md for a full description. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate a module named "env" that exports functions to get the @@ -35,7 +40,7 @@ func main() { } return uint32(time.Now().Year()) }). - Instantiate() + Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -46,7 +51,7 @@ func main() { // // Note: The import syntax in both Text and Binary format is the same // regardless of if the function was defined in Go or WebAssembly. - ageCalculator, err := r.InstantiateModuleFromCode([]byte(` + ageCalculator, err := r.InstantiateModuleFromCode(ctx, []byte(` ;; Define the optional module name. '$' prefixing is a part of the text format. (module $age-calculator @@ -91,14 +96,14 @@ func main() { } // First, try calling the "get_age" function and printing to the console externally. - results, err := ageCalculator.ExportedFunction("get_age").Call(nil, birthYear) + results, err := ageCalculator.ExportedFunction("get_age").Call(ctx, birthYear) if err != nil { log.Fatal(err) } fmt.Println("println >>", results[0]) // First, try calling the "log_age" function and printing to the console externally. - _, err = ageCalculator.ExportedFunction("log_age").Call(nil, birthYear) + _, err = ageCalculator.ExportedFunction("log_age").Call(ctx, birthYear) if err != nil { log.Fatal(err) } diff --git a/examples/multiple-results/multiple-results.go b/examples/multiple-results/multiple-results.go index e5f041dd..aa8b375a 100644 --- a/examples/multiple-results/multiple-results.go +++ b/examples/multiple-results/multiple-results.go @@ -1,6 +1,7 @@ package multiple_results import ( + "context" _ "embed" "fmt" "log" @@ -24,18 +25,21 @@ import ( // * multiValueHostFunctions // See https://github.com/WebAssembly/spec/blob/main/proposals/multi-value/Overview.md func main() { - // Create a portable WebAssembly Runtime. + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. runtime := wazero.NewRuntime() // Add a module that uses offset parameters for multiple results, with functions defined in WebAssembly. - wasm, err := resultOffsetWasmFunctions(runtime) + wasm, err := resultOffsetWasmFunctions(ctx, runtime) if err != nil { log.Fatal(err) } defer wasm.Close() // Add a module that uses offset parameters for multiple results, with functions defined in Go. - host, err := resultOffsetHostFunctions(runtime) + host, err := resultOffsetHostFunctions(ctx, runtime) if err != nil { log.Fatal(err) } @@ -48,14 +52,14 @@ func main() { ) // Add a module that uses multiple results values, with functions defined in WebAssembly. - wasmWithMultiValue, err := multiValueWasmFunctions(runtimeWithMultiValue) + wasmWithMultiValue, err := multiValueWasmFunctions(ctx, runtimeWithMultiValue) if err != nil { log.Fatal(err) } defer wasmWithMultiValue.Close() // Add a module that uses multiple results values, with functions defined in Go. - hostWithMultiValue, err := multiValueHostFunctions(runtimeWithMultiValue) + hostWithMultiValue, err := multiValueHostFunctions(ctx, runtimeWithMultiValue) if err != nil { log.Fatal(err) } @@ -64,7 +68,7 @@ func main() { // Call the same function in all modules and print the results to the console. for _, mod := range []api.Module{wasm, host, wasmWithMultiValue, hostWithMultiValue} { getAge := mod.ExportedFunction("call_get_age") - results, err := getAge.Call(nil) + results, err := getAge.Call(ctx) if err != nil { log.Fatal(err) } @@ -78,8 +82,8 @@ func main() { // // To return a value in WASM written to a result parameter, you have to define memory and pass a location to write // the result. At the end of your function, you load that location. -func resultOffsetWasmFunctions(r wazero.Runtime) (api.Module, error) { - return r.InstantiateModuleFromCode([]byte(`(module $result-offset/wasm +func resultOffsetWasmFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { + return r.InstantiateModuleFromCode(ctx, []byte(`(module $result-offset/wasm ;; To use result parameters, we need scratch memory. Allocate the least possible: 1 page (64KB). (memory 1 1) @@ -112,7 +116,7 @@ func resultOffsetWasmFunctions(r wazero.Runtime) (api.Module, error) { // // To return a value in WASM written to a result parameter, you have to define memory and pass a location to write // the result. At the end of your function, you load that location. -func resultOffsetHostFunctions(r wazero.Runtime) (api.Module, error) { +func resultOffsetHostFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { return r.NewModuleBuilder("result-offset/host"). // To use result parameters, we need scratch memory. Allocate the least possible: 1 page (64KB). ExportMemoryWithMax("mem", 1, 1). @@ -125,17 +129,17 @@ func resultOffsetHostFunctions(r wazero.Runtime) (api.Module, error) { }). // Now, define a function that shows the Wasm mechanics returning something written to a result parameter. // The caller provides a memory offset to the callee, so that it knows where to write the second result. - ExportFunction("call_get_age", func(m api.Module) (age uint64) { + ExportFunction("call_get_age", func(ctx context.Context, m api.Module) (age uint64) { resultOffsetAge := uint32(8) // arbitrary memory offset (in bytes) - _, _ = m.ExportedFunction("get_age").Call(m, uint64(resultOffsetAge)) + _, _ = m.ExportedFunction("get_age").Call(ctx, uint64(resultOffsetAge)) age, _ = m.Memory().ReadUint64Le(resultOffsetAge) return - }).Instantiate() + }).Instantiate(ctx) } // multiValueWasmFunctions defines Wasm functions that illustrate multiple results using the "multiple-results" feature. -func multiValueWasmFunctions(r wazero.Runtime) (api.Module, error) { - return r.InstantiateModuleFromCode([]byte(`(module $multi-value/wasm +func multiValueWasmFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { + return r.InstantiateModuleFromCode(ctx, []byte(`(module $multi-value/wasm ;; Define a function that returns two results (func $get_age (result (;age;) i64 (;errno;) i32) @@ -155,7 +159,7 @@ func multiValueWasmFunctions(r wazero.Runtime) (api.Module, error) { } // multiValueHostFunctions defines Wasm functions that illustrate multiple results using the "multiple-results" feature. -func multiValueHostFunctions(r wazero.Runtime) (api.Module, error) { +func multiValueHostFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { return r.NewModuleBuilder("multi-value/host"). // Define a function that returns two results ExportFunction("get_age", func() (age uint64, errno uint32) { @@ -164,8 +168,8 @@ func multiValueHostFunctions(r wazero.Runtime) (api.Module, error) { return }). // Now, define a function that returns only the first result. - ExportFunction("call_get_age", func(m api.Module) (age uint64) { - results, _ := m.ExportedFunction("get_age").Call(m) + ExportFunction("call_get_age", func(ctx context.Context, m api.Module) (age uint64) { + results, _ := m.ExportedFunction("get_age").Call(ctx) return results[0] - }).Instantiate() + }).Instantiate(ctx) } diff --git a/examples/replace-import/replace-import.go b/examples/replace-import/replace-import.go index e46554ab..48ffeaf1 100644 --- a/examples/replace-import/replace-import.go +++ b/examples/replace-import/replace-import.go @@ -1,6 +1,7 @@ package replace_import import ( + "context" _ "embed" "fmt" "log" @@ -11,20 +12,24 @@ import ( // main shows how you can replace a module import when it doesn't match instantiated modules. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate a function that closes the module under "assemblyscript.abort". host, err := r.NewModuleBuilder("assemblyscript"). ExportFunction("abort", func(m api.Module, messageOffset, fileNameOffset, line, col uint32) { _ = m.CloseWithExitCode(255) - }).Instantiate() + }).Instantiate(ctx) if err != nil { log.Fatal(err) } defer host.Close() // Compile code that needs the function "env.abort". - code, err := r.CompileModule([]byte(`(module $needs-import + code, err := r.CompileModule(ctx, []byte(`(module $needs-import (import "env" "abort" (func $~lib/builtins/abort (param i32 i32 i32 i32))) (export "abort" (func 0)) ;; exports the import for testing @@ -35,7 +40,7 @@ func main() { defer code.Close() // Instantiate the module, replacing the import "env.abort" with "assemblyscript.abort". - mod, err := r.InstantiateModuleWithConfig(code, wazero.NewModuleConfig(). + mod, err := r.InstantiateModuleWithConfig(ctx, code, wazero.NewModuleConfig(). WithImport("env", "abort", "assemblyscript", "abort")) if err != nil { log.Fatal(err) @@ -43,6 +48,6 @@ func main() { defer mod.Close() // Since the above worked, the exported function closes the module. - _, err = mod.ExportedFunction("abort").Call(nil, 0, 0, 0, 0) + _, err = mod.ExportedFunction("abort").Call(ctx, 0, 0, 0, 0) fmt.Println(err) } diff --git a/examples/wasi/cat.go b/examples/wasi/cat.go index 6c5a7691..1da8a25a 100644 --- a/examples/wasi/cat.go +++ b/examples/wasi/cat.go @@ -1,6 +1,7 @@ package wasi_example import ( + "context" "embed" _ "embed" "io/fs" @@ -24,6 +25,10 @@ var catWasm []byte // This is a basic introduction to the WebAssembly System Interface (WASI). // See https://github.com/WebAssembly/WASI func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Since wazero uses fs.FS, we can use standard libraries to do things like trim the leading path. @@ -36,7 +41,7 @@ func main() { config := wazero.NewModuleConfig().WithStdout(os.Stdout).WithFS(rooted) // Instantiate WASI, which implements system I/O such as console output. - wm, err := wasi.InstantiateSnapshotPreview1(r) + wm, err := wasi.InstantiateSnapshotPreview1(ctx, r) if err != nil { log.Fatal(err) } @@ -45,7 +50,7 @@ func main() { // InstantiateModuleFromCodeWithConfig runs the "_start" function which is what TinyGo compiles "main" to. // * Set the program name (arg[0]) to "wasi" and add args to write "test.txt" to stdout twice. // * We use "/test.txt" or "./test.txt" because WithFS by default maps the workdir "." to "/". - cat, err := r.InstantiateModuleFromCodeWithConfig(catWasm, config.WithArgs("wasi", os.Args[1])) + cat, err := r.InstantiateModuleFromCodeWithConfig(ctx, catWasm, config.WithArgs("wasi", os.Args[1])) if err != nil { log.Fatal(err) } diff --git a/internal/integration_test/bench/bench_test.go b/internal/integration_test/bench/bench_test.go index 9574461e..a2b921ff 100644 --- a/internal/integration_test/bench/bench_test.go +++ b/internal/integration_test/bench/bench_test.go @@ -1,6 +1,7 @@ package bench import ( + "context" _ "embed" "fmt" "math/rand" @@ -12,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/wasi" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // caseWasm was compiled from TinyGo testdata/case.go //go:embed testdata/case.wasm var caseWasm []byte @@ -46,14 +50,14 @@ func BenchmarkInitialization(b *testing.B) { } func runInitializationBench(b *testing.B, r wazero.Runtime) { - compiled, err := r.CompileModule(caseWasm) + compiled, err := r.CompileModule(testCtx, caseWasm) if err != nil { b.Fatal(err) } defer compiled.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - mod, err := r.InstantiateModule(compiled) + mod, err := r.InstantiateModule(testCtx, compiled) if err != nil { b.Fatal(err) } @@ -77,7 +81,7 @@ func runBase64Benches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("base64_%d_per_exec", numPerExec), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := base64.Call(nil, numPerExec); err != nil { + if _, err := base64.Call(testCtx, numPerExec); err != nil { b.Fatal(err) } } @@ -93,7 +97,7 @@ func runFibBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("fib_for_%d", num), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := fibonacci.Call(nil, num); err != nil { + if _, err := fibonacci.Call(testCtx, num); err != nil { b.Fatal(err) } } @@ -109,7 +113,7 @@ func runStringManipulationBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("string_manipulation_size_%d", initialSize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := stringManipulation.Call(nil, initialSize); err != nil { + if _, err := stringManipulation.Call(testCtx, initialSize); err != nil { b.Fatal(err) } } @@ -125,7 +129,7 @@ func runReverseArrayBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("reverse_array_size_%d", arraySize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := reverseArray.Call(nil, arraySize); err != nil { + if _, err := reverseArray.Call(testCtx, arraySize); err != nil { b.Fatal(err) } } @@ -141,7 +145,7 @@ func runRandomMatMul(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("random_mat_mul_size_%d", matrixSize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := randomMatMul.Call(nil, matrixSize); err != nil { + if _, err := randomMatMul.Call(testCtx, matrixSize); err != nil { b.Fatal(err) } } @@ -153,7 +157,7 @@ func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.Runtim r := createRuntime(b, engine) // InstantiateModuleFromCode runs the "_start" function which is what TinyGo compiles "main" to. - m, err := r.InstantiateModuleFromCode(caseWasm) + m, err := r.InstantiateModuleFromCode(testCtx, caseWasm) if err != nil { b.Fatal(err) } @@ -161,30 +165,32 @@ func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.Runtim } func createRuntime(b *testing.B, engine *wazero.RuntimeConfig) wazero.Runtime { - getRandomString := func(ctx api.Module, retBufPtr uint32, retBufSize uint32) { - results, err := ctx.ExportedFunction("allocate_buffer").Call(ctx, 10) + getRandomString := func(ctx context.Context, m api.Module, retBufPtr uint32, retBufSize uint32) { + results, err := m.ExportedFunction("allocate_buffer").Call(ctx, 10) if err != nil { b.Fatal(err) } offset := uint32(results[0]) - ctx.Memory().WriteUint32Le(retBufPtr, offset) - ctx.Memory().WriteUint32Le(retBufSize, 10) + m.Memory().WriteUint32Le(retBufPtr, offset) + m.Memory().WriteUint32Le(retBufSize, 10) b := make([]byte, 10) _, _ = rand.Read(b) - ctx.Memory().Write(offset, b) + m.Memory().Write(offset, b) } r := wazero.NewRuntimeWithConfig(engine) - _, err := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Instantiate() + _, err := r.NewModuleBuilder("env"). + ExportFunction("get_random_string", getRandomString). + Instantiate(testCtx) if err != nil { b.Fatal(err) } // Note: host_func.go doesn't directly use WASI, but TinyGo needs to be initialized as a WASI Command. // Add WASI to satisfy import tests - _, err = wasi.InstantiateSnapshotPreview1(r) + _, err = wasi.InstantiateSnapshotPreview1(testCtx, r) if err != nil { b.Fatal(err) } diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index e9647ce9..e31217cf 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -15,6 +15,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + var tests = map[string]func(t *testing.T, r wazero.Runtime){ "huge stack": testHugeStack, "unreachable": testUnreachable, @@ -37,17 +40,13 @@ func TestEngineInterpreter(t *testing.T) { runAllTests(t, tests, wazero.NewRuntimeConfigInterpreter()) } -type configContextKey string - -var configContext = context.WithValue(context.Background(), configContextKey("wa"), "zero") - func runAllTests(t *testing.T, tests map[string]func(t *testing.T, r wazero.Runtime), config *wazero.RuntimeConfig) { for name, testf := range tests { name := name // pin testf := testf // pin t.Run(name, func(t *testing.T) { t.Parallel() - testf(t, wazero.NewRuntimeWithConfig(config.WithContext(configContext))) + testf(t, wazero.NewRuntimeWithConfig(config)) }) } } @@ -62,14 +61,14 @@ var ( ) func testHugeStack(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(hugestackWasm) + module, err := r.InstantiateModuleFromCode(testCtx, hugestackWasm) require.NoError(t, err) defer module.Close() fn := module.ExportedFunction("main") require.NotNil(t, fn) - _, err = fn.Call(nil) + _, err = fn.Call(testCtx) require.NoError(t, err) } @@ -78,14 +77,14 @@ func testUnreachable(t *testing.T, r wazero.Runtime) { panic("panic in host function") } - _, err := r.NewModuleBuilder("host").ExportFunction("cause_unreachable", callUnreachable).Instantiate() + _, err := r.NewModuleBuilder("host").ExportFunction("cause_unreachable", callUnreachable).Instantiate(testCtx) require.NoError(t, err) - module, err := r.InstantiateModuleFromCode(unreachableWasm) + module, err := r.InstantiateModuleFromCode(testCtx, unreachableWasm) require.NoError(t, err) defer module.Close() - _, err = module.ExportedFunction("main").Call(nil) + _, err = module.ExportedFunction("main").Call(testCtx) exp := `panic in host function (recovered by wazero) wasm stack trace: host.cause_unreachable() @@ -97,18 +96,18 @@ wasm stack trace: func testRecursiveEntry(t *testing.T, r wazero.Runtime) { hostfunc := func(mod api.Module) { - _, err := mod.ExportedFunction("called_by_host_func").Call(nil) + _, err := mod.ExportedFunction("called_by_host_func").Call(testCtx) require.NoError(t, err) } - _, err := r.NewModuleBuilder("env").ExportFunction("host_func", hostfunc).Instantiate() + _, err := r.NewModuleBuilder("env").ExportFunction("host_func", hostfunc).Instantiate(testCtx) require.NoError(t, err) - module, err := r.InstantiateModuleFromCode(recursiveWasm) + module, err := r.InstantiateModuleFromCode(testCtx, recursiveWasm) require.NoError(t, err) defer module.Close() - _, err = module.ExportedFunction("main").Call(nil, 1) + _, err = module.ExportedFunction("main").Call(testCtx, 1) require.NoError(t, err) } @@ -130,11 +129,11 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { return 0 } - host, err := r.NewModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate() + host, err := r.NewModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate(testCtx) require.NoError(t, err) defer host.Close() - module, err := r.InstantiateModuleFromCode([]byte(`(module $test + module, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $test (import "" "store_int" (func $store_int (param $offset i32) (param $val i64) (result (;errno;) i32))) (memory $memory 1 1) @@ -147,7 +146,7 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { // Call store_int and ensure it didn't return an error code. fn := module.ExportedFunction("store_int") - results, err := fn.Call(nil, 1, math.MaxUint64) + results, err := fn.Call(testCtx, 1, math.MaxUint64) require.NoError(t, err) require.Equal(t, uint64(0), results[0]) @@ -166,7 +165,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { return p + 1 }, "go_context": func(ctx context.Context, p uint32) uint32 { - require.Equal(t, configContext, ctx) + require.Equal(t, testCtx, ctx) return p + 1 }, "module_context": func(module api.Module, p uint32) uint32 { @@ -175,14 +174,14 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { }, } - imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate() + imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() for test := range fns { t.Run(test, func(t *testing.T) { // Instantiate a module that uses Wasm code to call the host function. - importing, err = r.InstantiateModuleFromCode([]byte(fmt.Sprintf(`(module $%[1]s + importing, err = r.InstantiateModuleFromCode(testCtx, []byte(fmt.Sprintf(`(module $%[1]s (import "%[2]s" "%[3]s" (func $%[3]s (param i32) (result i32))) (func $call_%[3]s (param i32) (result i32) local.get 0 call $%[3]s) (export "call->%[3]s" (func $call_%[3]s)) @@ -190,7 +189,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { require.NoError(t, err) defer importing.Close() - results, err := importing.ExportedFunction("call->"+test).Call(nil, math.MaxUint32-1) + results, err := importing.ExportedFunction("call->"+test).Call(testCtx, math.MaxUint32-1) require.NoError(t, err) require.Equal(t, uint64(math.MaxUint32), results[0]) }) @@ -217,7 +216,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { }, } - imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate() + imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() @@ -248,7 +247,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { } { t.Run(test.name, func(t *testing.T) { // Instantiate a module that uses Wasm code to call the host function. - importing, err := r.InstantiateModuleFromCode([]byte(fmt.Sprintf(`(module $%[1]s + importing, err := r.InstantiateModuleFromCode(testCtx, []byte(fmt.Sprintf(`(module $%[1]s (import "%[2]s" "%[3]s" (func $%[3]s (param %[3]s) (result %[3]s))) (func $call_%[3]s (param %[3]s) (result %[3]s) local.get 0 call $%[3]s) (export "call->%[3]s" (func $call_%[3]s)) @@ -256,7 +255,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { require.NoError(t, err) defer importing.Close() - results, err := importing.ExportedFunction("call->"+test.name).Call(nil, test.input) + results, err := importing.ExportedFunction("call->"+test.name).Call(testCtx, test.input) require.NoError(t, err) require.Equal(t, test.expected, results[0]) }) @@ -340,19 +339,19 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { // Create the host module, which exports the function that closes the importing module. importedCode, err = r.NewModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", closeAndReturn).Build() + ExportFunction("return_input", closeAndReturn).Build(testCtx) require.NoError(t, err) - imported, err = r.InstantiateModule(importedCode) + imported, err = r.InstantiateModule(testCtx, importedCode) require.NoError(t, err) defer imported.Close() // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") - importingCode, err = r.CompileModule(source) + importingCode, err = r.CompileModule(testCtx, source) require.NoError(t, err) - importing, err = r.InstantiateModule(importingCode) + importing, err = r.InstantiateModule(testCtx, importingCode) require.NoError(t, err) defer importing.Close() @@ -369,14 +368,14 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { } // Functions that return after being closed should have an exit error. - _, err = importing.ExportedFunction(tc.function).Call(nil, 5) + _, err = importing.ExportedFunction(tc.function).Call(testCtx, 5) require.Equal(t, expectedErr, err) }) } } func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { - compiled, err := r.CompileModule([]byte(`(module $test + compiled, err := r.CompileModule(testCtx, []byte(`(module $test (memory 1) (func $store i32.const 1 ;; memory offset @@ -390,7 +389,7 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { // Instantiate multiple modules with the same source (*CompiledCode). for i := 0; i < 100; i++ { - module, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) + module, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) require.NoError(t, err) defer module.Close() @@ -403,7 +402,7 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { f := module.ExportedFunction("store") require.NotNil(t, f) - _, err = f.Call(nil) + _, err = f.Call(testCtx) require.NoError(t, err) // After the call, the value must be set properly. diff --git a/internal/integration_test/engine/hammer_test.go b/internal/integration_test/engine/hammer_test.go index 99b43d7e..c798084d 100644 --- a/internal/integration_test/engine/hammer_test.go +++ b/internal/integration_test/engine/hammer_test.go @@ -35,7 +35,7 @@ func closeImportingModuleWhileInUse(t *testing.T, r wazero.Runtime) { // Prove a module can be redefined even with in-flight calls. source := callReturnImportSource(imported.Name(), importing.Name()) - importing, err := r.InstantiateModuleFromCode(source) + importing, err := r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) return imported, importing }) @@ -50,12 +50,12 @@ func closeImportedModuleWhileInUse(t *testing.T, r wazero.Runtime) { // Redefine the imported module, with a function that no longer blocks. imported, err := r.NewModuleBuilder(imported.Name()).ExportFunction("return_input", func(x uint32) uint32 { return x - }).Instantiate() + }).Instantiate(testCtx) require.NoError(t, err) // Redefine the importing module, which should link to the redefined host module. source := callReturnImportSource(imported.Name(), importing.Name()) - importing, err = r.InstantiateModuleFromCode(source) + importing, err = r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) return imported, importing @@ -78,13 +78,13 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported // Create the host module, which exports the blocking function. imported, err := r.NewModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", blockAndReturn).Instantiate() + ExportFunction("return_input", blockAndReturn).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") - importing, err := r.InstantiateModuleFromCode(source) + importing, err := r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) defer importing.Close() @@ -110,12 +110,12 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported } func requireFunctionCall(t *testing.T, fn api.Function) { - res, err := fn.Call(nil, 3) + res, err := fn.Call(testCtx, 3) require.NoError(t, err) require.Equal(t, uint64(3), res[0]) } func requireFunctionCallExits(t *testing.T, moduleName string, fn api.Function) { - _, err := fn.Call(nil, 3) + _, err := fn.Call(testCtx, 3) require.Equal(t, sys.NewExitError(moduleName, 0), err) } diff --git a/internal/integration_test/post1_0/multi-value/multi_value_test.go b/internal/integration_test/post1_0/multi-value/multi_value_test.go index e241177b..2b637753 100644 --- a/internal/integration_test/post1_0/multi-value/multi_value_test.go +++ b/internal/integration_test/post1_0/multi-value/multi_value_test.go @@ -1,6 +1,7 @@ package multi_value import ( + "context" _ "embed" "testing" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestMultiValue_JIT(t *testing.T) { if !wazero.JITSupported { t.Skip() @@ -28,32 +32,32 @@ func testMultiValue(t *testing.T, newRuntimeConfig func() *wazero.RuntimeConfig) t.Run("disabled", func(t *testing.T) { // multi-value is disabled by default. r := wazero.NewRuntimeWithConfig(newRuntimeConfig()) - _, err := r.InstantiateModuleFromCode(multiValueWasm) + _, err := r.InstantiateModuleFromCode(testCtx, multiValueWasm) require.Error(t, err) }) t.Run("enabled", func(t *testing.T) { r := wazero.NewRuntimeWithConfig(newRuntimeConfig().WithFeatureMultiValue(true)) - module, err := r.InstantiateModuleFromCode(multiValueWasm) + module, err := r.InstantiateModuleFromCode(testCtx, multiValueWasm) require.NoError(t, err) defer module.Close() swap := module.ExportedFunction("swap") - results, err := swap.Call(nil, 100, 200) + results, err := swap.Call(testCtx, 100, 200) require.NoError(t, err) require.Equal(t, []uint64{200, 100}, results) add64UWithCarry := module.ExportedFunction("add64_u_with_carry") - results, err = add64UWithCarry.Call(nil, 0x8000000000000000, 0x8000000000000000, 0) + results, err = add64UWithCarry.Call(testCtx, 0x8000000000000000, 0x8000000000000000, 0) require.NoError(t, err) require.Equal(t, []uint64{0, 1}, results) add64USaturated := module.ExportedFunction("add64_u_saturated") - results, err = add64USaturated.Call(nil, 1230, 23) + results, err = add64USaturated.Call(testCtx, 1230, 23) require.NoError(t, err) require.Equal(t, []uint64{1253}, results) fac := module.ExportedFunction("fac") - results, err = fac.Call(nil, 25) + results, err = fac.Call(testCtx, 25) require.NoError(t, err) require.Equal(t, []uint64{7034535277573963776}, results) @@ -86,7 +90,7 @@ func testMultiValue(t *testing.T, newRuntimeConfig func() *wazero.RuntimeConfig) var brWasm []byte func testBr(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(brWasm) + module, err := r.InstantiateModuleFromCode(testCtx, brWasm) require.NoError(t, err) defer module.Close() @@ -109,7 +113,7 @@ func testBr(t *testing.T, r wazero.Runtime) { var callWasm []byte func testCall(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(callWasm) + module, err := r.InstantiateModuleFromCode(testCtx, callWasm) require.NoError(t, err) defer module.Close() @@ -130,7 +134,7 @@ func testCall(t *testing.T, r wazero.Runtime) { var callIndirectWasm []byte func testCallIndirect(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(callIndirectWasm) + module, err := r.InstantiateModuleFromCode(testCtx, callIndirectWasm) require.NoError(t, err) defer module.Close() @@ -141,7 +145,7 @@ func testCallIndirect(t *testing.T, r wazero.Runtime) { {name: "type-all-i32-i64", expected: []uint64{2, 1}}, }) - _, err = module.ExportedFunction("dispatch").Call(nil, 32, 2) + _, err = module.ExportedFunction("dispatch").Call(testCtx, 32, 2) require.EqualError(t, err, `wasm error: invalid table access wasm stack trace: call_indirect.wast.[16](i32,i64) i64`) @@ -152,12 +156,12 @@ wasm stack trace: var facWasm []byte func testFac(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(facWasm) + module, err := r.InstantiateModuleFromCode(testCtx, facWasm) require.NoError(t, err) defer module.Close() fac := module.ExportedFunction("fac-ssa") - results, err := fac.Call(nil, 25) + results, err := fac.Call(testCtx, 25) require.NoError(t, err) require.Equal(t, []uint64{7034535277573963776}, results) } @@ -167,7 +171,7 @@ func testFac(t *testing.T, r wazero.Runtime) { var funcWasm []byte func testFunc(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(funcWasm) + module, err := r.InstantiateModuleFromCode(testCtx, funcWasm) require.NoError(t, err) defer module.Close() @@ -196,7 +200,7 @@ func testFunc(t *testing.T, r wazero.Runtime) { }) fac := module.ExportedFunction("large-sig") - results, err := fac.Call(nil, + results, err := fac.Call(testCtx, 0, 1, api.EncodeF32(2), api.EncodeF32(3), 4, api.EncodeF64(5), api.EncodeF32(6), 7, 8, 9, api.EncodeF32(10), api.EncodeF64(11), @@ -215,7 +219,7 @@ func testFunc(t *testing.T, r wazero.Runtime) { var ifWasm []byte func testIf(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(ifWasm) + module, err := r.InstantiateModuleFromCode(testCtx, ifWasm) require.NoError(t, err) defer module.Close() @@ -267,7 +271,7 @@ func testIf(t *testing.T, r wazero.Runtime) { var loopWasm []byte func testLoop(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(loopWasm) + module, err := r.InstantiateModuleFromCode(testCtx, loopWasm) require.NoError(t, err) defer module.Close() @@ -296,7 +300,7 @@ func testFunctions(t *testing.T, module api.Module, tests []funcTest) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - results, err := module.ExportedFunction(tc.name).Call(nil, tc.params...) + results, err := module.ExportedFunction(tc.name).Call(testCtx, tc.params...) require.NoError(t, err) if tc.expected == nil { require.Equal(t, 0, len(results), "expected no results") diff --git a/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go b/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go index 1f25dcd0..598615a0 100644 --- a/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go +++ b/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go @@ -1,6 +1,7 @@ package sign_extension_ops import ( + "context" _ "embed" "fmt" "testing" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestSignExtensionOps_JIT(t *testing.T) { if !wazero.JITSupported { t.Skip() @@ -45,12 +49,12 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC t.Run("disabled", func(t *testing.T) { // Sign-extension is disabled by default. r := wazero.NewRuntimeWithConfig(newRuntimeConfig()) - _, err := r.InstantiateModuleFromCode(signExtend) + _, err := r.InstantiateModuleFromCode(testCtx, signExtend) require.Error(t, err) }) t.Run("enabled", func(t *testing.T) { r := wazero.NewRuntimeWithConfig(newRuntimeConfig().WithFeatureSignExtensionOps(true)) - module, err := r.InstantiateModuleFromCode(signExtend) + module, err := r.InstantiateModuleFromCode(testCtx, signExtend) require.NoError(t, err) signExtend32from8Name, signExtend32from16Name := "i32.extend8_s", "i32.extend16_s" @@ -83,7 +87,7 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC fn := module.ExportedFunction(tc.funcname) require.NotNil(t, fn) - actual, err := fn.Call(nil, uint64(uint32(tc.in))) + actual, err := fn.Call(testCtx, uint64(uint32(tc.in))) require.NoError(t, err) require.Equal(t, tc.expected, int32(actual[0])) }) @@ -131,7 +135,7 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC fn := module.ExportedFunction(tc.funcname) require.NotNil(t, fn) - actual, err := fn.Call(nil, uint64(tc.in)) + actual, err := fn.Call(testCtx, uint64(tc.in)) require.NoError(t, err) require.Equal(t, tc.expected, int64(actual[0])) }) diff --git a/internal/integration_test/spectest/spec_test.go b/internal/integration_test/spectest/spec_test.go index 88934e1c..8c9ada20 100644 --- a/internal/integration_test/spectest/spec_test.go +++ b/internal/integration_test/spectest/spec_test.go @@ -23,6 +23,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasmruntime" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + //go:embed testdata/*.wasm //go:embed testdata/*.json var testcases embed.FS @@ -267,10 +270,10 @@ func addSpectestModule(t *testing.T, store *wasm.Store) { mod.TableSection = &wasm.Table{Min: 10, Max: &tableLimitMax} mod.ExportSection = append(mod.ExportSection, &wasm.Export{Name: "table", Index: 0, Type: wasm.ExternTypeTable}) - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) require.NoError(t, err) - _, err = store.Instantiate(context.Background(), mod, mod.NameSection.ModuleName, wasm.DefaultSysContext()) + _, err = store.Instantiate(testCtx, mod, mod.NameSection.ModuleName, wasm.DefaultSysContext()) require.NoError(t, err) } @@ -339,11 +342,11 @@ func runTest(t *testing.T, newEngine func(wasm.Features) wasm.Engine) { } } - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) require.NoError(t, err, msg) moduleName = strings.TrimPrefix(moduleName, "$") - _, err = store.Instantiate(context.Background(), mod, moduleName, nil) + _, err = store.Instantiate(testCtx, mod, moduleName, nil) lastInstantiatedModuleName = moduleName require.NoError(t, err) case "register": @@ -468,12 +471,12 @@ func requireInstantiationError(t *testing.T, store *wasm.Store, buf []byte, msg mod.AssignModuleID(buf) - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) if err != nil { return } - _, err = store.Instantiate(context.Background(), mod, t.Name(), nil) + _, err = store.Instantiate(testCtx, mod, t.Name(), nil) require.Error(t, err, msg) } @@ -521,6 +524,6 @@ func requireValueEq(t *testing.T, actual, expected uint64, valType wasm.ValueTyp // TODO: This is likely already covered with unit tests! func callFunction(s *wasm.Store, moduleName, funcName string, params ...uint64) ([]uint64, []wasm.ValueType, error) { fn := s.Module(moduleName).ExportedFunction(funcName) - results, err := fn.Call(nil, params...) + results, err := fn.Call(testCtx, params...) return results, fn.ResultTypes(), err } diff --git a/internal/integration_test/vs/bench_fac_test.go b/internal/integration_test/vs/bench_fac_test.go index 52bf038c..2d264487 100644 --- a/internal/integration_test/vs/bench_fac_test.go +++ b/internal/integration_test/vs/bench_fac_test.go @@ -5,6 +5,7 @@ package vs import ( + "context" _ "embed" "errors" "fmt" @@ -19,6 +20,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // ensureJITFastest is overridable via ldflags. Ex. // -ldflags '-X github.com/tetratelabs/wazero/vs.ensureJITFastest=true' var ensureJITFastest = "false" @@ -38,7 +42,7 @@ func TestFac(t *testing.T) { defer mod.Close() for i := 0; i < 10000; i++ { - res, err := fn.Call(nil, in) + res, err := fn.Call(testCtx, in) require.NoError(t, err) require.Equal(t, expValue, res[0]) } @@ -50,7 +54,7 @@ func TestFac(t *testing.T) { defer mod.Close() for i := 0; i < 10000; i++ { - res, err := fn.Call(nil, in) + res, err := fn.Call(testCtx, in) require.NoError(t, err) require.Equal(t, expValue, res[0]) } @@ -230,7 +234,7 @@ func interpreterFacInvoke(b *testing.B) { defer mod.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err = fn.Call(nil, facArgumentU64); err != nil { + if _, err = fn.Call(testCtx, facArgumentU64); err != nil { b.Fatal(err) } } @@ -244,7 +248,7 @@ func jitFacInvoke(b *testing.B) { defer mod.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err = fn.Call(nil, facArgumentU64); err != nil { + if _, err = fn.Call(testCtx, facArgumentU64); err != nil { b.Fatal(err) } } @@ -298,7 +302,7 @@ func goWasm3FacInvoke(b *testing.B) { func newWazeroFacBench(config *wazero.RuntimeConfig) (api.Module, api.Function, error) { r := wazero.NewRuntimeWithConfig(config) - m, err := r.InstantiateModuleFromCode(facWasm) + m, err := r.InstantiateModuleFromCode(testCtx, facWasm) if err != nil { return nil, nil, err } diff --git a/internal/integration_test/vs/codec_test.go b/internal/integration_test/vs/codec_test.go index 6c745ed8..d145a002 100644 --- a/internal/integration_test/vs/codec_test.go +++ b/internal/integration_test/vs/codec_test.go @@ -110,17 +110,17 @@ func TestExampleUpToDate(t *testing.T) { r := wazero.NewRuntimeWithConfig(wazero.NewRuntimeConfig().WithFinishedFeatures()) // Add WASI to satisfy import tests - wm, err := wasi.InstantiateSnapshotPreview1(r) + wm, err := wasi.InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) defer wm.Close() // Decode and instantiate the module - module, err := r.InstantiateModuleFromCode(exampleBinary) + module, err := r.InstantiateModuleFromCode(testCtx, exampleBinary) require.NoError(t, err) defer module.Close() // Call the swap function as a smoke test - results, err := module.ExportedFunction("swap").Call(nil, 1, 2) + results, err := module.ExportedFunction("swap").Call(testCtx, 1, 2) require.NoError(t, err) require.Equal(t, []uint64{2, 1}, results) }) diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 6071b29b..5e868c72 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -29,6 +29,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasmdebug" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + type EngineTester interface { NewEngine(enabledFeatures wasm.Features) wasm.Engine InitTable(me wasm.ModuleEngine, initTableLen uint32, initTableIdxToFnIdx map[wasm.Index]wasm.Index) []interface{} @@ -44,7 +47,7 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { t.Run("sets module name", func(t *testing.T) { m := &wasm.Module{} - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) me, err := e.NewModuleEngine(t.Name(), m, nil, nil, nil, nil) require.NoError(t, err) @@ -76,7 +79,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) // To use the function, we first need to add it to a module. @@ -91,17 +94,17 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { linkModuleToEngine(module, me) // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. - results, err := me.Call(module.CallCtx, fn, 3) + results, err := me.Call(testCtx, module.CallCtx, fn, 3) require.NoError(t, err) require.Equal(t, uint64(3), results[0]) t.Run("errs when not enough parameters", func(t *testing.T) { - _, err := me.Call(module.CallCtx, fn) + _, err := me.Call(testCtx, module.CallCtx, fn) require.EqualError(t, err, "expected 1 params, but passed 0") }) t.Run("errs when too many parameters", func(t *testing.T) { - _, err := me.Call(module.CallCtx, fn, 1, 2) + _, err := me.Call(testCtx, module.CallCtx, fn, 1, 2) require.EqualError(t, err, "expected 1 params, but passed 2") }) } @@ -117,7 +120,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{}, ID: wasm.ModuleID{0}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) // Instantiate the module, which has nothing but an empty table. @@ -139,7 +142,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{1}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) moduleFunctions := []*wasm.FunctionInstance{ @@ -170,7 +173,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{2}, } - err := e.CompileModule(importedModule) + err := e.CompileModule(testCtx, importedModule) require.NoError(t, err) importedModuleInstance := &wasm.ModuleInstance{} @@ -194,7 +197,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{}, ID: wasm.ModuleID{3}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) tableInit := map[wasm.Index]wasm.Index{0: 2} @@ -217,7 +220,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{4}, } - err := e.CompileModule(importedModule) + err := e.CompileModule(testCtx, importedModule) require.NoError(t, err) importedModuleInstance := &wasm.ModuleInstance{} importedFunctions := []*wasm.FunctionInstance{ @@ -241,7 +244,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{5}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) importingModuleInstance := &wasm.ModuleInstance{} @@ -284,11 +287,11 @@ func runTestModuleEngine_Call_HostFn_CallContext(t *testing.T, et EngineTester) TypeSection: []*wasm.FunctionType{sig}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) module := &wasm.ModuleInstance{Memory: memory} - modCtx := wasm.NewCallContext(context.Background(), wasm.NewStore(features, e), module, nil) + modCtx := wasm.NewCallContext(wasm.NewStore(features, e), module, nil) f := &wasm.FunctionInstance{ GoFunc: &hostFn, @@ -303,7 +306,7 @@ func runTestModuleEngine_Call_HostFn_CallContext(t *testing.T, et EngineTester) t.Run("defaults to module memory when call stack empty", func(t *testing.T) { // When calling a host func directly, there may be no stack. This ensures the module's memory is used. - results, err := me.Call(modCtx, f, 3) + results, err := me.Call(testCtx, modCtx, f, 3) require.NoError(t, err) require.Equal(t, uint64(3), results[0]) require.Same(t, memory, ctxMemory) @@ -351,7 +354,7 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - results, err := f.Module.Engine.Call(m, f, 1) + results, err := f.Module.Engine.Call(testCtx, m, f, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -475,11 +478,11 @@ wasm stack trace: t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - _, err := f.Module.Engine.Call(m, f, tc.input...) + _, err := f.Module.Engine.Call(testCtx, m, f, tc.input...) require.EqualError(t, err, tc.expectedErr) // Ensure the module still works - results, err := f.Module.Engine.Call(m, f, 1) + results, err := f.Module.Engine.Call(testCtx, m, f, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -515,7 +518,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ID: wasm.ModuleID{0}, } - err := e.CompileModule(hostFnModule) + err := e.CompileModule(testCtx, hostFnModule) require.NoError(t, err) hostFn := &wasm.FunctionInstance{GoFunc: &hostFnVal, Kind: wasm.FunctionKindGoNoContext, Type: ft} hostFnModuleInstance := &wasm.ModuleInstance{Name: "host"} @@ -536,7 +539,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ID: wasm.ModuleID{1}, } - err = e.CompileModule(importedModule) + err = e.CompileModule(testCtx, importedModule) require.NoError(t, err) // To use the function, we first need to add it to a module. @@ -560,7 +563,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ImportSection: []*wasm.Import{{}}, ID: wasm.ModuleID{2}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) // Add the exported function. @@ -592,7 +595,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo func linkModuleToEngine(module *wasm.ModuleInstance, me wasm.ModuleEngine) { module.Engine = me // for JIT, links the module to the module-engine compiled from it (moduleInstanceEngineOffset). // callEngineModuleContextModuleInstanceAddressOffset - module.CallCtx = wasm.NewCallContext(context.Background(), nil, module, nil) + module.CallCtx = wasm.NewCallContext(nil, module, nil) } // addFunction assigns and adds a function to the module. diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index ddc09ec7..034130f9 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -12,9 +12,9 @@ import ( // compile time check to ensure CallContext implements api.Module var _ api.Module = &CallContext{} -func NewCallContext(ctx context.Context, store *Store, instance *ModuleInstance, Sys *SysContext) *CallContext { +func NewCallContext(store *Store, instance *ModuleInstance, Sys *SysContext) *CallContext { zero := uint64(0) - return &CallContext{ctx: ctx, memory: instance.Memory, module: instance, store: store, Sys: Sys, closed: &zero} + return &CallContext{memory: instance.Memory, module: instance, store: store, Sys: Sys, closed: &zero} } // CallContext is a function call context bound to a module. This is important as one module's functions can call @@ -24,8 +24,11 @@ func NewCallContext(ctx context.Context, store *Store, instance *ModuleInstance, // functionality like trace propagation. // Note: this also implements api.Module in order to simplify usage as a host function parameter. type CallContext struct { - // ctx is returned by Context and overridden WithContext - ctx context.Context // TODO: remove in next PR + // TODO: We've never found a great name for this. It is only used for function calls, hence CallContext, but it + // moves on a different axis than, for example, the context.Context. context.Context is the same root for the whole + // call stack, where the CallContext can change depending on where memory is defined and who defines the calling + // function. When we rename this again, we should try to capture as many key points possible on the docs. + module *ModuleInstance // memory is returned by Memory and overridden WithMemory memory api.Memory @@ -60,7 +63,7 @@ func (m *CallContext) Name() string { // WithMemory allows overriding memory without re-allocation when the result would be the same. func (m *CallContext) WithMemory(memory *MemoryInstance) *CallContext { if memory != nil && memory != m.memory { // only re-allocate if it will change the effective memory - return &CallContext{module: m.module, memory: memory, ctx: m.ctx, Sys: m.Sys, closed: m.closed} + return &CallContext{module: m.module, memory: memory, Sys: m.Sys, closed: m.closed} } return m } @@ -70,19 +73,6 @@ func (m *CallContext) String() string { return fmt.Sprintf("Module[%s]", m.Name()) } -// Context implements the same method as documented on api.Module -func (m *CallContext) Context() context.Context { - return m.ctx -} - -// WithContext implements the same method as documented on api.Module -func (m *CallContext) WithContext(ctx context.Context) api.Module { - if ctx != nil && ctx != m.ctx { // only re-allocate if it will change the effective context - return &CallContext{module: m.module, memory: m.memory, ctx: ctx, Sys: m.Sys, closed: m.closed} - } - return m -} - // Close implements the same method as documented on api.Module. func (m *CallContext) Close() (err error) { return m.CloseWithExitCode(0) @@ -145,11 +135,12 @@ func (f *importedFn) ResultTypes() []api.ValueType { } // Call implements the same method as documented on api.Function -func (f *importedFn) Call(m api.Module, params ...uint64) (ret []uint64, err error) { - if m == nil { - return f.importedFn.Call(f.importingModule, params...) +func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { + if ctx == nil { + ctx = context.Background() } - return f.importedFn.Call(m, params...) + mod := f.importingModule + return f.importedFn.Module.Engine.Call(ctx, mod, f.importedFn, params...) } // ParamTypes implements the same method as documented on api.Function @@ -163,15 +154,12 @@ func (f *FunctionInstance) ResultTypes() []api.ValueType { } // Call implements the same method as documented on api.Function -func (f *FunctionInstance) Call(m api.Module, params ...uint64) (ret []uint64, err error) { - mod := f.Module - modCtx, ok := m.(*CallContext) - if ok { - // TODO: check if the importing context is correct - } else { // allow nil to substitute for the defining module - modCtx = mod.CallCtx +func (f *FunctionInstance) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { + if ctx == nil { + ctx = context.Background() } - return mod.Engine.Call(modCtx, f, params...) + mod := f.Module + return mod.Engine.Call(ctx, mod.CallCtx, f, params...) } // ExportedGlobal implements api.Module ExportedGlobal diff --git a/internal/wasm/call_context_test.go b/internal/wasm/call_context_test.go index bfd12938..3cc9e8ca 100644 --- a/internal/wasm/call_context_test.go +++ b/internal/wasm/call_context_test.go @@ -8,55 +8,6 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -func TestCallContext_WithContext(t *testing.T) { - type key string - tests := []struct { - name string - mod *CallContext - ctx context.Context - expectSame bool - }{ - { - name: "nil->nil: same", - mod: &CallContext{}, - ctx: nil, - expectSame: true, - }, - { - name: "nil->ctx: not same", - mod: &CallContext{}, - ctx: context.WithValue(context.Background(), key("a"), "b"), - expectSame: false, - }, - { - name: "ctx->nil: same", - mod: &CallContext{ctx: context.Background()}, - ctx: nil, - expectSame: true, - }, - { - name: "ctx1->ctx2: not same", - mod: &CallContext{ctx: context.Background()}, - ctx: context.WithValue(context.Background(), key("a"), "b"), - expectSame: false, - }, - } - - for _, tt := range tests { - tc := tt - - t.Run(tc.name, func(t *testing.T) { - mod2 := tc.mod.WithContext(tc.ctx) - if tc.expectSame { - require.Same(t, tc.mod, mod2) - } else { - require.NotSame(t, tc.mod, mod2) - require.Equal(t, tc.ctx, mod2.Context()) - } - }) - } -} - func TestCallContext_WithMemory(t *testing.T) { tests := []struct { name string diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index 80e1a428..79b72ae5 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -1,10 +1,12 @@ package wasm +import "context" + // Engine is a Store-scoped mechanism to compile functions declared or imported by a module. // This is a top-level type implemented by an interpreter or JIT compiler. type Engine interface { // CompileModule implements the same method as documented on wasm.Engine. - CompileModule(module *Module) error + CompileModule(ctx context.Context, module *Module) error // NewModuleEngine compiles down the function instances in a module, and returns ModuleEngine for the module. // @@ -17,7 +19,13 @@ type Engine interface { // // Note: Input parameters must be pre-validated with wasm.Module Validate, to ensure no fields are invalid // due to reasons such as out-of-bounds. - NewModuleEngine(name string, module *Module, importedFunctions, moduleFunctions []*FunctionInstance, table *TableInstance, tableInit map[Index]Index) (ModuleEngine, error) + NewModuleEngine( + name string, + module *Module, + importedFunctions, moduleFunctions []*FunctionInstance, + table *TableInstance, + tableInit map[Index]Index, + ) (ModuleEngine, error) // DeleteCompiledModule releases compilation caches for the given module (source). // Note: it is safe to call this function for a module from which module instances are instantiated even when these module instances @@ -31,7 +39,5 @@ type ModuleEngine interface { Name() string // Call invokes a function instance f with given parameters. - // Returns the results from the function. - // The ctx's context.Context will be the outer-most ancestor of the argument to api.Function. - Call(ctx *CallContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) + Call(ctx context.Context, m *CallContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) } diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index f3331054..a6099635 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -21,8 +21,11 @@ const ( // a context.Context. FunctionKindGoContext // FunctionKindGoModule is a function implemented in Go, with a signature matching FunctionType, except arg - // zero is a Module. + // zero is an api.Module. FunctionKindGoModule + // FunctionKindGoContextModule is a function implemented in Go, with a signature matching FunctionType, except arg + // zero is a context.Context and arg one is an api.Module. + FunctionKindGoContextModule ) // Below are reflection code to get the interface type used to parse functions and set values. @@ -31,22 +34,6 @@ var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem() var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -// getGoFuncCallContextValue returns a reflect.Value for a context param[0], or nil if there isn't one. -func getGoFuncCallContextValue(fk FunctionKind, ctx *CallContext) *reflect.Value { - switch fk { - case FunctionKindGoNoContext: // no special param zero - case FunctionKindGoContext: - val := reflect.New(goContextType).Elem() - val.Set(reflect.ValueOf(ctx.Context())) - return &val - case FunctionKindGoModule: - val := reflect.New(moduleType).Elem() - val.Set(reflect.ValueOf(ctx)) - return &val - } - return nil -} - // PopGoFuncParams pops the correct number of parameters off the stack into a parameter slice for use in CallGoFunc // // For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and @@ -55,7 +42,11 @@ func getGoFuncCallContextValue(fk FunctionKind, ctx *CallContext) *reflect.Value func PopGoFuncParams(f *FunctionInstance, popParam func() uint64) []uint64 { // First, determine how many values we need to pop paramCount := f.GoFunc.Type().NumIn() - if f.Kind != FunctionKindGoNoContext { + switch f.Kind { + case FunctionKindGoNoContext: + case FunctionKindGoContextModule: + paramCount -= 2 + default: paramCount-- } @@ -82,22 +73,30 @@ func PopValues(count int, popper func() uint64) []uint64 { // * callCtx is passed to the host function as a first argument. // // Note: ctx must use the caller's memory, which might be different from the defining module on an imported function. -func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 { +func CallGoFunc(ctx context.Context, callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 { tp := f.GoFunc.Type() var in []reflect.Value if tp.NumIn() != 0 { in = make([]reflect.Value, tp.NumIn()) - wasmParamOffset := 0 - if f.Kind != FunctionKindGoNoContext { - wasmParamOffset = 1 + i := 0 + switch f.Kind { + case FunctionKindGoContext: + in[0] = newContextVal(ctx) + i = 1 + case FunctionKindGoModule: + in[0] = newModuleVal(callCtx) + i = 1 + case FunctionKindGoContextModule: + in[0] = newContextVal(ctx) + in[1] = newModuleVal(callCtx) + i = 2 } - for i, raw := range params { - inI := i + wasmParamOffset - val := reflect.New(tp.In(inI)).Elem() - switch tp.In(inI).Kind() { + for _, raw := range params { + val := reflect.New(tp.In(i)).Elem() + switch tp.In(i).Kind() { case reflect.Float32: val.SetFloat(float64(math.Float32frombits(uint32(raw)))) case reflect.Float64: @@ -107,12 +106,8 @@ func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []ui case reflect.Int32, reflect.Int64: val.SetInt(int64(raw)) } - in[inI] = val - } - - // Handle any special parameter zero - if val := getGoFuncCallContextValue(f.Kind, callCtx); val != nil { - in[0] = *val + in[i] = val + i++ } } @@ -138,6 +133,18 @@ func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []ui return results } +func newContextVal(ctx context.Context) reflect.Value { + val := reflect.New(goContextType).Elem() + val.Set(reflect.ValueOf(ctx)) + return val +} + +func newModuleVal(m api.Module) reflect.Value { + val := reflect.New(moduleType).Elem() + val.Set(reflect.ValueOf(m)) + return val +} + // getFunctionType returns the function type corresponding to the function signature or errs if invalid. func getFunctionType(fn *reflect.Value, enabledFeatures Features) (fk FunctionKind, ft *FunctionType, err error) { p := fn.Type() @@ -147,8 +154,13 @@ func getFunctionType(fn *reflect.Value, enabledFeatures Features) (fk FunctionKi return } + fk = kind(p) pOffset := 0 - if fk = kind(p); fk != FunctionKindGoNoContext { + switch fk { + case FunctionKindGoNoContext: + case FunctionKindGoContextModule: + pOffset = 2 + default: pOffset = 1 } @@ -212,6 +224,9 @@ func kind(p reflect.Type) FunctionKind { if p0.Implements(moduleType) { return FunctionKindGoModule } else if p0.Implements(goContextType) { + if pCount >= 2 && p.In(1).Implements(moduleType) { + return FunctionKindGoContextModule + } return FunctionKindGoContext } } diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index 24b07926..fe629047 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -10,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestGetFunctionType(t *testing.T) { var tests = []struct { name string @@ -35,6 +38,12 @@ func TestGetFunctionType(t *testing.T) { expectedKind: FunctionKindGoContext, expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, }, + { + name: "context.Context and api.Module void return", + inputFunc: func(context.Context, api.Module) {}, + expectedKind: FunctionKindGoContextModule, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, { name: "all supported params and i32 result", inputFunc: func(uint32, uint64, float32, float64) uint32 { return 0 }, @@ -59,6 +68,12 @@ func TestGetFunctionType(t *testing.T) { expectedKind: FunctionKindGoContext, expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, }, + { + name: "all supported params and i32 result - context.Context and api.Module", + inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindGoContextModule, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, } for _, tt := range tests { tc := tt @@ -201,6 +216,10 @@ func TestPopGoFuncParams(t *testing.T) { name: "context.Context", inputFunc: func(context.Context) {}, }, + { + name: "context.Context and api.Module", + inputFunc: func(context.Context, api.Module) {}, + }, { name: "all supported params", inputFunc: func(uint32, uint64, float32, float64) {}, @@ -216,6 +235,11 @@ func TestPopGoFuncParams(t *testing.T) { inputFunc: func(context.Context, uint32, uint64, float32, float64) {}, expected: []uint64{4, 5, 6, 7}, }, + { + name: "all supported params - context.Context and api.Module", + inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64) {}, + expected: []uint64{4, 5, 6, 7}, + }, } for _, tt := range tests { @@ -233,9 +257,7 @@ func TestPopGoFuncParams(t *testing.T) { } func TestCallGoFunc(t *testing.T) { - expectedCtx, cancel := context.WithCancel(context.Background()) // arbitrary non-default context - defer cancel() - callCtx := &CallContext{ctx: expectedCtx} + callCtx := &CallContext{} var tests = []struct { name string @@ -255,7 +277,14 @@ func TestCallGoFunc(t *testing.T) { { name: "context.Context void return", inputFunc: func(ctx context.Context) { - require.Equal(t, expectedCtx, ctx) + require.Equal(t, testCtx, ctx) + }, + }, + { + name: "context.Context and api.Module void return", + inputFunc: func(ctx context.Context, m api.Module) { + require.Equal(t, testCtx, ctx) + require.Equal(t, callCtx, m) }, }, { @@ -318,7 +347,26 @@ func TestCallGoFunc(t *testing.T) { { name: "all supported params and i32 result - context.Context", inputFunc: func(ctx context.Context, w uint32, x uint64, y float32, z float64) uint32 { - require.Equal(t, expectedCtx, ctx) + require.Equal(t, testCtx, ctx) + require.Equal(t, uint32(math.MaxUint32), w) + require.Equal(t, uint64(math.MaxUint64), x) + require.Equal(t, float32(math.MaxFloat32), y) + require.Equal(t, math.MaxFloat64, z) + return 100 + }, + inputParams: []uint64{ + math.MaxUint32, + math.MaxUint64, + api.EncodeF32(math.MaxFloat32), + api.EncodeF64(math.MaxFloat64), + }, + expectedResults: []uint64{100}, + }, + { + name: "all supported params and i32 result - context.Context and api.Module", + inputFunc: func(ctx context.Context, m api.Module, w uint32, x uint64, y float32, z float64) uint32 { + require.Equal(t, testCtx, ctx) + require.Equal(t, callCtx, m) require.Equal(t, uint32(math.MaxUint32), w) require.Equal(t, uint64(math.MaxUint64), x) require.Equal(t, float32(math.MaxFloat32), y) @@ -342,7 +390,7 @@ func TestCallGoFunc(t *testing.T) { fk, _, err := getFunctionType(&goFunc, FeaturesFinished) require.NoError(t, err) - results := CallGoFunc(callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) + results := CallGoFunc(testCtx, callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) require.Equal(t, tc.expectedResults, results) }) } diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 52c72105..e473ff19 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -1,6 +1,7 @@ package interpreter import ( + "context" "encoding/binary" "fmt" "math" @@ -172,7 +173,7 @@ type interpreterOp struct { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { if _, ok := e.getCodes(module); ok { // cache hit! return nil } @@ -186,7 +187,7 @@ func (e *engine) CompileModule(module *wasm.Module) error { funcs = append(funcs, &code{hostFn: hf}) } } else { - irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + irs, err := wazeroir.CompileFunctions(ctx, e.enabledFeatures, module) if err != nil { return err } @@ -513,7 +514,7 @@ func (me *moduleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (me *moduleEngine) Call(ctx context.Context, m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := me.functions[f.Index] @@ -554,27 +555,27 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para for _, param := range params { ce.pushValue(param) } - ce.callNativeFunc(m, compiled) + ce.callNativeFunc(ctx, m, compiled) results = wasm.PopValues(len(f.Type.Results), ce.popValue) } else { - results = ce.callGoFunc(m, compiled, params) + results = ce.callGoFunc(ctx, m, compiled, params) } return } -func (ce *callEngine) callGoFunc(ctx *wasm.CallContext, f *function, params []uint64) (results []uint64) { +func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { if len(ce.frames) > 0 { // Use the caller's memory, which might be different from the defining module on an imported function. - ctx = ctx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) + callCtx = callCtx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) } frame := &callFrame{f: f} ce.pushFrame(frame) - results = wasm.CallGoFunc(ctx, f.source, params) + results = wasm.CallGoFunc(ctx, callCtx, f.source, params) ce.popFrame() return } -func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallContext, f *function) { frame := &callFrame{f: f} moduleInst := f.source.Module memoryInst := moduleInst.Memory @@ -621,9 +622,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { { f := functions[op.us[0]] if f.hostFn != nil { - ce.callGoFuncWithStack(ctx, f) + ce.callGoFuncWithStack(ctx, callCtx, f) } else { - ce.callNativeFunc(ctx, f) + ce.callNativeFunc(ctx, callCtx, f) } frame.pc++ } @@ -642,9 +643,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { // Call in. if targetcode.hostFn != nil { - ce.callGoFuncWithStack(ctx, f) + ce.callGoFuncWithStack(ctx, callCtx, f) } else { - ce.callNativeFunc(ctx, targetcode) + ce.callNativeFunc(ctx, callCtx, targetcode) } frame.pc++ } @@ -1555,9 +1556,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { ce.popFrame() } -func (ce *callEngine) callGoFuncWithStack(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) callGoFuncWithStack(ctx context.Context, callCtx *wasm.CallContext, f *function) { params := wasm.PopGoFuncParams(f.source, ce.popValue) - results := ce.callGoFunc(ctx, f, params) + results := ce.callGoFunc(ctx, callCtx, f, params) for _, v := range results { ce.pushValue(v) } diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index 0c09ae6c..51e0f656 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -1,6 +1,7 @@ package interpreter import ( + "context" "fmt" "math" "testing" @@ -12,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/internal/wazeroir" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestInterpreter_CallEngine_PushFrame(t *testing.T) { f1 := &callFrame{} f2 := &callFrame{} @@ -135,7 +139,7 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { {kind: wazeroir.OperationKindBr, us: []uint64{math.MaxUint64}}, }, } - ce.callNativeFunc(&wasm.CallContext{}, f) + ce.callNativeFunc(testCtx, &wasm.CallContext{}, f) require.Equal(t, tc.expected, int32(uint32(ce.popValue()))) }) } @@ -187,7 +191,7 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { {kind: wazeroir.OperationKindBr, us: []uint64{math.MaxUint64}}, }, } - ce.callNativeFunc(&wasm.CallContext{}, f) + ce.callNativeFunc(testCtx, &wasm.CallContext{}, f) require.Equal(t, tc.expected, int64(ce.popValue())) }) } @@ -220,7 +224,7 @@ func TestInterpreter_Compile(t *testing.T) { ID: wasm.ModuleID{}, } - err := e.CompileModule(errModule) + err := e.CompileModule(testCtx, errModule) require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. @@ -241,7 +245,7 @@ func TestInterpreter_Compile(t *testing.T) { }, ID: wasm.ModuleID{}, } - err := e.CompileModule(okModule) + err := e.CompileModule(testCtx, okModule) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index a922261e..1ca3af89 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -1,6 +1,7 @@ package jit import ( + "context" "fmt" "reflect" "runtime" @@ -42,8 +43,8 @@ type ( // function calls originating from the same moduleEngine.Call execution. callEngine struct { // These contexts are read and written by JITed code. - // Note: we embed these structs so we can reduce the costs to access fields inside of them. - // Also, that eases the calculation of offsets to each field. + // Note: structs are embedded to reduce the costs to access fields inside them. Also, this eases field offset + // calculation. globalContext moduleContext valueStackContext @@ -393,7 +394,7 @@ func (e *engine) DeleteCompiledModule(module *wasm.Module) { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { if _, ok := e.getCodes(module); ok { // cache hit! return nil } @@ -416,7 +417,7 @@ func (e *engine) CompileModule(module *wasm.Module) error { funcs = append(funcs, compiled) } } else { - irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + irs, err := wazeroir.CompileFunctions(ctx, e.enabledFeatures, module) if err != nil { return err } @@ -496,12 +497,12 @@ func (me *moduleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (me *moduleEngine) Call(ctx context.Context, callCtx *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := me.functions[f.Index] if compiled == nil { // Lazy check the cause as it could be because the module was already closed. - if err = m.FailIfClosed(); err == nil { + if err = callCtx.FailIfClosed(); err == nil { panic(fmt.Errorf("BUG: %s.func[%d] was nil before close", me.name, f.Index)) } return @@ -522,7 +523,7 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { - err = m.FailIfClosed() + err = callCtx.FailIfClosed() } // TODO: ^^ Will not fail if the function was imported from a closed module. @@ -545,10 +546,10 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para for _, v := range params { ce.pushValue(v) } - ce.execWasmFunction(m, compiled) + ce.execWasmFunction(ctx, callCtx, compiled) results = wasm.PopValues(len(f.Type.Results), ce.popValue) } else { - results = wasm.CallGoFunc(m, compiled.source, params) + results = wasm.CallGoFunc(ctx, callCtx, compiled.source, params) } return } @@ -648,7 +649,7 @@ const ( builtinFunctionIndexBreakPoint ) -func (ce *callEngine) execWasmFunction(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) execWasmFunction(ctx context.Context, callCtx *wasm.CallContext, f *function) { // Push the initial callframe. ce.callFrameStack[0] = callFrame{returnAddress: f.codeInitialAddress, function: f} ce.globalContext.callFrameStackPointer++ @@ -675,8 +676,9 @@ jitentry: callerFunction := ce.callFrameAt(1).function params := wasm.PopGoFuncParams(calleeHostFunction.source, ce.popValue) results := wasm.CallGoFunc( + ctx, // Use the caller's memory, which might be different from the defining module on an imported function. - ctx.WithMemory(callerFunction.source.Module.Memory), + callCtx.WithMemory(callerFunction.source.Module.Memory), calleeHostFunction.source, params, ) diff --git a/internal/wasm/jit/engine_test.go b/internal/wasm/jit/engine_test.go index 4f277180..b6353548 100644 --- a/internal/wasm/jit/engine_test.go +++ b/internal/wasm/jit/engine_test.go @@ -13,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasm" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // Ensures that the offset consts do not drift when we manipulate the target structs. func TestJIT_VerifyOffsetValue(t *testing.T) { var me moduleEngine @@ -176,11 +179,11 @@ func TestJIT_CompileModule(t *testing.T) { ID: wasm.ModuleID{}, } - err := e.CompileModule(okModule) + err := e.CompileModule(testCtx, okModule) require.NoError(t, err) // Compiling same module shouldn't be compiled again, but instead should be cached. - err = e.CompileModule(okModule) + err = e.CompileModule(testCtx, okModule) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] @@ -206,7 +209,7 @@ func TestJIT_CompileModule(t *testing.T) { } e := et.NewEngine(wasm.Features20191205).(*engine) - err := e.CompileModule(errModule) + err := e.CompileModule(testCtx, errModule) require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, the compiled functions must not be cached. @@ -215,7 +218,7 @@ func TestJIT_CompileModule(t *testing.T) { }) } -// TestReleasecode_Panic tests that an unexpected panic has some identifying information in it. +// TestJIT_Releasecode_Panic tests that an unexpected panic has some identifying information in it. func TestJIT_Releasecode_Panic(t *testing.T) { captured := require.CapturePanic(func() { releaseCode(&code{ @@ -256,10 +259,10 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { }}, map[string]*wasm.Memory{}, map[string]*wasm.Global{}, enabledFeatures) require.NoError(t, err) - err = store.Engine.CompileModule(hm) + err = store.Engine.CompileModule(testCtx, hm) require.NoError(t, err) - _, err = store.Instantiate(context.Background(), hm, hostModuleName, nil) + _, err = store.Instantiate(testCtx, hm, hostModuleName, nil) require.NoError(t, err) const valueStackCorruption = "value_stack_corruption" @@ -311,16 +314,16 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { ID: wasm.ModuleID{1}, } - err = store.Engine.CompileModule(m) + err = store.Engine.CompileModule(testCtx, m) require.NoError(t, err) - mi, err := store.Instantiate(context.Background(), m, t.Name(), nil) + mi, err := store.Instantiate(testCtx, m, t.Name(), nil) require.NoError(t, err) for _, fnName := range []string{valueStackCorruption, callStackCorruption} { fnName := fnName t.Run(fnName, func(t *testing.T) { - ret, err := mi.ExportedFunction(fnName).Call(nil) + ret, err := mi.ExportedFunction(fnName).Call(testCtx) require.NoError(t, err) require.Equal(t, uint32(expectedReturnValue), uint32(ret[0])) diff --git a/internal/wasm/jit/jit_impl_arm64.go b/internal/wasm/jit/jit_impl_arm64.go index 5198a9a8..1c6994e0 100644 --- a/internal/wasm/jit/jit_impl_arm64.go +++ b/internal/wasm/jit/jit_impl_arm64.go @@ -3056,8 +3056,8 @@ func (c *arm64Compiler) compileReservedMemoryRegisterInitialization() { } } -// compileModuleContextInitialization adds instructions to initialize ce.CallContext's fields based on -// ce.CallContext.ModuleInstanceAddress. +// compileModuleContextInitialization adds instructions to initialize ce.moduleContext's fields based on +// ce.moduleContext.ModuleInstanceAddress. // This is called in two cases: in function preamble, and on the return from (non-Go) function calls. func (c *arm64Compiler) compileModuleContextInitialization() error { c.markRegisterUsed(arm64CallingConventionModuleInstanceAddressRegister) diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 7de5a42e..39ac5949 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -245,6 +245,10 @@ func NewStore(enabledFeatures Features, engine Engine) *Store { // // Note: Module.Validate must be called prior to instantiation. func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sys *SysContext) (*CallContext, error) { + if ctx == nil { + ctx = context.Background() + } + if err := s.requireModuleName(name); err != nil { return nil, err } @@ -298,13 +302,13 @@ func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sy m.applyData(module.DataSection) // Build the default context for calls to this module. - m.CallCtx = NewCallContext(ctx, s, m, sys) + m.CallCtx = NewCallContext(s, m, sys) // Execute the start function. if module.StartSection != nil { funcIdx := *module.StartSection f := m.Functions[funcIdx] - if _, err = f.Module.Engine.Call(m.CallCtx, f); err != nil { + if _, err = f.Module.Engine.Call(ctx, m.CallCtx, f); err != nil { s.deleteModule(name) return nil, fmt.Errorf("start %s failed: %w", module.funcDesc(funcSection, funcIdx), err) } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 3db49f9c..31714a07 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -99,15 +99,12 @@ func TestStore_Instantiate(t *testing.T) { ) require.NoError(t, err) - type key string - ctx := context.WithValue(context.Background(), key("a"), "b") // arbitrary non-default context sys := &SysContext{} - mod, err := s.Instantiate(ctx, m, "", sys) + mod, err := s.Instantiate(testCtx, m, "", sys) require.NoError(t, err) defer mod.Close() t.Run("CallContext defaults", func(t *testing.T) { - require.Equal(t, ctx, mod.ctx) require.Equal(t, s.modules[""], mod.module) require.Equal(t, s.modules[""].Memory, mod.memory) require.Equal(t, s, mod.store) @@ -369,58 +366,6 @@ func TestCallContext_ExportedFunction(t *testing.T) { }) } -func TestFunctionInstance_Call(t *testing.T) { - store := NewStore(Features20191205, &mockEngine{shouldCompileFail: false, callFailIndex: -1}) - - // Add the host module - functionName := "fn" - - // This is a fake engine, so we don't capture inside the function body. - m, err := NewHostModule( - "host", - map[string]interface{}{functionName: func(api.Module) {}}, - map[string]*Memory{}, - map[string]*Global{}, - Features20191205, - ) - require.NoError(t, err) - - // Add the host module - imported, err := store.Instantiate(context.Background(), m, "host", nil) - require.NoError(t, err) - defer imported.Close() - - // Make a module to import the function - importing, err := store.Instantiate(context.Background(), &Module{ - TypeSection: []*FunctionType{{}}, - ImportSection: []*Import{{ - Type: ExternTypeFunc, - Module: imported.Name(), - Name: functionName, - DescFunc: 0, - }}, - MemorySection: &Memory{Min: 1}, - ExportSection: []*Export{{Type: ExternTypeFunc, Name: functionName, Index: 0}}, - }, "test", nil) - require.NoError(t, err) - defer imported.Close() - - fn := importing.ExportedFunction(functionName) - // me.ctx will hold the last seen context. - me := imported.module.Engine.(*mockModuleEngine) - t.Run("nil defaults to current module", func(t *testing.T) { - _, err := fn.Call(nil) - require.NoError(t, err) - require.Equal(t, importing, me.ctx) - }) - t.Run("override current module context", func(t *testing.T) { - ctx := importing.WithContext(context.TODO()) - _, err := fn.Call(ctx) - require.NoError(t, err) - require.Equal(t, ctx, me.ctx) - }) -} - type mockEngine struct { shouldCompileFail bool callFailIndex int @@ -428,7 +373,6 @@ type mockEngine struct { type mockModuleEngine struct { name string - ctx *CallContext callFailIndex int } @@ -448,7 +392,7 @@ func (e *mockEngine) NewModuleEngine(_ string, _ *Module, _, _ []*FunctionInstan func (e *mockEngine) DeleteCompiledModule(*Module) {} // CompileModule implements the same method as documented on wasm.Engine. -func (e *mockEngine) CompileModule(module *Module) error { return nil } +func (e *mockEngine) CompileModule(_ context.Context, _ *Module) error { return nil } // Name implements the same method as documented on wasm.ModuleEngine. func (e *mockModuleEngine) Name() string { @@ -456,12 +400,11 @@ func (e *mockModuleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (e *mockModuleEngine) Call(ctx *CallContext, f *FunctionInstance, _ ...uint64) (results []uint64, err error) { +func (e *mockModuleEngine) Call(ctx context.Context, callCtx *CallContext, f *FunctionInstance, _ ...uint64) (results []uint64, err error) { if e.callFailIndex >= 0 && f.Index == Index(e.callFailIndex) { err = errors.New("call failed") return } - e.ctx = ctx return } diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 2eb28008..30b3bf07 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -2,6 +2,7 @@ package wazeroir import ( "bytes" + "context" "encoding/binary" "fmt" "math" @@ -180,7 +181,7 @@ type CompilationResult struct { HasTable bool } -func CompileFunctions(enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { +func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { functions, globals, mem, table, err := module.AllDeclarations() if err != nil { return nil, err diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 029a896d..9e34130a 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -1,6 +1,7 @@ package wazeroir import ( + "context" "testing" "github.com/tetratelabs/wazero/api" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasm/text" ) +// ctx is an arbitrary, non-default context. +var ctx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + var ( f64, i32 = wasm.ValueTypeF64, wasm.ValueTypeI32 i32_i32 = &wasm.FunctionType{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}} @@ -65,7 +69,7 @@ func TestCompile(t *testing.T) { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, tc.module) + res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) require.Equal(t, tc.expected, res[0]) }) @@ -395,7 +399,7 @@ func TestCompile_MultiValue(t *testing.T) { if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, tc.module) + res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) require.Equal(t, tc.expected, res[0]) }) @@ -406,7 +410,7 @@ func requireCompilationResult(t *testing.T, enabledFeatures wasm.Features, expec if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, module) + res, err := CompileFunctions(ctx, enabledFeatures, module) require.NoError(t, err) require.Equal(t, expected, res[0]) } diff --git a/wasi/example_test.go b/wasi/example_test.go index 93f5b15a..461b2f51 100644 --- a/wasi/example_test.go +++ b/wasi/example_test.go @@ -1,6 +1,7 @@ package wasi import ( + "context" "fmt" "log" "os" @@ -13,10 +14,14 @@ import ( // // See https://github.com/tetratelabs/wazero/tree/main/examples/wasi for another example. func Example() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate WASI, which implements system I/O such as console output. - wm, err := InstantiateSnapshotPreview1(r) + wm, err := InstantiateSnapshotPreview1(ctx, r) if err != nil { log.Fatal(err) } @@ -26,7 +31,7 @@ func Example() { config := wazero.NewModuleConfig().WithStdout(os.Stdout) // InstantiateModuleFromCodeWithConfig runs the "_start" function which is like a "main" function. - _, err = r.InstantiateModuleFromCodeWithConfig([]byte(` + _, err = r.InstantiateModuleFromCodeWithConfig(ctx, []byte(` (module (import "wasi_snapshot_preview1" "proc_exit" (func $wasi.proc_exit (param $rval i32))) diff --git a/wasi/usage_test.go b/wasi/usage_test.go index a70930c8..b3048649 100644 --- a/wasi/usage_test.go +++ b/wasi/usage_test.go @@ -20,17 +20,17 @@ func TestInstantiateModuleWithConfig(t *testing.T) { // Configure WASI to write stdout to a buffer, so that we can verify it later. sys := wazero.NewModuleConfig().WithStdout(stdout) - wm, err := InstantiateSnapshotPreview1(r) + wm, err := InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) defer wm.Close() - compiled, err := r.CompileModule(wasiArg) + compiled, err := r.CompileModule(testCtx, wasiArg) require.NoError(t, err) defer compiled.Close() // Re-use the same module many times. for _, tc := range []string{"a", "b", "c"} { - mod, err := r.InstantiateModuleWithConfig(compiled, sys.WithArgs(tc).WithName(tc)) + mod, err := r.InstantiateModuleWithConfig(testCtx, compiled, sys.WithArgs(tc).WithName(tc)) require.NoError(t, err) // Ensure the scoped configuration applied. As the args are null-terminated, we append zero (NUL). diff --git a/wasi/wasi.go b/wasi/wasi.go index 18295028..6d9db6ec 100644 --- a/wasi/wasi.go +++ b/wasi/wasi.go @@ -5,6 +5,7 @@ package wasi import ( + "context" crand "crypto/rand" "errors" "fmt" @@ -25,9 +26,9 @@ const ModuleSnapshotPreview1 = "wasi_snapshot_preview1" // InstantiateSnapshotPreview1 instantiates ModuleSnapshotPreview1, so that other modules can import them. // // Note: All WASI functions return a single Errno result, ErrnoSuccess on success. -func InstantiateSnapshotPreview1(r wazero.Runtime) (api.Module, error) { +func InstantiateSnapshotPreview1(ctx context.Context, r wazero.Runtime) (api.Module, error) { _, fns := snapshotPreview1Functions() - return r.NewModuleBuilder(ModuleSnapshotPreview1).ExportFunctions(fns).Instantiate() + return r.NewModuleBuilder(ModuleSnapshotPreview1).ExportFunctions(fns).Instantiate(ctx) } const ( diff --git a/wasi/wasi_bench_test.go b/wasi/wasi_bench_test.go index 92e7bbaf..7820a187 100644 --- a/wasi/wasi_bench_test.go +++ b/wasi/wasi_bench_test.go @@ -1,7 +1,6 @@ package wasi import ( - "context" "testing" "github.com/tetratelabs/wazero/internal/testing/require" @@ -59,7 +58,7 @@ func Benchmark_EnvironGet(b *testing.B) { } func newCtx(buf []byte, sys *wasm.SysContext) *wasm.CallContext { - return wasm.NewCallContext(context.Background(), nil, &wasm.ModuleInstance{ + return wasm.NewCallContext(nil, &wasm.ModuleInstance{ Memory: &wasm.MemoryInstance{Min: 1, Buffer: buf}, }, sys) } diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index 88920159..5c6993ef 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -2,6 +2,7 @@ package wasi import ( "bytes" + "context" _ "embed" "errors" "fmt" @@ -21,6 +22,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestSnapshotPreview1_ArgsGet(t *testing.T) { sysCtx, err := newSysContext([]string{"a", "bc"}, nil, nil) require.NoError(t, err) @@ -54,7 +58,7 @@ func TestSnapshotPreview1_ArgsGet(t *testing.T) { t.Run(functionArgsGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(argv), uint64(argvBuf)) + results, err := fn.Call(testCtx, uint64(argv), uint64(argvBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -147,7 +151,7 @@ func TestSnapshotPreview1_ArgsSizesGet(t *testing.T) { t.Run(functionArgsSizesGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultArgc), uint64(resultArgvBufSize)) + results, err := fn.Call(testCtx, uint64(resultArgc), uint64(resultArgvBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -239,7 +243,7 @@ func TestSnapshotPreview1_EnvironGet(t *testing.T) { t.Run(functionEnvironGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultEnviron), uint64(resultEnvironBuf)) + results, err := fn.Call(testCtx, uint64(resultEnviron), uint64(resultEnvironBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -331,7 +335,7 @@ func TestSnapshotPreview1_EnvironSizesGet(t *testing.T) { t.Run(functionEnvironSizesGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultEnvironc), uint64(resultEnvironBufSize)) + results, err := fn.Call(testCtx, uint64(resultEnvironc), uint64(resultEnvironBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -399,7 +403,7 @@ func TestSnapshotPreview1_ClockResGet(t *testing.T) { }) t.Run(functionClockResGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -435,7 +439,7 @@ func TestSnapshotPreview1_ClockTimeGet(t *testing.T) { t.Run(functionClockTimeGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(resultTimestamp)) + results, err := fn.Call(testCtx, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(resultTimestamp)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -476,7 +480,7 @@ func TestSnapshotPreview1_ClockTimeGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - results, err := fn.Call(mod, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(tc.resultTimestamp)) + results, err := fn.Call(testCtx, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(tc.resultTimestamp)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) @@ -495,7 +499,7 @@ func TestSnapshotPreview1_FdAdvise(t *testing.T) { }) t.Run(functionFdAdvise, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -513,7 +517,7 @@ func TestSnapshotPreview1_FdAllocate(t *testing.T) { }) t.Run(functionFdAllocate, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -566,7 +570,7 @@ func TestSnapshotPreview1_FdClose(t *testing.T) { mod, fn, _ := setupFD() defer mod.Close() - results, err := fn.Call(mod, uint64(fdToClose)) + results, err := fn.Call(testCtx, uint64(fdToClose)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -593,7 +597,7 @@ func TestSnapshotPreview1_FdDatasync(t *testing.T) { }) t.Run(functionFdDatasync, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -617,7 +621,7 @@ func TestSnapshotPreview1_FdFdstatSetFlags(t *testing.T) { }) t.Run(functionFdFdstatSetFlags, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -635,7 +639,7 @@ func TestSnapshotPreview1_FdFdstatSetRights(t *testing.T) { }) t.Run(functionFdFdstatSetRights, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -653,7 +657,7 @@ func TestSnapshotPreview1_FdFilestatGet(t *testing.T) { }) t.Run(functionFdFilestatGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -671,7 +675,7 @@ func TestSnapshotPreview1_FdFilestatSetSize(t *testing.T) { }) t.Run(functionFdFilestatSetSize, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -689,7 +693,7 @@ func TestSnapshotPreview1_FdFilestatSetTimes(t *testing.T) { }) t.Run(functionFdFilestatSetTimes, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -707,7 +711,7 @@ func TestSnapshotPreview1_FdPread(t *testing.T) { }) t.Run(functionFdPread, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -748,7 +752,7 @@ func TestSnapshotPreview1_FdPrestatGet(t *testing.T) { t.Run(functionFdPrestatDirName, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(fd), uint64(resultPrestat)) + results, err := fn.Call(testCtx, uint64(fd), uint64(resultPrestat)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -833,7 +837,7 @@ func TestSnapshotPreview1_FdPrestatDirName(t *testing.T) { t.Run(functionFdPrestatDirName, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(fd), uint64(path), uint64(pathLen)) + results, err := fn.Call(testCtx, uint64(fd), uint64(path), uint64(pathLen)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -915,7 +919,7 @@ func TestSnapshotPreview1_FdPwrite(t *testing.T) { }) t.Run(functionFdPwrite, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -956,7 +960,7 @@ func TestSnapshotPreview1_FdRead(t *testing.T) { }}, {functionFdRead, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdReadFn { return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - results, err := fn.Call(mod, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) + results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) } @@ -1094,7 +1098,7 @@ func TestSnapshotPreview1_FdReaddir(t *testing.T) { }) t.Run(functionFdReaddir, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1112,7 +1116,7 @@ func TestSnapshotPreview1_FdRenumber(t *testing.T) { }) t.Run(functionFdRenumber, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1143,7 +1147,7 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { }}, {functionFdSeek, func() fdSeekFn { return func(ctx api.Module, fd uint32, offset uint64, whence, resultNewoffset uint32) Errno { - results, err := fn.Call(mod, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) + results, err := fn.Call(testCtx, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) require.NoError(t, err) return Errno(results[0]) } @@ -1287,7 +1291,7 @@ func TestSnapshotPreview1_FdSync(t *testing.T) { }) t.Run(functionFdSync, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1305,7 +1309,7 @@ func TestSnapshotPreview1_FdTell(t *testing.T) { }) t.Run(functionFdTell, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1346,7 +1350,7 @@ func TestSnapshotPreview1_FdWrite(t *testing.T) { }}, {functionFdWrite, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdWriteFn { return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - results, err := fn.Call(mod, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) + results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) } @@ -1478,7 +1482,7 @@ func TestSnapshotPreview1_PathCreateDirectory(t *testing.T) { }) t.Run(functionPathCreateDirectory, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1496,7 +1500,7 @@ func TestSnapshotPreview1_PathFilestatGet(t *testing.T) { }) t.Run(functionPathFilestatGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1514,7 +1518,7 @@ func TestSnapshotPreview1_PathFilestatSetTimes(t *testing.T) { }) t.Run(functionPathFilestatSetTimes, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1532,7 +1536,7 @@ func TestSnapshotPreview1_PathLink(t *testing.T) { }) t.Run(functionPathLink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1597,7 +1601,7 @@ func TestSnapshotPreview1_PathOpen(t *testing.T) { t.Run(functionPathOpen, func(t *testing.T) { _, mod, fn := setup() - results, err := fn.Call(mod, uint64(workdirFD), uint64(dirflags), uint64(path), uint64(pathLen), uint64(oflags), fsRightsBase, fsRightsInheriting, uint64(fdFlags), uint64(resultOpenedFd)) + results, err := fn.Call(testCtx, uint64(workdirFD), uint64(dirflags), uint64(path), uint64(pathLen), uint64(oflags), fsRightsBase, fsRightsInheriting, uint64(fdFlags), uint64(resultOpenedFd)) require.NoError(t, err) errno := Errno(results[0]) verify(errno, mod) @@ -1682,7 +1686,7 @@ func TestSnapshotPreview1_PathReadlink(t *testing.T) { }) t.Run(functionPathReadlink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1700,7 +1704,7 @@ func TestSnapshotPreview1_PathRemoveDirectory(t *testing.T) { }) t.Run(functionPathRemoveDirectory, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1718,7 +1722,7 @@ func TestSnapshotPreview1_PathRename(t *testing.T) { }) t.Run(functionPathRename, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1736,7 +1740,7 @@ func TestSnapshotPreview1_PathSymlink(t *testing.T) { }) t.Run(functionPathSymlink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1754,7 +1758,7 @@ func TestSnapshotPreview1_PathUnlinkFile(t *testing.T) { }) t.Run(functionPathUnlinkFile, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1772,7 +1776,7 @@ func TestSnapshotPreview1_PollOneoff(t *testing.T) { }) t.Run(functionPollOneoff, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1804,7 +1808,7 @@ func TestSnapshotPreview1_ProcExit(t *testing.T) { defer mod.Close() // When ProcExit is called, store.Callfunction returns immediately, returning the exit code as the error. - _, err := fn.Call(nil, uint64(tc.exitCode)) + _, err := fn.Call(testCtx, uint64(tc.exitCode)) require.Equal(t, tc.exitCode, err.(*sys.ExitError).ExitCode()) }) } @@ -1821,7 +1825,7 @@ func TestSnapshotPreview1_ProcRaise(t *testing.T) { }) t.Run(functionProcRaise, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1839,7 +1843,7 @@ func TestSnapshotPreview1_SchedYield(t *testing.T) { }) t.Run(functionSchedYield, func(t *testing.T) { - results, err := fn.Call(mod) + results, err := fn.Call(testCtx) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1883,7 +1887,7 @@ func TestSnapshotPreview1_RandomGet(t *testing.T) { t.Run(functionRandomGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(offset), uint64(length)) + results, err := fn.Call(testCtx, uint64(offset), uint64(length)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -1953,7 +1957,7 @@ func TestSnapshotPreview1_SockRecv(t *testing.T) { }) t.Run(functionSockRecv, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1971,7 +1975,7 @@ func TestSnapshotPreview1_SockSend(t *testing.T) { }) t.Run(functionSockSend, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1989,7 +1993,7 @@ func TestSnapshotPreview1_SockShutdown(t *testing.T) { }) t.Run(functionSockShutdown, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -2011,10 +2015,10 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa // The package `wazero` has a simpler interface for adding host modules, but we can't use that as it would create an // import cycle. Instead, we export wasm.NewHostModule and use it here. a, fns := snapshotPreview1Functions() - _, err := r.NewModuleBuilder("wasi_snapshot_preview1").ExportFunctions(fns).Instantiate() + _, err := r.NewModuleBuilder("wasi_snapshot_preview1").ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) - compiled, err := r.CompileModule([]byte(fmt.Sprintf(`(module + compiled, err := r.CompileModule(testCtx, []byte(fmt.Sprintf(`(module %[2]s (memory 1 1) ;; just an arbitrary size big enough for tests (export "memory" (memory 0)) @@ -2023,7 +2027,7 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa require.NoError(t, err) defer compiled.Close() - mod, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(t.Name())) + mod, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(t.Name())) require.NoError(t, err) if sysCtx != nil { diff --git a/wasm.go b/wasm.go index 58818aaa..62c90940 100644 --- a/wasm.go +++ b/wasm.go @@ -16,9 +16,10 @@ import ( // Runtime allows embedding of WebAssembly 1.0 (20191205) modules. // // Ex. +// ctx := context.Background() // r := wazero.NewRuntime() -// code, _ := r.CompileModule(source) -// module, _ := r.InstantiateModule(code) +// compiled, _ := r.CompileModule(ctx, source) +// module, _ := r.InstantiateModule(ctx, compiled) // defer module.Close() // // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/ @@ -27,10 +28,11 @@ type Runtime interface { // // Ex. Below defines and instantiates a module named "env" with one function: // + // ctx := context.Background() // hello := func() { // fmt.Fprintln(stdout, "hello!") // } - // _, err := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate() + // _, err := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx) NewModuleBuilder(moduleName string) ModuleBuilder // Module returns exports from an instantiated module or nil if there aren't any. @@ -43,38 +45,46 @@ type Runtime interface { // * Improve performance when the same module is instantiated multiple times under different names // * Reduce the amount of errors that can occur during InstantiateModule. // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: The resulting module name defaults to what was binary from the custom name section. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#name-section%E2%91%A0 - CompileModule(source []byte) (*CompiledCode, error) + CompileModule(ctx context.Context, source []byte) (*CompiledCode, error) // InstantiateModuleFromCode instantiates a module from the WebAssembly 1.0 (20191205) text or binary source or // errs if invalid. // // Ex. - // module, _ := wazero.NewRuntime().InstantiateModuleFromCode(source) + // ctx := context.Background() + // module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) // defer module.Close() // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: This is a convenience utility that chains CompileModule with InstantiateModule. To instantiate the same // source multiple times, use CompileModule as InstantiateModule avoids redundant decoding and/or compilation. - InstantiateModuleFromCode(source []byte) (api.Module, error) + InstantiateModuleFromCode(ctx context.Context, source []byte) (api.Module, error) // InstantiateModuleFromCodeWithConfig is a convenience function that chains CompileModule to // InstantiateModuleWithConfig. // // Ex. To only change the module name: - // wasm, _ := wazero.NewRuntime().InstantiateModuleFromCodeWithConfig(source, wazero.NewModuleConfig(). + // ctx := context.Background() + // r := wazero.NewRuntime() + // wasm, _ := r.InstantiateModuleFromCodeWithConfig(ctx, source, wazero.NewModuleConfig(). // WithName("wasm") // ) // defer wasm.Close() - InstantiateModuleFromCodeWithConfig(source []byte, config *ModuleConfig) (api.Module, error) + // + // Note: When `ctx` is nil, it defaults to context.Background. + InstantiateModuleFromCodeWithConfig(ctx context.Context, source []byte, config *ModuleConfig) (api.Module, error) // InstantiateModule instantiates the module namespace or errs if the configuration was invalid. // // Ex. + // ctx := context.Background() // r := wazero.NewRuntime() - // code, _ := r.CompileModule(source) - // defer code.Close() - // module, _ := r.InstantiateModule(code) + // compiled, _ := r.CompileModule(ctx, source) + // defer compiled.Close() + // module, _ := r.InstantiateModule(ctx, compiled) // defer module.Close() // // While CompiledCode is pre-validated, there are a few situations which can cause an error: @@ -82,26 +92,28 @@ type Runtime interface { // * The module has a table element initializer that resolves to an index outside the Table minimum size. // * The module has a start function, and it failed to execute. // - // Note: The last value of RuntimeConfig.WithContext is used for any start function. - InstantiateModule(code *CompiledCode) (api.Module, error) + // Note: When `ctx` is nil, it defaults to context.Background. + InstantiateModule(ctx context.Context, compiled *CompiledCode) (api.Module, error) // InstantiateModuleWithConfig is like InstantiateModule, except you can override configuration such as the module // name or ENV variables. // // For example, you can use this to define different args depending on the importing module. // + // ctx := context.Background() // r := wazero.NewRuntime() - // wasi, _ := r.InstantiateModule(wazero.WASISnapshotPreview1()) - // code, _ := r.CompileModule(source) + // wasi, _ := wasi.InstantiateSnapshotPreview1(r) + // compiled, _ := r.CompileModule(ctx, source) // // // Initialize base configuration: // config := wazero.NewModuleConfig().WithStdout(buf) // // // Assign different configuration on each instantiation - // module, _ := r.InstantiateModuleWithConfig(code, config.WithName("rotate").WithArgs("rotate", "angle=90", "dir=cw")) + // module, _ := r.InstantiateModuleWithConfig(ctx, compiled, config.WithName("rotate").WithArgs("rotate", "angle=90", "dir=cw")) // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: Config is copied during instantiation: Later changes to config do not affect the instantiated result. - InstantiateModuleWithConfig(code *CompiledCode, config *ModuleConfig) (mod api.Module, err error) + InstantiateModuleWithConfig(ctx context.Context, compiled *CompiledCode, config *ModuleConfig) (mod api.Module, err error) } func NewRuntime() Runtime { @@ -111,7 +123,6 @@ func NewRuntime() Runtime { // NewRuntimeWithConfig returns a runtime with the given configuration. func NewRuntimeWithConfig(config *RuntimeConfig) Runtime { return &runtime{ - ctx: config.ctx, store: wasm.NewStore(config.enabledFeatures, config.newEngine(config.enabledFeatures)), enabledFeatures: config.enabledFeatures, memoryMaxPages: config.memoryMaxPages, @@ -121,7 +132,6 @@ func NewRuntimeWithConfig(config *RuntimeConfig) Runtime { // runtime allows decoupling of public interfaces from internal representation. type runtime struct { enabledFeatures wasm.Features - ctx context.Context store *wasm.Store memoryMaxPages uint32 } @@ -132,7 +142,7 @@ func (r *runtime) Module(moduleName string) api.Module { } // CompileModule implements Runtime.CompileModule -func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { +func (r *runtime) CompileModule(ctx context.Context, source []byte) (*CompiledCode, error) { if source == nil { return nil, errors.New("source == nil") } @@ -166,7 +176,7 @@ func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { internal.AssignModuleID(source) - if err = r.store.Engine.CompileModule(internal); err != nil { + if err = r.store.Engine.CompileModule(ctx, internal); err != nil { return nil, err } @@ -174,47 +184,47 @@ func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { } // InstantiateModuleFromCode implements Runtime.InstantiateModuleFromCode -func (r *runtime) InstantiateModuleFromCode(source []byte) (api.Module, error) { - if code, err := r.CompileModule(source); err != nil { +func (r *runtime) InstantiateModuleFromCode(ctx context.Context, source []byte) (api.Module, error) { + if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer code.Close() - return r.InstantiateModule(code) + defer compiled.Close() + return r.InstantiateModule(ctx, compiled) } } // InstantiateModuleFromCodeWithConfig implements Runtime.InstantiateModuleFromCodeWithConfig -func (r *runtime) InstantiateModuleFromCodeWithConfig(source []byte, config *ModuleConfig) (api.Module, error) { - if code, err := r.CompileModule(source); err != nil { +func (r *runtime) InstantiateModuleFromCodeWithConfig(ctx context.Context, source []byte, config *ModuleConfig) (api.Module, error) { + if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer code.Close() - return r.InstantiateModuleWithConfig(code, config) + defer compiled.Close() + return r.InstantiateModuleWithConfig(ctx, compiled, config) } } // InstantiateModule implements Runtime.InstantiateModule -func (r *runtime) InstantiateModule(code *CompiledCode) (mod api.Module, err error) { - return r.InstantiateModuleWithConfig(code, NewModuleConfig()) +func (r *runtime) InstantiateModule(ctx context.Context, compiled *CompiledCode) (mod api.Module, err error) { + return r.InstantiateModuleWithConfig(ctx, compiled, NewModuleConfig()) } // InstantiateModuleWithConfig implements Runtime.InstantiateModuleWithConfig -func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *ModuleConfig) (mod api.Module, err error) { +func (r *runtime) InstantiateModuleWithConfig(ctx context.Context, compiled *CompiledCode, config *ModuleConfig) (mod api.Module, err error) { var sysCtx *wasm.SysContext if sysCtx, err = config.toSysContext(); err != nil { return } name := config.name - if name == "" && code.module.NameSection != nil && code.module.NameSection.ModuleName != "" { - name = code.module.NameSection.ModuleName + if name == "" && compiled.module.NameSection != nil && compiled.module.NameSection.ModuleName != "" { + name = compiled.module.NameSection.ModuleName } - module := config.replaceImports(code.module) + module := config.replaceImports(compiled.module) - mod, err = r.store.Instantiate(r.ctx, module, name, sysCtx) + mod, err = r.store.Instantiate(ctx, module, name, sysCtx) if err != nil { return } @@ -224,7 +234,7 @@ func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *Module if start == nil { continue } - if _, err = start.Call(mod.WithContext(r.ctx)); err != nil { + if _, err = start.Call(ctx); err != nil { if _, ok := err.(*sys.ExitError); ok { return } diff --git a/wasm_test.go b/wasm_test.go index 125ac204..915ac609 100644 --- a/wasm_test.go +++ b/wasm_test.go @@ -15,6 +15,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestRuntime_DecodeModule(t *testing.T) { tests := []struct { name string @@ -54,7 +57,7 @@ func TestRuntime_DecodeModule(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - code, err := r.CompileModule(tc.source) + code, err := r.CompileModule(testCtx, tc.source) require.NoError(t, err) defer code.Close() if tc.expectedName != "" { @@ -115,7 +118,7 @@ func TestRuntime_DecodeModule_Errors(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - _, err := tc.runtime.CompileModule(tc.source) + _, err := tc.runtime.CompileModule(testCtx, tc.source) require.EqualError(t, err, tc.expectedErr) }) } @@ -151,7 +154,7 @@ func TestModule_Memory(t *testing.T) { r := NewRuntime() t.Run(tc.name, func(t *testing.T) { // Instantiate the module and get the export of the above memory - module, err := tc.builder(r).Instantiate() + module, err := tc.builder(r).Instantiate(testCtx) require.NoError(t, err) defer module.Close() @@ -224,14 +227,14 @@ func TestModule_Global(t *testing.T) { if tc.module != nil { code = &CompiledCode{module: tc.module} } else { - code, _ = tc.builder(r).Build() + code, _ = tc.builder(r).Build(testCtx) } - err := r.store.Engine.CompileModule(code.module) + err := r.store.Engine.CompileModule(testCtx, code.module) require.NoError(t, err) // Instantiate the module and get the export of the above global - module, err := r.InstantiateModule(code) + module, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) defer module.Close() @@ -253,26 +256,20 @@ func TestModule_Global(t *testing.T) { } func TestFunction_Context(t *testing.T) { - type key string - runtimeCtx := context.WithValue(context.Background(), key("wa"), "zero") - config := NewRuntimeConfig().WithContext(runtimeCtx) - - notStoreCtx := context.WithValue(context.Background(), key("wazer"), "o") - tests := []struct { name string ctx context.Context expected context.Context }{ { - name: "nil defaults to runtime context", + name: "nil defaults to context.Background", ctx: nil, - expected: runtimeCtx, + expected: context.Background(), }, { - name: "set overrides runtime context", - ctx: notStoreCtx, - expected: notStoreCtx, + name: "set context", + ctx: testCtx, + expected: testCtx, }, } @@ -280,49 +277,48 @@ func TestFunction_Context(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - r := NewRuntimeWithConfig(config) + r := NewRuntime() // Define a host function so that we can catch the context propagated from a module function call functionName := "fn" expectedResult := uint64(math.MaxUint64) - hostFn := func(ctx api.Module) uint64 { - require.Equal(t, tc.expected, ctx.Context()) + hostFn := func(ctx context.Context) uint64 { + require.Equal(t, tc.expected, ctx) return expectedResult } source, closer := requireImportAndExportFunction(t, r, hostFn, functionName) defer closer() // nolint // Instantiate the module and get the export of the above hostFn - module, err := r.InstantiateModuleFromCodeWithConfig(source, NewModuleConfig().WithName(t.Name())) + module, err := r.InstantiateModuleFromCodeWithConfig(tc.ctx, source, NewModuleConfig().WithName(t.Name())) require.NoError(t, err) defer module.Close() // This fails if the function wasn't invoked, or had an unexpected context. - results, err := module.ExportedFunction(functionName).Call(module.WithContext(tc.ctx)) + results, err := module.ExportedFunction(functionName).Call(tc.ctx) require.NoError(t, err) require.Equal(t, expectedResult, results[0]) }) } } -func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { - type key string - runtimeCtx := context.WithValue(context.Background(), key("wa"), "zero") - config := NewRuntimeConfig().WithContext(runtimeCtx) - r := NewRuntimeWithConfig(config) +func TestRuntime_InstantiateModule_UsesContext(t *testing.T) { + r := NewRuntime() // Define a function that will be set as the start function var calledStart bool - start := func(ctx api.Module) { + start := func(ctx context.Context) { calledStart = true - require.Equal(t, runtimeCtx, ctx.Context()) + require.Equal(t, testCtx, ctx) } - env, err := r.NewModuleBuilder("env").ExportFunction("start", start).Instantiate() + env, err := r.NewModuleBuilder("env"). + ExportFunction("start", start). + Instantiate(testCtx) require.NoError(t, err) defer env.Close() - code, err := r.CompileModule([]byte(`(module $runtime_test.go + code, err := r.CompileModule(testCtx, []byte(`(module $runtime_test.go (import "env" "start" (func $start)) (start $start) )`)) @@ -330,7 +326,7 @@ func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { defer code.Close() // Instantiate the module, which calls the start function. This will fail if the context wasn't as intended. - m, err := r.InstantiateModule(code) + m, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) defer m.Close() @@ -341,7 +337,7 @@ func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { func TestInstantiateModuleFromCode_DoesntEnforce_Start(t *testing.T) { r := NewRuntime() - mod, err := r.InstantiateModuleFromCode([]byte(`(module $wasi_test.go + mod, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $wasi_test.go (memory 1) (export "memory" (memory 0)) )`)) @@ -349,24 +345,24 @@ func TestInstantiateModuleFromCode_DoesntEnforce_Start(t *testing.T) { require.NoError(t, mod.Close()) } -func TestInstantiateModuleFromCode_UsesRuntimeContext(t *testing.T) { - type key string - config := NewRuntimeConfig().WithContext(context.WithValue(context.Background(), key("wa"), "zero")) - r := NewRuntimeWithConfig(config) +func TestRuntime_InstantiateModuleFromCode_UsesContext(t *testing.T) { + r := NewRuntime() // Define a function that will be re-exported as the WASI function: _start var calledStart bool - start := func(ctx api.Module) { + start := func(ctx context.Context) { calledStart = true - require.Equal(t, config.ctx, ctx.Context()) + require.Equal(t, testCtx, ctx) } - host, err := r.NewModuleBuilder("").ExportFunction("start", start).Instantiate() + host, err := r.NewModuleBuilder(""). + ExportFunction("start", start). + Instantiate(testCtx) require.NoError(t, err) defer host.Close() // Start the module as a WASI command. This will fail if the context wasn't as intended. - mod, err := r.InstantiateModuleFromCode([]byte(`(module $start + mod, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $start (import "" "start" (func $start)) (memory 1) (export "_start" (func $start)) @@ -382,7 +378,7 @@ func TestInstantiateModuleFromCode_UsesRuntimeContext(t *testing.T) { // different names. This pattern is used in wapc-go. func TestInstantiateModuleWithConfig_WithName(t *testing.T) { r := NewRuntime() - base, err := r.CompileModule([]byte(`(module $0 (memory 1))`)) + base, err := r.CompileModule(testCtx, []byte(`(module $0 (memory 1))`)) require.NoError(t, err) defer base.Close() @@ -390,14 +386,14 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { // Use the same runtime to instantiate multiple modules internal := r.(*runtime).store - m1, err := r.InstantiateModuleWithConfig(base, NewModuleConfig().WithName("1")) + m1, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("1")) require.NoError(t, err) defer m1.Close() require.Nil(t, internal.Module("0")) require.Equal(t, internal.Module("1"), m1) - m2, err := r.InstantiateModuleWithConfig(base, NewModuleConfig().WithName("2")) + m2, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("2")) require.NoError(t, err) defer m2.Close() @@ -412,15 +408,15 @@ func TestInstantiateModuleWithConfig_ExitError(t *testing.T) { require.NoError(t, m.CloseWithExitCode(2)) } - _, err := r.NewModuleBuilder("env").ExportFunction("_start", start).Instantiate() + _, err := r.NewModuleBuilder("env").ExportFunction("_start", start).Instantiate(testCtx) // Ensure the exit error propagated and didn't wrap. require.Equal(t, err, sys.NewExitError("env", 2)) } // requireImportAndExportFunction re-exports a host function because only host functions can see the propagated context. -func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx api.Module) uint64, functionName string) ([]byte, func() error) { - mod, err := r.NewModuleBuilder("host").ExportFunction(functionName, hostFn).Instantiate() +func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx context.Context) uint64, functionName string) ([]byte, func() error) { + mod, err := r.NewModuleBuilder("host").ExportFunction(functionName, hostFn).Instantiate(testCtx) require.NoError(t, err) return []byte(fmt.Sprintf( @@ -434,7 +430,7 @@ func TestCompiledCode_Close(t *testing.T) { var cs []*CompiledCode for i := 0; i < 10; i++ { m := &wasm.Module{} - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) cs = append(cs, &CompiledCode{module: m, compiledEngine: e}) } @@ -465,7 +461,7 @@ func (e *mockEngine) DeleteCompiledModule(module *wasm.Module) { delete(e.cachedModules, module) } -func (e *mockEngine) CompileModule(module *wasm.Module) error { +func (e *mockEngine) CompileModule(_ context.Context, module *wasm.Module) error { e.cachedModules[module] = struct{}{} return nil }