api: adds CallWithStack to avoid allocations (#1407)

Signed-off-by: Nuno Cruces <ncruces@users.noreply.github.com>
This commit is contained in:
Nuno Cruces
2023-05-01 00:52:40 +01:00
committed by GitHub
parent 0bfb4b52eb
commit 77e8d72d67
11 changed files with 188 additions and 105 deletions

View File

@@ -366,6 +366,34 @@ type Function interface {
// the end-to-end demonstrations of how these terminations can be performed.
Call(ctx context.Context, params ...uint64) ([]uint64, error)
// CallWithStack is an optimized variation of Call that saves memory
// allocations when the stack slice is reused across calls.
//
// Stack length must be at least the max of parameter or result length.
// The caller adds parameters in order to the stack, and reads any results
// in order from the stack, except in the error case.
//
// For example, the following reuses the same stack slice to call searchFn
// repeatedly saving one allocation per iteration:
//
// stack := make([]uint64, 4)
// for i, search := range searchParams {
// // copy the next params to the stack
// copy(stack, search)
// if err := searchFn.CallWithStack(ctx, stack); err != nil {
// return err
// } else if stack[0] == 1 { // found
// return i // searchParams[i] matched!
// }
// }
//
// # Notes
//
// - This is similar to GoModuleFunction, except for using calling functions
// instead of implementing them. Moreover, this is used regardless of
// whether the callee is a host or wasm defined function.
CallWithStack(ctx context.Context, stack []uint64) error
internalapi.WazeroOnly
}

View File

@@ -76,8 +76,11 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
panic(err)
}
// This needs copy (not reslice) because the stack is reused for results.
// Consider invoke_i (zero arguments, one result): index zero (tableOffset)
// is needed to store the result.
tableOffset := wasm.Index(stack[0]) // position in the module's only table.
params := stack[1:] // parameters to the dynamic function being called
copy(stack, stack[1:]) // pop the tableOffset.
// Lookup the table index we will call.
t := m.Tables[0] // Note: Emscripten doesn't use multiple tables
@@ -86,10 +89,8 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
panic(err)
}
ret, err := m.Engine.NewFunction(idx).Call(ctx, params...)
err = m.Engine.NewFunction(idx).CallWithStack(ctx, stack)
if err != nil {
panic(err)
}
// if there are any results, copy them back to the stack
copy(stack, ret)
}

View File

@@ -705,6 +705,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {
// Call implements the same method as documented on wasm.ModuleEngine.
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
ft := ce.initialFn.funcType
if n := ft.ParamNumInUint64; n != len(params) {
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
}
return ce.call(ctx, params, nil)
}
// CallWithStack implements the same method as documented on wasm.ModuleEngine.
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
params, results, err := wasm.SplitCallStack(ce.initialFn.funcType, stack)
if err != nil {
return err
}
_, err = ce.call(ctx, params, results)
return err
}
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.initialFn.moduleInstance
if ce.ensureTermination {
select {
@@ -717,13 +735,6 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
}
}
tp := ce.initialFn.funcType
paramCount := len(params)
if tp.ParamNumInUint64 != paramCount {
return nil, fmt.Errorf("expected %d params, but passed %d", ce.initialFn.funcType.ParamNumInUint64, paramCount)
}
// We ensure that this Call method never panics as
// this Call method is indirectly invoked by embedders via store.CallFunction,
// and we have to make sure that all the runtime errors, including the one happening inside
@@ -736,7 +747,8 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
}
}()
ce.initializeStack(tp, params)
ft := ce.initialFn.funcType
ce.initializeStack(ft, params)
if ce.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
@@ -747,12 +759,12 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
// This returns a safe copy of the results, instead of a slice view. If we
// returned a re-slice, the caller could accidentally or purposefully
// corrupt the stack of subsequent calls
if resultCount := tp.ResultNumInUint64; resultCount > 0 {
results = make([]uint64, resultCount)
copy(results, ce.stack[:resultCount])
// corrupt the stack of subsequent calls.
if results == nil && ft.ResultNumInUint64 > 0 {
results = make([]uint64, ft.ResultNumInUint64)
}
return
copy(results, ce.stack)
return results, nil
}
// initializeStack initializes callEngine.stack before entering native code.

View File

