Restores ability to define host functions w/o context via reflection (#832)

This restores the ability to leave out the initial context parameter
when defining functions with reflection. This is important because some
projects are porting from a different library to wazero, and all the
alternatives are not contextualized.

For example, this project is porting envoy host functions, and the
original definitions (in mosn) don't have a context parameter. By being
lenient, they can migrate easier.

See 6b813482b6/pkg/proxywasm/wazero/imports_v1.go

Signed-off-by: Adrian Cole <adrian@tetrate.io>
This commit is contained in:
Crypt Keeper
2022-10-28 12:44:12 -07:00
committed by GitHub
parent 1ac6c06acb
commit d108ce4c43
25 changed files with 115 additions and 69 deletions

View File

@@ -216,7 +216,7 @@ func (c *compiledModule) Name() (moduleName string) {
}
// Close implements CompiledModule.Close
func (c *compiledModule) Close(_ context.Context) error {
func (c *compiledModule) Close(context.Context) error {
c.compiledEngine.DeleteCompiledModule(c.module)
// It is possible the underlying may need to return an error later, but in any case this matches api.Module.Close.
return nil

View File

@@ -40,12 +40,12 @@ func main() {
// host-defined functions, but any name would do.
_, err := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, v uint32) {
WithFunc(func(v uint32) {
fmt.Println("log_i32 >>", v)
}).
Export("log_i32").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 {
WithFunc(func() uint32 {
if envYear, err := strconv.ParseUint(os.Getenv("CURRENT_YEAR"), 10, 64); err == nil {
return uint32(envYear) // Allow env-override to prevent annual test maintenance!
}

View File

@@ -58,7 +58,7 @@ type counter struct {
counter uint32
}
func (e *counter) getAndIncrement(context.Context) (ret uint32) {
func (e *counter) getAndIncrement() (ret uint32) {
ret = e.counter
e.counter++
return

View File

@@ -266,7 +266,7 @@ func Test_loggingListener(t *testing.T) {
var out bytes.Buffer
lf := logging.NewLoggingListenerFactory(&out)
fn := func(context.Context) {}
fn := func() {}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {

View File

@@ -33,7 +33,7 @@ func Example_functionExporter() {
// First construct your own module builder for "env"
envBuilder := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 { return 1 }).
WithFunc(func() uint32 { return 1 }).
Export("get_int")
// Now, add AssemblyScript special function imports into it.

View File

@@ -40,7 +40,7 @@ func Example_functionExporter() {
// you need.
envBuilder := r.NewHostModuleBuilder("env").
NewFunctionBuilder().
WithFunc(func(context.Context) uint32 { return 1 }).
WithFunc(func() uint32 { return 1 }).
Export("get_int")
// Now, add Emscripten special function imports into it.

View File

@@ -194,7 +194,7 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) {
const hostModuleName = "env"
const hostFnName = "grow_and_shrink_goroutine_stack"
hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func(context.Context) {
hm, err := wasm.NewHostModule(hostModuleName, map[string]interface{}{hostFnName: func() {
// This function aggressively grow the goroutine stack by recursively
// calling the function many times.
var callNum = 1000

View File

@@ -165,7 +165,7 @@ func testGlobalExtend(t *testing.T, r wazero.Runtime) {
}
func testUnreachable(t *testing.T, r wazero.Runtime) {
callUnreachable := func(context.Context) {
callUnreachable := func() {
panic("panic in host function")
}

View File

@@ -124,7 +124,7 @@ func (r *wasmedgeRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig)
return
}
func (r *wasmedgeRuntime) Close(_ context.Context) error {
func (r *wasmedgeRuntime) Close(context.Context) error {
if conf := r.conf; conf != nil {
conf.Release()
}
@@ -183,7 +183,7 @@ func (m *wasmedgeModule) WriteMemory(_ context.Context, offset uint32, bytes []b
return nil
}
func (m *wasmedgeModule) Close(_ context.Context) error {
func (m *wasmedgeModule) Close(context.Context) error {
if env := m.env; env != nil {
env.Release()
}

View File

@@ -144,7 +144,7 @@ func (r *wasmerRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig) (m
return
}
func (r *wasmerRuntime) Close(_ context.Context) error {
func (r *wasmerRuntime) Close(context.Context) error {
r.engine = nil
return nil
}
@@ -195,7 +195,7 @@ func (m *wasmerModule) Memory() []byte {
return m.mem.Data()
}
func (m *wasmerModule) Close(_ context.Context) error {
func (m *wasmerModule) Close(context.Context) error {
if instance := m.instance; instance != nil {
instance.Close()
}

View File

@@ -142,7 +142,7 @@ func (r *wasmtimeRuntime) Instantiate(_ context.Context, cfg *vs.RuntimeConfig)
return
}
func (r *wasmtimeRuntime) Close(_ context.Context) error {
func (r *wasmtimeRuntime) Close(context.Context) error {
r.engine = nil
return nil // wasmtime only closes via finalizer
}
@@ -193,7 +193,7 @@ func (m *wasmtimeModule) WriteMemory(_ context.Context, offset uint32, bytes []b
return nil
}
func (m *wasmtimeModule) Close(_ context.Context) error {
func (m *wasmtimeModule) Close(context.Context) error {
m.store = nil
m.instance = nil
m.funcs = nil

View File

@@ -153,7 +153,7 @@ func (c *FSContext) CloseFile(_ context.Context, fd uint32) bool {
}
// Close implements io.Closer
func (c *FSContext) Close(_ context.Context) (err error) {
func (c *FSContext) Close(context.Context) (err error) {
// Close any files opened in this context
for fd, entry := range c.openedFiles {
delete(c.openedFiles, fd)

View File

@@ -672,7 +672,7 @@ const (
callImportCallDivByGoName = "call_import->" + callDivByGoName
)
func divByGo(_ context.Context, d uint32) uint32 {
func divByGo(d uint32) uint32 {
if d == math.MaxUint32 {
panic(errors.New("host-function panic"))
}

View File

@@ -1,7 +1,6 @@
package binary
import (
"context"
"testing"
"github.com/tetratelabs/wazero/internal/leb128"
@@ -211,7 +210,7 @@ func TestModule_Encode(t *testing.T) {
func TestModule_Encode_HostFunctionSection_Unsupported(t *testing.T) {
// We don't currently have an approach to serialize reflect.Value pointers
fn := func(context.Context) {}
fn := func() {}
captured := require.CapturePanic(func() {
EncodeModule(&wasm.Module{

View File

@@ -1,7 +1,6 @@
package wasm
import (
"context"
"testing"
"github.com/tetratelabs/wazero/api"
@@ -10,7 +9,7 @@ import (
func TestModule_BuildFunctionDefinitions(t *testing.T) {
nopCode := &Code{Body: []byte{OpcodeEnd}}
fn := func(context.Context) {}
fn := func() {}
tests := []struct {
name string
m *Module

View File

@@ -20,7 +20,7 @@ func (g *mutableGlobal) Type() api.ValueType {
}
// Get implements the same method as documented on api.Global.
func (g *mutableGlobal) Get(_ context.Context) uint64 {
func (g *mutableGlobal) Get(context.Context) uint64 {
return g.g.Val
}
@@ -54,7 +54,7 @@ func (g globalI32) Type() api.ValueType {
}
// Get implements the same method as documented on api.Global.
func (g globalI32) Get(_ context.Context) uint64 {
func (g globalI32) Get(context.Context) uint64 {
return uint64(g)
}
@@ -74,7 +74,7 @@ func (g globalI64) Type() api.ValueType {
}
// Get implements the same method as documented on api.Global.
func (g globalI64) Get(_ context.Context) uint64 {
func (g globalI64) Get(context.Context) uint64 {
return uint64(g)
}
@@ -94,7 +94,7 @@ func (g globalF32) Type() api.ValueType {
}
// Get implements the same method as documented on api.Global.
func (g globalF32) Get(_ context.Context) uint64 {
func (g globalF32) Get(context.Context) uint64 {
return uint64(g)
}
@@ -114,7 +114,7 @@ func (g globalF64) Type() api.ValueType {
}
// Get implements the same method as documented on api.Global.
func (g globalF64) Get(_ context.Context) uint64 {
func (g globalF64) Get(context.Context) uint64 {
return uint64(g)
}

View File

@@ -11,6 +11,14 @@ import (
"github.com/tetratelabs/wazero/api"
)
type paramsKind byte
const (
paramsKindNoContext paramsKind = iota
paramsKindContext
paramsKindContextModule
)
// Below are reflection code to get the interface type used to parse functions and set values.
var moduleType = reflect.TypeOf((*api.Module)(nil)).Elem()
@@ -46,6 +54,7 @@ var _ api.GoFunction = (*reflectGoFunction)(nil)
type reflectGoFunction struct {
fn *reflect.Value
pk paramsKind
params, results []ValueType
}
@@ -55,12 +64,16 @@ func (f *reflectGoFunction) EqualTo(that interface{}) bool {
return false
} else {
// TODO compare reflect pointers
return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
return f.pk == f2.pk &&
bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
}
}
// Call implements the same method as documented on api.GoFunction.
func (f *reflectGoFunction) Call(ctx context.Context, params []uint64) []uint64 {
if f.pk == paramsKindNoContext {
ctx = nil
}
return callGoFunc(ctx, nil, f.fn, params)
}
@@ -93,8 +106,11 @@ func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, params [
if tp.NumIn() != 0 {
in = make([]reflect.Value, tp.NumIn())
i := 1
i := 0
if ctx != nil {
in[0] = newContextVal(ctx)
i++
}
if mod != nil {
in[1] = newModuleVal(mod)
i++
@@ -175,15 +191,19 @@ func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code
return
}
needsMod, needsErr := needsModule(p)
if needsErr != nil {
err = needsErr
pk, kindErr := kind(p)
if kindErr != nil {
err = kindErr
return
}
pOffset := 1 // ctx
if needsMod {
pOffset = 2 // ctx, mod
pOffset := 0
switch pk {
case paramsKindNoContext:
case paramsKindContext:
pOffset = 1
case paramsKindContextModule:
pOffset = 2
}
pCount := p.NumIn() - pOffset
@@ -234,30 +254,30 @@ func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code *Code
}
code = &Code{IsHostFunction: true}
if needsMod {
if pk == paramsKindContextModule {
code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
} else {
code.GoFunc = &reflectGoFunction{fn: &fnV, params: params, results: results}
code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
}
return
}
func needsModule(p reflect.Type) (bool, error) {
func kind(p reflect.Type) (paramsKind, error) {
pCount := p.NumIn()
if pCount == 0 {
return false, errors.New("invalid signature: context.Context must be param[0]")
}
if p.In(0).Kind() == reflect.Interface {
if pCount > 0 && p.In(0).Kind() == reflect.Interface {
p0 := p.In(0)
if p0.Implements(moduleType) {
return false, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
} else if p0.Implements(goContextType) {
if pCount >= 2 && p.In(1).Implements(moduleType) {
return true, nil
return paramsKindContextModule, nil
}
return paramsKindContext, nil
}
}
}
return false, nil
// Without context param allows portability with reflective runtimes.
// This allows people to more easily port to wazero.
return paramsKindNoContext, nil
}
func getTypeOf(kind reflect.Kind) (ValueType, bool) {

View File

@@ -20,6 +20,11 @@ func Test_parseGoFunc(t *testing.T) {
expectNeedsModule bool
expectedType *FunctionType
}{
{
name: "() -> ()",
input: func() {},
expectedType: &FunctionType{},
},
{
name: "(ctx) -> ()",
input: func(context.Context) {},
@@ -33,11 +38,16 @@ func Test_parseGoFunc(t *testing.T) {
},
{
name: "all supported params and i32 result",
input: func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
},
{
name: "all supported params and i32 result - (ctx)",
input: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
},
{
name: "all supported params and i32 result - context.Context and api.Module",
name: "all supported params and i32 result - (ctx, mod)",
input: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
expectNeedsModule: true,
expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
@@ -62,11 +72,6 @@ func Test_parseGoFunc_Errors(t *testing.T) {
input interface{}
expectedErr string
}{
{
name: "no context",
input: func() {},
expectedErr: "invalid signature: context.Context must be param[0]",
},
{
name: "module no context",
input: func(api.Module) {},
@@ -84,12 +89,12 @@ func Test_parseGoFunc_Errors(t *testing.T) {
},
{
name: "unsupported result",
input: func(context.Context) string { return "" },
input: func() string { return "" },
expectedErr: "result[0] is unsupported: string",
},
{
name: "error result",
input: func(context.Context) error { return nil },
input: func() error { return nil },
expectedErr: "result[0] is an error, which is unsupported",
},
{
@@ -178,20 +183,43 @@ func Test_callGoFunc(t *testing.T) {
inputParams, expectedResults []uint64
}{
{
name: "context.Context void return",
name: "() -> ()",
input: func() {},
},
{
name: "(ctx) -> ()",
input: func(ctx context.Context) {
require.Equal(t, testCtx, ctx)
},
},
{
name: "context.Context and api.Module void return",
name: "(ctx, mod) -> ()",
input: func(ctx context.Context, m api.Module) {
require.Equal(t, testCtx, ctx)
require.Equal(t, callCtx, m)
},
},
{
name: "all supported params and i32 result - context.Context",
name: "all supported params and i32 result",
input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
require.Equal(t, tPtr, v)
require.Equal(t, uint32(math.MaxUint32), w)
require.Equal(t, uint64(math.MaxUint64), x)
require.Equal(t, float32(math.MaxFloat32), y)
require.Equal(t, math.MaxFloat64, z)
return 100
},
inputParams: []uint64{
api.EncodeExternref(tPtr),
math.MaxUint32,
math.MaxUint64,
api.EncodeF32(math.MaxFloat32),
api.EncodeF64(math.MaxFloat64),
},
expectedResults: []uint64{100},
},
{
name: "all supported params and i32 result - (ctx)",
input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
require.Equal(t, testCtx, ctx)
require.Equal(t, tPtr, v)
@@ -211,7 +239,7 @@ func Test_callGoFunc(t *testing.T) {
expectedResults: []uint64{100},
},
{
name: "all supported params and i32 result - context.Context and api.Module",
name: "all supported params and i32 result - (ctx, mod)",
input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
require.Equal(t, testCtx, ctx)
require.Equal(t, callCtx, m)

View File

@@ -135,7 +135,7 @@ func TestNewHostModule_Errors(t *testing.T) {
},
{
name: "function has multiple results",
nameToGoFunc: map[string]interface{}{"fn": func(context.Context) (uint32, uint32) { return 0, 0 }},
nameToGoFunc: map[string]interface{}{"fn": func() (uint32, uint32) { return 0, 0 }},
expectedErr: "func[.fn] multiple result types invalid as feature \"multi-value\" is disabled",
},
}

View File

@@ -59,7 +59,7 @@ func (m *MemoryInstance) Definition() api.MemoryDefinition {
}
// Size implements the same method as documented on api.Memory.
func (m *MemoryInstance) Size(_ context.Context) uint32 {
func (m *MemoryInstance) Size(context.Context) uint32 {
return m.size()
}
@@ -203,7 +203,7 @@ func (m *MemoryInstance) Grow(_ context.Context, delta uint32) (result uint32, o
}
// PageSize returns the current memory buffer size in pages.
func (m *MemoryInstance) PageSize(_ context.Context) (result uint32) {
func (m *MemoryInstance) PageSize(context.Context) (result uint32) {
return memoryBytesNumToPages(uint64(len(m.Buffer)))
}

View File

@@ -278,7 +278,7 @@ func NewStore(enabledFeatures api.CoreFeatures, engine Engine) (*Store, *Namespa
}
// NewNamespace implements the same method as documented on wazero.Runtime.
func (s *Store) NewNamespace(_ context.Context) *Namespace {
func (s *Store) NewNamespace(context.Context) *Namespace {
ns := newNamespace()
s.mux.Lock()
defer s.mux.Unlock()

View File

@@ -91,7 +91,7 @@ func TestModuleInstance_Memory(t *testing.T) {
func TestStore_Instantiate(t *testing.T) {
s, ns := newStore()
m, err := NewHostModule("", map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1)
m, err := NewHostModule("", map[string]interface{}{"fn": func() {}}, nil, api.CoreFeaturesV1)
require.NoError(t, err)
sysCtx := sys.DefaultContext(nil)
@@ -169,7 +169,7 @@ func TestStore_CloseWithExitCode(t *testing.T) {
func TestStore_hammer(t *testing.T) {
const importedModuleName = "imported"
m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1)
m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func() {}}, nil, api.CoreFeaturesV1)
require.NoError(t, err)
s, ns := newStore()
@@ -223,7 +223,7 @@ func TestStore_Instantiate_Errors(t *testing.T) {
const importedModuleName = "imported"
const importingModuleName = "test"
m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func(context.Context) {}}, nil, api.CoreFeaturesV1)
m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func() {}}, nil, api.CoreFeaturesV1)
require.NoError(t, err)
t.Run("Fails if module name already in use", func(t *testing.T) {
@@ -314,7 +314,7 @@ func TestStore_Instantiate_Errors(t *testing.T) {
}
func TestCallContext_ExportedFunction(t *testing.T) {
host, err := NewHostModule("host", map[string]interface{}{"host_fn": func(context.Context) {}}, nil, api.CoreFeaturesV1)
host, err := NewHostModule("host", map[string]interface{}{"host_fn": func() {}}, nil, api.CoreFeaturesV1)
require.NoError(t, err)
s, ns := newStore()
@@ -396,7 +396,7 @@ func (e *mockModuleEngine) Name() string {
}
// Close implements the same method as documented on wasm.ModuleEngine.
func (e *mockModuleEngine) Close(_ context.Context) {
func (e *mockModuleEngine) Close(context.Context) {
}
// Call implements the same method as documented on wasm.ModuleEngine.

View File

@@ -82,7 +82,7 @@ func TestCompile(t *testing.T) {
module: &wasm.Module{
TypeSection: []*wasm.FunctionType{v_v},
FunctionSection: []wasm.Index{0},
CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(func(context.Context) {})},
CodeSection: []*wasm.Code{wasm.MustParseGoReflectFuncCode(func() {})},
},
expected: &CompilationResult{IsHostFunction: true},
},

View File

@@ -27,7 +27,7 @@ type Runtime interface {
// Below defines and instantiates a module named "env" with one function:
//
// ctx := context.Background()
// hello := func(context.Context) {
// hello := func() {
// fmt.Fprintln(stdout, "hello!")
// }
// _, err := r.NewHostModuleBuilder("env").

View File

@@ -405,7 +405,7 @@ func TestRuntime_InstantiateModuleFromBinary_ErrorOnStart(t *testing.T) {
r := NewRuntime(testCtx)
defer r.Close(testCtx)
start := func(context.Context) {
start := func() {
panic(errors.New("ice cream"))
}