diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go index ab13ca7c..6887024e 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go @@ -215,7 +215,15 @@ func (m *machine) LowerInstr(instr *ssa.Instruction) { m.insert(undef) case ssa.OpcodeSelect: c, x, y := instr.SelectData() - m.lowerSelect(c, x, y, instr.Return()) + if x.Type() == ssa.TypeV128 { + rc := m.getOperand_NR(m.compiler.ValueDefinition(c), extModeNone) + rn := m.getOperand_NR(m.compiler.ValueDefinition(x), extModeNone) + rm := m.getOperand_NR(m.compiler.ValueDefinition(y), extModeNone) + rd := operandNR(m.compiler.VRegOf(instr.Return())) + m.lowerSelectVec(rc, rn, rm, rd) + } else { + m.lowerSelect(c, x, y, instr.Return()) + } case ssa.OpcodeClz: x := instr.Arg() result := instr.Return() @@ -1072,7 +1080,7 @@ func (m *machine) lowerVMinMaxPseudo(instr *ssa.Instruction, max bool) { if max { fcmgt.asVecRRR(vecOpFcmgt, rd, rm, rn, arr) } else { - // if min, swap the args + // If min, swap the args. fcmgt.asVecRRR(vecOpFcmgt, rd, rn, rm, arr) } m.insert(fcmgt) @@ -1818,5 +1826,27 @@ func (m *machine) lowerSelect(c, x, y, result ssa.Value) { fcsel := m.allocateInstr() fcsel.asFpuCSel(rd, rn, rm, cc, x.Type().Bits() == 64) m.insert(fcsel) + default: + panic("BUG") } } + +func (m *machine) lowerSelectVec(rc, rn, rm, rd operand) { + tmp := operandNR(m.compiler.AllocateVReg(regalloc.RegTypeInt)) + + // Sets all bits to 1 if rc is not zero. + alu := m.allocateInstr() + alu.asALU(aluOpSub, tmp, operandNR(xzrVReg), rc, true) + m.insert(alu) + + // Then move the bits to the result vector register. + dup := m.allocateInstr() + dup.asVecDup(rd, tmp, vecArrangement2D) + m.insert(dup) + + // Now that `rd` has either all bits one or zero depending on `rc`, + // we can use bsl to select between `rn` and `rm`. + ins := m.allocateInstr() + ins.asVecRRR(vecOpBsl, rd, rn, rm, vecArrangement16B) + m.insert(ins) +} diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go index 4e9b7081..1331a1c8 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go @@ -831,3 +831,23 @@ ushl x1.16b, x2.16b, x1.16b }) } } + +func TestMachine_lowerSelectVec(t *testing.T) { + _, _, m := newSetupWithMockContext() + c := operandNR(m.compiler.AllocateVReg(regalloc.RegTypeInt)) + rn := operandNR(m.compiler.AllocateVReg(regalloc.RegTypeFloat)) + rm := operandNR(m.compiler.AllocateVReg(regalloc.RegTypeFloat)) + rd := operandNR(m.compiler.AllocateVReg(regalloc.RegTypeFloat)) + + require.Equal(t, 1, int(c.reg().ID())) + require.Equal(t, 2, int(rn.reg().ID())) + require.Equal(t, 3, int(rm.reg().ID())) + require.Equal(t, 4, int(rd.reg().ID())) + + m.lowerSelectVec(c, rn, rm, rd) + require.Equal(t, ` +sub x5?, xzr, x1? +dup v4?.2d, x5? +bsl v4?.16b, v2?.16b, v3?.16b +`, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") +} diff --git a/internal/integration_test/fuzzcases/fuzzcases_test.go b/internal/integration_test/fuzzcases/fuzzcases_test.go index 8aa2f90a..b0f9aa65 100644 --- a/internal/integration_test/fuzzcases/fuzzcases_test.go +++ b/internal/integration_test/fuzzcases/fuzzcases_test.go @@ -4,10 +4,13 @@ import ( "context" "embed" "fmt" + "runtime" + "strings" "testing" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/internal/engine/wazevo" "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/testing/binaryencoding" "github.com/tetratelabs/wazero/internal/testing/require" @@ -44,9 +47,28 @@ func runWithInterpreter(t *testing.T, runner func(t *testing.T, r wazero.Runtime }) } +func runWithWazevo(t *testing.T, runner func(t *testing.T, r wazero.Runtime)) { + t.Run("wazevo", func(t *testing.T) { + name := t.Name() + for _, skipTarget := range []string{"695", "701", "718"} { + if strings.Contains(name, skipTarget) { + t.Skip("TODO: skipping for wazevo until SIMD is completed") + } + } + config := wazero.NewRuntimeConfigInterpreter() + wazevo.ConfigureWazevo(config) + r := wazero.NewRuntimeWithConfig(ctx, config) + defer r.Close(ctx) + runner(t, r) + }) +} + func run(t *testing.T, runner func(t *testing.T, r wazero.Runtime)) { runWithInterpreter(t, runner) runWithCompiler(t, runner) + if runtime.GOARCH == "arm64" { + runWithWazevo(t, runner) + } } // Test695 requires two functions to exit with "out of bounds memory access" consistently across the implementations. @@ -66,20 +88,22 @@ func Test695(t *testing.T) { } func Test696(t *testing.T) { - functionNames := [4]string{ - "select with 0 / after calling dummy", - "select with 0", - "typed select with 1 / after calling dummy", - "typed select with 1", - } - run(t, func(t *testing.T, r wazero.Runtime) { module, err := r.Instantiate(ctx, getWasmBinary(t, 696)) require.NoError(t, err) - - for _, name := range functionNames { - _, err := module.ExportedFunction(name).Call(ctx) + for _, tc := range []struct { + fnName string + in uint64 + exp [2]uint64 + }{ + {fnName: "select", in: 1, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, + {fnName: "select", in: 0, exp: [2]uint64{0x1111111111111111, 0x2222222222222222}}, + {fnName: "typed select", in: 1, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, + {fnName: "typed select", in: 0, exp: [2]uint64{0x1111111111111111, 0x2222222222222222}}, + } { + res, err := module.ExportedFunction(tc.fnName).Call(ctx, tc.in) require.NoError(t, err) + require.Equal(t, tc.exp[:], res) } }) } diff --git a/internal/integration_test/fuzzcases/testdata/696.wasm b/internal/integration_test/fuzzcases/testdata/696.wasm index e0eb7547..eab54dae 100644 Binary files a/internal/integration_test/fuzzcases/testdata/696.wasm and b/internal/integration_test/fuzzcases/testdata/696.wasm differ diff --git a/internal/integration_test/fuzzcases/testdata/696.wat b/internal/integration_test/fuzzcases/testdata/696.wat index 4aae7535..f15a6dce 100644 --- a/internal/integration_test/fuzzcases/testdata/696.wat +++ b/internal/integration_test/fuzzcases/testdata/696.wat @@ -1,64 +1,18 @@ (module (func $dummy) - (func (export "select with 0 / after calling dummy") - v128.const i64x2 0xffffffffffffffff 0xffffffffffffffff - v128.const i64x2 0xeeeeeeeeeeeeeeee 0xeeeeeeeeeeeeeeee - i32.const 0 ;; choose 0xeeeeeeeeeeeeeeee lane. + (func (export "select") (param i32) (result v128) + v128.const i64x2 0xffffffffffffffff 0xeeeeeeeeeeeeeeee + v128.const i64x2 0x1111111111111111 0x2222222222222222 + local.get 0 call 0 ;; calling dummy function before select to select - ;; check the equality. - i64x2.extract_lane 0 - i64.const 0xeeeeeeeeeeeeeeee - i64.eq - (if - (then) - (else unreachable) - ) ) - (func (export "select with 0") - v128.const i64x2 0xffffffffffffffff 0xffffffffffffffff - v128.const i64x2 0xeeeeeeeeeeeeeeee 0xeeeeeeeeeeeeeeee - i32.const 0 ;; choose 0xeeeeeeeeeeeeeeee lane. - select - ;; check the equality. - i64x2.extract_lane 0 - i64.const 0xeeeeeeeeeeeeeeee - i64.eq - (if - (then) - (else unreachable) - ) - ) - - (func (export "typed select with 1 / after calling dummy") - v128.const i64x2 0xffffffffffffffff 0xffffffffffffffff - v128.const i64x2 0xeeeeeeeeeeeeeeee 0xeeeeeeeeeeeeeeee - i32.const 1 ;; choose 0xffffffffffffffff lane. + (func (export "typed select") (param i32) (result v128) + v128.const i64x2 0xffffffffffffffff 0xeeeeeeeeeeeeeeee + v128.const i64x2 0x1111111111111111 0x2222222222222222 + local.get 0 call 0 ;; calling dummy function before select to select (result v128) - ;; check the equality. - i64x2.extract_lane 0 - i64.const 0xffffffffffffffff - i64.eq - (if - (then) - (else unreachable) - ) - ) - - (func (export "typed select with 1") - v128.const i64x2 0xffffffffffffffff 0xffffffffffffffff - v128.const i64x2 0xeeeeeeeeeeeeeeee 0xeeeeeeeeeeeeeeee - i32.const 1 ;; choose 0xffffffffffffffff lane. - select (result v128) - ;; check the equality. - i64x2.extract_lane 0 - i64.const 0xffffffffffffffff - i64.eq - (if - (then) - (else unreachable) - ) ) )