Adds ability to disable mutable globals and improves decode perf (#315)

This adds `RuntimeConfig.WithFeatureMutableGlobal(enabled bool)`, which
allows disabling of mutable globals. When disabled, any attempt to add a
mutable global, either explicitly or implicitly via decoding wasm will
fail.

To support this, there's a new `Features` bitflag that can allow up to
63 feature toggles without passing structs.

While here, I fixed a significant performance problem in decoding
binary:

Before
```
BenchmarkCodecExample/binary.DecodeModule-16         	  184243	      5623 ns/op	    3848 B/op	     184 allocs/op
```

Now
```
BenchmarkCodecExample/binary.DecodeModule-16         	  294084	      3520 ns/op	    2176 B/op	      91 allocs/op

```

Signed-off-by: Adrian Cole <adrian@tetrate.io>
This commit is contained in:
Crypt Keeper
2022-03-03 10:31:10 +08:00
committed by GitHub
parent 490f830096
commit cfb11f352a
33 changed files with 519 additions and 221 deletions

View File

@@ -1,6 +1,7 @@
package binary
import (
"bytes"
"fmt"
"io"
"math"
@@ -9,37 +10,43 @@ import (
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeCode(r io.Reader) (*wasm.Code, error) {
func decodeCode(r *bytes.Reader) (*wasm.Code, error) {
ss, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get the size of code: %w", err)
}
r = io.LimitReader(r, int64(ss))
remaining := int64(ss)
// parse locals
ls, _, err := leb128.DecodeUint32(r)
ls, bytesRead, err := leb128.DecodeUint32(r)
remaining -= int64(bytesRead)
if err != nil {
return nil, fmt.Errorf("get the size locals: %v", err)
} else if remaining < 0 {
return nil, io.EOF
}
var nums []uint64
var types []wasm.ValueType
var sum uint64
b := make([]byte, 1)
var n uint32
for i := uint32(0); i < ls; i++ {
n, _, err := leb128.DecodeUint32(r)
n, bytesRead, err = leb128.DecodeUint32(r)
remaining -= int64(bytesRead) + 1 // +1 for the subsequent ReadByte
if err != nil {
return nil, fmt.Errorf("read n of locals: %v", err)
} else if remaining < 0 {
return nil, io.EOF
}
sum += uint64(n)
nums = append(nums, uint64(n))
_, err = io.ReadFull(r, b)
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read type of local: %v", err)
}
switch vt := b[0]; vt {
switch vt := b; vt {
case wasm.ValueTypeI32, wasm.ValueTypeF32, wasm.ValueTypeI64, wasm.ValueTypeF64:
types = append(types, vt)
default:
@@ -59,8 +66,8 @@ func decodeCode(r io.Reader) (*wasm.Code, error) {
}
}
body, err := io.ReadAll(r)
if err != nil {
body := make([]byte, remaining)
if _, err = io.ReadFull(r, body); err != nil {
return nil, fmt.Errorf("read body: %w", err)
}

View File

@@ -3,52 +3,53 @@ package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/ieee754"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeConstantExpression(r io.Reader) (*wasm.ConstantExpression, error) {
b := make([]byte, 1)
_, err := io.ReadFull(r, b)
func decodeConstantExpression(r *bytes.Reader) (*wasm.ConstantExpression, error) {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read opcode: %v", err)
}
buf := new(bytes.Buffer)
teeR := io.TeeReader(r, buf)
opcode := b[0]
remainingBeforeData := int64(r.Len())
offsetAtData := r.Size() - remainingBeforeData
opcode := b
switch opcode {
case wasm.OpcodeI32Const:
_, _, err = leb128.DecodeInt32(teeR)
_, _, err = leb128.DecodeInt32(r)
case wasm.OpcodeI64Const:
_, _, err = leb128.DecodeInt64(teeR)
_, _, err = leb128.DecodeInt64(r)
case wasm.OpcodeF32Const:
_, err = ieee754.DecodeFloat32(teeR)
_, err = ieee754.DecodeFloat32(r)
case wasm.OpcodeF64Const:
_, err = ieee754.DecodeFloat64(teeR)
_, err = ieee754.DecodeFloat64(r)
case wasm.OpcodeGlobalGet:
_, _, err = leb128.DecodeUint32(teeR)
_, _, err = leb128.DecodeUint32(r)
default:
return nil, fmt.Errorf("%v for const expression opt code: %#x", ErrInvalidByte, b[0])
return nil, fmt.Errorf("%v for const expression opt code: %#x", ErrInvalidByte, b)
}
if err != nil {
return nil, fmt.Errorf("read value: %v", err)
}
if _, err := io.ReadFull(r, b); err != nil {
if b, err = r.ReadByte(); err != nil {
return nil, fmt.Errorf("look for end opcode: %v", err)
}
if b[0] != byte(wasm.OpcodeEnd) {
if b != wasm.OpcodeEnd {
return nil, fmt.Errorf("constant expression has been not terminated")
}
return &wasm.ConstantExpression{
Opcode: opcode,
Data: buf.Bytes(),
}, nil
data := make([]byte, remainingBeforeData-int64(r.Len()))
if _, err := r.ReadAt(data, offsetAtData); err != nil {
return nil, fmt.Errorf("error re-buffering ConstantExpression.Data")
}
return &wasm.ConstantExpression{Opcode: opcode, Data: data}, nil
}

View File

@@ -1,6 +1,7 @@
package binary
import (
"bytes"
"fmt"
"io"
@@ -8,7 +9,7 @@ import (
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeDataSegment(r io.Reader) (*wasm.DataSegment, error) {
func decodeDataSegment(r *bytes.Reader) (*wasm.DataSegment, error) {
d, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("read memory index: %v", err)

View File

@@ -9,9 +9,9 @@ import (
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
// DecodeModule implements wasm.DecodeModule for the WebAssembly 1.0 (20191205) Binary Format
// DecodeModule implements internalwasm.DecodeModule for the WebAssembly 1.0 (20191205) Binary Format
// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-format%E2%91%A0
func DecodeModule(binary []byte) (*wasm.Module, error) {
func DecodeModule(binary []byte, features wasm.Features) (*wasm.Module, error) {
r := bytes.NewReader(binary)
// Magic number.
@@ -71,7 +71,9 @@ func DecodeModule(binary []byte) (*wasm.Module, error) {
case wasm.SectionIDType:
m.TypeSection, err = decodeTypeSection(r)
case wasm.SectionIDImport:
m.ImportSection, err = decodeImportSection(r)
if m.ImportSection, err = decodeImportSection(r, features); err != nil {
return nil, err // avoid re-wrapping the error.
}
case wasm.SectionIDFunction:
m.FunctionSection, err = decodeFunctionSection(r)
case wasm.SectionIDTable:
@@ -79,7 +81,9 @@ func DecodeModule(binary []byte) (*wasm.Module, error) {
case wasm.SectionIDMemory:
m.MemorySection, err = decodeMemorySection(r)
case wasm.SectionIDGlobal:
m.GlobalSection, err = decodeGlobalSection(r)
if m.GlobalSection, err = decodeGlobalSection(r, features); err != nil {
return nil, err // avoid re-wrapping the error.
}
case wasm.SectionIDExport:
m.ExportSection, err = decodeExportSection(r)
case wasm.SectionIDStart:

View File

@@ -1,6 +1,7 @@
package binary
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
@@ -87,7 +88,7 @@ func TestDecodeModule(t *testing.T) {
tc := tt
t.Run(tc.name, func(t *testing.T) {
m, e := DecodeModule(EncodeModule(tc.input))
m, e := DecodeModule(EncodeModule(tc.input), wasm.Features20191205)
require.NoError(t, e)
require.Equal(t, tc.input, m)
})
@@ -97,7 +98,7 @@ func TestDecodeModule(t *testing.T) {
wasm.SectionIDCustom, 0xf, // 15 bytes in this section
0x04, 'm', 'e', 'm', 'e',
1, 2, 3, 4, 5, 6, 7, 8, 9, 0)
m, e := DecodeModule(input)
m, e := DecodeModule(input, wasm.Features20191205)
require.NoError(t, e)
require.Equal(t, &wasm.Module{}, m)
})
@@ -111,7 +112,7 @@ func TestDecodeModule(t *testing.T) {
subsectionIDModuleName, 0x07, // 7 bytes in this subsection
0x06, // the Module name simple is 6 bytes long
's', 'i', 'm', 'p', 'l', 'e')
m, e := DecodeModule(input)
m, e := DecodeModule(input, wasm.Features20191205)
require.NoError(t, e)
require.Equal(t, &wasm.Module{NameSection: &wasm.NameSection{ModuleName: "simple"}}, m)
})
@@ -121,6 +122,7 @@ func TestDecodeModule_Errors(t *testing.T) {
tests := []struct {
name string
input []byte
features wasm.Features
expectedErr string
}{
{
@@ -144,13 +146,33 @@ func TestDecodeModule_Errors(t *testing.T) {
subsectionIDModuleName, 0x02, 0x01, 'x'),
expectedErr: "section custom: redundant custom section name",
},
{
name: fmt.Sprintf("define mutable global when %s disabled", wasm.FeatureMutableGlobal),
features: wasm.Features20191205.Set(wasm.FeatureMutableGlobal, false),
input: append(append(Magic, version...),
wasm.SectionIDGlobal, 0x06, // 6 bytes in this section
0x01, wasm.ValueTypeI32, 0x01, // 1 global i32 mutable
wasm.OpcodeI32Const, 0x00, wasm.OpcodeEnd, // arbitrary init to zero
),
expectedErr: "global[0]: feature mutable-global is disabled",
},
{
name: fmt.Sprintf("import mutable global when %s disabled", wasm.FeatureMutableGlobal),
features: wasm.Features20191205.Set(wasm.FeatureMutableGlobal, false),
input: append(append(Magic, version...),
wasm.SectionIDImport, 0x08, // 8 bytes in this section
0x01, 0x01, 'a', 0x01, 'b', wasm.ExternTypeGlobal, // 1 import a.b of type global
wasm.ValueTypeI32, 0x01, // 1 global i32 mutable
),
expectedErr: "import[0] global[a.b]: feature mutable-global is disabled",
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
_, e := DecodeModule(tc.input)
_, e := DecodeModule(tc.input, tc.features)
require.EqualError(t, e, tc.expectedErr)
})
}

View File

@@ -1,14 +1,14 @@
package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeElementSegment(r io.Reader) (*wasm.ElementSegment, error) {
func decodeElementSegment(r *bytes.Reader) (*wasm.ElementSegment, error) {
ti, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get table index: %w", err)

View File

@@ -3,7 +3,6 @@ package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
@@ -16,19 +15,19 @@ func decodeExport(r *bytes.Reader) (i *wasm.Export, err error) {
return nil, err
}
b := make([]byte, 1)
if _, err = io.ReadFull(r, b); err != nil {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("error decoding export kind: %w", err)
}
i.Type = b[0]
i.Type = b
switch i.Type {
case wasm.ExternTypeFunc, wasm.ExternTypeTable, wasm.ExternTypeMemory, wasm.ExternTypeGlobal:
if i.Index, _, err = leb128.DecodeUint32(r); err != nil {
return nil, fmt.Errorf("error decoding export index: %w", err)
}
default:
return nil, fmt.Errorf("%w: invalid byte for exportdesc: %#x", ErrInvalidByte, b[0])
return nil, fmt.Errorf("%w: invalid byte for exportdesc: %#x", ErrInvalidByte, b)
}
return
}

View File

@@ -1,25 +1,21 @@
package binary
import (
"fmt"
"io"
"bytes"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeGlobal(r io.Reader) (*wasm.Global, error) {
gt, err := decodeGlobalType(r)
func decodeGlobal(r *bytes.Reader, features wasm.Features) (*wasm.Global, error) {
gt, err := decodeGlobalType(r, features)
if err != nil {
return nil, fmt.Errorf("read global type: %v", err)
return nil, err
}
init, err := decodeConstantExpression(r)
if err != nil {
return nil, fmt.Errorf("get init expression: %v", err)
return nil, err
}
return &wasm.Global{
Type: gt,
Init: init,
}, nil
return &wasm.Global{Type: gt, Init: init}, nil
}

View File

@@ -3,47 +3,40 @@ package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeImport(r *bytes.Reader) (i *wasm.Import, err error) {
func decodeImport(r *bytes.Reader, idx uint32, features wasm.Features) (i *wasm.Import, err error) {
i = &wasm.Import{}
if i.Module, _, err = decodeUTF8(r, "import module"); err != nil {
return nil, err
return nil, fmt.Errorf("import[%d] error decoding module: %w", idx, err)
}
if i.Name, _, err = decodeUTF8(r, "import name"); err != nil {
return nil, err
return nil, fmt.Errorf("import[%d] error decoding name: %w", idx, err)
}
b := make([]byte, 1)
if _, err = io.ReadFull(r, b); err != nil {
return nil, fmt.Errorf("error decoding import kind: %w", err)
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("import[%d] error decoding type: %w", idx, err)
}
i.Type = b[0]
i.Type = b
switch i.Type {
case wasm.ExternTypeFunc:
if i.DescFunc, _, err = leb128.DecodeUint32(r); err != nil {
return nil, fmt.Errorf("error decoding import func typeindex: %w", err)
}
i.DescFunc, _, err = leb128.DecodeUint32(r)
case wasm.ExternTypeTable:
if i.DescTable, err = decodeTableType(r); err != nil {
return nil, fmt.Errorf("error decoding import table desc: %w", err)
}
i.DescTable, err = decodeTableType(r)
case wasm.ExternTypeMemory:
if i.DescMem, err = decodeMemoryType(r); err != nil {
return nil, fmt.Errorf("error decoding import mem desc: %w", err)
}
i.DescMem, err = decodeMemoryType(r)
case wasm.ExternTypeGlobal:
if i.DescGlobal, err = decodeGlobalType(r); err != nil {
return nil, fmt.Errorf("error decoding import global desc: %w", err)
}
i.DescGlobal, err = decodeGlobalType(r, features)
default:
return nil, fmt.Errorf("%w: invalid byte for importdesc: %#x", ErrInvalidByte, b[0])
err = fmt.Errorf("%w: invalid byte for importdesc: %#x", ErrInvalidByte, b)
}
if err != nil {
return nil, fmt.Errorf("import[%d] %s[%s.%s]: %w", idx, wasm.ExternTypeName(i.Type), i.Module, i.Name, err)
}
return
}

View File

@@ -1,8 +1,8 @@
package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
@@ -11,15 +11,14 @@ import (
// decodeLimitsType returns the wasm.LimitsType decoded with the WebAssembly 1.0 (20191205) Binary Format.
//
// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#limits%E2%91%A6
func decodeLimitsType(r io.Reader) (*wasm.LimitsType, error) {
b := make([]byte, 1)
_, err := io.ReadFull(r, b)
func decodeLimitsType(r *bytes.Reader) (*wasm.LimitsType, error) {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read leading byte: %v", err)
}
ret := &wasm.LimitsType{}
switch b[0] {
switch b {
case 0x00:
ret.Min, _, err = leb128.DecodeUint32(r)
if err != nil {
@@ -36,7 +35,7 @@ func decodeLimitsType(r io.Reader) (*wasm.LimitsType, error) {
}
ret.Max = &m
default:
return nil, fmt.Errorf("%v for limits: %#x != 0x00 or 0x01", ErrInvalidByte, b[0])
return nil, fmt.Errorf("%v for limits: %#x != 0x00 or 0x01", ErrInvalidByte, b)
}
return ret, nil
}

View File

@@ -1,8 +1,8 @@
package binary
import (
"bytes"
"fmt"
"io"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
@@ -10,7 +10,7 @@ import (
// decodeMemoryType returns the wasm.MemoryType decoded with the WebAssembly 1.0 (20191205) Binary Format.
//
// See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-memory
func decodeMemoryType(r io.Reader) (*wasm.MemoryType, error) {
func decodeMemoryType(r *bytes.Reader) (*wasm.MemoryType, error) {
ret, err := decodeLimitsType(r)
if err != nil {
return nil, err

View File

@@ -3,13 +3,12 @@ package binary
import (
"bytes"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/leb128"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeTypeSection(r io.Reader) ([]*wasm.FunctionType, error) {
func decodeTypeSection(r *bytes.Reader) ([]*wasm.FunctionType, error) {
vs, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get size of vector: %w", err)
@@ -24,14 +23,14 @@ func decodeTypeSection(r io.Reader) ([]*wasm.FunctionType, error) {
return result, nil
}
func decodeFunctionType(r io.Reader) (*wasm.FunctionType, error) {
b := make([]byte, 1)
if _, err := io.ReadFull(r, b); err != nil {
func decodeFunctionType(r *bytes.Reader) (*wasm.FunctionType, error) {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read leading byte: %w", err)
}
if b[0] != 0x60 {
return nil, fmt.Errorf("%w: %#x != 0x60", ErrInvalidByte, b[0])
if b != 0x60 {
return nil, fmt.Errorf("%w: %#x != 0x60", ErrInvalidByte, b)
}
s, _, err := leb128.DecodeUint32(r)
@@ -62,7 +61,7 @@ func decodeFunctionType(r io.Reader) (*wasm.FunctionType, error) {
}, nil
}
func decodeImportSection(r *bytes.Reader) ([]*wasm.Import, error) {
func decodeImportSection(r *bytes.Reader, features wasm.Features) ([]*wasm.Import, error) {
vs, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get size of vector: %w", err)
@@ -70,8 +69,8 @@ func decodeImportSection(r *bytes.Reader) ([]*wasm.Import, error) {
result := make([]*wasm.Import, vs)
for i := uint32(0); i < vs; i++ {
if result[i], err = decodeImport(r); err != nil {
return nil, fmt.Errorf("read import: %w", err)
if result[i], err = decodeImport(r, i, features); err != nil {
return nil, err
}
}
return result, nil
@@ -122,7 +121,7 @@ func decodeMemorySection(r *bytes.Reader) ([]*wasm.MemoryType, error) {
return result, nil
}
func decodeGlobalSection(r *bytes.Reader) ([]*wasm.Global, error) {
func decodeGlobalSection(r *bytes.Reader, features wasm.Features) ([]*wasm.Global, error) {
vs, _, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get size of vector: %w", err)
@@ -130,8 +129,8 @@ func decodeGlobalSection(r *bytes.Reader) ([]*wasm.Global, error) {
result := make([]*wasm.Global, vs)
for i := uint32(0); i < vs; i++ {
if result[i], err = decodeGlobal(r); err != nil {
return nil, fmt.Errorf("read global: %v ", err)
if result[i], err = decodeGlobal(r, features); err != nil {
return nil, fmt.Errorf("global[%d]: %w", i, err)
}
}
return result, nil

View File

@@ -1,20 +1,20 @@
package binary
import (
"bytes"
"fmt"
"io"
wasm "github.com/tetratelabs/wazero/internal/wasm"
)
func decodeTableType(r io.Reader) (*wasm.TableType, error) {
b := make([]byte, 1)
if _, err := io.ReadFull(r, b); err != nil {
func decodeTableType(r *bytes.Reader) (*wasm.TableType, error) {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read leading byte: %v", err)
}
if b[0] != 0x70 {
return nil, fmt.Errorf("%w: invalid element type %#x != %#x", ErrInvalidByte, b[0], 0x70)
if b != 0x70 {
return nil, fmt.Errorf("%w: invalid element type %#x != %#x", ErrInvalidByte, b, 0x70)
}
lm, err := decodeLimitsType(r)
@@ -28,7 +28,7 @@ func decodeTableType(r io.Reader) (*wasm.TableType, error) {
}, nil
}
func decodeGlobalType(r io.Reader) (*wasm.GlobalType, error) {
func decodeGlobalType(r *bytes.Reader, features wasm.Features) (*wasm.GlobalType, error) {
vt, err := decodeValueTypes(r, 1)
if err != nil {
return nil, fmt.Errorf("read value type: %w", err)
@@ -38,14 +38,17 @@ func decodeGlobalType(r io.Reader) (*wasm.GlobalType, error) {
ValType: vt[0],
}
b := make([]byte, 1)
if _, err := io.ReadFull(r, b); err != nil {
b, err := r.ReadByte()
if err != nil {
return nil, fmt.Errorf("read mutablity: %w", err)
}
switch mut := b[0]; mut {
switch mut := b; mut {
case 0x00:
case 0x01:
if err = features.Require(wasm.FeatureMutableGlobal); err != nil {
return nil, err
}
ret.Mutable = true
default:
return nil, fmt.Errorf("%w for mutability: %#x != 0x00 or 0x01", ErrInvalidByte, mut)

View File

@@ -42,7 +42,7 @@ func encodeValTypes(vt []wasm.ValueType) []byte {
return append(count, vt...)
}
func decodeValueTypes(r io.Reader, num uint32) ([]wasm.ValueType, error) {
func decodeValueTypes(r *bytes.Reader, num uint32) ([]wasm.ValueType, error) {
if num == 0 {
return nil, nil
}