Provide new StackIterator to Before Listener hook (#1363)

Signed-off-by: Thomas Pelletier <thomas@pelletier.codes>
This commit is contained in:
Thomas Pelletier
2023-04-17 19:22:58 -04:00
committed by GitHub
parent cd34767954
commit 9aca08c5e6
10 changed files with 428 additions and 18 deletions

View File

@@ -6,6 +6,24 @@ import (
"github.com/tetratelabs/wazero/api"
)
// StackIterator allows iterating on each function of the call stack, starting
// from the top. At least one call to Next() is required to start the iteration.
//
// Note: The iterator provides a view of the call stack at the time of
// iteration. As a result, parameter values may be different than the ones their
// function was called with.
type StackIterator interface {
// Next moves the iterator to the next function in the stack. Returns false
// if it reached the bottom of the stack.
Next() bool
// FunctionDefinition returns the function type of the current function.
FunctionDefinition() api.FunctionDefinition
// Parameters returns api.ValueType-encoded parameters of the current
// function. Do not modify the content of the slice, and copy out any value
// you need.
Parameters() []uint64
}
// FunctionListenerFactoryKey is a context.Context Value key. Its associated value should be a FunctionListenerFactory.
//
// See https://github.com/tetratelabs/wazero/issues/451
@@ -35,9 +53,12 @@ type FunctionListener interface {
// - mod: the calling module.
// - def: the function definition.
// - paramValues: api.ValueType encoded parameters.
// - stackIterator: iterator on the call stack. At least one entry is
// guaranteed (the called function), whose Args() will be equal to
// paramValues. The iterator will be reused between calls to Before.
//
// Note: api.Memory is meant for inspection, not modification.
Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64) context.Context
Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator StackIterator) context.Context
// After is invoked after a function is called.
//

View File

@@ -11,6 +11,7 @@ import (
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/internal/wasm"
)
// listenerWasm was generated by the following:
@@ -43,7 +44,7 @@ func (u uniqGoFuncs) NewListener(def api.FunctionDefinition) experimental.Functi
}
// Before implements FunctionListener.Before
func (u uniqGoFuncs) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64) context.Context {
func (u uniqGoFuncs) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64, _ experimental.StackIterator) context.Context {
u[def.DebugName()] = struct{}{}
return ctx
}
@@ -83,3 +84,65 @@ func Example_customListenerFactory() {
// wasi_snapshot_preview1.fd_write
// wasi_snapshot_preview1.random_get
}
func Example_stackIterator() {
it := &fakeStackIterator{}
for it.Next() {
fmt.Println("function:", it.FunctionDefinition().DebugName(), "args", it.Args())
}
// Output:
// function: fn0 args [1 2 3]
// function: fn1 args []
// function: fn2 args [4]
}
type fakeStackIterator struct {
iteration int
def api.FunctionDefinition
args []uint64
}
func (s *fakeStackIterator) Next() bool {
switch s.iteration {
case 0:
s.def = &mockFunctionDefinition{debugName: "fn0"}
s.args = []uint64{1, 2, 3}
case 1:
s.def = &mockFunctionDefinition{debugName: "fn1"}
s.args = []uint64{}
case 2:
s.def = &mockFunctionDefinition{debugName: "fn2"}
s.args = []uint64{4}
case 3:
return false
}
s.iteration++
return true
}
func (s *fakeStackIterator) FunctionDefinition() api.FunctionDefinition {
return s.def
}
func (s *fakeStackIterator) Args() []uint64 {
return s.args
}
type mockFunctionDefinition struct {
debugName string
*wasm.FunctionDefinition
}
func (f *mockFunctionDefinition) DebugName() string {
return f.debugName
}
func (f *mockFunctionDefinition) ParamTypes() []wasm.ValueType {
return []wasm.ValueType{}
}
func (f *mockFunctionDefinition) ResultTypes() []wasm.ValueType {
return []wasm.ValueType{}
}

View File

