wazevo: mitigate closed module presence in stack unwinding (#2128)

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
This commit is contained in:
Takeshi Yoneda
2024-03-07 11:51:31 +09:00
committed by GitHub
parent 4d110241c9
commit dfb2959d34
6 changed files with 81 additions and 19 deletions

View File

@@ -1532,10 +1532,13 @@ func (m *machine) lowerExitWithCode(execCtx regalloc.VReg, code wazevoapi.ExitCo
m.lowerIconst(exitCodeReg, uint64(code), false)
m.insert(setExitCode)
// Next is to save the return address.
readRip := m.allocateInstr()
m.insert(readRip)
ripReg := rbpVReg
// Next is to save the current address for stack unwinding.
nop, currentAddrLabel := m.allocateBrTarget()
m.insert(nop)
readRip := m.allocateInstr().asLEA(newOperandLabel(currentAddrLabel), ripReg)
m.insert(readRip)
saveRip := m.allocateInstr().asMovRM(
ripReg,
newOperandMem(m.newAmodeImmReg(wazevoapi.ExecutionContextOffsetGoCallReturnAddress.U32(), execCtx)),
@@ -1547,11 +1550,10 @@ func (m *machine) lowerExitWithCode(execCtx regalloc.VReg, code wazevoapi.ExitCo
exitSq := m.allocateExitSeq(execCtx)
m.insert(exitSq)
// Insert the label for the return address.
nop, l := m.allocateBrTarget()
readRip.asLEA(newOperandLabel(l), ripReg)
m.insert(nop)
return l
// Return the label for continuation.
continuation, afterLabel := m.allocateBrTarget()
m.insert(continuation)
return afterLabel
}
func (m *machine) lowerAluRmiROp(si *ssa.Instruction, op aluRmiROpcode) {

View File

@@ -279,19 +279,20 @@ func Test_machine_getOperand_Mem_Imm32_Reg(t *testing.T) {
func TestMachine_lowerExitWithCode(t *testing.T) {
_, _, m := newSetupWithMockContext()
m.lowerExitWithCode(r15VReg, wazevoapi.ExitCodeCallGoFunction)
m.lowerExitWithCode(r15VReg, wazevoapi.ExitCodeUnreachable)
m.insert(m.allocateInstr().asUD2())
m.ectx.FlushPendingInstructions()
m.ectx.RootInstr = m.ectx.PerBlockHead
require.Equal(t, `
mov.q %rsp, 56(%r15)
mov.q %rbp, 1152(%r15)
movl $6, %ebp
movl $3, %ebp
mov.l %rbp, (%r15)
L1:
lea L1, %rbp
mov.q %rbp, 48(%r15)
exit_sequence %r15
L1:
L2:
ud2
`, m.Format())
}

View File

@@ -147,6 +147,26 @@ func (c *callEngine) Call(ctx context.Context, params ...uint64) ([]uint64, erro
func (c *callEngine) addFrame(builder wasmdebug.ErrorBuilder, addr uintptr) (def api.FunctionDefinition, listener experimental.FunctionListener) {
eng := c.parent.parent.parent
cm := eng.compiledModuleOfAddr(addr)
if cm == nil {
// This case, the module might have been closed and deleted from the engine.
// We fall back to searching the imported modules that can be referenced from this callEngine.
// First, we check itself.
if checkAddrInBytes(addr, c.parent.parent.executable) {
cm = c.parent.parent
} else {
// Otherwise, search all imported modules. TODO: maybe recursive, but not sure it's useful in practice.
p := c.parent
for i := range p.importedFunctions {
candidate := p.importedFunctions[i].me.parent
if checkAddrInBytes(addr, candidate.executable) {
cm = candidate
break
}
}
}
}
if cm != nil {
index := cm.functionIndexOf(addr)
def = cm.module.FunctionDefinition(cm.module.ImportFunctionCount + index)

View File

@@ -540,7 +540,16 @@ func (e *engine) compiledModuleOfAddr(addr uintptr) *compiledModule {
if index < 0 {
return nil
}
return e.sortedCompiledModules[index]
candidate := e.sortedCompiledModules[index]
if checkAddrInBytes(addr, candidate.executable) {
// If a module is already deleted, the found module may have been wrong.
return candidate
}
return nil
}
func checkAddrInBytes(addr uintptr, b []byte) bool {
return uintptr(unsafe.Pointer(&b[0])) <= addr && addr <= uintptr(unsafe.Pointer(&b[len(b)-1]))
}
// NewModuleEngine implements wasm.Engine.

View File

@@ -110,8 +110,8 @@ func TestEngine_sortedCompiledModules(t *testing.T) {
// TODO: use unsafe.Slice after floor version is set to Go 1.20.
hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))
hdr.Data = addr
hdr.Len = 1
hdr.Cap = 1
hdr.Len = 4
hdr.Cap = 4
}
cm := &compiledModule{executables: &executables{executable: buf}}
return cm
@@ -166,14 +166,19 @@ func TestEngine_sortedCompiledModules(t *testing.T) {
require.Equal(t, unsafe.Pointer(m1), unsafe.Pointer(e.compiledModuleOfAddr(1)))
require.Equal(t, unsafe.Pointer(m1), unsafe.Pointer(e.compiledModuleOfAddr(4)))
require.Equal(t, unsafe.Pointer(m5), unsafe.Pointer(e.compiledModuleOfAddr(5)))
require.Equal(t, unsafe.Pointer(m5), unsafe.Pointer(e.compiledModuleOfAddr(9)))
require.Equal(t, unsafe.Pointer(m5), unsafe.Pointer(e.compiledModuleOfAddr(8)))
require.Equal(t, unsafe.Pointer(m10), unsafe.Pointer(e.compiledModuleOfAddr(10)))
require.Equal(t, unsafe.Pointer(m10), unsafe.Pointer(e.compiledModuleOfAddr(11)))
require.Equal(t, unsafe.Pointer(m10), unsafe.Pointer(e.compiledModuleOfAddr(12)))
require.Equal(t, unsafe.Pointer(m10), unsafe.Pointer(e.compiledModuleOfAddr(50)))
require.Equal(t, unsafe.Pointer(m10), unsafe.Pointer(e.compiledModuleOfAddr(99)))
require.Equal(t, unsafe.Pointer(m100), unsafe.Pointer(e.compiledModuleOfAddr(100)))
require.Equal(t, unsafe.Pointer(m100), unsafe.Pointer(e.compiledModuleOfAddr(10000)))
require.Equal(t, unsafe.Pointer(m100), unsafe.Pointer(e.compiledModuleOfAddr(103)))
e.deleteCompiledModuleFromSortedList(m1)
require.Equal(t, nil, e.compiledModuleOfAddr(1))
require.Equal(t, nil, e.compiledModuleOfAddr(2))
require.Equal(t, nil, e.compiledModuleOfAddr(4))
e.deleteCompiledModuleFromSortedList(m100)
require.Equal(t, nil, e.compiledModuleOfAddr(100))
require.Equal(t, nil, e.compiledModuleOfAddr(103))
})
}
@@ -201,3 +206,14 @@ func TestCompiledModule_functionIndexOf(t *testing.T) {
require.Equal(t, wasm.Index(2), cm.functionIndexOf(executableAddr+1999))
require.Equal(t, wasm.Index(3), cm.functionIndexOf(executableAddr+2000))
}
func Test_checkAddrInBytes(t *testing.T) {
bytes := []byte{0, 1, 2, 3, 4, 5, 6, 7}
begin := uintptr(unsafe.Pointer(&bytes[0]))
end := uintptr(unsafe.Pointer(&bytes[len(bytes)-1]))
require.True(t, checkAddrInBytes(begin, bytes))
require.True(t, checkAddrInBytes(end, bytes))
require.False(t, checkAddrInBytes(begin-1, bytes))
require.False(t, checkAddrInBytes(end+1, bytes))
}

View File

@@ -1151,9 +1151,15 @@ func testModuleMemory(t *testing.T, r wazero.Runtime) {
func testTwoIndirection(t *testing.T, r wazero.Runtime) {
var buf bytes.Buffer
ctx := context.WithValue(testCtx, experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf))
_, err := r.NewHostModuleBuilder("host").NewFunctionBuilder().WithFunc(func(d uint32) uint32 {
_, err := r.NewHostModuleBuilder("host").NewFunctionBuilder().WithFunc(func(
_ context.Context, m api.Module, d uint32,
) uint32 {
if d == math.MaxUint32 {
panic(errors.New("host-function panic"))
} else if d == math.MaxUint32-1 {
err := m.CloseWithExitCode(context.Background(), 1)
require.NoError(t, err)
panic(errors.New("host-function panic"))
}
return 1 / d // panics if d ==0.
}).Export("div").Instantiate(ctx)
@@ -1217,6 +1223,11 @@ wasm stack trace:
host_importer.call_host_div(i32) i32
main.main(i32) i32`},
{name: "go runtime panic", input: 0, expErr: `runtime error: integer divide by zero (recovered by wazero)
wasm stack trace:
host.div(i32) i32
host_importer.call_host_div(i32) i32
main.main(i32) i32`},
{name: "module closed and then go runtime panic", input: math.MaxUint32 - 1, expErr: `host-function panic (recovered by wazero)
wasm stack trace:
host.div(i32) i32
host_importer.call_host_div(i32) i32
@@ -1259,6 +1270,9 @@ wasm stack trace:
--> main.main(0)
--> host_importer.call_host_div(0)
==> host.div(0)
--> main.main(-2)
--> host_importer.call_host_div(-2)
==> host.div(-2)
`, "\n"+buf.String())
}