api: adds CallWithStack to avoid allocations (#1407)
Signed-off-by: Nuno Cruces <ncruces@users.noreply.github.com>
This commit is contained in:
28
api/wasm.go
28
api/wasm.go
@@ -366,6 +366,34 @@ type Function interface {
|
||||
// the end-to-end demonstrations of how these terminations can be performed.
|
||||
Call(ctx context.Context, params ...uint64) ([]uint64, error)
|
||||
|
||||
// CallWithStack is an optimized variation of Call that saves memory
|
||||
// allocations when the stack slice is reused across calls.
|
||||
//
|
||||
// Stack length must be at least the max of parameter or result length.
|
||||
// The caller adds parameters in order to the stack, and reads any results
|
||||
// in order from the stack, except in the error case.
|
||||
//
|
||||
// For example, the following reuses the same stack slice to call searchFn
|
||||
// repeatedly saving one allocation per iteration:
|
||||
//
|
||||
// stack := make([]uint64, 4)
|
||||
// for i, search := range searchParams {
|
||||
// // copy the next params to the stack
|
||||
// copy(stack, search)
|
||||
// if err := searchFn.CallWithStack(ctx, stack); err != nil {
|
||||
// return err
|
||||
// } else if stack[0] == 1 { // found
|
||||
// return i // searchParams[i] matched!
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// # Notes
|
||||
//
|
||||
// - This is similar to GoModuleFunction, except for using calling functions
|
||||
// instead of implementing them. Moreover, this is used regardless of
|
||||
// whether the callee is a host or wasm defined function.
|
||||
CallWithStack(ctx context.Context, stack []uint64) error
|
||||
|
||||
internalapi.WazeroOnly
|
||||
}
|
||||
|
||||
|
||||
@@ -76,8 +76,11 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// This needs copy (not reslice) because the stack is reused for results.
|
||||
// Consider invoke_i (zero arguments, one result): index zero (tableOffset)
|
||||
// is needed to store the result.
|
||||
tableOffset := wasm.Index(stack[0]) // position in the module's only table.
|
||||
params := stack[1:] // parameters to the dynamic function being called
|
||||
copy(stack, stack[1:]) // pop the tableOffset.
|
||||
|
||||
// Lookup the table index we will call.
|
||||
t := m.Tables[0] // Note: Emscripten doesn't use multiple tables
|
||||
@@ -86,10 +89,8 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret, err := m.Engine.NewFunction(idx).Call(ctx, params...)
|
||||
err = m.Engine.NewFunction(idx).CallWithStack(ctx, stack)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// if there are any results, copy them back to the stack
|
||||
copy(stack, ret)
|
||||
}
|
||||
|
||||
@@ -705,6 +705,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {
|
||||
|
||||
// Call implements the same method as documented on wasm.ModuleEngine.
|
||||
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
|
||||
ft := ce.initialFn.funcType
|
||||
if n := ft.ParamNumInUint64; n != len(params) {
|
||||
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
|
||||
}
|
||||
return ce.call(ctx, params, nil)
|
||||
}
|
||||
|
||||
// CallWithStack implements the same method as documented on wasm.ModuleEngine.
|
||||
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
|
||||
params, results, err := wasm.SplitCallStack(ce.initialFn.funcType, stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = ce.call(ctx, params, results)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
|
||||
m := ce.initialFn.moduleInstance
|
||||
if ce.ensureTermination {
|
||||
select {
|
||||
@@ -717,13 +735,6 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
|
||||
}
|
||||
}
|
||||
|
||||
tp := ce.initialFn.funcType
|
||||
|
||||
paramCount := len(params)
|
||||
if tp.ParamNumInUint64 != paramCount {
|
||||
return nil, fmt.Errorf("expected %d params, but passed %d", ce.initialFn.funcType.ParamNumInUint64, paramCount)
|
||||
}
|
||||
|
||||
// We ensure that this Call method never panics as
|
||||
// this Call method is indirectly invoked by embedders via store.CallFunction,
|
||||
// and we have to make sure that all the runtime errors, including the one happening inside
|
||||
@@ -736,7 +747,8 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
|
||||
}
|
||||
}()
|
||||
|
||||
ce.initializeStack(tp, params)
|
||||
ft := ce.initialFn.funcType
|
||||
ce.initializeStack(ft, params)
|
||||
|
||||
if ce.ensureTermination {
|
||||
done := m.CloseModuleOnCanceledOrTimeout(ctx)
|
||||
@@ -747,12 +759,12 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
|
||||
|
||||
// This returns a safe copy of the results, instead of a slice view. If we
|
||||
// returned a re-slice, the caller could accidentally or purposefully
|
||||
// corrupt the stack of subsequent calls
|
||||
if resultCount := tp.ResultNumInUint64; resultCount > 0 {
|
||||
results = make([]uint64, resultCount)
|
||||
copy(results, ce.stack[:resultCount])
|
||||
// corrupt the stack of subsequent calls.
|
||||
if results == nil && ft.ResultNumInUint64 > 0 {
|
||||
results = make([]uint64, ft.ResultNumInUint64)
|
||||
}
|
||||
return
|
||||
copy(results, ce.stack)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// initializeStack initializes callEngine.stack before entering native code.
|
||||
|
||||
@@ -67,6 +67,16 @@ func TestCompiler_ModuleEngine_Call(t *testing.T) {
|
||||
`, "\n"+functionLog.String())
|
||||
}
|
||||
|
||||
func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
|
||||
defer functionLog.Reset()
|
||||
requireSupportedOSArch(t)
|
||||
enginetest.RunTestModuleEngineCallWithStack(t, et)
|
||||
require.Equal(t, `
|
||||
--> .$0(1,2)
|
||||
<-- (1,2)
|
||||
`, "\n"+functionLog.String())
|
||||
}
|
||||
|
||||
func TestCompiler_ModuleEngine_Call_HostFn(t *testing.T) {
|
||||
defer functionLog.Reset()
|
||||
requireSupportedOSArch(t)
|
||||
|
||||
@@ -116,6 +116,10 @@ func (ce *callEngine) pushValue(v uint64) {
|
||||
ce.stack = append(ce.stack, v)
|
||||
}
|
||||
|
||||
func (ce *callEngine) pushValues(v []uint64) {
|
||||
ce.stack = append(ce.stack, v...)
|
||||
}
|
||||
|
||||
func (ce *callEngine) popValue() (v uint64) {
|
||||
// No need to check stack bound
|
||||
// as we can assume that all the operations
|
||||
@@ -129,6 +133,12 @@ func (ce *callEngine) popValue() (v uint64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (ce *callEngine) popValues(v []uint64) {
|
||||
stackTopIndex := len(ce.stack) - len(v)
|
||||
copy(v, ce.stack[stackTopIndex:])
|
||||
ce.stack = ce.stack[:stackTopIndex]
|
||||
}
|
||||
|
||||
// peekValues peeks api.ValueType values from the stack and returns them.
|
||||
func (ce *callEngine) peekValues(count int) []uint64 {
|
||||
if count == 0 {
|
||||
@@ -445,10 +455,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {
|
||||
|
||||
// Call implements the same method as documented on api.Function.
|
||||
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
|
||||
return ce.call(ctx, ce.compiled, params)
|
||||
ft := ce.compiled.funcType
|
||||
if n := ft.ParamNumInUint64; n != len(params) {
|
||||
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
|
||||
}
|
||||
return ce.call(ctx, params, nil)
|
||||
}
|
||||
|
||||
func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (results []uint64, err error) {
|
||||
// CallWithStack implements the same method as documented on api.Function.
|
||||
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
|
||||
params, results, err := wasm.SplitCallStack(ce.compiled.funcType, stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = ce.call(ctx, params, results)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
|
||||
m := ce.compiled.moduleInstance
|
||||
if ce.compiled.parent.ensureTermination {
|
||||
select {
|
||||
@@ -461,13 +485,6 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
|
||||
}
|
||||
}
|
||||
|
||||
ft := tf.funcType
|
||||
paramSignature := ft.ParamNumInUint64
|
||||
paramCount := len(params)
|
||||
if paramSignature != paramCount {
|
||||
return nil, fmt.Errorf("expected %d params, but passed %d", paramSignature, paramCount)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
|
||||
if err == nil {
|
||||
@@ -480,22 +497,24 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
|
||||
}
|
||||
}()
|
||||
|
||||
for _, param := range params {
|
||||
ce.pushValue(param)
|
||||
}
|
||||
ce.pushValues(params)
|
||||
|
||||
if ce.compiled.parent.ensureTermination {
|
||||
done := m.CloseModuleOnCanceledOrTimeout(ctx)
|
||||
defer done()
|
||||
}
|
||||
|
||||
ce.callFunction(ctx, m, tf)
|
||||
ce.callFunction(ctx, m, ce.compiled)
|
||||
|
||||
// This returns a safe copy of the results, instead of a slice view. If we
|
||||
// returned a re-slice, the caller could accidentally or purposefully
|
||||
// corrupt the stack of subsequent calls.
|
||||
results = wasm.PopValues(ft.ResultNumInUint64, ce.popValue)
|
||||
return
|
||||
ft := ce.compiled.funcType
|
||||
if results == nil && ft.ResultNumInUint64 > 0 {
|
||||
results = make([]uint64, ft.ResultNumInUint64)
|
||||
}
|
||||
ce.popValues(results)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// recoverOnCall takes the recovered value `recoverOnCall`, and wraps it
|
||||
|
||||
@@ -105,6 +105,15 @@ func TestInterpreter_ModuleEngine_Call(t *testing.T) {
|
||||
`, "\n"+functionLog.String())
|
||||
}
|
||||
|
||||
func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
|
||||
defer functionLog.Reset()
|
||||
enginetest.RunTestModuleEngineCallWithStack(t, et)
|
||||
require.Equal(t, `
|
||||
--> .$0(1,2)
|
||||
<-- (1,2)
|
||||
`, "\n"+functionLog.String())
|
||||
}
|
||||
|
||||
func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) {
|
||||
defer functionLog.Reset()
|
||||
enginetest.RunTestModuleEngineCallHostFn(t, et)
|
||||
|
||||
@@ -58,6 +58,23 @@ func BenchmarkHostFunctionCall(b *testing.B) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fn+"_with_stack", func(b *testing.B) {
|
||||
ce := getCallEngine(m, fn)
|
||||
|
||||
b.ResetTimer()
|
||||
stack := make([]uint64, 1)
|
||||
for i := 0; i < b.N; i++ {
|
||||
stack[0] = offset
|
||||
err := ce.CallWithStack(testCtx, stack)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if uint32(stack[0]) != math.Float32bits(val) {
|
||||
b.Fail()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -191,6 +191,58 @@ func RunTestModuleEngineCall(t *testing.T, et EngineTester) {
|
||||
})
|
||||
}
|
||||
|
||||
func RunTestModuleEngineCallWithStack(t *testing.T, et EngineTester) {
|
||||
e := et.NewEngine(api.CoreFeaturesV2)
|
||||
|
||||
// Define a basic function which defines two parameters and two results.
|
||||
// This is used to test results when incorrect arity is used.
|
||||
m := &wasm.Module{
|
||||
TypeSection: []wasm.FunctionType{
|
||||
{
|
||||
Params: []wasm.ValueType{i64, i64},
|
||||
Results: []wasm.ValueType{i64, i64},
|
||||
ParamNumInUint64: 2,
|
||||
ResultNumInUint64: 2,
|
||||
},
|
||||
},
|
||||
FunctionSection: []wasm.Index{0},
|
||||
CodeSection: []wasm.Code{
|
||||
{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeLocalGet, 1, wasm.OpcodeEnd}},
|
||||
},
|
||||
}
|
||||
|
||||
m.BuildFunctionDefinitions()
|
||||
listeners := buildListeners(et.ListenerFactory(), m)
|
||||
err := e.CompileModule(testCtx, m, listeners, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// To use the function, we first need to add it to a module.
|
||||
module := &wasm.ModuleInstance{
|
||||
ModuleName: t.Name(), TypeIDs: []wasm.FunctionTypeID{0},
|
||||
Definitions: m.FunctionDefinitionSection,
|
||||
}
|
||||
|
||||
// Compile the module
|
||||
me, err := e.NewModuleEngine(m, module)
|
||||
require.NoError(t, err)
|
||||
linkModuleToEngine(module, me)
|
||||
|
||||
// Ensure the base case doesn't fail: A single parameter should work as that matches the function signature.
|
||||
const funcIndex = 0
|
||||
ce := me.NewFunction(funcIndex)
|
||||
|
||||
stack := []uint64{1, 2}
|
||||
err = ce.CallWithStack(testCtx, stack)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint64{1, 2}, stack)
|
||||
|
||||
t.Run("errs when not enough parameters", func(t *testing.T) {
|
||||
ce := me.NewFunction(funcIndex)
|
||||
err = ce.CallWithStack(testCtx, nil)
|
||||
require.EqualError(t, err, "need 2 params, but stack size is 0")
|
||||
})
|
||||
}
|
||||
|
||||
func RunTestModuleEngineLookupFunction(t *testing.T, et EngineTester) {
|
||||
e := et.NewEngine(api.CoreFeaturesV1)
|
||||
|
||||
|
||||
@@ -79,26 +79,6 @@ func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
|
||||
callGoFunc(ctx, nil, f.fn, stack)
|
||||
}
|
||||
|
||||
// PopValues pops the specified number of api.ValueType parameters off the
|
||||
// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction.
|
||||
//
|
||||
// For example, if the host function F requires the (x1 uint32, x2 float32)
|
||||
// parameters, and the stack is [..., A, B], then the function is called as
|
||||
// F(A, B) where A and B are interpreted as uint32 and float32 respectively.
|
||||
//
|
||||
// Note: the popper intentionally doesn't return bool or error because the
|
||||
// caller's stack depth is trusted.
|
||||
func PopValues(count int, popper func() uint64) []uint64 {
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
params := make([]uint64, count)
|
||||
for i := count - 1; i >= 0; i-- {
|
||||
params[i] = popper()
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// callGoFunc executes the reflective function by converting params to Go
|
||||
// types. The results of the function call are converted back to api.ValueType.
|
||||
func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
|
||||
|
||||
@@ -124,55 +124,6 @@ func Test_parseGoFunc_Errors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// stack simulates the value stack in a way easy to be tested.
|
||||
type stack struct {
|
||||
vals []uint64
|
||||
}
|
||||
|
||||
func (s *stack) pop() (result uint64) {
|
||||
stackTopIndex := len(s.vals) - 1
|
||||
result = s.vals[stackTopIndex]
|
||||
s.vals = s.vals[:stackTopIndex]
|
||||
return
|
||||
}
|
||||
|
||||
func TestPopValues(t *testing.T) {
|
||||
stackVals := []uint64{1, 2, 3, 4, 5, 6, 7}
|
||||
tests := []struct {
|
||||
name string
|
||||
count int
|
||||
expected []uint64
|
||||
}{
|
||||
{
|
||||
name: "pop zero doesn't allocate a slice ",
|
||||
},
|
||||
{
|
||||
name: "pop 1",
|
||||
count: 1,
|
||||
expected: []uint64{7},
|
||||
},
|
||||
{
|
||||
name: "pop 2",
|
||||
count: 2,
|
||||
expected: []uint64{6, 7},
|
||||
},
|
||||
{
|
||||
name: "pop 3",
|
||||
count: 3,
|
||||
expected: []uint64{5, 6, 7},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tc := tt
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
vals := PopValues(tc.count, (&stack{stackVals}).pop)
|
||||
require.Equal(t, tc.expected, vals)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_callGoFunc(t *testing.T) {
|
||||
tPtr := uintptr(unsafe.Pointer(t))
|
||||
inst := &ModuleInstance{}
|
||||
|
||||
@@ -488,12 +488,16 @@ func (e *mockModuleEngine) Close(context.Context) {
|
||||
func (ce *mockCallEngine) Definition() api.FunctionDefinition { return nil }
|
||||
|
||||
// Call implements the same method as documented on api.Function.
|
||||
func (ce *mockCallEngine) Call(_ context.Context, _ ...uint64) (results []uint64, err error) {
|
||||
func (ce *mockCallEngine) Call(ctx context.Context, _ ...uint64) (results []uint64, err error) {
|
||||
return nil, ce.CallWithStack(ctx, nil)
|
||||
}
|
||||
|
||||
// CallWithStack implements the same method as documented on api.Function.
|
||||
func (ce *mockCallEngine) CallWithStack(_ context.Context, _ []uint64) error {
|
||||
if ce.callFailIndex >= 0 && ce.index == Index(ce.callFailIndex) {
|
||||
err = errors.New("call failed")
|
||||
return
|
||||
return errors.New("call failed")
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestStore_getFunctionTypeID(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user