diff --git a/builder.go b/builder.go index 87e65a08..c52a3833 100644 --- a/builder.go +++ b/builder.go @@ -247,17 +247,18 @@ func (b *moduleBuilder) Build() (*CompiledCode, error) { } } - if module, err := wasm.NewHostModule( - b.moduleName, - b.nameToGoFunc, - b.nameToMemory, - b.nameToGlobal, - b.r.enabledFeatures, - ); err != nil { + module, err := wasm.NewHostModule(b.moduleName, b.nameToGoFunc, b.nameToMemory, b.nameToGlobal, b.r.enabledFeatures) + if err != nil { return nil, err - } else { - return &CompiledCode{module: module}, nil } + + if err = b.r.store.Engine.CompileModule(module); err != nil { + return nil, err + } + + ret := &CompiledCode{module: module} + ret.addCacheEntry(module, b.r.store.Engine) + return &CompiledCode{module: module}, nil } // Instantiate implements ModuleBuilder.Instantiate @@ -265,6 +266,9 @@ func (b *moduleBuilder) Instantiate() (api.Module, error) { if module, err := b.Build(); err != nil { return nil, err } else { + if err = b.r.store.Engine.CompileModule(module.module); err != nil { + return nil, err + } // *wasm.ModuleInstance cannot be tracked, so we release the cache inside of this function. defer module.Close() return b.r.InstantiateModuleWithConfig(module, NewModuleConfig().WithName(b.moduleName)) diff --git a/builder_test.go b/builder_test.go index 187103a8..279e9869 100644 --- a/builder_test.go +++ b/builder_test.go @@ -343,9 +343,14 @@ func TestNewModuleBuilder_Build(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - m, e := tc.input(NewRuntime()).Build() - require.NoError(t, e) + b := tc.input(NewRuntime()).(*moduleBuilder) + m, err := b.Build() + require.NoError(t, err) requireHostModuleEquals(t, tc.expected, m.module) + + // Built module must be instantiable by Engine. + _, err = b.r.InstantiateModule(m) + require.NoError(t, err) }) } } diff --git a/config.go b/config.go index ea37b8aa..bf17564f 100644 --- a/config.go +++ b/config.go @@ -155,10 +155,12 @@ type CompiledCode struct { } // Close releases all the allocated resources for this CompiledCode. +// +// Note: it is safe to call Close while having outstanding calls from Modules instantiated from this *CompiledCode. func (c *CompiledCode) Close() { for engine, modules := range c.cachedEngines { for module := range modules { - engine.ReleaseCompilationCache(module) + engine.DeleteCompiledModule(module) } } } diff --git a/internal/integration_test/engine/adhoc_test.go b/internal/integration_test/engine/adhoc_test.go index 749ad94f..e9647ce9 100644 --- a/internal/integration_test/engine/adhoc_test.go +++ b/internal/integration_test/engine/adhoc_test.go @@ -277,8 +277,9 @@ func callReturnImportSource(importedModule, importingModule string) []byte { func testCloseInFlight(t *testing.T, r wazero.Runtime) { tests := []struct { - name, function string - closeImporting, closeImported uint32 + name, function string + closeImporting, closeImported uint32 + closeImportingCode, closeImportedCode bool }{ { // Ex. WASI proc_exit or AssemblyScript abort handler. name: "importing", @@ -292,11 +293,33 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { closeImporting: 1, closeImported: 2, }, + { // Ex. WASI proc_exit or AssemblyScript abort handler. + name: "importing", + function: "call_return_import", + closeImporting: 1, + closeImportedCode: true, + }, + { // Ex. WASI proc_exit or AssemblyScript abort handler. + name: "importing", + function: "call_return_import", + closeImporting: 1, + closeImportedCode: true, + closeImportingCode: true, + }, + // TODO: A module that re-exports a function (ex "return_input") can call it after it is closed! + { // Ex. A function that stops the runtime. + name: "both", + function: "call_return_import", + closeImporting: 1, + closeImported: 2, + closeImportingCode: true, + }, } for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { + var importingCode, importedCode *wazero.CompiledCode var imported, importing api.Module var err error closeAndReturn := func(x uint32) uint32 { @@ -306,18 +329,30 @@ func testCloseInFlight(t *testing.T, r wazero.Runtime) { if tc.closeImported != 0 { require.NoError(t, imported.CloseWithExitCode(tc.closeImported)) } + if tc.closeImportedCode { + importedCode.Close() + } + if tc.closeImportingCode { + importingCode.Close() + } return x } // Create the host module, which exports the function that closes the importing module. - imported, err = r.NewModuleBuilder(t.Name()+"-imported"). - ExportFunction("return_input", closeAndReturn).Instantiate() + importedCode, err = r.NewModuleBuilder(t.Name()+"-imported"). + ExportFunction("return_input", closeAndReturn).Build() + require.NoError(t, err) + + imported, err = r.InstantiateModule(importedCode) require.NoError(t, err) defer imported.Close() // Import that module. source := callReturnImportSource(imported.Name(), t.Name()+"-importing") - importing, err = r.InstantiateModuleFromCode(source) + importingCode, err = r.CompileModule(source) + require.NoError(t, err) + + importing, err = r.InstantiateModule(importingCode) require.NoError(t, err) defer importing.Close() diff --git a/internal/integration_test/spectest/spec_test.go b/internal/integration_test/spectest/spec_test.go index c437edf5..20d77f55 100644 --- a/internal/integration_test/spectest/spec_test.go +++ b/internal/integration_test/spectest/spec_test.go @@ -267,6 +267,9 @@ func addSpectestModule(t *testing.T, store *wasm.Store) { mod.TableSection = &wasm.Table{Min: 10, Max: &tableLimitMax} mod.ExportSection = append(mod.ExportSection, &wasm.Export{Name: "table", Index: 0, Type: wasm.ExternTypeTable}) + err = store.Engine.CompileModule(mod) + require.NoError(t, err) + _, err = store.Instantiate(context.Background(), mod, mod.NameSection.ModuleName, wasm.DefaultSysContext()) require.NoError(t, err) } @@ -334,6 +337,8 @@ func runTest(t *testing.T, newEngine func(wasm.Features) wasm.Engine) { moduleName = c.Filename } } + err = store.Engine.CompileModule(mod) + require.NoError(t, err, msg) moduleName = strings.TrimPrefix(moduleName, "$") _, err = store.Instantiate(context.Background(), mod, moduleName, nil) lastInstantiatedModuleName = moduleName @@ -458,6 +463,11 @@ func requireInstantiationError(t *testing.T, store *wasm.Store, buf []byte, msg return } + err = store.Engine.CompileModule(mod) + if err != nil { + return + } + _, err = store.Instantiate(context.Background(), mod, t.Name(), nil) require.Error(t, err, msg) } diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index 6b828055..0174b40e 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -37,33 +37,57 @@ type EngineTester interface { func RunTestEngine_NewModuleEngine(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) + t.Run("error before instantiation", func(t *testing.T) { + _, err := e.NewModuleEngine("mymod", &wasm.Module{}, nil, nil, nil, nil) + require.EqualError(t, err, "source module for mymod must be compiled before instantiation") + }) + t.Run("sets module name", func(t *testing.T) { - me, err := e.NewModuleEngine(t.Name(), nil, nil, nil, nil, nil) + m := &wasm.Module{} + err := e.CompileModule(m) + require.NoError(t, err) + me, err := e.NewModuleEngine(t.Name(), m, nil, nil, nil, nil) require.NoError(t, err) - defer me.Close() require.Equal(t, t.Name(), me.Name()) }) } +func getFunctionInstance(module *wasm.Module, index wasm.Index, moduleInstance *wasm.ModuleInstance) *wasm.FunctionInstance { + c := module.ImportFuncCount() + typeIndex := module.FunctionSection[index] + return &wasm.FunctionInstance{ + Kind: wasm.FunctionKindWasm, + Module: moduleInstance, + Type: module.TypeSection[typeIndex], + Body: module.CodeSection[index].Body, + LocalTypes: module.CodeSection[index].LocalTypes, + Index: index + c, + } +} + func RunTestModuleEngine_Call(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) // Define a basic function which defines one parameter. This is used to test results when incorrect arity is used. i64 := wasm.ValueTypeI64 - fn := &wasm.FunctionInstance{ - Kind: wasm.FunctionKindWasm, - Type: &wasm.FunctionType{Params: []wasm.ValueType{i64}, Results: []wasm.ValueType{i64}}, - Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, + m := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{Params: []wasm.ValueType{i64}, Results: []wasm.ValueType{i64}}}, + FunctionSection: []uint32{0}, + CodeSection: []*wasm.Code{{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeEnd}, LocalTypes: []wasm.ValueType{wasm.ValueTypeI64}}}, } + err := e.CompileModule(m) + require.NoError(t, err) + // To use the function, we first need to add it to a module. module := &wasm.ModuleInstance{Name: t.Name()} + fn := getFunctionInstance(m, 0, module) addFunction(module, "fn", fn) // Compile the module - me, err := e.NewModuleEngine(module.Name, nil, nil, module.Functions, nil, nil) + me, err := e.NewModuleEngine(module.Name, m, nil, module.Functions, nil, nil) + fn.Module.Engine = me require.NoError(t, err) - defer me.Close() linkModuleToEngine(module, me) // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. @@ -87,34 +111,46 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { t.Run("no table elements", func(t *testing.T) { table := &wasm.TableInstance{Min: 2, Table: make([]interface{}, 2)} - var importedFunctions []*wasm.FunctionInstance - var moduleFunctions []*wasm.FunctionInstance - var tableInit map[wasm.Index]wasm.Index + m := &wasm.Module{ + TypeSection: []*wasm.FunctionType{}, + FunctionSection: []uint32{}, + CodeSection: []*wasm.Code{}, + } + err := e.CompileModule(m) + require.NoError(t, err) // Instantiate the module, which has nothing but an empty table. - me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) + _, err = e.NewModuleEngine(t.Name(), m, nil, nil, table, nil) require.NoError(t, err) - defer me.Close() // Since there are no elements to initialize, we expect the table to be nil. require.Equal(t, table.Table, make([]interface{}, 2)) }) - t.Run("module-defined function", func(t *testing.T) { table := &wasm.TableInstance{Min: 2, Table: make([]interface{}, 2)} - var importedFunctions []*wasm.FunctionInstance + + m := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []uint32{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + }, + } + + err := e.CompileModule(m) + require.NoError(t, err) + moduleFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, + getFunctionInstance(m, 0, nil), + getFunctionInstance(m, 1, nil), + getFunctionInstance(m, 2, nil), + getFunctionInstance(m, 3, nil), } tableInit := map[wasm.Index]wasm.Index{0: 2} // Instantiate the module whose table points to its own functions. - me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) + me, err := e.NewModuleEngine(t.Name(), m, nil, moduleFunctions, table, tableInit) require.NoError(t, err) - defer me.Close() // The functions mapped to the table are defined in the same moduleEngine require.Equal(t, table.Table, et.InitTable(me, table.Min, tableInit)) @@ -122,24 +158,44 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { t.Run("imported function", func(t *testing.T) { table := &wasm.TableInstance{Min: 2, Table: make([]interface{}, 2)} + + importedModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []uint32{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + }, + } + + err := e.CompileModule(importedModule) + require.NoError(t, err) + + importedModuleInstance := &wasm.ModuleInstance{} importedFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, + getFunctionInstance(importedModule, 0, importedModuleInstance), + getFunctionInstance(importedModule, 1, importedModuleInstance), + getFunctionInstance(importedModule, 2, importedModuleInstance), + getFunctionInstance(importedModule, 3, importedModuleInstance), } var moduleFunctions []*wasm.FunctionInstance - tableInit := map[wasm.Index]wasm.Index{0: 2} // Imported functions are compiled before the importing module is instantiated. - imported, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) + imported, err := e.NewModuleEngine(t.Name(), importedModule, nil, importedFunctions, nil, nil) require.NoError(t, err) - defer imported.Close() + importedModuleInstance.Engine = imported // Instantiate the importing module, which is whose table is initialized. - importing, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) + importingModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{}, + FunctionSection: []uint32{}, + CodeSection: []*wasm.Code{}, + } + err = e.CompileModule(importingModule) + require.NoError(t, err) + + tableInit := map[wasm.Index]wasm.Index{0: 2} + importing, err := e.NewModuleEngine(t.Name(), importingModule, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) - defer importing.Close() // A moduleEngine's compiled function slice includes its imports, so the offsets is absolute. require.Equal(t, table.Table, et.InitTable(importing, table.Min, tableInit)) @@ -147,29 +203,53 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { t.Run("mixed functions", func(t *testing.T) { table := &wasm.TableInstance{Min: 2, Table: make([]interface{}, 2)} - importedFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, + + importedModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []uint32{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + }, } + + err := e.CompileModule(importedModule) + require.NoError(t, err) + importedModuleInstance := &wasm.ModuleInstance{} + importedFunctions := []*wasm.FunctionInstance{ + getFunctionInstance(importedModule, 0, importedModuleInstance), + getFunctionInstance(importedModule, 1, importedModuleInstance), + getFunctionInstance(importedModule, 2, importedModuleInstance), + getFunctionInstance(importedModule, 3, importedModuleInstance), + } + + // Imported functions are compiled before the importing module is instantiated. + imported, err := e.NewModuleEngine(t.Name(), importedModule, nil, importedFunctions, nil, nil) + require.NoError(t, err) + importedModuleInstance.Engine = imported + + importingModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []uint32{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, {Body: []byte{wasm.OpcodeEnd}}, + }, + } + + err = e.CompileModule(importingModule) + require.NoError(t, err) + + importingModuleInstance := &wasm.ModuleInstance{} moduleFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, + getFunctionInstance(importedModule, 0, importingModuleInstance), + getFunctionInstance(importedModule, 1, importingModuleInstance), + getFunctionInstance(importedModule, 2, importingModuleInstance), + getFunctionInstance(importedModule, 3, importingModuleInstance), } tableInit := map[wasm.Index]wasm.Index{0: 0, 1: 4} - // Imported functions are compiled before the importing module is instantiated. - imported, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) - require.NoError(t, err) - defer imported.Close() - // Instantiate the importing module, which is whose table is initialized. - importing, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, table, tableInit) + importing, err := e.NewModuleEngine(t.Name(), importingModule, importedFunctions, moduleFunctions, table, tableInit) require.NoError(t, err) - defer importing.Close() // A moduleEngine's compiled function slice includes its imports, so the offsets are absolute. require.Equal(t, table.Table, et.InitTable(importing, table.Min, tableInit)) @@ -177,6 +257,14 @@ func RunTestEngine_NewModuleEngine_InitTable(t *testing.T, et EngineTester) { } func runTestModuleEngine_Call_HostFn_ModuleContext(t *testing.T, et EngineTester) { + features := wasm.Features20191205 + e := et.NewEngine(features) + + sig := &wasm.FunctionType{ + Params: []wasm.ValueType{wasm.ValueTypeI64}, + Results: []wasm.ValueType{wasm.ValueTypeI64}, + } + memory := &wasm.MemoryInstance{} var ctxMemory api.Memory hostFn := reflect.ValueOf(func(ctx api.Module, v uint64) uint64 { @@ -184,27 +272,28 @@ func runTestModuleEngine_Call_HostFn_ModuleContext(t *testing.T, et EngineTester return v }) - features := wasm.Features20191205 - e := et.NewEngine(features) + m := &wasm.Module{ + HostFunctionSection: []*reflect.Value{&hostFn}, + FunctionSection: []wasm.Index{0}, + TypeSection: []*wasm.FunctionType{sig}, + } + + err := e.CompileModule(m) + require.NoError(t, err) + module := &wasm.ModuleInstance{Memory: memory} modCtx := wasm.NewModuleContext(context.Background(), wasm.NewStore(features, e), module, nil) f := &wasm.FunctionInstance{ GoFunc: &hostFn, Kind: wasm.FunctionKindGoModule, - Type: &wasm.FunctionType{ - Params: []wasm.ValueType{wasm.ValueTypeI64}, - Results: []wasm.ValueType{wasm.ValueTypeI64}, - }, + Type: sig, Module: module, Index: 0, } - module.Types = []*wasm.TypeInstance{{Type: f.Type}} - module.Functions = []*wasm.FunctionInstance{f} - me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, module.Functions, nil, nil) + me, err := e.NewModuleEngine(t.Name(), m, nil, []*wasm.FunctionInstance{f}, nil, nil) require.NoError(t, err) - defer me.Close() t.Run("defaults to module memory when call stack empty", func(t *testing.T) { // When calling a host func directly, there may be no stack. This ensures the module's memory is used. @@ -220,9 +309,8 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) - imported, importedMe, importing, importingMe := setupCallTests(t, e) - defer importingMe.Close() - defer importedMe.Close() + host, imported, importing, close := setupCallTests(t, e) + defer close() // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. tests := []struct { @@ -237,8 +325,8 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { }, { name: hostFnName, - module: imported.Ctx, - fn: imported.Exports[hostFnName].Function, + module: host.Ctx, + fn: host.Exports[hostFnName].Function, }, { name: callHostFnName, @@ -267,9 +355,8 @@ func RunTestModuleEngine_Call_HostFn(t *testing.T, et EngineTester) { func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { e := et.NewEngine(wasm.Features20191205) - imported, importedMe, importing, importingMe := setupCallTests(t, e) - defer importingMe.Close() - defer importedMe.Close() + host, imported, importing, close := setupCallTests(t, e) + defer close() tests := []struct { name string @@ -281,15 +368,15 @@ func RunTestModuleEngine_Call_Errors(t *testing.T, et EngineTester) { { name: "host function not enough parameters", input: []uint64{}, - module: imported.Ctx, - fn: imported.Exports[hostFnName].Function, + module: host.Ctx, + fn: host.Exports[hostFnName].Function, expectedErr: `expected 1 params, but passed 0`, }, { name: "host function too many parameters", input: []uint64{1, 2}, - module: imported.Ctx, - fn: imported.Exports[hostFnName].Function, + module: host.Ctx, + fn: host.Exports[hostFnName].Function, expectedErr: `expected 1 params, but passed 2`, }, { @@ -318,20 +405,20 @@ wasm stack trace: { name: "host function that panics", input: []uint64{math.MaxUint32}, - module: imported.Ctx, - fn: imported.Exports[hostFnName].Function, + module: host.Ctx, + fn: host.Exports[hostFnName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32`, + host.host_div_by(i32) i32`, }, { name: "host function panics with runtime.Error", input: []uint64{0}, - module: imported.Ctx, - fn: imported.Exports[hostFnName].Function, + module: host.Ctx, + fn: host.Exports[hostFnName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32`, + host.host_div_by(i32) i32`, }, { name: "wasm calls host function that panics", @@ -340,7 +427,7 @@ wasm stack trace: fn: imported.Exports[callHostFnName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32 + host.host_div_by(i32) i32 imported.call->host_div_by(i32) i32`, }, { @@ -350,7 +437,7 @@ wasm stack trace: fn: importing.Exports[callImportCallHostFnName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32 + host.host_div_by(i32) i32 imported.call->host_div_by(i32) i32 importing.call_import->call->host_div_by(i32) i32`, }, @@ -361,7 +448,7 @@ wasm stack trace: fn: importing.Exports[callImportCallHostFnName].Function, expectedErr: `host-function panic (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32 + host.host_div_by(i32) i32 imported.call->host_div_by(i32) i32 importing.call_import->call->host_div_by(i32) i32`, }, @@ -372,7 +459,7 @@ wasm stack trace: fn: importing.Exports[callImportCallHostFnName].Function, expectedErr: `runtime error: integer divide by zero (recovered by wazero) wasm stack trace: - imported.host_div_by(i32) i32 + host.host_div_by(i32) i32 imported.call->host_div_by(i32) i32 importing.call_import->call->host_div_by(i32) i32`, }, @@ -410,64 +497,80 @@ func divBy(d uint32) uint32 { return 1 / d // go panics if d == 0 } -func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, wasm.ModuleEngine, *wasm.ModuleInstance, wasm.ModuleEngine) { +func setupCallTests(t *testing.T, e wasm.Engine) (*wasm.ModuleInstance, *wasm.ModuleInstance, *wasm.ModuleInstance, func()) { i32 := wasm.ValueTypeI32 ft := &wasm.FunctionType{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}} - wasmFn := &wasm.FunctionInstance{ - Kind: wasm.FunctionKindWasm, - Type: ft, - Body: wasmFnBody, - Index: 0, - } + hostFnVal := reflect.ValueOf(divBy) - hostFn := &wasm.FunctionInstance{ - Kind: wasm.FunctionKindGoNoContext, - Type: ft, - GoFunc: &hostFnVal, - Index: 1, + hostFnModule := &wasm.Module{ + HostFunctionSection: []*reflect.Value{&hostFnVal}, + TypeSection: []*wasm.FunctionType{ft}, + FunctionSection: []wasm.Index{0}, } - callHostFn := &wasm.FunctionInstance{ - Kind: wasm.FunctionKindWasm, - Type: ft, - Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, byte(hostFn.Index), wasm.OpcodeEnd}, - Index: 2, + + err := e.CompileModule(hostFnModule) + require.NoError(t, err) + hostFn := &wasm.FunctionInstance{GoFunc: &hostFnVal, Kind: wasm.FunctionKindGoNoContext, Type: ft} + hostFnModuleInstance := &wasm.ModuleInstance{Name: "host"} + addFunction(hostFnModuleInstance, hostFnName, hostFn) + hostFnME, err := e.NewModuleEngine(hostFnModuleInstance.Name, hostFnModule, nil, hostFnModuleInstance.Functions, nil, nil) + require.NoError(t, err) + linkModuleToEngine(hostFnModuleInstance, hostFnME) + + importedModule := &wasm.Module{ + ImportSection: []*wasm.Import{{}}, + TypeSection: []*wasm.FunctionType{ft}, + FunctionSection: []uint32{0, 0}, + CodeSection: []*wasm.Code{ + {Body: wasmFnBody}, + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, byte(0), // Calling imported host function ^. + wasm.OpcodeEnd}}, + }, } + err = e.CompileModule(importedModule) + require.NoError(t, err) + // To use the function, we first need to add it to a module. imported := &wasm.ModuleInstance{Name: "imported"} - addFunction(imported, wasmFnName, wasmFn) - addFunction(imported, hostFnName, hostFn) + addFunction(imported, wasmFnName, getFunctionInstance(importedModule, 0, imported)) + callHostFn := getFunctionInstance(importedModule, 1, imported) addFunction(imported, callHostFnName, callHostFn) // Compile the imported module - importedMe, err := e.NewModuleEngine(imported.Name, &wasm.Module{}, nil, imported.Functions, nil, nil) + importedMe, err := e.NewModuleEngine(imported.Name, importedModule, hostFnModuleInstance.Functions, imported.Functions, nil, nil) require.NoError(t, err) linkModuleToEngine(imported, importedMe) // To test stack traces, call the same function from another module - importing := &wasm.ModuleInstance{Name: "importing"} - - // Don't add imported functions yet as NewModuleEngine requires them split. - importedFunctions := []*wasm.FunctionInstance{callHostFn} + importingModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{ft}, + FunctionSection: []uint32{0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 0 /* only one imported function */, wasm.OpcodeEnd}}, + }, + ImportSection: []*wasm.Import{{}}, + } + err = e.CompileModule(importingModule) + require.NoError(t, err) // Add the exported function. - callImportedHostFn := &wasm.FunctionInstance{ - Kind: wasm.FunctionKindWasm, - Type: ft, - Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeCall, 0 /* only one imported function */, wasm.OpcodeEnd}, - Index: 1, // after import - } - addFunction(importing, callImportCallHostFnName, callImportedHostFn) + importing := &wasm.ModuleInstance{Name: "importing"} + addFunction(importing, callImportCallHostFnName, getFunctionInstance(importedModule, 0, importing)) // Compile the importing module - importingMe, err := e.NewModuleEngine(importing.Name, &wasm.Module{}, importedFunctions, importing.Functions, nil, nil) + importingMe, err := e.NewModuleEngine(importing.Name, importingModule, []*wasm.FunctionInstance{callHostFn}, importing.Functions, nil, nil) require.NoError(t, err) linkModuleToEngine(importing, importingMe) // Add the imported functions back to the importing module. - importing.Functions = append(importedFunctions, importing.Functions...) + importing.Functions = append([]*wasm.FunctionInstance{callHostFn}, importing.Functions...) - return imported, importedMe, importing, importingMe + return hostFnModuleInstance, imported, importing, func() { + e.DeleteCompiledModule(hostFnModule) + e.DeleteCompiledModule(importedModule) + e.DeleteCompiledModule(importingModule) + } } // linkModuleToEngine assigns fields that wasm.Store would on instantiation. These includes fields both interpreter and diff --git a/internal/wasm/engine.go b/internal/wasm/engine.go index 46138fc5..b54253d0 100644 --- a/internal/wasm/engine.go +++ b/internal/wasm/engine.go @@ -3,6 +3,9 @@ package wasm // Engine is a Store-scoped mechanism to compile functions declared or imported by a module. // This is a top-level type implemented by an interpreter or JIT compiler. type Engine interface { + // CompileModule implements the same method as documented on wasm.Engine. + CompileModule(module *Module) error + // NewModuleEngine compiles down the function instances in a module, and returns ModuleEngine for the module. // // * name is the name the module was instantiated with used for error handling. @@ -16,8 +19,10 @@ type Engine interface { // due to reasons such as out-of-bounds. NewModuleEngine(name string, module *Module, importedFunctions, moduleFunctions []*FunctionInstance, table *TableInstance, tableInit map[Index]Index) (ModuleEngine, error) - // ReleaseCompilationCache releases compilation caches for the given module (source). - ReleaseCompilationCache(module *Module) + // DeleteCompiledModule releases compilation caches for the given module (source). + // Note: it is safe to call this function for a module from which module instances are instantiated even when these module instances + // are having outstanding calls. + DeleteCompiledModule(module *Module) } // ModuleEngine implements function calls for a given module. @@ -29,7 +34,4 @@ type ModuleEngine interface { // Returns the results from the function. // The ctx's context.Context will be the outer-most ancestor of the argument to api.Function. Call(ctx *ModuleContext, f *FunctionInstance, params ...uint64) (results []uint64, err error) - - // Close releases all the function instances declared in this module. - Close() } diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 43a107eb..bea11d33 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -729,7 +729,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("invalid numeric instruction 0x%x", op) } } else if op == OpcodeBlock { - bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) + bt, num, err := DecodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -741,7 +741,7 @@ func (m *Module) validateFunctionWithMaxStackValues( valueTypeStack.pushStackLimit(len(bt.Params)) pc += num } else if op == OpcodeLoop { - bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) + bt, num, err := DecodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -761,7 +761,7 @@ func (m *Module) validateFunctionWithMaxStackValues( valueTypeStack.pushStackLimit(len(bt.Params)) pc += num } else if op == OpcodeIf { - bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) + bt, num, err := DecodeBlockType(types, bytes.NewReader(body[pc+1:]), enabledFeatures) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -1117,7 +1117,7 @@ type controlBlock struct { op Opcode } -func decodeBlockType(types []*FunctionType, r *bytes.Reader, enabledFeatures Features) (*FunctionType, uint64, error) { +func DecodeBlockType(types []*FunctionType, r *bytes.Reader, enabledFeatures Features) (*FunctionType, uint64, error) { return decodeBlockTypeImpl(func(index int64) (*FunctionType, error) { if index < 0 || (index >= int64(len(types))) { return nil, fmt.Errorf("type index out of range: %d", index) @@ -1126,16 +1126,6 @@ func decodeBlockType(types []*FunctionType, r *bytes.Reader, enabledFeatures Fea }, r, enabledFeatures) } -// DecodeBlockType is exported for use in the compiler -func DecodeBlockType(types []*TypeInstance, r *bytes.Reader, enabledFeatures Features) (*FunctionType, uint64, error) { - return decodeBlockTypeImpl(func(index int64) (*FunctionType, error) { - if index < 0 || (index >= int64(len(types))) { - return nil, fmt.Errorf("type index out of range: %d", index) - } - return types[index].Type, nil - }, r, enabledFeatures) -} - // decodeBlockTypeImpl decodes the type index from a positive 33-bit signed integer. Negative numbers indicate up to one // WebAssembly 1.0 (20191205) compatible result type. Positive numbers are decoded when `enabledFeatures` include // FeatureMultiValue and include an index in the Module.TypeSection. diff --git a/internal/wasm/host.go b/internal/wasm/host.go index 38053261..c330c163 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -67,6 +67,10 @@ func NewHostModule( return } +func (m *Module) IsHostModule() bool { + return len(m.HostFunctionSection) > 0 +} + func addFuncs(m *Module, nameToGoFunc map[string]interface{}, enabledFeatures Features) error { funcCount := uint32(len(nameToGoFunc)) funcNames := make([]string, 0, funcCount) diff --git a/internal/wasm/interpreter/interpreter.go b/internal/wasm/interpreter/interpreter.go index 2abafea4..4c144895 100644 --- a/internal/wasm/interpreter/interpreter.go +++ b/internal/wasm/interpreter/interpreter.go @@ -21,60 +21,39 @@ var callStackCeiling = buildoptions.CallStackCeiling // engine is an interpreter implementation of wasm.Engine type engine struct { - enabledFeatures wasm.Features - compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. - cachedCompiledFunctionsPerModule map[*wasm.Module][]*compiledFunction // guarded by mutex. - mux sync.RWMutex + enabledFeatures wasm.Features + codes map[*wasm.Module][]*code // guarded by mutex. + mux sync.RWMutex } func NewEngine(enabledFeatures wasm.Features) wasm.Engine { return &engine{ - enabledFeatures: enabledFeatures, - compiledFunctions: make(map[*wasm.FunctionInstance]*compiledFunction), - cachedCompiledFunctionsPerModule: map[*wasm.Module][]*compiledFunction{}, + enabledFeatures: enabledFeatures, + codes: map[*wasm.Module][]*code{}, } } -// ReleaseCompilationCache implements the same method as documented on wasm.Engine. -func (e *engine) ReleaseCompilationCache(m *wasm.Module) { - e.deleteCachedCompiledFunctions(m) +// DeleteCompiledModule implements the same method as documented on wasm.Engine. +func (e *engine) DeleteCompiledModule(m *wasm.Module) { + e.deleteCodes(m) } -func (e *engine) deleteCompiledFunction(f *wasm.FunctionInstance) { +func (e *engine) deleteCodes(module *wasm.Module) { e.mux.Lock() defer e.mux.Unlock() - delete(e.compiledFunctions, f) + delete(e.codes, module) } -func (e *engine) getCompiledFunction(f *wasm.FunctionInstance) (cf *compiledFunction, ok bool) { +func (e *engine) addCodes(module *wasm.Module, fs []*code) { + e.mux.Lock() + defer e.mux.Unlock() + e.codes[module] = fs +} + +func (e *engine) getCodes(module *wasm.Module) (fs []*code, ok bool) { e.mux.RLock() defer e.mux.RUnlock() - cf, ok = e.compiledFunctions[f] - return -} - -func (e *engine) addCompiledFunction(f *wasm.FunctionInstance, cf *compiledFunction) { - e.mux.Lock() - defer e.mux.Unlock() - e.compiledFunctions[f] = cf -} - -func (e *engine) deleteCachedCompiledFunctions(module *wasm.Module) { - e.mux.Lock() - defer e.mux.Unlock() - delete(e.cachedCompiledFunctionsPerModule, module) -} - -func (e *engine) addCachedCompiledFunctions(module *wasm.Module, fs []*compiledFunction) { - e.mux.Lock() - defer e.mux.Unlock() - e.cachedCompiledFunctionsPerModule[module] = fs -} - -func (e *engine) getCachedCompiledFunctions(module *wasm.Module) (fs []*compiledFunction, ok bool) { - e.mux.RLock() - defer e.mux.RUnlock() - fs, ok = e.cachedCompiledFunctionsPerModule[module] + fs, ok = e.codes[module] return } @@ -83,9 +62,9 @@ type moduleEngine struct { // name is the name the module was instantiated with used for error handling. name string - // compiledFunctions are the compiled functions in a module instances. + // codes are the compiled functions in a module instances. // The index is module instance-scoped. - compiledFunctions []*compiledFunction + functions []*function // parentEngine holds *engine from which this module engine is created from. parentEngine *engine @@ -158,25 +137,28 @@ func (ce *callEngine) popFrame() (frame *callFrame) { } type callFrame struct { - // pc is the program counter representing the current position in compiledFunction.body. + // pc is the program counter representing the current position in code.body. pc uint64 // f is the compiled function used in this function frame. - f *compiledFunction + f *function } -type compiledFunction struct { - moduleEngine *moduleEngine - source *wasm.FunctionInstance - body []*interpreterOp - hostFn *reflect.Value +type code struct { + body []*interpreterOp + hostFn *reflect.Value } -func (c *compiledFunction) clone(me *moduleEngine, newSourceInstance *wasm.FunctionInstance) *compiledFunction { - return &compiledFunction{ - moduleEngine: me, - source: newSourceInstance, - body: c.body, - hostFn: c.hostFn, +type function struct { + source *wasm.FunctionInstance + body []*interpreterOp + hostFn *reflect.Value +} + +func (c *code) instantiate(f *wasm.FunctionInstance) *function { + return &function{ + source: f, + body: c.body, + hostFn: c.hostFn, } } @@ -188,78 +170,73 @@ type interpreterOp struct { rs []*wazeroir.InclusiveRange } +// CompileModule implements the same method as documented on wasm.Engine. +func (e *engine) CompileModule(module *wasm.Module) error { + if _, ok := e.getCodes(module); ok { // cache hit! + return nil + } + + funcs := make([]*code, 0, len(module.FunctionSection)) + if module.IsHostModule() { + // If this is the host module, there's nothing to do as the runtime reprsentation of + // host function in interpreter is its Go function itself as opposed to Wasm functions, + // which need to be compiled down to wazeroir. + for _, hf := range module.HostFunctionSection { + funcs = append(funcs, &code{hostFn: hf}) + } + } else { + irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + if err != nil { + return err + } + for i, ir := range irs { + compiled, err := e.lowerIR(ir) + if err != nil { + return fmt.Errorf("function[%d/%d] failed to convert wazeroir operations: %w", i, len(module.FunctionSection)-1, err) + } + funcs = append(funcs, compiled) + } + } + e.addCodes(module, funcs) + return nil + +} + // NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *engine) NewModuleEngine(name string, source *wasm.Module, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { +func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunctions, moduleFunctions []*wasm.FunctionInstance, table *wasm.TableInstance, tableInit map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { imported := uint32(len(importedFunctions)) me := &moduleEngine{ name: name, - compiledFunctions: make([]*compiledFunction, 0, imported+uint32(len(moduleFunctions))), parentEngine: e, importedFunctionCount: imported, } - for idx, f := range importedFunctions { - cf, ok := e.getCompiledFunction(f) - if !ok { - return nil, fmt.Errorf("import[%d] func[%s]: uncompiled", idx, f.DebugName) - } - me.compiledFunctions = append(me.compiledFunctions, cf) + for _, f := range importedFunctions { + cf := f.Module.Engine.(*moduleEngine).functions[f.Index] + me.functions = append(me.functions, cf) } - if cached, ok := e.getCachedCompiledFunctions(source); ok { // cache hit - for i, c := range cached[len(importedFunctions):] { - cloned := c.clone(me, moduleFunctions[i]) - me.compiledFunctions = append(me.compiledFunctions, cloned) - } - } else { // cache miss - for i, f := range moduleFunctions { - var compiled *compiledFunction - if f.Kind == wasm.FunctionKindWasm { - ir, err := wazeroir.Compile(e.enabledFeatures, f) - if err != nil { - me.Close() - // TODO(Adrian): extract Module.funcDesc so that errors here have more context - return nil, fmt.Errorf("function[%d/%d] failed to lower to wazeroir: %w", i, len(moduleFunctions)-1, err) - } + codes, ok := e.getCodes(module) + if !ok { + return nil, fmt.Errorf("source module for %s must be compiled before instantiation", name) + } - compiled, err = e.lowerIROps(f, ir.Operations) - if err != nil { - me.Close() - return nil, fmt.Errorf("function[%d/%d] failed to convert wazeroir operations: %w", i, len(moduleFunctions)-1, err) - } - } else { - compiled = &compiledFunction{hostFn: f.GoFunc, source: f} - } - compiled.moduleEngine = me - me.compiledFunctions = append(me.compiledFunctions, compiled) - - // Add the compiled function to the store-wide engine as well so that - // the future importing module can refer the function instance. - e.addCompiledFunction(f, compiled) - } - - e.addCachedCompiledFunctions(source, me.compiledFunctions) + for i, c := range codes { + f := moduleFunctions[i] + insntantiatedcode := c.instantiate(f) + me.functions = append(me.functions, insntantiatedcode) } for elemIdx, funcidx := range tableInit { // Initialize any elements with compiled functions - table.Table[elemIdx] = me.compiledFunctions[funcidx] + table.Table[elemIdx] = me.functions[funcidx] } return me, nil } -// Release implements wasm.Engine Release -func (me *moduleEngine) Release() error { - // Release all the function instances declared in this module. - for _, cf := range me.compiledFunctions[me.importedFunctionCount:] { - me.parentEngine.deleteCompiledFunction(cf.source) - } - return nil -} - -// lowerIROps lowers the wazeroir operations to engine friendly struct. -func (e *engine) lowerIROps(f *wasm.FunctionInstance, - ops []wazeroir.Operation) (*compiledFunction, error) { - ret := &compiledFunction{source: f} +// lowerIR lowers the wazeroir operations to engine friendly struct. +func (e *engine) lowerIR(ir *wazeroir.CompilationResult) (*code, error) { + ops := ir.Operations + ret := &code{} labelAddress := map[string]uint64{} onLabelAddressResolved := map[string][]func(addr uint64){} for _, original := range ops { @@ -352,9 +329,8 @@ func (e *engine) lowerIROps(f *wasm.FunctionInstance, op.us = make([]uint64, 1) op.us = []uint64{uint64(o.FunctionIndex)} case *wazeroir.OperationCallIndirect: - op.us = make([]uint64, 2) - op.us[0] = uint64(o.TableIndex) - op.us[1] = uint64(f.Module.Types[o.TypeIndex].TypeID) + op.us = make([]uint64, 1) + op.us[0] = uint64(o.TypeIndex) case *wazeroir.OperationDrop: op.rs = make([]*wazeroir.InclusiveRange, 1) op.rs[0] = o.Depth @@ -538,11 +514,11 @@ func (me *moduleEngine) Name() string { // Call implements the same method as documented on wasm.ModuleEngine. func (me *moduleEngine) Call(m *wasm.ModuleContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to - // compiledFunctions on close aren't locked, neither is this read. - compiled := me.compiledFunctions[f.Index] + // codes on close aren't locked, neither is this read. + compiled := me.functions[f.Index] if compiled == nil { // Lazy check the cause as it could be because the module was already closed. if err = m.FailIfClosed(); err == nil { - panic(fmt.Errorf("BUG: %s.compiledFunctions[%d] was nil before close", me.name, f.Index)) + panic(fmt.Errorf("BUG: %s.codes[%d] was nil before close", me.name, f.Index)) } return } @@ -588,7 +564,7 @@ func (me *moduleEngine) Call(m *wasm.ModuleContext, f *wasm.FunctionInstance, pa return } -func (ce *callEngine) callHostFunc(ctx *wasm.ModuleContext, f *compiledFunction) { +func (ce *callEngine) callHostFunc(ctx *wasm.ModuleContext, f *function) { tp := f.hostFn.Type() in := make([]reflect.Value, tp.NumIn()) @@ -642,13 +618,14 @@ func (ce *callEngine) callHostFunc(ctx *wasm.ModuleContext, f *compiledFunction) ce.popFrame() } -func (ce *callEngine) callNativeFunc(ctx *wasm.ModuleContext, f *compiledFunction) { +func (ce *callEngine) callNativeFunc(ctx *wasm.ModuleContext, f *function) { frame := &callFrame{f: f} moduleInst := f.source.Module memoryInst := moduleInst.Memory globals := moduleInst.Globals table := moduleInst.Table - compiledFunctions := f.moduleEngine.compiledFunctions + typeIDs := f.source.Module.TypeIDs + functions := f.source.Module.Engine.(*moduleEngine).functions ce.pushFrame(frame) bodyLen := uint64(len(frame.f.body)) for frame.pc < bodyLen { @@ -686,7 +663,7 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.ModuleContext, f *compiledFunctio } case wazeroir.OperationKindCall: { - f := compiledFunctions[op.us[0]] + f := functions[op.us[0]] if f.hostFn != nil { ce.callHostFunc(ctx, f) } else { @@ -700,18 +677,18 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.ModuleContext, f *compiledFunctio if offset >= uint64(len(table.Table)) { panic(wasmruntime.ErrRuntimeInvalidTableAccess) } - targetCompiledFunction, ok := table.Table[offset].(*compiledFunction) + targetcode, ok := table.Table[offset].(*function) if !ok { panic(wasmruntime.ErrRuntimeInvalidTableAccess) - } else if uint64(targetCompiledFunction.source.TypeID) != op.us[1] { + } else if targetcode.source.TypeID != typeIDs[op.us[0]] { panic(wasmruntime.ErrRuntimeIndirectCallTypeMismatch) } // Call in. - if targetCompiledFunction.hostFn != nil { - ce.callHostFunc(ctx, targetCompiledFunction) + if targetcode.hostFn != nil { + ce.callHostFunc(ctx, targetcode) } else { - ce.callNativeFunc(ctx, targetCompiledFunction) + ce.callNativeFunc(ctx, targetcode) } frame.pc++ } @@ -1621,10 +1598,3 @@ func (ce *callEngine) callNativeFunc(ctx *wasm.ModuleContext, f *compiledFunctio } ce.popFrame() } - -// Close releases all the function instances declared in this module. -func (me *moduleEngine) Close() { - for _, cf := range me.compiledFunctions[me.importedFunctionCount:] { - me.parentEngine.deleteCompiledFunction(cf.source) - } -} diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index 6ede0c41..b47c3571 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -3,9 +3,7 @@ package interpreter import ( "fmt" "math" - "strconv" "testing" - "unsafe" "github.com/tetratelabs/wazero/internal/buildoptions" "github.com/tetratelabs/wazero/internal/testing/enginetest" @@ -61,7 +59,7 @@ func (e engineTester) InitTable(me wasm.ModuleEngine, initTableLen uint32, initT table := make([]interface{}, initTableLen) internal := me.(*moduleEngine) for idx, fnidx := range initTableIdxToFnIdx { - table[idx] = internal.compiledFunctions[fnidx] + table[idx] = internal.functions[fnidx] } return table } @@ -129,9 +127,8 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { tc := tc t.Run(fmt.Sprintf("%s(i32.const(0x%x))", wasm.InstructionName(tc.opcode), tc.in), func(t *testing.T) { ce := &callEngine{} - f := &compiledFunction{ - moduleEngine: &moduleEngine{}, - source: &wasm.FunctionInstance{Module: &wasm.ModuleInstance{}}, + f := &function{ + source: &wasm.FunctionInstance{Module: &wasm.ModuleInstance{Engine: &moduleEngine{}}}, body: []*interpreterOp{ {kind: wazeroir.OperationKindConstI32, us: []uint64{uint64(uint32(tc.in))}}, {kind: translateToIROperationKind(tc.opcode)}, @@ -182,9 +179,8 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { tc := tc t.Run(fmt.Sprintf("%s(i64.const(0x%x))", wasm.InstructionName(tc.opcode), tc.in), func(t *testing.T) { ce := &callEngine{} - f := &compiledFunction{ - moduleEngine: &moduleEngine{}, - source: &wasm.FunctionInstance{Module: &wasm.ModuleInstance{}}, + f := &function{ + source: &wasm.FunctionInstance{Module: &wasm.ModuleInstance{Engine: &moduleEngine{}}}, body: []*interpreterOp{ {kind: wazeroir.OperationKindConstI64, us: []uint64{uint64(tc.in)}}, {kind: translateToIROperationKind(tc.opcode)}, @@ -198,215 +194,81 @@ func TestInterpreter_CallEngine_callNativeFunc_signExtend(t *testing.T) { }) } -func TestInterpreter_EngineCompile_Errors(t *testing.T) { - t.Run("invalid import", func(t *testing.T) { +func TestInterpreter_Compile(t *testing.T) { + t.Run("uncompiled", func(t *testing.T) { e := et.NewEngine(wasm.Features20191205).(*engine) - _, err := e.NewModuleEngine(t.Name(), + _, err := e.NewModuleEngine("foo", &wasm.Module{}, - []*wasm.FunctionInstance{{Module: &wasm.ModuleInstance{Name: "uncompiled"}, DebugName: "uncompiled.fn"}}, + nil, // imports nil, // moduleFunctions nil, // table nil, // tableInit ) - require.EqualError(t, err, "import[0] func[uncompiled.fn]: uncompiled") + require.EqualError(t, err, "source module for foo must be compiled before instantiation") }) - - t.Run("release on compilation error", func(t *testing.T) { + t.Run("fail", func(t *testing.T) { e := et.NewEngine(wasm.Features20191205).(*engine) - importedFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, + errModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []wasm.Index{0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeCall}}, // Call instruction without immediate for call target index is invalid and should fail to compile. + }, } - // initialize the module-engine containing imported functions - _, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) - require.NoError(t, err) - - require.Equal(t, len(importedFunctions), len(e.compiledFunctions)) - - moduleFunctions := []*wasm.FunctionInstance{ - {DebugName: "ok1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "ok2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "invalid code", Type: &wasm.FunctionType{}, Body: []byte{ - wasm.OpcodeCall, // Call instruction without immediate for call target index is invalid and should fail to compile. - }, Module: &wasm.ModuleInstance{}}, - } - - _, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) - require.EqualError(t, err, "function[2/2] failed to lower to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") + err := e.CompileModule(errModule) + require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") // On the compilation failure, all the compiled functions including succeeded ones must be released. - require.Equal(t, len(importedFunctions), len(e.compiledFunctions)) - for _, f := range moduleFunctions { - require.Nil(t, e.compiledFunctions[f]) + _, ok := e.codes[errModule] + require.False(t, ok) + }) + t.Run("ok", func(t *testing.T) { + e := et.NewEngine(wasm.Features20191205).(*engine) + + okModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []wasm.Index{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + }, } + err := e.CompileModule(okModule) + require.NoError(t, err) + + compiled, ok := e.codes[okModule] + require.True(t, ok) + require.Equal(t, len(okModule.FunctionSection), len(compiled)) + + _, ok = e.codes[okModule] + require.True(t, ok) }) } -func TestInterpreter_Close(t *testing.T) { - for _, tc := range []struct { - name string - importedFunctions, moduleFunctions []*wasm.FunctionInstance - }{ - { - name: "only module-defined", - moduleFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - }, - { - name: "only imports", - importedFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - }, - { - name: "imports and module-defined", - importedFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - moduleFunctions: []*wasm.FunctionInstance{newFunctionInstance(100), newFunctionInstance(200), newFunctionInstance(300)}, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - e := et.NewEngine(wasm.Features20191205).(*engine) - if len(tc.importedFunctions) > 0 { - // initialize the module-engine containing imported functions - me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, tc.importedFunctions, nil, nil) - require.NoError(t, err) - require.Equal(t, len(tc.importedFunctions), len(me.(*moduleEngine).compiledFunctions)) - } - - me, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, tc.importedFunctions, tc.moduleFunctions, nil, nil) - require.NoError(t, err) - require.Equal(t, len(tc.importedFunctions)+len(tc.moduleFunctions), len(me.(*moduleEngine).compiledFunctions)) - - require.Equal(t, len(tc.importedFunctions)+len(tc.moduleFunctions), len(e.compiledFunctions)) - for _, f := range tc.importedFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - for _, f := range tc.moduleFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - - me.Close() - - require.Equal(t, len(tc.importedFunctions), len(e.compiledFunctions)) - for _, f := range tc.importedFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - for i, f := range tc.moduleFunctions { - require.Nil(t, e.compiledFunctions[f], i) - } - }) - } -} - -func TestEngine_CachedCompiledFunctionsPerModule(t *testing.T) { +func TestEngine_CachedcodesPerModule(t *testing.T) { e := et.NewEngine(wasm.Features20191205).(*engine) - exp := []*compiledFunction{ - {source: &wasm.FunctionInstance{DebugName: "1"}}, - {source: &wasm.FunctionInstance{DebugName: "2"}}, + exp := []*code{ + {body: []*interpreterOp{}}, + {body: []*interpreterOp{}}, } m := &wasm.Module{} - e.addCachedCompiledFunctions(m, exp) + e.addCodes(m, exp) - actual, ok := e.getCachedCompiledFunctions(m) + actual, ok := e.getCodes(m) require.True(t, ok) require.Equal(t, len(exp), len(actual)) for i := range actual { require.Equal(t, exp[i], actual[i]) } - e.deleteCachedCompiledFunctions(m) - _, ok = e.getCachedCompiledFunctions(m) + e.deleteCodes(m) + _, ok = e.getCodes(m) require.False(t, ok) } - -func TestEngine_NewModuleEngine_cache(t *testing.T) { - e := et.NewEngine(wasm.Features20191205).(*engine) - importedModuleSource := &wasm.Module{} - - // No cache. - importedME, err := e.NewModuleEngine("1", importedModuleSource, nil, []*wasm.FunctionInstance{ - newFunctionInstance(1), - newFunctionInstance(2), - }, nil, nil) - require.NoError(t, err) - - // Cached. - importedMEFromCache, err := e.NewModuleEngine("2", importedModuleSource, nil, []*wasm.FunctionInstance{ - newFunctionInstance(1), - newFunctionInstance(2), - }, nil, nil) - require.NoError(t, err) - - require.NotEqual(t, importedME, importedMEFromCache) - require.NotEqual(t, importedME.Name(), importedMEFromCache.Name()) - - // Check compiled functions. - ime, imeCache := importedME.(*moduleEngine), importedMEFromCache.(*moduleEngine) - require.Equal(t, len(ime.compiledFunctions), len(imeCache.compiledFunctions)) - - for i, fn := range ime.compiledFunctions { - // Compiled functions must be cloend. - fnCached := imeCache.compiledFunctions[i] - require.NotEqual(t, fn, fnCached) - require.NotEqual(t, fn.moduleEngine, fnCached.moduleEngine) - require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. - // But the body stays the same. - require.Equal(t, fn.body, fnCached.body) - } - - // Next is to veirfy the caching works for modules with imports. - importedFunc := ime.compiledFunctions[0].source - moduleSource := &wasm.Module{} - - // No cache. - modEng, err := e.NewModuleEngine("3", moduleSource, - []*wasm.FunctionInstance{importedFunc}, // Import one function. - []*wasm.FunctionInstance{ - newFunctionInstance(10), - newFunctionInstance(20), - }, nil, nil) - require.NoError(t, err) - - // Cached. - modEngCache, err := e.NewModuleEngine("4", moduleSource, - []*wasm.FunctionInstance{importedFunc}, // Import one function. - []*wasm.FunctionInstance{ - newFunctionInstance(10), - newFunctionInstance(20), - }, nil, nil) - require.NoError(t, err) - - require.NotEqual(t, modEng, modEngCache) - require.NotEqual(t, modEng.Name(), modEngCache.Name()) - - me, meCache := modEng.(*moduleEngine), modEngCache.(*moduleEngine) - require.Equal(t, len(me.compiledFunctions), len(meCache.compiledFunctions)) - - for i, fn := range me.compiledFunctions { - fnCached := meCache.compiledFunctions[i] - if i == 0 { - // This case the function is imported, so it must be the same for both module engines. - require.Equal(t, fn, fnCached) - require.Equal(t, importedFunc, fn.source) - } else { - // Compiled functions must be cloend. - require.NotEqual(t, fn, fnCached) - require.NotEqual(t, fn.moduleEngine, fnCached.moduleEngine) - require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. - // But the code segment stays the same. - require.Equal(t, fn.body, fnCached.body) - } - } -} - -func newFunctionInstance(id int) *wasm.FunctionInstance { - return &wasm.FunctionInstance{ - DebugName: strconv.Itoa(id), - Type: &wasm.FunctionType{}, - Body: []byte{wasm.OpcodeEnd}, - Module: &wasm.ModuleInstance{}, - } -} diff --git a/internal/wasm/jit/arch_amd64.go b/internal/wasm/jit/arch_amd64.go index 60298e7b..6589b92e 100644 --- a/internal/wasm/jit/arch_amd64.go +++ b/internal/wasm/jit/arch_amd64.go @@ -1,7 +1,6 @@ package jit import ( - "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/internal/wazeroir" ) @@ -24,6 +23,6 @@ func init() { // newCompiler returns a new compiler interface which can be used to compile the given function instance. // Note: ir param can be nil for host functions. -func newCompiler(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) { - return newAmd64Compiler(f, ir) +func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { + return newAmd64Compiler(ir) } diff --git a/internal/wasm/jit/arch_arm64.go b/internal/wasm/jit/arch_arm64.go index 15eb1024..a2a18c48 100644 --- a/internal/wasm/jit/arch_arm64.go +++ b/internal/wasm/jit/arch_arm64.go @@ -3,7 +3,6 @@ package jit import ( "math" - "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/internal/wazeroir" ) @@ -49,6 +48,6 @@ func init() { // newCompiler returns a new compiler interface which can be used to compile the given function instance. // Note: ir param can be nil for host functions. -func newCompiler(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) { - return newArm64Compiler(f, ir) +func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { + return newArm64Compiler(ir) } diff --git a/internal/wasm/jit/arch_arm64.s b/internal/wasm/jit/arch_arm64.s index f10ccdc8..b453fcb0 100644 --- a/internal/wasm/jit/arch_arm64.s +++ b/internal/wasm/jit/arch_arm64.s @@ -7,8 +7,8 @@ TEXT ·jitcall(SB),NOSPLIT|NOFRAME,$0-24 MOVD ce+8(FP),R0 // In arm64, return address is stored in R30 after jumping into the code. // We save the return address value into archContext.jitReturnAddress in Engine. - // Note that the const 120 drifts after editting Engine or archContext struct. See TestArchContextOffsetInEngine. - MOVD R30,120(R0) + // Note that the const 128 drifts after editting Engine or archContext struct. See TestArchContextOffsetInEngine. + MOVD R30,128(R0) // Load the address of *wasm.ModuleInstance into arm64CallingConventionModuleInstanceAddressRegister. MOVD moduleInstanceAddress+16(FP),R29 // Load the address of native code. diff --git a/internal/wasm/jit/arch_other.go b/internal/wasm/jit/arch_other.go index ac63cd38..35662c3e 100644 --- a/internal/wasm/jit/arch_other.go +++ b/internal/wasm/jit/arch_other.go @@ -6,7 +6,6 @@ import ( "fmt" "runtime" - "github.com/tetratelabs/wazero/internal/wasm" "github.com/tetratelabs/wazero/internal/wazeroir" ) @@ -14,6 +13,6 @@ import ( type archContext struct{} // newCompiler returns an unsupported error. -func newCompiler(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) { +func newCompiler(ir *wazeroir.CompilationResult) (compiler, error) { return nil, fmt.Errorf("unsupported GOARCH %s", runtime.GOARCH) } diff --git a/internal/wasm/jit/compiler.go b/internal/wasm/jit/compiler.go index 2ce6b972..275584fe 100644 --- a/internal/wasm/jit/compiler.go +++ b/internal/wasm/jit/compiler.go @@ -14,8 +14,8 @@ type compiler interface { compilePreamble() error // compile generates the byte slice of native code. // stackPointerCeil is the max stack pointer that the target function would reach. - // staticData is compiledFunctionStaticData for the resulting native code. - compile() (code []byte, staticData compiledFunctionStaticData, stackPointerCeil uint64, err error) + // staticData is codeStaticData for the resulting native code. + compile() (code []byte, staticData codeStaticData, stackPointerCeil uint64, err error) // compileHostFunction emits the trampoline code from which native code can jump into the host function. // TODO: maybe we wouldn't need to have trampoline for host functions. compileHostFunction() error diff --git a/internal/wasm/jit/engine.go b/internal/wasm/jit/engine.go index bb37bc50..b0aace76 100644 --- a/internal/wasm/jit/engine.go +++ b/internal/wasm/jit/engine.go @@ -18,10 +18,9 @@ import ( type ( // engine is an JIT implementation of wasm.Engine engine struct { - enabledFeatures wasm.Features - compiledFunctions map[*wasm.FunctionInstance]*compiledFunction // guarded by mutex. - cachedCompiledFunctionsPerModule map[*wasm.Module][]*compiledFunction // guarded by mutex. - mux sync.RWMutex + enabledFeatures wasm.Features + codes map[*wasm.Module][]*code // guarded by mutex. + mux sync.RWMutex // setFinalizer defaults to runtime.SetFinalizer, but overridable for tests. setFinalizer func(obj interface{}, finalizer interface{}) } @@ -31,14 +30,12 @@ type ( // name is the name the module was instantiated with used for error handling. name string - // compiledFunctions are the compiled functions in a module instances. + // functions are the functions in a module instances. // The index is module instance-scoped. We intentionally avoid using map // as the underlying memory region is accessed by assembly directly by using - // compiledFunctionsElement0Address. - compiledFunctions []*compiledFunction + // codesElement0Address. + functions []*function - // parentEngine holds *engine from which this module engine is created from. - parentEngine *engine importedFunctionCount uint32 } @@ -110,8 +107,11 @@ type ( // tableSliceLen is the length of the memory buffer, i.e. len(ModuleInstance.Tables[0].Table). tableSliceLen uint64 - // compiledFunctionsElement0Address is &moduleContext.engine.compiledFunctions[0] as uintptr. - compiledFunctionsElement0Address uintptr + // codesElement0Address is &moduleContext.engine.codes[0] as uintptr. + codesElement0Address uintptr + + // typeIDsElement0Address holds the &ModuleInstance.typeIDs[0] as uintptr. + typeIDsElement0Address uintptr } // valueStackContext stores the data to access engine.valueStack. @@ -159,14 +159,13 @@ type ( // Set when making function call from this function frame. returnStackBasePointer uint64 // Set when making function call to this function frame. - compiledFunction *compiledFunction + function *function // _ is a necessary padding to make the size of callFrame struct a power of 2. _ [8]byte } - compiledFunction struct { - // The following fields are accessed by JITed code. - + // Function corresponds to function instance in Wasm, and is created from `code`. + function struct { // codeInitialAddress is the pre-calculated pointer pointing to the initial byte of .codeSegment slice. // That mean codeInitialAddress always equals uintptr(unsafe.Pointer(&.codeSegment[0])) // and we cache the value (uintptr(unsafe.Pointer(&.codeSegment[0]))) to this field, @@ -178,40 +177,49 @@ type ( source *wasm.FunctionInstance // moduleInstanceAddress holds the address of source.ModuleInstance. moduleInstanceAddress uintptr + // parent holds code from which this is crated. + parent *code + } - // Followings are not accessed by JITed code. - + // code corresponds to a function in a module (not insantaited one). This holds the machine code + // compiled by Wazero's JIT compiler. + code struct { // codeSegment is holding the compiled native code as a byte slice. codeSegment []byte - // See the doc for compiledFunctionStaticData type. - staticData compiledFunctionStaticData + // See the doc for codeStaticData type. + staticData codeStaticData + // stackPointerCeil is the max of the stack pointer this function can reach. Lazily applied via maybeGrowValueStack. + stackPointerCeil uint64 + + // indexInModule is the index of this function in the module. For logging purpose. + indexInModule wasm.Index + // sourceModule is the module from which this function is compiled. For logging purpose. + sourceModule *wasm.Module } // staticData holds the read-only data (i.e. out side of codeSegment which is marked as executable) per function. // This is used to store jump tables for br_table instructions. // The primary index is the logical separation of multiple data, for example data[0] and data[1] // correspond to different jump tables for different br_table instructions. - compiledFunctionStaticData = [][]byte + codeStaticData = [][]byte ) -func (c *compiledFunction) clone(newSourceInstance *wasm.FunctionInstance) *compiledFunction { - // Note: we don't need to set finalizer to munmap the code segment since it is - // already a target of munmap by the finalizer set on the original compiledFunction `c`. - return &compiledFunction{ - codeInitialAddress: c.codeInitialAddress, +// createFunction creates a new function which uses the native code compiled. +func (c *code) createFunction(f *wasm.FunctionInstance) *function { + return &function{ + codeInitialAddress: uintptr(unsafe.Pointer(&c.codeSegment[0])), stackPointerCeil: c.stackPointerCeil, - source: newSourceInstance, - moduleInstanceAddress: uintptr(unsafe.Pointer(newSourceInstance.Module)), - codeSegment: c.codeSegment, - staticData: c.staticData, + moduleInstanceAddress: uintptr(unsafe.Pointer(f.Module)), + source: f, + parent: c, } } // Native code reads/writes Go's structs with the following constants. // See TestVerifyOffsetValue for how to derive these values. const ( - // Offsets for moduleEngine.compiledFunctions - moduleEngineCompiledFunctionsOffset = 16 + // Offsets for moduleEngine.functions + moduleEngineFunctionsOffset = 16 // Offsets for callEngine globalContext. callEngineGlobalContextValueStackElement0AddressOffset = 0 @@ -221,40 +229,42 @@ const ( callEngineGlobalContextCallFrameStackPointerOffset = 32 // Offsets for callEngine moduleContext. - callEngineModuleContextModuleInstanceAddressOffset = 40 - callEngineModuleContextGlobalElement0AddressOffset = 48 - callEngineModuleContextMemoryElement0AddressOffset = 56 - callEngineModuleContextMemorySliceLenOffset = 64 - callEngineModuleContextTableElement0AddressOffset = 72 - callEngineModuleContextTableSliceLenOffset = 80 - callEngineModuleContextCompiledFunctionsElement0AddressOffset = 88 + callEngineModuleContextModuleInstanceAddressOffset = 40 + callEngineModuleContextGlobalElement0AddressOffset = 48 + callEngineModuleContextMemoryElement0AddressOffset = 56 + callEngineModuleContextMemorySliceLenOffset = 64 + callEngineModuleContextTableElement0AddressOffset = 72 + callEngineModuleContextTableSliceLenOffset = 80 + callEngineModuleContextcodesElement0AddressOffset = 88 + callEngineModuleContextTypeIDsElement0AddressOffset = 96 // Offsets for callEngine valueStackContext. - callEngineValueStackContextStackPointerOffset = 96 - callEngineValueStackContextStackBasePointerOffset = 104 + callEngineValueStackContextStackPointerOffset = 104 + callEngineValueStackContextStackBasePointerOffset = 112 // Offsets for callEngine exitContext. - callEngineExitContextJITCallStatusCodeOffset = 112 - callEngineExitContextBuiltinFunctionCallAddressOffset = 116 + callEngineExitContextJITCallStatusCodeOffset = 120 + callEngineExitContextBuiltinFunctionCallAddressOffset = 124 // Offsets for callFrame. callFrameDataSize = 32 callFrameDataSizeMostSignificantSetBit = 5 callFrameReturnAddressOffset = 0 callFrameReturnStackBasePointerOffset = 8 - callFrameCompiledFunctionOffset = 16 + callFrameFunctionOffset = 16 - // Offsets for compiledFunction. - compiledFunctionCodeInitialAddressOffset = 0 - compiledFunctionStackPointerCeilOffset = 8 - compiledFunctionSourceOffset = 16 - compiledFunctionModuleInstanceAddressOffset = 24 + // Offsets for function. + functionCodeInitialAddressOffset = 0 + functionStackPointerCeilOffset = 8 + functionSourceOffset = 16 + functionModuleInstanceAddressOffset = 24 // Offsets for wasm.ModuleInstance. moduleInstanceGlobalsOffset = 48 moduleInstanceMemoryOffset = 72 moduleInstanceTableOffset = 80 moduleInstanceEngineOffset = 120 + moduleInstanceTypeIDsOffset = 136 // Offsets for wasm.TableInstance. tableInstanceTableOffset = 0 @@ -334,6 +344,20 @@ func (s jitCallStatusCode) String() (ret string) { ret = "call_builtin_function" case jitCallStatusCodeUnreachable: ret = "unreachable" + case jitCallStatusCodeInvalidFloatToIntConversion: + ret = "invalid float to int conversion" + case jitCallStatusCodeMemoryOutOfBounds: + ret = "memory out of bounds" + case jitCallStatusCodeInvalidTableAccess: + ret = "invalid table access" + case jitCallStatusCodeTypeMismatchOnIndirectCall: + ret = "type mismatch on indirect call" + case jitCallStatusIntegerOverflow: + ret = "integer overflow" + case jitCallStatusIntegerDivisionByZero: + ret = "integer division by zero" + default: + panic("BUG") } return } @@ -342,12 +366,12 @@ func (s jitCallStatusCode) String() (ret string) { func (c *callFrame) String() string { return fmt.Sprintf( "[%s: return address=0x%x, return stack base pointer=%d]", - c.compiledFunction.source.DebugName, c.returnAddress, c.returnStackBasePointer, + c.function.source.DebugName, c.returnAddress, c.returnStackBasePointer, ) } -// releaseCompiledFunction is a runtime.SetFinalizer function that munmaps the compiledFunction.codeSegment. -func releaseCompiledFunction(compiledFn *compiledFunction) { +// releaseCode is a runtime.SetFinalizer function that munmaps the code.codeSegment. +func releaseCode(compiledFn *code) { codeSegment := compiledFn.codeSegment if codeSegment == nil { return // already released @@ -359,13 +383,62 @@ func releaseCompiledFunction(compiledFn *compiledFunction) { // munmap failure cannot recover, and happen asynchronously on the finalizer thread. While finalizer // functions can return errors, they are ignored. To make these visible for troubleshooting, we panic // with additional context. module+funcidx should be enough, but if not, we can add more later. - panic(fmt.Errorf("jit: failed to munmap code segment for %s.function[%d]: %w", compiledFn.source.Module.Name, compiledFn.source.Index, err)) + panic(fmt.Errorf("jit: failed to munmap code segment for %s.function[%d]: %w", compiledFn.sourceModule.NameSection.ModuleName, + compiledFn.indexInModule, err)) } } -// ReleaseCompilationCache implements the same method as documented on wasm.Engine. -func (e *engine) ReleaseCompilationCache(module *wasm.Module) { - e.deleteCachedCompiledFunctions(module) +// DeleteCompiledModule implements the same method as documented on wasm.Engine. +func (e *engine) DeleteCompiledModule(module *wasm.Module) { + e.deleteCodes(module) +} + +// CompileModule implements the same method as documented on wasm.Engine. +func (e *engine) CompileModule(module *wasm.Module) error { + if _, ok := e.getCodes(module); ok { // cache hit! + return nil + } + + funcs := make([]*code, 0, len(module.FunctionSection)) + + if module.IsHostModule() { + for funcIndex := range module.HostFunctionSection { + compiled, err := compileHostFunction(module.TypeSection[module.FunctionSection[funcIndex]]) + if err != nil { + return fmt.Errorf("function[%d/%d] %w", funcIndex, len(module.FunctionSection)-1, err) + } + + // As this uses mmap, we need a finalizer in case moduleEngine.Close was never called. Regardless, we need a + // finalizer due to how moduleEngine.Close is implemented. + e.setFinalizer(compiled, releaseCode) + + compiled.indexInModule = wasm.Index(funcIndex) + compiled.sourceModule = module + funcs = append(funcs, compiled) + } + } else { + irs, err := wazeroir.CompileFunctions(e.enabledFeatures, module) + if err != nil { + return err + } + + for funcIndex := range module.FunctionSection { + compiled, err := compileWasmFunction(e.enabledFeatures, irs[funcIndex]) + if err != nil { + return fmt.Errorf("function[%d/%d] %w", funcIndex, len(module.FunctionSection)-1, err) + } + + // As this uses mmap, we need to munmap on the compiled machine code when it's GCed. + e.setFinalizer(compiled, releaseCode) + + compiled.indexInModule = wasm.Index(funcIndex) + compiled.sourceModule = module + + funcs = append(funcs, compiled) + } + } + e.addCodes(module, funcs) + return nil } // NewModuleEngine implements the same method as documented on wasm.Engine. @@ -373,123 +446,48 @@ func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunct imported := uint32(len(importedFunctions)) me := &moduleEngine{ name: name, - compiledFunctions: make([]*compiledFunction, 0, imported+uint32(len(moduleFunctions))), - parentEngine: e, + functions: make([]*function, 0, imported+uint32(len(moduleFunctions))), importedFunctionCount: imported, } - for idx, f := range importedFunctions { - cf, ok := e.getCompiledFunction(f) - if !ok { - return nil, fmt.Errorf("import[%d] func[%s]: uncompiled", idx, f.DebugName) - } - me.compiledFunctions = append(me.compiledFunctions, cf) + for _, f := range importedFunctions { + cf := f.Module.Engine.(*moduleEngine).functions[f.Index] + me.functions = append(me.functions, cf) } - if cached, ok := e.getCachedCompiledFunctions(module); ok { // cache hit. - for i, c := range cached[len(importedFunctions):] { - cloned := c.clone(moduleFunctions[i]) - me.compiledFunctions = append(me.compiledFunctions, cloned) - } - } else { // cache miss. - for i, f := range moduleFunctions { - var compiled *compiledFunction - var err error - if f.Kind == wasm.FunctionKindWasm { - compiled, err = compileWasmFunction(e.enabledFeatures, f) - } else { - compiled, err = compileHostFunction(f) - } - if err != nil { - me.Close() // safe because the reference to me was never leaked. - return nil, fmt.Errorf("function[%s(%d/%d)] %w", f.DebugName, i, len(moduleFunctions)-1, err) - } + codes, ok := e.getCodes(module) + if !ok { + return nil, fmt.Errorf("source module for %s must be compiled before instantiation", name) + } - // As this uses mmap, we need a finalizer in case moduleEngine.Close was never called. Regardless, we need a - // finalizer due to how moduleEngine.Close is implemented. - e.setFinalizer(compiled, releaseCompiledFunction) - - me.compiledFunctions = append(me.compiledFunctions, compiled) - - // Add the compiled function to the store-wide engine as well so that - // the future importing module can refer the function instance. - e.addCompiledFunction(f, compiled) - } - - e.addCachedCompiledFunctions(module, me.compiledFunctions) + for i, c := range codes { + f := moduleFunctions[i] + function := c.createFunction(f) + me.functions = append(me.functions, function) } for elemIdx, funcidx := range tableInit { // Initialize any elements with compiled functions - table.Table[elemIdx] = me.compiledFunctions[funcidx] + table.Table[elemIdx] = me.functions[funcidx] } return me, nil } -// Close is guarded by the caller with CAS, which means it happens only once. However, there is a race-condition -// inside the critical section: functions are removed from the parent engine, but there's no guard to prevent this -// moduleInstance from making new calls. This means at inside the critical section there could be in-flight calls, -// and even after it new calls can be made, given a reference to this moduleEngine. -// -// To ensure neither in-flight, nor new calls segfault due to missing code segment, memory isn't unmapped here. So, this -// relies on the fact that NewModuleEngine already added a finalizer for each compiledFunction, -// -// Note that the finalizer is a queue of work to be done at some point (perhaps never). In worst case, the finalizer -// doesn't run and functions in already closed modules retain memory until exhaustion. -// -// Potential future design (possibly faulty, so expect impl to be more complete or better): -// * Change this to implement io.Closer and document this is blocking -// * This implies adding docs can suggest this is run in a goroutine -// * io.Closer allows an error return we can use in case an unrecoverable error happens -// * Continue to guard with CAS so that close is only executed once -// * Once in the critical section, write a status bit to a fixed memory location. -// * End new calls with a Closed Error if this is read. -// * This guard allows Close to eventually complete. -// * Block exiting the critical section until all in-flight calls complete. -// * Knowing which in-flight calls from other modules, that can use this module may be tricky -// * Pure wasm functions can be left to complete. -// * Host functions are the only unknowns (ex can do I/O) so they may need to be tracked. -func (me *moduleEngine) Close() { - // Release all the function instances declared in this module. - for _, cf := range me.compiledFunctions[me.importedFunctionCount:] { - // NOTE: we still rely on the finalizer of cf until the notes on this function are addressed. - me.parentEngine.deleteCompiledFunction(cf.source) - } -} - -func (e *engine) deleteCompiledFunction(f *wasm.FunctionInstance) { +func (e *engine) deleteCodes(module *wasm.Module) { e.mux.Lock() defer e.mux.Unlock() - delete(e.compiledFunctions, f) + delete(e.codes, module) } -func (e *engine) addCompiledFunction(f *wasm.FunctionInstance, cf *compiledFunction) { +func (e *engine) addCodes(module *wasm.Module, fs []*code) { e.mux.Lock() defer e.mux.Unlock() - e.compiledFunctions[f] = cf + e.codes[module] = fs } -func (e *engine) getCompiledFunction(f *wasm.FunctionInstance) (cf *compiledFunction, ok bool) { +func (e *engine) getCodes(module *wasm.Module) (fs []*code, ok bool) { e.mux.RLock() defer e.mux.RUnlock() - cf, ok = e.compiledFunctions[f] - return -} -func (e *engine) deleteCachedCompiledFunctions(module *wasm.Module) { - e.mux.Lock() - defer e.mux.Unlock() - delete(e.cachedCompiledFunctionsPerModule, module) -} - -func (e *engine) addCachedCompiledFunctions(module *wasm.Module, fs []*compiledFunction) { - e.mux.Lock() - defer e.mux.Unlock() - e.cachedCompiledFunctionsPerModule[module] = fs -} - -func (e *engine) getCachedCompiledFunctions(module *wasm.Module) (fs []*compiledFunction, ok bool) { - e.mux.RLock() - defer e.mux.RUnlock() - fs, ok = e.cachedCompiledFunctionsPerModule[module] + fs, ok = e.codes[module] return } @@ -501,11 +499,11 @@ func (me *moduleEngine) Name() string { // Call implements the same method as documented on wasm.ModuleEngine. func (me *moduleEngine) Call(m *wasm.ModuleContext, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) { // Note: The input parameters are pre-validated, so a compiled function is only absent on close. Updates to - // compiledFunctions on close aren't locked, neither is this read. - compiled := me.compiledFunctions[f.Index] + // codes on close aren't locked, neither is this read. + compiled := me.functions[f.Index] if compiled == nil { // Lazy check the cause as it could be because the module was already closed. if err = m.FailIfClosed(); err == nil { - panic(fmt.Errorf("BUG: %s.compiledFunctions[%d] was nil before close", me.name, f.Index)) + panic(fmt.Errorf("BUG: %s.func[%d] was nil before close", me.name, f.Index)) } return } @@ -537,7 +535,7 @@ func (me *moduleEngine) Call(m *wasm.ModuleContext, f *wasm.FunctionInstance, pa builder.AddFrame(fn.DebugName, fn.ParamTypes(), fn.ResultTypes()) } for i := uint64(0); i < ce.globalContext.callFrameStackPointer; i++ { - fn := ce.callFrameStack[ce.globalContext.callFrameStackPointer-1-i].compiledFunction.source + fn := ce.callFrameStack[ce.globalContext.callFrameStackPointer-1-i].function.source builder.AddFrame(fn.DebugName, fn.ParamTypes(), fn.ResultTypes()) } err = builder.FromRecovered(v) @@ -569,10 +567,9 @@ func NewEngine(enabledFeatures wasm.Features) wasm.Engine { func newEngine(enabledFeatures wasm.Features) *engine { return &engine{ - enabledFeatures: enabledFeatures, - compiledFunctions: map[*wasm.FunctionInstance]*compiledFunction{}, - cachedCompiledFunctionsPerModule: map[*wasm.Module][]*compiledFunction{}, - setFinalizer: runtime.SetFinalizer, + enabledFeatures: enabledFeatures, + codes: map[*wasm.Module][]*code{}, + setFinalizer: runtime.SetFinalizer, } } @@ -721,9 +718,9 @@ func (ce *callEngine) execHostFunction(fk wasm.FunctionKind, f *reflect.Value, c } } -func (ce *callEngine) execWasmFunction(ctx *wasm.ModuleContext, f *compiledFunction) { +func (ce *callEngine) execWasmFunction(ctx *wasm.ModuleContext, f *function) { // Push the initial callframe. - ce.callFrameStack[0] = callFrame{returnAddress: f.codeInitialAddress, compiledFunction: f} + ce.callFrameStack[0] = callFrame{returnAddress: f.codeInitialAddress, function: f} ce.globalContext.callFrameStackPointer++ jitentry: @@ -742,23 +739,23 @@ jitentry: case jitCallStatusCodeReturned: // Meaning that all the function frames above the previous call frame stack pointer are executed. case jitCallStatusCodeCallHostFunction: - calleeHostFunction := ce.callFrameTop().compiledFunction.source + calleeHostFunction := ce.callFrameTop().function.source // Not "callFrameTop" but take the below of peek with "callFrameAt(1)" as the top frame is for host function, // but when making host function calls, we need to pass the memory instance of host function caller. - callerCompiledFunction := ce.callFrameAt(1).compiledFunction + callercode := ce.callFrameAt(1).function // A host function is invoked with the calling frame's memory, which may be different if in another module. ce.execHostFunction(calleeHostFunction.Kind, calleeHostFunction.GoFunc, - ctx.WithMemory(callerCompiledFunction.source.Module.Memory), + ctx.WithMemory(callercode.source.Module.Memory), ) goto jitentry case jitCallStatusCodeCallBuiltInFunction: switch ce.exitContext.builtinFunctionCallIndex { case builtinFunctionIndexMemoryGrow: - callerCompiledFunction := ce.callFrameTop().compiledFunction - ce.builtinFunctionMemoryGrow(callerCompiledFunction.source.Module.Memory) + callercode := ce.callFrameTop().function + ce.builtinFunctionMemoryGrow(callercode.source.Module.Memory) case builtinFunctionIndexGrowValueStack: - callerCompiledFunction := ce.callFrameTop().compiledFunction - ce.builtinFunctionGrowValueStack(callerCompiledFunction.stackPointerCeil) + callercode := ce.callFrameTop().function + ce.builtinFunctionGrowValueStack(callercode.stackPointerCeil) case builtinFunctionIndexGrowCallFrameStack: ce.builtinFunctionGrowCallFrameStack() } @@ -817,8 +814,8 @@ func (ce *callEngine) builtinFunctionMemoryGrow(mem *wasm.MemoryInstance) { ce.moduleContext.memoryElement0Address = bufSliceHeader.Data } -func compileHostFunction(f *wasm.FunctionInstance) (*compiledFunction, error) { - compiler, err := newCompiler(f, nil) +func compileHostFunction(sig *wasm.FunctionType) (*code, error) { + compiler, err := newCompiler(&wazeroir.CompilationResult{Signature: sig}) if err != nil { return nil, err } @@ -827,26 +824,16 @@ func compileHostFunction(f *wasm.FunctionInstance) (*compiledFunction, error) { return nil, err } - code, _, _, err := compiler.compile() + c, _, _, err := compiler.compile() if err != nil { return nil, err } - return &compiledFunction{ - source: f, - codeSegment: code, - codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), - moduleInstanceAddress: uintptr(unsafe.Pointer(f.Module)), - }, nil + return &code{codeSegment: c}, nil } -func compileWasmFunction(enabledFeatures wasm.Features, f *wasm.FunctionInstance) (*compiledFunction, error) { - ir, err := wazeroir.Compile(enabledFeatures, f) - if err != nil { - return nil, fmt.Errorf("failed to lower to wazeroir: %w", err) - } - - compiler, err := newCompiler(f, ir) +func compileWasmFunction(enabledFeatures wasm.Features, ir *wazeroir.CompilationResult) (*code, error) { + compiler, err := newCompiler(ir) if err != nil { return nil, fmt.Errorf("failed to initialize assembly builder: %w", err) } @@ -1024,17 +1011,10 @@ func compileWasmFunction(enabledFeatures wasm.Features, f *wasm.FunctionInstance } } - code, staticData, stackPointerCeil, err := compiler.compile() + c, staticData, stackPointerCeil, err := compiler.compile() if err != nil { return nil, fmt.Errorf("failed to compile: %w", err) } - return &compiledFunction{ - source: f, - codeSegment: code, - codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), - moduleInstanceAddress: uintptr(unsafe.Pointer(f.Module)), - stackPointerCeil: stackPointerCeil, - staticData: staticData, - }, nil + return &code{codeSegment: c, stackPointerCeil: stackPointerCeil, staticData: staticData}, nil } diff --git a/internal/wasm/jit/engine_test.go b/internal/wasm/jit/engine_test.go index 314417ea..b036247e 100644 --- a/internal/wasm/jit/engine_test.go +++ b/internal/wasm/jit/engine_test.go @@ -5,7 +5,6 @@ import ( "fmt" "math" "runtime" - "strconv" "testing" "unsafe" @@ -17,7 +16,7 @@ import ( // Ensures that the offset consts do not drift when we manipulate the target structs. func TestJIT_VerifyOffsetValue(t *testing.T) { var me moduleEngine - require.Equal(t, int(unsafe.Offsetof(me.compiledFunctions)), moduleEngineCompiledFunctionsOffset) + require.Equal(t, int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset) var ce callEngine // Offsets for callEngine.globalContext. @@ -34,7 +33,8 @@ func TestJIT_VerifyOffsetValue(t *testing.T) { require.Equal(t, int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset) require.Equal(t, int(unsafe.Offsetof(ce.tableElement0Address)), callEngineModuleContextTableElement0AddressOffset) require.Equal(t, int(unsafe.Offsetof(ce.tableSliceLen)), callEngineModuleContextTableSliceLenOffset) - require.Equal(t, int(unsafe.Offsetof(ce.compiledFunctionsElement0Address)), callEngineModuleContextCompiledFunctionsElement0AddressOffset) + require.Equal(t, int(unsafe.Offsetof(ce.codesElement0Address)), callEngineModuleContextcodesElement0AddressOffset) + require.Equal(t, int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset) // Offsets for callEngine.valueStackContext require.Equal(t, int(unsafe.Offsetof(ce.stackPointer)), callEngineValueStackContextStackPointerOffset) @@ -52,14 +52,14 @@ func TestJIT_VerifyOffsetValue(t *testing.T) { require.Equal(t, math.Ilogb(float64(callFrameDataSize)), callFrameDataSizeMostSignificantSetBit) require.Equal(t, int(unsafe.Offsetof(frame.returnAddress)), callFrameReturnAddressOffset) require.Equal(t, int(unsafe.Offsetof(frame.returnStackBasePointer)), callFrameReturnStackBasePointerOffset) - require.Equal(t, int(unsafe.Offsetof(frame.compiledFunction)), callFrameCompiledFunctionOffset) + require.Equal(t, int(unsafe.Offsetof(frame.function)), callFrameFunctionOffset) - // Offsets for compiledFunction. - var compiledFunc compiledFunction - require.Equal(t, int(unsafe.Offsetof(compiledFunc.codeInitialAddress)), compiledFunctionCodeInitialAddressOffset) - require.Equal(t, int(unsafe.Offsetof(compiledFunc.stackPointerCeil)), compiledFunctionStackPointerCeilOffset) - require.Equal(t, int(unsafe.Offsetof(compiledFunc.source)), compiledFunctionSourceOffset) - require.Equal(t, int(unsafe.Offsetof(compiledFunc.moduleInstanceAddress)), compiledFunctionModuleInstanceAddressOffset) + // Offsets for code. + var compiledFunc function + require.Equal(t, int(unsafe.Offsetof(compiledFunc.codeInitialAddress)), functionCodeInitialAddressOffset) + require.Equal(t, int(unsafe.Offsetof(compiledFunc.stackPointerCeil)), functionStackPointerCeilOffset) + require.Equal(t, int(unsafe.Offsetof(compiledFunc.source)), functionSourceOffset) + require.Equal(t, int(unsafe.Offsetof(compiledFunc.moduleInstanceAddress)), functionModuleInstanceAddressOffset) // Offsets for wasm.ModuleInstance. var moduleInstance wasm.ModuleInstance @@ -67,6 +67,7 @@ func TestJIT_VerifyOffsetValue(t *testing.T) { require.Equal(t, int(unsafe.Offsetof(moduleInstance.Memory)), moduleInstanceMemoryOffset) require.Equal(t, int(unsafe.Offsetof(moduleInstance.Table)), moduleInstanceTableOffset) require.Equal(t, int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset) + require.Equal(t, int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset) var functionInstance wasm.FunctionInstance require.Equal(t, int(unsafe.Offsetof(functionInstance.TypeID)), functionInstanceTypeIDOffset) @@ -111,7 +112,7 @@ func (e *engineTester) InitTable(me wasm.ModuleEngine, initTableLen uint32, init table := make([]interface{}, initTableLen) internal := me.(*moduleEngine) for idx, fnidx := range initTableIdxToFnIdx { - table[idx] = internal.compiledFunctions[fnidx] + table[idx] = internal.functions[fnidx] } return table } @@ -147,219 +148,81 @@ func requireSupportedOSArch(t *testing.T) { } } -func TestJIT_EngineCompile_Errors(t *testing.T) { - t.Run("invalid import", func(t *testing.T) { - e := et.NewEngine(wasm.Features20191205) - _, err := e.NewModuleEngine( - t.Name(), - &wasm.Module{}, - []*wasm.FunctionInstance{{Module: &wasm.ModuleInstance{Name: "uncompiled"}, DebugName: "uncompiled.fn"}}, - nil, // moduleFunctions - nil, // table - nil, // tableInit - ) - require.EqualError(t, err, "import[0] func[uncompiled.fn]: uncompiled") - }) - - t.Run("release on compilation error", func(t *testing.T) { - e := et.NewEngine(wasm.Features20191205).(*engine) - - importedFunctions := []*wasm.FunctionInstance{ - {DebugName: "1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "3", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "4", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - } - _, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) - require.NoError(t, err) - - require.Equal(t, len(importedFunctions), len(e.compiledFunctions)) - - moduleFunctions := []*wasm.FunctionInstance{ - {DebugName: "ok1", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "ok2", Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}}, - {DebugName: "invalid code", Type: &wasm.FunctionType{}, Body: []byte{ - wasm.OpcodeCall, // Call instruction without immediate for call target index is invalid and should fail to compile. - }, Module: &wasm.ModuleInstance{}}, - } - - _, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) - require.EqualError(t, err, "function[invalid code(2/2)] failed to lower to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") - - // On the compilation failure, all the compiled functions including succeeded ones must be released. - require.Equal(t, len(importedFunctions), len(e.compiledFunctions)) - for _, f := range moduleFunctions { - require.Nil(t, e.compiledFunctions[f]) - } - }) -} - -type fakeFinalizer map[*compiledFunction]func(*compiledFunction) +type fakeFinalizer map[*code]func(*code) func (f fakeFinalizer) setFinalizer(obj interface{}, finalizer interface{}) { - cf := obj.(*compiledFunction) + cf := obj.(*code) if _, ok := f[cf]; ok { // easier than adding a field for testing.T panic(fmt.Sprintf("BUG: %v already had its finalizer set", cf)) } - f[cf] = finalizer.(func(*compiledFunction)) + f[cf] = finalizer.(func(*code)) } -func TestJIT_NewModuleEngine_CompiledFunctions(t *testing.T) { - e := et.NewEngine(wasm.Features20191205).(*engine) +func TestJIT_CompileModule(t *testing.T) { + t.Run("ok", func(t *testing.T) { + e := et.NewEngine(wasm.Features20191205).(*engine) + ff := fakeFinalizer{} + e.setFinalizer = ff.setFinalizer - importedFinalizer := fakeFinalizer{} - e.setFinalizer = importedFinalizer.setFinalizer + okModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []wasm.Index{0, 0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + }, + } - importedFunctions := []*wasm.FunctionInstance{ - newFunctionInstance(10), - newFunctionInstance(20), - } - modE, err := e.NewModuleEngine(t.Name(), &wasm.Module{}, nil, importedFunctions, nil, nil) - require.NoError(t, err) - defer modE.Close() - imported := modE.(*moduleEngine) + err := e.CompileModule(okModule) + require.NoError(t, err) - importingFinalizer := fakeFinalizer{} - e.setFinalizer = importingFinalizer.setFinalizer + // Compiling same module shouldn't be compiled again, but instead should be cached. + err = e.CompileModule(okModule) + require.NoError(t, err) - moduleFunctions := []*wasm.FunctionInstance{ - newFunctionInstance(100), - newFunctionInstance(200), - newFunctionInstance(300), - } + compiled, ok := e.codes[okModule] + require.True(t, ok) + require.Equal(t, len(okModule.FunctionSection), len(compiled)) - modE, err = e.NewModuleEngine(t.Name(), &wasm.Module{}, importedFunctions, moduleFunctions, nil, nil) - require.NoError(t, err) - defer modE.Close() - importing := modE.(*moduleEngine) + // Pretend the finalizer executed, by invoking them one-by-one. + for k, v := range ff { + v(k) + } + }) - // Ensure the importing module didn't try to finalize the imported functions. - require.Equal(t, len(importedFunctions), len(imported.compiledFunctions)) - for _, f := range importedFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - cf := e.compiledFunctions[f] - require.NotNil(t, importedFinalizer[cf], cf) - require.Nil(t, importingFinalizer[cf], cf) - } + t.Run("fail", func(t *testing.T) { + errModule := &wasm.Module{ + TypeSection: []*wasm.FunctionType{{}}, + FunctionSection: []wasm.Index{0, 0, 0}, + CodeSection: []*wasm.Code{ + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeEnd}}, + {Body: []byte{wasm.OpcodeCall}}, // Call instruction without immediate for call target index is invalid and should fail to compile. + }, + } - // The importing module's compiled functions include ones it compiled (module-defined) and imported ones). - require.Equal(t, len(importedFunctions)+len(moduleFunctions), len(importing.compiledFunctions)) + e := et.NewEngine(wasm.Features20191205).(*engine) + err := e.CompileModule(errModule) + require.EqualError(t, err, "failed to lower func[2/3] to wazeroir: handling instruction: apply stack failed for call: reading immediates: EOF") - // Ensure the importing module only tried to finalize its own functions. - for _, f := range moduleFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - cf := e.compiledFunctions[f] - require.Nil(t, importedFinalizer[cf], cf) - require.NotNil(t, importingFinalizer[cf], cf) - } - - // Pretend the finalizer executed, by invoking them one-by-one. - for k, v := range importingFinalizer { - v(k) - } - for k, v := range importedFinalizer { - v(k) - } - for _, f := range e.compiledFunctions { - require.Nil(t, f.codeSegment) // Set to nil if the correct finalizer was associated. - } + // On the compilation failure, the compiled functions must not be cached. + _, ok := e.codes[errModule] + require.False(t, ok) + }) } -// TestReleaseCompiledFunction_Panic tests that an unexpected panic has some identifying information in it. -func TestJIT_ReleaseCompiledFunction_Panic(t *testing.T) { +// TestReleasecode_Panic tests that an unexpected panic has some identifying information in it. +func TestJIT_Releasecode_Panic(t *testing.T) { captured := require.CapturePanic(func() { - releaseCompiledFunction(&compiledFunction{ - codeSegment: []byte{wasm.OpcodeEnd}, // never compiled means it was never mapped. - source: &wasm.FunctionInstance{Index: 2, Module: &wasm.ModuleInstance{Name: t.Name()}}, // for error string + releaseCode(&code{ + indexInModule: 2, + sourceModule: &wasm.Module{NameSection: &wasm.NameSection{ModuleName: t.Name()}}, + codeSegment: []byte{wasm.OpcodeEnd}, // never compiled means it was never mapped. }) }) - require.Contains(t, captured.Error(), fmt.Sprintf("jit: failed to munmap code segment for %[1]s.function[2]:", t.Name())) -} - -func TestJIT_ModuleEngine_Close(t *testing.T) { - newFunctionInstance := func(id int) *wasm.FunctionInstance { - return &wasm.FunctionInstance{ - DebugName: strconv.Itoa(id), Type: &wasm.FunctionType{}, Body: []byte{wasm.OpcodeEnd}, Module: &wasm.ModuleInstance{}} - } - - for _, tc := range []struct { - name string - importedFunctions, moduleFunctions []*wasm.FunctionInstance - }{ - { - name: "no imports", - moduleFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - }, - { - name: "only imports", - importedFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - }, - { - name: "mix", - importedFunctions: []*wasm.FunctionInstance{newFunctionInstance(0), newFunctionInstance(1)}, - moduleFunctions: []*wasm.FunctionInstance{newFunctionInstance(100), newFunctionInstance(200), newFunctionInstance(300)}, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - e := et.NewEngine(wasm.Features20191205).(*engine) - var imported *moduleEngine - if len(tc.importedFunctions) > 0 { - // Instantiate the imported module - modEngine, err := e.NewModuleEngine( - fmt.Sprintf("%s - imported functions", t.Name()), - &wasm.Module{}, - nil, // moduleFunctions - tc.importedFunctions, - nil, // table - nil, // tableInit - ) - require.NoError(t, err) - imported = modEngine.(*moduleEngine) - require.Equal(t, len(tc.importedFunctions), len(imported.compiledFunctions)) - } - - importing, err := e.NewModuleEngine( - fmt.Sprintf("%s - module-defined functions", t.Name()), - &wasm.Module{}, - tc.importedFunctions, - tc.moduleFunctions, - nil, // table - nil, // tableInit - ) - require.NoError(t, err) - require.Equal(t, len(tc.importedFunctions)+len(tc.moduleFunctions), len(importing.(*moduleEngine).compiledFunctions)) - - require.Equal(t, len(tc.importedFunctions)+len(tc.moduleFunctions), len(e.compiledFunctions)) - - for _, f := range tc.importedFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - for _, f := range tc.moduleFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - - importing.Close() - - // Closing the importing module shouldn't delete the imported functions from the engine. - require.Equal(t, len(tc.importedFunctions), len(e.compiledFunctions)) - for _, f := range tc.importedFunctions { - require.NotNil(t, e.compiledFunctions[f], f) - } - - // However, closing the importing module should delete its own functions from the engine. - for i, f := range tc.moduleFunctions { - require.Nil(t, e.compiledFunctions[f], i) - } - - if len(tc.importedFunctions) > 0 { - imported.Close() - } - - // When all modules are closed, the engine should be empty. - require.Equal(t, 0, len(e.compiledFunctions), "expected no compiledFunctions") - }) - } + require.Contains(t, captured.Error(), fmt.Sprintf("jit: failed to munmap code segment for %[1]s.function[2]", t.Name())) } // Ensures that value stack and call-frame stack are allocated on heap which @@ -391,6 +254,9 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { }}, map[string]*wasm.Memory{}, map[string]*wasm.Global{}, enabledFeatures) require.NoError(t, err) + err = store.Engine.CompileModule(hm) + require.NoError(t, err) + _, err = store.Instantiate(context.Background(), hm, hostModuleName, nil) require.NoError(t, err) @@ -442,6 +308,9 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { }, } + err = store.Engine.CompileModule(m) + require.NoError(t, err) + mi, err := store.Instantiate(context.Background(), m, t.Name(), nil) require.NoError(t, err) @@ -457,114 +326,24 @@ func TestJIT_SliceAllocatedOnHeap(t *testing.T) { } // TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and jit -func TestEngine_CachedCompiledFunctionsPerModule(t *testing.T) { +func TestEngine_Cachedcodes(t *testing.T) { e := newEngine(wasm.Features20191205) - exp := []*compiledFunction{ - {source: &wasm.FunctionInstance{DebugName: "1"}}, - {source: &wasm.FunctionInstance{DebugName: "2"}}, + exp := []*code{ + {codeSegment: []byte{0x0}}, + {codeSegment: []byte{0x0}}, } m := &wasm.Module{} - e.addCachedCompiledFunctions(m, exp) + e.addCodes(m, exp) - actual, ok := e.getCachedCompiledFunctions(m) + actual, ok := e.getCodes(m) require.True(t, ok) require.Equal(t, len(exp), len(actual)) for i := range actual { require.Equal(t, exp[i], actual[i]) } - e.deleteCachedCompiledFunctions(m) - _, ok = e.getCachedCompiledFunctions(m) + e.deleteCodes(m) + _, ok = e.getCodes(m) require.False(t, ok) } - -// TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and jit -func TestEngine_NewModuleEngine_cache(t *testing.T) { - e := newEngine(wasm.Features20191205) - importedModuleSource := &wasm.Module{} - - // No cache. - importedME, err := e.NewModuleEngine("1", importedModuleSource, nil, []*wasm.FunctionInstance{ - newFunctionInstance(1), - newFunctionInstance(2), - }, nil, nil) - require.NoError(t, err) - - // Cached. - importedMEFromCache, err := e.NewModuleEngine("2", importedModuleSource, nil, []*wasm.FunctionInstance{ - newFunctionInstance(1), - newFunctionInstance(2), - }, nil, nil) - require.NoError(t, err) - - require.NotEqual(t, importedME, importedMEFromCache) - require.NotEqual(t, importedME.Name(), importedMEFromCache.Name()) - - // Check compiled functions. - ime, imeCache := importedME.(*moduleEngine), importedMEFromCache.(*moduleEngine) - require.Equal(t, len(ime.compiledFunctions), len(imeCache.compiledFunctions)) - - for i, fn := range ime.compiledFunctions { - // Compiled functions must be cloend. - fnCached := imeCache.compiledFunctions[i] - require.NotEqual(t, fn, fnCached) - require.NotEqual(t, fn.moduleInstanceAddress, fnCached.moduleInstanceAddress) - require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. - // But the code segment stays the same. - require.Equal(t, fn.codeSegment, fnCached.codeSegment) - } - - // Next is to veirfy the caching works for modules with imports. - importedFunc := ime.compiledFunctions[0].source - moduleSource := &wasm.Module{} - - // No cache. - modEng, err := e.NewModuleEngine("3", moduleSource, - []*wasm.FunctionInstance{importedFunc}, // Import one function. - []*wasm.FunctionInstance{ - newFunctionInstance(10), - newFunctionInstance(20), - }, nil, nil) - require.NoError(t, err) - - // Cached. - modEngCache, err := e.NewModuleEngine("4", moduleSource, - []*wasm.FunctionInstance{importedFunc}, // Import one function. - []*wasm.FunctionInstance{ - newFunctionInstance(10), - newFunctionInstance(20), - }, nil, nil) - require.NoError(t, err) - - require.NotEqual(t, modEng, modEngCache) - require.NotEqual(t, modEng.Name(), modEngCache.Name()) - - me, meCache := modEng.(*moduleEngine), modEngCache.(*moduleEngine) - require.Equal(t, len(me.compiledFunctions), len(meCache.compiledFunctions)) - - for i, fn := range me.compiledFunctions { - fnCached := meCache.compiledFunctions[i] - if i == 0 { - // This case the function is imported, so it must be the same for both module engines. - require.Equal(t, fn, fnCached) - require.Equal(t, importedFunc, fn.source) - } else { - // Compiled functions must be cloend. - require.NotEqual(t, fn, fnCached) - require.NotEqual(t, fn.moduleInstanceAddress, fnCached.moduleInstanceAddress) - require.NotEqual(t, unsafe.Pointer(fn.source), unsafe.Pointer(fnCached.source)) // unsafe.Pointer to compare the actual address. - // But the code segment stays the same. - require.Equal(t, fn.codeSegment, fnCached.codeSegment) - } - } -} - -func newFunctionInstance(id int) *wasm.FunctionInstance { - return &wasm.FunctionInstance{ - DebugName: strconv.Itoa(id), - Type: &wasm.FunctionType{}, - Body: []byte{wasm.OpcodeEnd}, - Module: &wasm.ModuleInstance{}, - } -} diff --git a/internal/wasm/jit/jit_controlflow_test.go b/internal/wasm/jit/jit_controlflow_test.go index 793c88bc..c6d3dc65 100644 --- a/internal/wasm/jit/jit_controlflow_test.go +++ b/internal/wasm/jit/jit_controlflow_test.go @@ -529,13 +529,15 @@ func TestCompiler_compileCallIndirect(t *testing.T) { t.Run("out of bounds", func(t *testing.T) { env := newJITEnvironment() env.setTable(make([]interface{}, 10)) - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Types: []*wasm.FunctionType{{}}, + HasTable: true, + }) err := compiler.compilePreamble() require.NoError(t, err) targetOperation := &wazeroir.OperationCallIndirect{} - // Ensure that the module instance has the type information for targetOperation.TypeIndex. - env.module().Types = []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}} // Place the offset value. err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: 10}) @@ -557,18 +559,21 @@ func TestCompiler_compileCallIndirect(t *testing.T) { t.Run("uninitialized", func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Types: []*wasm.FunctionType{{}}, + HasTable: true, + }) err := compiler.compilePreamble() require.NoError(t, err) targetOperation := &wazeroir.OperationCallIndirect{} targetOffset := &wazeroir.OperationConstI32{Value: uint32(0)} - // Ensure that the module instance has the type information for targetOperation.TypeIndex, - env.module().Types = []*wasm.TypeInstance{{Type: &wasm.FunctionType{}}} // and the typeID doesn't match the table[targetOffset]'s type ID. table := make([]interface{}, 10) env.setTable(table) + env.module().TypeIDs = make([]wasm.FunctionTypeID, 10) table[0] = nil // Place the offset value. @@ -591,19 +596,23 @@ func TestCompiler_compileCallIndirect(t *testing.T) { t.Run("type not match", func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Types: []*wasm.FunctionType{{}}, + HasTable: true, + }) err := compiler.compilePreamble() require.NoError(t, err) targetOperation := &wazeroir.OperationCallIndirect{} targetOffset := &wazeroir.OperationConstI32{Value: uint32(0)} - env.module().Types = []*wasm.TypeInstance{{Type: &wasm.FunctionType{}, TypeID: 1000}} + env.module().TypeIDs = []wasm.FunctionTypeID{1000} // Ensure that the module instance has the type information for targetOperation.TypeIndex, // and the typeID doesn't match the table[targetOffset]'s type ID. table := make([]interface{}, 10) env.setTable(table) - cf := &compiledFunction{source: &wasm.FunctionInstance{TypeID: 50}} + cf := &function{source: &wasm.FunctionInstance{TypeID: 50}} table[0] = cf // Place the offset value. @@ -622,7 +631,7 @@ func TestCompiler_compileCallIndirect(t *testing.T) { require.NoError(t, err) env.exec(code) - require.Equal(t, jitCallStatusCodeTypeMismatchOnIndirectCall, env.jitStatus()) + require.Equal(t, jitCallStatusCodeTypeMismatchOnIndirectCall.String(), env.jitStatus().String()) }) t.Run("ok", func(t *testing.T) { @@ -641,9 +650,9 @@ func TestCompiler_compileCallIndirect(t *testing.T) { // Ensure that the module instance has the type information for targetOperation.TypeIndex, // and the typeID matches the table[targetOffset]'s type ID. - env.module().Types = make([]*wasm.TypeInstance, 100) - env.module().Types[operation.TypeIndex] = &wasm.TypeInstance{Type: targetType, TypeID: targetTypeID} - env.module().Engine = &moduleEngine{compiledFunctions: []*compiledFunction{}} + env.module().TypeIDs = make([]wasm.FunctionTypeID, 100) + env.module().TypeIDs[operation.TypeIndex] = targetTypeID + env.module().Engine = &moduleEngine{functions: []*function{}} me := env.moduleEngine() for i := 0; i < len(table); i++ { @@ -663,19 +672,19 @@ func TestCompiler_compileCallIndirect(t *testing.T) { err = compiler.compileReturnFunction() require.NoError(t, err) - code, _, _, err := compiler.compile() + c, _, _, err := compiler.compile() require.NoError(t, err) - cf := &compiledFunction{ - codeSegment: code, - codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), + f := &function{ + parent: &code{codeSegment: c}, + codeInitialAddress: uintptr(unsafe.Pointer(&c[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(env.moduleInstance)), source: &wasm.FunctionInstance{ TypeID: targetTypeID, }, } - me.compiledFunctions = append(me.compiledFunctions, cf) - table[i] = cf + me.functions = append(me.functions, f) + table[i] = f }) } @@ -686,7 +695,11 @@ func TestCompiler_compileCallIndirect(t *testing.T) { env.setCallFrameStackPointerLen(1) } - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: targetType, + Types: []*wasm.FunctionType{targetType}, + HasTable: true, + }) err := compiler.compilePreamble() require.NoError(t, err) @@ -726,7 +739,7 @@ func TestCompiler_compileCallIndirect(t *testing.T) { ) } - require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + require.Equal(t, jitCallStatusCodeReturned.String(), env.jitStatus().String()) require.Equal(t, uint64(1), env.stackPointer()) require.Equal(t, expectedReturnValue, uint32(env.ce.popValue())) }) @@ -763,7 +776,7 @@ func TestCompiler_compileCall(t *testing.T) { // the mutex lock and must release on the cleanup of each subtest. // TODO: delete after https://github.com/tetratelabs/wazero/issues/233 t.Run(fmt.Sprintf("compiling call target %d", i), func(t *testing.T) { - compiler := env.requireNewCompiler(t, newCompiler, targetFunctionType) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{Signature: targetFunctionType}) err := compiler.compilePreamble() require.NoError(t, err) @@ -776,12 +789,12 @@ func TestCompiler_compileCall(t *testing.T) { err = compiler.compileReturnFunction() require.NoError(t, err) - code, _, _, err := compiler.compile() + c, _, _, err := compiler.compile() require.NoError(t, err) index := wasm.Index(i) - me.compiledFunctions = append(me.compiledFunctions, &compiledFunction{ - codeSegment: code, - codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), + me.functions = append(me.functions, &function{ + parent: &code{codeSegment: c}, + codeInitialAddress: uintptr(unsafe.Pointer(&c[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(env.moduleInstance)), }) env.module().Functions = append(env.module().Functions, @@ -790,7 +803,12 @@ func TestCompiler_compileCall(t *testing.T) { } // Now we start building the caller's code. - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Functions: make([]uint32, numCalls), + Types: []*wasm.FunctionType{targetFunctionType}, + }) + err := compiler.compilePreamble() require.NoError(t, err) @@ -885,23 +903,24 @@ func TestCompiler_returnFunction(t *testing.T) { err = compiler.compileReturnFunction() require.NoError(t, err) - code, _, _, err := compiler.compile() + c, _, _, err := compiler.compile() require.NoError(t, err) // Compiles and adds to the engine. - compiledFunction := &compiledFunction{ - codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), + f := &function{ + parent: &code{codeSegment: c}, + codeInitialAddress: uintptr(unsafe.Pointer(&c[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(env.moduleInstance)), } - moduleEngine.compiledFunctions = append(moduleEngine.compiledFunctions, compiledFunction) + moduleEngine.functions = append(moduleEngine.functions, f) // Pushes the frame whose return address equals the beginning of the function just compiled. frame := callFrame{ // Set the return address to the beginning of the function so that we can execute the constI32 above. - returnAddress: compiledFunction.codeInitialAddress, + returnAddress: f.codeInitialAddress, // Note: return stack base pointer is set to funcaddr*5 and this is where the const should be pushed. returnStackBasePointer: uint64(funcIndex) * 5, - compiledFunction: compiledFunction, + function: f, } ce.callFrameStack[ce.globalContext.callFrameStackPointer] = frame ce.globalContext.callFrameStackPointer++ @@ -912,7 +931,7 @@ func TestCompiler_returnFunction(t *testing.T) { require.Equal(t, uint64(callFrameNums), env.callFrameStackPointer()) // Run code from the top frame. - env.exec(ce.callFrameTop().compiledFunction.codeSegment) + env.exec(ce.callFrameTop().function.parent.codeSegment) // Check the exit status and the values on stack. require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) diff --git a/internal/wasm/jit/jit_global_test.go b/internal/wasm/jit/jit_global_test.go index 5fa7a7eb..4acb9dac 100644 --- a/internal/wasm/jit/jit_global_test.go +++ b/internal/wasm/jit/jit_global_test.go @@ -16,7 +16,10 @@ func TestCompiler_compileGlobalGet(t *testing.T) { tp := tp t.Run(wasm.ValueTypeName(tp), func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Globals: []*wasm.GlobalType{nil, {ValType: tp}}, + }) // Setup the global. (Start with nil as a dummy so that global index can be non-trivial.) globals := []*wasm.GlobalInstance{nil, {Val: globalValue, Type: &wasm.GlobalType{ValType: tp}}} @@ -66,7 +69,10 @@ func TestCompiler_compileGlobalSet(t *testing.T) { tp := tp t.Run(wasm.ValueTypeName(tp), func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{ + Signature: &wasm.FunctionType{}, + Globals: []*wasm.GlobalType{nil, {ValType: tp}}, + }) // Setup the global. (Start with nil as a dummy so that global index can be non-trivial.) env.addGlobals(nil, &wasm.GlobalInstance{Val: 40, Type: &wasm.GlobalType{ValType: tp}}) diff --git a/internal/wasm/jit/jit_impl_amd64.go b/internal/wasm/jit/jit_impl_amd64.go index d872a1bb..5cfd7108 100644 --- a/internal/wasm/jit/jit_impl_amd64.go +++ b/internal/wasm/jit/jit_impl_amd64.go @@ -115,7 +115,6 @@ func (c *amd64Compiler) String() string { type amd64Compiler struct { assembler amd64.Assembler - f *wasm.FunctionInstance ir *wazeroir.CompilationResult // locationStack holds the state of wazeroir virtual stack. // and each item is either placed in register or the actual memory stack. @@ -128,12 +127,11 @@ type amd64Compiler struct { currentLabel string // onStackPointerCeilDeterminedCallBack hold a callback which are called when the max stack pointer is determined BEFORE generating native code. onStackPointerCeilDeterminedCallBack func(stackPointerCeil uint64) - staticData compiledFunctionStaticData + staticData codeStaticData } -func newAmd64Compiler(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) { +func newAmd64Compiler(ir *wazeroir.CompilationResult) (compiler, error) { c := &amd64Compiler{ - f: f, assembler: amd64.NewAssemblerImpl(), locationStack: newValueLocationStack(), currentLabel: wazeroir.EntrypointLabel, @@ -195,7 +193,7 @@ func (c *amd64Compiler) compileHostFunction() error { } // compile implements compiler.compile for the amd64 architecture. -func (c *amd64Compiler) compile() (code []byte, staticData compiledFunctionStaticData, stackPointerCeil uint64, err error) { +func (c *amd64Compiler) compile() (code []byte, staticData codeStaticData, stackPointerCeil uint64, err error) { // c.stackPointerCeil tracks the stack pointer ceiling (max seen) value across all valueLocationStack(s) // used for all labels (via setLocationStack), excluding the current one. // Hence, we check here if the final block's max one exceeds the current c.stackPointerCeil. @@ -226,8 +224,8 @@ func (c *amd64Compiler) compile() (code []byte, staticData compiledFunctionStati } func (c *amd64Compiler) pushFunctionParams() { - if c.f != nil && c.f.Type != nil { - for _, t := range c.f.Type.Params { + if c.ir != nil { + for _, t := range c.ir.Signature.Params { loc := c.locationStack.pushValueLocationOnStack() switch t { case wasm.ValueTypeI32, wasm.ValueTypeI64: @@ -349,7 +347,7 @@ func (c *amd64Compiler) compileGlobalGet(o *wazeroir.OperationGlobalGet) error { // When an integer, reuse the pointer register for the value. Otherwise, allocate a float register for it. valueReg := intReg - wasmType := c.f.Module.Globals[o.Index].Type.ValType + wasmType := c.ir.Globals[o.Index].ValType switch wasmType { case wasm.ValueTypeF32, wasm.ValueTypeF64: valueReg, err = c.allocateRegister(generalPurposeRegisterTypeFloat) @@ -731,18 +729,19 @@ func (c *amd64Compiler) compileLabel(o *wazeroir.OperationLabel) (skipLabel bool // compileCall implements compiler.compileCall for the amd64 architecture. func (c *amd64Compiler) compileCall(o *wazeroir.OperationCall) error { - target := c.f.Module.Functions[o.FunctionIndex] - if err := c.compileCallFunctionImpl(o.FunctionIndex, asm.NilRegister, target.Type); err != nil { + target := c.ir.Functions[o.FunctionIndex] + targetType := c.ir.Types[target] + if err := c.compileCallFunctionImpl(o.FunctionIndex, asm.NilRegister, targetType); err != nil { return err } // We consumed the function parameters from the stack after call. - for i := 0; i < len(target.Type.Params); i++ { + for i := 0; i < len(targetType.Params); i++ { c.locationStack.pop() } // Also, the function results were pushed by the call. - for _, t := range target.Type.Results { + for _, t := range targetType.Results { loc := c.locationStack.pushValueLocationOnStack() switch t { case wasm.ValueTypeI32, wasm.ValueTypeI64: @@ -765,6 +764,13 @@ func (c *amd64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e if err != nil { return err } + c.locationStack.markRegisterUsed(tmp) + + tmp2, err := c.allocateRegister(generalPurposeRegisterTypeInt) + if err != nil { + return err + } + c.locationStack.markRegisterUsed(tmp2) // First, we need to check if the offset doesn't exceed the length of table. c.assembler.CompileMemoryToRegister(amd64.CMPQ, amd64ReservedRegisterForCallEngine, callEngineModuleContextTableSliceLenOffset, offset.register) @@ -784,10 +790,10 @@ func (c *amd64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e c.assembler.CompileMemoryToRegister(amd64.ADDQ, amd64ReservedRegisterForCallEngine, callEngineModuleContextTableElement0AddressOffset, offset.register) - // "offset = (*offset) + interfaceDataOffset (== table[offset] + interfaceDataOffset == *compiledFunction type)" + // "offset = (*offset) + interfaceDataOffset (== table[offset] + interfaceDataOffset == *code type)" c.assembler.CompileMemoryToRegister(amd64.MOVQ, offset.register, interfaceDataOffset, offset.register) - // At this point offset.register holds the address of *compiledFunction (as uintptr) at wasm.Table[offset]. + // At this point offset.register holds the address of *code (as uintptr) at wasm.Table[offset]. // // Check if the value of table[offset] equals zero, meaning that the target is uninitialized. c.assembler.CompileRegisterToConst(amd64.CMPQ, offset.register, 0) @@ -800,28 +806,32 @@ func (c *amd64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e c.assembler.SetJumpTargetOnNext(jumpIfInitialized) - // Next we need to check the type matches, i.e. table[offset].source.TypeID == targetFunctionType. + // Next we need to check the type matches, i.e. table[offset].source.TypeID == targetFunctionType's typeID. // // "tmp = table[offset].source ( == *FunctionInstance type)" - c.assembler.CompileMemoryToRegister(amd64.MOVQ, offset.register, compiledFunctionSourceOffset, tmp) + c.assembler.CompileMemoryToRegister(amd64.MOVQ, offset.register, functionSourceOffset, tmp) - ti := c.f.Module.Types[o.TypeIndex] - targetFunctionType := ti.Type - c.assembler.CompileMemoryToConst(amd64.CMPL, tmp, functionInstanceTypeIDOffset, int64(ti.TypeID)) + // "tmp2 = [&moduleInstance.TypeIDs[0] + index * 4] (== moduleInstance.TypeIDs[index])" + c.assembler.CompileMemoryToRegister(amd64.MOVQ, + amd64ReservedRegisterForCallEngine, callEngineModuleContextTypeIDsElement0AddressOffset, + tmp2) + c.assembler.CompileMemoryToRegister(amd64.MOVQ, tmp2, int64(o.TypeIndex)*4, tmp2) // Jump if the type matches. + c.assembler.CompileMemoryToRegister(amd64.CMPL, tmp, functionInstanceTypeIDOffset, tmp2) jumpIfTypeMatch := c.assembler.CompileJump(amd64.JEQ) // Otherwise, exit with type mismatch status. c.compileExitFromNativeCode(jitCallStatusCodeTypeMismatchOnIndirectCall) c.assembler.SetJumpTargetOnNext(jumpIfTypeMatch) + targetFunctionType := c.ir.Types[o.TypeIndex] if err = c.compileCallFunctionImpl(0, offset.register, targetFunctionType); err != nil { return nil } // The offset register should be marked as un-used as we consumed in the function call. - c.locationStack.markRegisterUnused(offset.register, tmp) + c.locationStack.markRegisterUnused(offset.register, tmp, tmp2) // We consumed the function parameters from the stack after call. for i := 0; i < len(targetFunctionType.Params); i++ { @@ -3394,25 +3404,25 @@ func (c *amd64Compiler) allocateRegister(t generalPurposeRegisterType) (reg asm. // // Note: this is the counter part for returnFunction, and see the comments there as well // to understand how the function calls are achieved. -func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFunctionAddressRegister asm.Register, functype *wasm.FunctionType) error { +func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, codeAddressRegister asm.Register, functype *wasm.FunctionType) error { // Release all the registers as our calling convention requires the caller-save. c.compileReleaseAllRegistersToStack() // First, we have to make sure that - if !isNilRegister(compiledFunctionAddressRegister) { - c.locationStack.markRegisterUsed(compiledFunctionAddressRegister) + if !isNilRegister(codeAddressRegister) { + c.locationStack.markRegisterUsed(codeAddressRegister) } // Obtain the temporary registers to be used in the followings. freeRegs, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 4) if !found { - // This in theory never happen as all the registers must be free except compiledFunctionAddressRegister. + // This in theory never happen as all the registers must be free except codeAddressRegister. return fmt.Errorf("could not find enough free registers") } c.locationStack.markRegisterUsed(freeRegs...) // Alias these free tmp registers for readability. - callFrameStackPointerRegister, tmpRegister, targetCompiledFunctionAddressRegister, + callFrameStackPointerRegister, tmpRegister, targetcodeAddressRegister, callFrameStackTopAddressRegister := freeRegs[0], freeRegs[1], freeRegs[2], freeRegs[3] // First, we read the current call frame stack pointer. @@ -3428,10 +3438,10 @@ func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFuncti jmpIfNotCallFrameStackNeedsGrow := c.assembler.CompileJump(amd64.JNE) // Otherwise, we have to make the builtin function call to grow the call frame stack. - if !isNilRegister(compiledFunctionAddressRegister) { + if !isNilRegister(codeAddressRegister) { // If we need to get the target funcaddr from register (call_indirect case), we must save it before growing the // call-frame stack, as the register is not saved across function calls. - savedOffsetLocation := c.pushValueLocationOnRegister(compiledFunctionAddressRegister) + savedOffsetLocation := c.pushValueLocationOnRegister(codeAddressRegister) c.compileReleaseRegisterToStack(savedOffsetLocation) } @@ -3441,13 +3451,13 @@ func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFuncti } // For call_indirect, we need to push the value back to the register. - if !isNilRegister(compiledFunctionAddressRegister) { + if !isNilRegister(codeAddressRegister) { // Since this is right after callGoFunction, we have to initialize the stack base pointer // to properly load the value on memory stack. c.compileReservedStackBasePointerInitialization() savedOffsetLocation := c.locationStack.pop() - savedOffsetLocation.setRegister(compiledFunctionAddressRegister) + savedOffsetLocation.setRegister(codeAddressRegister) c.compileLoadValueOnStackToRegister(savedOffsetLocation) } @@ -3483,7 +3493,7 @@ func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFuncti // where: // ra.* = callFrame.returnAddress // rb.* = callFrame.returnStackBasePointer - // rc.* = callFrame.compiledFunction + // rc.* = callFrame.code // _ = callFrame's padding (see comment on callFrame._ field.) // // In the following comment, we use the notations in the above example. @@ -3525,28 +3535,28 @@ func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFuncti // 3) Set rc.next to specify which function is executed on the current call frame (needs to make builtin function calls). { - if isNilRegister(compiledFunctionAddressRegister) { - // We must set the target function's address(pointer) of *compiledFunction into the next call-frame stack. + if isNilRegister(codeAddressRegister) { + // We must set the target function's address(pointer) of *code into the next call-frame stack. // In the example, this is equivalent to writing the value into "rc.next". // - // First, we read the address of the first item of callEngine.compiledFunctions slice (= &callEngine.compiledFunctions[0]) + // First, we read the address of the first item of callEngine.codes slice (= &callEngine.codes[0]) // into tmpRegister. - c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64ReservedRegisterForCallEngine, callEngineModuleContextCompiledFunctionsElement0AddressOffset, tmpRegister) + c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64ReservedRegisterForCallEngine, callEngineModuleContextcodesElement0AddressOffset, tmpRegister) - // Next, read the address of the target function (= &callEngine.compiledFunctions[offset]) + // Next, read the address of the target function (= &callEngine.codes[offset]) // into targetAddressRegister. c.assembler.CompileMemoryToRegister(amd64.MOVQ, // Note: FunctionIndex is limited up to 2^27 so this offset never exceeds 32-bit integer. - // *8 because the size of *compiledFunction equals 8 bytes. + // *8 because the size of *code equals 8 bytes. tmpRegister, int64(index)*8, - targetCompiledFunctionAddressRegister, + targetcodeAddressRegister, ) } else { - targetCompiledFunctionAddressRegister = compiledFunctionAddressRegister + targetcodeAddressRegister = codeAddressRegister } - // Finally, we are ready to place the address of the target function's *compiledFunction into the new call-frame. + // Finally, we are ready to place the address of the target function's *code into the new call-frame. // In the example, this is equivalent to set "rc.next". - c.assembler.CompileRegisterToMemory(amd64.MOVQ, targetCompiledFunctionAddressRegister, callFrameStackTopAddressRegister, callFrameCompiledFunctionOffset) + c.assembler.CompileRegisterToMemory(amd64.MOVQ, targetcodeAddressRegister, callFrameStackTopAddressRegister, callFrameFunctionOffset) } // 4) Set ra.1 so that we can return back to this function properly. @@ -3568,11 +3578,11 @@ func (c *amd64Compiler) compileCallFunctionImpl(index wasm.Index, compiledFuncti c.assembler.CompileNoneToMemory(amd64.INCQ, amd64ReservedRegisterForCallEngine, callEngineGlobalContextCallFrameStackPointerOffset) // Also, we have to put the target function's *wasm.ModuleInstance into amd64CallingConventionModuleInstanceAddressRegister. - c.assembler.CompileMemoryToRegister(amd64.MOVQ, targetCompiledFunctionAddressRegister, compiledFunctionModuleInstanceAddressOffset, + c.assembler.CompileMemoryToRegister(amd64.MOVQ, targetcodeAddressRegister, functionModuleInstanceAddressOffset, amd64CallingConventionModuleInstanceAddressRegister) // And jump into the initial address of the target function. - c.assembler.CompileJumpToMemory(amd64.JMP, targetCompiledFunctionAddressRegister, compiledFunctionCodeInitialAddressOffset) + c.assembler.CompileJumpToMemory(amd64.JMP, targetcodeAddressRegister, functionCodeInitialAddressOffset) // All the registers used are temporary so we mark them unused. c.locationStack.markRegisterUnused(freeRegs...) @@ -3667,7 +3677,7 @@ func (c *amd64Compiler) compileReturnFunction() error { // where: // ra.* = callFrame.returnAddress // rb.* = callFrame.returnStackBasePointer - // rc.* = callFrame.compiledFunction + // rc.* = callFrame.code // _ = callFrame's padding (see comment on callFrame._ field.) // // What we have to do in the following is that @@ -3687,11 +3697,11 @@ func (c *amd64Compiler) compileReturnFunction() error { // 2) Load rc.caller.moduleInstanceAddress into amd64CallingConventionModuleInstanceAddressRegister c.assembler.CompileMemoryToRegister(amd64.MOVQ, // "rc.caller" is BELOW the top address. See the above example for detail. - callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameCompiledFunctionOffset), + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameFunctionOffset), amd64CallingConventionModuleInstanceAddressRegister, ) c.assembler.CompileMemoryToRegister(amd64.MOVQ, - amd64CallingConventionModuleInstanceAddressRegister, compiledFunctionModuleInstanceAddressOffset, + amd64CallingConventionModuleInstanceAddressRegister, functionModuleInstanceAddressOffset, amd64CallingConventionModuleInstanceAddressRegister, ) @@ -3850,7 +3860,7 @@ func (c *amd64Compiler) compileReservedStackBasePointerInitialization() { } func (c *amd64Compiler) compileReservedMemoryPointerInitialization() { - if c.f.Module.Memory != nil { + if c.ir.HasMemory { c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64ReservedRegisterForCallEngine, callEngineModuleContextMemoryElement0AddressOffset, amd64ReservedRegisterForMemory, @@ -3925,14 +3935,15 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { // * callEngine.moduleContext.tableSliceLen // * callEngine.moduleContext.memoryElement0Address // * callEngine.moduleContext.memorySliceLen - // * callEngine.moduleContext.compiledFunctionsElement0Address + // * callEngine.moduleContext.codesElement0Address + // * callEngine.moduleContext.typeIDsElement0Address // Update globalElement0Address. // // Note: if there's global.get or set instruction in the function, the existence of the globals // is ensured by function validation at module instantiation phase, and that's why it is ok to // skip the initialization if the module's globals slice is empty. - if len(c.f.Module.Globals) > 0 { + if len(c.ir.Globals) > 0 { // Since ModuleInstance.Globals is []*globalInstance, internally // the address of the first item in the underlying array lies exactly on the globals offset. // See https://go.dev/blog/slices-intro if unfamiliar. @@ -3946,7 +3957,7 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { // Note: if there's table instruction in the function, the existence of the table // is ensured by function validation at module instantiation phase, and that's // why it is ok to skip the initialization if the module's table doesn't exist. - if c.f.Module.Table != nil { + if c.ir.HasTable { // First, we need to read the *wasm.Table. c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64CallingConventionModuleInstanceAddressRegister, moduleInstanceTableOffset, tmpRegister) @@ -3958,13 +3969,18 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { c.assembler.CompileRegisterToMemory(amd64.MOVQ, tmpRegister2, amd64ReservedRegisterForCallEngine, callEngineModuleContextTableElement0AddressOffset) - // Finally, read the length of table and update tableSliceLen accordingly. + // Next, read the length of table and update tableSliceLen accordingly. c.assembler.CompileMemoryToRegister(amd64.MOVQ, tmpRegister, tableInstanceTableLenOffset, tmpRegister2) // And put the length into tableSliceLen. - c.assembler.CompileRegisterToMemory(amd64.MOVQ, tmpRegister2, amd64ReservedRegisterForCallEngine, callEngineModuleContextTableSliceLenOffset) + + // Finally, we put &ModuleInstance.TypeIDs[0] into moduleContext.typeIDsElement0Address. + c.assembler.CompileMemoryToRegister(amd64.MOVQ, + amd64CallingConventionModuleInstanceAddressRegister, moduleInstanceTypeIDsOffset, tmpRegister) + c.assembler.CompileRegisterToMemory(amd64.MOVQ, + tmpRegister, amd64ReservedRegisterForCallEngine, callEngineModuleContextTypeIDsElement0AddressOffset) } // Update memoryElement0Address and memorySliceLen. @@ -3972,7 +3988,7 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { // Note: if there's memory instruction in the function, memory instance must be non-nil. // That is ensured by function validation at module instantiation phase, and that's // why it is ok to skip the initialization if the module's memory instance is nil. - if c.f.Module.Memory != nil { + if c.ir.HasMemory { c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64CallingConventionModuleInstanceAddressRegister, moduleInstanceMemoryOffset, tmpRegister) // Set length. @@ -3986,7 +4002,7 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { amd64ReservedRegisterForCallEngine, callEngineModuleContextMemoryElement0AddressOffset) } - // Update moduleContext.compiledFunctionsElement0Address + // Update moduleContext.codesElement0Address { // "tmpRegister = [moduleInstanceAddressRegister + moduleInstanceEngineOffset + interfaceDataOffset] (== *moduleEngine)" // @@ -3998,11 +4014,11 @@ func (c *amd64Compiler) compileModuleContextInitialization() error { // * https://github.com/golang/go/blob/release-branch.go1.17/src/runtime/runtime2.go#L207-L210 c.assembler.CompileMemoryToRegister(amd64.MOVQ, amd64CallingConventionModuleInstanceAddressRegister, moduleInstanceEngineOffset+interfaceDataOffset, tmpRegister) - // "tmpRegister = [tmpRegister + moduleEngineCompiledFunctionsOffset] (== &moduleEngine.compiledFunctions[0])" - c.assembler.CompileMemoryToRegister(amd64.MOVQ, tmpRegister, moduleEngineCompiledFunctionsOffset, tmpRegister) + // "tmpRegister = [tmpRegister + moduleEnginecodesOffset] (== &moduleEngine.codes[0])" + c.assembler.CompileMemoryToRegister(amd64.MOVQ, tmpRegister, moduleEngineFunctionsOffset, tmpRegister) - // "callEngine.moduleContext.compiledFunctionsElement0Address = tmpRegister". - c.assembler.CompileRegisterToMemory(amd64.MOVQ, tmpRegister, amd64ReservedRegisterForCallEngine, callEngineModuleContextCompiledFunctionsElement0AddressOffset) + // "callEngine.moduleContext.codesElement0Address = tmpRegister". + c.assembler.CompileRegisterToMemory(amd64.MOVQ, tmpRegister, amd64ReservedRegisterForCallEngine, callEngineModuleContextcodesElement0AddressOffset) } c.locationStack.markRegisterUnused(regs...) diff --git a/internal/wasm/jit/jit_impl_arm64.go b/internal/wasm/jit/jit_impl_arm64.go index 6f0c93c8..c02288aa 100644 --- a/internal/wasm/jit/jit_impl_arm64.go +++ b/internal/wasm/jit/jit_impl_arm64.go @@ -22,7 +22,6 @@ import ( type arm64Compiler struct { assembler arm64.Assembler - f *wasm.FunctionInstance ir *wazeroir.CompilationResult // locationStack holds the state of wazeroir virtual stack. // and each item is either placed in register or the actual memory stack. @@ -33,14 +32,13 @@ type arm64Compiler struct { stackPointerCeil uint64 // onStackPointerCeilDeterminedCallBack hold a callback which are called when the ceil of stack pointer is determined before generating native code. onStackPointerCeilDeterminedCallBack func(stackPointerCeil uint64) - // compiledFunctionStaticData holds br_table offset tables. - // See compiledFunctionStaticData and arm64Compiler.compileBrTable. - staticData compiledFunctionStaticData + // codeStaticData holds br_table offset tables. + // See codeStaticData and arm64Compiler.compileBrTable. + staticData codeStaticData } -func newArm64Compiler(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) { +func newArm64Compiler(ir *wazeroir.CompilationResult) (compiler, error) { return &arm64Compiler{ - f: f, assembler: arm64.NewAssemblerImpl(arm64ReservedRegisterForTemporary), locationStack: newValueLocationStack(), ir: ir, @@ -87,11 +85,11 @@ var ( const ( // arm64CallEngineArchContextJITCallReturnAddressOffset is the offset of archContext.jitCallReturnAddress in callEngine. - arm64CallEngineArchContextJITCallReturnAddressOffset = 120 + arm64CallEngineArchContextJITCallReturnAddressOffset = 128 // arm64CallEngineArchContextMinimum32BitSignedIntOffset is the offset of archContext.minimum32BitSignedIntAddress in callEngine. - arm64CallEngineArchContextMinimum32BitSignedIntOffset = 128 + arm64CallEngineArchContextMinimum32BitSignedIntOffset = 136 // arm64CallEngineArchContextMinimum64BitSignedIntOffset is the offset of archContext.minimum64BitSignedIntAddress in callEngine. - arm64CallEngineArchContextMinimum64BitSignedIntOffset = 136 + arm64CallEngineArchContextMinimum64BitSignedIntOffset = 144 ) func isZeroRegister(r asm.Register) bool { @@ -103,7 +101,7 @@ func (c *arm64Compiler) addStaticData(d []byte) { } // compile implements compiler.compile for the arm64 architecture. -func (c *arm64Compiler) compile() (code []byte, staticData compiledFunctionStaticData, stackPointerCeil uint64, err error) { +func (c *arm64Compiler) compile() (code []byte, staticData codeStaticData, stackPointerCeil uint64, err error) { // c.stackPointerCeil tracks the stack pointer ceiling (max seen) value across all valueLocationStack(s) // used for all labels (via setLocationStack), excluding the current one. // Hence, we check here if the final block's max one exceeds the current c.stackPointerCeil. @@ -178,10 +176,7 @@ func (c *arm64Compiler) String() (ret string) { return } // pushFunctionParams pushes any function parameters onto the stack, setting appropriate register types. func (c *arm64Compiler) pushFunctionParams() { - if c.f == nil || c.f.Type == nil { - return - } - for _, t := range c.f.Type.Params { + for _, t := range c.ir.Signature.Params { loc := c.locationStack.pushValueLocationOnStack() switch t { case wasm.ValueTypeI32, wasm.ValueTypeI64: @@ -335,7 +330,7 @@ func (c *arm64Compiler) compileReturnFunction() error { // where: // ra.* = callFrame.returnAddress // rb.* = callFrame.returnStackBasePointer - // rc.* = callFrame.compiledFunction + // rc.* = callFrame.code // _ = callFrame's padding (see comment on callFrame._ field.) // // What we have to do in the following is that @@ -355,10 +350,10 @@ func (c *arm64Compiler) compileReturnFunction() error { // 2) Load rc.caller.moduleInstanceAddress into arm64CallingConventionModuleInstanceAddressRegister. c.assembler.CompileMemoryToRegister(arm64.MOVD, // "rb.caller" is below the top address. - callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameCompiledFunctionOffset), + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameFunctionOffset), arm64CallingConventionModuleInstanceAddressRegister) c.assembler.CompileMemoryToRegister(arm64.MOVD, - arm64CallingConventionModuleInstanceAddressRegister, compiledFunctionModuleInstanceAddressOffset, + arm64CallingConventionModuleInstanceAddressRegister, functionModuleInstanceAddressOffset, arm64CallingConventionModuleInstanceAddressRegister) // 3) Branch into the address of "ra.caller". @@ -404,7 +399,7 @@ func (c *arm64Compiler) compileHostFunction() error { // First we must update the location stack to reflect the number of host function inputs. c.pushFunctionParams() - if err := c.compileCallGoFunction(jitCallStatusCodeCallHostFunction, c.f.Index); err != nil { + if err := c.compileCallGoFunction(jitCallStatusCodeCallHostFunction, 0); err != nil { return err } return c.compileReturnFunction() @@ -483,7 +478,7 @@ func (c *arm64Compiler) compileGlobalGet(o *wazeroir.OperationGlobalGet) error { } var intMov, floatMov asm.Instruction = arm64.NOP, arm64.NOP - switch c.f.Module.Globals[o.Index].Type.ValType { + switch c.ir.Globals[o.Index].ValType { case wasm.ValueTypeI32: intMov = arm64.MOVWU case wasm.ValueTypeI64: @@ -531,7 +526,7 @@ func (c *arm64Compiler) compileGlobalSet(o *wazeroir.OperationGlobalSet) error { } var mov asm.Instruction - switch c.f.Module.Globals[o.Index].Type.ValType { + switch c.ir.Globals[o.Index].ValType { case wasm.ValueTypeI32: mov = arm64.MOVWU case wasm.ValueTypeI64: @@ -852,12 +847,12 @@ func (c *arm64Compiler) compileBrTable(o *wazeroir.OperationBrTable) error { // compileCall implements compiler.compileCall for the arm64 architecture. func (c *arm64Compiler) compileCall(o *wazeroir.OperationCall) error { - tp := c.f.Module.Functions[o.FunctionIndex].Type + tp := c.ir.Types[c.ir.Functions[o.FunctionIndex]] return c.compileCallImpl(o.FunctionIndex, asm.NilRegister, tp) } // compileCallImpl implements compiler.compileCall and compiler.compileCallIndirect for the arm64 architecture. -func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddressRegister asm.Register, functype *wasm.FunctionType) error { +func (c *arm64Compiler) compileCallImpl(index wasm.Index, codeAddressRegister asm.Register, functype *wasm.FunctionType) error { // Release all the registers as our calling convention requires the caller-save. if err := c.compileReleaseAllRegistersToStack(); err != nil { return err @@ -870,7 +865,7 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres c.markRegisterUsed(freeRegisters...) // Alias for readability. - callFrameStackPointerRegister, callFrameStackTopAddressRegister, compiledFunctionRegister, oldStackBasePointer, + callFrameStackPointerRegister, callFrameStackTopAddressRegister, codeRegister, oldStackBasePointer, tmp := freeRegisters[0], freeRegisters[1], freeRegisters[2], freeRegisters[3], freeRegisters[4] // First, we have to check if we need to grow the callFrame stack. @@ -891,10 +886,10 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres // If these values equal, we need to grow the callFrame stack. // For call_indirect, we need to push the value back to the register. - if !isNilRegister(compiledFunctionAddressRegister) { + if !isNilRegister(codeAddressRegister) { // If we need to get the target funcaddr from register (call_indirect case), we must save it before growing the // call-frame stack, as the register is not saved across function calls. - savedOffsetLocation := c.pushValueLocationOnRegister(compiledFunctionAddressRegister) + savedOffsetLocation := c.pushValueLocationOnRegister(codeAddressRegister) c.compileReleaseRegisterToStack(savedOffsetLocation) } @@ -903,13 +898,13 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres } // For call_indirect, we need to push the value back to the register. - if !isNilRegister(compiledFunctionAddressRegister) { + if !isNilRegister(codeAddressRegister) { // Since this is right after callGoFunction, we have to initialize the stack base pointer // to properly load the value on memory stack. c.compileReservedStackBasePointerRegisterInitialization() savedOffsetLocation := c.locationStack.pop() - savedOffsetLocation.setRegister(compiledFunctionAddressRegister) + savedOffsetLocation.setRegister(codeAddressRegister) c.compileLoadValueOnStackToRegister(savedOffsetLocation) } @@ -933,7 +928,7 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres // where: // ra.* = callFrame.returnAddress // rb.* = callFrame.returnStackBasePointer - // rc.* = callFrame.compiledFunction + // rc.* = callFrame.code // _ = callFrame's padding (see comment on callFrame._ field.) // // In the following comment, we use the notations in the above example. @@ -967,27 +962,27 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres // 3) Set rc.next to specify which function is executed on the current call frame. // - // First, we read the address of the first item of ce.compiledFunctions slice (= &ce.compiledFunctions[0]) + // First, we read the address of the first item of ce.codes slice (= &ce.codes[0]) // into tmp. c.assembler.CompileMemoryToRegister(arm64.MOVD, - arm64ReservedRegisterForCallEngine, callEngineModuleContextCompiledFunctionsElement0AddressOffset, + arm64ReservedRegisterForCallEngine, callEngineModuleContextcodesElement0AddressOffset, tmp) - // Next, read the index of the target function (= &ce.compiledFunctions[offset]) - // into compiledFunctionIndexRegister. - if isNilRegister(compiledFunctionAddressRegister) { + // Next, read the index of the target function (= &ce.codes[offset]) + // into codeIndexRegister. + if isNilRegister(codeAddressRegister) { c.assembler.CompileMemoryToRegister( arm64.MOVD, - tmp, int64(index)*8, // * 8 because the size of *compiledFunction equals 8 bytes. - compiledFunctionRegister) + tmp, int64(index)*8, // * 8 because the size of *code equals 8 bytes. + codeRegister) } else { - compiledFunctionRegister = compiledFunctionAddressRegister + codeRegister = codeAddressRegister } - // Finally, we are ready to write the address of the target function's *compiledFunction into the new call-frame. + // Finally, we are ready to write the address of the target function's *code into the new call-frame. c.assembler.CompileRegisterToMemory(arm64.MOVD, - compiledFunctionRegister, - callFrameStackTopAddressRegister, callFrameCompiledFunctionOffset) + codeRegister, + callFrameStackTopAddressRegister, callFrameFunctionOffset) // 4) Set ra.current so that we can return back to this function properly. // @@ -1009,15 +1004,15 @@ func (c *arm64Compiler) compileCallImpl(index wasm.Index, compiledFunctionAddres tmp, arm64ReservedRegisterForCallEngine, callEngineGlobalContextCallFrameStackPointerOffset) - // Also, we have to put the compiledFunction's moduleinstance address into arm64CallingConventionModuleInstanceAddressRegister. + // Also, we have to put the code's moduleinstance address into arm64CallingConventionModuleInstanceAddressRegister. c.assembler.CompileMemoryToRegister(arm64.MOVD, - compiledFunctionRegister, compiledFunctionModuleInstanceAddressOffset, + codeRegister, functionModuleInstanceAddressOffset, arm64CallingConventionModuleInstanceAddressRegister, ) // Then, br into the target function's initial address. c.assembler.CompileMemoryToRegister(arm64.MOVD, - compiledFunctionRegister, compiledFunctionCodeInitialAddressOffset, + codeRegister, functionCodeInitialAddressOffset, tmp) c.assembler.CompileJumpToMemory(arm64.B, tmp) @@ -1091,6 +1086,13 @@ func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e if err != nil { return err } + c.markRegisterUsed(tmp) + + tmp2, err := c.allocateRegister(generalPurposeRegisterTypeInt) + if err != nil { + return err + } + c.markRegisterUsed(tmp2) // First, we need to check if the offset doesn't exceed the length of table. // "tmp = len(table)" @@ -1125,7 +1127,7 @@ func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e offset.register, ) - // "offset = (*offset) + interfaceDataOffset (== table[offset] + interfaceDataOffset == *compiledFunction type)" + // "offset = (*offset) + interfaceDataOffset (== table[offset] + interfaceDataOffset == *code type)" c.assembler.CompileMemoryToRegister(arm64.MOVD, offset.register, interfaceDataOffset, offset.register) // Check if the value of table[offset] equals zero, meaning that the target element is uninitialized. @@ -1134,12 +1136,11 @@ func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e c.compileExitFromNativeCode(jitCallStatusCodeInvalidTableAccess) c.assembler.SetJumpTargetOnNext(brIfInitialized) - targetFunctionType := c.f.Module.Types[o.TypeIndex] // Next we check the type matches, i.e. table[offset].source.TypeID == targetFunctionType. // "tmp = table[offset].source ( == *FunctionInstance type)" c.assembler.CompileMemoryToRegister( arm64.MOVD, - offset.register, compiledFunctionSourceOffset, + offset.register, functionSourceOffset, tmp, ) // "tmp = [tmp + functionInstanceTypeIDOffset] (== table[offset].source.TypeID)" @@ -1147,22 +1148,26 @@ func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e arm64.MOVW, tmp, functionInstanceTypeIDOffset, tmp, ) - // "arm64ReservedRegisterForTemporary = targetFunctionType.TypeID" - c.assembler.CompileConstToRegister(arm64.MOVW, int64(targetFunctionType.TypeID), arm64ReservedRegisterForTemporary) + // "tmp2 = ModuleInstance.TypeIDs[index]" + c.assembler.CompileMemoryToRegister(arm64.MOVD, + arm64ReservedRegisterForCallEngine, callEngineModuleContextTypeIDsElement0AddressOffset, + tmp2) + c.assembler.CompileMemoryToRegister(arm64.MOVWU, tmp2, int64(o.TypeIndex)*4, tmp2) // Compare these two values, and if they equal, we are ready to make function call. - c.assembler.CompileTwoRegistersToNone(arm64.CMPW, tmp, arm64ReservedRegisterForTemporary) + c.assembler.CompileTwoRegistersToNone(arm64.CMPW, tmp, tmp2) brIfTypeMatched := c.assembler.CompileJump(arm64.BEQ) c.compileExitFromNativeCode(jitCallStatusCodeTypeMismatchOnIndirectCall) c.assembler.SetJumpTargetOnNext(brIfTypeMatched) - if err := c.compileCallImpl(0, offset.register, targetFunctionType.Type); err != nil { + targetFunctionType := c.ir.Types[o.TypeIndex] + if err := c.compileCallImpl(0, offset.register, targetFunctionType); err != nil { return err } // The offset register should be marked as un-used as we consumed in the function call. - c.markRegisterUnused(offset.register) + c.markRegisterUnused(offset.register, tmp, tmp2) return nil } @@ -3041,7 +3046,7 @@ func (c *arm64Compiler) compileReservedStackBasePointerRegisterInitialization() } func (c *arm64Compiler) compileReservedMemoryRegisterInitialization() { - if c.f.Module.Memory != nil { + if c.ir.HasMemory { // "arm64ReservedRegisterForMemory = ce.MemoryElement0Address" c.assembler.CompileMemoryToRegister( arm64.MOVD, @@ -3086,14 +3091,15 @@ func (c *arm64Compiler) compileModuleContextInitialization() error { // * callEngine.moduleContext.memorySliceLen // * callEngine.moduleContext.tableElement0Address // * callEngine.moduleContext.tableSliceLen - // * callEngine.moduleContext.compiledFunctionsElement0Address + // * callEngine.moduleContext.codesElement0Address + // * callEngine.moduleContext.typeIDsElement0Address // Update globalElement0Address. // // Note: if there's global.get or set instruction in the function, the existence of the globals // is ensured by function validation at module instantiation phase, and that's why it is ok to // skip the initialization if the module's globals slice is empty. - if len(c.f.Module.Globals) > 0 { + if len(c.ir.Globals) > 0 { // "tmpX = &moduleInstance.Globals[0]" c.assembler.CompileMemoryToRegister(arm64.MOVD, arm64CallingConventionModuleInstanceAddressRegister, moduleInstanceGlobalsOffset, @@ -3112,7 +3118,7 @@ func (c *arm64Compiler) compileModuleContextInitialization() error { // Note: if there's memory instruction in the function, memory instance must be non-nil. // That is ensured by function validation at module instantiation phase, and that's // why it is ok to skip the initialization if the module's memory instance is nil. - if c.f.Module.Memory != nil { + if c.ir.HasMemory { // "tmpX = moduleInstance.Memory" c.assembler.CompileMemoryToRegister( arm64.MOVD, @@ -3151,12 +3157,12 @@ func (c *arm64Compiler) compileModuleContextInitialization() error { ) } - // Update tableElement0Address and tableSliceLen. + // Update tableElement0Address, tableSliceLen and typeIDsElement0Address. // // Note: if there's table instruction in the function, the existence of the table // is ensured by function validation at module instantiation phase, and that's // why it is ok to skip the initialization if the module's table doesn't exist. - if c.f.Module.Table != nil { + if c.ir.HasTable { // "tmpX = &tables[0] (type of **wasm.Table)" c.assembler.CompileMemoryToRegister( arm64.MOVD, @@ -3191,9 +3197,15 @@ func (c *arm64Compiler) compileModuleContextInitialization() error { tmpY, arm64ReservedRegisterForCallEngine, callEngineModuleContextTableSliceLenOffset, ) + + // Finally, we put &ModuleInstance.TypeIDs[0] into moduleContext.typeIDsElement0Address. + c.assembler.CompileMemoryToRegister(arm64.MOVD, + arm64CallingConventionModuleInstanceAddressRegister, moduleInstanceTypeIDsOffset, tmpX) + c.assembler.CompileRegisterToMemory(arm64.MOVD, + tmpX, arm64ReservedRegisterForCallEngine, callEngineModuleContextTypeIDsElement0AddressOffset) } - // Update callEngine.moduleContext.compiledFunctionsElement0Address + // Update callEngine.moduleContext.codesElement0Address { // "tmpX = [moduleInstanceAddressRegister + moduleInstanceEngineOffset + interfaceDataOffset] (== *moduleEngine)" // @@ -3209,18 +3221,18 @@ func (c *arm64Compiler) compileModuleContextInitialization() error { tmpX, ) - // "tmpY = [tmpX + moduleEngineCompiledFunctionsOffset] (== &moduleEngine.compiledFunctions[0])" + // "tmpY = [tmpX + moduleEngineFunctionsOffset] (== &moduleEngine.codes[0])" c.assembler.CompileMemoryToRegister( arm64.MOVD, - tmpX, moduleEngineCompiledFunctionsOffset, + tmpX, moduleEngineFunctionsOffset, tmpY, ) - // "callEngine.moduleContext.compiledFunctionsElement0Address = tmpY". + // "callEngine.moduleContext.codesElement0Address = tmpY". c.assembler.CompileRegisterToMemory( arm64.MOVD, tmpY, - arm64ReservedRegisterForCallEngine, callEngineModuleContextCompiledFunctionsElement0AddressOffset, + arm64ReservedRegisterForCallEngine, callEngineModuleContextcodesElement0AddressOffset, ) } diff --git a/internal/wasm/jit/jit_initialization_test.go b/internal/wasm/jit/jit_initialization_test.go index 6301aad3..3e5cdf6f 100644 --- a/internal/wasm/jit/jit_initialization_test.go +++ b/internal/wasm/jit/jit_initialization_test.go @@ -8,6 +8,7 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/wasm" + "github.com/tetratelabs/wazero/internal/wazeroir" ) func TestCompiler_compileModuleContextInitialization(t *testing.T) { @@ -21,13 +22,15 @@ func TestCompiler_compileModuleContextInitialization(t *testing.T) { Globals: []*wasm.GlobalInstance{{Val: 100}}, Memory: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, Table: &wasm.TableInstance{Table: make([]interface{}, 20)}, + TypeIDs: make([]wasm.FunctionTypeID, 10), }, }, { name: "globals nil", moduleInstance: &wasm.ModuleInstance{ - Memory: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, - Table: &wasm.TableInstance{Table: make([]interface{}, 20)}, + Memory: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, + Table: &wasm.TableInstance{Table: make([]interface{}, 20)}, + TypeIDs: make([]wasm.FunctionTypeID, 10), }, }, { @@ -35,19 +38,22 @@ func TestCompiler_compileModuleContextInitialization(t *testing.T) { moduleInstance: &wasm.ModuleInstance{ Globals: []*wasm.GlobalInstance{{Val: 100}}, Table: &wasm.TableInstance{Table: make([]interface{}, 20)}, + TypeIDs: make([]wasm.FunctionTypeID, 10), }, }, { name: "table nil", moduleInstance: &wasm.ModuleInstance{ - Memory: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, - Table: &wasm.TableInstance{Table: nil}, + Memory: &wasm.MemoryInstance{Buffer: make([]byte, 10)}, + Table: &wasm.TableInstance{Table: nil}, + TypeIDs: make([]wasm.FunctionTypeID, 10), }, }, { name: "table empty", moduleInstance: &wasm.ModuleInstance{ - Table: &wasm.TableInstance{Table: make([]interface{}, 0)}, + Table: &wasm.TableInstance{Table: make([]interface{}, 0)}, + TypeIDs: make([]wasm.FunctionTypeID, 10), }, }, { @@ -67,8 +73,15 @@ func TestCompiler_compileModuleContextInitialization(t *testing.T) { env.moduleInstance = tc.moduleInstance ce := env.callEngine() - compiler := env.requireNewCompiler(t, newCompiler, nil) - me := &moduleEngine{compiledFunctions: make([]*compiledFunction, 10)} + ir := &wazeroir.CompilationResult{ + HasMemory: tc.moduleInstance.Memory != nil, + HasTable: tc.moduleInstance.Table != nil, + } + for _, g := range tc.moduleInstance.Globals { + ir.Globals = append(ir.Globals, g.Type) + } + compiler := env.requireNewCompiler(t, newCompiler, ir) + me := &moduleEngine{functions: make([]*function, 10)} tc.moduleInstance.Engine = me // The golang-asm assembler skips the first instruction, so we emit NOP here which is ignored. @@ -104,9 +117,10 @@ func TestCompiler_compileModuleContextInitialization(t *testing.T) { tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Table.Table)) require.Equal(t, uint64(tableHeader.Len), ce.moduleContext.tableSliceLen) require.Equal(t, tableHeader.Data, ce.moduleContext.tableElement0Address) + require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address) } - require.Equal(t, uintptr(unsafe.Pointer(&me.compiledFunctions[0])), ce.moduleContext.compiledFunctionsElement0Address) + require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.codesElement0Address) }) } } diff --git a/internal/wasm/jit/jit_memory_test.go b/internal/wasm/jit/jit_memory_test.go index adce5c0c..f1f478fb 100644 --- a/internal/wasm/jit/jit_memory_test.go +++ b/internal/wasm/jit/jit_memory_test.go @@ -51,7 +51,7 @@ func TestCompiler_compileMemoryGrow(t *testing.T) { func TestCompiler_compileMemorySize(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}) err := compiler.compilePreamble() require.NoError(t, err) @@ -235,7 +235,7 @@ func TestCompiler_compileLoad(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}) err := compiler.compilePreamble() require.NoError(t, err) @@ -370,7 +370,7 @@ func TestCompiler_compileStore(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { env := newJITEnvironment() - compiler := env.requireNewCompiler(t, newCompiler, nil) + compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{HasMemory: true, Signature: &wasm.FunctionType{}}) err := compiler.compilePreamble() require.NoError(t, err) diff --git a/internal/wasm/jit/jit_test.go b/internal/wasm/jit/jit_test.go index 409404f7..4655b6a7 100644 --- a/internal/wasm/jit/jit_test.go +++ b/internal/wasm/jit/jit_test.go @@ -117,10 +117,10 @@ func (j *jitEnv) callEngine() *callEngine { return j.ce } -func (j *jitEnv) exec(code []byte) { - compiledFunction := &compiledFunction{ - codeSegment: code, - codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), +func (j *jitEnv) exec(codeSegment []byte) { + f := &function{ + parent: &code{codeSegment: codeSegment}, + codeInitialAddress: uintptr(unsafe.Pointer(&codeSegment[0])), moduleInstanceAddress: uintptr(unsafe.Pointer(j.moduleInstance)), source: &wasm.FunctionInstance{ Kind: wasm.FunctionKindWasm, @@ -129,25 +129,30 @@ func (j *jitEnv) exec(code []byte) { }, } - j.ce.callFrameStack[j.ce.globalContext.callFrameStackPointer] = callFrame{compiledFunction: compiledFunction} + j.ce.callFrameStack[j.ce.globalContext.callFrameStackPointer] = callFrame{function: f} j.ce.globalContext.callFrameStackPointer++ jitcall( - uintptr(unsafe.Pointer(&code[0])), + uintptr(unsafe.Pointer(&codeSegment[0])), uintptr(unsafe.Pointer(j.ce)), uintptr(unsafe.Pointer(j.moduleInstance)), ) } // newTestCompiler allows us to test a different architecture than the current one. -type newTestCompiler func(f *wasm.FunctionInstance, ir *wazeroir.CompilationResult) (compiler, error) +type newTestCompiler func(ir *wazeroir.CompilationResult) (compiler, error) -func (j *jitEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, functype *wasm.FunctionType) compilerImpl { +func (j *jitEnv) requireNewCompiler(t *testing.T, fn newTestCompiler, ir *wazeroir.CompilationResult) compilerImpl { requireSupportedOSArch(t) - c, err := fn( - &wasm.FunctionInstance{Module: j.moduleInstance, Kind: wasm.FunctionKindWasm, Type: functype}, - &wazeroir.CompilationResult{LabelCallers: map[string]uint32{}}, - ) + + if ir == nil { + ir = &wazeroir.CompilationResult{ + LabelCallers: map[string]uint32{}, + Signature: &wasm.FunctionType{}, + } + } + c, err := fn(ir) + require.NoError(t, err) ret, ok := c.(compilerImpl) diff --git a/internal/wasm/jit/mmap_test.go b/internal/wasm/jit/mmap_test.go index 8c12eb43..37f8653d 100644 --- a/internal/wasm/jit/mmap_test.go +++ b/internal/wasm/jit/mmap_test.go @@ -8,14 +8,14 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -var code, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024)) +var testCode, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024)) func Test_mmapCodeSegment(t *testing.T) { requireSupportedOSArch(t) - newCode, err := mmapCodeSegment(code) + newCode, err := mmapCodeSegment(testCode) require.NoError(t, err) // Verify that the mmap is the same as the original. - require.Equal(t, code, newCode) + require.Equal(t, testCode, newCode) // TODO: test newCode can executed. t.Run("panic on zero length", func(t *testing.T) { @@ -30,9 +30,9 @@ func Test_munmapCodeSegment(t *testing.T) { requireSupportedOSArch(t) // Errors if never mapped - require.Error(t, munmapCodeSegment(code)) + require.Error(t, munmapCodeSegment(testCode)) - newCode, err := mmapCodeSegment(code) + newCode, err := mmapCodeSegment(testCode) require.NoError(t, err) // First munmap should succeed. require.NoError(t, munmapCodeSegment(newCode)) diff --git a/internal/wasm/module.go b/internal/wasm/module.go index ea30c4a3..5509d1e7 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -210,7 +210,7 @@ func (m *Module) Validate(enabledFeatures Features) error { return errors.New("cannot mix functions and host functions in the same module") } - functions, globals, memory, table, err := m.allDeclarations() + functions, globals, memory, table, err := m.AllDeclarations() if err != nil { return err } @@ -692,8 +692,8 @@ type NameMapAssoc struct { NameMap NameMap } -// allDeclarations returns all declarations for functions, globals, memories and tables in a module including imported ones. -func (m *Module) allDeclarations() (functions []Index, globals []*GlobalType, memory *Memory, table *Table, err error) { +// AllDeclarations returns all declarations for functions, globals, memories and tables in a module including imported ones. +func (m *Module) AllDeclarations() (functions []Index, globals []*GlobalType, memory *Memory, table *Table, err error) { for _, imp := range m.ImportSection { switch imp.Type { case ExternTypeFunc: diff --git a/internal/wasm/module_context.go b/internal/wasm/module_context.go index ebd93529..2fdef059 100644 --- a/internal/wasm/module_context.go +++ b/internal/wasm/module_context.go @@ -89,7 +89,6 @@ func (m *ModuleContext) CloseWithExitCode(exitCode uint32) (err error) { if !atomic.CompareAndSwapUint64(m.closed, 0, closed) { return nil } - m.module.Engine.Close() m.store.deleteModule(m.Name()) if sys := m.Sys; sys != nil { // ex nil if from ModuleBuilder return sys.Close() diff --git a/internal/wasm/module_test.go b/internal/wasm/module_test.go index 69a5a4f2..196f823e 100644 --- a/internal/wasm/module_test.go +++ b/internal/wasm/module_test.go @@ -167,7 +167,7 @@ func TestModule_allDeclarations(t *testing.T) { } { tc := tc t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - functions, globals, memory, table, err := tc.module.allDeclarations() + functions, globals, memory, table, err := tc.module.AllDeclarations() require.NoError(t, err) require.Equal(t, tc.expectedFunctions, functions) require.Equal(t, tc.expectedGlobals, globals) diff --git a/internal/wasm/store.go b/internal/wasm/store.go index 67827e63..78e110cd 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -62,13 +62,17 @@ type ( // Memory is set when Module.MemorySection had a memory, regardless of whether it was exported. Memory *MemoryInstance Table *TableInstance - Types []*TypeInstance + Types []*FunctionType // Ctx holds default function call context from this function instance. Ctx *ModuleContext // Engine implements function calls for this module. Engine ModuleEngine + + // TypeIDs is index-correlated with types and holds typeIDs which is uniquely assigned to a type by store. + // This is necessary to achieve fast runtime type checking for indirect function calls at runtime. + TypeIDs []FunctionTypeID } // ExportInstance represents an exported instance in a Store. @@ -119,14 +123,6 @@ type ( Index Index } - // TypeInstance is a store-specific representation of FunctionType where the function type - // is coupled with TypeID which is specific in a store. - TypeInstance struct { - Type *FunctionType - // TypeID is assigned by a store for FunctionType. - TypeID FunctionTypeID - } - // GlobalInstance represents a global instance in a store. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#global-instances%E2%91%A0 GlobalInstance struct { @@ -150,15 +146,16 @@ const ( // addSections adds section elements to the ModuleInstance func (m *ModuleInstance) addSections(module *Module, importedFunctions, functions []*FunctionInstance, importedGlobals, globals []*GlobalInstance, table *TableInstance, memory, importedMemory *MemoryInstance, - typeInstances []*TypeInstance) { + types []*FunctionType, typeIDs []FunctionTypeID) { - m.Types = typeInstances + m.Types = types + m.TypeIDs = typeIDs m.Functions = append(m.Functions, importedFunctions...) for i, f := range functions { // Associate each function with the type instance and the module instance's pointer. f.Module = m - f.TypeID = typeInstances[module.FunctionSection[i]].TypeID + f.TypeID = typeIDs[module.FunctionSection[i]] m.Functions = append(m.Functions, f) } @@ -252,7 +249,7 @@ func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sy return nil, err } - types, err := s.getTypes(module.TypeSection) + typeIDs, err := s.getFunctionTypeIDs(module.TypeSection) if err != nil { s.deleteModule(name) return nil, err @@ -284,7 +281,7 @@ func (s *Store) Instantiate(ctx context.Context, module *Module, name string, sy // Now we have all instances from imports and local ones, so ready to create a new ModuleInstance. m := &ModuleInstance{Name: name} - m.addSections(module, importedFunctions, functions, importedGlobals, globals, table, importedMemory, memory, types) + m.addSections(module, importedFunctions, functions, importedGlobals, globals, table, importedMemory, memory, module.TypeSection, typeIDs) if err = m.validateData(module.DataSection); err != nil { s.deleteModule(name) @@ -500,13 +497,13 @@ func executeConstExpression(globals []*GlobalInstance, expr *ConstantExpression) return } -func (s *Store) getTypes(ts []*FunctionType) ([]*TypeInstance, error) { +func (s *Store) getFunctionTypeIDs(ts []*FunctionType) ([]FunctionTypeID, error) { // We take write-lock here as the following might end up mutating typeIDs map. s.mux.Lock() defer s.mux.Unlock() - ret := make([]*TypeInstance, len(ts)) + ret := make([]FunctionTypeID, len(ts)) for i, t := range ts { - inst, err := s.getTypeInstance(t) + inst, err := s.getFunctionTypeID(t) if err != nil { return nil, err } @@ -515,16 +512,16 @@ func (s *Store) getTypes(ts []*FunctionType) ([]*TypeInstance, error) { return ret, nil } -func (s *Store) getTypeInstance(t *FunctionType) (*TypeInstance, error) { +func (s *Store) getFunctionTypeID(t *FunctionType) (FunctionTypeID, error) { key := t.String() id, ok := s.typeIDs[key] if !ok { l := uint32(len(s.typeIDs)) if l >= s.functionMaxTypes { - return nil, fmt.Errorf("too many function types in a store") + return 0, fmt.Errorf("too many function types in a store") } id = FunctionTypeID(len(s.typeIDs)) s.typeIDs[key] = id } - return &TypeInstance{Type: t, TypeID: id}, nil + return id, nil } diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 526b2d97..6626c5eb 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -444,8 +444,11 @@ func (e *mockEngine) NewModuleEngine(_ string, _ *Module, _, _ []*FunctionInstan return &mockModuleEngine{callFailIndex: e.callFailIndex}, nil } -// ReleaseCompilationCache implements the same method as documented on wasm.Engine. -func (e *mockEngine) ReleaseCompilationCache(*Module) {} +// DeleteCompiledModule implements the same method as documented on wasm.Engine. +func (e *mockEngine) DeleteCompiledModule(*Module) {} + +// CompileModule implements the same method as documented on wasm.Engine. +func (e *mockEngine) CompileModule(module *Module) error { return nil } // Name implements the same method as documented on wasm.ModuleEngine. func (e *mockModuleEngine) Name() string { @@ -466,7 +469,7 @@ func (e *mockModuleEngine) Call(ctx *ModuleContext, f *FunctionInstance, _ ...ui func (e *mockModuleEngine) Close() { } -func TestStore_getTypeInstance(t *testing.T) { +func TestStore_getFunctionTypeID(t *testing.T) { t.Run("too many functions", func(t *testing.T) { s := newStore() const max = 10 @@ -475,7 +478,7 @@ func TestStore_getTypeInstance(t *testing.T) { for i := 0; i < max; i++ { s.typeIDs[strconv.Itoa(i)] = 0 } - _, err := s.getTypeInstance(&FunctionType{}) + _, err := s.getFunctionTypeID(&FunctionType{}) require.Error(t, err) }) t.Run("ok", func(t *testing.T) { @@ -488,13 +491,12 @@ func TestStore_getTypeInstance(t *testing.T) { tc := tc t.Run(tc.String(), func(t *testing.T) { s := newStore() - actual, err := s.getTypeInstance(tc) + actual, err := s.getFunctionTypeID(tc) require.NoError(t, err) expectedTypeID, ok := s.typeIDs[tc.String()] require.True(t, ok) - require.Equal(t, expectedTypeID, actual.TypeID) - require.Equal(t, tc, actual.Type) + require.Equal(t, expectedTypeID, actual) }) } }) diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 48e34529..2eb28008 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -114,8 +114,21 @@ type compiler struct { depth int } pc uint64 - f *wasm.FunctionInstance result CompilationResult + + // body holds the code for the function's body where Wasm instructions are stored. + body []byte + // sig is the function type of the target function. + sig *wasm.FunctionType + // localTypes holds the target function locals' value types. + localTypes []wasm.ValueType + + // types hold all the function types in the module where the targe function exists. + types []*wasm.FunctionType + // funcs holds the type indexes for all declard functions in the module where the targe function exists. + funcs []uint32 + // globals holds the global types for all declard globas in the module where the targe function exists. + globals []*wasm.GlobalType } // For debugging only. @@ -152,40 +165,92 @@ type CompilationResult struct { // // This example the label corresponding to `(block i32.const 1111)` is never be reached at runtime because `br 0` exits the function before we reach there LabelCallers map[string]uint32 + + // Signature is the function type of the compilation target function. + Signature *wasm.FunctionType + // Globals holds all the declarations of globals in the module from which this function is compiled. + Globals []*wasm.GlobalType + // Functions holds all the declarations of function in the module from which this function is compiled, including itself. + Functions []wasm.Index + // Types holds all the types in the module from which this function is compiled. + Types []*wasm.FunctionType + // HasMemory is true if the module from which this function is compiled has memory declaration. + HasMemory bool + // HasTable is true if the module from which this function is compiled has table declaration. + HasTable bool +} + +func CompileFunctions(enabledFeatures wasm.Features, module *wasm.Module) ([]*CompilationResult, error) { + functions, globals, mem, table, err := module.AllDeclarations() + if err != nil { + return nil, err + } + + hasMemory, hasTable := mem != nil, table != nil + + var ret []*CompilationResult + for funcInxdex := range module.FunctionSection { + typeID := module.FunctionSection[funcInxdex] + sig := module.TypeSection[typeID] + code := module.CodeSection[funcInxdex] + r, err := compile(enabledFeatures, sig, code.Body, code.LocalTypes, module.TypeSection, functions, globals) + if err != nil { + return nil, fmt.Errorf("failed to lower func[%d/%d] to wazeroir: %w", funcInxdex, len(functions), err) + } + r.Globals = globals + r.Functions = functions + r.Types = module.TypeSection + r.HasMemory = hasMemory + r.HasTable = hasTable + r.Signature = sig + ret = append(ret, r) + } + return ret, nil } // Compile lowers given function instance into wazeroir operations // so that the resulting operations can be consumed by the interpreter // or the JIT compilation engine. -func Compile(enabledFeatures wasm.Features, f *wasm.FunctionInstance) (*CompilationResult, error) { +func compile(enabledFeatures wasm.Features, + sig *wasm.FunctionType, + body []byte, + localTypes []wasm.ValueType, + types []*wasm.FunctionType, + functions []uint32, globals []*wasm.GlobalType, +) (*CompilationResult, error) { c := compiler{ enabledFeatures: enabledFeatures, controlFrames: &controlFrames{}, - f: f, result: CompilationResult{LabelCallers: map[string]uint32{}}, + body: body, + localTypes: localTypes, + sig: sig, + globals: globals, + funcs: functions, + types: types, } // Push function arguments. - for _, t := range f.Type.Params { + for _, t := range sig.Params { c.stackPush(wasmValueTypeToUnsignedType(t)) } // Emit const expressions for locals. // Note that here we don't take function arguments // into account, meaning that callers must push // arguments before entering into the function body. - for _, t := range f.LocalTypes { + for _, t := range localTypes { c.emitDefaultValue(t) } // Insert the function control frame. c.controlFrames.push(&controlFrame{ frameID: c.nextID(), - blockType: f.Type, + blockType: c.sig, kind: controlFrameKindFunction, }) // Now, enter the function body. - for !c.controlFrames.empty() { + for !c.controlFrames.empty() && c.pc < uint64(len(c.body)) { if err := c.handleInstruction(); err != nil { return nil, fmt.Errorf("handling instruction: %w", err) } @@ -196,7 +261,7 @@ func Compile(enabledFeatures wasm.Features, f *wasm.FunctionInstance) (*Compilat // Translate the current Wasm instruction to wazeroir's operations, // and emit the results into c.results. func (c *compiler) handleInstruction() error { - op := c.f.Body[c.pc] + op := c.body[c.pc] if buildoptions.IsDebugMode { fmt.Printf("handling %s, unreachable_state(on=%v,depth=%d)\n", wasm.InstructionName(op), @@ -223,8 +288,8 @@ operatorSwitch: case wasm.OpcodeNop: // Nop is noop! case wasm.OpcodeBlock: - bt, num, err := wasm.DecodeBlockType(c.f.Module.Types, - bytes.NewReader(c.f.Body[c.pc+1:]), c.enabledFeatures) + bt, num, err := wasm.DecodeBlockType(c.types, + bytes.NewReader(c.body[c.pc+1:]), c.enabledFeatures) if err != nil { return fmt.Errorf("reading block type for block instruction: %w", err) } @@ -247,7 +312,7 @@ operatorSwitch: c.controlFrames.push(frame) case wasm.OpcodeLoop: - bt, num, err := wasm.DecodeBlockType(c.f.Module.Types, bytes.NewReader(c.f.Body[c.pc+1:]), c.enabledFeatures) + bt, num, err := wasm.DecodeBlockType(c.types, bytes.NewReader(c.body[c.pc+1:]), c.enabledFeatures) if err != nil { return fmt.Errorf("reading block type for loop instruction: %w", err) } @@ -282,7 +347,7 @@ operatorSwitch: ) case wasm.OpcodeIf: - bt, num, err := wasm.DecodeBlockType(c.f.Module.Types, bytes.NewReader(c.f.Body[c.pc+1:]), c.enabledFeatures) + bt, num, err := wasm.DecodeBlockType(c.types, bytes.NewReader(c.body[c.pc+1:]), c.enabledFeatures) if err != nil { return fmt.Errorf("reading block type for if instruction: %w", err) } @@ -474,7 +539,7 @@ operatorSwitch: } case wasm.OpcodeBr: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) + targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -494,7 +559,7 @@ operatorSwitch: // and can be safely removed. c.markUnreachable() case wasm.OpcodeBrIf: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) + targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -519,7 +584,7 @@ operatorSwitch: }, ) case wasm.OpcodeBrTable: - r := bytes.NewReader(c.f.Body[c.pc+1:]) + r := bytes.NewReader(c.body[c.pc+1:]) numTargets, n, err := leb128.DecodeUint32(r) if err != nil { return fmt.Errorf("error reading number of targets in br_table: %w", err) @@ -591,7 +656,7 @@ operatorSwitch: if index == nil { return fmt.Errorf("index does not exist for indirect function call") } - tableIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) + tableIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("read target for br_table: %w", err) } @@ -847,7 +912,7 @@ operatorSwitch: &OperationMemoryGrow{}, ) case wasm.OpcodeI32Const: - val, num, err := leb128.DecodeInt32(bytes.NewReader(c.f.Body[c.pc+1:])) + val, num, err := leb128.DecodeInt32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -856,7 +921,7 @@ operatorSwitch: &OperationConstI32{Value: uint32(val)}, ) case wasm.OpcodeI64Const: - val, num, err := leb128.DecodeInt64(bytes.NewReader(c.f.Body[c.pc+1:])) + val, num, err := leb128.DecodeInt64(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i64.const value: %v", err) } @@ -865,13 +930,13 @@ operatorSwitch: &OperationConstI64{Value: uint64(val)}, ) case wasm.OpcodeF32Const: - v := math.Float32frombits(binary.LittleEndian.Uint32(c.f.Body[c.pc+1:])) + v := math.Float32frombits(binary.LittleEndian.Uint32(c.body[c.pc+1:])) c.pc += 4 c.emit( &OperationConstF32{Value: v}, ) case wasm.OpcodeF64Const: - v := math.Float64frombits(binary.LittleEndian.Uint64(c.f.Body[c.pc+1:])) + v := math.Float64frombits(binary.LittleEndian.Uint64(c.body[c.pc+1:])) c.pc += 8 c.emit( &OperationConstF64{Value: v}, @@ -1418,7 +1483,7 @@ func (c *compiler) applyToStack(opcode wasm.Opcode) (*uint32, error) { wasm.OpcodeGlobalGet, wasm.OpcodeGlobalSet: // Assumes that we are at the opcode now so skip it before read immediates. - v, num, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) + v, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) if err != nil { return nil, fmt.Errorf("reading immediates: %w", err) } @@ -1437,7 +1502,7 @@ func (c *compiler) applyToStack(opcode wasm.Opcode) (*uint32, error) { } // Retrieve the signature of the opcode. - s, err := wasmOpcodeSignature(c.f, opcode, index) + s, err := c.wasmOpcodeSignature(opcode, index) if err != nil { return nil, err } @@ -1566,7 +1631,7 @@ func (c *compiler) getFrameDropRange(frame *controlFrame, isEnd bool) *Inclusive } func (c *compiler) readMemoryImmediate(tag string) (*MemoryImmediate, error) { - r := bytes.NewReader(c.f.Body[c.pc+1:]) + r := bytes.NewReader(c.body[c.pc+1:]) alignment, num, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("reading alignment for %s: %w", tag, err) diff --git a/internal/wazeroir/compiler_test.go b/internal/wazeroir/compiler_test.go index 0c38a7a9..029a896d 100644 --- a/internal/wazeroir/compiler_test.go +++ b/internal/wazeroir/compiler_test.go @@ -1,7 +1,6 @@ package wazeroir import ( - "context" "testing" "github.com/tetratelabs/wazero/api" @@ -33,6 +32,9 @@ func TestCompile(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, LabelCallers: map[string]uint32{}, + Functions: []uint32{0}, + Types: []*wasm.FunctionType{{}}, + Signature: &wasm.FunctionType{}, }, }, { @@ -47,6 +49,9 @@ func TestCompile(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, LabelCallers: map[string]uint32{}, + Types: []*wasm.FunctionType{{Params: []wasm.ValueType{i32}, Results: []wasm.ValueType{i32}}}, + Functions: []uint32{0}, + Signature: &wasm.FunctionType{Params: []wasm.ValueType{wasm.ValueTypeI32}, Results: []wasm.ValueType{wasm.ValueTypeI32}}, }, }, } @@ -59,13 +64,10 @@ func TestCompile(t *testing.T) { if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - functions, err := compileFunctions(enabledFeatures, tc.module) - require.NoError(t, err) - require.Equal(t, 1, len(functions)) - res, err := Compile(enabledFeatures, functions[0]) + res, err := CompileFunctions(enabledFeatures, tc.module) require.NoError(t, err) - require.Equal(t, tc.expected, res) + require.Equal(t, tc.expected, res[0]) }) } } @@ -109,6 +111,9 @@ func TestCompile_Block(t *testing.T) { // shouldn't because the br instruction is stack-polymorphic. In other words, (br 0) substitutes for the // two i32 parameters to add. LabelCallers: map[string]uint32{".L2_cont": 1}, + Functions: []uint32{0}, + Types: []*wasm.FunctionType{v_v}, + Signature: v_v, }, }, } @@ -123,6 +128,12 @@ func TestCompile_Block(t *testing.T) { } func TestCompile_MultiValue(t *testing.T) { + i32i32_i32i32 := &wasm.FunctionType{Params: []wasm.ValueType{ + wasm.ValueTypeI32, wasm.ValueTypeI32}, + Results: []wasm.ValueType{wasm.ValueTypeI32, wasm.ValueTypeI32}, + } + _i32i64 := &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI32, wasm.ValueTypeI64}} + tests := []struct { name string module *wasm.Module @@ -143,6 +154,9 @@ func TestCompile_MultiValue(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, LabelCallers: map[string]uint32{}, + Signature: i32i32_i32i32, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{i32i32_i32i32}, }, }, { @@ -184,6 +198,9 @@ func TestCompile_MultiValue(t *testing.T) { // Note: f64.add comes after br 0 so is unreachable. This is why neither the add, nor its other operand // are in the above compilation result. LabelCallers: map[string]uint32{".L2_cont": 1}, // arbitrary label + Signature: v_f64f64, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{v_f64f64}, }, }, { @@ -199,6 +216,9 @@ func TestCompile_MultiValue(t *testing.T) { &OperationBr{Target: &BranchTarget{}}, // return! }, LabelCallers: map[string]uint32{}, + Signature: _i32i64, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{_i32i64}, }, }, { @@ -248,6 +268,9 @@ func TestCompile_MultiValue(t *testing.T) { ".L2_cont": 2, ".L2_else": 1, }, + Signature: i32_i32, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{i32_i32}, }, }, { @@ -301,6 +324,9 @@ func TestCompile_MultiValue(t *testing.T) { ".L2_cont": 2, ".L2_else": 1, }, + Signature: i32_i32, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{i32_i32, i32i32_i32}, }, }, { @@ -354,6 +380,9 @@ func TestCompile_MultiValue(t *testing.T) { ".L2_cont": 2, ".L2_else": 1, }, + Signature: i32_i32, + Functions: []wasm.Index{0}, + Types: []*wasm.FunctionType{i32_i32, i32i32_i32}, }, }, } @@ -366,13 +395,9 @@ func TestCompile_MultiValue(t *testing.T) { if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - functions, err := compileFunctions(enabledFeatures, tc.module) + res, err := CompileFunctions(enabledFeatures, tc.module) require.NoError(t, err) - require.Equal(t, 1, len(functions)) - - res, err := Compile(enabledFeatures, functions[0]) - require.NoError(t, err) - require.Equal(t, tc.expected, res) + require.Equal(t, tc.expected, res[0]) }) } } @@ -381,13 +406,9 @@ func requireCompilationResult(t *testing.T, enabledFeatures wasm.Features, expec if enabledFeatures == 0 { enabledFeatures = wasm.FeaturesFinished } - functions, err := compileFunctions(enabledFeatures, module) + res, err := CompileFunctions(enabledFeatures, module) require.NoError(t, err) - require.Equal(t, 1, len(functions)) - - res, err := Compile(enabledFeatures, functions[0]) - require.NoError(t, err) - require.Equal(t, expected, res) + require.Equal(t, expected, res[0]) } func requireModuleText(t *testing.T, source string) *wasm.Module { @@ -395,35 +416,3 @@ func requireModuleText(t *testing.T, source string) *wasm.Module { require.NoError(t, err) return m } - -func compileFunctions(enabledFeatures wasm.Features, module *wasm.Module) ([]*wasm.FunctionInstance, error) { - cf := &catchFunctions{} - _, err := wasm.NewStore(enabledFeatures, cf).Instantiate(context.Background(), module, "", wasm.DefaultSysContext()) - return cf.functions, err -} - -type catchFunctions struct { - functions []*wasm.FunctionInstance -} - -// NewModuleEngine implements the same method as documented on wasm.Engine. -func (e *catchFunctions) NewModuleEngine(_ string, _ *wasm.Module, _, functions []*wasm.FunctionInstance, _ *wasm.TableInstance, _ map[wasm.Index]wasm.Index) (wasm.ModuleEngine, error) { - e.functions = functions - return e, nil -} - -// Name implements the same method as documented on wasm.ModuleEngine. -func (e *catchFunctions) Name() string { - return "" -} - -// Call implements the same method as documented on wasm.ModuleEngine. -func (e *catchFunctions) Call(_ *wasm.ModuleContext, _ *wasm.FunctionInstance, _ ...uint64) ([]uint64, error) { - return nil, nil -} - -// Close implements the same method as documented on wasm.ModuleEngine. -func (e *catchFunctions) Close() {} - -// ReleaseCompilationCache implements the same method as documented on wasm.Engine. -func (e *catchFunctions) ReleaseCompilationCache(*wasm.Module) {} diff --git a/internal/wazeroir/signature.go b/internal/wazeroir/signature.go index 78307072..b1fee8bb 100644 --- a/internal/wazeroir/signature.go +++ b/internal/wazeroir/signature.go @@ -147,7 +147,7 @@ var ( // the function instance (for example, local types). // "index" parameter is not used by most of opcodes. // The returned signature is used for stack validation when lowering Wasm's opcodes to wazeroir. -func wasmOpcodeSignature(f *wasm.FunctionInstance, op wasm.Opcode, index uint32) (*signature, error) { +func (c *compiler) wasmOpcodeSignature(op wasm.Opcode, index uint32) (*signature, error) { switch op { case wasm.OpcodeUnreachable, wasm.OpcodeNop, wasm.OpcodeBlock, wasm.OpcodeLoop: return signature_None_None, nil @@ -160,9 +160,9 @@ func wasmOpcodeSignature(f *wasm.FunctionInstance, op wasm.Opcode, index uint32) case wasm.OpcodeReturn: return signature_None_None, nil case wasm.OpcodeCall: - return funcTypeToSignature(f.Module.Functions[index].Type), nil + return funcTypeToSignature(c.types[c.funcs[index]]), nil case wasm.OpcodeCallIndirect: - ret := funcTypeToSignature(f.Module.Types[index].Type) + ret := funcTypeToSignature(c.types[index]) ret.in = append(ret.in, UnsignedTypeI32) return ret, nil case wasm.OpcodeDrop: @@ -170,54 +170,54 @@ func wasmOpcodeSignature(f *wasm.FunctionInstance, op wasm.Opcode, index uint32) case wasm.OpcodeSelect: return signature_UnknownUnknownI32_Unknown, nil case wasm.OpcodeLocalGet: - inputLen := uint32(len(f.Type.Params)) - if l := uint32(len(f.LocalTypes)) + inputLen; index >= l { + inputLen := uint32(len(c.sig.Params)) + if l := uint32(len(c.localTypes)) + inputLen; index >= l { return nil, fmt.Errorf("invalid local index for local.get %d >= %d", index, l) } var t UnsignedType if index < inputLen { - t = wasmValueTypeToUnsignedType(f.Type.Params[index]) + t = wasmValueTypeToUnsignedType(c.sig.Params[index]) } else { - t = wasmValueTypeToUnsignedType(f.LocalTypes[index-inputLen]) + t = wasmValueTypeToUnsignedType(c.localTypes[index-inputLen]) } return &signature{out: []UnsignedType{t}}, nil case wasm.OpcodeLocalSet: - inputLen := uint32(len(f.Type.Params)) - if l := uint32(len(f.LocalTypes)) + inputLen; index >= l { + inputLen := uint32(len(c.sig.Params)) + if l := uint32(len(c.localTypes)) + inputLen; index >= l { return nil, fmt.Errorf("invalid local index for local.get %d >= %d", index, l) } var t UnsignedType if index < inputLen { - t = wasmValueTypeToUnsignedType(f.Type.Params[index]) + t = wasmValueTypeToUnsignedType(c.sig.Params[index]) } else { - t = wasmValueTypeToUnsignedType(f.LocalTypes[index-inputLen]) + t = wasmValueTypeToUnsignedType(c.localTypes[index-inputLen]) } return &signature{in: []UnsignedType{t}}, nil case wasm.OpcodeLocalTee: - inputLen := uint32(len(f.Type.Params)) - if l := uint32(len(f.LocalTypes)) + inputLen; index >= l { + inputLen := uint32(len(c.sig.Params)) + if l := uint32(len(c.localTypes)) + inputLen; index >= l { return nil, fmt.Errorf("invalid local index for local.get %d >= %d", index, l) } var t UnsignedType if index < inputLen { - t = wasmValueTypeToUnsignedType(f.Type.Params[index]) + t = wasmValueTypeToUnsignedType(c.sig.Params[index]) } else { - t = wasmValueTypeToUnsignedType(f.LocalTypes[index-inputLen]) + t = wasmValueTypeToUnsignedType(c.localTypes[index-inputLen]) } return &signature{in: []UnsignedType{t}, out: []UnsignedType{t}}, nil case wasm.OpcodeGlobalGet: - if len(f.Module.Globals) <= int(index) { - return nil, fmt.Errorf("invalid global index for global.get %d >= %d", index, len(f.Module.Globals)) + if len(c.globals) <= int(index) { + return nil, fmt.Errorf("invalid global index for global.get %d >= %d", index, len(c.globals)) } return &signature{ - out: []UnsignedType{wasmValueTypeToUnsignedType(f.Module.Globals[index].Type.ValType)}, + out: []UnsignedType{wasmValueTypeToUnsignedType(c.globals[index].ValType)}, }, nil case wasm.OpcodeGlobalSet: - if len(f.Module.Globals) <= int(index) { - return nil, fmt.Errorf("invalid global index for global.get %d >= %d", index, len(f.Module.Globals)) + if len(c.globals) <= int(index) { + return nil, fmt.Errorf("invalid global index for global.get %d >= %d", index, len(c.globals)) } return &signature{ - in: []UnsignedType{wasmValueTypeToUnsignedType(f.Module.Globals[index].Type.ValType)}, + in: []UnsignedType{wasmValueTypeToUnsignedType(c.globals[index].ValType)}, }, nil case wasm.OpcodeI32Load: return signature_I32_I32, nil diff --git a/wasm.go b/wasm.go index 758c2d24..6a02e6a3 100644 --- a/wasm.go +++ b/wasm.go @@ -164,7 +164,13 @@ func (r *runtime) CompileModule(source []byte) (*CompiledCode, error) { return nil, err } - return &CompiledCode{module: internal}, nil + if err = r.store.Engine.CompileModule(internal); err != nil { + return nil, err + } + + ret := &CompiledCode{module: internal} + ret.addCacheEntry(internal, r.store.Engine) + return ret, nil } // InstantiateModuleFromCode implements Runtime.InstantiateModuleFromCode @@ -207,6 +213,14 @@ func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *Module } module := config.replaceImports(code.module) + if module != code.module { + // If replacing imports had an effect, the module changed, so we have to recompile it. + // TODO: maybe we should move replaceImports configs into CompileModule. + if err = r.store.Engine.CompileModule(module); err != nil { + return nil, err + } + code.addCacheEntry(module, r.store.Engine) + } mod, err = r.store.Instantiate(r.ctx, module, name, sysCtx) if err != nil { @@ -226,6 +240,5 @@ func (r *runtime) InstantiateModuleWithConfig(code *CompiledCode, config *Module return } } - code.addCacheEntry(module, r.store.Engine) return } diff --git a/wasm_test.go b/wasm_test.go index e2a8edc0..65c4c844 100644 --- a/wasm_test.go +++ b/wasm_test.go @@ -218,7 +218,7 @@ func TestModule_Global(t *testing.T) { for _, tt := range tests { tc := tt - r := NewRuntime() + r := NewRuntime().(*runtime) t.Run(tc.name, func(t *testing.T) { var code *CompiledCode if tc.module != nil { @@ -227,6 +227,9 @@ func TestModule_Global(t *testing.T) { code, _ = tc.builder(r).Build() } + err := r.store.Engine.CompileModule(code.module) + require.NoError(t, err) + // Instantiate the module and get the export of the above global module, err := r.InstantiateModule(code) require.NoError(t, err) @@ -489,7 +492,9 @@ func (e *mockEngine) NewModuleEngine(_ string, module *wasm.Module, _, _ []*wasm return nil, nil } -// ReleaseCompilationCache implements the same method as documented on wasm.Engine. -func (e *mockEngine) ReleaseCompilationCache(module *wasm.Module) { +// DeleteCompiledModule implements the same method as documented on wasm.Engine. +func (e *mockEngine) DeleteCompiledModule(module *wasm.Module) { delete(e.cachedModules, module) } + +func (e *mockEngine) CompileModule(module *wasm.Module) error { return nil }