From cfb11f352a5b51b3cfcabdbefccfa659ed284065 Mon Sep 17 00:00:00 2001 From: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com> Date: Thu, 3 Mar 2022 10:31:10 +0800 Subject: [PATCH] Adds ability to disable mutable globals and improves decode perf (#315) This adds `RuntimeConfig.WithFeatureMutableGlobal(enabled bool)`, which allows disabling of mutable globals. When disabled, any attempt to add a mutable global, either explicitly or implicitly via decoding wasm will fail. To support this, there's a new `Features` bitflag that can allow up to 63 feature toggles without passing structs. While here, I fixed a significant performance problem in decoding binary: Before ``` BenchmarkCodecExample/binary.DecodeModule-16 184243 5623 ns/op 3848 B/op 184 allocs/op ``` Now ``` BenchmarkCodecExample/binary.DecodeModule-16 294084 3520 ns/op 2176 B/op 91 allocs/op ``` Signed-off-by: Adrian Cole --- config.go | 30 ++++- config_test.go | 48 +++++++ internal/leb128/leb128.go | 28 ++--- internal/wasi/wasi_test.go | 5 +- internal/wasm/binary/code.go | 27 ++-- internal/wasm/binary/const_expr.go | 39 +++--- internal/wasm/binary/data.go | 3 +- internal/wasm/binary/decoder.go | 12 +- internal/wasm/binary/decoder_test.go | 30 ++++- internal/wasm/binary/element.go | 4 +- internal/wasm/binary/export.go | 9 +- internal/wasm/binary/global.go | 16 +-- internal/wasm/binary/import.go | 37 +++--- internal/wasm/binary/limits.go | 11 +- internal/wasm/binary/memory.go | 4 +- internal/wasm/binary/section.go | 25 ++-- internal/wasm/binary/types.go | 23 ++-- internal/wasm/binary/value.go | 2 +- internal/wasm/features.go | 55 ++++++++ internal/wasm/features_test.go | 94 ++++++++++++++ internal/wasm/func_validation.go | 35 +++--- internal/wasm/global_test.go | 3 +- internal/wasm/interpreter/interpreter_test.go | 2 +- internal/wasm/jit/engine_test.go | 2 +- internal/wasm/module.go | 4 +- internal/wasm/store.go | 15 ++- internal/wasm/store_test.go | 119 ++++++++++++------ internal/wasm/text/decoder.go | 6 +- internal/wasm/text/decoder_test.go | 4 +- internal/wazeroir/compiler.go | 22 ++-- tests/spectest/spec_test.go | 6 +- vs/codec_test.go | 10 +- wasm.go | 10 +- 33 files changed, 519 insertions(+), 221 deletions(-) create mode 100644 config_test.go create mode 100644 internal/wasm/features.go create mode 100644 internal/wasm/features_test.go diff --git a/config.go b/config.go index 2b54f15a..aa2a0166 100644 --- a/config.go +++ b/config.go @@ -13,18 +13,27 @@ import ( // Note: This panics at runtime the runtime.GOOS or runtime.GOARCH does not support JIT. Use NewRuntimeConfig to safely // detect and fallback to NewRuntimeConfigInterpreter if needed. func NewRuntimeConfigJIT() *RuntimeConfig { - return &RuntimeConfig{engine: jit.NewEngine(), ctx: context.Background()} + return &RuntimeConfig{ + engine: jit.NewEngine(), + ctx: context.Background(), + enabledFeatures: internalwasm.Features20191205, + } } // NewRuntimeConfigInterpreter interprets WebAssembly modules instead of compiling them into assembly. func NewRuntimeConfigInterpreter() *RuntimeConfig { - return &RuntimeConfig{engine: interpreter.NewEngine(), ctx: context.Background()} + return &RuntimeConfig{ + engine: interpreter.NewEngine(), + ctx: context.Background(), + enabledFeatures: internalwasm.Features20191205, + } } // RuntimeConfig controls runtime behavior, with the default implementation as NewRuntimeConfig type RuntimeConfig struct { - engine internalwasm.Engine - ctx context.Context + engine internalwasm.Engine + ctx context.Context + enabledFeatures internalwasm.Features } // WithContext sets the default context used to initialize the module. Defaults to context.Background if nil. @@ -39,7 +48,18 @@ func (r *RuntimeConfig) WithContext(ctx context.Context) *RuntimeConfig { if ctx == nil { ctx = context.Background() } - return &RuntimeConfig{engine: r.engine, ctx: ctx} + return &RuntimeConfig{engine: r.engine, ctx: ctx, enabledFeatures: r.enabledFeatures} +} + +// WithFeatureMutableGlobal allows globals to be mutable. This defaults to true as the feature was finished in +// WebAssembly 1.0 (20191205). +// +// When false, a wasm.Global can never be cast to a wasm.MutableGlobal, and any source that includes mutable globals +// will fail to parse. +// +func (r *RuntimeConfig) WithFeatureMutableGlobal(enabled bool) *RuntimeConfig { + enabledFeatures := r.enabledFeatures.Set(internalwasm.FeatureMutableGlobal, enabled) + return &RuntimeConfig{engine: r.engine, ctx: r.ctx, enabledFeatures: enabledFeatures} } // DecodedModule is a WebAssembly 1.0 (20191205) text or binary encoded module to instantiate. diff --git a/config_test.go b/config_test.go new file mode 100644 index 00000000..93b15cc9 --- /dev/null +++ b/config_test.go @@ -0,0 +1,48 @@ +package wazero + +import ( + "testing" + + "github.com/stretchr/testify/require" + + internalwasm "github.com/tetratelabs/wazero/internal/wasm" +) + +func TestRuntimeConfig_Features(t *testing.T) { + tests := []struct { + name string + feature internalwasm.Features + expectDefault bool + setFeature func(*RuntimeConfig, bool) *RuntimeConfig + }{ + { + name: "mutable-global", + feature: internalwasm.FeatureMutableGlobal, + expectDefault: true, + setFeature: func(c *RuntimeConfig, v bool) *RuntimeConfig { + return c.WithFeatureMutableGlobal(v) + }, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + c := NewRuntimeConfig() + require.Equal(t, tc.expectDefault, c.enabledFeatures.Get(tc.feature)) + + // Set to false even if it was initially false. + c = tc.setFeature(c, false) + require.False(t, c.enabledFeatures.Get(tc.feature)) + + // Set true makes it true + c = tc.setFeature(c, true) + require.True(t, c.enabledFeatures.Get(tc.feature)) + + // Set false makes it false again + c = tc.setFeature(c, false) + require.False(t, c.enabledFeatures.Get(tc.feature)) + }) + } +} diff --git a/internal/leb128/leb128.go b/internal/leb128/leb128.go index 627f7d47..b9a033d3 100644 --- a/internal/leb128/leb128.go +++ b/internal/leb128/leb128.go @@ -1,9 +1,9 @@ package leb128 import ( + "bytes" "errors" "fmt" - "io" ) const ( @@ -57,12 +57,12 @@ func EncodeUint32(value uint32) (buf []byte) { } } -func DecodeUint32(r io.Reader) (ret uint32, bytesRead uint64, err error) { +func DecodeUint32(r *bytes.Reader) (ret uint32, bytesRead uint64, err error) { // Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go // with the modification on the overflow handling tailored for 32-bits. var s uint32 for i := 0; i < maxVarintLen32; i++ { - b, err := readByte(r) + b, err := r.ReadByte() if err != nil { return 0, 0, err } @@ -79,11 +79,11 @@ func DecodeUint32(r io.Reader) (ret uint32, bytesRead uint64, err error) { return 0, 0, errOverflow32 } -func DecodeUint64(r io.Reader) (ret uint64, bytesRead uint64, err error) { +func DecodeUint64(r *bytes.Reader) (ret uint64, bytesRead uint64, err error) { // Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go var s uint64 for i := 0; i < maxVarintLen64; i++ { - b, err := readByte(r) + b, err := r.ReadByte() if err != nil { return 0, 0, err } @@ -100,11 +100,11 @@ func DecodeUint64(r io.Reader) (ret uint64, bytesRead uint64, err error) { return 0, 0, errOverflow64 } -func DecodeInt32(r io.Reader) (ret int32, bytesRead uint64, err error) { +func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) { var shift int var b byte for { - b, err = readByte(r) + b, err = r.ReadByte() if err != nil { return 0, 0, fmt.Errorf("readByte failed: %w", err) } @@ -133,7 +133,7 @@ func DecodeInt32(r io.Reader) (ret int32, bytesRead uint64, err error) { // still needs to fit the 32-bit range of allowed indices. Hence, this is 33, not 32-bit! // // See https://webassembly.github.io/spec/core/binary/instructions.html#control-instructions -func DecodeInt33AsInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { +func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) { const ( int33Mask int64 = 1 << 7 int33Mask2 = ^int33Mask @@ -146,7 +146,7 @@ func DecodeInt33AsInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { var b int64 var rb byte for shift < 35 { - rb, err = readByte(r) + rb, err = r.ReadByte() if err != nil { return 0, 0, fmt.Errorf("readByte failed: %w", err) } @@ -181,7 +181,7 @@ func DecodeInt33AsInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { return ret, bytesRead, nil } -func DecodeInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { +func DecodeInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) { const ( int64Mask3 = 1 << 6 int64Mask4 = ^0 @@ -189,7 +189,7 @@ func DecodeInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { var shift int var b byte for { - b, err = readByte(r) + b, err = r.ReadByte() if err != nil { return 0, 0, fmt.Errorf("readByte failed: %w", err) } @@ -213,9 +213,3 @@ func DecodeInt64(r io.Reader) (ret int64, bytesRead uint64, err error) { } } } - -func readByte(r io.Reader) (byte, error) { - b := make([]byte, 1) - _, err := io.ReadFull(r, b) - return b[0], err -} diff --git a/internal/wasi/wasi_test.go b/internal/wasi/wasi_test.go index 190b59d2..1e1654a7 100644 --- a/internal/wasi/wasi_test.go +++ b/internal/wasi/wasi_test.go @@ -1094,15 +1094,16 @@ func TestSnapshotPreview1_RandomGet_SourceError(t *testing.T) { // TODO: TestSnapshotPreview1_SockShutdown TestSnapshotPreview1_SockShutdown_Errors func instantiateWasmStore(t *testing.T, wasiFunction, wasiImport, moduleName string, opts ...Option) (*wasm.Store, *wasm.ModuleContext, *wasm.FunctionInstance) { + enabledFeatures := wasm.Features20191205 mod, err := text.DecodeModule([]byte(fmt.Sprintf(`(module %[2]s (memory 1) ;; just an arbitrary size big enough for tests (export "memory" (memory 0)) (export "%[1]s" (func $wasi.%[1]s)) -)`, wasiFunction, wasiImport))) +)`, wasiFunction, wasiImport)), enabledFeatures) require.NoError(t, err) - store := wasm.NewStore(context.Background(), interpreter.NewEngine()) + store := wasm.NewStore(context.Background(), interpreter.NewEngine(), enabledFeatures) snapshotPreview1Functions := SnapshotPreview1Functions(opts...) goFunc := snapshotPreview1Functions[wasiFunction] diff --git a/internal/wasm/binary/code.go b/internal/wasm/binary/code.go index e0d3cb82..41117a54 100644 --- a/internal/wasm/binary/code.go +++ b/internal/wasm/binary/code.go @@ -1,6 +1,7 @@ package binary import ( + "bytes" "fmt" "io" "math" @@ -9,37 +10,43 @@ import ( wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeCode(r io.Reader) (*wasm.Code, error) { +func decodeCode(r *bytes.Reader) (*wasm.Code, error) { ss, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get the size of code: %w", err) } - - r = io.LimitReader(r, int64(ss)) + remaining := int64(ss) // parse locals - ls, _, err := leb128.DecodeUint32(r) + ls, bytesRead, err := leb128.DecodeUint32(r) + remaining -= int64(bytesRead) if err != nil { return nil, fmt.Errorf("get the size locals: %v", err) + } else if remaining < 0 { + return nil, io.EOF } var nums []uint64 var types []wasm.ValueType var sum uint64 - b := make([]byte, 1) + var n uint32 for i := uint32(0); i < ls; i++ { - n, _, err := leb128.DecodeUint32(r) + n, bytesRead, err = leb128.DecodeUint32(r) + remaining -= int64(bytesRead) + 1 // +1 for the subsequent ReadByte if err != nil { return nil, fmt.Errorf("read n of locals: %v", err) + } else if remaining < 0 { + return nil, io.EOF } + sum += uint64(n) nums = append(nums, uint64(n)) - _, err = io.ReadFull(r, b) + b, err := r.ReadByte() if err != nil { return nil, fmt.Errorf("read type of local: %v", err) } - switch vt := b[0]; vt { + switch vt := b; vt { case wasm.ValueTypeI32, wasm.ValueTypeF32, wasm.ValueTypeI64, wasm.ValueTypeF64: types = append(types, vt) default: @@ -59,8 +66,8 @@ func decodeCode(r io.Reader) (*wasm.Code, error) { } } - body, err := io.ReadAll(r) - if err != nil { + body := make([]byte, remaining) + if _, err = io.ReadFull(r, body); err != nil { return nil, fmt.Errorf("read body: %w", err) } diff --git a/internal/wasm/binary/const_expr.go b/internal/wasm/binary/const_expr.go index 4a8d662b..4eb6d05b 100644 --- a/internal/wasm/binary/const_expr.go +++ b/internal/wasm/binary/const_expr.go @@ -3,52 +3,53 @@ package binary import ( "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/ieee754" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeConstantExpression(r io.Reader) (*wasm.ConstantExpression, error) { - b := make([]byte, 1) - _, err := io.ReadFull(r, b) +func decodeConstantExpression(r *bytes.Reader) (*wasm.ConstantExpression, error) { + b, err := r.ReadByte() if err != nil { return nil, fmt.Errorf("read opcode: %v", err) } - buf := new(bytes.Buffer) - teeR := io.TeeReader(r, buf) - opcode := b[0] + remainingBeforeData := int64(r.Len()) + offsetAtData := r.Size() - remainingBeforeData + + opcode := b switch opcode { case wasm.OpcodeI32Const: - _, _, err = leb128.DecodeInt32(teeR) + _, _, err = leb128.DecodeInt32(r) case wasm.OpcodeI64Const: - _, _, err = leb128.DecodeInt64(teeR) + _, _, err = leb128.DecodeInt64(r) case wasm.OpcodeF32Const: - _, err = ieee754.DecodeFloat32(teeR) + _, err = ieee754.DecodeFloat32(r) case wasm.OpcodeF64Const: - _, err = ieee754.DecodeFloat64(teeR) + _, err = ieee754.DecodeFloat64(r) case wasm.OpcodeGlobalGet: - _, _, err = leb128.DecodeUint32(teeR) + _, _, err = leb128.DecodeUint32(r) default: - return nil, fmt.Errorf("%v for const expression opt code: %#x", ErrInvalidByte, b[0]) + return nil, fmt.Errorf("%v for const expression opt code: %#x", ErrInvalidByte, b) } if err != nil { return nil, fmt.Errorf("read value: %v", err) } - if _, err := io.ReadFull(r, b); err != nil { + if b, err = r.ReadByte(); err != nil { return nil, fmt.Errorf("look for end opcode: %v", err) } - if b[0] != byte(wasm.OpcodeEnd) { + if b != wasm.OpcodeEnd { return nil, fmt.Errorf("constant expression has been not terminated") } - return &wasm.ConstantExpression{ - Opcode: opcode, - Data: buf.Bytes(), - }, nil + data := make([]byte, remainingBeforeData-int64(r.Len())) + if _, err := r.ReadAt(data, offsetAtData); err != nil { + return nil, fmt.Errorf("error re-buffering ConstantExpression.Data") + } + + return &wasm.ConstantExpression{Opcode: opcode, Data: data}, nil } diff --git a/internal/wasm/binary/data.go b/internal/wasm/binary/data.go index 29350bcb..2a6a848c 100644 --- a/internal/wasm/binary/data.go +++ b/internal/wasm/binary/data.go @@ -1,6 +1,7 @@ package binary import ( + "bytes" "fmt" "io" @@ -8,7 +9,7 @@ import ( wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeDataSegment(r io.Reader) (*wasm.DataSegment, error) { +func decodeDataSegment(r *bytes.Reader) (*wasm.DataSegment, error) { d, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("read memory index: %v", err) diff --git a/internal/wasm/binary/decoder.go b/internal/wasm/binary/decoder.go index dcf626c4..c2eb1c6e 100644 --- a/internal/wasm/binary/decoder.go +++ b/internal/wasm/binary/decoder.go @@ -9,9 +9,9 @@ import ( wasm "github.com/tetratelabs/wazero/internal/wasm" ) -// DecodeModule implements wasm.DecodeModule for the WebAssembly 1.0 (20191205) Binary Format +// DecodeModule implements internalwasm.DecodeModule for the WebAssembly 1.0 (20191205) Binary Format // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-format%E2%91%A0 -func DecodeModule(binary []byte) (*wasm.Module, error) { +func DecodeModule(binary []byte, features wasm.Features) (*wasm.Module, error) { r := bytes.NewReader(binary) // Magic number. @@ -71,7 +71,9 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { case wasm.SectionIDType: m.TypeSection, err = decodeTypeSection(r) case wasm.SectionIDImport: - m.ImportSection, err = decodeImportSection(r) + if m.ImportSection, err = decodeImportSection(r, features); err != nil { + return nil, err // avoid re-wrapping the error. + } case wasm.SectionIDFunction: m.FunctionSection, err = decodeFunctionSection(r) case wasm.SectionIDTable: @@ -79,7 +81,9 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { case wasm.SectionIDMemory: m.MemorySection, err = decodeMemorySection(r) case wasm.SectionIDGlobal: - m.GlobalSection, err = decodeGlobalSection(r) + if m.GlobalSection, err = decodeGlobalSection(r, features); err != nil { + return nil, err // avoid re-wrapping the error. + } case wasm.SectionIDExport: m.ExportSection, err = decodeExportSection(r) case wasm.SectionIDStart: diff --git a/internal/wasm/binary/decoder_test.go b/internal/wasm/binary/decoder_test.go index d2957f7e..b9b2393a 100644 --- a/internal/wasm/binary/decoder_test.go +++ b/internal/wasm/binary/decoder_test.go @@ -1,6 +1,7 @@ package binary import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -87,7 +88,7 @@ func TestDecodeModule(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - m, e := DecodeModule(EncodeModule(tc.input)) + m, e := DecodeModule(EncodeModule(tc.input), wasm.Features20191205) require.NoError(t, e) require.Equal(t, tc.input, m) }) @@ -97,7 +98,7 @@ func TestDecodeModule(t *testing.T) { wasm.SectionIDCustom, 0xf, // 15 bytes in this section 0x04, 'm', 'e', 'm', 'e', 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) - m, e := DecodeModule(input) + m, e := DecodeModule(input, wasm.Features20191205) require.NoError(t, e) require.Equal(t, &wasm.Module{}, m) }) @@ -111,7 +112,7 @@ func TestDecodeModule(t *testing.T) { subsectionIDModuleName, 0x07, // 7 bytes in this subsection 0x06, // the Module name simple is 6 bytes long 's', 'i', 'm', 'p', 'l', 'e') - m, e := DecodeModule(input) + m, e := DecodeModule(input, wasm.Features20191205) require.NoError(t, e) require.Equal(t, &wasm.Module{NameSection: &wasm.NameSection{ModuleName: "simple"}}, m) }) @@ -121,6 +122,7 @@ func TestDecodeModule_Errors(t *testing.T) { tests := []struct { name string input []byte + features wasm.Features expectedErr string }{ { @@ -144,13 +146,33 @@ func TestDecodeModule_Errors(t *testing.T) { subsectionIDModuleName, 0x02, 0x01, 'x'), expectedErr: "section custom: redundant custom section name", }, + { + name: fmt.Sprintf("define mutable global when %s disabled", wasm.FeatureMutableGlobal), + features: wasm.Features20191205.Set(wasm.FeatureMutableGlobal, false), + input: append(append(Magic, version...), + wasm.SectionIDGlobal, 0x06, // 6 bytes in this section + 0x01, wasm.ValueTypeI32, 0x01, // 1 global i32 mutable + wasm.OpcodeI32Const, 0x00, wasm.OpcodeEnd, // arbitrary init to zero + ), + expectedErr: "global[0]: feature mutable-global is disabled", + }, + { + name: fmt.Sprintf("import mutable global when %s disabled", wasm.FeatureMutableGlobal), + features: wasm.Features20191205.Set(wasm.FeatureMutableGlobal, false), + input: append(append(Magic, version...), + wasm.SectionIDImport, 0x08, // 8 bytes in this section + 0x01, 0x01, 'a', 0x01, 'b', wasm.ExternTypeGlobal, // 1 import a.b of type global + wasm.ValueTypeI32, 0x01, // 1 global i32 mutable + ), + expectedErr: "import[0] global[a.b]: feature mutable-global is disabled", + }, } for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - _, e := DecodeModule(tc.input) + _, e := DecodeModule(tc.input, tc.features) require.EqualError(t, e, tc.expectedErr) }) } diff --git a/internal/wasm/binary/element.go b/internal/wasm/binary/element.go index 5b47f6c3..6b7e0257 100644 --- a/internal/wasm/binary/element.go +++ b/internal/wasm/binary/element.go @@ -1,14 +1,14 @@ package binary import ( + "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeElementSegment(r io.Reader) (*wasm.ElementSegment, error) { +func decodeElementSegment(r *bytes.Reader) (*wasm.ElementSegment, error) { ti, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get table index: %w", err) diff --git a/internal/wasm/binary/export.go b/internal/wasm/binary/export.go index dc749c55..72a21dc7 100644 --- a/internal/wasm/binary/export.go +++ b/internal/wasm/binary/export.go @@ -3,7 +3,6 @@ package binary import ( "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" @@ -16,19 +15,19 @@ func decodeExport(r *bytes.Reader) (i *wasm.Export, err error) { return nil, err } - b := make([]byte, 1) - if _, err = io.ReadFull(r, b); err != nil { + b, err := r.ReadByte() + if err != nil { return nil, fmt.Errorf("error decoding export kind: %w", err) } - i.Type = b[0] + i.Type = b switch i.Type { case wasm.ExternTypeFunc, wasm.ExternTypeTable, wasm.ExternTypeMemory, wasm.ExternTypeGlobal: if i.Index, _, err = leb128.DecodeUint32(r); err != nil { return nil, fmt.Errorf("error decoding export index: %w", err) } default: - return nil, fmt.Errorf("%w: invalid byte for exportdesc: %#x", ErrInvalidByte, b[0]) + return nil, fmt.Errorf("%w: invalid byte for exportdesc: %#x", ErrInvalidByte, b) } return } diff --git a/internal/wasm/binary/global.go b/internal/wasm/binary/global.go index 0de08bbc..8adc1be1 100644 --- a/internal/wasm/binary/global.go +++ b/internal/wasm/binary/global.go @@ -1,25 +1,21 @@ package binary import ( - "fmt" - "io" + "bytes" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeGlobal(r io.Reader) (*wasm.Global, error) { - gt, err := decodeGlobalType(r) +func decodeGlobal(r *bytes.Reader, features wasm.Features) (*wasm.Global, error) { + gt, err := decodeGlobalType(r, features) if err != nil { - return nil, fmt.Errorf("read global type: %v", err) + return nil, err } init, err := decodeConstantExpression(r) if err != nil { - return nil, fmt.Errorf("get init expression: %v", err) + return nil, err } - return &wasm.Global{ - Type: gt, - Init: init, - }, nil + return &wasm.Global{Type: gt, Init: init}, nil } diff --git a/internal/wasm/binary/import.go b/internal/wasm/binary/import.go index 5316f86c..ea92de5c 100644 --- a/internal/wasm/binary/import.go +++ b/internal/wasm/binary/import.go @@ -3,47 +3,40 @@ package binary import ( "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeImport(r *bytes.Reader) (i *wasm.Import, err error) { +func decodeImport(r *bytes.Reader, idx uint32, features wasm.Features) (i *wasm.Import, err error) { i = &wasm.Import{} if i.Module, _, err = decodeUTF8(r, "import module"); err != nil { - return nil, err + return nil, fmt.Errorf("import[%d] error decoding module: %w", idx, err) } if i.Name, _, err = decodeUTF8(r, "import name"); err != nil { - return nil, err + return nil, fmt.Errorf("import[%d] error decoding name: %w", idx, err) } - b := make([]byte, 1) - if _, err = io.ReadFull(r, b); err != nil { - return nil, fmt.Errorf("error decoding import kind: %w", err) + b, err := r.ReadByte() + if err != nil { + return nil, fmt.Errorf("import[%d] error decoding type: %w", idx, err) } - - i.Type = b[0] + i.Type = b switch i.Type { case wasm.ExternTypeFunc: - if i.DescFunc, _, err = leb128.DecodeUint32(r); err != nil { - return nil, fmt.Errorf("error decoding import func typeindex: %w", err) - } + i.DescFunc, _, err = leb128.DecodeUint32(r) case wasm.ExternTypeTable: - if i.DescTable, err = decodeTableType(r); err != nil { - return nil, fmt.Errorf("error decoding import table desc: %w", err) - } + i.DescTable, err = decodeTableType(r) case wasm.ExternTypeMemory: - if i.DescMem, err = decodeMemoryType(r); err != nil { - return nil, fmt.Errorf("error decoding import mem desc: %w", err) - } + i.DescMem, err = decodeMemoryType(r) case wasm.ExternTypeGlobal: - if i.DescGlobal, err = decodeGlobalType(r); err != nil { - return nil, fmt.Errorf("error decoding import global desc: %w", err) - } + i.DescGlobal, err = decodeGlobalType(r, features) default: - return nil, fmt.Errorf("%w: invalid byte for importdesc: %#x", ErrInvalidByte, b[0]) + err = fmt.Errorf("%w: invalid byte for importdesc: %#x", ErrInvalidByte, b) + } + if err != nil { + return nil, fmt.Errorf("import[%d] %s[%s.%s]: %w", idx, wasm.ExternTypeName(i.Type), i.Module, i.Name, err) } return } diff --git a/internal/wasm/binary/limits.go b/internal/wasm/binary/limits.go index 9cc5f887..662ea6fd 100644 --- a/internal/wasm/binary/limits.go +++ b/internal/wasm/binary/limits.go @@ -1,8 +1,8 @@ package binary import ( + "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" @@ -11,15 +11,14 @@ import ( // decodeLimitsType returns the wasm.LimitsType decoded with the WebAssembly 1.0 (20191205) Binary Format. // // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#limits%E2%91%A6 -func decodeLimitsType(r io.Reader) (*wasm.LimitsType, error) { - b := make([]byte, 1) - _, err := io.ReadFull(r, b) +func decodeLimitsType(r *bytes.Reader) (*wasm.LimitsType, error) { + b, err := r.ReadByte() if err != nil { return nil, fmt.Errorf("read leading byte: %v", err) } ret := &wasm.LimitsType{} - switch b[0] { + switch b { case 0x00: ret.Min, _, err = leb128.DecodeUint32(r) if err != nil { @@ -36,7 +35,7 @@ func decodeLimitsType(r io.Reader) (*wasm.LimitsType, error) { } ret.Max = &m default: - return nil, fmt.Errorf("%v for limits: %#x != 0x00 or 0x01", ErrInvalidByte, b[0]) + return nil, fmt.Errorf("%v for limits: %#x != 0x00 or 0x01", ErrInvalidByte, b) } return ret, nil } diff --git a/internal/wasm/binary/memory.go b/internal/wasm/binary/memory.go index 4aefdca7..2bbb0aa6 100644 --- a/internal/wasm/binary/memory.go +++ b/internal/wasm/binary/memory.go @@ -1,8 +1,8 @@ package binary import ( + "bytes" "fmt" - "io" wasm "github.com/tetratelabs/wazero/internal/wasm" ) @@ -10,7 +10,7 @@ import ( // decodeMemoryType returns the wasm.MemoryType decoded with the WebAssembly 1.0 (20191205) Binary Format. // // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-memory -func decodeMemoryType(r io.Reader) (*wasm.MemoryType, error) { +func decodeMemoryType(r *bytes.Reader) (*wasm.MemoryType, error) { ret, err := decodeLimitsType(r) if err != nil { return nil, err diff --git a/internal/wasm/binary/section.go b/internal/wasm/binary/section.go index 57dcd07d..eba34b92 100644 --- a/internal/wasm/binary/section.go +++ b/internal/wasm/binary/section.go @@ -3,13 +3,12 @@ package binary import ( "bytes" "fmt" - "io" "github.com/tetratelabs/wazero/internal/leb128" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeTypeSection(r io.Reader) ([]*wasm.FunctionType, error) { +func decodeTypeSection(r *bytes.Reader) ([]*wasm.FunctionType, error) { vs, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get size of vector: %w", err) @@ -24,14 +23,14 @@ func decodeTypeSection(r io.Reader) ([]*wasm.FunctionType, error) { return result, nil } -func decodeFunctionType(r io.Reader) (*wasm.FunctionType, error) { - b := make([]byte, 1) - if _, err := io.ReadFull(r, b); err != nil { +func decodeFunctionType(r *bytes.Reader) (*wasm.FunctionType, error) { + b, err := r.ReadByte() + if err != nil { return nil, fmt.Errorf("read leading byte: %w", err) } - if b[0] != 0x60 { - return nil, fmt.Errorf("%w: %#x != 0x60", ErrInvalidByte, b[0]) + if b != 0x60 { + return nil, fmt.Errorf("%w: %#x != 0x60", ErrInvalidByte, b) } s, _, err := leb128.DecodeUint32(r) @@ -62,7 +61,7 @@ func decodeFunctionType(r io.Reader) (*wasm.FunctionType, error) { }, nil } -func decodeImportSection(r *bytes.Reader) ([]*wasm.Import, error) { +func decodeImportSection(r *bytes.Reader, features wasm.Features) ([]*wasm.Import, error) { vs, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get size of vector: %w", err) @@ -70,8 +69,8 @@ func decodeImportSection(r *bytes.Reader) ([]*wasm.Import, error) { result := make([]*wasm.Import, vs) for i := uint32(0); i < vs; i++ { - if result[i], err = decodeImport(r); err != nil { - return nil, fmt.Errorf("read import: %w", err) + if result[i], err = decodeImport(r, i, features); err != nil { + return nil, err } } return result, nil @@ -122,7 +121,7 @@ func decodeMemorySection(r *bytes.Reader) ([]*wasm.MemoryType, error) { return result, nil } -func decodeGlobalSection(r *bytes.Reader) ([]*wasm.Global, error) { +func decodeGlobalSection(r *bytes.Reader, features wasm.Features) ([]*wasm.Global, error) { vs, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get size of vector: %w", err) @@ -130,8 +129,8 @@ func decodeGlobalSection(r *bytes.Reader) ([]*wasm.Global, error) { result := make([]*wasm.Global, vs) for i := uint32(0); i < vs; i++ { - if result[i], err = decodeGlobal(r); err != nil { - return nil, fmt.Errorf("read global: %v ", err) + if result[i], err = decodeGlobal(r, features); err != nil { + return nil, fmt.Errorf("global[%d]: %w", i, err) } } return result, nil diff --git a/internal/wasm/binary/types.go b/internal/wasm/binary/types.go index e4c84d1a..49523a9e 100644 --- a/internal/wasm/binary/types.go +++ b/internal/wasm/binary/types.go @@ -1,20 +1,20 @@ package binary import ( + "bytes" "fmt" - "io" wasm "github.com/tetratelabs/wazero/internal/wasm" ) -func decodeTableType(r io.Reader) (*wasm.TableType, error) { - b := make([]byte, 1) - if _, err := io.ReadFull(r, b); err != nil { +func decodeTableType(r *bytes.Reader) (*wasm.TableType, error) { + b, err := r.ReadByte() + if err != nil { return nil, fmt.Errorf("read leading byte: %v", err) } - if b[0] != 0x70 { - return nil, fmt.Errorf("%w: invalid element type %#x != %#x", ErrInvalidByte, b[0], 0x70) + if b != 0x70 { + return nil, fmt.Errorf("%w: invalid element type %#x != %#x", ErrInvalidByte, b, 0x70) } lm, err := decodeLimitsType(r) @@ -28,7 +28,7 @@ func decodeTableType(r io.Reader) (*wasm.TableType, error) { }, nil } -func decodeGlobalType(r io.Reader) (*wasm.GlobalType, error) { +func decodeGlobalType(r *bytes.Reader, features wasm.Features) (*wasm.GlobalType, error) { vt, err := decodeValueTypes(r, 1) if err != nil { return nil, fmt.Errorf("read value type: %w", err) @@ -38,14 +38,17 @@ func decodeGlobalType(r io.Reader) (*wasm.GlobalType, error) { ValType: vt[0], } - b := make([]byte, 1) - if _, err := io.ReadFull(r, b); err != nil { + b, err := r.ReadByte() + if err != nil { return nil, fmt.Errorf("read mutablity: %w", err) } - switch mut := b[0]; mut { + switch mut := b; mut { case 0x00: case 0x01: + if err = features.Require(wasm.FeatureMutableGlobal); err != nil { + return nil, err + } ret.Mutable = true default: return nil, fmt.Errorf("%w for mutability: %#x != 0x00 or 0x01", ErrInvalidByte, mut) diff --git a/internal/wasm/binary/value.go b/internal/wasm/binary/value.go index 8856076e..947d8f4e 100644 --- a/internal/wasm/binary/value.go +++ b/internal/wasm/binary/value.go @@ -42,7 +42,7 @@ func encodeValTypes(vt []wasm.ValueType) []byte { return append(count, vt...) } -func decodeValueTypes(r io.Reader, num uint32) ([]wasm.ValueType, error) { +func decodeValueTypes(r *bytes.Reader, num uint32) ([]wasm.ValueType, error) { if num == 0 { return nil, nil } diff --git a/internal/wasm/features.go b/internal/wasm/features.go new file mode 100644 index 00000000..3e74e204 --- /dev/null +++ b/internal/wasm/features.go @@ -0,0 +1,55 @@ +package internalwasm + +import ( + "fmt" +) + +// Features are the currently enabled features. +// +// Note: This is a bit flag until we have too many (>63). Flags are simpler to manage in multiple places than a map. +type Features uint64 + +// Features20191205 include those finished in WebAssembly 1.0 (20191205). +// +// See https://github.com/WebAssembly/proposals/blob/main/finished-proposals.md +// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205 +const Features20191205 = FeatureMutableGlobal + +const ( + // FeatureMutableGlobal decides if parsing should succeed on internalwasm.GlobalType Mutable + // See https://github.com/WebAssembly/mutable-global + FeatureMutableGlobal Features = 1 << iota +) + +// Set assigns the value for the given feature. +func (f Features) Set(feature Features, val bool) Features { + if val { + return f | feature + } + return f &^ feature +} + +// Get returns the value of the given feature. +func (f Features) Get(feature Features) bool { + return f&feature != 0 +} + +// Require fails with a configuration error if the given feature is not enabled +func (f Features) Require(feature Features) error { + if f&feature == 0 { + return fmt.Errorf("feature %s is disabled", feature) + } + return nil +} + +// String implements fmt.Stringer by returning each enabled feature. +func (f Features) String() string { + switch f { + case 0: + return "" + case FeatureMutableGlobal: + return "mutable-global" // match https://github.com/WebAssembly/mutable-global + default: + return "undefined" // TODO: when there are multiple features join known ones on pipe (|) + } +} diff --git a/internal/wasm/features_test.go b/internal/wasm/features_test.go new file mode 100644 index 00000000..12d62223 --- /dev/null +++ b/internal/wasm/features_test.go @@ -0,0 +1,94 @@ +package internalwasm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestFeatures_ZeroIsInvalid reminds maintainers that a bitset cannot use zero as a flag! +// This is why we start iota with 1. +func TestFeatures_ZeroIsInvalid(t *testing.T) { + f := Features(0) + f = f.Set(0, true) + require.False(t, f.Get(0)) +} + +// TestFeatures tests the bitset works as expected +func TestFeatures(t *testing.T) { + tests := []struct { + name string + feature Features + }{ + { + name: "one is the smallest flag", + feature: 1, + }, + { + name: "63 is the largest feature flag", // because uint64 + feature: 1 << 63, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + f := Features(0) + + // Defaults to false + require.False(t, f.Get(tc.feature)) + + // Set true makes it true + f = f.Set(tc.feature, true) + require.True(t, f.Get(tc.feature)) + + // Set false makes it false again + f = f.Set(tc.feature, false) + require.False(t, f.Get(tc.feature)) + }) + } +} + +func TestFeatures_String(t *testing.T) { + tests := []struct { + name string + feature Features + expected string + }{ + {name: "none", feature: 0, expected: ""}, + {name: "mutable-global", feature: FeatureMutableGlobal, expected: "mutable-global"}, + {name: "undefined", feature: 1 << 63, expected: "undefined"}, + } + + for _, tt := range tests { + tc := tt + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expected, tc.feature.String()) + }) + } +} + +func TestFeatures_Require(t *testing.T) { + tests := []struct { + name string + feature Features + expectedErr string + }{ + {name: "none", feature: 0, expectedErr: "feature mutable-global is disabled"}, + {name: "mutable-global", feature: FeatureMutableGlobal}, + {name: "undefined", feature: 1 << 63, expectedErr: "feature mutable-global is disabled"}, + } + + for _, tt := range tests { + tc := tt + t.Run(tc.name, func(t *testing.T) { + err := tc.feature.Require(FeatureMutableGlobal) + if tc.expectedErr != "" { + require.EqualError(t, err, tc.expectedErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 59a4f361..a834c623 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -3,7 +3,6 @@ package internalwasm import ( "bytes" "fmt" - "io" "strings" "github.com/tetratelabs/wazero/internal/leb128" @@ -52,7 +51,7 @@ func validateFunction( return fmt.Errorf("unknown memory access") } pc++ - align, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + align, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read memory align: %v", err) } @@ -230,7 +229,7 @@ func validateFunction( } pc += num // offset - _, num, err = leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + _, num, err = leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read memory offset: %v", err) } @@ -240,7 +239,7 @@ func validateFunction( return fmt.Errorf("unknown memory access") } pc++ - val, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -261,14 +260,14 @@ func validateFunction( pc++ switch Opcode(op) { case OpcodeI32Const: - _, num, err := leb128.DecodeInt32(bytes.NewBuffer(body[pc:])) + _, num, err := leb128.DecodeInt32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read i32 immediate: %s", err) } pc += num - 1 valueTypeStack.push(ValueTypeI32) case OpcodeI64Const: - _, num, err := leb128.DecodeInt64(bytes.NewBuffer(body[pc:])) + _, num, err := leb128.DecodeInt64(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read i64 immediate: %v", err) } @@ -283,7 +282,7 @@ func validateFunction( } } else if OpcodeLocalGet <= op && op <= OpcodeGlobalSet { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -345,7 +344,7 @@ func validateFunction( } } else if op == OpcodeBr { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } else if int(index) >= len(controlBloclStack) { @@ -367,7 +366,7 @@ func validateFunction( valueTypeStack.unreachable() } else if op == OpcodeBrIf { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } else if int(index) >= len(controlBloclStack) { @@ -396,7 +395,7 @@ func validateFunction( } } else if op == OpcodeBrTable { pc++ - r := bytes.NewBuffer(body[pc:]) + r := bytes.NewReader(body[pc:]) nl, num, err := leb128.DecodeUint32(r) if err != nil { return fmt.Errorf("read immediate: %w", err) @@ -458,7 +457,7 @@ func validateFunction( valueTypeStack.unreachable() } else if op == OpcodeCall { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -477,7 +476,7 @@ func validateFunction( } } else if op == OpcodeCallIndirect { pc++ - typeIndex, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:])) + typeIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -702,7 +701,7 @@ func validateFunction( return fmt.Errorf("invalid numeric instruction 0x%x", op) } } else if op == OpcodeBlock { - bt, num, err := decodeBlockType(types, bytes.NewBuffer(body[pc+1:])) + bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:])) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -714,7 +713,7 @@ func validateFunction( valueTypeStack.pushStackLimit() pc += num } else if op == OpcodeLoop { - bt, num, err := decodeBlockType(types, bytes.NewBuffer(body[pc+1:])) + bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:])) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -727,7 +726,7 @@ func validateFunction( valueTypeStack.pushStackLimit() pc += num } else if op == OpcodeIf { - bt, num, err := decodeBlockType(types, bytes.NewBuffer(body[pc+1:])) + bt, num, err := decodeBlockType(types, bytes.NewReader(body[pc+1:])) if err != nil { return fmt.Errorf("read block: %w", err) } @@ -948,7 +947,7 @@ type controlBlock struct { isIf bool } -func decodeBlockType(types []*FunctionType, r io.Reader) (*FunctionType, uint64, error) { +func decodeBlockType(types []*FunctionType, r *bytes.Reader) (*FunctionType, uint64, error) { return decodeBlockTypeImpl(func(index int64) (*FunctionType, error) { if index < 0 || (index >= int64(len(types))) { return nil, fmt.Errorf("type index out of range: %d", index) @@ -958,7 +957,7 @@ func decodeBlockType(types []*FunctionType, r io.Reader) (*FunctionType, uint64, } // DecodeBlockType is exported for use in the compiler -func DecodeBlockType(types []*TypeInstance, r io.Reader) (*FunctionType, uint64, error) { +func DecodeBlockType(types []*TypeInstance, r *bytes.Reader) (*FunctionType, uint64, error) { return decodeBlockTypeImpl(func(index int64) (*FunctionType, error) { if index < 0 || (index >= int64(len(types))) { return nil, fmt.Errorf("type index out of range: %d", index) @@ -967,7 +966,7 @@ func DecodeBlockType(types []*TypeInstance, r io.Reader) (*FunctionType, uint64, }, r) } -func decodeBlockTypeImpl(functionTypeResolver func(index int64) (*FunctionType, error), r io.Reader) (*FunctionType, uint64, error) { +func decodeBlockTypeImpl(functionTypeResolver func(index int64) (*FunctionType, error), r *bytes.Reader) (*FunctionType, uint64, error) { raw, num, err := leb128.DecodeInt33AsInt64(r) if err != nil { return nil, 0, fmt.Errorf("decode int33: %w", err) diff --git a/internal/wasm/global_test.go b/internal/wasm/global_test.go index de9cde51..1c50713a 100644 --- a/internal/wasm/global_test.go +++ b/internal/wasm/global_test.go @@ -1,7 +1,6 @@ package internalwasm import ( - "context" gobinary "encoding/binary" "testing" @@ -257,7 +256,7 @@ func TestPublicModule_Global(t *testing.T) { for _, tt := range tests { tc := tt - s := NewStore(context.Background(), &catchContext{}) + s := newStore() t.Run(tc.name, func(t *testing.T) { // Instantiate the module and get the export of the above global module, err := s.Instantiate(tc.module, "") diff --git a/internal/wasm/interpreter/interpreter_test.go b/internal/wasm/interpreter/interpreter_test.go index 764085f1..1c7ae501 100644 --- a/internal/wasm/interpreter/interpreter_test.go +++ b/internal/wasm/interpreter/interpreter_test.go @@ -53,7 +53,7 @@ func TestEngine_Call(t *testing.T) { // Use exported functions to simplify instantiation of a Wasm function e := NewEngine() - store := wasm.NewStore(context.Background(), e) + store := wasm.NewStore(context.Background(), e, wasm.Features20191205) _, err := store.Instantiate(m, "") require.NoError(t, err) diff --git a/internal/wasm/jit/engine_test.go b/internal/wasm/jit/engine_test.go index 26d125fa..4c13980b 100644 --- a/internal/wasm/jit/engine_test.go +++ b/internal/wasm/jit/engine_test.go @@ -97,7 +97,7 @@ func TestEngine_Call(t *testing.T) { // Use exported functions to simplify instantiation of a Wasm function e := NewEngine() - store := wasm.NewStore(context.Background(), e) + store := wasm.NewStore(context.Background(), e, wasm.Features20191205) _, err := store.Instantiate(m, "") require.NoError(t, err) diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 1a6f503a..f475ae08 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -16,7 +16,7 @@ import ( // * result is the module parsed or nil on error // * err is a FormatError invoking the parser, dangling block comments or unexpected characters. // See binary.DecodeModule and text.DecodeModule -type DecodeModule func(source []byte) (result *Module, err error) +type DecodeModule func(source []byte, features Features) (result *Module, err error) // EncodeModule encodes the given module into a byte slice depending on the format of the implementation. // See binary.EncodeModule @@ -322,7 +322,7 @@ func (m *Module) validateExports(functions []Index, globals []*GlobalType, memor func validateConstExpression(globals []*GlobalType, expr *ConstantExpression, expectedType ValueType) (err error) { var actualType ValueType - r := bytes.NewBuffer(expr.Data) + r := bytes.NewReader(expr.Data) switch expr.Opcode { case OpcodeI32Const: _, _, err = leb128.DecodeInt32(r) diff --git a/internal/wasm/store.go b/internal/wasm/store.go index abb2b26a..7cac2682 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -33,6 +33,9 @@ type ( // Engine is a global context for a Store which is in responsible for compilation and execution of Wasm modules. Engine Engine + // EnabledFeatures are read-only to allow optimizations. + EnabledFeatures Features + // ModuleInstances holds the instantiated Wasm modules by module name from Instantiate. ModuleInstances map[string]*ModuleInstance @@ -381,13 +384,14 @@ func (m *ModuleInstance) GetExport(name string, et ExternType) (*ExportInstance, return exp, nil } -func NewStore(ctx context.Context, engine Engine) *Store { +func NewStore(ctx context.Context, engine Engine, enabledFeatures Features) *Store { return &Store{ ctx: ctx, + Engine: engine, + EnabledFeatures: enabledFeatures, ModuleInstances: map[string]*ModuleInstance{}, ModuleContexts: map[string]*ModuleContext{}, TypeIDs: map[string]FunctionTypeID{}, - Engine: engine, maximumFunctionIndex: maximumFunctionIndex, maximumFunctionTypes: maximumFunctionTypes, } @@ -761,7 +765,7 @@ func (s *Store) resolveImports(module *Module) ( } func executeConstExpression(globals []*GlobalInstance, expr *ConstantExpression) (v interface{}) { - r := bytes.NewBuffer(expr.Data) + r := bytes.NewReader(expr.Data) switch expr.Opcode { case OpcodeI32Const: v, _, _ = leb128.DecodeInt32(r) @@ -829,6 +833,11 @@ func (s *Store) AddHostFunction(m *ModuleInstance, hf *GoFunc) (*FunctionInstanc } func (s *Store) AddGlobal(m *ModuleInstance, name string, value uint64, valueType ValueType, mutable bool) error { + if mutable { + if err := s.EnabledFeatures.Require(FeatureMutableGlobal); err != nil { + return err + } + } g := &GlobalInstance{ Val: value, Type: &GlobalType{Mutable: mutable, ValType: valueType}, diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 0c306b76..98978df1 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -3,6 +3,7 @@ package internalwasm import ( "context" "encoding/binary" + "fmt" "math" "os" "strconv" @@ -71,7 +72,7 @@ func TestModuleInstance_Memory(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() instance, err := s.Instantiate(tc.input, "test") require.NoError(t, err) @@ -87,7 +88,7 @@ func TestModuleInstance_Memory(t *testing.T) { } func TestStore_AddHostFunction(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() hf, err := NewGoFunc("fn", func(wasm.ModuleContext) {}) require.NoError(t, err) @@ -119,7 +120,7 @@ func TestStore_AddHostFunction(t *testing.T) { } func TestStore_ExportImportedHostFunction(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() hf, err := NewGoFunc("host_fn", func(wasm.ModuleContext) {}) require.NoError(t, err) @@ -177,7 +178,7 @@ func TestFunctionInstance_Call(t *testing.T) { t.Run(tc.name, func(t *testing.T) { engine := &catchContext{} - store := NewStore(storeCtx, engine) + store := NewStore(storeCtx, engine, Features20191205) // Define a fake host function functionName := "fn" @@ -237,14 +238,14 @@ func (e *catchContext) Release(_ *FunctionInstance) error { func TestStore_checkFuncAddrOverflow(t *testing.T) { t.Run("too many functions", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() const max = 10 s.maximumFunctionIndex = max err := s.checkFunctionIndexOverflow(max + 1) require.Error(t, err) }) t.Run("ok", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() const max = 10 s.maximumFunctionIndex = max err := s.checkFunctionIndexOverflow(max) @@ -254,7 +255,7 @@ func TestStore_checkFuncAddrOverflow(t *testing.T) { func TestStore_getTypeInstance(t *testing.T) { t.Run("too many functions", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() const max = 10 s.maximumFunctionTypes = max s.TypeIDs = make(map[string]FunctionTypeID) @@ -273,7 +274,7 @@ func TestStore_getTypeInstance(t *testing.T) { } { tc := tc t.Run(tc.String(), func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() actual, err := s.getTypeInstance(tc) require.NoError(t, err) @@ -372,12 +373,56 @@ func TestExecuteConstExpression(t *testing.T) { }) } +func TestStore_AddGlobal(t *testing.T) { + tests := []struct { + name string + enabledFeatures Features + mutable bool + expectedErr string + }{ + { + name: fmt.Sprintf("immutable when %s enabled", FeatureMutableGlobal), + enabledFeatures: Features20191205, + }, + { + name: fmt.Sprintf("immutable when %s disabled", FeatureMutableGlobal), + enabledFeatures: Features20191205.Set(FeatureMutableGlobal, false), + }, + { + name: fmt.Sprintf("mutable when %s enabled", FeatureMutableGlobal), + enabledFeatures: Features20191205, + }, + { + name: fmt.Sprintf("mutable when %s disabled", FeatureMutableGlobal), + enabledFeatures: Features20191205.Set(FeatureMutableGlobal, false), + }, + } + + for _, tt := range tests { + tc := tt + t.Run(tc.name, func(t *testing.T) { + s := NewStore(context.Background(), &catchContext{}, tc.enabledFeatures) + m := &ModuleInstance{Exports: map[string]*ExportInstance{}} + + err := s.AddGlobal(m, "test", 1, ValueTypeI64, tc.mutable) + if tc.expectedErr == "" { + require.NoError(t, err) + e, err := m.GetExport("test", ExternTypeGlobal) + require.NoError(t, err) + require.NotNil(t, e.Global) + } else { + require.EqualError(t, err, tc.expectedErr) + } + }) + } +} + func TestStore_releaseModuleInstance(t *testing.T) { // TODO: } func TestStore_releaseFunctionInstances(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() nonReleaseTargetAddr := FunctionIndex(0) maxAddr := FunctionIndex(10) s.Functions = make([]*FunctionInstance, maxAddr+1) @@ -408,7 +453,7 @@ func TestStore_releaseFunctionInstances(t *testing.T) { func TestStore_addFunctionInstances(t *testing.T) { t.Run("no released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() prevMaxAddr := FunctionIndex(10) s.Functions = make([]*FunctionInstance, prevMaxAddr+1) @@ -422,7 +467,7 @@ func TestStore_addFunctionInstances(t *testing.T) { } }) t.Run("reuse released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() expectedAddr := FunctionIndex(10) s.releasedFunctionIndex = []FunctionIndex{1, expectedAddr} expectedReleasedAddr := s.releasedFunctionIndex[:1] @@ -446,7 +491,7 @@ func TestStore_addFunctionInstances(t *testing.T) { } func TestStore_releaseGlobalInstances(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() nonReleaseTargetAddr := globalIndex(0) maxAddr := globalIndex(10) s.Globals = make([]*GlobalInstance, maxAddr+1) @@ -476,7 +521,7 @@ func TestStore_releaseGlobalInstances(t *testing.T) { func TestStore_addGlobalInstances(t *testing.T) { t.Run("no released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() prevMaxAddr := globalIndex(10) s.Globals = make([]*GlobalInstance, prevMaxAddr+1) @@ -490,7 +535,7 @@ func TestStore_addGlobalInstances(t *testing.T) { } }) t.Run("reuse released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() expectedAddr := globalIndex(10) s.releasedGlobalIndex = []globalIndex{1, expectedAddr} expectedReleasedGlobalIndex := s.releasedGlobalIndex[:1] @@ -514,7 +559,7 @@ func TestStore_addGlobalInstances(t *testing.T) { } func TestStore_releaseTableInstance(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() nonReleaseTargetAddr := tableIndex(0) maxAddr := tableIndex(10) s.Tables = make([]*TableInstance, maxAddr+1) @@ -536,7 +581,7 @@ func TestStore_releaseTableInstance(t *testing.T) { func TestStore_addTableInstance(t *testing.T) { t.Run("no released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() prevMaxAddr := tableIndex(10) s.Tables = make([]*TableInstance, prevMaxAddr+1) @@ -550,7 +595,7 @@ func TestStore_addTableInstance(t *testing.T) { } }) t.Run("reuse released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() expectedAddr := tableIndex(10) s.releasedTableIndex = []tableIndex{1000, expectedAddr} expectedReleasedTableIndex := s.releasedTableIndex[:1] @@ -574,7 +619,7 @@ func TestStore_addTableInstance(t *testing.T) { } func TestStore_releaseMemoryInstance(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() nonReleaseTargetAddr := memoryIndex(0) releaseTargetAddr := memoryIndex(10) s.Memories = make([]*MemoryInstance, releaseTargetAddr+1) @@ -596,7 +641,7 @@ func TestStore_releaseMemoryInstance(t *testing.T) { func TestStore_addMemoryInstance(t *testing.T) { t.Run("no released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() prevMaxAddr := memoryIndex(10) s.Memories = make([]*MemoryInstance, prevMaxAddr+1) @@ -610,7 +655,7 @@ func TestStore_addMemoryInstance(t *testing.T) { } }) t.Run("reuse released index", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() expectedAddr := memoryIndex(10) s.releasedMemoryIndex = []memoryIndex{1000, expectedAddr} expectedReleasedMemoryIndex := s.releasedMemoryIndex[:1] @@ -638,13 +683,13 @@ func TestStore_resolveImports(t *testing.T) { const name = "target" t.Run("module not instantiated", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: "unknown", Name: "unknown"}}}) require.Error(t, err) require.Contains(t, err.Error(), "module \"unknown\" not instantiated") }) t.Run("export instance not found", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{}, Name: moduleName} _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: "unknown"}}}) require.Error(t, err) @@ -652,14 +697,14 @@ func TestStore_resolveImports(t *testing.T) { }) t.Run("func", func(t *testing.T) { t.Run("unknwon type", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: {}}, Name: moduleName} _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeFunc, DescFunc: 100}}}) require.Error(t, err) require.Contains(t, err.Error(), "unknown type for function import") }) t.Run("signature mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Function: &FunctionInstance{FunctionType: &TypeInstance{Type: &FunctionType{}}}, }}, Name: moduleName} @@ -672,7 +717,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "signature mimatch: null_f32 != null_null") }) t.Run("ok", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() f := &FunctionInstance{FunctionType: &TypeInstance{Type: &FunctionType{Results: []ValueType{ValueTypeF32}}}} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Function: f, @@ -689,7 +734,7 @@ func TestStore_resolveImports(t *testing.T) { }) t.Run("global", func(t *testing.T) { t.Run("mutability mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Type: ExternTypeGlobal, Global: &GlobalInstance{Type: &GlobalType{Mutable: false}}, @@ -699,7 +744,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible global import: mutability mismatch") }) t.Run("type mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Type: ExternTypeGlobal, Global: &GlobalInstance{Type: &GlobalType{ValType: ValueTypeI32}}, @@ -709,7 +754,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible global import: value type mismatch") }) t.Run("ok", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() inst := &GlobalInstance{Type: &GlobalType{ValType: ValueTypeI32}} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: {Type: ExternTypeGlobal, Global: inst}}, Name: moduleName} _, globals, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: inst.Type}}}) @@ -719,7 +764,7 @@ func TestStore_resolveImports(t *testing.T) { }) t.Run("table", func(t *testing.T) { t.Run("element type", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Type: ExternTypeTable, Table: &TableInstance{ElemType: 0x00}, // Unknown! @@ -729,7 +774,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible table improt: element type mismatch") }) t.Run("minimum size mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() importTableType := &TableType{Limit: &LimitsType{Min: 2}} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Type: ExternTypeTable, @@ -740,7 +785,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible table import: minimum size mismatch") }) t.Run("maximum size mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() max := uint32(10) importTableType := &TableType{Limit: &LimitsType{Max: &max}} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { @@ -752,7 +797,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible table import: maximum size mismatch") }) t.Run("ok", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() max := uint32(10) tableInst := &TableInstance{Max: &max} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { @@ -766,7 +811,7 @@ func TestStore_resolveImports(t *testing.T) { }) t.Run("memory", func(t *testing.T) { t.Run("minimum size mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() importMemoryType := &MemoryType{Min: 2} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { Type: ExternTypeMemory, @@ -777,7 +822,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible memory import: minimum size mismatch") }) t.Run("maximum size mismatch", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() max := uint32(10) importMemoryType := &MemoryType{Max: &max} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { @@ -789,7 +834,7 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, err.Error(), "incompatible memory import: maximum size mismatch") }) t.Run("ok", func(t *testing.T) { - s := NewStore(context.Background(), &catchContext{}) + s := newStore() max := uint32(10) memoryInst := &MemoryInstance{Max: &max} s.ModuleInstances[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { @@ -924,3 +969,7 @@ func TestModuleInstance_applyElements(t *testing.T) { func Test_newModuleInstance(t *testing.T) { } + +func newStore() *Store { + return NewStore(context.Background(), &catchContext{}, Features20191205) +} diff --git a/internal/wasm/text/decoder.go b/internal/wasm/text/decoder.go index 1aca36ab..6a00d079 100644 --- a/internal/wasm/text/decoder.go +++ b/internal/wasm/text/decoder.go @@ -97,9 +97,11 @@ type moduleParser struct { fieldCountFunc, fieldCountMemory uint32 } -// DecodeModule implements wasm.DecodeModule for the WebAssembly 1.0 (20191205) Text Format +// DecodeModule implements internalwasm.DecodeModule for the WebAssembly 1.0 (20191205) Text Format // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#text-format%E2%91%A0 -func DecodeModule(source []byte) (result *wasm.Module, err error) { +func DecodeModule(source []byte, _ wasm.Features) (result *wasm.Module, err error) { + // TODO: when globals are supported, err on mutable globals if disabled + // names are the wasm.Module NameSection // // * ModuleName: ex. "test" if (module $test) diff --git a/internal/wasm/text/decoder_test.go b/internal/wasm/text/decoder_test.go index 6f8efc6f..b64399f0 100644 --- a/internal/wasm/text/decoder_test.go +++ b/internal/wasm/text/decoder_test.go @@ -1493,7 +1493,7 @@ func TestDecodeModule(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - m, err := DecodeModule([]byte(tc.input)) + m, err := DecodeModule([]byte(tc.input), wasm.Features20191205) require.NoError(t, err) require.Equal(t, tc.expected, m) }) @@ -2080,7 +2080,7 @@ func TestParseModule_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - _, err := DecodeModule([]byte(tc.input)) + _, err := DecodeModule([]byte(tc.input), wasm.Features20191205) require.EqualError(t, err, tc.expectedErr) }) } diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index ffb24fbc..60fd474d 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -221,7 +221,7 @@ operatorSwitch: // Nop is noop! case wasm.OpcodeBlock: bt, num, err := wasm.DecodeBlockType(c.f.ModuleInstance.Types, - bytes.NewBuffer(c.f.Body[c.pc+1:])) + bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("reading block type for block instruction: %w", err) } @@ -247,7 +247,7 @@ operatorSwitch: case wasm.OpcodeLoop: bt, num, err := wasm.DecodeBlockType(c.f.ModuleInstance.Types, - bytes.NewBuffer(c.f.Body[c.pc+1:])) + bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("reading block type for loop instruction: %w", err) } @@ -285,7 +285,7 @@ operatorSwitch: case wasm.OpcodeIf: bt, num, err := wasm.DecodeBlockType(c.f.ModuleInstance.Types, - bytes.NewBuffer(c.f.Body[c.pc+1:])) + bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("reading block type for if instruction: %w", err) } @@ -468,7 +468,7 @@ operatorSwitch: } case wasm.OpcodeBr: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewBuffer(c.f.Body[c.pc+1:])) + targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -488,7 +488,7 @@ operatorSwitch: // and can be safely removed. c.markUnreachable() case wasm.OpcodeBrIf: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewBuffer(c.f.Body[c.pc+1:])) + targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -513,7 +513,7 @@ operatorSwitch: }, ) case wasm.OpcodeBrTable: - r := bytes.NewBuffer(c.f.Body[c.pc+1:]) + r := bytes.NewReader(c.f.Body[c.pc+1:]) numTargets, n, err := leb128.DecodeUint32(r) if err != nil { return fmt.Errorf("error reading number of targets in br_table: %w", err) @@ -585,7 +585,7 @@ operatorSwitch: if index == nil { return fmt.Errorf("index does not exist for indirect function call") } - tableIndex, n, err := leb128.DecodeUint32(bytes.NewBuffer(c.f.Body[c.pc+1:])) + tableIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("read target for br_table: %w", err) } @@ -841,7 +841,7 @@ operatorSwitch: &OperationMemoryGrow{}, ) case wasm.OpcodeI32Const: - val, num, err := leb128.DecodeInt32(bytes.NewBuffer(c.f.Body[c.pc+1:])) + val, num, err := leb128.DecodeInt32(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -850,7 +850,7 @@ operatorSwitch: &OperationConstI32{Value: uint32(val)}, ) case wasm.OpcodeI64Const: - val, num, err := leb128.DecodeInt64(bytes.NewBuffer(c.f.Body[c.pc+1:])) + val, num, err := leb128.DecodeInt64(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return fmt.Errorf("reading i64.const value: %v", err) } @@ -1392,7 +1392,7 @@ func (c *compiler) applyToStack(opcode wasm.Opcode) (*uint32, error) { wasm.OpcodeGlobalGet, wasm.OpcodeGlobalSet: // Assumes that we are at the opcode now so skip it before read immediates. - v, num, err := leb128.DecodeUint32(bytes.NewBuffer(c.f.Body[c.pc+1:])) + v, num, err := leb128.DecodeUint32(bytes.NewReader(c.f.Body[c.pc+1:])) if err != nil { return nil, fmt.Errorf("reading immediates: %w", err) } @@ -1530,7 +1530,7 @@ func (c *compiler) getFrameDropRange(frame *controlFrame) *InclusiveRange { } func (c *compiler) readMemoryImmediate(tag string) (*MemoryImmediate, error) { - r := bytes.NewBuffer(c.f.Body[c.pc+1:]) + r := bytes.NewReader(c.f.Body[c.pc+1:]) alignment, num, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("reading alignment for %s: %w", tag, err) diff --git a/tests/spectest/spec_test.go b/tests/spectest/spec_test.go index f21a3b3f..3df58c5f 100644 --- a/tests/spectest/spec_test.go +++ b/tests/spectest/spec_test.go @@ -269,7 +269,7 @@ func runTest(t *testing.T, newEngine func() wasm.Engine) { wastName := basename(base.SourceFile) t.Run(wastName, func(t *testing.T) { - store := wasm.NewStore(context.Background(), newEngine()) + store := wasm.NewStore(context.Background(), newEngine(), wasm.Features20191205) addSpectestModule(t, store) var lastInstanceName string @@ -281,7 +281,7 @@ func runTest(t *testing.T, newEngine func() wasm.Engine) { buf, err := testcases.ReadFile(testdataPath(c.Filename)) require.NoError(t, err, msg) - mod, err := binary.DecodeModule(buf) + mod, err := binary.DecodeModule(buf, wasm.Features20191205) require.NoError(t, err, msg) lastInstanceName = c.Name @@ -414,7 +414,7 @@ func runTest(t *testing.T, newEngine func() wasm.Engine) { } func requireInstantiationError(t *testing.T, store *wasm.Store, buf []byte, msg string) { - mod, err := binary.DecodeModule(buf) + mod, err := binary.DecodeModule(buf, store.EnabledFeatures) if err != nil { return } diff --git a/vs/codec_test.go b/vs/codec_test.go index 5393c507..fce9f373 100644 --- a/vs/codec_test.go +++ b/vs/codec_test.go @@ -91,13 +91,13 @@ func newExample() *wasm.Module { func TestExampleUpToDate(t *testing.T) { t.Run("binary.DecodeModule", func(t *testing.T) { - m, err := binary.DecodeModule(exampleBinary) + m, err := binary.DecodeModule(exampleBinary, wasm.Features20191205) require.NoError(t, err) require.Equal(t, example, m) }) t.Run("text.DecodeModule", func(t *testing.T) { - m, err := text.DecodeModule(exampleText) + m, err := text.DecodeModule(exampleText, wasm.Features20191205) require.NoError(t, err) require.Equal(t, example, m) }) @@ -124,7 +124,7 @@ func BenchmarkCodecExample(b *testing.B) { b.Run("binary.DecodeModule", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - if _, err := binary.DecodeModule(exampleBinary); err != nil { + if _, err := binary.DecodeModule(exampleBinary, wasm.Features20191205); err != nil { b.Fatal(err) } } @@ -138,7 +138,7 @@ func BenchmarkCodecExample(b *testing.B) { b.Run("text.DecodeModule", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - if _, err := text.DecodeModule(exampleText); err != nil { + if _, err := text.DecodeModule(exampleText, wasm.Features20191205); err != nil { b.Fatal(err) } } @@ -146,7 +146,7 @@ func BenchmarkCodecExample(b *testing.B) { b.Run("wat2wasm via text.DecodeModule->binary.EncodeModule", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - if m, err := text.DecodeModule(exampleText); err != nil { + if m, err := text.DecodeModule(exampleText, wasm.Features20191205); err != nil { b.Fatal(err) } else { _ = binary.EncodeModule(m) diff --git a/wasm.go b/wasm.go index 22e9b744..3578e66d 100644 --- a/wasm.go +++ b/wasm.go @@ -65,12 +65,16 @@ func NewRuntime() Runtime { // NewRuntimeWithConfig returns a runtime with the given configuration. func NewRuntimeWithConfig(config *RuntimeConfig) Runtime { - return &runtime{store: internalwasm.NewStore(config.ctx, config.engine)} + return &runtime{ + store: internalwasm.NewStore(config.ctx, config.engine, config.enabledFeatures), + enabledFeatures: config.enabledFeatures, + } } // runtime allows decoupling of public interfaces from internal representation. type runtime struct { - store *internalwasm.Store + store *internalwasm.Store + enabledFeatures internalwasm.Features } // Module implements wasm.Store Module @@ -101,7 +105,7 @@ func (r *runtime) DecodeModule(source []byte) (*DecodedModule, error) { decoder = text.DecodeModule } - internal, err := decoder(source) + internal, err := decoder(source, r.enabledFeatures) if err != nil { return nil, err } else if err = internal.Validate(); err != nil {