@@ -67,6 +67,16 @@ func TestCompiler_ModuleEngine_Call(t *testing.T) {
`, "\n"+functionLog.String())
}
func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
defer functionLog.Reset()
requireSupportedOSArch(t)
enginetest.RunTestModuleEngineCallWithStack(t, et)
require.Equal(t, `
--> .$0(1,2)
<-- (1,2)
`, "\n"+functionLog.String())
}
func TestCompiler_ModuleEngine_Call_HostFn(t *testing.T) {
defer functionLog.Reset()
requireSupportedOSArch(t)

View File

@@ -116,6 +116,10 @@ func (ce *callEngine) pushValue(v uint64) {
ce.stack = append(ce.stack, v)
}
func (ce *callEngine) pushValues(v []uint64) {
ce.stack = append(ce.stack, v...)
}
func (ce *callEngine) popValue() (v uint64) {
// No need to check stack bound
// as we can assume that all the operations
@@ -129,6 +133,12 @@ func (ce *callEngine) popValue() (v uint64) {
return
}
func (ce *callEngine) popValues(v []uint64) {
stackTopIndex := len(ce.stack) - len(v)
copy(v, ce.stack[stackTopIndex:])
ce.stack = ce.stack[:stackTopIndex]
}
// peekValues peeks api.ValueType values from the stack and returns them.
func (ce *callEngine) peekValues(count int) []uint64 {
if count == 0 {
@@ -445,10 +455,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {
// Call implements the same method as documented on api.Function.
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
return ce.call(ctx, ce.compiled, params)
ft := ce.compiled.funcType
if n := ft.ParamNumInUint64; n != len(params) {
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
}
return ce.call(ctx, params, nil)
}
func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (results []uint64, err error) {
// CallWithStack implements the same method as documented on api.Function.
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
params, results, err := wasm.SplitCallStack(ce.compiled.funcType, stack)
if err != nil {
return err
}
_, err = ce.call(ctx, params, results)
return err
}
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.compiled.moduleInstance
if ce.compiled.parent.ensureTermination {
select {
@@ -461,13 +485,6 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
}
}
ft := tf.funcType
paramSignature := ft.ParamNumInUint64
paramCount := len(params)
if paramSignature != paramCount {
return nil, fmt.Errorf("expected %d params, but passed %d", paramSignature, paramCount)
}
defer func() {
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
if err == nil {
@@ -480,22 +497,24 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
}
}()
for _, param := range params {
ce.pushValue(param)
}
ce.pushValues(params)
if ce.compiled.parent.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}
ce.callFunction(ctx, m, tf)
ce.callFunction(ctx, m, ce.compiled)
// This returns a safe copy of the results, instead of a slice view. If we
// returned a re-slice, the caller could accidentally or purposefully
// corrupt the stack of subsequent calls.
results = wasm.PopValues(ft.ResultNumInUint64, ce.popValue)
return
ft := ce.compiled.funcType
if results == nil && ft.ResultNumInUint64 > 0 {
results = make([]uint64, ft.ResultNumInUint64)
}
ce.popValues(results)
return results, nil
}
// recoverOnCall takes the recovered value `recoverOnCall`, and wraps it

View File

@@ -105,6 +105,15 @@ func TestInterpreter_ModuleEngine_Call(t *testing.T) {
`, "\n"+functionLog.String())
}
func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
defer functionLog.Reset()
enginetest.RunTestModuleEngineCallWithStack(t, et)
require.Equal(t, `
--> .$0(1,2)
<-- (1,2)
`, "\n"+functionLog.String())
}
func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) {
defer functionLog.Reset()
enginetest.RunTestModuleEngineCallHostFn(t, et)

View File

@@ -58,6 +58,23 @@ func BenchmarkHostFunctionCall(b *testing.B) {
}
}
})
b.Run(fn+"_with_stack", func(b *testing.B) {
ce := getCallEngine(m, fn)
b.ResetTimer()
stack := make([]uint64, 1)
for i := 0; i < b.N; i++ {
stack[0] = offset
err := ce.CallWithStack(testCtx, stack)
if err != nil {
b.Fatal(err)
}
if uint32(stack[0]) != math.Float32bits(val) {
b.Fail()
}
}
})
}
}

View File

