From 45ccab589bc8dba1f40df51656b0d9ecb14d291d Mon Sep 17 00:00:00 2001 From: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com> Date: Tue, 19 Apr 2022 16:52:57 +0800 Subject: [PATCH] Refactors API to ensure context propagation (#482) This is an API breaking change that does a few things: * Stop encouraging practice that can break context propagation: * Stops caching `context.Context` in `wazero.RuntimeConfig` * Stops caching `context.Context` in `api.Module` * Fixes context propagation in function calls: * Changes `api.Function`'s arg0 from `api.Module` to `context.Context` * Adds `context.Context` parameter in instantiation (propagates to .start) * Allows context propagation for heavy operations like compile: * Adds `context.Context` as the initial parameter of `CompileModule` The design we had earlier was a good start, but this is the only way to ensure coherence when users start correlating or tracing. While adding a `context.Context` parameter may seem difficult, wazero is a low-level library and WebAssembly is notoriously difficult to troubleshoot. In other words, it will be easier to explain to users to pass (even nil) as the context parameter vs try to figure out things without coherent context. Signed-off-by: Adrian Cole --- README.md | 13 ++- api/wasm.go | 24 +---- builder.go | 41 ++++--- builder_test.go | 12 +-- config.go | 21 ---- config_test.go | 19 ---- example_test.go | 8 +- examples/basic/add.go | 10 +- examples/import-go/age-calculator.go | 13 ++- examples/multiple-results/multiple-results.go | 40 +++---- examples/replace-import/replace-import.go | 13 ++- examples/wasi/cat.go | 9 +- internal/integration_test/bench/bench_test.go | 36 ++++--- .../integration_test/engine/adhoc_test.go | 63 ++++++----- .../integration_test/engine/hammer_test.go | 14 +-- .../post1_0/multi-value/multi_value_test.go | 38 ++++--- .../sign_extension_ops_test.go | 12 ++- .../integration_test/spectest/spec_test.go | 17 +-- .../integration_test/vs/bench_fac_test.go | 14 ++- internal/integration_test/vs/codec_test.go | 6 +- internal/testing/enginetest/enginetest.go | 45 ++++---- internal/wasm/call_context.go | 48 ++++----- internal/wasm/call_context_test.go | 49 --------- internal/wasm/engine.go | 16 ++- internal/wasm/gofunc.go | 81 ++++++++------ internal/wasm/gofunc_test.go | 60 +++++++++-- internal/wasm/interpreter/interpreter.go | 31 +++--- internal/wasm/interpreter/interpreter_test.go | 12 ++- internal/wasm/jit/engine.go | 24 +++-- internal/wasm/jit/engine_test.go | 21 ++-- internal/wasm/jit/jit_impl_arm64.go | 4 +- internal/wasm/store.go | 8 +- internal/wasm/store_test.go | 63 +---------- internal/wazeroir/compiler.go | 3 +- internal/wazeroir/compiler_test.go | 10 +- wasi/example_test.go | 9 +- wasi/usage_test.go | 6 +- wasi/wasi.go | 5 +- wasi/wasi_bench_test.go | 3 +- wasi/wasi_test.go | 100 +++++++++--------- wasm.go | 84 ++++++++------- wasm_test.go | 94 ++++++++-------- 42 files changed, 592 insertions(+), 607 deletions(-) diff --git a/README.md b/README.md index 1074200b..5a9e46b9 100644 --- a/README.md +++ b/README.md @@ -17,16 +17,19 @@ For the impatient, here's how invoking a factorial function looks in wazero: ```golang func main() { + // Choose the context to use for function calls. + ctx := context.Background() + // Read a WebAssembly binary containing an exported "fac" function. // * Ex. (func (export "fac") (param i64) (result i64) ... source, _ := os.ReadFile("./tests/bench/testdata/fac.wasm") // Instantiate the module and return its exported functions - module, _ := wazero.NewRuntime().InstantiateModuleFromCode(source) + module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) defer module.Close() // Discover 7! is 5040 - fmt.Println(module.ExportedFunction("fac").Call(nil, 7)) + fmt.Println(module.ExportedFunction("fac").Call(ctx, 7)) } ``` @@ -56,7 +59,7 @@ env, err := r.NewModuleBuilder("env"). ExportFunction("log_i32", func(v uint32) { fmt.Println("log_i32 >>", v) }). - Instantiate() + Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -73,11 +76,11 @@ bundles an implementation. That way, you don't have to write these functions. For example, here's how you can allow WebAssembly modules to read "/work/home/a.txt" as "/a.txt" or "./a.txt": ```go -wm, err := wasi.InstantiateSnapshotPreview1(r) +wm, err := wasi.InstantiateSnapshotPreview1(ctx, r) defer wm.Close() config := wazero.ModuleConfig().WithFS(os.DirFS("/work/home")) -module, err := r.InstantiateModule(binary, config) +module, err := r.InstantiateModule(ctx, binary, config) defer module.Close() ... ``` diff --git a/api/wasm.go b/api/wasm.go index 6e58d1ba..fb79160f 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -84,14 +84,6 @@ type Module interface { // with the exitCode. CloseWithExitCode(exitCode uint32) error - // Context returns any propagated context from the Runtime or a prior function call. - // - // The returned context is always non-nil; it defaults to context.Background. - Context() context.Context - - // WithContext allows callers to override the propagated context, for example, to add values to it. - WithContext(ctx context.Context) Module - // Memory returns a memory defined in this module or nil if there are none wasn't. Memory() Memory @@ -129,23 +121,11 @@ type Function interface { // encoded according to ResultTypes. An error is returned for any failure looking up or invoking the function // including signature mismatch. // - // If `m` is nil, it defaults to the module the function was defined in. - // - // To override context propagation, use Module.WithContext - // fn = m.ExportedFunction("fib") - // results, err := fn(m.WithContext(ctx), 5) - // --snip-- - // - // To ensure context propagation in a host function body, pass the `ctx` parameter: - // hostFunction := func(m api.Module, offset, byteCount uint32) uint32 { - // fn = m.ExportedFunction("__read") - // results, err := fn(m, offset, byteCount) - // --snip-- - // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: If Module.Close or Module.CloseWithExitCode were invoked during this call, the error returned may be a // sys.ExitError. Interpreting this is specific to the module. For example, some "main" functions always call a // function that exits. - Call(m Module, params ...uint64) ([]uint64, error) + Call(ctx context.Context, params ...uint64) ([]uint64, error) } // Global is a WebAssembly 1.0 (20191205) global exported from an instantiated module (wazero.Runtime InstantiateModule). diff --git a/builder.go b/builder.go index c2213ae1..37441f36 100644 --- a/builder.go +++ b/builder.go @@ -1,6 +1,7 @@ package wazero import ( + "context" "fmt" "github.com/tetratelabs/wazero/api" @@ -13,19 +14,20 @@ import ( // // Ex. Below defines and instantiates a module named "env" with one function: // +// ctx := context.Background() // hello := func() { // fmt.Fprintln(stdout, "hello!") // } -// env, _ := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate() +// env, _ := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx) // // If the same module may be instantiated multiple times, it is more efficient to separate steps. Ex. // -// env, _ := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Build() +// env, _ := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Build(ctx) // -// env1, _ := r.InstantiateModuleWithConfig(env, NewModuleConfig().WithName("env.1")) +// env1, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.1")) // defer env1.Close() // -// env2, _ := r.InstantiateModuleWithConfig(env, NewModuleConfig().WithName("env.2")) +// env2, _ := r.InstantiateModuleWithConfig(ctx, env, NewModuleConfig().WithName("env.2")) // defer env2.Close() // // Note: Builder methods do not return errors, to allow chaining. Any validation errors are deferred until Build. @@ -56,11 +58,8 @@ type ModuleBuilder interface { // return x + y + m.Value(extraKey).(uint32) // } // - // The most sophisticated context is api.Module, which allows access to the Go context, but also - // allows writing to memory. This is important because there are only numeric types in Wasm. The only way to share other - // data is via writing memory and sharing offsets. - // - // Ex. This reads the parameters from! + // Ex. This uses an api.Module to reads the parameters from memory. This is important because there are only numeric + // types in Wasm. The only way to share other data is via writing memory and sharing offsets. // // addInts := func(m api.Module, offset uint32) uint32 { // x, _ := m.Memory().ReadUint32Le(offset) @@ -68,6 +67,14 @@ type ModuleBuilder interface { // return x + y // } // + // If both parameters exist, they must be in order at positions zero and one. + // + // Ex. This uses propagates context properly when calling other functions exported in the api.Module: + // callRead := func(ctx context.Context, m api.Module, offset, byteCount uint32) uint32 { + // fn = m.ExportedFunction("__read") + // results, err := fn(ctx, offset, byteCount) + // --snip-- + // // Note: If a function is already exported with the same name, this overwrites it. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#host-functions%E2%91%A2 ExportFunction(name string, goFunc interface{}) ModuleBuilder @@ -144,12 +151,12 @@ type ModuleBuilder interface { ExportGlobalF64(name string, v float64) ModuleBuilder // Build returns a module to instantiate, or returns an error if any of the configuration is invalid. - Build() (*CompiledCode, error) + Build(ctx context.Context) (*CompiledCode, error) // Instantiate is a convenience that calls Build, then Runtime.InstantiateModule // // Note: Fields in the builder are copied during instantiation: Later changes do not affect the instantiated result. - Instantiate() (api.Module, error) + Instantiate(ctx context.Context) (api.Module, error) } // moduleBuilder implements ModuleBuilder @@ -237,7 +244,7 @@ func (b *moduleBuilder) ExportGlobalF64(name string, v float64) ModuleBuilder { } // Build implements ModuleBuilder.Build -func (b *moduleBuilder) Build() (*CompiledCode, error) { +func (b *moduleBuilder) Build(ctx context.Context) (*CompiledCode, error) { // Verify the maximum limit here, so we don't have to pass it to wasm.NewHostModule maxLimit := b.r.memoryMaxPages for name, mem := range b.nameToMemory { @@ -252,7 +259,7 @@ func (b *moduleBuilder) Build() (*CompiledCode, error) { return nil, err } - if err = b.r.store.Engine.CompileModule(module); err != nil { + if err = b.r.store.Engine.CompileModule(ctx, module); err != nil { return nil, err } @@ -260,15 +267,15 @@ func (b *moduleBuilder) Build() (*CompiledCode, error) { } // Instantiate implements ModuleBuilder.Instantiate -func (b *moduleBuilder) Instantiate() (api.Module, error) { - if module, err := b.Build(); err != nil { +func (b *moduleBuilder) Instantiate(ctx context.Context) (api.Module, error) { + if module, err := b.Build(ctx); err != nil { return nil, err } else { - if err = b.r.store.Engine.CompileModule(module.module); err != nil { + if err = b.r.store.Engine.CompileModule(ctx, module.module); err != nil { return nil, err } // *wasm.ModuleInstance cannot be tracked, so we release the cache inside of this function. defer module.Close() - return b.r.InstantiateModuleWithConfig(module, NewModuleConfig().WithName(b.moduleName)) + return b.r.InstantiateModuleWithConfig(ctx, module, NewModuleConfig().WithName(b.moduleName)) } } diff --git a/builder_test.go b/builder_test.go index 4a18bd7d..877883f5 100644 --- a/builder_test.go +++ b/builder_test.go @@ -344,7 +344,7 @@ func TestNewModuleBuilder_Build(t *testing.T) { t.Run(tc.name, func(t *testing.T) { b := tc.input(NewRuntime()).(*moduleBuilder) - m, err := b.Build() + m, err := b.Build(testCtx) require.NoError(t, err) requireHostModuleEquals(t, tc.expected, m.module) @@ -352,7 +352,7 @@ func TestNewModuleBuilder_Build(t *testing.T) { require.Equal(t, b.r.store.Engine, m.compiledEngine) // Built module must be instantiable by Engine. - _, err = b.r.InstantiateModule(m) + _, err = b.r.InstantiateModule(testCtx, m) require.NoError(t, err) }) } @@ -385,7 +385,7 @@ func TestNewModuleBuilder_Build_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - _, e := tc.input(NewRuntime()).Build() + _, e := tc.input(NewRuntime()).Build(testCtx) require.EqualError(t, e, tc.expectedErr) }) } @@ -394,7 +394,7 @@ func TestNewModuleBuilder_Build_Errors(t *testing.T) { // TestNewModuleBuilder_Instantiate ensures Runtime.InstantiateModule is called on success. func TestNewModuleBuilder_Instantiate(t *testing.T) { r := NewRuntime() - m, err := r.NewModuleBuilder("env").Instantiate() + m, err := r.NewModuleBuilder("env").Instantiate(testCtx) require.NoError(t, err) // If this was instantiated, it would be added to the store under the same name @@ -404,10 +404,10 @@ func TestNewModuleBuilder_Instantiate(t *testing.T) { // TestNewModuleBuilder_Instantiate_Errors ensures errors propagate from Runtime.InstantiateModule func TestNewModuleBuilder_Instantiate_Errors(t *testing.T) { r := NewRuntime() - _, err := r.NewModuleBuilder("env").Instantiate() + _, err := r.NewModuleBuilder("env").Instantiate(testCtx) require.NoError(t, err) - _, err = r.NewModuleBuilder("env").Instantiate() + _, err = r.NewModuleBuilder("env").Instantiate(testCtx) require.EqualError(t, err, "module env has already been instantiated") } diff --git a/config.go b/config.go index 7d73a47a..f6986389 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package wazero import ( - "context" "errors" "fmt" "io" @@ -18,14 +17,12 @@ import ( type RuntimeConfig struct { enabledFeatures wasm.Features newEngine func(wasm.Features) wasm.Engine - ctx context.Context memoryMaxPages uint32 } // engineLessConfig helps avoid copy/pasting the wrong defaults. var engineLessConfig = &RuntimeConfig{ enabledFeatures: wasm.Features20191205, - ctx: context.Background(), memoryMaxPages: wasm.MemoryMaxPages, } @@ -34,7 +31,6 @@ func (c *RuntimeConfig) clone() *RuntimeConfig { return &RuntimeConfig{ enabledFeatures: c.enabledFeatures, newEngine: c.newEngine, - ctx: c.ctx, memoryMaxPages: c.memoryMaxPages, } } @@ -56,23 +52,6 @@ func NewRuntimeConfigInterpreter() *RuntimeConfig { return ret } -// WithContext sets the default context used to initialize the module. Defaults to context.Background if nil. -// -// Notes: -// * If the Module defines a start function, this is used to invoke it. -// * This is the outer-most ancestor of api.Module Context() during api.Function invocations. -// * This is the default context of api.Function when callers pass nil. -// -// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#start-function%E2%91%A0 -func (c *RuntimeConfig) WithContext(ctx context.Context) *RuntimeConfig { - if ctx == nil { - ctx = context.Background() - } - ret := c.clone() - ret.ctx = ctx - return ret -} - // WithMemoryMaxPages reduces the maximum number of pages a module can define from 65536 pages (4GiB) to a lower value. // // Notes: diff --git a/config_test.go b/config_test.go index 517f6ee7..2111dcc6 100644 --- a/config_test.go +++ b/config_test.go @@ -1,7 +1,6 @@ package wazero import ( - "context" "io" "math" "testing" @@ -17,24 +16,6 @@ func TestRuntimeConfig(t *testing.T) { with func(*RuntimeConfig) *RuntimeConfig expected *RuntimeConfig }{ - { - name: "WithContext", - with: func(c *RuntimeConfig) *RuntimeConfig { - return c.WithContext(context.TODO()) - }, - expected: &RuntimeConfig{ - ctx: context.TODO(), - }, - }, - { - name: "WithContext - nil", - with: func(c *RuntimeConfig) *RuntimeConfig { - return c.WithContext(nil) //nolint - }, - expected: &RuntimeConfig{ - ctx: context.Background(), - }, - }, { name: "WithMemoryMaxPages", with: func(c *RuntimeConfig) *RuntimeConfig { diff --git a/example_test.go b/example_test.go index 730ced0a..4d6431c1 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package wazero import ( + "context" _ "embed" "fmt" "log" @@ -10,11 +11,14 @@ import ( // // See https://github.com/tetratelabs/wazero/tree/main/examples for more examples. func Example() { + // Choose the context to use for function calls. + ctx := context.Background() + // Create a new WebAssembly Runtime. r := NewRuntime() // Add a module to the runtime named "wasm/math" which exports one function "add", implemented in WebAssembly. - mod, err := r.InstantiateModuleFromCode([]byte(`(module $wasm/math + mod, err := r.InstantiateModuleFromCode(ctx, []byte(`(module $wasm/math (func $add (param i32 i32) (result i32) local.get 0 local.get 1 @@ -31,7 +35,7 @@ func Example() { add := mod.ExportedFunction("add") x, y := uint64(1), uint64(2) - results, err := add.Call(nil, x, y) + results, err := add.Call(ctx, x, y) if err != nil { log.Fatal(err) } diff --git a/examples/basic/add.go b/examples/basic/add.go index 0a229418..44d5a14f 100644 --- a/examples/basic/add.go +++ b/examples/basic/add.go @@ -1,6 +1,7 @@ package add import ( + "context" _ "embed" "fmt" "log" @@ -13,11 +14,14 @@ import ( // main implements a basic function in both Go and WebAssembly. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Add a module to the runtime named "wasm/math" which exports one function "add", implemented in WebAssembly. - wasm, err := r.InstantiateModuleFromCode([]byte(`(module $wasm/math + wasm, err := r.InstantiateModuleFromCode(ctx, []byte(`(module $wasm/math (func $add (param i32 i32) (result i32) local.get 0 local.get 1 @@ -34,7 +38,7 @@ func main() { host, err := r.NewModuleBuilder("host/math"). ExportFunction("add", func(v1, v2 uint32) uint32 { return v1 + v2 - }).Instantiate() + }).Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -46,7 +50,7 @@ func main() { // Call the same function in both modules and print the results to the console. for _, mod := range []api.Module{wasm, host} { add := mod.ExportedFunction("add") - results, err := add.Call(nil, x, y) + results, err := add.Call(ctx, x, y) if err != nil { log.Fatal(err) } diff --git a/examples/import-go/age-calculator.go b/examples/import-go/age-calculator.go index 0d080cc1..c55b4a63 100644 --- a/examples/import-go/age-calculator.go +++ b/examples/import-go/age-calculator.go @@ -1,6 +1,7 @@ package age_calculator import ( + "context" _ "embed" "fmt" "log" @@ -16,6 +17,10 @@ import ( // // See README.md for a full description. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate a module named "env" that exports functions to get the @@ -35,7 +40,7 @@ func main() { } return uint32(time.Now().Year()) }). - Instantiate() + Instantiate(ctx) if err != nil { log.Fatal(err) } @@ -46,7 +51,7 @@ func main() { // // Note: The import syntax in both Text and Binary format is the same // regardless of if the function was defined in Go or WebAssembly. - ageCalculator, err := r.InstantiateModuleFromCode([]byte(` + ageCalculator, err := r.InstantiateModuleFromCode(ctx, []byte(` ;; Define the optional module name. '$' prefixing is a part of the text format. (module $age-calculator @@ -91,14 +96,14 @@ func main() { } // First, try calling the "get_age" function and printing to the console externally. - results, err := ageCalculator.ExportedFunction("get_age").Call(nil, birthYear) + results, err := ageCalculator.ExportedFunction("get_age").Call(ctx, birthYear) if err != nil { log.Fatal(err) } fmt.Println("println >>", results[0]) // First, try calling the "log_age" function and printing to the console externally. - _, err = ageCalculator.ExportedFunction("log_age").Call(nil, birthYear) + _, err = ageCalculator.ExportedFunction("log_age").Call(ctx, birthYear) if err != nil { log.Fatal(err) } diff --git a/examples/multiple-results/multiple-results.go b/examples/multiple-results/multiple-results.go index e5f041dd..aa8b375a 100644 --- a/examples/multiple-results/multiple-results.go +++ b/examples/multiple-results/multiple-results.go @@ -1,6 +1,7 @@ package multiple_results import ( + "context" _ "embed" "fmt" "log" @@ -24,18 +25,21 @@ import ( // * multiValueHostFunctions // See https://github.com/WebAssembly/spec/blob/main/proposals/multi-value/Overview.md func main() { - // Create a portable WebAssembly Runtime. + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. runtime := wazero.NewRuntime() // Add a module that uses offset parameters for multiple results, with functions defined in WebAssembly. - wasm, err := resultOffsetWasmFunctions(runtime) + wasm, err := resultOffsetWasmFunctions(ctx, runtime) if err != nil { log.Fatal(err) } defer wasm.Close() // Add a module that uses offset parameters for multiple results, with functions defined in Go. - host, err := resultOffsetHostFunctions(runtime) + host, err := resultOffsetHostFunctions(ctx, runtime) if err != nil { log.Fatal(err) } @@ -48,14 +52,14 @@ func main() { ) // Add a module that uses multiple results values, with functions defined in WebAssembly. - wasmWithMultiValue, err := multiValueWasmFunctions(runtimeWithMultiValue) + wasmWithMultiValue, err := multiValueWasmFunctions(ctx, runtimeWithMultiValue) if err != nil { log.Fatal(err) } defer wasmWithMultiValue.Close() // Add a module that uses multiple results values, with functions defined in Go. - hostWithMultiValue, err := multiValueHostFunctions(runtimeWithMultiValue) + hostWithMultiValue, err := multiValueHostFunctions(ctx, runtimeWithMultiValue) if err != nil { log.Fatal(err) } @@ -64,7 +68,7 @@ func main() { // Call the same function in all modules and print the results to the console. for _, mod := range []api.Module{wasm, host, wasmWithMultiValue, hostWithMultiValue} { getAge := mod.ExportedFunction("call_get_age") - results, err := getAge.Call(nil) + results, err := getAge.Call(ctx) if err != nil { log.Fatal(err) } @@ -78,8 +82,8 @@ func main() { // // To return a value in WASM written to a result parameter, you have to define memory and pass a location to write // the result. At the end of your function, you load that location. -func resultOffsetWasmFunctions(r wazero.Runtime) (api.Module, error) { - return r.InstantiateModuleFromCode([]byte(`(module $result-offset/wasm +func resultOffsetWasmFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { + return r.InstantiateModuleFromCode(ctx, []byte(`(module $result-offset/wasm ;; To use result parameters, we need scratch memory. Allocate the least possible: 1 page (64KB). (memory 1 1) @@ -112,7 +116,7 @@ func resultOffsetWasmFunctions(r wazero.Runtime) (api.Module, error) { // // To return a value in WASM written to a result parameter, you have to define memory and pass a location to write // the result. At the end of your function, you load that location. -func resultOffsetHostFunctions(r wazero.Runtime) (api.Module, error) { +func resultOffsetHostFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { return r.NewModuleBuilder("result-offset/host"). // To use result parameters, we need scratch memory. Allocate the least possible: 1 page (64KB). ExportMemoryWithMax("mem", 1, 1). @@ -125,17 +129,17 @@ func resultOffsetHostFunctions(r wazero.Runtime) (api.Module, error) { }). // Now, define a function that shows the Wasm mechanics returning something written to a result parameter. // The caller provides a memory offset to the callee, so that it knows where to write the second result. - ExportFunction("call_get_age", func(m api.Module) (age uint64) { + ExportFunction("call_get_age", func(ctx context.Context, m api.Module) (age uint64) { resultOffsetAge := uint32(8) // arbitrary memory offset (in bytes) - _, _ = m.ExportedFunction("get_age").Call(m, uint64(resultOffsetAge)) + _, _ = m.ExportedFunction("get_age").Call(ctx, uint64(resultOffsetAge)) age, _ = m.Memory().ReadUint64Le(resultOffsetAge) return - }).Instantiate() + }).Instantiate(ctx) } // multiValueWasmFunctions defines Wasm functions that illustrate multiple results using the "multiple-results" feature. -func multiValueWasmFunctions(r wazero.Runtime) (api.Module, error) { - return r.InstantiateModuleFromCode([]byte(`(module $multi-value/wasm +func multiValueWasmFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { + return r.InstantiateModuleFromCode(ctx, []byte(`(module $multi-value/wasm ;; Define a function that returns two results (func $get_age (result (;age;) i64 (;errno;) i32) @@ -155,7 +159,7 @@ func multiValueWasmFunctions(r wazero.Runtime) (api.Module, error) { } // multiValueHostFunctions defines Wasm functions that illustrate multiple results using the "multiple-results" feature. -func multiValueHostFunctions(r wazero.Runtime) (api.Module, error) { +func multiValueHostFunctions(ctx context.Context, r wazero.Runtime) (api.Module, error) { return r.NewModuleBuilder("multi-value/host"). // Define a function that returns two results ExportFunction("get_age", func() (age uint64, errno uint32) { @@ -164,8 +168,8 @@ func multiValueHostFunctions(r wazero.Runtime) (api.Module, error) { return }). // Now, define a function that returns only the first result. - ExportFunction("call_get_age", func(m api.Module) (age uint64) { - results, _ := m.ExportedFunction("get_age").Call(m) + ExportFunction("call_get_age", func(ctx context.Context, m api.Module) (age uint64) { + results, _ := m.ExportedFunction("get_age").Call(ctx) return results[0] - }).Instantiate() + }).Instantiate(ctx) } diff --git a/examples/replace-import/replace-import.go b/examples/replace-import/replace-import.go index e46554ab..48ffeaf1 100644 --- a/examples/replace-import/replace-import.go +++ b/examples/replace-import/replace-import.go @@ -1,6 +1,7 @@ package replace_import import ( + "context" _ "embed" "fmt" "log" @@ -11,20 +12,24 @@ import ( // main shows how you can replace a module import when it doesn't match instantiated modules. func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate a function that closes the module under "assemblyscript.abort". host, err := r.NewModuleBuilder("assemblyscript"). ExportFunction("abort", func(m api.Module, messageOffset, fileNameOffset, line, col uint32) { _ = m.CloseWithExitCode(255) - }).Instantiate() + }).Instantiate(ctx) if err != nil { log.Fatal(err) } defer host.Close() // Compile code that needs the function "env.abort". - code, err := r.CompileModule([]byte(`(module $needs-import + code, err := r.CompileModule(ctx, []byte(`(module $needs-import (import "env" "abort" (func $~lib/builtins/abort (param i32 i32 i32 i32))) (export "abort" (func 0)) ;; exports the import for testing @@ -35,7 +40,7 @@ func main() { defer code.Close() // Instantiate the module, replacing the import "env.abort" with "assemblyscript.abort". - mod, err := r.InstantiateModuleWithConfig(code, wazero.NewModuleConfig(). + mod, err := r.InstantiateModuleWithConfig(ctx, code, wazero.NewModuleConfig(). WithImport("env", "abort", "assemblyscript", "abort")) if err != nil { log.Fatal(err) @@ -43,6 +48,6 @@ func main() { defer mod.Close() // Since the above worked, the exported function closes the module. - _, err = mod.ExportedFunction("abort").Call(nil, 0, 0, 0, 0) + _, err = mod.ExportedFunction("abort").Call(ctx, 0, 0, 0, 0) fmt.Println(err) } diff --git a/examples/wasi/cat.go b/examples/wasi/cat.go index 6c5a7691..1da8a25a 100644 --- a/examples/wasi/cat.go +++ b/examples/wasi/cat.go @@ -1,6 +1,7 @@ package wasi_example import ( + "context" "embed" _ "embed" "io/fs" @@ -24,6 +25,10 @@ var catWasm []byte // This is a basic introduction to the WebAssembly System Interface (WASI). // See https://github.com/WebAssembly/WASI func main() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Since wazero uses fs.FS, we can use standard libraries to do things like trim the leading path. @@ -36,7 +41,7 @@ func main() { config := wazero.NewModuleConfig().WithStdout(os.Stdout).WithFS(rooted) // Instantiate WASI, which implements system I/O such as console output. - wm, err := wasi.InstantiateSnapshotPreview1(r) + wm, err := wasi.InstantiateSnapshotPreview1(ctx, r) if err != nil { log.Fatal(err) } @@ -45,7 +50,7 @@ func main() { // InstantiateModuleFromCodeWithConfig runs the "_start" function which is what TinyGo compiles "main" to. // * Set the program name (arg[0]) to "wasi" and add args to write "test.txt" to stdout twice. // * We use "/test.txt" or "./test.txt" because WithFS by default maps the workdir "." to "/". - cat, err := r.InstantiateModuleFromCodeWithConfig(catWasm, config.WithArgs("wasi", os.Args[1])) + cat, err := r.InstantiateModuleFromCodeWithConfig(ctx, catWasm, config.WithArgs("wasi", os.Args[1])) if err != nil { log.Fatal(err) } diff --git a/internal/integration_test/bench/bench_test.go b/internal/integration_test/bench/bench_test.go index 9574461e..a2b921ff 100644 --- a/internal/integration_test/bench/bench_test.go +++ b/internal/integration_test/bench/bench_test.go @@ -1,6 +1,7 @@ package bench import ( + "context" _ "embed" "fmt" "math/rand" @@ -12,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/wasi" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // caseWasm was compiled from TinyGo testdata/case.go //go:embed testdata/case.wasm var caseWasm []byte @@ -46,14 +50,14 @@ func BenchmarkInitialization(b *testing.B) { } func runInitializationBench(b *testing.B, r wazero.Runtime) { - compiled, err := r.CompileModule(caseWasm) + compiled, err := r.CompileModule(testCtx, caseWasm) if err != nil { b.Fatal(err) } defer compiled.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - mod, err := r.InstantiateModule(compiled) + mod, err := r.InstantiateModule(testCtx, compiled) if err != nil { b.Fatal(err) } @@ -77,7 +81,7 @@ func runBase64Benches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("base64_%d_per_exec", numPerExec), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := base64.Call(nil, numPerExec); err != nil { + if _, err := base64.Call(testCtx, numPerExec); err != nil { b.Fatal(err) } } @@ -93,7 +97,7 @@ func runFibBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("fib_for_%d", num), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := fibonacci.Call(nil, num); err != nil { + if _, err := fibonacci.Call(testCtx, num); err != nil { b.Fatal(err) } } @@ -109,7 +113,7 @@ func runStringManipulationBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("string_manipulation_size_%d", initialSize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := stringManipulation.Call(nil, initialSize); err != nil { + if _, err := stringManipulation.Call(testCtx, initialSize); err != nil { b.Fatal(err) } } @@ -125,7 +129,7 @@ func runReverseArrayBenches(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("reverse_array_size_%d", arraySize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := reverseArray.Call(nil, arraySize); err != nil { + if _, err := reverseArray.Call(testCtx, arraySize); err != nil { b.Fatal(err) } } @@ -141,7 +145,7 @@ func runRandomMatMul(b *testing.B, m api.Module) { b.ResetTimer() b.Run(fmt.Sprintf("random_mat_mul_size_%d", matrixSize), func(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := randomMatMul.Call(nil, matrixSize); err != nil { + if _, err := randomMatMul.Call(testCtx, matrixSize); err != nil { b.Fatal(err) } } @@ -153,7 +157,7 @@ func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.Runtim r := createRuntime(b, engine) // InstantiateModuleFromCode runs the "_start" function which is what TinyGo compiles "main" to. - m, err := r.InstantiateModuleFromCode(caseWasm) + m, err := r.InstantiateModuleFromCode(testCtx, caseWasm) if err != nil { b.Fatal(err) } @@ -161,30 +165,32 @@ func instantiateHostFunctionModuleWithEngine(b *testing.B, engine *wazero.Runtim } func createRuntime(b *testing.B, engine *wazero.RuntimeConfig) wazero.Runtime { - getRandomString := func(ctx api.Module, retBufPtr uint32, retBufSize uint32) { - results, err := ctx.ExportedFunction("allocate_buffer").Call(ctx, 10) + getRandomString := func(ctx context.Context, m api.Module, retBufPtr uint32, retBufSize uint32) { + results, err := m.ExportedFunction("allocate_buffer").Call(ctx, 10) if err != nil { b.Fatal(err) } offset := uint32(results[0]) - ctx.Memory().WriteUint32Le(retBufPtr, offset) - ctx.Memory().WriteUint32Le(retBufSize, 10) + m.Memory().WriteUint32Le(retBufPtr, offset) + m.Memory().WriteUint32Le(retBufSize, 10) b := make([]byte, 10) _, _ = rand.Read(b) - ctx.Memory().Write(offset, b) + m.Memory().Write(offset, b) } r := wazero.NewRuntimeWithConfig(engine) - _, err := r.NewModuleBuilder("env").ExportFunction("get_random_string", getRandomString).Instantiate() + _, err := r.NewModuleBuilder("env"). + ExportFunction("get_random_string", getRandomString). + Instantiate(testCtx) if err != nil { b.Fatal(err) } // Note: host_func.go doesn't directly use WASI, but TinyGo needs to be initialized as a WASI Command. // Add WASI to satisfy import tests - _, err = wasi.InstantiateSnapshotPreview1(r) + _, err = wasi.InstantiateSnapshotPreview1(testCtx, r) if err != nil { b.Fatal(err) } diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index e9647ce9..e31217cf 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -15,6 +15,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + var tests = map[string]func(t *testing.T, r wazero.Runtime){ "huge stack": testHugeStack, "unreachable": testUnreachable, @@ -37,17 +40,13 @@ func TestEngineInterpreter(t *testing.T) { runAllTests(t, tests, wazero.NewRuntimeConfigInterpreter()) } -type configContextKey string - -var configContext = context.WithValue(context.Background(), configContextKey("wa"), "zero") - func runAllTests(t *testing.T, tests map[string]func(t *testing.T, r wazero.Runtime), config *wazero.RuntimeConfig) { for name, testf := range tests { name := name // pin testf := testf // pin t.Run(name, func(t *testing.T) { t.Parallel() - testf(t, wazero.NewRuntimeWithConfig(config.WithContext(configContext))) + testf(t, wazero.NewRuntimeWithConfig(config)) }) } } @@ -62,14 +61,14 @@ var ( ) func testHugeStack(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(hugestackWasm) + module, err := r.InstantiateModuleFromCode(testCtx, hugestackWasm) require.NoError(t, err) defer module.Close() fn := module.ExportedFunction("main") require.NotNil(t, fn) - _, err = fn.Call(nil) + _, err = fn.Call(testCtx) require.NoError(t, err) } @@ -78,14 +77,14 @@ func testUnreachable(t *testing.T, r wazero.Runtime) { panic("panic in host function") } - _, err := r.NewModuleBuilder("host").ExportFunction("cause_unreachable", callUnreachable).Instantiate() + _, err := r.NewModuleBuilder("host").ExportFunction("cause_unreachable", callUnreachable).Instantiate(testCtx) require.NoError(t, err) - module, err := r.InstantiateModuleFromCode(unreachableWasm) + module, err := r.InstantiateModuleFromCode(testCtx, unreachableWasm) require.NoError(t, err) defer module.Close() - _, err = module.ExportedFunction("main").Call(nil) + _, err = module.ExportedFunction("main").Call(testCtx) exp := `panic in host function (recovered by wazero) wasm stack trace: host.cause_unreachable() @@ -97,18 +96,18 @@ wasm stack trace: func testRecursiveEntry(t *testing.T, r wazero.Runtime) { hostfunc := func(mod api.Module) { - _, err := mod.ExportedFunction("called_by_host_func").Call(nil) + _, err := mod.ExportedFunction("called_by_host_func").Call(testCtx) require.NoError(t, err) } - _, err := r.NewModuleBuilder("env").ExportFunction("host_func", hostfunc).Instantiate() + _, err := r.NewModuleBuilder("env").ExportFunction("host_func", hostfunc).Instantiate(testCtx) require.NoError(t, err) - module, err := r.InstantiateModuleFromCode(recursiveWasm) + module, err := r.InstantiateModuleFromCode(testCtx, recursiveWasm) require.NoError(t, err) defer module.Close() - _, err = module.ExportedFunction("main").Call(nil, 1) + _, err = module.ExportedFunction("main").Call(testCtx, 1) require.NoError(t, err) } @@ -130,11 +129,11 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { return 0 } - host, err := r.NewModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate() + host, err := r.NewModuleBuilder("").ExportFunction("store_int", storeInt).Instantiate(testCtx) require.NoError(t, err) defer host.Close() - module, err := r.InstantiateModuleFromCode([]byte(`(module $test + module, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $test (import "" "store_int" (func $store_int (param $offset i32) (param $val i64) (result (;errno;) i32))) (memory $memory 1 1) @@ -147,7 +146,7 @@ func testImportedAndExportedFunc(t *testing.T, r wazero.Runtime) { // Call store_int and ensure it didn't return an error code. fn := module.ExportedFunction("store_int") - results, err := fn.Call(nil, 1, math.MaxUint64) + results, err := fn.Call(testCtx, 1, math.MaxUint64) require.NoError(t, err) require.Equal(t, uint64(0), results[0]) @@ -166,7 +165,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { return p + 1 }, "go_context": func(ctx context.Context, p uint32) uint32 { - require.Equal(t, configContext, ctx) + require.Equal(t, testCtx, ctx) return p + 1 }, "module_context": func(module api.Module, p uint32) uint32 { @@ -175,14 +174,14 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { }, } - imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate() + imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() for test := range fns { t.Run(test, func(t *testing.T) { // Instantiate a module that uses Wasm code to call the host function. - importing, err = r.InstantiateModuleFromCode([]byte(fmt.Sprintf(`(module $%[1]s + importing, err = r.InstantiateModuleFromCode(testCtx, []byte(fmt.Sprintf(`(module $%[1]s (import "%[2]s" "%[3]s" (func $%[3]s (param i32) (result i32))) (func $call_%[3]s (param i32) (result i32) local.get 0 call $%[3]s) (export "call->%[3]s" (func $call_%[3]s)) @@ -190,7 +189,7 @@ func testHostFunctionContextParameter(t *testing.T, r wazero.Runtime) { require.NoError(t, err) defer importing.Close() - results, err := importing.ExportedFunction("call->"+test).Call(nil, math.MaxUint32-1) + results, err := importing.ExportedFunction("call->"+test).Call(testCtx, math.MaxUint32-1) require.NoError(t, err) require.Equal(t, uint64(math.MaxUint32), results[0]) }) @@ -217,7 +216,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { }, } - imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate() + imported, err := r.NewModuleBuilder(importedName).ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() @@ -248,7 +247,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { } { t.Run(test.name, func(t *testing.T) { // Instantiate a module that uses Wasm code to call the host function. - importing, err := r.InstantiateModuleFromCode([]byte(fmt.Sprintf(`(module $%[1]s + importing, err := r.InstantiateModuleFromCode(testCtx, []byte(fmt.Sprintf(`(module $%[1]s (import "%[2]s" "%[3]s" (func $%[3]s (param %[3]s) (result %[3]s))) (func $call_%[3]s (param %[3]s) (result %[3]s) local.get 0 call $%[3]s) (export "call->%[3]s" (func $call_%[3]s)) @@ -256,7 +255,7 @@ func testHostFunctionNumericParameter(t *testing.T, r wazero.Runtime) { require.NoError(t, err) defer importing.Close() - results, err := importing.ExportedFunction("call->"+test.name).Call(nil, test.input) + results, err := importing.ExportedFunction("call->"+test.name).Call(testCtx, test.input) require.NoError(t, err) require.Equal(t, test.expected, results[0]) }) @@ -340,19 +339,19 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { // Create the host module, which exports the function that closes the importing module. importedCode, err = r.NewModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", closeAndReturn).Build() + ExportFunction("return_input", closeAndReturn).Build(testCtx) require.NoError(t, err) - imported, err = r.InstantiateModule(importedCode) + imported, err = r.InstantiateModule(testCtx, importedCode) require.NoError(t, err) defer imported.Close() // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") - importingCode, err = r.CompileModule(source) + importingCode, err = r.CompileModule(testCtx, source) require.NoError(t, err) - importing, err = r.InstantiateModule(importingCode) + importing, err = r.InstantiateModule(testCtx, importingCode) require.NoError(t, err) defer importing.Close() @@ -369,14 +368,14 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { } // Functions that return after being closed should have an exit error. - _, err = importing.ExportedFunction(tc.function).Call(nil, 5) + _, err = importing.ExportedFunction(tc.function).Call(testCtx, 5) require.Equal(t, expectedErr, err) }) } } func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { - compiled, err := r.CompileModule([]byte(`(module $test + compiled, err := r.CompileModule(testCtx, []byte(`(module $test (memory 1) (func $store i32.const 1 ;; memory offset @@ -390,7 +389,7 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { // Instantiate multiple modules with the same source (*CompiledCode). for i := 0; i < 100; i++ { - module, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) + module, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(strconv.Itoa(i))) require.NoError(t, err) defer module.Close() @@ -403,7 +402,7 @@ func testMultipleInstantiation(t *testing.T, r wazero.Runtime) { f := module.ExportedFunction("store") require.NotNil(t, f) - _, err = f.Call(nil) + _, err = f.Call(testCtx) require.NoError(t, err) // After the call, the value must be set properly. diff --git a/internal/integration_test/engine/hammer_test.go b/internal/integration_test/engine/hammer_test.go index 99b43d7e..c798084d 100644 --- a/internal/integration_test/engine/hammer_test.go +++ b/internal/integration_test/engine/hammer_test.go @@ -35,7 +35,7 @@ func closeImportingModuleWhileInUse(t *testing.T, r wazero.Runtime) { // Prove a module can be redefined even with in-flight calls. source := callReturnImportSource(imported.Name(), importing.Name()) - importing, err := r.InstantiateModuleFromCode(source) + importing, err := r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) return imported, importing }) @@ -50,12 +50,12 @@ func closeImportedModuleWhileInUse(t *testing.T, r wazero.Runtime) { // Redefine the imported module, with a function that no longer blocks. imported, err := r.NewModuleBuilder(imported.Name()).ExportFunction("return_input", func(x uint32) uint32 { return x - }).Instantiate() + }).Instantiate(testCtx) require.NoError(t, err) // Redefine the importing module, which should link to the redefined host module. source := callReturnImportSource(imported.Name(), importing.Name()) - importing, err = r.InstantiateModuleFromCode(source) + importing, err = r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) return imported, importing @@ -78,13 +78,13 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported // Create the host module, which exports the blocking function. imported, err := r.NewModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", blockAndReturn).Instantiate() + ExportFunction("return_input", blockAndReturn).Instantiate(testCtx) require.NoError(t, err) defer imported.Close() // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") - importing, err := r.InstantiateModuleFromCode(source) + importing, err := r.InstantiateModuleFromCode(testCtx, source) require.NoError(t, err) defer importing.Close() @@ -110,12 +110,12 @@ func closeModuleWhileInUse(t *testing.T, r wazero.Runtime, closeFn func(imported } func requireFunctionCall(t *testing.T, fn api.Function) { - res, err := fn.Call(nil, 3) + res, err := fn.Call(testCtx, 3) require.NoError(t, err) require.Equal(t, uint64(3), res[0]) } func requireFunctionCallExits(t *testing.T, moduleName string, fn api.Function) { - _, err := fn.Call(nil, 3) + _, err := fn.Call(testCtx, 3) require.Equal(t, sys.NewExitError(moduleName, 0), err) } diff --git a/internal/integration_test/post1_0/multi-value/multi_value_test.go b/internal/integration_test/post1_0/multi-value/multi_value_test.go index e241177b..2b637753 100644 --- a/internal/integration_test/post1_0/multi-value/multi_value_test.go +++ b/internal/integration_test/post1_0/multi-value/multi_value_test.go @@ -1,6 +1,7 @@ package multi_value import ( + "context" _ "embed" "testing" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestMultiValue_JIT(t *testing.T) { if !wazero.JITSupported { t.Skip() @@ -28,32 +32,32 @@ func testMultiValue(t *testing.T, newRuntimeConfig func() *wazero.RuntimeConfig) t.Run("disabled", func(t *testing.T) { // multi-value is disabled by default. r := wazero.NewRuntimeWithConfig(newRuntimeConfig()) - _, err := r.InstantiateModuleFromCode(multiValueWasm) + _, err := r.InstantiateModuleFromCode(testCtx, multiValueWasm) require.Error(t, err) }) t.Run("enabled", func(t *testing.T) { r := wazero.NewRuntimeWithConfig(newRuntimeConfig().WithFeatureMultiValue(true)) - module, err := r.InstantiateModuleFromCode(multiValueWasm) + module, err := r.InstantiateModuleFromCode(testCtx, multiValueWasm) require.NoError(t, err) defer module.Close() swap := module.ExportedFunction("swap") - results, err := swap.Call(nil, 100, 200) + results, err := swap.Call(testCtx, 100, 200) require.NoError(t, err) require.Equal(t, []uint64{200, 100}, results) add64UWithCarry := module.ExportedFunction("add64_u_with_carry") - results, err = add64UWithCarry.Call(nil, 0x8000000000000000, 0x8000000000000000, 0) + results, err = add64UWithCarry.Call(testCtx, 0x8000000000000000, 0x8000000000000000, 0) require.NoError(t, err) require.Equal(t, []uint64{0, 1}, results) add64USaturated := module.ExportedFunction("add64_u_saturated") - results, err = add64USaturated.Call(nil, 1230, 23) + results, err = add64USaturated.Call(testCtx, 1230, 23) require.NoError(t, err) require.Equal(t, []uint64{1253}, results) fac := module.ExportedFunction("fac") - results, err = fac.Call(nil, 25) + results, err = fac.Call(testCtx, 25) require.NoError(t, err) require.Equal(t, []uint64{7034535277573963776}, results) @@ -86,7 +90,7 @@ func testMultiValue(t *testing.T, newRuntimeConfig func() *wazero.RuntimeConfig) var brWasm []byte func testBr(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(brWasm) + module, err := r.InstantiateModuleFromCode(testCtx, brWasm) require.NoError(t, err) defer module.Close() @@ -109,7 +113,7 @@ func testBr(t *testing.T, r wazero.Runtime) { var callWasm []byte func testCall(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(callWasm) + module, err := r.InstantiateModuleFromCode(testCtx, callWasm) require.NoError(t, err) defer module.Close() @@ -130,7 +134,7 @@ func testCall(t *testing.T, r wazero.Runtime) { var callIndirectWasm []byte func testCallIndirect(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(callIndirectWasm) + module, err := r.InstantiateModuleFromCode(testCtx, callIndirectWasm) require.NoError(t, err) defer module.Close() @@ -141,7 +145,7 @@ func testCallIndirect(t *testing.T, r wazero.Runtime) { {name: "type-all-i32-i64", expected: []uint64{2, 1}}, }) - _, err = module.ExportedFunction("dispatch").Call(nil, 32, 2) + _, err = module.ExportedFunction("dispatch").Call(testCtx, 32, 2) require.EqualError(t, err, `wasm error: invalid table access wasm stack trace: call_indirect.wast.[16](i32,i64) i64`) @@ -152,12 +156,12 @@ wasm stack trace: var facWasm []byte func testFac(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(facWasm) + module, err := r.InstantiateModuleFromCode(testCtx, facWasm) require.NoError(t, err) defer module.Close() fac := module.ExportedFunction("fac-ssa") - results, err := fac.Call(nil, 25) + results, err := fac.Call(testCtx, 25) require.NoError(t, err) require.Equal(t, []uint64{7034535277573963776}, results) } @@ -167,7 +171,7 @@ func testFac(t *testing.T, r wazero.Runtime) { var funcWasm []byte func testFunc(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(funcWasm) + module, err := r.InstantiateModuleFromCode(testCtx, funcWasm) require.NoError(t, err) defer module.Close() @@ -196,7 +200,7 @@ func testFunc(t *testing.T, r wazero.Runtime) { }) fac := module.ExportedFunction("large-sig") - results, err := fac.Call(nil, + results, err := fac.Call(testCtx, 0, 1, api.EncodeF32(2), api.EncodeF32(3), 4, api.EncodeF64(5), api.EncodeF32(6), 7, 8, 9, api.EncodeF32(10), api.EncodeF64(11), @@ -215,7 +219,7 @@ func testFunc(t *testing.T, r wazero.Runtime) { var ifWasm []byte func testIf(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(ifWasm) + module, err := r.InstantiateModuleFromCode(testCtx, ifWasm) require.NoError(t, err) defer module.Close() @@ -267,7 +271,7 @@ func testIf(t *testing.T, r wazero.Runtime) { var loopWasm []byte func testLoop(t *testing.T, r wazero.Runtime) { - module, err := r.InstantiateModuleFromCode(loopWasm) + module, err := r.InstantiateModuleFromCode(testCtx, loopWasm) require.NoError(t, err) defer module.Close() @@ -296,7 +300,7 @@ func testFunctions(t *testing.T, module api.Module, tests []funcTest) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - results, err := module.ExportedFunction(tc.name).Call(nil, tc.params...) + results, err := module.ExportedFunction(tc.name).Call(testCtx, tc.params...) require.NoError(t, err) if tc.expected == nil { require.Equal(t, 0, len(results), "expected no results") diff --git a/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go b/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go index 1f25dcd0..598615a0 100644 --- a/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go +++ b/internal/integration_test/post1_0/sign-extension-ops/sign_extension_ops_test.go @@ -1,6 +1,7 @@ package sign_extension_ops import ( + "context" _ "embed" "fmt" "testing" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestSignExtensionOps_JIT(t *testing.T) { if !wazero.JITSupported { t.Skip() @@ -45,12 +49,12 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC t.Run("disabled", func(t *testing.T) { // Sign-extension is disabled by default. r := wazero.NewRuntimeWithConfig(newRuntimeConfig()) - _, err := r.InstantiateModuleFromCode(signExtend) + _, err := r.InstantiateModuleFromCode(testCtx, signExtend) require.Error(t, err) }) t.Run("enabled", func(t *testing.T) { r := wazero.NewRuntimeWithConfig(newRuntimeConfig().WithFeatureSignExtensionOps(true)) - module, err := r.InstantiateModuleFromCode(signExtend) + module, err := r.InstantiateModuleFromCode(testCtx, signExtend) require.NoError(t, err) signExtend32from8Name, signExtend32from16Name := "i32.extend8_s", "i32.extend16_s" @@ -83,7 +87,7 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC fn := module.ExportedFunction(tc.funcname) require.NotNil(t, fn) - actual, err := fn.Call(nil, uint64(uint32(tc.in))) + actual, err := fn.Call(testCtx, uint64(uint32(tc.in))) require.NoError(t, err) require.Equal(t, tc.expected, int32(actual[0])) }) @@ -131,7 +135,7 @@ func testSignExtensionOps(t *testing.T, newRuntimeConfig func() *wazero.RuntimeC fn := module.ExportedFunction(tc.funcname) require.NotNil(t, fn) - actual, err := fn.Call(nil, uint64(tc.in)) + actual, err := fn.Call(testCtx, uint64(tc.in)) require.NoError(t, err) require.Equal(t, tc.expected, int64(actual[0])) }) diff --git a/internal/integration_test/spectest/spec_test.go b/internal/integration_test/spectest/spec_test.go index 88934e1c..8c9ada20 100644 --- a/internal/integration_test/spectest/spec_test.go +++ b/internal/integration_test/spectest/spec_test.go @@ -23,6 +23,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasmruntime" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + //go:embed testdata/*.wasm //go:embed testdata/*.json var testcases embed.FS @@ -267,10 +270,10 @@ func addSpectestModule(t *testing.T, store *wasm.Store) { mod.TableSection = &wasm.Table{Min: 10, Max: &tableLimitMax} mod.ExportSection = append(mod.ExportSection, &wasm.Export{Name: "table", Index: 0, Type: wasm.ExternTypeTable}) - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) require.NoError(t, err) - _, err = store.Instantiate(context.Background(), mod, mod.NameSection.ModuleName, wasm.DefaultSysContext()) + _, err = store.Instantiate(testCtx, mod, mod.NameSection.ModuleName, wasm.DefaultSysContext()) require.NoError(t, err) } @@ -339,11 +342,11 @@ func runTest(t *testing.T, newEngine func(wasm.Features) wasm.Engine) { } } - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) require.NoError(t, err, msg) moduleName = strings.TrimPrefix(moduleName, "$") - _, err = store.Instantiate(context.Background(), mod, moduleName, nil) + _, err = store.Instantiate(testCtx, mod, moduleName, nil) lastInstantiatedModuleName = moduleName require.NoError(t, err) case "register": @@ -468,12 +471,12 @@ func requireInstantiationError(t *testing.T, store *wasm.Store, buf []byte, msg mod.AssignModuleID(buf) - err = store.Engine.CompileModule(mod) + err = store.Engine.CompileModule(testCtx, mod) if err != nil { return } - _, err = store.Instantiate(context.Background(), mod, t.Name(), nil) + _, err = store.Instantiate(testCtx, mod, t.Name(), nil) require.Error(t, err, msg) } @@ -521,6 +524,6 @@ func requireValueEq(t *testing.T, actual, expected uint64, valType wasm.ValueTyp // TODO: This is likely already covered with unit tests! func callFunction(s *wasm.Store, moduleName, funcName string, params ...uint64) ([]uint64, []wasm.ValueType, error) { fn := s.Module(moduleName).ExportedFunction(funcName) - results, err := fn.Call(nil, params...) + results, err := fn.Call(testCtx, params...) return results, fn.ResultTypes(), err } diff --git a/internal/integration_test/vs/bench_fac_test.go b/internal/integration_test/vs/bench_fac_test.go index 52bf038c..2d264487 100644 --- a/internal/integration_test/vs/bench_fac_test.go +++ b/internal/integration_test/vs/bench_fac_test.go @@ -5,6 +5,7 @@ package vs import ( + "context" _ "embed" "errors" "fmt" @@ -19,6 +20,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // ensureJITFastest is overridable via ldflags. Ex. // -ldflags '-X github.com/tetratelabs/wazero/vs.ensureJITFastest=true' var ensureJITFastest = "false" @@ -38,7 +42,7 @@ func TestFac(t *testing.T) { defer mod.Close() for i := 0; i < 10000; i++ { - res, err := fn.Call(nil, in) + res, err := fn.Call(testCtx, in) require.NoError(t, err) require.Equal(t, expValue, res[0]) } @@ -50,7 +54,7 @@ func TestFac(t *testing.T) { defer mod.Close() for i := 0; i < 10000; i++ { - res, err := fn.Call(nil, in) + res, err := fn.Call(testCtx, in) require.NoError(t, err) require.Equal(t, expValue, res[0]) } @@ -230,7 +234,7 @@ func interpreterFacInvoke(b *testing.B) { defer mod.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err = fn.Call(nil, facArgumentU64); err != nil { + if _, err = fn.Call(testCtx, facArgumentU64); err != nil { b.Fatal(err) } } @@ -244,7 +248,7 @@ func jitFacInvoke(b *testing.B) { defer mod.Close() b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err = fn.Call(nil, facArgumentU64); err != nil { + if _, err = fn.Call(testCtx, facArgumentU64); err != nil { b.Fatal(err) } } @@ -298,7 +302,7 @@ func goWasm3FacInvoke(b *testing.B) { func newWazeroFacBench(config *wazero.RuntimeConfig) (api.Module, api.Function, error) { r := wazero.NewRuntimeWithConfig(config) - m, err := r.InstantiateModuleFromCode(facWasm) + m, err := r.InstantiateModuleFromCode(testCtx, facWasm) if err != nil { return nil, nil, err } diff --git a/internal/integration_test/vs/codec_test.go b/internal/integration_test/vs/codec_test.go index 6c745ed8..d145a002 100644 --- a/internal/integration_test/vs/codec_test.go +++ b/internal/integration_test/vs/codec_test.go @@ -110,17 +110,17 @@ func TestExampleUpToDate(t *testing.T) { r := wazero.NewRuntimeWithConfig(wazero.NewRuntimeConfig().WithFinishedFeatures()) // Add WASI to satisfy import tests - wm, err := wasi.InstantiateSnapshotPreview1(r) + wm, err := wasi.InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) defer wm.Close() // Decode and instantiate the module - module, err := r.InstantiateModuleFromCode(exampleBinary) + module, err := r.InstantiateModuleFromCode(testCtx, exampleBinary) require.NoError(t, err) defer module.Close() // Call the swap function as a smoke test - results, err := module.ExportedFunction("swap").Call(nil, 1, 2) + results, err := module.ExportedFunction("swap").Call(testCtx, 1, 2) require.NoError(t, err) require.Equal(t, []uint64{2, 1}, results) }) diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 6071b29b..5e868c72 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -29,6 +29,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasmdebug" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + type EngineTester interface { NewEngine(enabledFeatures wasm.Features) wasm.Engine InitTable(me wasm.ModuleEngine, initTableLen uint32, initTableIdxToFnIdx map[wasm.Index]wasm.Index) []interface{} @@ -44,7 +47,7 @@ func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { t.Run("sets module name", func(t *testing.T) { m := &wasm.Module{} - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) me, err := e.NewModuleEngine(t.Name(), m, nil, nil, nil, nil) require.NoError(t, err) @@ -76,7 +79,7 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) // To use the function, we first need to add it to a module. @@ -91,17 +94,17 @@ func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { linkModuleToEngine(module, me) // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. - results, err := me.Call(module.CallCtx, fn, 3) + results, err := me.Call(testCtx, module.CallCtx, fn, 3) require.NoError(t, err) require.Equal(t, uint64(3), results[0]) t.Run("errs when not enough parameters", func(t *testing.T) { - _, err := me.Call(module.CallCtx, fn) + _, err := me.Call(testCtx, module.CallCtx, fn) require.EqualError(t, err, "expected 1 params, but passed 0") }) t.Run("errs when too many parameters", func(t *testing.T) { - _, err := me.Call(module.CallCtx, fn, 1, 2) + _, err := me.Call(testCtx, module.CallCtx, fn, 1, 2) require.EqualError(t, err, "expected 1 params, but passed 2") }) } @@ -117,7 +120,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{}, ID: wasm.ModuleID{0}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) // Instantiate the module, which has nothing but an empty table. @@ -139,7 +142,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{1}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) moduleFunctions := []*wasm.FunctionInstance{ @@ -170,7 +173,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{2}, } - err := e.CompileModule(importedModule) + err := e.CompileModule(testCtx, importedModule) require.NoError(t, err) importedModuleInstance := &wasm.ModuleInstance{} @@ -194,7 +197,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { CodeSection: []*wasm.Code{}, ID: wasm.ModuleID{3}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) tableInit := map[wasm.Index]wasm.Index{0: 2} @@ -217,7 +220,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{4}, } - err := e.CompileModule(importedModule) + err := e.CompileModule(testCtx, importedModule) require.NoError(t, err) importedModuleInstance := &wasm.ModuleInstance{} importedFunctions := []*wasm.FunctionInstance{ @@ -241,7 +244,7 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { ID: wasm.ModuleID{5}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) importingModuleInstance := &wasm.ModuleInstance{} @@ -284,11 +287,11 @@ func runTestModuleEngine_Call_HostFn_CallContext(t *testing.T, et EngineTester) TypeSection: []*wasm.FunctionType{sig}, } - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) module := &wasm.ModuleInstance{Memory: memory} - modCtx := wasm.NewCallContext(context.Background(), wasm.NewStore(features, e), module, nil) + modCtx := wasm.NewCallContext(wasm.NewStore(features, e), module, nil) f := &wasm.FunctionInstance{ GoFunc: &hostFn, @@ -303,7 +306,7 @@ func runTestModuleEngine_Call_HostFn_CallContext(t *testing.T, et EngineTester) t.Run("defaults to module memory when call stack empty", func(t *testing.T) { // When calling a host func directly, there may be no stack. This ensures the module's memory is used. - results, err := me.Call(modCtx, f, 3) + results, err := me.Call(testCtx, modCtx, f, 3) require.NoError(t, err) require.Equal(t, uint64(3), results[0]) require.Same(t, memory, ctxMemory) @@ -351,7 +354,7 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - results, err := f.Module.Engine.Call(m, f, 1) + results, err := f.Module.Engine.Call(testCtx, m, f, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -475,11 +478,11 @@ wasm stack trace: t.Run(tc.name, func(t *testing.T) { m := tc.module f := tc.fn - _, err := f.Module.Engine.Call(m, f, tc.input...) + _, err := f.Module.Engine.Call(testCtx, m, f, tc.input...) require.EqualError(t, err, tc.expectedErr) // Ensure the module still works - results, err := f.Module.Engine.Call(m, f, 1) + results, err := f.Module.Engine.Call(testCtx, m, f, 1) require.NoError(t, err) require.Equal(t, uint64(1), results[0]) }) @@ -515,7 +518,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ID: wasm.ModuleID{0}, } - err := e.CompileModule(hostFnModule) + err := e.CompileModule(testCtx, hostFnModule) require.NoError(t, err) hostFn := &wasm.FunctionInstance{GoFunc: &hostFnVal, Kind: wasm.FunctionKindGoNoContext, Type: ft} hostFnModuleInstance := &wasm.ModuleInstance{Name: "host"} @@ -536,7 +539,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ID: wasm.ModuleID{1}, } - err = e.CompileModule(importedModule) + err = e.CompileModule(testCtx, importedModule) require.NoError(t, err) // To use the function, we first need to add it to a module. @@ -560,7 +563,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo ImportSection: []*wasm.Import{{}}, ID: wasm.ModuleID{2}, } - err = e.CompileModule(importingModule) + err = e.CompileModule(testCtx, importingModule) require.NoError(t, err) // Add the exported function. @@ -592,7 +595,7 @@ func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.Mo func linkModuleToEngine(module *wasm.ModuleInstance, me wasm.ModuleEngine) { module.Engine = me // for JIT, links the module to the module-engine compiled from it (moduleInstanceEngineOffset). // callEngineModuleContextModuleInstanceAddressOffset - module.CallCtx = wasm.NewCallContext(context.Background(), nil, module, nil) + module.CallCtx = wasm.NewCallContext(nil, module, nil) } // addFunction assigns and adds a function to the module. diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index ddc09ec7..034130f9 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -12,9 +12,9 @@ import ( // compile time check to ensure CallContext implements api.Module var _ api.Module = &CallContext{} -func NewCallContext(ctx context.Context, store *Store, instance *ModuleInstance, Sys *SysContext) *CallContext { +func NewCallContext(store *Store, instance *ModuleInstance, Sys *SysContext) *CallContext { zero := uint64(0) - return &CallContext{ctx: ctx, memory: instance.Memory, module: instance, store: store, Sys: Sys, closed: &zero} + return &CallContext{memory: instance.Memory, module: instance, store: store, Sys: Sys, closed: &zero} } // CallContext is a function call context bound to a module. This is important as one module's functions can call @@ -24,8 +24,11 @@ func NewCallContext(ctx context.Context, store *Store, instance *ModuleInstance, // functionality like trace propagation. // Note: this also implements api.Module in order to simplify usage as a host function parameter. type CallContext struct { - // ctx is returned by Context and overridden WithContext - ctx context.Context // TODO: remove in next PR + // TODO: We've never found a great name for this. It is only used for function calls, hence CallContext, but it + // moves on a different axis than, for example, the context.Context. context.Context is the same root for the whole + // call stack, where the CallContext can change depending on where memory is defined and who defines the calling + // function. When we rename this again, we should try to capture as many key points possible on the docs. + module *ModuleInstance // memory is returned by Memory and overridden WithMemory memory api.Memory @@ -60,7 +63,7 @@ func (m *CallContext) Name() string { // WithMemory allows overriding memory without re-allocation when the result would be the same. func (m *CallContext) WithMemory(memory *MemoryInstance) *CallContext { if memory != nil && memory != m.memory { // only re-allocate if it will change the effective memory - return &CallContext{module: m.module, memory: memory, ctx: m.ctx, Sys: m.Sys, closed: m.closed} + return &CallContext{module: m.module, memory: memory, Sys: m.Sys, closed: m.closed} } return m } @@ -70,19 +73,6 @@ func (m *CallContext) String() string { return fmt.Sprintf("Module[%s]", m.Name()) } -// Context implements the same method as documented on api.Module -func (m *CallContext) Context() context.Context { - return m.ctx -} - -// WithContext implements the same method as documented on api.Module -func (m *CallContext) WithContext(ctx context.Context) api.Module { - if ctx != nil && ctx != m.ctx { // only re-allocate if it will change the effective context - return &CallContext{module: m.module, memory: m.memory, ctx: ctx, Sys: m.Sys, closed: m.closed} - } - return m -} - // Close implements the same method as documented on api.Module. func (m *CallContext) Close() (err error) { return m.CloseWithExitCode(0) @@ -145,11 +135,12 @@ func (f *importedFn) ResultTypes() []api.ValueType { } // Call implements the same method as documented on api.Function -func (f *importedFn) Call(m api.Module, params ...uint64) (ret []uint64, err error) { - if m == nil { - return f.importedFn.Call(f.importingModule, params...) +func (f *importedFn) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { + if ctx == nil { + ctx = context.Background() } - return f.importedFn.Call(m, params...) + mod := f.importingModule + return f.importedFn.Module.Engine.Call(ctx, mod, f.importedFn, params...) } // ParamTypes implements the same method as documented on api.Function @@ -163,15 +154,12 @@ func (f *FunctionInstance) ResultTypes() []api.ValueType { } // Call implements the same method as documented on api.Function -func (f *FunctionInstance) Call(m api.Module, params ...uint64) (ret []uint64, err error) { - mod := f.Module - modCtx, ok := m.(*CallContext) - if ok { - // TODO: check if the importing context is correct - } else { // allow nil to substitute for the defining module - modCtx = mod.CallCtx +func (f *FunctionInstance) Call(ctx context.Context, params ...uint64) (ret []uint64, err error) { + if ctx == nil { + ctx = context.Background() } - return mod.Engine.Call(modCtx, f, params...) + mod := f.Module + return mod.Engine.Call(ctx, mod.CallCtx, f, params...) } // ExportedGlobal implements api.Module ExportedGlobal diff --git a/internal/wasm/call_context_test.go b/internal/wasm/call_context_test.go index bfd12938..3cc9e8ca 100644 --- a/internal/wasm/call_context_test.go +++ b/internal/wasm/call_context_test.go @@ -8,55 +8,6 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -func TestCallContext_WithContext(t *testing.T) { - type key string - tests := []struct { - name string - mod *CallContext - ctx context.Context - expectSame bool - }{ - { - name: "nil->nil: same", - mod: &CallContext{}, - ctx: nil, - expectSame: true, - }, - { - name: "nil->ctx: not same", - mod: &CallContext{}, - ctx: context.WithValue(context.Background(), key("a"), "b"), - expectSame: false, - }, - { - name: "ctx->nil: same", - mod: &CallContext{ctx: context.Background()}, - ctx: nil, - expectSame: true, - }, - { - name: "ctx1->ctx2: not same", - mod: &CallContext{ctx: context.Background()}, - ctx: context.WithValue(context.Background(), key("a"), "b"), - expectSame: false, - }, - } - - for _, tt := range tests { - tc := tt - - t.Run(tc.name, func(t *testing.T) { - mod2 := tc.mod.WithContext(tc.ctx) - if tc.expectSame { - require.Same(t, tc.mod, mod2) - } else { - require.NotSame(t, tc.mod, mod2) - require.Equal(t, tc.ctx, mod2.Context()) - } - }) - } -} - func TestCallContext_WithMemory(t *testing.T) { tests := []struct { name string diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index 80e1a428..79b72ae5 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -1,10 +1,12 @@ package wasm +import "context" + // Engine is a Store-scoped mechanism to compile functions declared or imported by a module. // This is a top-level type implemented by an interpreter or JIT compiler. type Engine interface { // CompileModule implements the same method as documented on wasm.Engine. - CompileModule(module *Module) error + CompileModule(ctx context.Context, module *Module) error // NewModuleEngine compiles down the function instances in a module, and returns ModuleEngine for the module. // @@ -17,7 +19,13 @@ type Engine interface { // // Note: Input parameters must be pre-validated with wasm.Module Validate, to ensure no fields are invalid // due to reasons such as out-of-bounds. - NewModuleEngine(name string, module *Module, importedFunctions, moduleFunctions []*FunctionInstance, table *TableInstance, tableInit map[Index]Index) (ModuleEngine, error) + NewModuleEngine( + name string, + module *Module, + importedFunctions, moduleFunctions []*FunctionInstance, + table *TableInstance, + tableInit map[Index]Index, + ) (ModuleEngine, error) // DeleteCompiledModule releases compilation caches for the given module (source). // Note: it is safe to call this function for a module from which module instances are instantiated even when these module instances @@ -31,7 +39,5 @@ type ModuleEngine interface { Name() string // Call invokes a function instance f with given parameters. - // Returns the results from the function. - // The ctx's context.Context will be the outer-most ancestor of the argument to api.Function. - Call(ctx *CallContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) + Call(ctx context.Context, m *CallContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) } diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index f3331054..a6099635 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -21,8 +21,11 @@ const ( // a context.Context. FunctionKindGoContext // FunctionKindGoModule is a function implemented in Go, with a signature matching FunctionType, except arg - // zero is a Module. + // zero is an api.Module. FunctionKindGoModule + // FunctionKindGoContextModule is a function implemented in Go, with a signature matching FunctionType, except arg + // zero is a context.Context and arg one is an api.Module. + FunctionKindGoContextModule ) // Below are reflection code to get the interface type used to parse functions and set values. @@ -31,22 +34,6 @@ var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem() var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -// getGoFuncCallContextValue returns a reflect.Value for a context param[0], or nil if there isn't one. -func getGoFuncCallContextValue(fk FunctionKind, ctx *CallContext) *reflect.Value { - switch fk { - case FunctionKindGoNoContext: // no special param zero - case FunctionKindGoContext: - val := reflect.New(goContextType).Elem() - val.Set(reflect.ValueOf(ctx.Context())) - return &val - case FunctionKindGoModule: - val := reflect.New(moduleType).Elem() - val.Set(reflect.ValueOf(ctx)) - return &val - } - return nil -} - // PopGoFuncParams pops the correct number of parameters off the stack into a parameter slice for use in CallGoFunc // // For example, if the host function F requires the (x1 uint32, x2 float32) parameters, and @@ -55,7 +42,11 @@ func getGoFuncCallContextValue(fk FunctionKind, ctx *CallContext) *reflect.Value func PopGoFuncParams(f *FunctionInstance, popParam func() uint64) []uint64 { // First, determine how many values we need to pop paramCount := f.GoFunc.Type().NumIn() - if f.Kind != FunctionKindGoNoContext { + switch f.Kind { + case FunctionKindGoNoContext: + case FunctionKindGoContextModule: + paramCount -= 2 + default: paramCount-- } @@ -82,22 +73,30 @@ func PopValues(count int, popper func() uint64) []uint64 { // * callCtx is passed to the host function as a first argument. // // Note: ctx must use the caller's memory, which might be different from the defining module on an imported function. -func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 { +func CallGoFunc(ctx context.Context, callCtx *CallContext, f *FunctionInstance, params []uint64) []uint64 { tp := f.GoFunc.Type() var in []reflect.Value if tp.NumIn() != 0 { in = make([]reflect.Value, tp.NumIn()) - wasmParamOffset := 0 - if f.Kind != FunctionKindGoNoContext { - wasmParamOffset = 1 + i := 0 + switch f.Kind { + case FunctionKindGoContext: + in[0] = newContextVal(ctx) + i = 1 + case FunctionKindGoModule: + in[0] = newModuleVal(callCtx) + i = 1 + case FunctionKindGoContextModule: + in[0] = newContextVal(ctx) + in[1] = newModuleVal(callCtx) + i = 2 } - for i, raw := range params { - inI := i + wasmParamOffset - val := reflect.New(tp.In(inI)).Elem() - switch tp.In(inI).Kind() { + for _, raw := range params { + val := reflect.New(tp.In(i)).Elem() + switch tp.In(i).Kind() { case reflect.Float32: val.SetFloat(float64(math.Float32frombits(uint32(raw)))) case reflect.Float64: @@ -107,12 +106,8 @@ func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []ui case reflect.Int32, reflect.Int64: val.SetInt(int64(raw)) } - in[inI] = val - } - - // Handle any special parameter zero - if val := getGoFuncCallContextValue(f.Kind, callCtx); val != nil { - in[0] = *val + in[i] = val + i++ } } @@ -138,6 +133,18 @@ func CallGoFunc(callCtx *CallContext, f *FunctionInstance, params []uint64) []ui return results } +func newContextVal(ctx context.Context) reflect.Value { + val := reflect.New(goContextType).Elem() + val.Set(reflect.ValueOf(ctx)) + return val +} + +func newModuleVal(m api.Module) reflect.Value { + val := reflect.New(moduleType).Elem() + val.Set(reflect.ValueOf(m)) + return val +} + // getFunctionType returns the function type corresponding to the function signature or errs if invalid. func getFunctionType(fn *reflect.Value, enabledFeatures Features) (fk FunctionKind, ft *FunctionType, err error) { p := fn.Type() @@ -147,8 +154,13 @@ func getFunctionType(fn *reflect.Value, enabledFeatures Features) (fk FunctionKi return } + fk = kind(p) pOffset := 0 - if fk = kind(p); fk != FunctionKindGoNoContext { + switch fk { + case FunctionKindGoNoContext: + case FunctionKindGoContextModule: + pOffset = 2 + default: pOffset = 1 } @@ -212,6 +224,9 @@ func kind(p reflect.Type) FunctionKind { if p0.Implements(moduleType) { return FunctionKindGoModule } else if p0.Implements(goContextType) { + if pCount >= 2 && p.In(1).Implements(moduleType) { + return FunctionKindGoContextModule + } return FunctionKindGoContext } } diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index 24b07926..fe629047 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -10,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestGetFunctionType(t *testing.T) { var tests = []struct { name string @@ -35,6 +38,12 @@ func TestGetFunctionType(t *testing.T) { expectedKind: FunctionKindGoContext, expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, }, + { + name: "context.Context and api.Module void return", + inputFunc: func(context.Context, api.Module) {}, + expectedKind: FunctionKindGoContextModule, + expectedType: &FunctionType{Params: []ValueType{}, Results: []ValueType{}}, + }, { name: "all supported params and i32 result", inputFunc: func(uint32, uint64, float32, float64) uint32 { return 0 }, @@ -59,6 +68,12 @@ func TestGetFunctionType(t *testing.T) { expectedKind: FunctionKindGoContext, expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, }, + { + name: "all supported params and i32 result - context.Context and api.Module", + inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64) uint32 { return 0 }, + expectedKind: FunctionKindGoContextModule, + expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64}, Results: []ValueType{i32}}, + }, } for _, tt := range tests { tc := tt @@ -201,6 +216,10 @@ func TestPopGoFuncParams(t *testing.T) { name: "context.Context", inputFunc: func(context.Context) {}, }, + { + name: "context.Context and api.Module", + inputFunc: func(context.Context, api.Module) {}, + }, { name: "all supported params", inputFunc: func(uint32, uint64, float32, float64) {}, @@ -216,6 +235,11 @@ func TestPopGoFuncParams(t *testing.T) { inputFunc: func(context.Context, uint32, uint64, float32, float64) {}, expected: []uint64{4, 5, 6, 7}, }, + { + name: "all supported params - context.Context and api.Module", + inputFunc: func(context.Context, api.Module, uint32, uint64, float32, float64) {}, + expected: []uint64{4, 5, 6, 7}, + }, } for _, tt := range tests { @@ -233,9 +257,7 @@ func TestPopGoFuncParams(t *testing.T) { } func TestCallGoFunc(t *testing.T) { - expectedCtx, cancel := context.WithCancel(context.Background()) // arbitrary non-default context - defer cancel() - callCtx := &CallContext{ctx: expectedCtx} + callCtx := &CallContext{} var tests = []struct { name string @@ -255,7 +277,14 @@ func TestCallGoFunc(t *testing.T) { { name: "context.Context void return", inputFunc: func(ctx context.Context) { - require.Equal(t, expectedCtx, ctx) + require.Equal(t, testCtx, ctx) + }, + }, + { + name: "context.Context and api.Module void return", + inputFunc: func(ctx context.Context, m api.Module) { + require.Equal(t, testCtx, ctx) + require.Equal(t, callCtx, m) }, }, { @@ -318,7 +347,26 @@ func TestCallGoFunc(t *testing.T) { { name: "all supported params and i32 result - context.Context", inputFunc: func(ctx context.Context, w uint32, x uint64, y float32, z float64) uint32 { - require.Equal(t, expectedCtx, ctx) + require.Equal(t, testCtx, ctx) + require.Equal(t, uint32(math.MaxUint32), w) + require.Equal(t, uint64(math.MaxUint64), x) + require.Equal(t, float32(math.MaxFloat32), y) + require.Equal(t, math.MaxFloat64, z) + return 100 + }, + inputParams: []uint64{ + math.MaxUint32, + math.MaxUint64, + api.EncodeF32(math.MaxFloat32), + api.EncodeF64(math.MaxFloat64), + }, + expectedResults: []uint64{100}, + }, + { + name: "all supported params and i32 result - context.Context and api.Module", + inputFunc: func(ctx context.Context, m api.Module, w uint32, x uint64, y float32, z float64) uint32 { + require.Equal(t, testCtx, ctx) + require.Equal(t, callCtx, m) require.Equal(t, uint32(math.MaxUint32), w) require.Equal(t, uint64(math.MaxUint64), x) require.Equal(t, float32(math.MaxFloat32), y) @@ -342,7 +390,7 @@ func TestCallGoFunc(t *testing.T) { fk, _, err := getFunctionType(&goFunc, FeaturesFinished) require.NoError(t, err) - results := CallGoFunc(callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) + results := CallGoFunc(testCtx, callCtx, &FunctionInstance{Kind: fk, GoFunc: &goFunc}, tc.inputParams) require.Equal(t, tc.expectedResults, results) }) } diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 52c72105..e473ff19 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -1,6 +1,7 @@ package interpreter import ( + "context" "encoding/binary" "fmt" "math" @@ -172,7 +173,7 @@ type interpreterOp struct { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { if _, ok := e.getCodes(module); ok { // cache hit! return nil } @@ -186,7 +187,7 @@ func (e *engine) CompileModule(module *wasm.Module) error { funcs = append(funcs, &code{hostFn: hf}) } } else { - irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + irs, err := wazeroir.CompileFunctions(ctx, e.enabledFeatures, module) if err != nil { return err } @@ -513,7 +514,7 @@ func (me *moduleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (me *moduleEngine) Call(ctx context.Context, m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := me.functions[f.Index] @@ -554,27 +555,27 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para for _, param := range params { ce.pushValue(param) } - ce.callNativeFunc(m, compiled) + ce.callNativeFunc(ctx, m, compiled) results = wasm.PopValues(len(f.Type.Results), ce.popValue) } else { - results = ce.callGoFunc(m, compiled, params) + results = ce.callGoFunc(ctx, m, compiled, params) } return } -func (ce *callEngine) callGoFunc(ctx *wasm.CallContext, f *function, params []uint64) (results []uint64) { +func (ce *callEngine) callGoFunc(ctx context.Context, callCtx *wasm.CallContext, f *function, params []uint64) (results []uint64) { if len(ce.frames) > 0 { // Use the caller's memory, which might be different from the defining module on an imported function. - ctx = ctx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) + callCtx = callCtx.WithMemory(ce.frames[len(ce.frames)-1].f.source.Module.Memory) } frame := &callFrame{f: f} ce.pushFrame(frame) - results = wasm.CallGoFunc(ctx, f.source, params) + results = wasm.CallGoFunc(ctx, callCtx, f.source, params) ce.popFrame() return } -func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) callNativeFunc(ctx context.Context, callCtx *wasm.CallContext, f *function) { frame := &callFrame{f: f} moduleInst := f.source.Module memoryInst := moduleInst.Memory @@ -621,9 +622,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { { f := functions[op.us[0]] if f.hostFn != nil { - ce.callGoFuncWithStack(ctx, f) + ce.callGoFuncWithStack(ctx, callCtx, f) } else { - ce.callNativeFunc(ctx, f) + ce.callNativeFunc(ctx, callCtx, f) } frame.pc++ } @@ -642,9 +643,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { // Call in. if targetcode.hostFn != nil { - ce.callGoFuncWithStack(ctx, f) + ce.callGoFuncWithStack(ctx, callCtx, f) } else { - ce.callNativeFunc(ctx, targetcode) + ce.callNativeFunc(ctx, callCtx, targetcode) } frame.pc++ } @@ -1555,9 +1556,9 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.CallContext, f *function) { ce.popFrame() } -func (ce *callEngine) callGoFuncWithStack(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) callGoFuncWithStack(ctx context.Context, callCtx *wasm.CallContext, f *function) { params := wasm.PopGoFuncParams(f.source, ce.popValue) - results := ce.callGoFunc(ctx, f, params) + results := ce.callGoFunc(ctx, callCtx, f, params) for _, v := range results { ce.pushValue(v) } diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index 0c09ae6c..51e0f656 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -1,6 +1,7 @@ package interpreter import ( + "context" "fmt" "math" "testing" @@ -12,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/internal/wazeroir" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestInterpreter_CallEngine_PushFrame(t *testing.T) { f1 := &callFrame{} f2 := &callFrame{} @@ -135,7 +139,7 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { {kind: wazeroir.OperationKindBr, us: []uint64{math.MaxUint64}}, }, } - ce.callNativeFunc(&wasm.CallContext{}, f) + ce.callNativeFunc(testCtx, &wasm.CallContext{}, f) require.Equal(t, tc.expected, int32(uint32(ce.popValue()))) }) } @@ -187,7 +191,7 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { {kind: wazeroir.OperationKindBr, us: []uint64{math.MaxUint64}}, }, } - ce.callNativeFunc(&wasm.CallContext{}, f) + ce.callNativeFunc(testCtx, &wasm.CallContext{}, f) require.Equal(t, tc.expected, int64(ce.popValue())) }) } @@ -220,7 +224,7 @@ func TestInterpreter_Compile(t *testing.T) { ID: wasm.ModuleID{}, } - err := e.CompileModule(errModule) + err := e.CompileModule(testCtx, errModule) require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. @@ -241,7 +245,7 @@ func TestInterpreter_Compile(t *testing.T) { }, ID: wasm.ModuleID{}, } - err := e.CompileModule(okModule) + err := e.CompileModule(testCtx, okModule) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index a922261e..1ca3af89 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -1,6 +1,7 @@ package jit import ( + "context" "fmt" "reflect" "runtime" @@ -42,8 +43,8 @@ type ( // function calls originating from the same moduleEngine.Call execution. callEngine struct { // These contexts are read and written by JITed code. - // Note: we embed these structs so we can reduce the costs to access fields inside of them. - // Also, that eases the calculation of offsets to each field. + // Note: structs are embedded to reduce the costs to access fields inside them. Also, this eases field offset + // calculation. globalContext moduleContext valueStackContext @@ -393,7 +394,7 @@ func (e *engine) DeleteCompiledModule(module *wasm.Module) { } // CompileModule implements the same method as documented on wasm.Engine. -func (e *engine) CompileModule(module *wasm.Module) error { +func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error { if _, ok := e.getCodes(module); ok { // cache hit! return nil } @@ -416,7 +417,7 @@ func (e *engine) CompileModule(module *wasm.Module) error { funcs = append(funcs, compiled) } } else { - irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + irs, err := wazeroir.CompileFunctions(ctx, e.enabledFeatures, module) if err != nil { return err } @@ -496,12 +497,12 @@ func (me *moduleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { +func (me *moduleEngine) Call(ctx context.Context, callCtx *wasm.CallContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to // code on close aren't locked, neither is this read. compiled := me.functions[f.Index] if compiled == nil { // Lazy check the cause as it could be because the module was already closed. - if err = m.FailIfClosed(); err == nil { + if err = callCtx.FailIfClosed(); err == nil { panic(fmt.Errorf("BUG: %s.func[%d] was nil before close", me.name, f.Index)) } return @@ -522,7 +523,7 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { - err = m.FailIfClosed() + err = callCtx.FailIfClosed() } // TODO: ^^ Will not fail if the function was imported from a closed module. @@ -545,10 +546,10 @@ func (me *moduleEngine) Call(m *wasm.CallContext, f *wasm.FunctionInstance, para for _, v := range params { ce.pushValue(v) } - ce.execWasmFunction(m, compiled) + ce.execWasmFunction(ctx, callCtx, compiled) results = wasm.PopValues(len(f.Type.Results), ce.popValue) } else { - results = wasm.CallGoFunc(m, compiled.source, params) + results = wasm.CallGoFunc(ctx, callCtx, compiled.source, params) } return } @@ -648,7 +649,7 @@ const ( builtinFunctionIndexBreakPoint ) -func (ce *callEngine) execWasmFunction(ctx *wasm.CallContext, f *function) { +func (ce *callEngine) execWasmFunction(ctx context.Context, callCtx *wasm.CallContext, f *function) { // Push the initial callframe. ce.callFrameStack[0] = callFrame{returnAddress: f.codeInitialAddress, function: f} ce.globalContext.callFrameStackPointer++ @@ -675,8 +676,9 @@ jitentry: callerFunction := ce.callFrameAt(1).function params := wasm.PopGoFuncParams(calleeHostFunction.source, ce.popValue) results := wasm.CallGoFunc( + ctx, // Use the caller's memory, which might be different from the defining module on an imported function. - ctx.WithMemory(callerFunction.source.Module.Memory), + callCtx.WithMemory(callerFunction.source.Module.Memory), calleeHostFunction.source, params, ) diff --git a/internal/wasm/jit/engine_test.go b/internal/wasm/jit/engine_test.go index 4f277180..b6353548 100644 --- a/internal/wasm/jit/engine_test.go +++ b/internal/wasm/jit/engine_test.go @@ -13,6 +13,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasm" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + // Ensures that the offset consts do not drift when we manipulate the target structs. func TestJIT_VerifyOffsetValue(t *testing.T) { var me moduleEngine @@ -176,11 +179,11 @@ func TestJIT_CompileModule(t *testing.T) { ID: wasm.ModuleID{}, } - err := e.CompileModule(okModule) + err := e.CompileModule(testCtx, okModule) require.NoError(t, err) // Compiling same module shouldn't be compiled again, but instead should be cached. - err = e.CompileModule(okModule) + err = e.CompileModule(testCtx, okModule) require.NoError(t, err) compiled, ok := e.codes[okModule.ID] @@ -206,7 +209,7 @@ func TestJIT_CompileModule(t *testing.T) { } e := et.NewEngine(wasm.Features20191205).(*engine) - err := e.CompileModule(errModule) + err := e.CompileModule(testCtx, errModule) require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, the compiled functions must not be cached. @@ -215,7 +218,7 @@ func TestJIT_CompileModule(t *testing.T) { }) } -// TestReleasecode_Panic tests that an unexpected panic has some identifying information in it. +// TestJIT_Releasecode_Panic tests that an unexpected panic has some identifying information in it. func TestJIT_Releasecode_Panic(t *testing.T) { captured := require.CapturePanic(func() { releaseCode(&code{ @@ -256,10 +259,10 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { }}, map[string]*wasm.Memory{}, map[string]*wasm.Global{}, enabledFeatures) require.NoError(t, err) - err = store.Engine.CompileModule(hm) + err = store.Engine.CompileModule(testCtx, hm) require.NoError(t, err) - _, err = store.Instantiate(context.Background(), hm, hostModuleName, nil) + _, err = store.Instantiate(testCtx, hm, hostModuleName, nil) require.NoError(t, err) const valueStackCorruption = "value_stack_corruption" @@ -311,16 +314,16 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { ID: wasm.ModuleID{1}, } - err = store.Engine.CompileModule(m) + err = store.Engine.CompileModule(testCtx, m) require.NoError(t, err) - mi, err := store.Instantiate(context.Background(), m, t.Name(), nil) + mi, err := store.Instantiate(testCtx, m, t.Name(), nil) require.NoError(t, err) for _, fnName := range []string{valueStackCorruption, callStackCorruption} { fnName := fnName t.Run(fnName, func(t *testing.T) { - ret, err := mi.ExportedFunction(fnName).Call(nil) + ret, err := mi.ExportedFunction(fnName).Call(testCtx) require.NoError(t, err) require.Equal(t, uint32(expectedReturnValue), uint32(ret[0])) diff --git a/internal/wasm/jit/jit_impl_arm64.go b/internal/wasm/jit/jit_impl_arm64.go index 5198a9a8..1c6994e0 100644 --- a/internal/wasm/jit/jit_impl_arm64.go +++ b/internal/wasm/jit/jit_impl_arm64.go @@ -3056,8 +3056,8 @@ func (c *arm64Compiler) compileReservedMemoryRegisterInitialization() { } } -// compileModuleContextInitialization adds instructions to initialize ce.CallContext's fields based on -// ce.CallContext.ModuleInstanceAddress. +// compileModuleContextInitialization adds instructions to initialize ce.moduleContext's fields based on +// ce.moduleContext.ModuleInstanceAddress. // This is called in two cases: in function preamble, and on the return from (non-Go) function calls. func (c *arm64Compiler) compileModuleContextInitialization() error { c.markRegisterUsed(arm64CallingConventionModuleInstanceAddressRegister) diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 7de5a42e..39ac5949 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -245,6 +245,10 @@ func NewStore(enabledFeatures Features, engine Engine) *Store { // // Note: Module.Validate must be called prior to instantiation. func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sys *SysContext) (*CallContext, error) { + if ctx == nil { + ctx = context.Background() + } + if err := s.requireModuleName(name); err != nil { return nil, err } @@ -298,13 +302,13 @@ func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sy m.applyData(module.DataSection) // Build the default context for calls to this module. - m.CallCtx = NewCallContext(ctx, s, m, sys) + m.CallCtx = NewCallContext(s, m, sys) // Execute the start function. if module.StartSection != nil { funcIdx := *module.StartSection f := m.Functions[funcIdx] - if _, err = f.Module.Engine.Call(m.CallCtx, f); err != nil { + if _, err = f.Module.Engine.Call(ctx, m.CallCtx, f); err != nil { s.deleteModule(name) return nil, fmt.Errorf("start %s failed: %w", module.funcDesc(funcSection, funcIdx), err) } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 3db49f9c..31714a07 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -99,15 +99,12 @@ func TestStore_Instantiate(t *testing.T) { ) require.NoError(t, err) - type key string - ctx := context.WithValue(context.Background(), key("a"), "b") // arbitrary non-default context sys := &SysContext{} - mod, err := s.Instantiate(ctx, m, "", sys) + mod, err := s.Instantiate(testCtx, m, "", sys) require.NoError(t, err) defer mod.Close() t.Run("CallContext defaults", func(t *testing.T) { - require.Equal(t, ctx, mod.ctx) require.Equal(t, s.modules[""], mod.module) require.Equal(t, s.modules[""].Memory, mod.memory) require.Equal(t, s, mod.store) @@ -369,58 +366,6 @@ func TestCallContext_ExportedFunction(t *testing.T) { }) } -func TestFunctionInstance_Call(t *testing.T) { - store := NewStore(Features20191205, &mockEngine{shouldCompileFail: false, callFailIndex: -1}) - - // Add the host module - functionName := "fn" - - // This is a fake engine, so we don't capture inside the function body. - m, err := NewHostModule( - "host", - map[string]interface{}{functionName: func(api.Module) {}}, - map[string]*Memory{}, - map[string]*Global{}, - Features20191205, - ) - require.NoError(t, err) - - // Add the host module - imported, err := store.Instantiate(context.Background(), m, "host", nil) - require.NoError(t, err) - defer imported.Close() - - // Make a module to import the function - importing, err := store.Instantiate(context.Background(), &Module{ - TypeSection: []*FunctionType{{}}, - ImportSection: []*Import{{ - Type: ExternTypeFunc, - Module: imported.Name(), - Name: functionName, - DescFunc: 0, - }}, - MemorySection: &Memory{Min: 1}, - ExportSection: []*Export{{Type: ExternTypeFunc, Name: functionName, Index: 0}}, - }, "test", nil) - require.NoError(t, err) - defer imported.Close() - - fn := importing.ExportedFunction(functionName) - // me.ctx will hold the last seen context. - me := imported.module.Engine.(*mockModuleEngine) - t.Run("nil defaults to current module", func(t *testing.T) { - _, err := fn.Call(nil) - require.NoError(t, err) - require.Equal(t, importing, me.ctx) - }) - t.Run("override current module context", func(t *testing.T) { - ctx := importing.WithContext(context.TODO()) - _, err := fn.Call(ctx) - require.NoError(t, err) - require.Equal(t, ctx, me.ctx) - }) -} - type mockEngine struct { shouldCompileFail bool callFailIndex int @@ -428,7 +373,6 @@ type mockEngine struct { type mockModuleEngine struct { name string - ctx *CallContext callFailIndex int } @@ -448,7 +392,7 @@ func (e *mockEngine) NewModuleEngine(_ string, _ *Module, _, _ []*FunctionInstan func (e *mockEngine) DeleteCompiledModule(*Module) {} // CompileModule implements the same method as documented on wasm.Engine. -func (e *mockEngine) CompileModule(module *Module) error { return nil } +func (e *mockEngine) CompileModule(_ context.Context, _ *Module) error { return nil } // Name implements the same method as documented on wasm.ModuleEngine. func (e *mockModuleEngine) Name() string { @@ -456,12 +400,11 @@ func (e *mockModuleEngine) Name() string { } // Call implements the same method as documented on wasm.ModuleEngine. -func (e *mockModuleEngine) Call(ctx *CallContext, f *FunctionInstance, _ ...uint64) (results []uint64, err error) { +func (e *mockModuleEngine) Call(ctx context.Context, callCtx *CallContext, f *FunctionInstance, _ ...uint64) (results []uint64, err error) { if e.callFailIndex >= 0 && f.Index == Index(e.callFailIndex) { err = errors.New("call failed") return } - e.ctx = ctx return } diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 2eb28008..30b3bf07 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -2,6 +2,7 @@ package wazeroir import ( "bytes" + "context" "encoding/binary" "fmt" "math" @@ -180,7 +181,7 @@ type CompilationResult struct { HasTable bool } -func CompileFunctions(enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { +func CompileFunctions(_ context.Context, enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { functions, globals, mem, table, err := module.AllDeclarations() if err != nil { return nil, err diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 029a896d..9e34130a 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -1,6 +1,7 @@ package wazeroir import ( + "context" "testing" "github.com/tetratelabs/wazero/api" @@ -9,6 +10,9 @@ import ( "github.com/tetratelabs/wazero/internal/wasm/text" ) +// ctx is an arbitrary, non-default context. +var ctx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + var ( f64, i32 = wasm.ValueTypeF64, wasm.ValueTypeI32 i32_i32 = &wasm.FunctionType{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}} @@ -65,7 +69,7 @@ func TestCompile(t *testing.T) { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, tc.module) + res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) require.Equal(t, tc.expected, res[0]) }) @@ -395,7 +399,7 @@ func TestCompile_MultiValue(t *testing.T) { if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, tc.module) + res, err := CompileFunctions(ctx, enabledFeatures, tc.module) require.NoError(t, err) require.Equal(t, tc.expected, res[0]) }) @@ -406,7 +410,7 @@ func requireCompilationResult(t *testing.T, enabledFeatures wasm.Features, expec if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - res, err := CompileFunctions(enabledFeatures, module) + res, err := CompileFunctions(ctx, enabledFeatures, module) require.NoError(t, err) require.Equal(t, expected, res[0]) } diff --git a/wasi/example_test.go b/wasi/example_test.go index 93f5b15a..461b2f51 100644 --- a/wasi/example_test.go +++ b/wasi/example_test.go @@ -1,6 +1,7 @@ package wasi import ( + "context" "fmt" "log" "os" @@ -13,10 +14,14 @@ import ( // // See https://github.com/tetratelabs/wazero/tree/main/examples/wasi for another example. func Example() { + // Choose the context to use for function calls. + ctx := context.Background() + + // Create a new WebAssembly Runtime. r := wazero.NewRuntime() // Instantiate WASI, which implements system I/O such as console output. - wm, err := InstantiateSnapshotPreview1(r) + wm, err := InstantiateSnapshotPreview1(ctx, r) if err != nil { log.Fatal(err) } @@ -26,7 +31,7 @@ func Example() { config := wazero.NewModuleConfig().WithStdout(os.Stdout) // InstantiateModuleFromCodeWithConfig runs the "_start" function which is like a "main" function. - _, err = r.InstantiateModuleFromCodeWithConfig([]byte(` + _, err = r.InstantiateModuleFromCodeWithConfig(ctx, []byte(` (module (import "wasi_snapshot_preview1" "proc_exit" (func $wasi.proc_exit (param $rval i32))) diff --git a/wasi/usage_test.go b/wasi/usage_test.go index a70930c8..b3048649 100644 --- a/wasi/usage_test.go +++ b/wasi/usage_test.go @@ -20,17 +20,17 @@ func TestInstantiateModuleWithConfig(t *testing.T) { // Configure WASI to write stdout to a buffer, so that we can verify it later. sys := wazero.NewModuleConfig().WithStdout(stdout) - wm, err := InstantiateSnapshotPreview1(r) + wm, err := InstantiateSnapshotPreview1(testCtx, r) require.NoError(t, err) defer wm.Close() - compiled, err := r.CompileModule(wasiArg) + compiled, err := r.CompileModule(testCtx, wasiArg) require.NoError(t, err) defer compiled.Close() // Re-use the same module many times. for _, tc := range []string{"a", "b", "c"} { - mod, err := r.InstantiateModuleWithConfig(compiled, sys.WithArgs(tc).WithName(tc)) + mod, err := r.InstantiateModuleWithConfig(testCtx, compiled, sys.WithArgs(tc).WithName(tc)) require.NoError(t, err) // Ensure the scoped configuration applied. As the args are null-terminated, we append zero (NUL). diff --git a/wasi/wasi.go b/wasi/wasi.go index 18295028..6d9db6ec 100644 --- a/wasi/wasi.go +++ b/wasi/wasi.go @@ -5,6 +5,7 @@ package wasi import ( + "context" crand "crypto/rand" "errors" "fmt" @@ -25,9 +26,9 @@ const ModuleSnapshotPreview1 = "wasi_snapshot_preview1" // InstantiateSnapshotPreview1 instantiates ModuleSnapshotPreview1, so that other modules can import them. // // Note: All WASI functions return a single Errno result, ErrnoSuccess on success. -func InstantiateSnapshotPreview1(r wazero.Runtime) (api.Module, error) { +func InstantiateSnapshotPreview1(ctx context.Context, r wazero.Runtime) (api.Module, error) { _, fns := snapshotPreview1Functions() - return r.NewModuleBuilder(ModuleSnapshotPreview1).ExportFunctions(fns).Instantiate() + return r.NewModuleBuilder(ModuleSnapshotPreview1).ExportFunctions(fns).Instantiate(ctx) } const ( diff --git a/wasi/wasi_bench_test.go b/wasi/wasi_bench_test.go index 92e7bbaf..7820a187 100644 --- a/wasi/wasi_bench_test.go +++ b/wasi/wasi_bench_test.go @@ -1,7 +1,6 @@ package wasi import ( - "context" "testing" "github.com/tetratelabs/wazero/internal/testing/require" @@ -59,7 +58,7 @@ func Benchmark_EnvironGet(b *testing.B) { } func newCtx(buf []byte, sys *wasm.SysContext) *wasm.CallContext { - return wasm.NewCallContext(context.Background(), nil, &wasm.ModuleInstance{ + return wasm.NewCallContext(nil, &wasm.ModuleInstance{ Memory: &wasm.MemoryInstance{Min: 1, Buffer: buf}, }, sys) } diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index 88920159..5c6993ef 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -2,6 +2,7 @@ package wasi import ( "bytes" + "context" _ "embed" "errors" "fmt" @@ -21,6 +22,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestSnapshotPreview1_ArgsGet(t *testing.T) { sysCtx, err := newSysContext([]string{"a", "bc"}, nil, nil) require.NoError(t, err) @@ -54,7 +58,7 @@ func TestSnapshotPreview1_ArgsGet(t *testing.T) { t.Run(functionArgsGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(argv), uint64(argvBuf)) + results, err := fn.Call(testCtx, uint64(argv), uint64(argvBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -147,7 +151,7 @@ func TestSnapshotPreview1_ArgsSizesGet(t *testing.T) { t.Run(functionArgsSizesGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultArgc), uint64(resultArgvBufSize)) + results, err := fn.Call(testCtx, uint64(resultArgc), uint64(resultArgvBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -239,7 +243,7 @@ func TestSnapshotPreview1_EnvironGet(t *testing.T) { t.Run(functionEnvironGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultEnviron), uint64(resultEnvironBuf)) + results, err := fn.Call(testCtx, uint64(resultEnviron), uint64(resultEnvironBuf)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -331,7 +335,7 @@ func TestSnapshotPreview1_EnvironSizesGet(t *testing.T) { t.Run(functionEnvironSizesGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(resultEnvironc), uint64(resultEnvironBufSize)) + results, err := fn.Call(testCtx, uint64(resultEnvironc), uint64(resultEnvironBufSize)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -399,7 +403,7 @@ func TestSnapshotPreview1_ClockResGet(t *testing.T) { }) t.Run(functionClockResGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -435,7 +439,7 @@ func TestSnapshotPreview1_ClockTimeGet(t *testing.T) { t.Run(functionClockTimeGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(resultTimestamp)) + results, err := fn.Call(testCtx, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(resultTimestamp)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -476,7 +480,7 @@ func TestSnapshotPreview1_ClockTimeGet_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - results, err := fn.Call(mod, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(tc.resultTimestamp)) + results, err := fn.Call(testCtx, 0 /* TODO: id */, 0 /* TODO: precision */, uint64(tc.resultTimestamp)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoFault, errno, ErrnoName(errno)) @@ -495,7 +499,7 @@ func TestSnapshotPreview1_FdAdvise(t *testing.T) { }) t.Run(functionFdAdvise, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -513,7 +517,7 @@ func TestSnapshotPreview1_FdAllocate(t *testing.T) { }) t.Run(functionFdAllocate, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -566,7 +570,7 @@ func TestSnapshotPreview1_FdClose(t *testing.T) { mod, fn, _ := setupFD() defer mod.Close() - results, err := fn.Call(mod, uint64(fdToClose)) + results, err := fn.Call(testCtx, uint64(fdToClose)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -593,7 +597,7 @@ func TestSnapshotPreview1_FdDatasync(t *testing.T) { }) t.Run(functionFdDatasync, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -617,7 +621,7 @@ func TestSnapshotPreview1_FdFdstatSetFlags(t *testing.T) { }) t.Run(functionFdFdstatSetFlags, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -635,7 +639,7 @@ func TestSnapshotPreview1_FdFdstatSetRights(t *testing.T) { }) t.Run(functionFdFdstatSetRights, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -653,7 +657,7 @@ func TestSnapshotPreview1_FdFilestatGet(t *testing.T) { }) t.Run(functionFdFilestatGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -671,7 +675,7 @@ func TestSnapshotPreview1_FdFilestatSetSize(t *testing.T) { }) t.Run(functionFdFilestatSetSize, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -689,7 +693,7 @@ func TestSnapshotPreview1_FdFilestatSetTimes(t *testing.T) { }) t.Run(functionFdFilestatSetTimes, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -707,7 +711,7 @@ func TestSnapshotPreview1_FdPread(t *testing.T) { }) t.Run(functionFdPread, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -748,7 +752,7 @@ func TestSnapshotPreview1_FdPrestatGet(t *testing.T) { t.Run(functionFdPrestatDirName, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(fd), uint64(resultPrestat)) + results, err := fn.Call(testCtx, uint64(fd), uint64(resultPrestat)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -833,7 +837,7 @@ func TestSnapshotPreview1_FdPrestatDirName(t *testing.T) { t.Run(functionFdPrestatDirName, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(fd), uint64(path), uint64(pathLen)) + results, err := fn.Call(testCtx, uint64(fd), uint64(path), uint64(pathLen)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -915,7 +919,7 @@ func TestSnapshotPreview1_FdPwrite(t *testing.T) { }) t.Run(functionFdPwrite, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -956,7 +960,7 @@ func TestSnapshotPreview1_FdRead(t *testing.T) { }}, {functionFdRead, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdReadFn { return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - results, err := fn.Call(mod, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) + results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) } @@ -1094,7 +1098,7 @@ func TestSnapshotPreview1_FdReaddir(t *testing.T) { }) t.Run(functionFdReaddir, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1112,7 +1116,7 @@ func TestSnapshotPreview1_FdRenumber(t *testing.T) { }) t.Run(functionFdRenumber, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1143,7 +1147,7 @@ func TestSnapshotPreview1_FdSeek(t *testing.T) { }}, {functionFdSeek, func() fdSeekFn { return func(ctx api.Module, fd uint32, offset uint64, whence, resultNewoffset uint32) Errno { - results, err := fn.Call(mod, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) + results, err := fn.Call(testCtx, uint64(fd), offset, uint64(whence), uint64(resultNewoffset)) require.NoError(t, err) return Errno(results[0]) } @@ -1287,7 +1291,7 @@ func TestSnapshotPreview1_FdSync(t *testing.T) { }) t.Run(functionFdSync, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1305,7 +1309,7 @@ func TestSnapshotPreview1_FdTell(t *testing.T) { }) t.Run(functionFdTell, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1346,7 +1350,7 @@ func TestSnapshotPreview1_FdWrite(t *testing.T) { }}, {functionFdWrite, func(_ *snapshotPreview1, mod api.Module, fn api.Function) fdWriteFn { return func(ctx api.Module, fd, iovs, iovsCount, resultSize uint32) Errno { - results, err := fn.Call(mod, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) + results, err := fn.Call(testCtx, uint64(fd), uint64(iovs), uint64(iovsCount), uint64(resultSize)) require.NoError(t, err) return Errno(results[0]) } @@ -1478,7 +1482,7 @@ func TestSnapshotPreview1_PathCreateDirectory(t *testing.T) { }) t.Run(functionPathCreateDirectory, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1496,7 +1500,7 @@ func TestSnapshotPreview1_PathFilestatGet(t *testing.T) { }) t.Run(functionPathFilestatGet, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1514,7 +1518,7 @@ func TestSnapshotPreview1_PathFilestatSetTimes(t *testing.T) { }) t.Run(functionPathFilestatSetTimes, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1532,7 +1536,7 @@ func TestSnapshotPreview1_PathLink(t *testing.T) { }) t.Run(functionPathLink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1597,7 +1601,7 @@ func TestSnapshotPreview1_PathOpen(t *testing.T) { t.Run(functionPathOpen, func(t *testing.T) { _, mod, fn := setup() - results, err := fn.Call(mod, uint64(workdirFD), uint64(dirflags), uint64(path), uint64(pathLen), uint64(oflags), fsRightsBase, fsRightsInheriting, uint64(fdFlags), uint64(resultOpenedFd)) + results, err := fn.Call(testCtx, uint64(workdirFD), uint64(dirflags), uint64(path), uint64(pathLen), uint64(oflags), fsRightsBase, fsRightsInheriting, uint64(fdFlags), uint64(resultOpenedFd)) require.NoError(t, err) errno := Errno(results[0]) verify(errno, mod) @@ -1682,7 +1686,7 @@ func TestSnapshotPreview1_PathReadlink(t *testing.T) { }) t.Run(functionPathReadlink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1700,7 +1704,7 @@ func TestSnapshotPreview1_PathRemoveDirectory(t *testing.T) { }) t.Run(functionPathRemoveDirectory, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1718,7 +1722,7 @@ func TestSnapshotPreview1_PathRename(t *testing.T) { }) t.Run(functionPathRename, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1736,7 +1740,7 @@ func TestSnapshotPreview1_PathSymlink(t *testing.T) { }) t.Run(functionPathSymlink, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1754,7 +1758,7 @@ func TestSnapshotPreview1_PathUnlinkFile(t *testing.T) { }) t.Run(functionPathUnlinkFile, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1772,7 +1776,7 @@ func TestSnapshotPreview1_PollOneoff(t *testing.T) { }) t.Run(functionPollOneoff, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1804,7 +1808,7 @@ func TestSnapshotPreview1_ProcExit(t *testing.T) { defer mod.Close() // When ProcExit is called, store.Callfunction returns immediately, returning the exit code as the error. - _, err := fn.Call(nil, uint64(tc.exitCode)) + _, err := fn.Call(testCtx, uint64(tc.exitCode)) require.Equal(t, tc.exitCode, err.(*sys.ExitError).ExitCode()) }) } @@ -1821,7 +1825,7 @@ func TestSnapshotPreview1_ProcRaise(t *testing.T) { }) t.Run(functionProcRaise, func(t *testing.T) { - results, err := fn.Call(mod, 0) + results, err := fn.Call(testCtx, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1839,7 +1843,7 @@ func TestSnapshotPreview1_SchedYield(t *testing.T) { }) t.Run(functionSchedYield, func(t *testing.T) { - results, err := fn.Call(mod) + results, err := fn.Call(testCtx) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1883,7 +1887,7 @@ func TestSnapshotPreview1_RandomGet(t *testing.T) { t.Run(functionRandomGet, func(t *testing.T) { maskMemory(t, mod, len(expectedMemory)) - results, err := fn.Call(mod, uint64(offset), uint64(length)) + results, err := fn.Call(testCtx, uint64(offset), uint64(length)) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Zero(t, errno, ErrnoName(errno)) @@ -1953,7 +1957,7 @@ func TestSnapshotPreview1_SockRecv(t *testing.T) { }) t.Run(functionSockRecv, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1971,7 +1975,7 @@ func TestSnapshotPreview1_SockSend(t *testing.T) { }) t.Run(functionSockSend, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0, 0, 0, 0) + results, err := fn.Call(testCtx, 0, 0, 0, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -1989,7 +1993,7 @@ func TestSnapshotPreview1_SockShutdown(t *testing.T) { }) t.Run(functionSockShutdown, func(t *testing.T) { - results, err := fn.Call(mod, 0, 0) + results, err := fn.Call(testCtx, 0, 0) require.NoError(t, err) errno := Errno(results[0]) // results[0] is the errno require.Equal(t, ErrnoNosys, errno, ErrnoName(errno)) @@ -2011,10 +2015,10 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa // The package `wazero` has a simpler interface for adding host modules, but we can't use that as it would create an // import cycle. Instead, we export wasm.NewHostModule and use it here. a, fns := snapshotPreview1Functions() - _, err := r.NewModuleBuilder("wasi_snapshot_preview1").ExportFunctions(fns).Instantiate() + _, err := r.NewModuleBuilder("wasi_snapshot_preview1").ExportFunctions(fns).Instantiate(testCtx) require.NoError(t, err) - compiled, err := r.CompileModule([]byte(fmt.Sprintf(`(module + compiled, err := r.CompileModule(testCtx, []byte(fmt.Sprintf(`(module %[2]s (memory 1 1) ;; just an arbitrary size big enough for tests (export "memory" (memory 0)) @@ -2023,7 +2027,7 @@ func instantiateModule(t *testing.T, wasifunction, wasiimport string, sysCtx *wa require.NoError(t, err) defer compiled.Close() - mod, err := r.InstantiateModuleWithConfig(compiled, wazero.NewModuleConfig().WithName(t.Name())) + mod, err := r.InstantiateModuleWithConfig(testCtx, compiled, wazero.NewModuleConfig().WithName(t.Name())) require.NoError(t, err) if sysCtx != nil { diff --git a/wasm.go b/wasm.go index 58818aaa..62c90940 100644 --- a/wasm.go +++ b/wasm.go @@ -16,9 +16,10 @@ import ( // Runtime allows embedding of WebAssembly 1.0 (20191205) modules. // // Ex. +// ctx := context.Background() // r := wazero.NewRuntime() -// code, _ := r.CompileModule(source) -// module, _ := r.InstantiateModule(code) +// compiled, _ := r.CompileModule(ctx, source) +// module, _ := r.InstantiateModule(ctx, compiled) // defer module.Close() // // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/ @@ -27,10 +28,11 @@ type Runtime interface { // // Ex. Below defines and instantiates a module named "env" with one function: // + // ctx := context.Background() // hello := func() { // fmt.Fprintln(stdout, "hello!") // } - // _, err := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate() + // _, err := r.NewModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx) NewModuleBuilder(moduleName string) ModuleBuilder // Module returns exports from an instantiated module or nil if there aren't any. @@ -43,38 +45,46 @@ type Runtime interface { // * Improve performance when the same module is instantiated multiple times under different names // * Reduce the amount of errors that can occur during InstantiateModule. // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: The resulting module name defaults to what was binary from the custom name section. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#name-section%E2%91%A0 - CompileModule(source []byte) (*CompiledCode, error) + CompileModule(ctx context.Context, source []byte) (*CompiledCode, error) // InstantiateModuleFromCode instantiates a module from the WebAssembly 1.0 (20191205) text or binary source or // errs if invalid. // // Ex. - // module, _ := wazero.NewRuntime().InstantiateModuleFromCode(source) + // ctx := context.Background() + // module, _ := wazero.NewRuntime().InstantiateModuleFromCode(ctx, source) // defer module.Close() // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: This is a convenience utility that chains CompileModule with InstantiateModule. To instantiate the same // source multiple times, use CompileModule as InstantiateModule avoids redundant decoding and/or compilation. - InstantiateModuleFromCode(source []byte) (api.Module, error) + InstantiateModuleFromCode(ctx context.Context, source []byte) (api.Module, error) // InstantiateModuleFromCodeWithConfig is a convenience function that chains CompileModule to // InstantiateModuleWithConfig. // // Ex. To only change the module name: - // wasm, _ := wazero.NewRuntime().InstantiateModuleFromCodeWithConfig(source, wazero.NewModuleConfig(). + // ctx := context.Background() + // r := wazero.NewRuntime() + // wasm, _ := r.InstantiateModuleFromCodeWithConfig(ctx, source, wazero.NewModuleConfig(). // WithName("wasm") // ) // defer wasm.Close() - InstantiateModuleFromCodeWithConfig(source []byte, config *ModuleConfig) (api.Module, error) + // + // Note: When `ctx` is nil, it defaults to context.Background. + InstantiateModuleFromCodeWithConfig(ctx context.Context, source []byte, config *ModuleConfig) (api.Module, error) // InstantiateModule instantiates the module namespace or errs if the configuration was invalid. // // Ex. + // ctx := context.Background() // r := wazero.NewRuntime() - // code, _ := r.CompileModule(source) - // defer code.Close() - // module, _ := r.InstantiateModule(code) + // compiled, _ := r.CompileModule(ctx, source) + // defer compiled.Close() + // module, _ := r.InstantiateModule(ctx, compiled) // defer module.Close() // // While CompiledCode is pre-validated, there are a few situations which can cause an error: @@ -82,26 +92,28 @@ type Runtime interface { // * The module has a table element initializer that resolves to an index outside the Table minimum size. // * The module has a start function, and it failed to execute. // - // Note: The last value of RuntimeConfig.WithContext is used for any start function. - InstantiateModule(code *CompiledCode) (api.Module, error) + // Note: When `ctx` is nil, it defaults to context.Background. + InstantiateModule(ctx context.Context, compiled *CompiledCode) (api.Module, error) // InstantiateModuleWithConfig is like InstantiateModule, except you can override configuration such as the module // name or ENV variables. // // For example, you can use this to define different args depending on the importing module. // + // ctx := context.Background() // r := wazero.NewRuntime() - // wasi, _ := r.InstantiateModule(wazero.WASISnapshotPreview1()) - // code, _ := r.CompileModule(source) + // wasi, _ := wasi.InstantiateSnapshotPreview1(r) + // compiled, _ := r.CompileModule(ctx, source) // // // Initialize base configuration: // config := wazero.NewModuleConfig().WithStdout(buf) // // // Assign different configuration on each instantiation - // module, _ := r.InstantiateModuleWithConfig(code, config.WithName("rotate").WithArgs("rotate", "angle=90", "dir=cw")) + // module, _ := r.InstantiateModuleWithConfig(ctx, compiled, config.WithName("rotate").WithArgs("rotate", "angle=90", "dir=cw")) // + // Note: when `ctx` is nil, it defaults to context.Background. // Note: Config is copied during instantiation: Later changes to config do not affect the instantiated result. - InstantiateModuleWithConfig(code *CompiledCode, config *ModuleConfig) (mod api.Module, err error) + InstantiateModuleWithConfig(ctx context.Context, compiled *CompiledCode, config *ModuleConfig) (mod api.Module, err error) } func NewRuntime() Runtime { @@ -111,7 +123,6 @@ func NewRuntime() Runtime { // NewRuntimeWithConfig returns a runtime with the given configuration. func NewRuntimeWithConfig(config *RuntimeConfig) Runtime { return &runtime{ - ctx: config.ctx, store: wasm.NewStore(config.enabledFeatures, config.newEngine(config.enabledFeatures)), enabledFeatures: config.enabledFeatures, memoryMaxPages: config.memoryMaxPages, @@ -121,7 +132,6 @@ func NewRuntimeWithConfig(config *RuntimeConfig) Runtime { // runtime allows decoupling of public interfaces from internal representation. type runtime struct { enabledFeatures wasm.Features - ctx context.Context store *wasm.Store memoryMaxPages uint32 } @@ -132,7 +142,7 @@ func (r *runtime) Module(moduleName string) api.Module { } // CompileModule implements Runtime.CompileModule -func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { +func (r *runtime) CompileModule(ctx context.Context, source []byte) (*CompiledCode, error) { if source == nil { return nil, errors.New("source == nil") } @@ -166,7 +176,7 @@ func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { internal.AssignModuleID(source) - if err = r.store.Engine.CompileModule(internal); err != nil { + if err = r.store.Engine.CompileModule(ctx, internal); err != nil { return nil, err } @@ -174,47 +184,47 @@ func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { } // InstantiateModuleFromCode implements Runtime.InstantiateModuleFromCode -func (r *runtime) InstantiateModuleFromCode(source []byte) (api.Module, error) { - if code, err := r.CompileModule(source); err != nil { +func (r *runtime) InstantiateModuleFromCode(ctx context.Context, source []byte) (api.Module, error) { + if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer code.Close() - return r.InstantiateModule(code) + defer compiled.Close() + return r.InstantiateModule(ctx, compiled) } } // InstantiateModuleFromCodeWithConfig implements Runtime.InstantiateModuleFromCodeWithConfig -func (r *runtime) InstantiateModuleFromCodeWithConfig(source []byte, config *ModuleConfig) (api.Module, error) { - if code, err := r.CompileModule(source); err != nil { +func (r *runtime) InstantiateModuleFromCodeWithConfig(ctx context.Context, source []byte, config *ModuleConfig) (api.Module, error) { + if compiled, err := r.CompileModule(ctx, source); err != nil { return nil, err } else { // *wasm.ModuleInstance for the source cannot be tracked, so we release the cache inside of this function. - defer code.Close() - return r.InstantiateModuleWithConfig(code, config) + defer compiled.Close() + return r.InstantiateModuleWithConfig(ctx, compiled, config) } } // InstantiateModule implements Runtime.InstantiateModule -func (r *runtime) InstantiateModule(code *CompiledCode) (mod api.Module, err error) { - return r.InstantiateModuleWithConfig(code, NewModuleConfig()) +func (r *runtime) InstantiateModule(ctx context.Context, compiled *CompiledCode) (mod api.Module, err error) { + return r.InstantiateModuleWithConfig(ctx, compiled, NewModuleConfig()) } // InstantiateModuleWithConfig implements Runtime.InstantiateModuleWithConfig -func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *ModuleConfig) (mod api.Module, err error) { +func (r *runtime) InstantiateModuleWithConfig(ctx context.Context, compiled *CompiledCode, config *ModuleConfig) (mod api.Module, err error) { var sysCtx *wasm.SysContext if sysCtx, err = config.toSysContext(); err != nil { return } name := config.name - if name == "" && code.module.NameSection != nil && code.module.NameSection.ModuleName != "" { - name = code.module.NameSection.ModuleName + if name == "" && compiled.module.NameSection != nil && compiled.module.NameSection.ModuleName != "" { + name = compiled.module.NameSection.ModuleName } - module := config.replaceImports(code.module) + module := config.replaceImports(compiled.module) - mod, err = r.store.Instantiate(r.ctx, module, name, sysCtx) + mod, err = r.store.Instantiate(ctx, module, name, sysCtx) if err != nil { return } @@ -224,7 +234,7 @@ func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *Module if start == nil { continue } - if _, err = start.Call(mod.WithContext(r.ctx)); err != nil { + if _, err = start.Call(ctx); err != nil { if _, ok := err.(*sys.ExitError); ok { return } diff --git a/wasm_test.go b/wasm_test.go index 125ac204..915ac609 100644 --- a/wasm_test.go +++ b/wasm_test.go @@ -15,6 +15,9 @@ import ( "github.com/tetratelabs/wazero/sys" ) +// testCtx is an arbitrary, non-default context. Non-nil also prevents linter errors. +var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary") + func TestRuntime_DecodeModule(t *testing.T) { tests := []struct { name string @@ -54,7 +57,7 @@ func TestRuntime_DecodeModule(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - code, err := r.CompileModule(tc.source) + code, err := r.CompileModule(testCtx, tc.source) require.NoError(t, err) defer code.Close() if tc.expectedName != "" { @@ -115,7 +118,7 @@ func TestRuntime_DecodeModule_Errors(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - _, err := tc.runtime.CompileModule(tc.source) + _, err := tc.runtime.CompileModule(testCtx, tc.source) require.EqualError(t, err, tc.expectedErr) }) } @@ -151,7 +154,7 @@ func TestModule_Memory(t *testing.T) { r := NewRuntime() t.Run(tc.name, func(t *testing.T) { // Instantiate the module and get the export of the above memory - module, err := tc.builder(r).Instantiate() + module, err := tc.builder(r).Instantiate(testCtx) require.NoError(t, err) defer module.Close() @@ -224,14 +227,14 @@ func TestModule_Global(t *testing.T) { if tc.module != nil { code = &CompiledCode{module: tc.module} } else { - code, _ = tc.builder(r).Build() + code, _ = tc.builder(r).Build(testCtx) } - err := r.store.Engine.CompileModule(code.module) + err := r.store.Engine.CompileModule(testCtx, code.module) require.NoError(t, err) // Instantiate the module and get the export of the above global - module, err := r.InstantiateModule(code) + module, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) defer module.Close() @@ -253,26 +256,20 @@ func TestModule_Global(t *testing.T) { } func TestFunction_Context(t *testing.T) { - type key string - runtimeCtx := context.WithValue(context.Background(), key("wa"), "zero") - config := NewRuntimeConfig().WithContext(runtimeCtx) - - notStoreCtx := context.WithValue(context.Background(), key("wazer"), "o") - tests := []struct { name string ctx context.Context expected context.Context }{ { - name: "nil defaults to runtime context", + name: "nil defaults to context.Background", ctx: nil, - expected: runtimeCtx, + expected: context.Background(), }, { - name: "set overrides runtime context", - ctx: notStoreCtx, - expected: notStoreCtx, + name: "set context", + ctx: testCtx, + expected: testCtx, }, } @@ -280,49 +277,48 @@ func TestFunction_Context(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - r := NewRuntimeWithConfig(config) + r := NewRuntime() // Define a host function so that we can catch the context propagated from a module function call functionName := "fn" expectedResult := uint64(math.MaxUint64) - hostFn := func(ctx api.Module) uint64 { - require.Equal(t, tc.expected, ctx.Context()) + hostFn := func(ctx context.Context) uint64 { + require.Equal(t, tc.expected, ctx) return expectedResult } source, closer := requireImportAndExportFunction(t, r, hostFn, functionName) defer closer() // nolint // Instantiate the module and get the export of the above hostFn - module, err := r.InstantiateModuleFromCodeWithConfig(source, NewModuleConfig().WithName(t.Name())) + module, err := r.InstantiateModuleFromCodeWithConfig(tc.ctx, source, NewModuleConfig().WithName(t.Name())) require.NoError(t, err) defer module.Close() // This fails if the function wasn't invoked, or had an unexpected context. - results, err := module.ExportedFunction(functionName).Call(module.WithContext(tc.ctx)) + results, err := module.ExportedFunction(functionName).Call(tc.ctx) require.NoError(t, err) require.Equal(t, expectedResult, results[0]) }) } } -func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { - type key string - runtimeCtx := context.WithValue(context.Background(), key("wa"), "zero") - config := NewRuntimeConfig().WithContext(runtimeCtx) - r := NewRuntimeWithConfig(config) +func TestRuntime_InstantiateModule_UsesContext(t *testing.T) { + r := NewRuntime() // Define a function that will be set as the start function var calledStart bool - start := func(ctx api.Module) { + start := func(ctx context.Context) { calledStart = true - require.Equal(t, runtimeCtx, ctx.Context()) + require.Equal(t, testCtx, ctx) } - env, err := r.NewModuleBuilder("env").ExportFunction("start", start).Instantiate() + env, err := r.NewModuleBuilder("env"). + ExportFunction("start", start). + Instantiate(testCtx) require.NoError(t, err) defer env.Close() - code, err := r.CompileModule([]byte(`(module $runtime_test.go + code, err := r.CompileModule(testCtx, []byte(`(module $runtime_test.go (import "env" "start" (func $start)) (start $start) )`)) @@ -330,7 +326,7 @@ func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { defer code.Close() // Instantiate the module, which calls the start function. This will fail if the context wasn't as intended. - m, err := r.InstantiateModule(code) + m, err := r.InstantiateModule(testCtx, code) require.NoError(t, err) defer m.Close() @@ -341,7 +337,7 @@ func TestRuntime_NewModule_UsesConfiguredContext(t *testing.T) { func TestInstantiateModuleFromCode_DoesntEnforce_Start(t *testing.T) { r := NewRuntime() - mod, err := r.InstantiateModuleFromCode([]byte(`(module $wasi_test.go + mod, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $wasi_test.go (memory 1) (export "memory" (memory 0)) )`)) @@ -349,24 +345,24 @@ func TestInstantiateModuleFromCode_DoesntEnforce_Start(t *testing.T) { require.NoError(t, mod.Close()) } -func TestInstantiateModuleFromCode_UsesRuntimeContext(t *testing.T) { - type key string - config := NewRuntimeConfig().WithContext(context.WithValue(context.Background(), key("wa"), "zero")) - r := NewRuntimeWithConfig(config) +func TestRuntime_InstantiateModuleFromCode_UsesContext(t *testing.T) { + r := NewRuntime() // Define a function that will be re-exported as the WASI function: _start var calledStart bool - start := func(ctx api.Module) { + start := func(ctx context.Context) { calledStart = true - require.Equal(t, config.ctx, ctx.Context()) + require.Equal(t, testCtx, ctx) } - host, err := r.NewModuleBuilder("").ExportFunction("start", start).Instantiate() + host, err := r.NewModuleBuilder(""). + ExportFunction("start", start). + Instantiate(testCtx) require.NoError(t, err) defer host.Close() // Start the module as a WASI command. This will fail if the context wasn't as intended. - mod, err := r.InstantiateModuleFromCode([]byte(`(module $start + mod, err := r.InstantiateModuleFromCode(testCtx, []byte(`(module $start (import "" "start" (func $start)) (memory 1) (export "_start" (func $start)) @@ -382,7 +378,7 @@ func TestInstantiateModuleFromCode_UsesRuntimeContext(t *testing.T) { // different names. This pattern is used in wapc-go. func TestInstantiateModuleWithConfig_WithName(t *testing.T) { r := NewRuntime() - base, err := r.CompileModule([]byte(`(module $0 (memory 1))`)) + base, err := r.CompileModule(testCtx, []byte(`(module $0 (memory 1))`)) require.NoError(t, err) defer base.Close() @@ -390,14 +386,14 @@ func TestInstantiateModuleWithConfig_WithName(t *testing.T) { // Use the same runtime to instantiate multiple modules internal := r.(*runtime).store - m1, err := r.InstantiateModuleWithConfig(base, NewModuleConfig().WithName("1")) + m1, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("1")) require.NoError(t, err) defer m1.Close() require.Nil(t, internal.Module("0")) require.Equal(t, internal.Module("1"), m1) - m2, err := r.InstantiateModuleWithConfig(base, NewModuleConfig().WithName("2")) + m2, err := r.InstantiateModuleWithConfig(testCtx, base, NewModuleConfig().WithName("2")) require.NoError(t, err) defer m2.Close() @@ -412,15 +408,15 @@ func TestInstantiateModuleWithConfig_ExitError(t *testing.T) { require.NoError(t, m.CloseWithExitCode(2)) } - _, err := r.NewModuleBuilder("env").ExportFunction("_start", start).Instantiate() + _, err := r.NewModuleBuilder("env").ExportFunction("_start", start).Instantiate(testCtx) // Ensure the exit error propagated and didn't wrap. require.Equal(t, err, sys.NewExitError("env", 2)) } // requireImportAndExportFunction re-exports a host function because only host functions can see the propagated context. -func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx api.Module) uint64, functionName string) ([]byte, func() error) { - mod, err := r.NewModuleBuilder("host").ExportFunction(functionName, hostFn).Instantiate() +func requireImportAndExportFunction(t *testing.T, r Runtime, hostFn func(ctx context.Context) uint64, functionName string) ([]byte, func() error) { + mod, err := r.NewModuleBuilder("host").ExportFunction(functionName, hostFn).Instantiate(testCtx) require.NoError(t, err) return []byte(fmt.Sprintf( @@ -434,7 +430,7 @@ func TestCompiledCode_Close(t *testing.T) { var cs []*CompiledCode for i := 0; i < 10; i++ { m := &wasm.Module{} - err := e.CompileModule(m) + err := e.CompileModule(testCtx, m) require.NoError(t, err) cs = append(cs, &CompiledCode{module: m, compiledEngine: e}) } @@ -465,7 +461,7 @@ func (e *mockEngine) DeleteCompiledModule(module *wasm.Module) { delete(e.cachedModules, module) } -func (e *mockEngine) CompileModule(module *wasm.Module) error { +func (e *mockEngine) CompileModule(_ context.Context, module *wasm.Module) error { e.cachedModules[module] = struct{}{} return nil }