wazevo: adds support for context cancelation (#1709)

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
This commit is contained in:
Takeshi Yoneda
2023-09-14 13:22:30 +09:00
committed by GitHub
parent 69c15b10ca
commit 173fae7b81
13 changed files with 165 additions and 38 deletions

View File

@@ -3100,7 +3100,7 @@ L9 (SSA Block: blk6):
t.Run(tc.name, func(t *testing.T) {
ssab := ssa.NewBuilder()
offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset)
fc := frontend.NewFrontendCompiler(tc.m, ssab, &offset, false)
machine := newMachine()
machine.DisableStackCheck()
be := backend.NewCompiler(context.Background(), machine, ssab)

View File

@@ -376,7 +376,7 @@ func (m *machine) insertStackBoundsCheck(requiredStackSize int64, cur *instructi
ldrAddress.asULoad(operandNR(tmpRegVReg), addressMode{
kind: addressModeKindRegUnsignedImm12,
rn: x0VReg, // execution context is always the first argument
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallSequenceAddress.I64(),
imm: wazevoapi.ExecutionContextOffsets.StackGrowCallTrampolineAddress.I64(),
}, 64)
cur = linkInstr(cur, ldrAddress)

View File

@@ -63,11 +63,12 @@ type (
stackGrowRequiredSize uintptr
// memoryGrowTrampolineAddress holds the address of memory grow trampoline function.
memoryGrowTrampolineAddress *byte
// stackGrowCallSequenceAddress holds the address of stack grow call sequence function.
stackGrowCallSequenceAddress *byte
// stackGrowCallTrampolineAddress holds the address of stack grow trampoline function.
stackGrowCallTrampolineAddress *byte
// checkModuleExitCodeTrampolineAddress holds the address of check-module-exit-code function.
checkModuleExitCodeTrampolineAddress *byte
// savedRegisters is the opaque spaces for save/restore registers.
// We want to align 16 bytes for each register, so we use [64][2]uint64.
_ uint64
savedRegisters [64][2]uint64
// goFunctionCallCalleeModuleContextOpaque is the pointer to the target Go function's moduleContextOpaque.
goFunctionCallCalleeModuleContextOpaque uintptr
@@ -138,6 +139,19 @@ func (c *callEngine) CallWithStack(ctx context.Context, paramResultStack []uint6
// CallWithStack implements api.Function.
func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint64) (err error) {
p := c.parent
ensureTermination := p.parent.ensureTermination
m := p.module
if ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the module and return the error.
m.CloseWithCtxErr(ctx)
return m.FailIfClosed()
default:
}
}
var paramResultPtr *uint64
if len(paramResultStack) > 0 {
paramResultPtr = &paramResultStack[0]
@@ -165,6 +179,11 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
}
}()
if ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}
entrypoint(c.preambleExecutable, c.executable, c.execCtxPtr, c.parent.opaquePtr, paramResultPtr, c.stackTop)
for {
switch ec := c.execCtx.exitCode; ec & wazevoapi.ExitCodeMask {
@@ -210,6 +229,15 @@ func (c *callEngine) callWithStack(ctx context.Context, paramResultStack []uint6
f.Call(ctx, mod, c.execCtx.goFunctionCallStack[:])
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeCheckModuleExitCode:
// Note: this operation must be done in Go, not native code. The reason is that
// native code cannot be preempted and that means it can block forever if there are not
// enough OS threads (which we don't have control over).
if err := m.FailIfClosed(); err != nil {
panic(err)
}
c.execCtx.exitCode = wazevoapi.ExitCodeOK
afterGoFunctionCallEntrypoint(c.execCtx.goCallReturnAddress, c.execCtxPtr, c.execCtx.stackPointerBeforeGoCall)
case wazevoapi.ExitCodeUnreachable:
panic(wasmruntime.ErrRuntimeUnreachable)
case wazevoapi.ExitCodeMemoryOutOfBounds:

View File

@@ -45,6 +45,9 @@ type (
sharedFunctions struct {
// memoryGrowExecutable is a compiled executable for memory.grow builtin function.
memoryGrowExecutable []byte
// checkModuleExitCode is a compiled executable for checking module instance exit code. This
// is used when ensureTermination is true.
checkModuleExitCode []byte
// stackGrowExecutable is a compiled executable for growing stack builtin function.
stackGrowExecutable []byte
entryPreambles map[*wasm.FunctionType][]byte
@@ -58,6 +61,7 @@ type (
parent *engine
module *wasm.Module
entryPreambles []*byte // indexed-correlated with the type index.
ensureTermination bool
// The followings are only available for non host modules.
@@ -78,7 +82,7 @@ func NewEngine(ctx context.Context, _ api.CoreFeatures, _ filecache.Cache) wasm.
machine: machine,
be: be,
}
e.compileBuiltinFunctions()
e.compileSharedFunctions()
return e
}
@@ -108,6 +112,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
e.rels = e.rels[:0]
cm := &compiledModule{
offsets: wazevoapi.NewModuleContextOffsetData(module), parent: e, module: module,
ensureTermination: ensureTermination,
}
if module.IsHostModule {
@@ -131,7 +136,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
// Creates new compiler instances which are reused for each function.
ssaBuilder := ssa.NewBuilder()
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets)
fe := frontend.NewFrontendCompiler(module, ssaBuilder, &cm.offsets, ensureTermination)
machine := newMachine()
be := backend.NewCompiler(ctx, machine, ssaBuilder)
@@ -154,7 +159,7 @@ func (e *engine) compileModule(ctx context.Context, module *wasm.Module, listene
ctx = wazevoapi.SetCurrentFunctionName(ctx, fmt.Sprintf("[%d/%d] \"%s\"", i, len(module.CodeSection)-1, name))
}
body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners, ensureTermination)
body, rels, err := e.compileLocalWasmFunction(ctx, module, wasm.Index(i), fe, ssaBuilder, be, listeners)
if err != nil {
return nil, fmt.Errorf("compile function %d/%d: %v", i, len(module.CodeSection)-1, err)
}
@@ -214,7 +219,7 @@ func (e *engine) compileLocalWasmFunction(
fe *frontend.Compiler,
ssaBuilder ssa.Builder,
be backend.Compiler,
_ []experimental.FunctionListener, _ bool,
_ []experimental.FunctionListener,
) (body []byte, rels []backend.RelocationInfo, err error) {
typ := &module.TypeSection[module.FunctionSection[localFunctionIndex]]
codeSeg := &module.CodeSection[localFunctionIndex]
@@ -468,7 +473,7 @@ func (e *engine) NewModuleEngine(m *wasm.Module, mi *wasm.ModuleInstance) (wasm.
return me, nil
}
func (e *engine) compileBuiltinFunctions() {
func (e *engine) compileSharedFunctions() {
e.sharedFunctions = &sharedFunctions{entryPreambles: make(map[*wasm.FunctionType][]byte)}
e.be.Init()
@@ -480,6 +485,15 @@ func (e *engine) compileBuiltinFunctions() {
e.sharedFunctions.memoryGrowExecutable = mmapExecutable(src)
}
e.be.Init()
{
src := e.machine.CompileGoFunctionTrampoline(wazevoapi.ExitCodeCheckModuleExitCode, &ssa.Signature{
Params: []ssa.Type{ssa.TypeI32 /* exec context */},
Results: []ssa.Type{ssa.TypeI32},
}, false)
e.sharedFunctions.checkModuleExitCode = mmapExecutable(src)
}
// TODO: table grow, etc.
e.be.Init()
@@ -491,21 +505,26 @@ func (e *engine) compileBuiltinFunctions() {
e.setFinalizer(e.sharedFunctions, sharedFunctionsFinalizer)
}
func sharedFunctionsFinalizer(bf *sharedFunctions) {
if err := platform.MunmapCodeSegment(bf.memoryGrowExecutable); err != nil {
func sharedFunctionsFinalizer(sf *sharedFunctions) {
if err := platform.MunmapCodeSegment(sf.memoryGrowExecutable); err != nil {
panic(err)
}
if err := platform.MunmapCodeSegment(bf.stackGrowExecutable); err != nil {
if err := platform.MunmapCodeSegment(sf.checkModuleExitCode); err != nil {
panic(err)
}
for _, f := range bf.entryPreambles {
if err := platform.MunmapCodeSegment(sf.stackGrowExecutable); err != nil {
panic(err)
}
for _, f := range sf.entryPreambles {
if err := platform.MunmapCodeSegment(f); err != nil {
panic(err)
}
}
bf.memoryGrowExecutable = nil
bf.stackGrowExecutable = nil
sf.memoryGrowExecutable = nil
sf.checkModuleExitCode = nil
sf.stackGrowExecutable = nil
sf.entryPreambles = nil
}
func compiledModuleFinalizer(cm *compiledModule) {

View File

@@ -13,19 +13,37 @@ import (
)
func Test_sharedFunctionsFinalizer(t *testing.T) {
bf := &sharedFunctions{}
sf := &sharedFunctions{}
b1, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
b2, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
bf.memoryGrowExecutable = b1
bf.stackGrowExecutable = b2
sharedFunctionsFinalizer(bf)
require.Nil(t, bf.memoryGrowExecutable)
require.Nil(t, bf.stackGrowExecutable)
b3, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
b4, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
b5, err := platform.MmapCodeSegment(100)
require.NoError(t, err)
preabmles := map[*wasm.FunctionType][]byte{
{Params: []wasm.ValueType{}}: b4,
{Params: []wasm.ValueType{wasm.ValueTypeI32}}: b5,
}
sf.memoryGrowExecutable = b1
sf.stackGrowExecutable = b2
sf.checkModuleExitCode = b3
sf.entryPreambles = preabmles
sharedFunctionsFinalizer(sf)
require.Nil(t, sf.memoryGrowExecutable)
require.Nil(t, sf.stackGrowExecutable)
require.Nil(t, sf.checkModuleExitCode)
require.Nil(t, sf.entryPreambles)
}
func Test_compiledModuleFinalizer(t *testing.T) {

View File

@@ -20,6 +20,9 @@ type Compiler struct {
ssaBuilder ssa.Builder
signatures map[*wasm.FunctionType]*ssa.Signature
memoryGrowSig ssa.Signature
checkModuleExitCodeSig ssa.Signature
checkModuleExitCodeArg [1]ssa.Value
ensureTermination bool
// Followings are reset by per function.
@@ -43,13 +46,14 @@ type Compiler struct {
}
// NewFrontendCompiler returns a frontend Compiler.
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData) *Compiler {
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool) *Compiler {
c := &Compiler{
m: m,
ssaBuilder: ssaBuilder,
br: bytes.NewReader(nil),
wasmLocalToVariable: make(map[wasm.Index]ssa.Variable),
offset: offset,
ensureTermination: ensureTermination,
}
c.signatures = make(map[*wasm.FunctionType]*ssa.Signature, len(m.TypeSection)+1)
@@ -70,6 +74,12 @@ func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoa
}
c.ssaBuilder.DeclareSignature(&c.memoryGrowSig)
c.checkModuleExitCodeSig = ssa.Signature{
ID: c.memoryGrowSig.ID + 1,
// Only takes execution context.
Params: []ssa.Type{ssa.TypeI64},
}
c.ssaBuilder.DeclareSignature(&c.checkModuleExitCodeSig)
return c
}

View File

@@ -18,6 +18,7 @@ func TestCompiler_LowerToSSA(t *testing.T) {
// `~/wasmtime/target/debug/clif-util wasm --target aarch64-apple-darwin testcase.wat -p -t`
for _, tc := range []struct {
name string
ensureTermination bool
// m is the *wasm.Module to be compiled in this test.
m *wasm.Module
// targetIndex is the index of a local function to be compiled in this test.
@@ -219,6 +220,36 @@ blk0: (exec_ctx:i64, module_ctx:i64)
blk1: () <-- (blk0,blk1)
Jump blk1
`,
},
{
name: "loop - br / ensure termination", m: testcases.LoopBr.Module,
ensureTermination: true,
exp: `
signatures:
sig2: i64_v
blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1
blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1
blk2: ()
`,
expAfterOpt: `
signatures:
sig2: i64_v
blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk1
blk1: () <-- (blk0,blk1)
v2:i64 = Load exec_ctx, 0x58
CallIndirect v2:sig2, exec_ctx
Jump blk1
`,
},
{
@@ -1736,7 +1767,7 @@ blk4: () <-- (blk2,blk3)
b := ssa.NewBuilder()
offset := wazevoapi.NewModuleContextOffsetData(tc.m)
fc := NewFrontendCompiler(tc.m, b, &offset)
fc := NewFrontendCompiler(tc.m, b, &offset, tc.ensureTermination)
typeIndex := tc.m.FunctionSection[tc.targetIndex]
code := &tc.m.CodeSection[tc.targetIndex]
fc.Init(tc.targetIndex, &tc.m.TypeSection[typeIndex], code.LocalTypes, code.Body)

View File

@@ -1012,6 +1012,19 @@ func (c *Compiler) lowerCurrentOpcode() {
c.switchTo(originalLen, loopHeader)
if c.ensureTermination {
checkModuleExitCodePtr := builder.AllocateInstruction().
AsLoad(c.execCtxPtrValue,
wazevoapi.ExecutionContextOffsets.CheckModuleExitCodeTrampolineAddress.U32(),
ssa.TypeI64,
).Insert(builder).Return()
c.checkModuleExitCodeArg[0] = c.execCtxPtrValue
builder.AllocateInstruction().
AsCallIndirect(checkModuleExitCodePtr, &c.checkModuleExitCodeSig, c.checkModuleExitCodeArg[:]).
Insert(builder)
}
case wasm.OpcodeIf:
bt := c.readBlockType()

View File

@@ -142,7 +142,8 @@ func (m *moduleEngine) NewFunction(index wasm.Index) api.Function {
}
ce.execCtx.memoryGrowTrampolineAddress = &m.parent.sharedFunctions.memoryGrowExecutable[0]
ce.execCtx.stackGrowCallSequenceAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.stackGrowCallTrampolineAddress = &m.parent.sharedFunctions.stackGrowExecutable[0]
ce.execCtx.checkModuleExitCodeTrampolineAddress = &m.parent.sharedFunctions.checkModuleExitCode[0]
ce.init()
return ce
}

View File

@@ -58,7 +58,8 @@ func Test_ExecutionContextOffsets(t *testing.T) {
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackPointerBeforeGoCall)), offsets.StackPointerBeforeGoCall)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowRequiredSize)), offsets.StackGrowRequiredSize)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.memoryGrowTrampolineAddress)), offsets.MemoryGrowTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallSequenceAddress)), offsets.StackGrowCallSequenceAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.stackGrowCallTrampolineAddress)), offsets.StackGrowCallTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.checkModuleExitCodeTrampolineAddress)), offsets.CheckModuleExitCodeTrampolineAddress)
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters))%16, wazevoapi.Offset(0),
"SavedRegistersBegin must be aligned to 16 bytes")
require.Equal(t, wazevoapi.Offset(unsafe.Offsetof(execCtx.savedRegisters)), offsets.SavedRegistersBegin)

View File

@@ -19,6 +19,7 @@ const (
ExitCodeIntegerDivisionByZero
ExitCodeIntegerOverflow
ExitCodeInvalidConversionToInteger
ExitCodeCheckModuleExitCode
exitCodeMax
)
@@ -51,6 +52,8 @@ func (e ExitCode) String() string {
return "integer_overflow"
case ExitCodeInvalidConversionToInteger:
return "invalid_conversion_to_integer"
case ExitCodeCheckModuleExitCode:
return "check_module_exit_code"
}
panic("TODO")
}

View File

@@ -24,7 +24,8 @@ var ExecutionContextOffsets = ExecutionContextOffsetData{
StackPointerBeforeGoCall: 56,
StackGrowRequiredSize: 64,
MemoryGrowTrampolineAddress: 72,
StackGrowCallSequenceAddress: 80,
StackGrowCallTrampolineAddress: 80,
CheckModuleExitCodeTrampolineAddress: 88,
SavedRegistersBegin: 96,
GoFunctionCallCalleeModuleContextOpaque: 1120,
GoFunctionCallStackBegin: 1128,
@@ -53,8 +54,10 @@ type ExecutionContextOffsetData struct {
StackGrowRequiredSize Offset
// MemoryGrowTrampolineAddress is an offset of `memoryGrowTrampolineAddress` field in wazevo.executionContext
MemoryGrowTrampolineAddress Offset
// StackGrowCallSequenceAddress is an offset of `stackGrowCallSequenceAddress` field in wazevo.executionContext
StackGrowCallSequenceAddress Offset
// stackGrowCallTrampolineAddress is an offset of `stackGrowCallTrampolineAddress` field in wazevo.executionContext.
StackGrowCallTrampolineAddress Offset
// CheckModuleExitCodeTrampolineAddress is an offset of `checkModuleExitCodeTrampolineAddress` field in wazevo.executionContext.
CheckModuleExitCodeTrampolineAddress Offset
// GoCallReturnAddress is an offset of the first element of `savedRegisters` field in wazevo.executionContext
SavedRegistersBegin Offset
// GoFunctionCallCalleeModuleContextOpaque is an offset of `goFunctionCallCalleeModuleContextOpaque` field in wazevo.executionContext

View File

@@ -51,7 +51,7 @@ var tests = map[string]testCase{
"overflow integer addition": {f: testOverflow},
"un-signed extend global": {f: testGlobalExtend},
"user-defined primitive in host func": {f: testUserDefinedPrimitiveHostFunc},
"ensures invocations terminate on module close": {f: testEnsureTerminationOnClose, wazevoSkip: true},
"ensures invocations terminate on module close": {f: testEnsureTerminationOnClose},
"call host function indirectly": {f: callHostFunctionIndirect, wazevoSkip: true},
"lookup function": {f: testLookupFunction},
"memory grow in recursive call": {f: testMemoryGrowInRecursiveCall},