@@ -191,6 +191,58 @@ func RunTestModuleEngineCall(t *testing.T, et EngineTester) {
})
}
func RunTestModuleEngineCallWithStack(t *testing.T, et EngineTester) {
e := et.NewEngine(api.CoreFeaturesV2)
// Define a basic function which defines two parameters and two results.
// This is used to test results when incorrect arity is used.
m := &wasm.Module{
TypeSection: []wasm.FunctionType{
{
Params: []wasm.ValueType{i64, i64},
Results: []wasm.ValueType{i64, i64},
ParamNumInUint64: 2,
ResultNumInUint64: 2,
},
},
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{
{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeLocalGet, 1, wasm.OpcodeEnd}},
},
}
m.BuildFunctionDefinitions()
listeners := buildListeners(et.ListenerFactory(), m)
err := e.CompileModule(testCtx, m, listeners, false)
require.NoError(t, err)
// To use the function, we first need to add it to a module.
module := &wasm.ModuleInstance{
ModuleName: t.Name(), TypeIDs: []wasm.FunctionTypeID{0},
Definitions: m.FunctionDefinitionSection,
}
// Compile the module
me, err := e.NewModuleEngine(m, module)
require.NoError(t, err)
linkModuleToEngine(module, me)
// Ensure the base case doesn't fail: A single parameter should work as that matches the function signature.
const funcIndex = 0
ce := me.NewFunction(funcIndex)
stack := []uint64{1, 2}
err = ce.CallWithStack(testCtx, stack)
require.NoError(t, err)
require.Equal(t, []uint64{1, 2}, stack)
t.Run("errs when not enough parameters", func(t *testing.T) {
ce := me.NewFunction(funcIndex)
err = ce.CallWithStack(testCtx, nil)
require.EqualError(t, err, "need 2 params, but stack size is 0")
})
}
func RunTestModuleEngineLookupFunction(t *testing.T, et EngineTester) {
e := et.NewEngine(api.CoreFeaturesV1)

View File

@@ -79,26 +79,6 @@ func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
callGoFunc(ctx, nil, f.fn, stack)
}
// PopValues pops the specified number of api.ValueType parameters off the
// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction.
//
// For example, if the host function F requires the (x1 uint32, x2 float32)
// parameters, and the stack is [..., A, B], then the function is called as
// F(A, B) where A and B are interpreted as uint32 and float32 respectively.
//
// Note: the popper intentionally doesn't return bool or error because the
// caller's stack depth is trusted.
func PopValues(count int, popper func() uint64) []uint64 {
if count == 0 {
return nil
}
params := make([]uint64, count)
for i := count - 1; i >= 0; i-- {
params[i] = popper()
}
return params
}
// callGoFunc executes the reflective function by converting params to Go
// types. The results of the function call are converted back to api.ValueType.
func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {

View File

@@ -124,55 +124,6 @@ func Test_parseGoFunc_Errors(t *testing.T) {
}
}
// stack simulates the value stack in a way easy to be tested.
type stack struct {
vals []uint64
}
func (s *stack) pop() (result uint64) {
stackTopIndex := len(s.vals) - 1
result = s.vals[stackTopIndex]
s.vals = s.vals[:stackTopIndex]
return
}
func TestPopValues(t *testing.T) {
stackVals := []uint64{1, 2, 3, 4, 5, 6, 7}
tests := []struct {
name string
count int
expected []uint64
}{
{
name: "pop zero doesn't allocate a slice ",
},
{
name: "pop 1",
count: 1,
expected: []uint64{7},
},
{
name: "pop 2",
count: 2,
expected: []uint64{6, 7},
},
{
name: "pop 3",
count: 3,
expected: []uint64{5, 6, 7},
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
vals := PopValues(tc.count, (&stack{stackVals}).pop)
require.Equal(t, tc.expected, vals)
})
}
}
func Test_callGoFunc(t *testing.T) {
tPtr := uintptr(unsafe.Pointer(t))
inst := &ModuleInstance{}

View File

@@ -488,12 +488,16 @@ func (e *mockModuleEngine) Close(context.Context) {
func (ce *mockCallEngine) Definition() api.FunctionDefinition { return nil }
// Call implements the same method as documented on api.Function.
func (ce *mockCallEngine) Call(_ context.Context, _ ...uint64) (results []uint64, err error) {
func (ce *mockCallEngine) Call(ctx context.Context, _ ...uint64) (results []uint64, err error) {
return nil, ce.CallWithStack(ctx, nil)
}
// CallWithStack implements the same method as documented on api.Function.
func (ce *mockCallEngine) CallWithStack(_ context.Context, _ []uint64) error {
if ce.callFailIndex >= 0 && ce.index == Index(ce.callFailIndex) {
err = errors.New("call failed")
return
return errors.New("call failed")
}
return
return nil
}
func TestStore_getFunctionTypeID(t *testing.T) {