diff --git a/internal/wasm/counts_test.go b/internal/wasm/counts_test.go index c0312384..cacee95c 100644 --- a/internal/wasm/counts_test.go +++ b/internal/wasm/counts_test.go @@ -213,12 +213,12 @@ func TestModule_SectionElementCount(t *testing.T) { expected: map[string]uint32{}, }, { - name: "only name section", + name: "NameSection", input: &Module{NameSection: &NameSection{ModuleName: "simple"}}, expected: map[string]uint32{"custom": 1}, }, { - name: "type section", + name: "TypeSection", input: &Module{ TypeSection: []*FunctionType{ {}, @@ -229,7 +229,7 @@ func TestModule_SectionElementCount(t *testing.T) { expected: map[string]uint32{"type": 3}, }, { - name: "type and import section", + name: "TypeSection and ImportSection", input: &Module{ TypeSection: []*FunctionType{ {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, @@ -250,7 +250,7 @@ func TestModule_SectionElementCount(t *testing.T) { expected: map[string]uint32{"import": 2, "type": 2}, }, { - name: "type function and start section", + name: "TypeSection, FunctionSection, CodeSection, ExportSection and StartSection", input: &Module{ TypeSection: []*FunctionType{{}}, FunctionSection: []Index{0}, @@ -265,7 +265,7 @@ func TestModule_SectionElementCount(t *testing.T) { expected: map[string]uint32{"code": 1, "export": 1, "function": 1, "start": 1, "type": 1}, }, { - name: "memory and data", + name: "MemorySection and DataSection", input: &Module{ MemorySection: []*MemoryType{{Min: 1}}, DataSection: []*DataSegment{{MemoryIndex: 0, OffsetExpression: empty}}, @@ -273,7 +273,7 @@ func TestModule_SectionElementCount(t *testing.T) { expected: map[string]uint32{"data": 1, "memory": 1}, }, { - name: "table and element", + name: "TableSection and ElementSection", input: &Module{ TableSection: []*TableType{{ElemType: 0x70, Limit: &LimitsType{Min: 1}}}, ElementSection: []*ElementSegment{{TableIndex: 0, OffsetExpr: empty}}, diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index e3dfe154..eac65a90 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -31,7 +31,7 @@ func validateFunction( tables []*TableType, types []*FunctionType, maxStackValues int, - features Features, + enabledFeatures Features, ) error { // Note: In WebAssembly 1.0 (20191205), multiple memories are not allowed. hasMemory := len(memories) > 0 @@ -39,7 +39,7 @@ func validateFunction( hasTable := len(tables) > 0 // We start with the outermost control block which is for function return if the code branches into it. - controlBloclStack := []*controlBlock{{blockType: functionType}} + controlBlockStack := []*controlBlock{{blockType: functionType}} // Create the valueTypeStack to track the state of Wasm value stacks at anypoint of execution. valueTypeStack := &valueTypeStack{} @@ -348,12 +348,12 @@ func validateFunction( index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) - } else if int(index) >= len(controlBloclStack) { + } else if int(index) >= len(controlBlockStack) { return fmt.Errorf("invalid br operation: index out of range") } pc += num - 1 // Check type soundness. - target := controlBloclStack[len(controlBloclStack)-int(index)-1] + target := controlBlockStack[len(controlBlockStack)-int(index)-1] targetResultType := target.blockType.Results if target.isLoop { // Loop operation doesn't require results since the continuation is @@ -370,17 +370,17 @@ func validateFunction( index, num, err := leb128.DecodeUint32(bytes.NewReader(body[pc:])) if err != nil { return fmt.Errorf("read immediate: %v", err) - } else if int(index) >= len(controlBloclStack) { + } else if int(index) >= len(controlBlockStack) { return fmt.Errorf( "invalid ln param given for br_if: index=%d with %d for the current lable stack length", - index, len(controlBloclStack)) + index, len(controlBlockStack)) } pc += num - 1 if err := valueTypeStack.popAndVerifyType(ValueTypeI32); err != nil { return fmt.Errorf("cannot pop the required operand for br_if") } // Check type soundness. - target := controlBloclStack[len(controlBloclStack)-int(index)-1] + target := controlBlockStack[len(controlBlockStack)-int(index)-1] targetResultType := target.blockType.Results if target.isLoop { // Loop operation doesn't require results since the continuation is @@ -414,17 +414,17 @@ func validateFunction( ln, n, err := leb128.DecodeUint32(r) if err != nil { return fmt.Errorf("read immediate: %w", err) - } else if int(ln) >= len(controlBloclStack) { + } else if int(ln) >= len(controlBlockStack) { return fmt.Errorf( "invalid ln param given for br_table: ln=%d with %d for the current lable stack length", - ln, len(controlBloclStack)) + ln, len(controlBlockStack)) } pc += n + num - 1 // Check type soundness. if err := valueTypeStack.popAndVerifyType(ValueTypeI32); err != nil { return fmt.Errorf("cannot pop the required operand for br_table") } - lnLabel := controlBloclStack[len(controlBloclStack)-1-int(ln)] + lnLabel := controlBlockStack[len(controlBlockStack)-1-int(ln)] expType := lnLabel.blockType.Results if lnLabel.isLoop { // Loop operation doesn't require results since the continuation is @@ -432,10 +432,10 @@ func validateFunction( expType = []ValueType{} } for _, l := range list { - if int(l) >= len(controlBloclStack) { + if int(l) >= len(controlBlockStack) { return fmt.Errorf("invalid l param given for br_table") } - label := controlBloclStack[len(controlBloclStack)-1-int(l)] + label := controlBlockStack[len(controlBlockStack)-1-int(l)] expType2 := label.blockType.Results if label.isLoop { // Loop operation doesn't require results since the continuation is @@ -699,7 +699,7 @@ func validateFunction( } valueTypeStack.push(ValueTypeF64) case OpcodeI32Extend8S, OpcodeI32Extend16S: - if err := features.Require(FeatureSignExtensionOps); err != nil { + if err := enabledFeatures.Require(FeatureSignExtensionOps); err != nil { return fmt.Errorf("%s invalid as %v", instructionNames[op], err) } if err := valueTypeStack.popAndVerifyType(ValueTypeI32); err != nil { @@ -707,7 +707,7 @@ func validateFunction( } valueTypeStack.push(ValueTypeI32) case OpcodeI64Extend8S, OpcodeI64Extend16S, OpcodeI64Extend32S: - if err := features.Require(FeatureSignExtensionOps); err != nil { + if err := enabledFeatures.Require(FeatureSignExtensionOps); err != nil { return fmt.Errorf("%s invalid as %v", instructionNames[op], err) } if err := valueTypeStack.popAndVerifyType(ValueTypeI64); err != nil { @@ -722,7 +722,7 @@ func validateFunction( if err != nil { return fmt.Errorf("read block: %w", err) } - controlBloclStack = append(controlBloclStack, &controlBlock{ + controlBlockStack = append(controlBlockStack, &controlBlock{ startAt: pc, blockType: bt, blockTypeBytes: num, @@ -734,7 +734,7 @@ func validateFunction( if err != nil { return fmt.Errorf("read block: %w", err) } - controlBloclStack = append(controlBloclStack, &controlBlock{ + controlBlockStack = append(controlBlockStack, &controlBlock{ startAt: pc, blockType: bt, blockTypeBytes: num, @@ -747,7 +747,7 @@ func validateFunction( if err != nil { return fmt.Errorf("read block: %w", err) } - controlBloclStack = append(controlBloclStack, &controlBlock{ + controlBlockStack = append(controlBlockStack, &controlBlock{ startAt: pc, blockType: bt, blockTypeBytes: num, @@ -759,7 +759,7 @@ func validateFunction( valueTypeStack.pushStackLimit() pc += num } else if op == OpcodeElse { - bl := controlBloclStack[len(controlBloclStack)-1] + bl := controlBlockStack[len(controlBlockStack)-1] bl.elseAt = pc // Check the type soundness of the instructions *before*  entering this Eles Op. if err := valueTypeStack.popResults(bl.blockType.Results, true); err != nil { @@ -769,9 +769,9 @@ func validateFunction( // then block. valueTypeStack.resetAtStackLimit() } else if op == OpcodeEnd { - bl := controlBloclStack[len(controlBloclStack)-1] + bl := controlBlockStack[len(controlBlockStack)-1] bl.endAt = pc - controlBloclStack = controlBloclStack[:len(controlBloclStack)-1] + controlBlockStack = controlBlockStack[:len(controlBlockStack)-1] if bl.isIf && bl.elseAt <= bl.startAt { if len(bl.blockType.Results) > 0 { return fmt.Errorf("type mismatch between then and else blocks") @@ -836,7 +836,7 @@ func validateFunction( } } - if len(controlBloclStack) > 0 { + if len(controlBlockStack) > 0 { return fmt.Errorf("ill-nested block exists") } if valueTypeStack.maximumStackPointer > maxStackValues { diff --git a/internal/wasm/func_validation_test.go b/internal/wasm/func_validation_test.go index 008c2032..8d7e9af5 100644 --- a/internal/wasm/func_validation_test.go +++ b/internal/wasm/func_validation_test.go @@ -26,11 +26,11 @@ func TestValidateFunction_valueStackLimit(t *testing.T) { body = append(body, OpcodeEnd) t.Run("not exceed", func(t *testing.T) { - err := validateFunction(&FunctionType{}, body, nil, nil, nil, nil, nil, nil, max+1, Features(0)) + err := validateFunction(&FunctionType{}, body, nil, nil, nil, nil, nil, nil, max+1, Features20191205) require.NoError(t, err) }) t.Run("exceed", func(t *testing.T) { - err := validateFunction(&FunctionType{}, body, nil, nil, nil, nil, nil, nil, max, Features(0)) + err := validateFunction(&FunctionType{}, body, nil, nil, nil, nil, nil, nil, max, Features20191205) require.Error(t, err) expMsg := fmt.Sprintf("function may have %d stack values, which exceeds limit %d", valuesNum, max) require.Equal(t, expMsg, err.Error()) @@ -69,7 +69,7 @@ func TestValidateFunction_SignExtensionOps(t *testing.T) { tc := tt t.Run(InstructionName(tc.input), func(t *testing.T) { t.Run("disabled", func(t *testing.T) { - err := validateFunction(&FunctionType{}, []byte{tc.input}, nil, nil, nil, nil, nil, nil, maxStackHeight, Features(0)) + err := validateFunction(&FunctionType{}, []byte{tc.input}, nil, nil, nil, nil, nil, nil, maxStackHeight, Features20191205) require.EqualError(t, err, tc.expectedErrOnDisable) }) t.Run("enabled", func(t *testing.T) { diff --git a/internal/wasm/module.go b/internal/wasm/module.go index 094e7023..c3947f5e 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -3,6 +3,7 @@ package internalwasm import ( "bytes" "fmt" + "strings" "github.com/tetratelabs/wazero/internal/ieee754" "github.com/tetratelabs/wazero/internal/leb128" @@ -26,8 +27,8 @@ type EncodeModule func(m *Module) (bytes []byte) // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#modules%E2%91%A8 // // Differences from the specification: -// * The NameSection is decoded, so not present as a key "name" in CustomSections. -// * The ExportSection is represented as a map for lookup convenience. +// * NameSection is the only key ("name") decoded from the SectionIDCustom. +// * ExportSection is represented as a map for lookup convenience. type Module struct { // TypeSection contains the unique FunctionType of functions imported or defined in this module. // @@ -173,7 +174,7 @@ func (m *Module) TypeOfFunction(funcIdx Index) *FunctionType { return m.TypeSection[typeIdx] } -func (m *Module) Validate(features Features) error { +func (m *Module) Validate(enabledFeatures Features) error { if err := m.validateStartSection(); err != nil { return err } @@ -194,7 +195,7 @@ func (m *Module) Validate(features Features) error { return err } - if err := m.validateFunctions(functions, globals, memories, tables, features); err != nil { + if err := m.validateFunctions(functions, globals, memories, tables, enabledFeatures); err != nil { return err } @@ -212,10 +213,10 @@ func (m *Module) validateStartSection() error { startIndex := *m.StartSection ft := m.TypeOfFunction(startIndex) if ft == nil { // TODO: move this check to decoder so that a module can never be decoded invalidly - return fmt.Errorf("start function has an invalid type") + return fmt.Errorf("invalid start function: func[%d] has an invalid type", startIndex) } if len(ft.Params) > 0 || len(ft.Results) > 0 { - return fmt.Errorf("start function must have an empty (nullary) signature: %s", ft.String()) + return fmt.Errorf("invalid start function: func[%d] must have an empty (nullary) signature: %s", startIndex, ft) } } return nil @@ -237,29 +238,49 @@ func (m *Module) validateGlobals(globals []*GlobalType, maxGlobals int) error { return nil } -func (m *Module) validateFunctions(functions []Index, globals []*GlobalType, memories []*MemoryType, tables []*TableType, features Features) error { +func (m *Module) validateFunctions(functions []Index, globals []*GlobalType, memories []*MemoryType, tables []*TableType, enabledFeatures Features) error { + typeCount := m.SectionElementCount(SectionIDType) + codeCount := m.SectionElementCount(SectionIDCode) + functionCount := m.SectionElementCount(SectionIDFunction) + if codeCount != functionCount { + return fmt.Errorf("code count (%d) != function count (%d)", codeCount, functionCount) + } + // The wazero specific limitation described at RATIONALE.md. const maximumValuesOnStack = 1 << 27 - - for codeIndex, typeIndex := range m.FunctionSection { - if typeIndex >= m.SectionElementCount(SectionIDType) { - return fmt.Errorf("function type index out of range") - } else if uint32(codeIndex) >= m.SectionElementCount(SectionIDCode) { - return fmt.Errorf("code index out of range") + for idx, typeIndex := range m.FunctionSection { + if typeIndex >= typeCount { + return fmt.Errorf("invalid %s: type section index %d out of range", m.funcDesc(SectionIDFunction, Index(idx)), typeIndex) } if err := validateFunction( m.TypeSection[typeIndex], - m.CodeSection[codeIndex].Body, - m.CodeSection[codeIndex].LocalTypes, - functions, globals, memories, tables, m.TypeSection, maximumValuesOnStack, features); err != nil { - idx := m.SectionElementCount(SectionIDFunction) - 1 - return fmt.Errorf("invalid function (%d/%d): %v", codeIndex, idx, err) + m.CodeSection[idx].Body, + m.CodeSection[idx].LocalTypes, + functions, globals, memories, tables, m.TypeSection, maximumValuesOnStack, enabledFeatures); err != nil { + + return fmt.Errorf("invalid %s: %w", m.funcDesc(SectionIDFunction, Index(idx)), err) } } return nil } +func (m *Module) funcDesc(sectionID SectionID, sectionIndex Index) string { + // Try to improve the error message by collecting any exports: + var exportNames []string + funcIdx := sectionIndex + m.importCount(ExternTypeFunc) + for _, e := range m.ExportSection { + if e.Index == funcIdx && e.Type == ExternTypeFunc { + exportNames = append(exportNames, fmt.Sprintf("%q", e.Name)) + } + } + sectionIDName := SectionIDName(sectionID) + if exportNames == nil { + return fmt.Sprintf("%s[%d]", sectionIDName, sectionIndex) + } + return fmt.Sprintf("%s[%d] (export %s)", sectionIDName, sectionIndex, strings.Join(exportNames, ",")) +} + func (m *Module) validateTables(tables []*TableType, globals []*GlobalType) error { if len(tables) > 1 { return fmt.Errorf("multiple tables are not supported") @@ -468,6 +489,11 @@ type FunctionType struct { string string } +// EqualsSignature returns true if the function type has the same parameters and results. +func (t *FunctionType) EqualsSignature(params []ValueType, results []ValueType) bool { + return bytes.Equal(t.Params, params) && bytes.Equal(t.Results, results) +} + // key gets or generates the key for Store.typeIDs. Ex. "i32_v" for one i32 parameter and no (void) result. func (t *FunctionType) key() string { if t.string != "" { diff --git a/internal/wasm/module_test.go b/internal/wasm/module_test.go index 924b8b0d..dcb4aae8 100644 --- a/internal/wasm/module_test.go +++ b/internal/wasm/module_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "math" + "strings" "testing" "github.com/stretchr/testify/require" @@ -186,106 +187,6 @@ func TestModule_allDeclarations(t *testing.T) { } } -func TestModule_SectionSize(t *testing.T) { - i32, f32 := ValueTypeI32, ValueTypeF32 - zero := uint32(0) - empty := &ConstantExpression{Opcode: OpcodeI32Const, Data: []byte{0x00}} - - tests := []struct { - name string - input *Module - expected map[string]uint32 - }{ - { - name: "empty", - input: &Module{}, - expected: map[string]uint32{}, - }, - { - name: "only name section", - input: &Module{NameSection: &NameSection{ModuleName: "simple"}}, - expected: map[string]uint32{"custom": 1}, - }, - { - name: "type section", - input: &Module{ - TypeSection: []*FunctionType{ - {}, - {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, - {Params: []ValueType{i32, i32, i32, i32}, Results: []ValueType{i32}}, - }, - }, - expected: map[string]uint32{"type": 3}, - }, - { - name: "type and import section", - input: &Module{ - TypeSection: []*FunctionType{ - {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, - {Params: []ValueType{f32, f32}, Results: []ValueType{f32}}, - }, - ImportSection: []*Import{ - { - Module: "Math", Name: "Mul", - Type: ExternTypeFunc, - DescFunc: 1, - }, { - Module: "Math", Name: "Add", - Type: ExternTypeFunc, - DescFunc: 0, - }, - }, - }, - expected: map[string]uint32{"import": 2, "type": 2}, - }, - { - name: "type function and start section", - input: &Module{ - TypeSection: []*FunctionType{{}}, - FunctionSection: []Index{0}, - CodeSection: []*Code{ - {Body: []byte{OpcodeLocalGet, 0, OpcodeLocalGet, 1, OpcodeI32Add, OpcodeEnd}}, - }, - ExportSection: map[string]*Export{ - "AddInt": {Name: "AddInt", Type: ExternTypeFunc, Index: Index(0)}, - }, - StartSection: &zero, - }, - expected: map[string]uint32{"code": 1, "export": 1, "function": 1, "start": 1, "type": 1}, - }, - { - name: "memory and data", - input: &Module{ - MemorySection: []*MemoryType{{Min: 1}}, - DataSection: []*DataSegment{{MemoryIndex: 0, OffsetExpression: empty}}, - }, - expected: map[string]uint32{"data": 1, "memory": 1}, - }, - { - name: "table and element", - input: &Module{ - TableSection: []*TableType{{ElemType: 0x70, Limit: &LimitsType{Min: 1}}}, - ElementSection: []*ElementSegment{{TableIndex: 0, OffsetExpr: empty}}, - }, - expected: map[string]uint32{"element": 1, "table": 1}, - }, - } - - for _, tt := range tests { - tc := tt - - t.Run(tc.name, func(t *testing.T) { - actual := map[string]uint32{} - for i := SectionID(0); i <= SectionIDData; i++ { - if size := tc.input.SectionElementCount(i); size > 0 { - actual[SectionIDName(i)] = size - } - } - require.Equal(t, tc.expected, actual) - }) - } -} - func TestValidateConstExpression(t *testing.T) { t.Run("invalid opcode", func(t *testing.T) { expr := &ConstantExpression{Opcode: OpcodeNop} @@ -379,6 +280,35 @@ func TestValidateConstExpression(t *testing.T) { }) } +func TestModule_Validate_Errors(t *testing.T) { + zero := Index(0) + tests := []struct { + name string + input *Module + expectedErr string + }{ + { + name: "StartSection points to an invalid func", + input: &Module{ + TypeSection: nil, + FunctionSection: []uint32{0}, + CodeSection: []*Code{{Body: []byte{OpcodeEnd}}}, + StartSection: &zero, + }, + expectedErr: "invalid start function: func[0] has an invalid type", + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + err := tc.input.Validate(Features20191205) + require.EqualError(t, err, tc.expectedErr) + }) + } +} + func TestModule_validateStartSection(t *testing.T) { t.Run("no start section", func(t *testing.T) { m := Module{} @@ -463,44 +393,85 @@ func TestModule_validateGlobals(t *testing.T) { } func TestModule_validateFunctions(t *testing.T) { - t.Run("type index out of range", func(t *testing.T) { - m := Module{FunctionSection: []uint32{1000 /* arbitrary large */}} - err := m.validateFunctions(nil, nil, nil, nil, Features(0)) - require.Error(t, err) - require.Contains(t, err.Error(), "function type index out of range") - }) - t.Run("insufficient code section", func(t *testing.T) { - m := Module{ - FunctionSection: []uint32{0}, - TypeSection: []*FunctionType{{}}, - // Code section not exists. - } - err := m.validateFunctions(nil, nil, nil, nil, Features(0)) - require.Error(t, err) - require.Contains(t, err.Error(), "code index out of range") - }) - t.Run("invalid function", func(t *testing.T) { - m := Module{ - FunctionSection: []uint32{0}, - TypeSection: []*FunctionType{{}}, - CodeSection: []*Code{ - {Body: []byte{OpcodeF32Abs}}, - }, - } - err := m.validateFunctions(nil, nil, nil, nil, Features(0)) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid function (0/0): cannot pop the 1st f32 operand") - }) t.Run("ok", func(t *testing.T) { m := Module{ - FunctionSection: []uint32{0}, TypeSection: []*FunctionType{{}}, - CodeSection: []*Code{ - {Body: []byte{OpcodeI32Const, 0, OpcodeDrop, OpcodeEnd}}, + FunctionSection: []uint32{0}, + CodeSection: []*Code{{Body: []byte{OpcodeI32Const, 0, OpcodeDrop, OpcodeEnd}}}, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.NoError(t, err) + }) + t.Run("function, but no code", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{0}, + CodeSection: nil, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + require.EqualError(t, err, "code count (0) != function count (1)") + }) + t.Run("function out of range of code", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{1}, + CodeSection: []*Code{{Body: []byte{OpcodeEnd}}}, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + require.EqualError(t, err, "invalid function[0]: type section index 1 out of range") + }) + t.Run("invalid", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{0}, + CodeSection: []*Code{{Body: []byte{OpcodeF32Abs}}}, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid function[0]: cannot pop the 1st f32 operand") + }) + t.Run("invalid - exported", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{0}, + CodeSection: []*Code{{Body: []byte{OpcodeF32Abs}}}, + ExportSection: map[string]*Export{"f1": {Name: "f1", Type: ExternTypeFunc, Index: 0}}, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + require.Contains(t, err.Error(), `invalid function[0] (export "f1"): cannot pop the 1st f32 operand`) + }) + t.Run("invalid - exported after import", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + ImportSection: []*Import{{Type: ExternTypeFunc}}, + FunctionSection: []Index{0}, + CodeSection: []*Code{{Body: []byte{OpcodeF32Abs}}}, + ExportSection: map[string]*Export{"f1": {Name: "f1", Type: ExternTypeFunc, Index: 1}}, + } + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + require.Contains(t, err.Error(), `invalid function[0] (export "f1"): cannot pop the 1st f32 operand`) + }) + t.Run("invalid - exported twice", func(t *testing.T) { + m := Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{0}, + CodeSection: []*Code{{Body: []byte{OpcodeF32Abs}}}, + ExportSection: map[string]*Export{ + "f1": {Name: "f1", Type: ExternTypeFunc, Index: 0}, + "f2": {Name: "f2", Type: ExternTypeFunc, Index: 0}, }, } - err := m.validateFunctions(nil, nil, nil, nil, Features(0)) - require.NoError(t, err) + err := m.validateFunctions(nil, nil, nil, nil, Features20191205) + require.Error(t, err) + + // go map keys do not iterate consistently + if !strings.Contains(err.Error(), `invalid function[0] (export "f1","f2"): cannot pop the 1st f32 operand`) { + require.Contains(t, err.Error(), `invalid function[0] (export "f2","f1"): cannot pop the 1st f32 operand`) + } }) } @@ -718,7 +689,7 @@ func TestModule_buildMemoryInstance(t *testing.T) { t.Run("non-nil", func(t *testing.T) { min := uint32(1) max := uint32(10) - m := Module{MemorySection: []*MemoryType{&LimitsType{Min: min, Max: &max}}} + m := Module{MemorySection: []*MemoryType{{Min: min, Max: &max}}} mem := m.buildMemory() require.Equal(t, min, mem.Min) require.Equal(t, max, *mem.Max) diff --git a/internal/wasm/store.go b/internal/wasm/store.go index b650ee52..01254b62 100644 --- a/internal/wasm/store.go +++ b/internal/wasm/store.go @@ -425,16 +425,16 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error) module.buildFunctions(), module.buildGlobals(importedGlobals), module.buildTable(), module.buildMemory() // Now we have all instances from imports and local ones, so ready to create a new ModuleInstance. - instance := &ModuleInstance{Name: name} - instance.addSections(module, importedFunctions, functions, importedGlobals, + m := &ModuleInstance{Name: name} + m.addSections(module, importedFunctions, functions, importedGlobals, globals, importedTable, table, importedMemory, memory, types, moduleImports) - if err = instance.validateElements(module.ElementSection); err != nil { + if err = m.validateElements(module.ElementSection); err != nil { s.deleteModule(name) return nil, err } - if err = instance.validateData(module.DataSection); err != nil { + if err = m.validateData(module.DataSection); err != nil { s.deleteModule(name) return nil, err } @@ -446,18 +446,17 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error) if err = s.engine.Compile(f); err != nil { s.deleteModule(name) // On the failure, release the assigned funcaddr and already compiled functions. - _ = s.releaseFunctions(functions[:i]...) // ignore any release error so we can report the original one. - idx := module.SectionElementCount(SectionIDFunction) - 1 - return nil, fmt.Errorf("compilation failed at index %d/%d: %w", i, idx, err) + _ = s.releaseFunctions(functions[:i]...) // ignore any release error, so we can report the original one. + return nil, fmt.Errorf("%s compilation failed: %w", module.funcDesc(SectionIDFunction, Index(i)), err) } } // Now all the validation passes, we are safe to mutate memory/table instances (possibly imported ones). - instance.applyElements(module.ElementSection) - instance.applyData(module.DataSection) + m.applyElements(module.ElementSection) + m.applyData(module.DataSection) // Build the default context for calls to this module. - instance.Ctx = NewModuleContext(s.ctx, s.engine, instance) + m.Ctx = NewModuleContext(s.ctx, s.engine, m) // Plus, we can finalize the module import reference count. moduleImportsFinalized = true @@ -465,15 +464,15 @@ func (s *Store) Instantiate(module *Module, name string) (*PublicModule, error) // Execute the start function. if module.StartSection != nil { funcIdx := *module.StartSection - if _, err = s.engine.Call(instance.Ctx, instance.Functions[funcIdx]); err != nil { + if _, err = s.engine.Call(m.Ctx, m.Functions[funcIdx]); err != nil { s.deleteModule(name) - return nil, fmt.Errorf("module[%s] start function failed: %w", name, err) + return nil, fmt.Errorf("start %s failed: %w", module.funcDesc(SectionIDFunction, funcIdx), err) } } // Now that the instantiation is complete without error, add it. This makes it visible for import. - s.addModule(instance) - return &PublicModule{s: s, instance: instance}, nil + s.addModule(m) + return &PublicModule{s: s, instance: m}, nil } // ReleaseModule deallocates resources if a module with the given name exists. @@ -634,8 +633,8 @@ func (s *Store) HostModule(moduleName string) publicwasm.HostModule { } func (s *Store) resolveImports(module *Module) ( - functions []*FunctionInstance, globals []*GlobalInstance, - table *TableInstance, memory *MemoryInstance, + importedFunctions []*FunctionInstance, importedGlobals []*GlobalInstance, + importedTable *TableInstance, importedMemory *MemoryInstance, moduleImports map[*ModuleInstance]struct{}, err error, ) { @@ -643,89 +642,121 @@ func (s *Store) resolveImports(module *Module) ( defer s.mux.RUnlock() moduleImports = map[*ModuleInstance]struct{}{} - for _, is := range module.ImportSection { - m, ok := s.modules[is.Module] + for idx, i := range module.ImportSection { + m, ok := s.modules[i.Module] if !ok { - err = fmt.Errorf("module \"%s\" not instantiated", is.Module) + err = fmt.Errorf("module[%s] not instantiated", i.Module) return } m.incDependentCount() // TODO: check if the module is already released. See #293 moduleImports[m] = struct{}{} - var exp *ExportInstance - exp, err = m.getExport(is.Name, is.Type) + var imported *ExportInstance + imported, err = m.getExport(i.Name, i.Type) if err != nil { return } - switch is.Type { + switch i.Type { case ExternTypeFunc: - typeIndex := is.DescFunc + typeIndex := i.DescFunc + // TODO: this shouldn't be possible as invalid should fail validate if int(typeIndex) >= len(module.TypeSection) { - err = fmt.Errorf("unknown type for function import") + err = errorInvalidImport(i, idx, fmt.Errorf("function type out of range")) return } - expectedType := module.TypeSection[typeIndex] - f := exp.Function - if !bytes.Equal(expectedType.Results, f.Type.Results) || !bytes.Equal(expectedType.Params, f.Type.Params) { - err = fmt.Errorf("signature mimatch: %s != %s", expectedType, f.Type) - return - } - functions = append(functions, f) - case ExternTypeTable: - tableType := is.DescTable - table = exp.Table - if table.ElemType != tableType.ElemType { - err = fmt.Errorf("incompatible table import: element type mismatch") - return - } - if table.Min < tableType.Limit.Min { - err = fmt.Errorf("incompatible table import: minimum size mismatch") + expectedType := module.TypeSection[i.DescFunc] + importedFunction := imported.Function + + actualType := importedFunction.Type + if !expectedType.EqualsSignature(actualType.Params, actualType.Results) { + err = errorInvalidImport(i, idx, fmt.Errorf("signature mismatch: %s != %s", expectedType, actualType)) return } - if tableType.Limit.Max != nil { - if table.Max == nil { - err = fmt.Errorf("incompatible table import: maximum size mismatch") + importedFunctions = append(importedFunctions, importedFunction) + case ExternTypeTable: + expected := i.DescTable + importedTable = imported.Table + + if importedTable.ElemType != expected.ElemType { + err = errorInvalidImport(i, idx, fmt.Errorf("element type mismatch: %s != %s", + ValueTypeName(expected.ElemType), ValueTypeName(importedTable.ElemType))) + return + } + + if expected.Limit.Min > importedTable.Min { + err = errorMinSizeMismatch(i, idx, expected.Limit.Min, importedTable.Min) + return + } + + if expected.Limit.Max != nil { + expectedMax := *expected.Limit.Max + if importedTable.Max == nil { + err = errorNoMax(i, idx, expectedMax) return - } else if *table.Max > *tableType.Limit.Max { - err = fmt.Errorf("incompatible table import: maximum size mismatch") + } else if expectedMax < *importedTable.Max { + err = errorMaxSizeMismatch(i, idx, expectedMax, *importedTable.Max) return } } case ExternTypeMemory: - memoryType := is.DescMem - memory = exp.Memory - if memory.Min < memoryType.Min { - err = fmt.Errorf("incompatible memory import: minimum size mismatch") + expected := i.DescMem + importedMemory = imported.Memory + + if expected.Min > importedMemory.Min { + err = errorMinSizeMismatch(i, idx, expected.Min, importedMemory.Min) return } - if memoryType.Max != nil { - if memory.Max == nil { - err = fmt.Errorf("incompatible memory import: maximum size mismatch") + + if expected.Max != nil { + expectedMax := *expected.Max + if importedMemory.Max == nil { + err = errorNoMax(i, idx, expectedMax) return - } else if *memory.Max > *memoryType.Max { - err = fmt.Errorf("incompatible memory import: maximum size mismatch") + } else if expectedMax < *importedMemory.Max { + err = errorMaxSizeMismatch(i, idx, expectedMax, *importedMemory.Max) return } } case ExternTypeGlobal: - globalType := is.DescGlobal - g := exp.Global - if globalType.Mutable != g.Type.Mutable { - err = fmt.Errorf("incompatible global import: mutability mismatch") - return - } else if globalType.ValType != g.Type.ValType { - err = fmt.Errorf("incompatible global import: value type mismatch") + expected := i.DescGlobal + importedGlobal := imported.Global + + if expected.Mutable != importedGlobal.Type.Mutable { + err = errorInvalidImport(i, idx, fmt.Errorf("mutability mismatch: %t != %t", + expected.Mutable, importedGlobal.Type.Mutable)) return } - globals = append(globals, g) + + if expected.ValType != importedGlobal.Type.ValType { + err = errorInvalidImport(i, idx, fmt.Errorf("value type mismatch: %s != %s", + ValueTypeName(expected.ValType), ValueTypeName(importedGlobal.Type.ValType))) + return + } + importedGlobals = append(importedGlobals, importedGlobal) } } return } +func errorMinSizeMismatch(i *Import, idx int, expected, actual uint32) error { + return errorInvalidImport(i, idx, fmt.Errorf("minimum size mismatch: %d > %d", expected, actual)) +} + +func errorNoMax(i *Import, idx int, expected uint32) error { + return errorInvalidImport(i, idx, fmt.Errorf("maximum size mismatch: %d, but actual has no max", expected)) +} + +func errorMaxSizeMismatch(i *Import, idx int, expected, actual uint32) error { + return errorInvalidImport(i, idx, fmt.Errorf("maximum size mismatch: %d < %d", expected, actual)) +} + +func errorInvalidImport(i *Import, idx int, err error) error { + return fmt.Errorf("import[%d] %s[%s.%s]: %w", idx, ExternTypeName(i.Type), i.Module, i.Name, err) +} + func executeConstExpression(globals []*GlobalInstance, expr *ConstantExpression) (v interface{}) { r := bytes.NewReader(expr.Data) switch expr.Opcode { diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index c629221c..af8bec2d 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -240,7 +240,7 @@ func TestStore_concurrent(t *testing.T) { require.Len(t, s.modules, 0) } -func TestSotre_Instantiate_Errors(t *testing.T) { +func TestStore_Instantiate_Errors(t *testing.T) { const importedModuleName = "imported" const importingModuleName = "test" @@ -255,13 +255,13 @@ func TestSotre_Instantiate_Errors(t *testing.T) { _, err = s.Instantiate(&Module{ TypeSection: []*FunctionType{{}}, ImportSection: []*Import{ - // The fisrt import resolve succeeds -> increment hm.dependentCount. + // The first import resolve succeeds -> increment hm.dependentCount. {Type: ExternTypeFunc, Module: importedModuleName, Name: "fn", DescFunc: 0}, // But the second one tries to import uninitialized-module -> {Type: ExternTypeFunc, Module: "non-exist", Name: "fn", DescFunc: 0}, }, }, importingModuleName) - require.EqualError(t, err, "module \"non-exist\" not instantiated") + require.EqualError(t, err, "module[non-exist] not instantiated") // hm.dependentCount must be intact as the instantiation failed. require.Zero(t, hm.dependentCount) @@ -285,7 +285,7 @@ func TestSotre_Instantiate_Errors(t *testing.T) { {Body: []byte{OpcodeEnd}}, // FunctionIndex = 1 {Body: []byte{OpcodeEnd}}, // FunctionIndex = 2 {Body: []byte{OpcodeEnd}}, // FunctionIndex = 3 == compilation failing function. - // Functions after failrued must not be passed to engine.Release. + // Functions after failure must not be passed to engine.Release. {Body: []byte{OpcodeEnd}}, {Body: []byte{OpcodeEnd}}, {Body: []byte{OpcodeEnd}}, @@ -295,7 +295,7 @@ func TestSotre_Instantiate_Errors(t *testing.T) { {Type: ExternTypeFunc, Module: importedModuleName, Name: "fn", DescFunc: 0}, }, }, importingModuleName) - require.EqualError(t, err, "compilation failed at index 2/2: compilation failed") + require.EqualError(t, err, "function[2] compilation failed: compilation failed") // hm.dependentCount must be intact as the instantiation failed. require.Zero(t, hm.dependentCount) @@ -324,7 +324,7 @@ func TestSotre_Instantiate_Errors(t *testing.T) { {Type: ExternTypeFunc, Module: importedModuleName, Name: "fn", DescFunc: 0}, }, }, importingModuleName) - require.EqualError(t, err, "module[test] start function failed: call failed") + require.EqualError(t, err, "start function[1] failed: call failed") // hm.dependentCount must stay incremented as the instantiation itself has already succeeded. require.Equal(t, 1, hm.dependentCount) @@ -658,7 +658,7 @@ func TestStore_resolveImports(t *testing.T) { t.Run("module not instantiated", func(t *testing.T) { s := newStore() _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: "unknown", Name: "unknown"}}}) - require.EqualError(t, err, "module \"unknown\" not instantiated") + require.EqualError(t, err, "module[unknown] not instantiated") }) t.Run("export instance not found", func(t *testing.T) { s := newStore() @@ -667,24 +667,6 @@ func TestStore_resolveImports(t *testing.T) { require.EqualError(t, err, "\"unknown\" is not exported in module \"test\"") }) t.Run("func", func(t *testing.T) { - t.Run("unknwon type", func(t *testing.T) { - s := newStore() - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: {}}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeFunc, DescFunc: 100}}}) - require.EqualError(t, err, "unknown type for function import") - }) - t.Run("signature mismatch", func(t *testing.T) { - s := newStore() - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Function: &FunctionInstance{Type: &FunctionType{}}, - }}, Name: moduleName} - m := &Module{ - TypeSection: []*FunctionType{{Results: []ValueType{ValueTypeF32}}}, - ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeFunc, DescFunc: 0}}, - } - _, _, _, _, _, err := s.resolveImports(m) - require.EqualError(t, err, "signature mimatch: v_f32 != v_v") - }) t.Run("ok", func(t *testing.T) { s := newStore() f := &FunctionInstance{Type: &FunctionType{Results: []ValueType{ValueTypeF32}}} @@ -701,26 +683,26 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, functions, f) require.Equal(t, 1, s.modules[moduleName].dependentCount) }) + t.Run("type out of range", func(t *testing.T) { + s := newStore() + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: {}}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeFunc, DescFunc: 100}}}) + require.EqualError(t, err, "import[0] func[test.target]: function type out of range") + }) + t.Run("signature mismatch", func(t *testing.T) { + s := newStore() + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Function: &FunctionInstance{Type: &FunctionType{}}, + }}, Name: moduleName} + m := &Module{ + TypeSection: []*FunctionType{{Results: []ValueType{ValueTypeF32}}}, + ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeFunc, DescFunc: 0}}, + } + _, _, _, _, _, err := s.resolveImports(m) + require.EqualError(t, err, "import[0] func[test.target]: signature mismatch: v_f32 != v_v") + }) }) t.Run("global", func(t *testing.T) { - t.Run("mutability mismatch", func(t *testing.T) { - s := newStore() - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeGlobal, - Global: &GlobalInstance{Type: &GlobalType{Mutable: false}}, - }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: &GlobalType{Mutable: true}}}}) - require.EqualError(t, err, "incompatible global import: mutability mismatch") - }) - t.Run("type mismatch", func(t *testing.T) { - s := newStore() - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeGlobal, - Global: &GlobalInstance{Type: &GlobalType{ValType: ValueTypeI32}}, - }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: &GlobalType{ValType: ValueTypeF64}}}}) - require.EqualError(t, err, "incompatible global import: value type mismatch") - }) t.Run("ok", func(t *testing.T) { s := newStore() inst := &GlobalInstance{Type: &GlobalType{ValType: ValueTypeI32}} @@ -730,38 +712,26 @@ func TestStore_resolveImports(t *testing.T) { require.Contains(t, globals, inst) require.Equal(t, 1, s.modules[moduleName].dependentCount) }) + t.Run("mutability mismatch", func(t *testing.T) { + s := newStore() + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Type: ExternTypeGlobal, + Global: &GlobalInstance{Type: &GlobalType{Mutable: false}}, + }}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: &GlobalType{Mutable: true}}}}) + require.EqualError(t, err, "import[0] global[test.target]: mutability mismatch: true != false") + }) + t.Run("type mismatch", func(t *testing.T) { + s := newStore() + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Type: ExternTypeGlobal, + Global: &GlobalInstance{Type: &GlobalType{ValType: ValueTypeI32}}, + }}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: &GlobalType{ValType: ValueTypeF64}}}}) + require.EqualError(t, err, "import[0] global[test.target]: value type mismatch: f64 != i32") + }) }) t.Run("table", func(t *testing.T) { - t.Run("element type", func(t *testing.T) { - s := newStore() - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeTable, - Table: &TableInstance{ElemType: 0x00}, // Unknown! - }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: &TableType{ElemType: 0x1}}}}) - require.EqualError(t, err, "incompatible table import: element type mismatch") - }) - t.Run("minimum size mismatch", func(t *testing.T) { - s := newStore() - importTableType := &TableType{Limit: &LimitsType{Min: 2}} - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeTable, - Table: &TableInstance{Min: importTableType.Limit.Min - 1}, - }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}}) - require.EqualError(t, err, "incompatible table import: minimum size mismatch") - }) - t.Run("maximum size mismatch", func(t *testing.T) { - s := newStore() - max := uint32(10) - importTableType := &TableType{Limit: &LimitsType{Max: &max}} - s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeTable, - Table: &TableInstance{Min: importTableType.Limit.Min - 1}, - }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}}) - require.EqualError(t, err, "incompatible table import: maximum size mismatch") - }) t.Run("ok", func(t *testing.T) { s := newStore() max := uint32(10) @@ -775,29 +745,38 @@ func TestStore_resolveImports(t *testing.T) { require.Equal(t, table, tableInst) require.Equal(t, 1, s.modules[moduleName].dependentCount) }) - }) - t.Run("memory", func(t *testing.T) { + t.Run("element type", func(t *testing.T) { + s := newStore() + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Type: ExternTypeTable, + Table: &TableInstance{ElemType: 0x00}, // Unknown! + }}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: &TableType{ElemType: 0x1}}}}) + require.EqualError(t, err, "import[0] table[test.target]: element type mismatch: unknown != unknown") + }) t.Run("minimum size mismatch", func(t *testing.T) { s := newStore() - importMemoryType := &MemoryType{Min: 2} + importTableType := &TableType{Limit: &LimitsType{Min: 2}} s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeMemory, - Memory: &MemoryInstance{Min: importMemoryType.Min - 1}, + Type: ExternTypeTable, + Table: &TableInstance{Min: importTableType.Limit.Min - 1}, }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}}) - require.EqualError(t, err, "incompatible memory import: minimum size mismatch") + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}}) + require.EqualError(t, err, "import[0] table[test.target]: minimum size mismatch: 2 > 1") }) t.Run("maximum size mismatch", func(t *testing.T) { s := newStore() max := uint32(10) - importMemoryType := &MemoryType{Max: &max} + importTableType := &TableType{Limit: &LimitsType{Max: &max}} s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { - Type: ExternTypeMemory, - Memory: &MemoryInstance{}, + Type: ExternTypeTable, + Table: &TableInstance{Min: importTableType.Limit.Min - 1}, }}, Name: moduleName} - _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}}) - require.EqualError(t, err, "incompatible memory import: maximum size mismatch") + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeTable, DescTable: importTableType}}}) + require.EqualError(t, err, "import[0] table[test.target]: maximum size mismatch: 10, but actual has no max") }) + }) + t.Run("memory", func(t *testing.T) { t.Run("ok", func(t *testing.T) { s := newStore() max := uint32(10) @@ -811,6 +790,27 @@ func TestStore_resolveImports(t *testing.T) { require.Equal(t, memory, memoryInst) require.Equal(t, 1, s.modules[moduleName].dependentCount) }) + t.Run("minimum size mismatch", func(t *testing.T) { + s := newStore() + importMemoryType := &MemoryType{Min: 2} + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Type: ExternTypeMemory, + Memory: &MemoryInstance{Min: importMemoryType.Min - 1}, + }}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}}) + require.EqualError(t, err, "import[0] memory[test.target]: minimum size mismatch: 2 > 1") + }) + t.Run("maximum size mismatch", func(t *testing.T) { + s := newStore() + max := uint32(10) + importMemoryType := &MemoryType{Max: &max} + s.modules[moduleName] = &ModuleInstance{Exports: map[string]*ExportInstance{name: { + Type: ExternTypeMemory, + Memory: &MemoryInstance{}, + }}, Name: moduleName} + _, _, _, _, _, err := s.resolveImports(&Module{ImportSection: []*Import{{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}}) + require.EqualError(t, err, "import[0] memory[test.target]: maximum size mismatch: 10, but actual has no max") + }) }) } diff --git a/internal/wasm/text/decoder.go b/internal/wasm/text/decoder.go index cd3c5bc2..08233810 100644 --- a/internal/wasm/text/decoder.go +++ b/internal/wasm/text/decoder.go @@ -721,7 +721,7 @@ INLINED: // A type can be defined after its type use. Ex. (module (func (param i32)) (type (func (param i32))) // This uses an inner loop to avoid creating a large map for an edge case. for realIdx, t := range p.module.TypeSection { - if funcTypeEquals(t, inlined.Params, inlined.Results) { + if t.EqualsSignature(inlined.Params, inlined.Results) { inlinedToRealIdx[inlinedIdx] = wasm.Index(realIdx) continue INLINED } diff --git a/internal/wasm/text/typeuse_parser.go b/internal/wasm/text/typeuse_parser.go index 94daefa4..ae6cb62f 100644 --- a/internal/wasm/text/typeuse_parser.go +++ b/internal/wasm/text/typeuse_parser.go @@ -1,7 +1,6 @@ package text import ( - "bytes" "errors" "fmt" @@ -390,7 +389,7 @@ func (p *typeUseParser) end(pos callbackPosition, tok tokenType, tokenBytes []by func (p *typeUseParser) inlinedTypeIndex() wasm.Index { it := p.currentInlinedType for i, t := range p.module.TypeSection { - if funcTypeEquals(t, it.Params, it.Results) { + if t.EqualsSignature(it.Params, it.Results) { return wasm.Index(i) } } @@ -424,7 +423,7 @@ func (p *typeUseParser) typeFieldIndex() (wasm.Index, error) { // it didn't already exist. func (p *typeUseParser) maybeAddInlinedType(it *wasm.FunctionType) { for i, t := range p.inlinedTypes { - if funcTypeEquals(t, it.Params, it.Results) { + if t.EqualsSignature(it.Params, it.Results) { p.recordInlinedType(wasm.Index(i)) return } @@ -452,12 +451,8 @@ func (p *typeUseParser) recordInlinedType(inlinedIdx wasm.Index) { // >> If inline declarations are given, then their types must match the referenced function type. // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#type-uses%E2%91%A0 func requireInlinedMatchesReferencedType(typeSection []*wasm.FunctionType, index wasm.Index, params, results []wasm.ValueType) error { - if !funcTypeEquals(typeSection[index], params, results) { + if !typeSection[index].EqualsSignature(params, results) { return fmt.Errorf("inlined type doesn't match module.type[%d].func", index) } return nil } - -func funcTypeEquals(f *wasm.FunctionType, params []wasm.ValueType, results []wasm.ValueType) bool { - return bytes.Equal(f.Params, params) && bytes.Equal(f.Results, results) -}