@@ -21,7 +21,7 @@ type recorder struct {
beforeNames, afterNames []string
}
func (r *recorder) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64) context.Context {
func (r *recorder) Before(ctx context.Context, _ api.Module, def api.FunctionDefinition, _ []uint64, _ experimental.StackIterator) context.Context {
r.beforeNames = append(r.beforeNames, def.DebugName())
return ctx
}

View File

@@ -180,7 +180,7 @@ type loggingListener struct {
// Before logs to stdout the module and function name, prefixed with '-->' and
// indented based on the call nesting level.
func (l *loggingListener) Before(ctx context.Context, mod api.Module, _ api.FunctionDefinition, params []uint64) context.Context {
func (l *loggingListener) Before(ctx context.Context, mod api.Module, _ api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context {
// First, see if this invocation is sampled.
sampled := true
if s := l.pSampler; s != nil {

View File

@@ -301,7 +301,7 @@ func Test_loggingListener(t *testing.T) {
l := lf.NewListener(def)
out.Reset()
ctx := l.Before(testCtx, nil, def, tc.params)
ctx := l.Before(testCtx, nil, def, tc.params, nil)
l.After(ctx, nil, def, tc.err, tc.results)
require.Equal(t, tc.expected, out.String())
})
@@ -337,8 +337,8 @@ func Test_loggingListener_indentation(t *testing.T) {
def2 := &m.FunctionDefinitionSection[1]
l2 := lf.NewListener(def2)
ctx := l1.Before(testCtx, nil, def1, []uint64{})
ctx1 := l2.Before(ctx, nil, def2, []uint64{})
ctx := l1.Before(testCtx, nil, def1, []uint64{}, nil)
ctx1 := l2.Before(ctx, nil, def2, []uint64{}, nil)
l2.After(ctx1, nil, def2, nil, []uint64{})
l1.After(ctx, nil, def1, nil, []uint64{})
require.Equal(t, `--> test.fn1()

View File

@@ -133,6 +133,10 @@ type (
// contextStack is a stack of contexts which is pushed and popped by function listeners.
// This is used and modified when there are function listeners.
contextStack *contextStack
// stackIterator provides a way to iterate over the stack for Listeners.
// It is setup and valid only during a call to a Listener hook.
stackIterator stackIterator
}
// contextStack is a stack of context.Context.
@@ -1034,12 +1038,67 @@ func (ce *callEngine) builtinFunctionTableGrow(tables []*wasm.TableInstance) {
ce.pushValue(uint64(res))
}
// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64
fn *function
base int
started bool
}
func (si *stackIterator) reset(stack []uint64, fn *function, base int) {
si.stack = stack
si.fn = fn
si.base = base
si.started = false
}
func (si *stackIterator) clear() {
si.stack = nil
si.fn = nil
si.base = 0
si.started = false
}
// Next implements experimental.StackIterator.
func (si *stackIterator) Next() bool {
if !si.started {
si.started = true
return true
}
if si.fn == nil || si.base == 0 {
return false
}
frame := si.base + callFrameOffset(si.fn.funcType)
si.base = int(si.stack[frame+1] >> 3)
// *function lives in the third field of callFrame struct. This must be
// aligned with the definition of callFrame struct.
si.fn = (*function)(unsafe.Pointer(uintptr(si.stack[frame+2])))
return si.fn != nil
}
// FunctionDefinition implements experimental.StackIterator.
func (si *stackIterator) FunctionDefinition() api.FunctionDefinition {
return si.fn.def
}
// Args implements experimental.StackIterator.
func (si *stackIterator) Parameters() []uint64 {
return si.stack[si.base : si.base+si.fn.funcType.ParamNumInUint64]
}
func (ce *callEngine) builtinFunctionFunctionListenerBefore(ctx context.Context, mod api.Module, fn *function) {
base := int(ce.stackBasePointerInBytes >> 3)
listerCtx := fn.parent.listener.Before(ctx, mod, fn.def, ce.stack[base:base+fn.funcType.ParamNumInUint64])
ce.stackIterator.reset(ce.stack, fn, base)
listerCtx := fn.parent.listener.Before(ctx, mod, fn.def, ce.stack[base:base+fn.funcType.ParamNumInUint64], &ce.stackIterator)
prevStackTop := ce.contextStack
ce.contextStack = &contextStack{self: ctx, prev: prevStackTop}
ce.ctx = listerCtx
ce.stackIterator.clear()
}
func (ce *callEngine) builtinFunctionFunctionListenerAfter(ctx context.Context, mod api.Module, fn *function) {

View File

@@ -134,6 +134,10 @@ func TestCompiler_ModuleEngine_Memory(t *testing.T) {
enginetest.RunTestModuleEngine_Memory(t, et)
}
func TestCompiler_BeforeListenerStackIterator(t *testing.T) {
enginetest.RunTestModuleEngine_BeforeListenerStackIterator(t, et)
}
// requireSupportedOSArch is duplicated also in the platform package to ensure no cyclic dependency.
func requireSupportedOSArch(t *testing.T) {
if !platform.CompilerSupported() {
@@ -575,24 +579,39 @@ func Test_callFrameOffset(t *testing.T) {
require.Equal(t, 100, callFrameOffset(&wasm.FunctionType{ParamNumInUint64: 100, ResultNumInUint64: 50}))
}
type stackEntry struct {
def api.FunctionDefinition
args []uint64
}
func assertStackIterator(t *testing.T, it experimental.StackIterator, expected []stackEntry) {
var actual []stackEntry
for it.Next() {
actual = append(actual, stackEntry{def: it.FunctionDefinition(), args: it.Parameters()})
}
require.Equal(t, expected, actual)
}
func TestCallEngine_builtinFunctionFunctionListenerBefore(t *testing.T) {
nextContext, currentContext, prevContext := context.Background(), context.Background(), context.Background()
def := newMockFunctionDefinition("1")
f := &function{
def: newMockFunctionDefinition("1"),
def: def,
funcType: &wasm.FunctionType{ParamNumInUint64: 3},
parent: &code{
listener: mockListener{
before: func(ctx context.Context, _ api.Module, def api.FunctionDefinition, paramValues []uint64) context.Context {
before: func(ctx context.Context, _ api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator experimental.StackIterator) context.Context {
require.Equal(t, currentContext, ctx)
require.Equal(t, []uint64{2, 3, 4}, paramValues)
assertStackIterator(t, stackIterator, []stackEntry{{def: def, args: []uint64{2, 3, 4}}})
return nextContext
},
},
},
}
ce := &callEngine{
ctx: currentContext, stack: []uint64{0, 1, 2, 3, 4, 5},
ctx: currentContext, stack: []uint64{0, 1, 2, 3, 4, 0, 0, 0},
stackContext: stackContext{stackBasePointerInBytes: 16},
contextStack: &contextStack{self: prevContext},
}
@@ -631,12 +650,12 @@ func TestCallEngine_builtinFunctionFunctionListenerAfter(t *testing.T) {
}
type mockListener struct {
before func(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64) context.Context
before func(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator experimental.StackIterator) context.Context
after func(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, resultValues []uint64)
}
func (m mockListener) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64) context.Context {
return m.before(ctx, mod, def, paramValues)
func (m mockListener) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator experimental.StackIterator) context.Context {
return m.before(ctx, mod, def, paramValues, stackIterator)
}
func (m mockListener) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, resultValues []uint64) {

View File

@@ -100,6 +100,9 @@ type callEngine struct {
// compiled is the initial function for this call engine.
compiled *function
// stackiterator Listeners to walk frames and stack.
stackIterator stackIterator
}
func (e *moduleEngine) newCallEngine(compiled *function) *callEngine {
@@ -166,6 +169,9 @@ type callFrame struct {
pc uint64
// f is the compiled function used in this function frame.
f *function
// base index in the frame of this function, used to detect the count of
// values on the stack.
base int
}
type code struct {
@@ -198,6 +204,58 @@ func functionFromUintptr(ptr uintptr) *function {
return *(**function)(unsafe.Pointer(wrapped))
}
// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64
frames []*callFrame
started bool
fn *function
}
func (si *stackIterator) reset(stack []uint64, frames []*callFrame, f *function) {
si.fn = f
si.stack = stack
si.frames = frames
si.started = false
}
func (si *stackIterator) clear() {
si.stack = nil
si.frames = nil
si.started = false
si.fn = nil
}
// Next implements experimental.StackIterator.
func (si *stackIterator) Next() bool {
if !si.started {
si.started = true
return true
}
if len(si.frames) == 0 {
return false
}
frame := si.frames[len(si.frames)-1]
si.stack = si.stack[:frame.base]
si.fn = frame.f
si.frames = si.frames[:len(si.frames)-1]
return true
}
// FunctionDefinition implements experimental.StackIterator.
func (si *stackIterator) FunctionDefinition() api.FunctionDefinition {
return si.fn.def
}
// Args implements experimental.StackIterator.
func (si *stackIterator) Parameters() []uint64 {
paramsCount := si.fn.funcType.ParamNumInUint64
top := len(si.stack)
return si.stack[top-paramsCount:]
}
// interpreter mode doesn't maintain call frames in the stack, so pass the zero size to the IR.
const callFrameStackSize = 0
@@ -475,9 +533,11 @@ func (ce *callEngine) callGoFunc(ctx context.Context, m *wasm.ModuleInstance, f
lsn := f.parent.listener
if lsn != nil {
params := stack[:typ.ParamNumInUint64]
ctx = lsn.Before(ctx, m, def, params)
ce.stackIterator.reset(ce.stack, ce.frames, f)
ctx = lsn.Before(ctx, m, def, params, &ce.stackIterator)
ce.stackIterator.clear()
}
frame := &callFrame{f: f}
frame := &callFrame{f: f, base: len(ce.stack)}
ce.pushFrame(frame)
fn := f.parent.hostFn
@@ -497,7 +557,7 @@ func (ce *callEngine) callGoFunc(ctx context.Context, m *wasm.ModuleInstance, f
}
func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance, f *function) {
frame := &callFrame{f: f}
frame := &callFrame{f: f, base: len(ce.stack)}
moduleInst := f.moduleInstance
functions := moduleInst.Engine.(*moduleEngine).functions
memoryInst := moduleInst.MemoryInstance
@@ -3966,7 +4026,10 @@ func i32Abs(v uint32) uint32 {
func (ce *callEngine) callNativeFuncWithListener(ctx context.Context, m *wasm.ModuleInstance, f *function, fnl experimental.FunctionListener) context.Context {
def, typ := &f.moduleInstance.Definitions[f.index], f.funcType
ctx = fnl.Before(ctx, m, def, ce.peekValues(len(typ.Params)))
ce.stackIterator.reset(ce.stack, ce.frames, f)
ctx = fnl.Before(ctx, m, def, ce.peekValues(len(typ.Params)), &ce.stackIterator)
ce.stackIterator.clear()
ce.callNativeFunc(ctx, m, f)
fnl.After(ctx, m, def, nil, ce.peekValues(len(typ.Results)))
return ctx

View File

@@ -561,3 +561,7 @@ func TestEngine_CachedcodesPerModule(t *testing.T) {
_, ok = e.getCodes(m)
require.False(t, ok)
}
func TestCompiler_BeforeListenerStackIterator(t *testing.T) {
enginetest.RunTestModuleEngine_BeforeListenerStackIterator(t, et)
}

View File

@@ -458,6 +458,187 @@ wasm stack trace:
}
}
// This tests that the StackIterator provided by the Engine to the Before hook
// of the listener is properly able to walk the stack. As an example, it
// validates that the following call stack is properly walked:
//
// 1. f1(2,3,4) [no return, no local]
// 2. calls f2(no arg) [1 return, 1 local]
// 3. calls f3(5) [1 return, no local]
// 4. calls f4(6) [1 return, HOST]
func RunTestModuleEngine_BeforeListenerStackIterator(t *testing.T, et EngineTester) {
e := et.NewEngine(api.CoreFeaturesV2)
type stackEntry struct {
debugName string
args []uint64
}
expectedCallstacks := [][]stackEntry{
{ // when calling f1
{debugName: "whatever.f1", args: []uint64{2, 3, 4}},
},
{ // when calling f2
{debugName: "whatever.f2", args: []uint64{}},
{debugName: "whatever.f1", args: []uint64{2, 3, 4}},
},
{ // when calling f3
{debugName: "whatever.f3", args: []uint64{5}},
{debugName: "whatever.f2", args: []uint64{}},
{debugName: "whatever.f1", args: []uint64{2, 3, 4}},
},
{ // when calling f4
{debugName: "whatever.f4", args: []uint64{6}},
{debugName: "whatever.f3", args: []uint64{5}},
{debugName: "whatever.f2", args: []uint64{}},
{debugName: "whatever.f1", args: []uint64{2, 3, 4}},
},
}
fnListener := &fnListener{
beforeFn: func(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, si experimental.StackIterator) context.Context {
require.True(t, len(expectedCallstacks) > 0)
expectedCallstack := expectedCallstacks[0]
for si.Next() {
require.True(t, len(expectedCallstack) > 0)
require.Equal(t, expectedCallstack[0].debugName, si.FunctionDefinition().DebugName())
require.Equal(t, expectedCallstack[0].args, si.Parameters())
expectedCallstack = expectedCallstack[1:]
}
require.Equal(t, 0, len(expectedCallstack))
expectedCallstacks = expectedCallstacks[1:]
return ctx
},
}
functionTypes := []wasm.FunctionType{
// f1 type
{
Params: []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32},
ParamNumInUint64: 3,
Results: []api.ValueType{},
ResultNumInUint64: 0,
},
// f2 type
{
Params: []api.ValueType{},
ParamNumInUint64: 0,
Results: []api.ValueType{api.ValueTypeI32},
ResultNumInUint64: 1,
},
// f3 type
{
Params: []api.ValueType{api.ValueTypeI32},
ParamNumInUint64: 1,
Results: []api.ValueType{api.ValueTypeI32},
ResultNumInUint64: 1,
},
// f4 type
{
Params: []api.ValueType{api.ValueTypeI32},
ParamNumInUint64: 1,
Results: []api.ValueType{api.ValueTypeI32},
ResultNumInUint64: 1,
},
}
hostgofn := wasm.MustParseGoReflectFuncCode(func(x int32) int32 {
return x + 100
})
m := &wasm.Module{
TypeSection: functionTypes,
FunctionSection: []wasm.Index{0, 1, 2, 3},
NameSection: &wasm.NameSection{
ModuleName: "whatever",
FunctionNames: wasm.NameMap{
{Index: wasm.Index(0), Name: "f1"},
{Index: wasm.Index(1), Name: "f2"},
{Index: wasm.Index(2), Name: "f3"},
{Index: wasm.Index(3), Name: "f4"},
},
},
CodeSection: []wasm.Code{
{ // f1
Body: []byte{
wasm.OpcodeI32Const, 0, // reserve return for f2
wasm.OpcodeCall,
1, // call f2
wasm.OpcodeEnd,
},
},
{ // f2
LocalTypes: []wasm.ValueType{wasm.ValueTypeI32},
Body: []byte{
wasm.OpcodeI32Const, 42, // local for f2
wasm.OpcodeLocalSet, 0,
wasm.OpcodeI32Const, 5, // argument of f3
wasm.OpcodeCall,
2, // call f3
wasm.OpcodeEnd,
},
},
{ // f3
Body: []byte{
wasm.OpcodeI32Const, 6,
wasm.OpcodeCall,
3, // call host function
wasm.OpcodeEnd,
},
},
// f4 [host function]
hostgofn,
},
ExportSection: []wasm.Export{
{Name: "f1", Type: wasm.ExternTypeFunc, Index: 0},
},
ID: wasm.ModuleID{0},
}
m.BuildFunctionDefinitions()
listeners := buildListeners(fnListener, m)
err := e.CompileModule(testCtx, m, listeners, false)
require.NoError(t, err)
module := &wasm.ModuleInstance{
ModuleName: t.Name(),
TypeIDs: []wasm.FunctionTypeID{0, 1, 2, 3},
Definitions: m.FunctionDefinitionSection,
Exports: exportMap(m),
}
me, err := e.NewModuleEngine(m, module)
require.NoError(t, err)
linkModuleToEngine(module, me)
initCallEngine := me.NewFunction(0) // f1
_, err = initCallEngine.Call(testCtx, 2, 3, 4)
require.NoError(t, err)
require.Equal(t, 0, len(expectedCallstacks))
}
type fnListener struct {
beforeFn func(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator experimental.StackIterator) context.Context
afterFn func(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, resultValues []uint64)
}
func (f *fnListener) NewListener(fnd api.FunctionDefinition) experimental.FunctionListener {
return f
}
func (f fnListener) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, paramValues []uint64, stackIterator experimental.StackIterator) context.Context {
if f.beforeFn != nil {
return f.beforeFn(ctx, mod, def, paramValues, stackIterator)
}
return ctx
}
func (f fnListener) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, resultValues []uint64) {
if f.afterFn != nil {
f.afterFn(ctx, mod, def, err, resultValues)
}
}
// RunTestModuleEngine_Memory shows that the byte slice returned from api.Memory Read is not a copy, rather a re-slice
// of the underlying memory. This allows both host and Wasm to see each other's writes, unless one side changes the
// capacity of the slice.