This PR follows @hafeidejiangyou advice to not only enable end users to
avoid reflection when calling host functions, but also use that approach
ourselves internally. The performance results are staggering and will be
noticable in high performance applications.
Before
```
BenchmarkHostCall/Call
BenchmarkHostCall/Call-16 1000000 1050 ns/op
Benchmark_EnvironGet/environGet
Benchmark_EnvironGet/environGet-16 525492 2224 ns/op
```
Now
```
BenchmarkHostCall/Call
BenchmarkHostCall/Call-16 14807203 83.22 ns/op
Benchmark_EnvironGet/environGet
Benchmark_EnvironGet/environGet-16 951690 1054 ns/op
```
To accomplish this, this PR consolidates code around host function
definition and enables a fast path for functions where the user takes
responsibility for defining its WebAssembly mappings. Existing users
will need to change their code a bit, as signatures have changed.
For example, we are now more strict that all host functions require a
context parameter zero. Also, we've replaced
`HostModuleBuilder.ExportFunction` and `ExportFunctions` with a new type
`HostFunctionBuilder` that consolidates the responsibility and the
documentation.
```diff
ctx := context.Background()
-hello := func() {
+hello := func(context.Context) {
fmt.Fprintln(stdout, "hello!")
}
-_, err := r.NewHostModuleBuilder("env").ExportFunction("hello", hello).Instantiate(ctx, r)
+_, err := r.NewHostModuleBuilder("env").
+ NewFunctionBuilder().WithFunc(hello).Export("hello").
+ Instantiate(ctx, r)
```
Power users can now use `HostFunctionBuilder` to define functions that
won't use reflection. There are two choices of interfaces to use
depending on if that function needs access to the calling module or not:
`api.GoFunction` and `api.GoModuleFunction`. Here's an example defining
one.
```go
builder.WithGoFunction(api.GoFunc(func(ctx context.Context, params []uint64) []uint64 {
x, y := uint32(params[0]), uint32(params[1])
sum := x + y
return []uint64{sum}
}, []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32})
```
As you'll notice and as documented, this approach is more verbose and
not for everyone. If you aren't making a low-level library, you are
likely able to afford the 1us penalty for the convenience of reflection.
However, we are happy to enable this option for foundational libraries
and those with high performance requirements (like ourselves)!
Fixes #825
Signed-off-by: Adrian Cole <adrian@tetrate.io>
279 lines
7.4 KiB
Go
279 lines
7.4 KiB
Go
package wasm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
|
|
"github.com/tetratelabs/wazero/api"
|
|
)
|
|
|
|
// Below are reflection code to get the interface type used to parse functions and set values.
|
|
|
|
var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem()
|
|
var goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
|
|
var errorType = reflect.TypeOf((*error)(nil)).Elem()
|
|
|
|
// compile-time check to ensure reflectGoModuleFunction implements
|
|
// api.GoModuleFunction.
|
|
var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil)
|
|
|
|
type reflectGoModuleFunction struct {
|
|
fn *reflect.Value
|
|
params, results []ValueType
|
|
}
|
|
|
|
// Call implements the same method as documented on api.GoModuleFunction.
|
|
func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, params []uint64) []uint64 {
|
|
return callGoFunc(ctx, mod, f.fn, params)
|
|
}
|
|
|
|
// EqualTo is exposed for testing.
|
|
func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool {
|
|
if f2, ok := that.(*reflectGoModuleFunction); !ok {
|
|
return false
|
|
} else {
|
|
// TODO compare reflect pointers
|
|
return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
|
|
}
|
|
}
|
|
|
|
// compile-time check to ensure reflectGoFunction implements api.GoFunction.
|
|
var _ api.GoFunction = (*reflectGoFunction)(nil)
|
|
|
|
type reflectGoFunction struct {
|
|
fn *reflect.Value
|
|
params, results []ValueType
|
|
}
|
|
|
|
// EqualTo is exposed for testing.
|
|
func (f *reflectGoFunction) EqualTo(that interface{}) bool {
|
|
if f2, ok := that.(*reflectGoFunction); !ok {
|
|
return false
|
|
} else {
|
|
// TODO compare reflect pointers
|
|
return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
|
|
}
|
|
}
|
|
|
|
// Call implements the same method as documented on api.GoFunction.
|
|
func (f *reflectGoFunction) Call(ctx context.Context, params []uint64) []uint64 {
|
|
return callGoFunc(ctx, nil, f.fn, params)
|
|
}
|
|
|
|
// PopValues pops the specified number of api.ValueType parameters off the
|
|
// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction.
|
|
//
|
|
// For example, if the host function F requires the (x1 uint32, x2 float32)
|
|
// parameters, and the stack is [..., A, B], then the function is called as
|
|
// F(A, B) where A and B are interpreted as uint32 and float32 respectively.
|
|
//
|
|
// Note: the popper intentionally doesn't return bool or error because the
|
|
// caller's stack depth is trusted.
|
|
func PopValues(count int, popper func() uint64) []uint64 {
|
|
if count == 0 {
|
|
return nil
|
|
}
|
|
params := make([]uint64, count)
|
|
for i := count - 1; i >= 0; i-- {
|
|
params[i] = popper()
|
|
}
|
|
return params
|
|
}
|
|
|
|
// callGoFunc executes the reflective function by converting params to Go
|
|
// types. The results of the function call are converted back to api.ValueType.
|
|
func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params []uint64) []uint64 {
|
|
tp := fn.Type()
|
|
|
|
var in []reflect.Value
|
|
if tp.NumIn() != 0 {
|
|
in = make([]reflect.Value, tp.NumIn())
|
|
|
|
i := 1
|
|
in[0] = newContextVal(ctx)
|
|
if mod != nil {
|
|
in[1] = newModuleVal(mod)
|
|
i++
|
|
}
|
|
|
|
for _, raw := range params {
|
|
val := reflect.New(tp.In(i)).Elem()
|
|
k := tp.In(i).Kind()
|
|
switch k {
|
|
case reflect.Float32:
|
|
val.SetFloat(float64(math.Float32frombits(uint32(raw))))
|
|
case reflect.Float64:
|
|
val.SetFloat(math.Float64frombits(raw))
|
|
case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
val.SetUint(raw)
|
|
case reflect.Int32, reflect.Int64:
|
|
val.SetInt(int64(raw))
|
|
default:
|
|
panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
|
|
}
|
|
in[i] = val
|
|
i++
|
|
}
|
|
}
|
|
|
|
// Execute the host function and push back the call result onto the stack.
|
|
var results []uint64
|
|
if tp.NumOut() > 0 {
|
|
results = make([]uint64, 0, tp.NumOut())
|
|
}
|
|
for i, ret := range fn.Call(in) {
|
|
switch ret.Kind() {
|
|
case reflect.Float32:
|
|
results = append(results, uint64(math.Float32bits(float32(ret.Float()))))
|
|
case reflect.Float64:
|
|
results = append(results, math.Float64bits(ret.Float()))
|
|
case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
results = append(results, ret.Uint())
|
|
case reflect.Int32, reflect.Int64:
|
|
results = append(results, uint64(ret.Int()))
|
|
default:
|
|
panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
|
|
}
|
|
}
|
|
return results
|
|
}
|
|
|
|
func newContextVal(ctx context.Context) reflect.Value {
|
|
val := reflect.New(goContextType).Elem()
|
|
val.Set(reflect.ValueOf(ctx))
|
|
return val
|
|
}
|
|
|
|
func newModuleVal(m api.Module) reflect.Value {
|
|
val := reflect.New(moduleType).Elem()
|
|
val.Set(reflect.ValueOf(m))
|
|
return val
|
|
}
|
|
|
|
// MustParseGoReflectFuncCode parses Code from the go function or panics.
|
|
//
|
|
// Exposing this simplifies FunctionDefinition of host functions in built-in host
|
|
// modules and tests.
|
|
func MustParseGoReflectFuncCode(fn interface{}) *Code {
|
|
_, _, code, err := parseGoReflectFunc(fn)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return code
|
|
}
|
|
|
|
func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code, err error) {
|
|
fnV := reflect.ValueOf(fn)
|
|
p := fnV.Type()
|
|
|
|
if fnV.Kind() != reflect.Func {
|
|
err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
|
|
return
|
|
}
|
|
|
|
needsMod, needsErr := needsModule(p)
|
|
if needsErr != nil {
|
|
err = needsErr
|
|
return
|
|
}
|
|
|
|
pOffset := 1 // ctx
|
|
if needsMod {
|
|
pOffset = 2 // ctx, mod
|
|
}
|
|
|
|
pCount := p.NumIn() - pOffset
|
|
if pCount > 0 {
|
|
params = make([]ValueType, pCount)
|
|
}
|
|
for i := 0; i < len(params); i++ {
|
|
pI := p.In(i + pOffset)
|
|
if t, ok := getTypeOf(pI.Kind()); ok {
|
|
params[i] = t
|
|
continue
|
|
}
|
|
|
|
// Now, we will definitely err, decide which message is best
|
|
var arg0Type reflect.Type
|
|
if hc := pI.Implements(moduleType); hc {
|
|
arg0Type = moduleType
|
|
} else if gc := pI.Implements(goContextType); gc {
|
|
arg0Type = goContextType
|
|
}
|
|
|
|
if arg0Type != nil {
|
|
err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
|
|
} else {
|
|
err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
|
|
}
|
|
return
|
|
}
|
|
|
|
rCount := p.NumOut()
|
|
if rCount > 0 {
|
|
results = make([]ValueType, rCount)
|
|
}
|
|
for i := 0; i < len(results); i++ {
|
|
rI := p.Out(i)
|
|
if t, ok := getTypeOf(rI.Kind()); ok {
|
|
results[i] = t
|
|
continue
|
|
}
|
|
|
|
// Now, we will definitely err, decide which message is best
|
|
if rI.Implements(errorType) {
|
|
err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
|
|
} else {
|
|
err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
|
|
}
|
|
return
|
|
}
|
|
|
|
code = &Code{IsHostFunction: true}
|
|
if needsMod {
|
|
code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
|
|
} else {
|
|
code.GoFunc = &reflectGoFunction{fn: &fnV, params: params, results: results}
|
|
}
|
|
return
|
|
}
|
|
|
|
func needsModule(p reflect.Type) (bool, error) {
|
|
pCount := p.NumIn()
|
|
if pCount == 0 {
|
|
return false, errors.New("invalid signature: context.Context must be param[0]")
|
|
}
|
|
if p.In(0).Kind() == reflect.Interface {
|
|
p0 := p.In(0)
|
|
if p0.Implements(moduleType) {
|
|
return false, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
|
|
} else if p0.Implements(goContextType) {
|
|
if pCount >= 2 && p.In(1).Implements(moduleType) {
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func getTypeOf(kind reflect.Kind) (ValueType, bool) {
|
|
switch kind {
|
|
case reflect.Float64:
|
|
return ValueTypeF64, true
|
|
case reflect.Float32:
|
|
return ValueTypeF32, true
|
|
case reflect.Int32, reflect.Uint32:
|
|
return ValueTypeI32, true
|
|
case reflect.Int64, reflect.Uint64:
|
|
return ValueTypeI64, true
|
|
case reflect.Uintptr:
|
|
return ValueTypeExternref, true
|
|
default:
|
|
return 0x00, false
|
|
}
|
|
}
|