From f391a1d312ce433502a15a7343c13a21ed703f62 Mon Sep 17 00:00:00 2001 From: Clifton Kaznocha Date: Wed, 2 Nov 2022 19:23:35 -0700 Subject: [PATCH] add ieee754 and leb128 byte slice funcs (#837) Signed-off-by: Clifton Kaznocha --- internal/ieee754/ieee754.go | 20 +++--- internal/leb128/leb128.go | 111 ++++++++++++++++++++--------- internal/leb128/leb128_test.go | 12 ++-- internal/wasm/binary/const_expr.go | 13 +++- internal/wasm/binary/element.go | 2 +- internal/wasm/func_validation.go | 46 ++++++------ internal/wasm/module.go | 25 +++---- internal/wasm/store.go | 14 ++-- internal/wasm/table.go | 5 +- internal/wazeroir/compiler.go | 43 ++++++----- 10 files changed, 170 insertions(+), 121 deletions(-) diff --git a/internal/ieee754/ieee754.go b/internal/ieee754/ieee754.go index 5968a6a1..0c929895 100644 --- a/internal/ieee754/ieee754.go +++ b/internal/ieee754/ieee754.go @@ -8,24 +8,22 @@ import ( // DecodeFloat32 decodes a float32 in IEEE 754 binary representation. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#floating-point%E2%91%A2 -func DecodeFloat32(r io.Reader) (float32, error) { - buf := make([]byte, 4) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err +func DecodeFloat32(buf []byte) (float32, error) { + if len(buf) < 4 { + return 0, io.ErrUnexpectedEOF } - raw := binary.LittleEndian.Uint32(buf) + + raw := binary.LittleEndian.Uint32(buf[:4]) return math.Float32frombits(raw), nil } // DecodeFloat64 decodes a float64 in IEEE 754 binary representation. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#floating-point%E2%91%A2 -func DecodeFloat64(r io.Reader) (float64, error) { - buf := make([]byte, 8) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err +func DecodeFloat64(buf []byte) (float64, error) { + if len(buf) < 8 { + return 0, io.ErrUnexpectedEOF } + raw := binary.LittleEndian.Uint64(buf) return math.Float64frombits(raw), nil } diff --git a/internal/leb128/leb128.go b/internal/leb128/leb128.go index 093e2925..9a62360b 100644 --- a/internal/leb128/leb128.go +++ b/internal/leb128/leb128.go @@ -1,14 +1,25 @@ package leb128 import ( - "bytes" "errors" "fmt" + "io" ) const ( maxVarintLen32 = 5 + maxVarintLen33 = maxVarintLen32 maxVarintLen64 = 10 + + int33Mask int64 = 1 << 7 + int33Mask2 = ^int33Mask + int33Mask3 = 1 << 6 + int33Mask4 = 8589934591 // 2^33-1 + int33Mask5 = 1 << 32 + int33Mask6 = int33Mask4 + 1 // 2^33 + + int64Mask3 = 1 << 6 + int64Mask4 = ^0 ) var ( @@ -98,12 +109,37 @@ func EncodeUint64(value uint64) (buf []byte) { } } -func DecodeUint32(r *bytes.Reader) (ret uint32, bytesRead uint64, err error) { +type nextByte interface { + next(i int) (byte, error) +} + +type byteSliceNext []byte + +func (n byteSliceNext) next(i int) (byte, error) { + if i >= len(n) { + return 0, io.EOF + } + return n[i], nil +} + +type byteReaderNext struct{ io.ByteReader } + +func (n byteReaderNext) next(_ int) (byte, error) { return n.ReadByte() } + +func DecodeUint32(r io.ByteReader) (ret uint32, bytesRead uint64, err error) { + return decodeUint32(byteReaderNext{r}) +} + +func LoadUint32(buf []byte) (ret uint32, bytesRead uint64, err error) { + return decodeUint32(byteSliceNext(buf)) +} + +func decodeUint32(buf nextByte) (ret uint32, bytesRead uint64, err error) { // Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go // with the modification on the overflow handling tailored for 32-bits. var s uint32 for i := 0; i < maxVarintLen32; i++ { - b, err := r.ReadByte() + b, err := buf.next(i) if err != nil { return 0, 0, err } @@ -120,14 +156,19 @@ func DecodeUint32(r *bytes.Reader) (ret uint32, bytesRead uint64, err error) { return 0, 0, errOverflow32 } -func DecodeUint64(r *bytes.Reader) (ret uint64, bytesRead uint64, err error) { +func LoadUint64(buf []byte) (ret uint64, bytesRead uint64, err error) { + bufLen := len(buf) + if bufLen == 0 { + return 0, 0, io.EOF + } + // Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go var s uint64 for i := 0; i < maxVarintLen64; i++ { - b, err := r.ReadByte() - if err != nil { - return 0, 0, err + if i >= bufLen { + return 0, 0, io.EOF } + b := buf[i] if b < 0x80 { // Unused bits (non first bit) must all be zero. if i == maxVarintLen64-1 && b > 1 { @@ -141,11 +182,19 @@ func DecodeUint64(r *bytes.Reader) (ret uint64, bytesRead uint64, err error) { return 0, 0, errOverflow64 } -func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) { +func DecodeInt32(r io.ByteReader) (ret int32, bytesRead uint64, err error) { + return decodeInt32(byteReaderNext{r}) +} + +func LoadInt32(buf []byte) (ret int32, bytesRead uint64, err error) { + return decodeInt32(byteSliceNext(buf)) +} + +func decodeInt32(buf nextByte) (ret int32, bytesRead uint64, err error) { var shift int var b byte for { - b, err = r.ReadByte() + b, err = buf.next(int(bytesRead)) if err != nil { return 0, 0, fmt.Errorf("readByte failed: %w", err) } @@ -158,11 +207,11 @@ func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) { } // Over flow checks. // fixme: can be optimized. - if bytesRead > 5 { + if bytesRead > maxVarintLen32 { return 0, 0, errOverflow32 - } else if unused := b & 0b00110000; bytesRead == 5 && ret < 0 && unused != 0b00110000 { + } else if unused := b & 0b00110000; bytesRead == maxVarintLen32 && ret < 0 && unused != 0b00110000 { return 0, 0, errOverflow32 - } else if bytesRead == 5 && ret >= 0 && unused != 0x00 { + } else if bytesRead == maxVarintLen32 && ret >= 0 && unused != 0x00 { return 0, 0, errOverflow32 } return @@ -174,15 +223,7 @@ func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) { // still needs to fit the 32-bit range of allowed indices. Hence, this is 33, not 32-bit! // // See https://webassembly.github.io/spec/core/binary/instructions.html#control-instructions -func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) { - const ( - int33Mask int64 = 1 << 7 - int33Mask2 = ^int33Mask - int33Mask3 = 1 << 6 - int33Mask4 = 8589934591 // 2^33-1 - int33Mask5 = 1 << 32 - int33Mask6 = int33Mask4 + 1 // 2^33 - ) +func DecodeInt33AsInt64(r io.ByteReader) (ret int64, bytesRead uint64, err error) { var shift int var b int64 var rb byte @@ -212,25 +253,29 @@ func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error } // Over flow checks. // fixme: can be optimized. - if bytesRead > 5 { + if bytesRead > maxVarintLen33 { return 0, 0, errOverflow33 - } else if unused := b & 0b00100000; bytesRead == 5 && ret < 0 && unused != 0b00100000 { + } else if unused := b & 0b00100000; bytesRead == maxVarintLen33 && ret < 0 && unused != 0b00100000 { return 0, 0, errOverflow33 - } else if bytesRead == 5 && ret >= 0 && unused != 0x00 { + } else if bytesRead == maxVarintLen33 && ret >= 0 && unused != 0x00 { return 0, 0, errOverflow33 } return ret, bytesRead, nil } -func DecodeInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) { - const ( - int64Mask3 = 1 << 6 - int64Mask4 = ^0 - ) +func DecodeInt64(r io.ByteReader) (ret int64, bytesRead uint64, err error) { + return decodeInt64(byteReaderNext{r}) +} + +func LoadInt64(buf []byte) (ret int64, bytesRead uint64, err error) { + return decodeInt64(byteSliceNext(buf)) +} + +func decodeInt64(buf nextByte) (ret int64, bytesRead uint64, err error) { var shift int var b byte for { - b, err = r.ReadByte() + b, err = buf.next(int(bytesRead)) if err != nil { return 0, 0, fmt.Errorf("readByte failed: %w", err) } @@ -243,11 +288,11 @@ func DecodeInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) { } // Over flow checks. // fixme: can be optimized. - if bytesRead > 10 { + if bytesRead > maxVarintLen64 { return 0, 0, errOverflow64 - } else if unused := b & 0b00111110; bytesRead == 10 && ret < 0 && unused != 0b00111110 { + } else if unused := b & 0b00111110; bytesRead == maxVarintLen64 && ret < 0 && unused != 0b00111110 { return 0, 0, errOverflow64 - } else if bytesRead == 10 && ret >= 0 && unused != 0x00 { + } else if bytesRead == maxVarintLen64 && ret >= 0 && unused != 0x00 { return 0, 0, errOverflow64 } return diff --git a/internal/leb128/leb128_test.go b/internal/leb128/leb128_test.go index 7a8bcc9a..84d29b13 100644 --- a/internal/leb128/leb128_test.go +++ b/internal/leb128/leb128_test.go @@ -28,7 +28,7 @@ func TestEncode_DecodeInt32(t *testing.T) { {input: int32(math.MaxInt32), expected: []byte{0xff, 0xff, 0xff, 0xff, 0x7}}, } { require.Equal(t, c.expected, EncodeInt32(c.input)) - decoded, _, err := DecodeInt32(bytes.NewReader(c.expected)) + decoded, _, err := LoadInt32(c.expected) require.NoError(t, err) require.Equal(t, c.input, decoded) } @@ -55,7 +55,7 @@ func TestEncode_DecodeInt64(t *testing.T) { {input: math.MaxInt64, expected: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0}}, } { require.Equal(t, c.expected, EncodeInt64(c.input)) - decoded, _, err := DecodeInt64(bytes.NewReader(c.expected)) + decoded, _, err := LoadInt64(c.expected) require.NoError(t, err) require.Equal(t, c.input, decoded) } @@ -115,7 +115,7 @@ func TestDecodeUint32(t *testing.T) { {bytes: []byte{0x82, 0x80, 0x80, 0x80, 0x70}, expErr: true}, {bytes: []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x00}, expErr: true}, } { - actual, num, err := DecodeUint32(bytes.NewReader(c.bytes)) + actual, num, err := LoadUint32(c.bytes) if c.expErr { require.Error(t, err) } else { @@ -140,7 +140,7 @@ func TestDecodeUint64(t *testing.T) { {bytes: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1}, exp: math.MaxUint64}, {bytes: []byte{0x89, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x71}, expErr: true}, } { - actual, num, err := DecodeUint64(bytes.NewReader(c.bytes)) + actual, num, err := LoadUint64(c.bytes) if c.expErr { require.Error(t, err) } else { @@ -169,7 +169,7 @@ func TestDecodeInt32(t *testing.T) { {bytes: []byte{0xff, 0xff, 0xff, 0xff, 0x4f}, expErr: true}, {bytes: []byte{0x80, 0x80, 0x80, 0x80, 0x70}, expErr: true}, } { - actual, num, err := DecodeInt32(bytes.NewReader(c.bytes)) + actual, num, err := LoadInt32(c.bytes) if c.expErr { require.Error(t, err, fmt.Sprintf("%d-th got value %d", i, actual)) } else { @@ -220,7 +220,7 @@ func TestDecodeInt64(t *testing.T) { {bytes: []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x7f}, exp: -9223372036854775808}, } { - actual, num, err := DecodeInt64(bytes.NewReader(c.bytes)) + actual, num, err := LoadInt64(c.bytes) require.NoError(t, err) require.Equal(t, c.exp, actual) require.Equal(t, uint64(len(c.bytes)), num) diff --git a/internal/wasm/binary/const_expr.go b/internal/wasm/binary/const_expr.go index 8269d317..b391e740 100644 --- a/internal/wasm/binary/const_expr.go +++ b/internal/wasm/binary/const_expr.go @@ -3,6 +3,7 @@ package binary import ( "bytes" "fmt" + "io" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/ieee754" @@ -28,9 +29,17 @@ func decodeConstantExpression(r *bytes.Reader, enabledFeatures api.CoreFeatures) // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md _, _, err = leb128.DecodeInt64(r) case wasm.OpcodeF32Const: - _, err = ieee754.DecodeFloat32(r) + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read f32 constant: %v", err) + } + _, err = ieee754.DecodeFloat32(buf) case wasm.OpcodeF64Const: - _, err = ieee754.DecodeFloat64(r) + buf := make([]byte, 8) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read f64 constant: %v", err) + } + _, err = ieee754.DecodeFloat64(buf) case wasm.OpcodeGlobalGet: _, _, err = leb128.DecodeUint32(r) case wasm.OpcodeRefNull: diff --git a/internal/wasm/binary/element.go b/internal/wasm/binary/element.go index c776645c..7876fed9 100644 --- a/internal/wasm/binary/element.go +++ b/internal/wasm/binary/element.go @@ -54,7 +54,7 @@ func decodeElementConstExprVector(r *bytes.Reader, elemType wasm.RefType, enable if elemType != wasm.RefTypeFuncref { return nil, fmt.Errorf("element type mismatch: want %s, but constexpr has funcref", wasm.RefTypeName(elemType)) } - v, _, _ := leb128.DecodeUint32(bytes.NewReader(expr.Data)) + v, _, _ := leb128.LoadUint32(expr.Data) vec[i] = &v case wasm.OpcodeRefNull: if elemType != expr.Data[0] { diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 9a31b588..dcf16337 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -32,14 +32,14 @@ func (m *Module) validateFunction(enabledFeatures api.CoreFeatures, idx Index, f } func readMemArg(pc uint64, body []byte) (align, offset uint32, read uint64, err error) { - align, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + align, num, err := leb128.LoadUint32(body[pc:]) if err != nil { err = fmt.Errorf("read memory align: %v", err) return } read += num - offset, num, err = leb128.DecodeUint32(bytes.NewReader(body[pc+num:])) + offset, num, err = leb128.LoadUint32(body[pc+num:]) if err != nil { err = fmt.Errorf("read memory offset: %v", err) return @@ -276,7 +276,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("memory must exist for %s", InstructionName(op)) } pc++ - val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + val, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -297,14 +297,14 @@ func (m *Module) validateFunctionWithMaxStackValues( pc++ switch Opcode(op) { case OpcodeI32Const: - _, num, err := leb128.DecodeInt32(bytes.NewReader(body[pc:])) + _, num, err := leb128.LoadInt32(body[pc:]) if err != nil { return fmt.Errorf("read i32 immediate: %s", err) } pc += num - 1 valueTypeStack.push(ValueTypeI32) case OpcodeI64Const: - _, num, err := leb128.DecodeInt64(bytes.NewReader(body[pc:])) + _, num, err := leb128.LoadInt64(body[pc:]) if err != nil { return fmt.Errorf("read i64 immediate: %v", err) } @@ -319,7 +319,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } } else if OpcodeLocalGet <= op && op <= OpcodeGlobalSet { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -384,7 +384,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } } else if op == OpcodeBr { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } else if int(index) >= len(controlBlockStack) { @@ -406,7 +406,7 @@ func (m *Module) validateFunctionWithMaxStackValues( valueTypeStack.unreachable() } else if op == OpcodeBrIf { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } else if int(index) >= len(controlBlockStack) { @@ -526,7 +526,7 @@ func (m *Module) validateFunctionWithMaxStackValues( valueTypeStack.unreachable() } else if op == OpcodeCall { pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -545,7 +545,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } } else if op == OpcodeCallIndirect { pc++ - typeIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + typeIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -555,7 +555,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("invalid type index at %s: %d", OpcodeCallIndirectName, typeIndex) } - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + tableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read table index: %v", err) } @@ -827,7 +827,7 @@ func (m *Module) validateFunctionWithMaxStackValues( valueTypeStack.push(ValueTypeI32) case OpcodeRefFunc: pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read function index for ref.func: %v", err) } @@ -842,7 +842,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("%s is invalid as %v", InstructionName(op), err) } pc++ - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + tableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("read immediate: %v", err) } @@ -903,7 +903,7 @@ func (m *Module) validateFunctionWithMaxStackValues( // We need to read the index to the data section. pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read data segment index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -924,7 +924,7 @@ func (m *Module) validateFunctionWithMaxStackValues( // We need to read the index to the data section. pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + index, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read data segment index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -935,7 +935,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } pc++ - val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + val, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read memory index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -945,7 +945,7 @@ func (m *Module) validateFunctionWithMaxStackValues( if miscOpcode == OpcodeMiscMemoryCopy { pc++ // memory.copy needs two memory index which are reserved as zero. - val, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + val, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read memory index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -957,7 +957,7 @@ func (m *Module) validateFunctionWithMaxStackValues( case OpcodeMiscTableInit: params = []ValueType{ValueTypeI32, ValueTypeI32, ValueTypeI32} pc++ - elementIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + elementIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read element segment index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -966,7 +966,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } pc += num - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + tableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read source table index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -988,7 +988,7 @@ func (m *Module) validateFunctionWithMaxStackValues( pc += num - 1 case OpcodeMiscElemDrop: pc++ - elementIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + elementIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read element segment index for %s: %v", MiscInstructionName(miscOpcode), err) } else if int(elementIndex) >= len(m.ElementSection) { @@ -999,7 +999,7 @@ func (m *Module) validateFunctionWithMaxStackValues( params = []ValueType{ValueTypeI32, ValueTypeI32, ValueTypeI32} pc++ - dstTableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + dstTableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read destination table index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -1013,7 +1013,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } pc += num - srcTableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + srcTableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read source table index for %s: %v", MiscInstructionName(miscOpcode), err) } @@ -1044,7 +1044,7 @@ func (m *Module) validateFunctionWithMaxStackValues( } pc++ - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) + tableIndex, num, err := leb128.LoadUint32(body[pc:]) if err != nil { return fmt.Errorf("failed to read table index for %s: %v", MiscInstructionName(miscOpcode), err) } diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 60f30b12..a3c594a2 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "errors" "fmt" + "io" "sort" "strings" @@ -372,7 +373,7 @@ func (m *Module) declaredFunctionIndexes() (ret map[Index]struct{}, err error) { for i, g := range m.GlobalSection { if g.Init.Opcode == OpcodeRefFunc { var index uint32 - index, _, err = leb128.DecodeUint32(bytes.NewReader(g.Init.Data)) + index, _, err = leb128.LoadUint32(g.Init.Data) if err != nil { err = fmt.Errorf("%s[%d] failed to initialize: %w", SectionIDName(SectionIDGlobal), i, err) return @@ -480,36 +481,35 @@ func (m *Module) validateExports(enabledFeatures api.CoreFeatures, functions []I func validateConstExpression(globals []*GlobalType, numFuncs uint32, expr *ConstantExpression, expectedType ValueType) (err error) { var actualType ValueType - r := bytes.NewReader(expr.Data) switch expr.Opcode { case OpcodeI32Const: // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md - _, _, err = leb128.DecodeInt32(r) + _, _, err = leb128.LoadInt32(expr.Data) if err != nil { return fmt.Errorf("read i32: %w", err) } actualType = ValueTypeI32 case OpcodeI64Const: // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md - _, _, err = leb128.DecodeInt64(r) + _, _, err = leb128.LoadInt64(expr.Data) if err != nil { return fmt.Errorf("read i64: %w", err) } actualType = ValueTypeI64 case OpcodeF32Const: - _, err = ieee754.DecodeFloat32(r) + _, err = ieee754.DecodeFloat32(expr.Data) if err != nil { return fmt.Errorf("read f32: %w", err) } actualType = ValueTypeF32 case OpcodeF64Const: - _, err = ieee754.DecodeFloat64(r) + _, err = ieee754.DecodeFloat64(expr.Data) if err != nil { return fmt.Errorf("read f64: %w", err) } actualType = ValueTypeF64 case OpcodeGlobalGet: - id, _, err := leb128.DecodeUint32(r) + id, _, err := leb128.LoadUint32(expr.Data) if err != nil { return fmt.Errorf("read index of global: %w", err) } @@ -518,15 +518,16 @@ func validateConstExpression(globals []*GlobalType, numFuncs uint32, expr *Const } actualType = globals[id].ValType case OpcodeRefNull: - reftype, err := r.ReadByte() - if err != nil { - return fmt.Errorf("read reference type for ref.null: %w", err) - } else if reftype != RefTypeFuncref && reftype != RefTypeExternref { + if len(expr.Data) == 0 { + return fmt.Errorf("read reference type for ref.null: %w", io.ErrShortBuffer) + } + reftype := expr.Data[0] + if reftype != RefTypeFuncref && reftype != RefTypeExternref { return fmt.Errorf("invalid type for ref.null: 0x%x", reftype) } actualType = reftype case OpcodeRefFunc: - index, _, err := leb128.DecodeUint32(r) + index, _, err := leb128.LoadUint32(expr.Data) if err != nil { return fmt.Errorf("read i32: %w", err) } else if index >= numFuncs { diff --git a/internal/wasm/store.go b/internal/wasm/store.go index e8425917..02718bd4 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -1,7 +1,6 @@ package wasm import ( - "bytes" "context" "encoding/binary" "errors" @@ -533,20 +532,19 @@ func errorInvalidImport(i *Import, idx int, err error) error { // Global initialization constant expression can only reference the imported globals. // See the note on https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#constant-expressions%E2%91%A0 func executeConstExpression(importedGlobals []*GlobalInstance, expr *ConstantExpression) (v interface{}) { - r := bytes.NewReader(expr.Data) switch expr.Opcode { case OpcodeI32Const: // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md - v, _, _ = leb128.DecodeInt32(r) + v, _, _ = leb128.LoadInt32(expr.Data) case OpcodeI64Const: // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md - v, _, _ = leb128.DecodeInt64(r) + v, _, _ = leb128.LoadInt64(expr.Data) case OpcodeF32Const: - v, _ = ieee754.DecodeFloat32(r) + v, _ = ieee754.DecodeFloat32(expr.Data) case OpcodeF64Const: - v, _ = ieee754.DecodeFloat64(r) + v, _ = ieee754.DecodeFloat64(expr.Data) case OpcodeGlobalGet: - id, _, _ := leb128.DecodeUint32(r) + id, _, _ := leb128.LoadUint32(expr.Data) g := importedGlobals[id] switch g.Type.ValType { case ValueTypeI32: @@ -573,7 +571,7 @@ func executeConstExpression(importedGlobals []*GlobalInstance, expr *ConstantExp // For ref.func const expression, we temporarily store the index as value, // and if this is the const expr for global, the value will be further downed to // opaque pointer of the engine-specific compiled function. - v, _, _ = leb128.DecodeUint32(r) + v, _, _ = leb128.LoadUint32(expr.Data) case OpcodeVecV128Const: v = [2]uint64{binary.LittleEndian.Uint64(expr.Data[0:8]), binary.LittleEndian.Uint64(expr.Data[8:16])} } diff --git a/internal/wasm/table.go b/internal/wasm/table.go index 8ec18b68..831d8940 100644 --- a/internal/wasm/table.go +++ b/internal/wasm/table.go @@ -1,7 +1,6 @@ package wasm import ( - "bytes" "context" "fmt" "math" @@ -185,7 +184,7 @@ func (m *Module) validateTable(enabledFeatures api.CoreFeatures, tables []*Table // global.get needs to be discovered during initialization oc := elem.OffsetExpr.Opcode if oc == OpcodeGlobalGet { - globalIdx, _, err := leb128.DecodeUint32(bytes.NewReader(elem.OffsetExpr.Data)) + globalIdx, _, err := leb128.LoadUint32(elem.OffsetExpr.Data) if err != nil { return nil, fmt.Errorf("%s[%d] couldn't read global.get parameter: %w", SectionIDName(SectionIDElement), idx, err) } else if err = m.verifyImportGlobalI32(SectionIDElement, idx, globalIdx); err != nil { @@ -199,7 +198,7 @@ func (m *Module) validateTable(enabledFeatures api.CoreFeatures, tables []*Table ret = append(ret, &validatedActiveElementSegment{opcode: oc, arg: globalIdx, init: elem.Init, tableIndex: elem.TableIndex}) } else if oc == OpcodeI32Const { // Treat constants as signed as their interpretation is not yet known per /RATIONALE.md - o, _, err := leb128.DecodeInt32(bytes.NewReader(elem.OffsetExpr.Data)) + o, _, err := leb128.LoadInt32(elem.OffsetExpr.Data) if err != nil { return nil, fmt.Errorf("%s[%d] couldn't read i32.const parameter: %w", SectionIDName(SectionIDElement), idx, err) } diff --git a/internal/wazeroir/compiler.go b/internal/wazeroir/compiler.go index 1363239b..3cf2df94 100644 --- a/internal/wazeroir/compiler.go +++ b/internal/wazeroir/compiler.go @@ -649,7 +649,7 @@ operatorSwitch: } case wasm.OpcodeBr: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + targetIndex, n, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -674,7 +674,7 @@ operatorSwitch: // and can be safely removed. c.markUnreachable() case wasm.OpcodeBrIf: - targetIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + targetIndex, n, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("read the target for br_if: %w", err) } @@ -790,7 +790,7 @@ operatorSwitch: if index == nil { return fmt.Errorf("index does not exist for indirect function call") } - tableIndex, n, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + tableIndex, n, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("read target for br_table: %w", err) } @@ -1096,7 +1096,7 @@ operatorSwitch: &OperationMemoryGrow{}, ) case wasm.OpcodeI32Const: - val, num, err := leb128.DecodeInt32(bytes.NewReader(c.body[c.pc+1:])) + val, num, err := leb128.LoadInt32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1105,7 +1105,7 @@ operatorSwitch: &OperationConstI32{Value: uint32(val)}, ) case wasm.OpcodeI64Const: - val, num, err := leb128.DecodeInt64(bytes.NewReader(c.body[c.pc+1:])) + val, num, err := leb128.LoadInt64(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i64.const value: %v", err) } @@ -1639,7 +1639,7 @@ operatorSwitch: ) case wasm.OpcodeRefFunc: c.pc++ - index, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc:])) + index, num, err := leb128.LoadUint32(c.body[c.pc:]) if err != nil { return fmt.Errorf("failed to read function index for ref.func: %v", err) } @@ -1659,7 +1659,7 @@ operatorSwitch: ) case wasm.OpcodeTableGet: c.pc++ - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc:]) if err != nil { return fmt.Errorf("failed to read function index for table.get: %v", err) } @@ -1669,7 +1669,7 @@ operatorSwitch: ) case wasm.OpcodeTableSet: c.pc++ - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc:]) if err != nil { return fmt.Errorf("failed to read function index for table.set: %v", err) } @@ -1714,7 +1714,7 @@ operatorSwitch: ) case wasm.OpcodeMiscMemoryInit: c.result.UsesMemory = true - dataIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + dataIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1723,7 +1723,7 @@ operatorSwitch: &OperationMemoryInit{DataIndex: dataIndex}, ) case wasm.OpcodeMiscDataDrop: - dataIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + dataIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1744,13 +1744,13 @@ operatorSwitch: &OperationMemoryFill{}, ) case wasm.OpcodeMiscTableInit: - elemIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + elemIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } c.pc += num // Read table index which is fixed to zero currently. - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1759,7 +1759,7 @@ operatorSwitch: &OperationTableInit{ElemIndex: elemIndex, TableIndex: tableIndex}, ) case wasm.OpcodeMiscElemDrop: - elemIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + elemIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1769,13 +1769,13 @@ operatorSwitch: ) case wasm.OpcodeMiscTableCopy: // Read the source table inde.g. - dst, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + dst, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } c.pc += num // Read the destination table inde.g. - src, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + src, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) @@ -1786,7 +1786,7 @@ operatorSwitch: ) case wasm.OpcodeMiscTableGrow: // Read the source table inde.g. - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1796,7 +1796,7 @@ operatorSwitch: ) case wasm.OpcodeMiscTableSize: // Read the source table inde.g. - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -1806,7 +1806,7 @@ operatorSwitch: ) case wasm.OpcodeMiscTableFill: // Read the source table index. - tableIndex, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + tableIndex, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return fmt.Errorf("reading i32.const value: %v", err) } @@ -2936,7 +2936,7 @@ func (c *compiler) applyToStack(opcode wasm.Opcode) (*uint32, error) { wasm.OpcodeGlobalGet, wasm.OpcodeGlobalSet: // Assumes that we are at the opcode now so skip it before read immediates. - v, num, err := leb128.DecodeUint32(bytes.NewReader(c.body[c.pc+1:])) + v, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return nil, fmt.Errorf("reading immediates: %w", err) } @@ -3117,13 +3117,12 @@ func (c *compiler) stackLenInUint64(ceil int) (ret int) { func (c *compiler) readMemoryArg(tag string) (*MemoryArg, error) { c.result.UsesMemory = true - r := bytes.NewReader(c.body[c.pc+1:]) - alignment, num, err := leb128.DecodeUint32(r) + alignment, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return nil, fmt.Errorf("reading alignment for %s: %w", tag, err) } c.pc += num - offset, num, err := leb128.DecodeUint32(r) + offset, num, err := leb128.LoadUint32(c.body[c.pc+1:]) if err != nil { return nil, fmt.Errorf("reading offset for %s: %w", tag, err) }