diff --git a/internal/engine/wazevo/backend/regalloc/assign.go b/internal/engine/wazevo/backend/regalloc/assign.go index 8dc4c2ba..b15c31a3 100644 --- a/internal/engine/wazevo/backend/regalloc/assign.go +++ b/internal/engine/wazevo/backend/regalloc/assign.go @@ -10,33 +10,34 @@ import ( // This is called after coloring is done. func (a *Allocator) assignRegisters(f Function) { for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() { - info := a.blockInfoAt(blk.ID()) - lns := info.liveNodes - a.assignRegistersPerBlock(f, blk, a.vRegIDToNode, lns) + a.assignRegistersPerBlock(f, blk, a.vRegIDToNode) } } // assignRegistersPerBlock assigns real registers to virtual registers on each instruction in a block. -func (a *Allocator) assignRegistersPerBlock(f Function, blk Block, vRegIDToNode []*node, liveNodes []liveNodeInBlock) { +func (a *Allocator) assignRegistersPerBlock(f Function, blk Block, vRegIDToNode []*node) { if wazevoapi.RegAllocLoggingEnabled { fmt.Println("---------------------- assigning registers for block", blk.ID(), "----------------------") } + blkID := blk.ID() var pc programCounter for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() { - a.assignRegistersPerInstr(f, pc, instr, vRegIDToNode, liveNodes) + tree := a.blockInfos[blkID].intervalTree + a.assignRegistersPerInstr(f, pc, instr, vRegIDToNode, tree) pc += pcStride } } -func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, liveNodes []liveNodeInBlock) { +func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, tree *intervalTree) { if indirect := instr.IsIndirectCall(); instr.IsCall() || indirect { // Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since // the real VRegs are already placed in the right registers at this point. - a.collectActiveNonRealVRegsAt( + tree.collectActiveNonRealVRegsAt( // To find the all the live registers "after" call, we need to add pcDefOffset for search. pc+pcDefOffset, - liveNodes) + &a.nodes1, + ) for _, active := range a.nodes1 { if r := active.r; a.regInfo.isCallerSaved(r) { v := active.v.SetRealReg(r) @@ -100,19 +101,19 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr panic("BUG: multiple def instructions must be special cased") } - a.handleSpills(f, pc, instr, liveNodes, usesSpills, defSpill) + a.handleSpills(f, pc, instr, usesSpills, defSpill, tree) a.vs = usesSpills[:0] // for reuse. } func (a *Allocator) handleSpills( - f Function, pc programCounter, instr Instr, liveNodes []liveNodeInBlock, - usesSpills []VReg, defSpill VReg, + f Function, pc programCounter, instr Instr, + usesSpills []VReg, defSpill VReg, tree *intervalTree, ) { _usesSpills, _defSpill := len(usesSpills) > 0, defSpill.Valid() switch { case !_usesSpills && !_defSpill: // Nothing to do. case !_usesSpills && _defSpill: // Only definition is spilled. - a.collectActiveNodesAt(pc+pcDefOffset, liveNodes) + tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1) a.spillHandler.init(a.nodes1, instr) r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo) @@ -128,7 +129,7 @@ func (a *Allocator) handleSpills( f.StoreRegisterAfter(defSpill, instr) case _usesSpills: - a.collectActiveNodesAt(pc, liveNodes) + tree.collectActiveRealRegNodesAt(pc, &a.nodes1) a.spillHandler.init(a.nodes1, instr) var evicted [3]*node @@ -172,7 +173,7 @@ func (a *Allocator) handleSpills( if !defSpill.IsRealReg() { // This case, the destination register type is different from the source registers. - a.collectActiveNodesAt(pc+pcDefOffset, liveNodes) + tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1) a.spillHandler.init(a.nodes1, instr) r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo) if evictedNode != nil { @@ -232,45 +233,3 @@ func (a *Allocator) assignIndirectCall(f Function, instr Instr, vRegIDToNode []* } instr.AssignUse(0, v) } - -// collectActiveNonRealVRegsAt collects the set of active registers at the given program counter into `a.nodes1` slice by appending -// the found registers from its beginning. This excludes the VRegs backed by a real register since this is used to list the registers -// alive but not used by a call instruction. -func (a *Allocator) collectActiveNonRealVRegsAt(pc programCounter, liveNodes []liveNodeInBlock) { - nodes := a.nodes1[:0] - for i := range liveNodes { - live := &liveNodes[i] - n := live.n - if n.spill() || n.v.IsRealReg() { - continue - } - r := &n.ranges[live.rangeIndex] - if r.begin > pc { - // liveNodes are sorted by the start program counter, so we can break here. - break - } - if pc <= r.end { // pc is in the range. - nodes = append(nodes, n) - } - } - a.nodes1 = nodes -} - -func (a *Allocator) collectActiveNodesAt(pc programCounter, liveNodes []liveNodeInBlock) { - nodes := a.nodes1[:0] - for i := range liveNodes { - live := &liveNodes[i] - n := live.n - if n.assignedRealReg() != RealRegInvalid { - r := &n.ranges[live.rangeIndex] - if r.begin > pc { - // liveNodes are sorted by the start program counter, so we can break here. - break - } - if pc <= r.end { // pc is in the range. - nodes = append(nodes, n) - } - } - } - a.nodes1 = nodes -} diff --git a/internal/engine/wazevo/backend/regalloc/assign_test.go b/internal/engine/wazevo/backend/regalloc/assign_test.go index 543c6d67..93b85ea0 100644 --- a/internal/engine/wazevo/backend/regalloc/assign_test.go +++ b/internal/engine/wazevo/backend/regalloc/assign_test.go @@ -10,17 +10,22 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) { t.Run("call", func(t *testing.T) { a := NewAllocator(&RegisterInfo{CallerSavedRegisters: [RealRegsNumMax]bool{1: true, 3: true}}) pc := programCounter(5) - liveNodes := []liveNodeInBlock{ - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: RealRegInvalid, v: 0xb, ranges: []liveRange{{begin: 5, end: 20}}}}, // Spill. not save target. - {n: &node{r: 2, v: FromRealReg(1, RegTypeInt), ranges: []liveRange{{begin: 5, end: 20}}}}, // Real reg-backed VReg. not save target - {n: &node{r: 3, v: 0xc, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 4, v: 0xd, ranges: []liveRange{{begin: 5, end: 20}}}}, // real reg, but not caller saved. not save target + + tree := newIntervalTree() + liveNodes := []*node{ + {r: 1, v: 0xa}, + {r: RealRegInvalid, v: 0xb}, // Spill. not save target. + {r: 2, v: FromRealReg(1, RegTypeInt)}, // Real reg-backed VReg. not save target + {r: 3, v: 0xc}, + {r: 4, v: 0xd}, // real reg, but not caller saved. not save target + } + for _, n := range liveNodes { + tree.insert(n, 5, 20) } call := newMockInstr().asCall() blk := newMockBlock(0, call).entry() f := newMockFunction(blk) - a.assignRegistersPerInstr(f, pc, call, nil, liveNodes) + a.assignRegistersPerInstr(f, pc, call, nil, tree) require.Equal(t, 2, len(f.befores)) require.Equal(t, 2, len(f.afters)) @@ -30,20 +35,22 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) { pc := programCounter(5) functionPtrVRegID := 0x0 functionPtrVReg := VReg(functionPtrVRegID).SetRegType(RegTypeInt) - functionPtrLiveNode := liveNodeInBlock{ - n: &node{r: 0xf, v: functionPtrVReg, ranges: []liveRange{{begin: 4, end: pc /* killed at this indirect call. */}}}, + functionPtrLiveNode := &node{r: 0xf, v: functionPtrVReg} + tree := newIntervalTree() + tree.insert(functionPtrLiveNode, 4, pc) + liveNodes := []*node{ + {r: 1, v: 0xa}, + {r: 2, v: FromRealReg(1, RegTypeInt)}, // Real reg-backed VReg. not target + {r: 3, v: 0xc}, + {r: 4, v: 0xd}, // real reg, but not caller saved. not save target } - liveNodes := []liveNodeInBlock{ - functionPtrLiveNode, // Function pointer, used at this PC. not save target. - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 2, v: FromRealReg(1, RegTypeInt), ranges: []liveRange{{begin: 5, end: 20}}}}, // Real reg-backed VReg. not target - {n: &node{r: 3, v: 0xc, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 4, v: 0xd, ranges: []liveRange{{begin: 5, end: 20}}}}, // real reg, but not caller saved. not save target + for _, n := range liveNodes { + tree.insert(n, 5, 20) } callInd := newMockInstr().asIndirectCall().use(functionPtrVReg) blk := newMockBlock(0, callInd).entry() f := newMockFunction(blk) - a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode.n}, liveNodes) + a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode}, tree) require.Equal(t, 2, len(f.befores)) require.Equal(t, 2, len(f.afters)) @@ -58,18 +65,23 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) { pc := programCounter(5) functionPtrVRegID := 0x0 functionPtrVReg := VReg(functionPtrVRegID).SetRegType(RegTypeInt) - liveNodes := []liveNodeInBlock{ - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 2, v: FromRealReg(1, RegTypeInt), ranges: []liveRange{{begin: 5, end: 20}}}}, // Real reg-backed VReg. not target - {n: &node{r: 3, v: 0xc, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 4, v: 0xd, ranges: []liveRange{{begin: 5, end: 20}}}}, // real reg, but not caller saved. not save target + liveNodes := []*node{ + {r: 1, v: 0xa}, + {r: 2, v: FromRealReg(1, RegTypeInt)}, // Real reg-backed VReg. not target + {r: 3, v: 0xc}, + {r: 4, v: 0xd}, // real reg, but not caller saved. not save target } + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, 5, 20) + } + callInd := newMockInstr().asIndirectCall().use(functionPtrVReg) blk := newMockBlock(0, callInd).entry() f := newMockFunction(blk) a.assignRegistersPerInstr(f, pc, callInd, []*node{ 0: {r: RealRegInvalid}, - }, liveNodes) + }, tree) require.Equal(t, 3, len(f.befores)) require.Equal(t, 2, len(f.afters)) @@ -96,71 +108,6 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) { }) } -func TestAllocator_activeNonRealVRegsAt(t *testing.T) { - r := FromRealReg(1, RegTypeInt) - for _, tc := range []struct { - name string - lives []liveNodeInBlock - pc programCounter - want []VReg - }{ - { - name: "no live nodes", - pc: 0, - lives: []liveNodeInBlock{}, - want: []VReg{}, - }, - { - name: "no live nodes at pc", - pc: 10, - lives: []liveNodeInBlock{{n: &node{ranges: []liveRange{{begin: 100, end: 2000}}}}}, - want: []VReg{}, - }, - { - name: "one live", - pc: 10, - lives: []liveNodeInBlock{ - {n: &node{r: 2, v: 0xf, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 100, end: 2000}}}}, - }, - want: []VReg{0xf}, - }, - { - name: "three lives but one spill", - pc: 10, - lives: []liveNodeInBlock{ - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: RealRegInvalid, v: 0xb, ranges: []liveRange{{begin: 5, end: 20}}}}, // Spill. - {n: &node{r: 3, v: 0xc, ranges: []liveRange{{begin: 5, end: 20}}}}, - }, - want: []VReg{0xa, 0xc}, - }, - { - name: "three lives but one real reg-backed VReg", - pc: 10, - lives: []liveNodeInBlock{ - {n: &node{r: 1, v: 0xa, ranges: []liveRange{{begin: 5, end: 20}}}}, - {n: &node{r: 2, v: r, ranges: []liveRange{{begin: 5, end: 20}}}}, // Real reg-backed VReg. - {n: &node{r: 3, v: 0xc, ranges: []liveRange{{begin: 5, end: 20}}}}, - }, - want: []VReg{0xa, 0xc}, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - a := NewAllocator(&RegisterInfo{}) - a.collectActiveNonRealVRegsAt(tc.pc, tc.lives) - ans := a.nodes1 - - actual := make([]VReg, len(ans)) - for i, n := range ans { - actual[i] = n.v - } - require.Equal(t, tc.want, actual) - }) - } -} - func TestAllocator_handleSpills(t *testing.T) { requireInsertedInst := func(t *testing.T, f *mockFunction, before bool, index int, instr Instr, reload bool, v VReg) { lists := f.afters @@ -175,17 +122,20 @@ func TestAllocator_handleSpills(t *testing.T) { t.Run("no spills", func(t *testing.T) { a := NewAllocator(&RegisterInfo{}) - a.handleSpills(nil, 0, nil, nil, nil, VRegInvalid) + a.handleSpills(nil, 0, nil, nil, VRegInvalid, nil) }) t.Run("only def / evicted / Real reg backed VReg", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ + liveNodes := []*node{ // Real reg backed VReg. - {n: &node{r: RealRegInvalid, v: VReg(1).SetRealReg(0xa), ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xc), v: 0xc, ranges: []liveRange{{begin: pc, end: 20}}}}, + {r: RealRegInvalid, v: VReg(1).SetRealReg(0xa)}, + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xc), v: 0xc}, + } + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } - a := NewAllocator(&RegisterInfo{ AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xa, 0xb, 0xc}}, // Only live nodes are allocatable. }) @@ -194,23 +144,27 @@ func TestAllocator_handleSpills(t *testing.T) { vr := VReg(100).SetRegType(RegTypeInt) instr := newMockInstr().def(vr) - a.handleSpills(f, pc, instr, liveNodes, nil, vr) + a.handleSpills(f, pc, instr, nil, vr, tree) require.Equal(t, 1, len(instr.defs)) require.Equal(t, RealReg(0xa), instr.defs[0].RealReg()) require.Equal(t, 1, len(f.befores)) - requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].n.v) + requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].v) require.Equal(t, 2, len(f.afters)) - requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].n.v) + requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].v) requireInsertedInst(t, f, false, 1, instr, false, vr.SetRealReg(0xa)) }) t.Run("only def / evicted", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xc), v: 0xc, ranges: []liveRange{{begin: pc, end: 20}}}}, + liveNodes := []*node{ + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xc), v: 0xc}, + } + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } a := NewAllocator(&RegisterInfo{ @@ -221,23 +175,28 @@ func TestAllocator_handleSpills(t *testing.T) { vr := VReg(100).SetRegType(RegTypeInt) instr := newMockInstr().def(vr) - a.handleSpills(f, pc, instr, liveNodes, nil, vr) + a.handleSpills(f, pc, instr, nil, vr, tree) require.Equal(t, 1, len(instr.defs)) require.Equal(t, RealReg(0xb), instr.defs[0].RealReg()) require.Equal(t, 1, len(f.befores)) - requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].n.v.SetRealReg(0xb)) + requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].v.SetRealReg(0xb)) require.Equal(t, 2, len(f.afters)) - requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].n.v.SetRealReg(0xb)) + requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].v.SetRealReg(0xb)) requireInsertedInst(t, f, false, 1, instr, false, vr.SetRealReg(0xb)) }) t.Run("only def / not evicted", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xc), v: 0xc, ranges: []liveRange{{begin: pc, end: 20}}}}, + liveNodes := []*node{ + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xc), v: 0xc}, + } + + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } a := NewAllocator(&RegisterInfo{ @@ -248,7 +207,7 @@ func TestAllocator_handleSpills(t *testing.T) { vr := VReg(100).SetRegType(RegTypeInt) instr := newMockInstr().def(vr) - a.handleSpills(f, pc, instr, liveNodes, nil, vr) + a.handleSpills(f, pc, instr, nil, vr, tree) require.Equal(t, 1, len(instr.defs)) require.Equal(t, RealReg(0xf), instr.defs[0].RealReg()) @@ -259,9 +218,14 @@ func TestAllocator_handleSpills(t *testing.T) { t.Run("uses and def / not evicted / def same type", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xc), v: 0xc, ranges: []liveRange{{begin: pc, end: 20}}}}, + liveNodes := []*node{ + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xc), v: 0xc}, + } + + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } a := NewAllocator(&RegisterInfo{ @@ -277,7 +241,7 @@ func TestAllocator_handleSpills(t *testing.T) { VReg(102).SetRegType(RegTypeFloat) d1 := VReg(104).SetRegType(RegTypeFloat) instr := newMockInstr().use(u1, u2, u3).def(d1) - a.handleSpills(f, pc, instr, liveNodes, []VReg{u1, u3}, d1) + a.handleSpills(f, pc, instr, []VReg{u1, u3}, d1, tree) require.Equal(t, []VReg{u1.SetRealReg(0xa), u2, u3.SetRealReg(0xf)}, instr.uses) require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs) @@ -291,9 +255,15 @@ func TestAllocator_handleSpills(t *testing.T) { t.Run("uses and def / not evicted / def different type", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xf), v: 0xb, ranges: []liveRange{{begin: pc, end: 20}}}}, + + liveNodes := []*node{ + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xf), v: 0xb}, + } + + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } a := NewAllocator(&RegisterInfo{ @@ -307,24 +277,29 @@ func TestAllocator_handleSpills(t *testing.T) { u1 := VReg(100).SetRegType(RegTypeInt) d1 := VReg(104).SetRegType(RegTypeFloat) instr := newMockInstr().use(u1).def(d1) - a.handleSpills(f, pc, instr, liveNodes, []VReg{u1}, d1) + a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree) require.Equal(t, []VReg{u1.SetRealReg(0xa)}, instr.uses) require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs) require.Equal(t, 2, len(f.befores)) requireInsertedInst(t, f, true, 0, instr, true, u1.SetRealReg(0xa)) - requireInsertedInst(t, f, true, 1, instr, false, liveNodes[1].n.v.SetRealReg(0xf)) + requireInsertedInst(t, f, true, 1, instr, false, liveNodes[1].v.SetRealReg(0xf)) require.Equal(t, 2, len(f.afters)) - requireInsertedInst(t, f, false, 0, instr, true, liveNodes[1].n.v.SetRealReg(0xf)) + requireInsertedInst(t, f, false, 0, instr, true, liveNodes[1].v.SetRealReg(0xf)) requireInsertedInst(t, f, false, 1, instr, false, d1.SetRealReg(0xf)) }) t.Run("uses and def / evicted / def different type", func(t *testing.T) { const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xb), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xc), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xf), v: 0xb, ranges: []liveRange{{begin: pc, end: 20}}}}, + + liveNodes := []*node{ + {r: RealReg(0xb), v: 0xa}, + {r: RealReg(0xc), v: 0xa}, + {r: RealReg(0xf), v: 0xb}, + } + tree := newIntervalTree() + for _, n := range liveNodes { + tree.insert(n, pc, 20) } a := NewAllocator(&RegisterInfo{ @@ -338,45 +313,17 @@ func TestAllocator_handleSpills(t *testing.T) { u1 := VReg(100).SetRegType(RegTypeInt) d1 := VReg(104).SetRegType(RegTypeFloat) instr := newMockInstr().use(u1).def(d1) - a.handleSpills(f, pc, instr, liveNodes, []VReg{u1}, d1) + a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree) require.Equal(t, []VReg{u1.SetRealReg(0xb)}, instr.uses) require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs) require.Equal(t, 3, len(f.befores)) - requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].n.v.SetRealReg(0xb)) + requireInsertedInst(t, f, true, 0, instr, false, liveNodes[0].v.SetRealReg(0xb)) requireInsertedInst(t, f, true, 1, instr, true, u1.SetRealReg(0xb)) - requireInsertedInst(t, f, true, 2, instr, false, liveNodes[2].n.v.SetRealReg(0xf)) + requireInsertedInst(t, f, true, 2, instr, false, liveNodes[2].v.SetRealReg(0xf)) require.Equal(t, 3, len(f.afters)) - requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].n.v.SetRealReg(0xb)) - requireInsertedInst(t, f, false, 1, instr, true, liveNodes[2].n.v.SetRealReg(0xf)) + requireInsertedInst(t, f, false, 0, instr, true, liveNodes[0].v.SetRealReg(0xb)) + requireInsertedInst(t, f, false, 1, instr, true, liveNodes[2].v.SetRealReg(0xf)) requireInsertedInst(t, f, false, 2, instr, false, d1.SetRealReg(0xf)) }) } - -func TestAllocator_collectActiveNodesAt(t *testing.T) { - t.Run("no live nodes", func(t *testing.T) { - a := NewAllocator(&RegisterInfo{}) - a.nodes1 = []*node{{r: 1}, {r: 2}} // Must be cleared. - a.collectActiveNodesAt(0, nil) - require.Equal(t, 0, len(a.nodes1)) - }) - - t.Run("lives", func(t *testing.T) { - const pc = 5 - liveNodes := []liveNodeInBlock{ - {n: &node{r: RealReg(0xf), v: 0xa, ranges: []liveRange{{begin: 0, end: pc - 1}}}}, - {n: &node{r: RealReg(0x1), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0x2), v: 0xa, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0x4), v: 0xb, ranges: []liveRange{{begin: pc, end: 20}}}}, - {n: &node{r: RealReg(0xf), v: 0xa, ranges: []liveRange{{begin: 1000, end: 2000000}}}}, - } - - a := NewAllocator(&RegisterInfo{}) - a.nodes1 = []*node{{r: 1}, {r: 2}} // Must be cleared. - a.collectActiveNodesAt(pc, liveNodes) - require.Equal(t, 3, len(a.nodes1)) - require.Equal(t, liveNodes[1].n, a.nodes1[0]) - require.Equal(t, liveNodes[2].n, a.nodes1[1]) - require.Equal(t, liveNodes[3].n, a.nodes1[2]) - }) -} diff --git a/internal/engine/wazevo/backend/regalloc/coloring.go b/internal/engine/wazevo/backend/regalloc/coloring.go index 841bcaf1..00f3245d 100644 --- a/internal/engine/wazevo/backend/regalloc/coloring.go +++ b/internal/engine/wazevo/backend/regalloc/coloring.go @@ -8,38 +8,48 @@ import ( ) // buildNeighbors builds the neighbors for each node in the interference graph. -// TODO: node coalescing by leveraging the info given by Instr.IsCopy(). -func (a *Allocator) buildNeighbors(f Function) { - for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() { - lives := a.blockInfos[blk.ID()].liveNodes - a.buildNeighborsByLiveNodes(lives) +func (a *Allocator) buildNeighbors() { + allocated := a.nodePool.Allocated() + if diff := allocated - len(a.dedup); diff > 0 { + a.dedup = append(a.dedup, make([]bool, diff+1)...) + } + for i := 0; i < allocated; i++ { + n := a.nodePool.View(i) + a.buildNeighborsFor(n) } } -func (a *Allocator) buildNeighborsByLiveNodes(lives []liveNodeInBlock) { - if len(lives) == 0 { - // TODO: shouldn't this kind of block be removed before reg alloc? - return - } - for i, src := range lives[:len(lives)-1] { - srcRange := &src.n.ranges[src.rangeIndex] - for _, dst := range lives[i+1:] { - srcN, dstN := src.n, dst.n - if dst == src || dstN == srcN { - panic(fmt.Sprintf("BUG: %s and %s are the same node", src.n.v, dst.n.v)) +func (a *Allocator) buildNeighborsFor(n *node) { + for _, r := range n.ranges { + // Collects all the nodes that are in the same range. + for _, neighbor := range r.nodes { + neighborID := neighbor.id + if neighbor.v.RegType() != n.v.RegType() { + continue } - dstRange := &dst.n.ranges[dst.rangeIndex] - if dstRange.begin > srcRange.end { - // liveNodes are sorted by the start program counter, so we can break here. - break - } - - if srcN.v.RegType() == dstN.v.RegType() && // Interfere only if they are the same type. - srcRange.intersects(dstRange) { - srcN.neighbors = append(srcN.neighbors, dst.n) - dstN.neighbors = append(dstN.neighbors, src.n) + if neighbor != n && !a.dedup[neighborID] { + n.neighbors = append(n.neighbors, neighbor) + a.dedup[neighborID] = true } } + + // And also collects all the nodes that are in the neighbor ranges. + for _, neighborInterval := range r.neighbors { + for _, neighbor := range neighborInterval.nodes { + if neighbor.v.RegType() != n.v.RegType() { + continue + } + neighborID := neighbor.id + if neighbor != n && !a.dedup[neighborID] { + n.neighbors = append(n.neighbors, neighbor) + a.dedup[neighborID] = true + } + } + } + } + // Reset for the next iteration. + for _, neighbor := range n.neighbors { + a.dedup[neighbor.id] = false } } diff --git a/internal/engine/wazevo/backend/regalloc/coloring_test.go b/internal/engine/wazevo/backend/regalloc/coloring_test.go index 713e3f8a..9b2a837b 100644 --- a/internal/engine/wazevo/backend/regalloc/coloring_test.go +++ b/internal/engine/wazevo/backend/regalloc/coloring_test.go @@ -1,97 +1,13 @@ package regalloc import ( + "sort" + "strconv" "testing" "github.com/tetratelabs/wazero/internal/testing/require" ) -func TestAllocator_buildNeighborsByLiveNodes(t *testing.T) { - for _, tc := range []struct { - name string - lives []liveNodeInBlock - expectedEdges [][2]int - }{ - {name: "empty", lives: []liveNodeInBlock{}}, - { - name: "one node", - lives: []liveNodeInBlock{ - {rangeIndex: 0, n: &node{ranges: []liveRange{{begin: 0, end: 1}}}}, - }, - }, - { - name: "no overlap", - lives: []liveNodeInBlock{ - {rangeIndex: 4, n: &node{ranges: []liveRange{ - {}, {}, {}, {}, {begin: 0, end: 1}, - }}}, - {rangeIndex: 1, n: &node{v: VReg(0).SetRegType(RegTypeInt), ranges: []liveRange{ - {}, {begin: 2, end: 3}, - }}}, - // This overlaps with the above, but is not the same type. - {rangeIndex: 0, n: &node{v: VReg(1).SetRegType(RegTypeFloat), ranges: []liveRange{ - {begin: 2, end: 3}, - }}}, - }, - }, - { - name: "overlap", - lives: []liveNodeInBlock{ - {rangeIndex: 0, n: &node{v: VReg(0).SetRegType(RegTypeInt), ranges: []liveRange{ - {begin: 2, end: 50}, - }}}, - {rangeIndex: 0, n: &node{v: VReg(1).SetRegType(RegTypeInt), ranges: []liveRange{ - {begin: 2, end: 3}, - }}}, - // This overlaps with the above, but is not the same type. - {rangeIndex: 0, n: &node{v: VReg(2).SetRegType(RegTypeFloat), ranges: []liveRange{ - {begin: 2, end: 100}, - }}}, - {rangeIndex: 0, n: &node{v: VReg(3).SetRegType(RegTypeFloat), ranges: []liveRange{ - {begin: 100, end: 100}, - }}}, - }, - expectedEdges: [][2]int{ - {0, 1}, {2, 3}, - }, - }, - } { - t.Run(tc.name, func(t *testing.T) { - a := NewAllocator(&RegisterInfo{}) - - a.buildNeighborsByLiveNodes(tc.lives) - - expectedNeighborCounts := map[*node]int{} - for _, edge := range tc.expectedEdges { - i1, i2 := edge[0], edge[1] - n1, n2 := tc.lives[i1].n, tc.lives[i2].n - - var found bool - for _, neighbor := range n2.neighbors { - if neighbor == n1 { - found = true - break - } - } - require.True(t, found) - found = false - for _, neighbor := range n1.neighbors { - if neighbor == n2 { - found = true - break - } - } - require.True(t, found) - expectedNeighborCounts[n1]++ - expectedNeighborCounts[n2]++ - } - for _, n := range tc.lives { - require.Equal(t, expectedNeighborCounts[n.n], len(n.n.neighbors)) - } - }) - } -} - func TestAllocator_collectNodesByRegType(t *testing.T) { a := NewAllocator(&RegisterInfo{}) n1 := a.allocateNode() @@ -319,3 +235,74 @@ func TestAllocator_assignColor(t *testing.T) { require.True(t, ok) }) } + +func TestAllocator_buildNeighbors(t *testing.T) { + a := NewAllocator(&RegisterInfo{}) + a.dedup = make([]bool, 1000) // Enough large. + + newNode := func(id int) *node { + return &node{id: id} + } + + newNodes := func(ids ...int) []*node { + var ns []*node + for _, id := range ids { + ns = append(ns, newNode(id)) + } + return ns + } + + for i, tc := range []struct { + n *node + exp []int + }{ + {n: newNode(0)}, + { + n: &node{ + ranges: []*intervalTreeNode{ + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(4, 5, 1, 2, 3)}, + }, + }, + exp: []int{1, 2, 3, 4, 5}, + }, + { + n: &node{ + ranges: []*intervalTreeNode{ + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(4, 5, 1, 2, 3)}, + }, + }, + exp: []int{1, 2, 3, 4, 5}, + }, + { + n: &node{ + ranges: []*intervalTreeNode{ + {nodes: newNodes(1, 2, 3)}, + {nodes: newNodes(4), neighbors: []*intervalTreeNode{ + {nodes: newNodes(5, 6)}, + {nodes: newNodes(100, 200)}, + }}, + }, + }, + exp: []int{1, 2, 3, 4, 5, 6, 100, 200}, + }, + } { + tc := tc + t.Run(strconv.Itoa(i), func(t *testing.T) { + a.buildNeighborsFor(tc.n) + var collected []int + for _, nei := range tc.n.neighbors { + collected = append(collected, nei.id) + } + sort.Ints(collected) + require.Equal(t, tc.exp, collected) + for i := range a.dedup { + require.False(t, a.dedup[i]) // must be cleaned up. + } + }) + } +} diff --git a/internal/engine/wazevo/backend/regalloc/interval_tree.go b/internal/engine/wazevo/backend/regalloc/interval_tree.go new file mode 100644 index 00000000..7d74378c --- /dev/null +++ b/internal/engine/wazevo/backend/regalloc/interval_tree.go @@ -0,0 +1,149 @@ +package regalloc + +import "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi" + +type intervalTree struct { + root *intervalTreeNode + allocator intervalTreeNodeAllocator + intervals map[uint64]*intervalTreeNode +} + +func intervalTreeNodeKey(begin, end programCounter) uint64 { + return uint64(begin) | uint64(end)<<32 +} + +func (t *intervalTree) insert(n *node, begin, end programCounter) *intervalTreeNode { + key := uint64(begin) | uint64(end)<<32 + if i, ok := t.intervals[key]; ok { + i.nodes = append(i.nodes, n) + return i + } + t.root = t.root.insert(t, n, begin, end) + ret := t.intervals[key] + t.buildNeighbors(ret) // TODO: this can be done while inserting. + return ret +} + +func (t *intervalTree) reset() { + t.root = nil + t.allocator.Reset() + t.intervals = make(map[uint64]*intervalTreeNode) +} + +func newIntervalTree() *intervalTree { + return &intervalTree{ + allocator: wazevoapi.NewPool[intervalTreeNode](), + intervals: make(map[uint64]*intervalTreeNode), + } +} + +type intervalTreeNodeAllocator = wazevoapi.Pool[intervalTreeNode] + +type intervalTreeNode struct { + begin, end programCounter + nodes []*node + maxEnd programCounter + neighbors []*intervalTreeNode + left, right *intervalTreeNode + // TODO: color for red-black balancing. +} + +func (i *intervalTreeNode) insert(t *intervalTree, n *node, begin, end programCounter) *intervalTreeNode { + if i == nil { + intervalNode := t.allocator.Allocate() + intervalNode.right = nil + intervalNode.left = nil + intervalNode.nodes = append(intervalNode.nodes, n) + intervalNode.maxEnd = end + intervalNode.begin = begin + intervalNode.end = end + key := uint64(begin) | uint64(end)<<32 + t.intervals[key] = intervalNode + return intervalNode + } + if begin < i.begin { + i.left = i.left.insert(t, n, begin, end) + } else { + i.right = i.right.insert(t, n, begin, end) + } + if i.maxEnd < end { + i.maxEnd = end + } + + // TODO: balancing logic so that collection functions are faster. + + return i +} + +func (t *intervalTree) buildNeighbors(from *intervalTreeNode) { + t.root.buildNeighbors(from) +} + +func (i *intervalTreeNode) buildNeighbors(from *intervalTreeNode) { + if i == nil { + return + } + if i.intersects(from) { + from.neighbors = append(from.neighbors, i) + i.neighbors = append(i.neighbors, from) + } + if i.left != nil && i.left.maxEnd >= from.begin { + i.left.buildNeighbors(from) + } + if i.begin <= from.end { + i.right.buildNeighbors(from) + } +} + +func (t *intervalTree) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) { + *overlaps = (*overlaps)[:0] + t.root.collectActiveNonRealVRegsAt(pc, overlaps) +} + +func (i *intervalTreeNode) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) { + if i == nil { + return + } + if i.begin <= pc && i.end >= pc { + for _, n := range i.nodes { + if n.spill() || n.v.IsRealReg() { + continue + } + *overlaps = append(*overlaps, n) + } + } + if i.left != nil && i.left.maxEnd >= pc { + i.left.collectActiveNonRealVRegsAt(pc, overlaps) + } + if i.begin <= pc { + i.right.collectActiveNonRealVRegsAt(pc, overlaps) + } +} + +func (t *intervalTree) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) { + *overlaps = (*overlaps)[:0] + t.root.collectActiveRealRegNodesAt(pc, overlaps) +} + +func (i *intervalTreeNode) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) { + if i == nil { + return + } + if i.begin <= pc && i.end >= pc { + for _, n := range i.nodes { + if n.assignedRealReg() != RealRegInvalid { + *overlaps = append(*overlaps, n) + } + } + } + if i.left != nil && i.left.maxEnd >= pc { + i.left.collectActiveRealRegNodesAt(pc, overlaps) + } + if i.begin <= pc { + i.right.collectActiveRealRegNodesAt(pc, overlaps) + } +} + +func (i *intervalTreeNode) intersects(j *intervalTreeNode) bool { + return i.begin <= j.end && i.end >= j.begin || j.begin <= i.end && j.end >= i.begin +} diff --git a/internal/engine/wazevo/backend/regalloc/interval_tree_test.go b/internal/engine/wazevo/backend/regalloc/interval_tree_test.go new file mode 100644 index 00000000..a2b195c4 --- /dev/null +++ b/internal/engine/wazevo/backend/regalloc/interval_tree_test.go @@ -0,0 +1,385 @@ +package regalloc + +import ( + "fmt" + "sort" + "testing" + + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func TestIntervalTree_reset(t *testing.T) { + tree := newIntervalTree() + n := tree.allocator.Allocate() + tree.root = n + tree.reset() + + require.Nil(t, tree.root) + require.Equal(t, 0, tree.allocator.Allocated()) +} + +func TestIntervalTreeInsert(t *testing.T) { + n1 := &node{} + tree := newIntervalTree() + tree.insert(n1, 100, 200) + require.NotNil(t, tree.root) + require.NotNil(t, tree.root.nodes) + require.Equal(t, n1, tree.root.nodes[0]) + n, ok := tree.intervals[intervalTreeNodeKey(100, 200)] + require.True(t, ok) + require.Equal(t, n1, n.nodes[0]) +} + +func TestIntervalTreeNodeInsert(t *testing.T) { + t.Run("nil", func(t *testing.T) { + tree := newIntervalTree() + var n *intervalTreeNode + allocated := n.insert(tree, &node{}, 0, 100) + require.Equal(t, 1, tree.allocator.Allocated()) + require.NotNil(t, allocated) + require.Equal(t, allocated, tree.allocator.View(0)) + require.Equal(t, programCounter(100), allocated.maxEnd) + require.Equal(t, programCounter(0), allocated.begin) + require.Equal(t, programCounter(100), allocated.end) + require.Equal(t, 1, len(allocated.nodes)) + n, ok := tree.intervals[intervalTreeNodeKey(0, 100)] + require.True(t, ok) + require.Equal(t, allocated, n) + }) + t.Run("left", func(t *testing.T) { + tree := newIntervalTree() + n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100} + n1 := &node{} + self := n.insert(tree, n1, 0, 200) + require.Equal(t, self, n) + require.Equal(t, 1, tree.allocator.Allocated()) + left := tree.allocator.View(0) + require.Equal(t, n.left, left) + require.Nil(t, n.right) + require.Equal(t, programCounter(200), n.maxEnd) + require.Equal(t, left.nodes[0], n1) + }) + t.Run("right", func(t *testing.T) { + tree := newIntervalTree() + n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100} + n1 := &node{} + self := n.insert(tree, n1, 150, 200) + require.Equal(t, self, n) + require.Equal(t, 1, tree.allocator.Allocated()) + right := tree.allocator.View(0) + require.Equal(t, n.right, right) + require.Nil(t, n.left) + require.Equal(t, programCounter(200), n.maxEnd) + require.Equal(t, right.nodes[0], n1) + }) +} + +type ( + interval struct { + begin, end programCounter + id int + } + queryCase struct { + query programCounter + exp []int + } +) + +func newQueryCase(s programCounter, exp ...int) queryCase { + return queryCase{query: s, exp: exp} +} + +func TestIntervalTree_collectActiveNodesAt(t *testing.T) { + for _, tc := range []struct { + name string + intervals []interval + queryCases []queryCase + }{ + { + name: "single", + intervals: []interval{{begin: 0, end: 100, id: 0}}, + queryCases: []queryCase{ + newQueryCase(0, 0), + newQueryCase(0, 0), + newQueryCase(1, 0), + newQueryCase(1, 0), + newQueryCase(101), + }, + }, + { + name: "single/2", + intervals: []interval{{begin: 50, end: 100, id: 0}}, + queryCases: []queryCase{ + newQueryCase(50, 0), + newQueryCase(50, 0), + newQueryCase(51, 0), + newQueryCase(51, 0), + newQueryCase(101), + newQueryCase(48), + }, + }, + { + name: "same id for different intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}}, + queryCases: []queryCase{ + newQueryCase(0), + newQueryCase(50, 0xa), + newQueryCase(101), + newQueryCase(150, 0xa), + newQueryCase(101), + }, + }, + { + name: "two disjoint intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}}, + queryCases: []queryCase{ + newQueryCase(50, 0xa), + newQueryCase(0), + newQueryCase(51, 0xa), + newQueryCase(101), + newQueryCase(150, 0xb), + newQueryCase(200, 0xb), + newQueryCase(101), + newQueryCase(201), + }, + }, + { + name: "two intersecting intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}}, + queryCases: []queryCase{ + newQueryCase(0), + newQueryCase(70, 0xa, 0xb), + newQueryCase(1), + newQueryCase(50, 0xa), + newQueryCase(51, 0xa, 0xb), + newQueryCase(100, 0xa, 0xb), + newQueryCase(101, 0xb), + newQueryCase(49), + newQueryCase(1001), + }, + }, + { + name: "two enclosing interval", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}}, + queryCases: []queryCase{ + newQueryCase(24), + newQueryCase(25, 0xb), + newQueryCase(39, 0xb), + newQueryCase(40, 0xb, 0xc), + newQueryCase(100, 0xa, 0xb, 0xc), + newQueryCase(99, 0xa, 0xb, 0xc), + newQueryCase(100, 0xa, 0xb, 0xc), + newQueryCase(50, 0xa, 0xb, 0xc), + newQueryCase(51, 0xa, 0xb, 0xc), + newQueryCase(101, 0xb, 0xc), + newQueryCase(49, 0xb, 0xc), + newQueryCase(200, 0xb, 0xc), + newQueryCase(201, 0xc), + newQueryCase(1000, 0xc), + newQueryCase(1001), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tree := newIntervalTree() + var maxID int + for _, inter := range tc.intervals { + n := &node{id: inter.id, r: RealReg(1)} + tree.insert(n, inter.begin, inter.end) + if maxID < inter.id { + maxID = inter.id + } + key := intervalTreeNodeKey(inter.begin, inter.end) + inserted := tree.intervals[key] + inserted.nodes = append(inserted.nodes, &node{v: VRegInvalid.SetRealReg(RealRegInvalid)}) // non-real reg should be ignored. + } + for _, qc := range tc.queryCases { + t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) { + var collected []*node + tree.collectActiveRealRegNodesAt(qc.query, &collected) + require.Equal(t, len(qc.exp), len(collected)) + var foundIDs []int + for _, n := range collected { + foundIDs = append(foundIDs, n.id) + } + sort.Slice(foundIDs, func(i, j int) bool { + return foundIDs[i] < foundIDs[j] + }) + require.Equal(t, qc.exp, foundIDs) + }) + } + }) + } +} + +func TestIntervalTree_collectActiveNonRealVRegsAt(t *testing.T) { + for _, tc := range []struct { + name string + intervals []interval + queryCases []queryCase + }{ + { + name: "single", + intervals: []interval{{begin: 0, end: 100, id: 0}}, + queryCases: []queryCase{ + newQueryCase(0, 0), + newQueryCase(0, 0), + newQueryCase(1, 0), + newQueryCase(1, 0), + newQueryCase(101), + }, + }, + { + name: "single/2", + intervals: []interval{{begin: 50, end: 100, id: 0}}, + queryCases: []queryCase{ + newQueryCase(50, 0), + newQueryCase(50, 0), + newQueryCase(51, 0), + newQueryCase(51, 0), + newQueryCase(101), + newQueryCase(48), + }, + }, + { + name: "same id for different intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}}, + queryCases: []queryCase{ + newQueryCase(0), + newQueryCase(50, 0xa), + newQueryCase(101), + newQueryCase(150, 0xa), + newQueryCase(101), + }, + }, + { + name: "two disjoint intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}}, + queryCases: []queryCase{ + newQueryCase(50, 0xa), + newQueryCase(0), + newQueryCase(51, 0xa), + newQueryCase(101), + newQueryCase(150, 0xb), + newQueryCase(200, 0xb), + newQueryCase(101), + newQueryCase(201), + }, + }, + { + name: "two intersecting intervals", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}}, + queryCases: []queryCase{ + newQueryCase(0), + newQueryCase(70, 0xa, 0xb), + newQueryCase(1), + newQueryCase(50, 0xa), + newQueryCase(51, 0xa, 0xb), + newQueryCase(100, 0xa, 0xb), + newQueryCase(101, 0xb), + newQueryCase(49), + newQueryCase(1001), + }, + }, + { + name: "two enclosing interval", + intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}}, + queryCases: []queryCase{ + newQueryCase(24), + newQueryCase(25, 0xb), + newQueryCase(39, 0xb), + newQueryCase(40, 0xb, 0xc), + newQueryCase(100, 0xa, 0xb, 0xc), + newQueryCase(99, 0xa, 0xb, 0xc), + newQueryCase(100, 0xa, 0xb, 0xc), + newQueryCase(50, 0xa, 0xb, 0xc), + newQueryCase(51, 0xa, 0xb, 0xc), + newQueryCase(101, 0xb, 0xc), + newQueryCase(49, 0xb, 0xc), + newQueryCase(200, 0xb, 0xc), + newQueryCase(201, 0xc), + newQueryCase(1000, 0xc), + newQueryCase(1001), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tree := newIntervalTree() + var maxID int + for _, inter := range tc.intervals { + n := &node{id: inter.id, r: RealReg(1)} + tree.insert(n, inter.begin, inter.end) + if maxID < inter.id { + maxID = inter.id + } + key := intervalTreeNodeKey(inter.begin, inter.end) + inserted := tree.intervals[key] + // They are ignored. + inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeInt)}) + inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeFloat)}) + inserted.nodes = append(inserted.nodes, &node{v: VReg(1)}) + } + for _, qc := range tc.queryCases { + t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) { + var collected []*node + tree.collectActiveNonRealVRegsAt(qc.query, &collected) + require.Equal(t, len(qc.exp), len(collected)) + var foundIDs []int + for _, n := range collected { + foundIDs = append(foundIDs, n.id) + } + sort.Slice(foundIDs, func(i, j int) bool { + return foundIDs[i] < foundIDs[j] + }) + require.Equal(t, qc.exp, foundIDs) + }) + } + }) + } +} + +func TestIntervalTreeNode_intersects(t *testing.T) { + for _, tc := range []struct { + rhs, lhs intervalTreeNode + exp bool + }{ + { + rhs: intervalTreeNode{begin: 0, end: 100}, + lhs: intervalTreeNode{begin: 0, end: 100}, + exp: true, + }, + { + rhs: intervalTreeNode{begin: 0, end: 100}, + lhs: intervalTreeNode{begin: 0, end: 99}, + exp: true, + }, + { + rhs: intervalTreeNode{begin: 0, end: 100}, + lhs: intervalTreeNode{begin: 1, end: 100}, + exp: true, + }, + { + rhs: intervalTreeNode{begin: 50, end: 100}, + lhs: intervalTreeNode{begin: 1, end: 49}, + exp: false, + }, + { + rhs: intervalTreeNode{begin: 50, end: 100}, + lhs: intervalTreeNode{begin: 1, end: 50}, + exp: true, + }, + { + rhs: intervalTreeNode{begin: 50, end: 100}, + lhs: intervalTreeNode{begin: 1, end: 51}, + exp: true, + }, + { + rhs: intervalTreeNode{begin: 50, end: 100}, + lhs: intervalTreeNode{begin: 99, end: 102}, + exp: true, + }, + } { + actual := tc.rhs.intersects(&tc.lhs) + require.Equal(t, tc.exp, actual) + } +} diff --git a/internal/engine/wazevo/backend/regalloc/regalloc.go b/internal/engine/wazevo/backend/regalloc/regalloc.go index 3cff49f3..9c9efbbf 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc.go @@ -69,6 +69,7 @@ type ( nodes1 []*node nodes2 []*node nodes3 []*node + dedup []bool } // blockInfo is a per-block information used during the register allocation. @@ -79,23 +80,16 @@ type ( lastUses map[VReg]programCounter kills map[VReg]programCounter // Pre-colored real registers can have multiple live ranges in one block. - realRegUses map[VReg][]programCounter - realRegDefs map[VReg][]programCounter - liveNodes []liveNodeInBlock + realRegUses map[VReg][]programCounter + realRegDefs map[VReg][]programCounter + intervalTree *intervalTree } - liveNodeInBlock struct { - // rangeIndex is the index to n.ranges which represents the live range of n.v in the block. - rangeIndex int - n *node - } - - // node represents a node interference graph of LiveRange(s) of VReg(s). + // node represents a VReg. node struct { - v VReg - // ranges holds the live ranges of this node per block. This will be accessed by - // liveNodeInBlock.rangeIndex, which in turn is stored in blockInfo.liveNodes. - ranges []liveRange + id int + v VReg + ranges []*intervalTreeNode // r is the real register assigned to this node. It is either a pre-colored register or a register assigned during allocation. r RealReg // neighbors are the nodes that this node interferes with. Such neighbors have the same RegType as this node. @@ -109,21 +103,15 @@ type ( visited bool } - // liveRange represents a lifetime of a VReg. Both begin (LiveInterval[0]) and end (LiveInterval[1]) are inclusive. - liveRange struct { - blockID int - begin, end programCounter - } - // programCounter represents an opaque index into the program which is used to represents a LiveInterval of a VReg. - programCounter int64 + programCounter int32 ) // DoAllocation performs register allocation on the given Function. func (a *Allocator) DoAllocation(f Function) { a.livenessAnalysis(f) a.buildLiveRanges(f) - a.buildNeighbors(f) + a.buildNeighbors() a.coloring() a.determineCalleeSavedRealRegs(f) a.assignRegisters(f) @@ -310,21 +298,12 @@ func (a *Allocator) buildLiveRanges(f Function) { for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() { // Order doesn't matter. blkID := blk.ID() info := a.blockInfoAt(blkID) - a.buildLiveRangesForNonReals(blkID, info) - a.buildLiveRangesForReals(blkID, info) - // Sort the live range for a fast lookup to find live registers at a given program counter. - sort.Slice(info.liveNodes, func(i, j int) bool { - inode, jnode := &info.liveNodes[i], &info.liveNodes[j] - irange, jrange := inode.n.ranges[inode.rangeIndex], jnode.n.ranges[jnode.rangeIndex] - if irange.begin == jrange.begin { - return irange.end < jrange.end - } - return irange.begin < jrange.begin - }) + a.buildLiveRangesForNonReals(info) + a.buildLiveRangesForReals(info) } } -func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { +func (a *Allocator) buildLiveRangesForNonReals(info *blockInfo) { ins, outs, defs, kills := info.liveIns, info.liveOuts, info.defs, info.kills // In order to do the deterministic allocation, we need to sort ins. @@ -342,7 +321,7 @@ func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { var begin, end programCounter if _, ok := outs[v]; ok { // v is live-in and live-out, so it is live-through. - begin, end = 0, math.MaxInt64 + begin, end = 0, math.MaxInt32 if _, ok := kills[v]; ok { panic("BUG: v is live-out but also killed") } @@ -355,9 +334,8 @@ func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { begin, end = 0, killPos } n := a.getOrAllocateNode(v) - rangeIndex := len(n.ranges) - n.ranges = append(n.ranges, liveRange{blockID: blkID, begin: begin, end: end}) - info.liveNodes = append(info.liveNodes, liveNodeInBlock{rangeIndex, n}) + intervalNode := info.intervalTree.insert(n, begin, end) + n.ranges = append(n.ranges, intervalNode) } // In order to do the deterministic allocation, we need to sort defs. @@ -382,7 +360,7 @@ func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { var end programCounter if _, ok := outs[v]; ok { // v is defined here and live-out, so it is live-through. - end = math.MaxInt64 + end = math.MaxInt32 if _, ok := kills[v]; ok { panic("BUG: v is killed here but also killed") } @@ -397,9 +375,8 @@ func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { } } n := a.getOrAllocateNode(v) - rangeIndex := len(n.ranges) - n.ranges = append(n.ranges, liveRange{blockID: blkID, begin: defPos, end: end}) - info.liveNodes = append(info.liveNodes, liveNodeInBlock{rangeIndex, n}) + intervalNode := info.intervalTree.insert(n, defPos, end) + n.ranges = append(n.ranges, intervalNode) } // Reuse for the next block. @@ -419,7 +396,7 @@ func (a *Allocator) buildLiveRangesForNonReals(blkID int, info *blockInfo) { } // buildLiveRangesForReals builds live ranges for pre-colored real registers. -func (a *Allocator) buildLiveRangesForReals(blkID int, info *blockInfo) { +func (a *Allocator) buildLiveRangesForReals(info *blockInfo) { ds, us := info.realRegDefs, info.realRegUses // In order to do the deterministic compilation, we need to sort the registers. @@ -454,8 +431,8 @@ func (a *Allocator) buildLiveRangesForReals(blkID int, info *blockInfo) { n.r = v.RealReg() n.v = v defined, used := defs[i], uses[i] - n.ranges = append(n.ranges, liveRange{blockID: blkID, begin: defined, end: used}) - info.liveNodes = append(info.liveNodes, liveNodeInBlock{0, n}) + intervalNode := info.intervalTree.insert(n, defined, used) + n.ranges = append(n.ranges, intervalNode) } } } @@ -512,6 +489,7 @@ func (a *Allocator) getOrAllocateNode(v VReg) (n *node) { func (a *Allocator) allocateNode() (n *node) { n = a.nodePool.Allocate() + n.id = a.nodePool.Allocated() - 1 n.ranges = n.ranges[:0] n.copyFromVReg = nil n.copyToVReg = nil @@ -533,7 +511,11 @@ func resetMap[T any](a *Allocator, m map[VReg]T) { } func (a *Allocator) initBlockInfo(i *blockInfo) { - i.liveNodes = i.liveNodes[:0] + if i.intervalTree == nil { + i.intervalTree = newIntervalTree() + } else { + i.intervalTree.reset() + } if i.liveOuts == nil { i.liveOuts = make(map[VReg]struct{}) } else { @@ -622,11 +604,6 @@ func (n *node) String() string { if n.r != RealRegInvalid { buf.WriteString(fmt.Sprintf(":%v", n.r)) } - buf.WriteString(" ranges[") - for _, r := range n.ranges { - buf.WriteString(fmt.Sprintf("[%v-%v]@blk%d ", r.begin, r.end, r.blockID)) - } - buf.WriteString("]") // Add neighbors buf.WriteString(" neighbors[") for _, n := range n.neighbors { @@ -640,12 +617,6 @@ func (n *node) spill() bool { return n.r == RealRegInvalid } -// intersects returns true if the two live ranges intersect. -// Note that this doesn't compare the block ID because this is called to compare two intervals in the same block. -func (l *liveRange) intersects(other *liveRange) bool { - return other.begin <= l.end && l.begin <= other.end -} - func (r *RegisterInfo) isCalleeSaved(reg RealReg) bool { return r.CalleeSavedRegisters[reg] } @@ -654,12 +625,6 @@ func (r *RegisterInfo) isCallerSaved(reg RealReg) bool { return r.CallerSavedRegisters[reg] } -// String implements fmt.Stringer for debugging. -func (l *liveNodeInBlock) String() string { - r := l.n.ranges[l.rangeIndex] - return fmt.Sprintf("v%d@[%v-%v]", l.n.v.ID(), r.begin, r.end) -} - func (a *Allocator) recordCopyRelation(dst, src VReg) { sr, dr := src.IsRealReg(), dst.IsRealReg() switch { diff --git a/internal/engine/wazevo/backend/regalloc/regalloc_test.go b/internal/engine/wazevo/backend/regalloc/regalloc_test.go index aa8ff93e..af9736c0 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc_test.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc_test.go @@ -1,8 +1,6 @@ package regalloc import ( - "math" - "sort" "testing" "github.com/tetratelabs/wazero/internal/testing/require" @@ -473,173 +471,6 @@ func TestAllocator_livenessAnalysis_copy(t *testing.T) { require.Nil(t, n2.copyToVReg) } -func TestAllocator_buildLiveRangesForNonReals(t *testing.T) { - const blockID = 100 - const v1, v2, v3, v4, v5 = 1, 2, 3, 4, 5 - for _, tc := range []struct { - name string - info *blockInfo - exps map[VReg][]liveRange - }{ - { - name: "no defs without outs", - info: &blockInfo{ - liveIns: map[VReg]struct{}{ - v1: {}, v2: {}, v3: {}, - }, - kills: map[VReg]programCounter{v1: 1111, v2: 2222, v3: 3333}, - }, - exps: map[VReg][]liveRange{ - v1: {{blockID: blockID, begin: 0, end: 1111}}, - v2: {{blockID: blockID, begin: 0, end: 2222}}, - v3: {{blockID: blockID, begin: 0, end: 3333}}, - }, - }, - { - name: "no defs with outs", - info: &blockInfo{ - liveIns: map[VReg]struct{}{ - v1: {}, v2: {}, v3: {}, - }, - liveOuts: map[VReg]struct{}{v1: {}}, - kills: map[VReg]programCounter{v2: 2222, v3: 3333}, - }, - exps: map[VReg][]liveRange{ - v1: {{blockID: blockID, begin: 0, end: math.MaxInt64}}, - v2: {{blockID: blockID, begin: 0, end: 2222}}, - v3: {{blockID: blockID, begin: 0, end: 3333}}, - }, - }, - { - name: "only defs with outs", - info: &blockInfo{ - defs: map[VReg]programCounter{v1: 1, v2: 2, v3: 3}, - liveOuts: map[VReg]struct{}{v1: {}}, - kills: map[VReg]programCounter{v2: 2222, v3: 3333}, - }, - exps: map[VReg][]liveRange{ - v1: {{blockID: blockID, begin: 1, end: math.MaxInt64}}, - v2: {{blockID: blockID, begin: 2, end: 2222}}, - v3: {{blockID: blockID, begin: 3, end: 3333}}, - }, - }, - { - name: "only defs without outs", - info: &blockInfo{ - defs: map[VReg]programCounter{v1: 1, v2: 2, v3: 3}, - // Defined but not killed is allowed: v1 is unused variable. - kills: map[VReg]programCounter{v2: 2222, v3: 3333}, - }, - exps: map[VReg][]liveRange{ - v1: {{blockID: blockID, begin: 1, end: 1}}, // Defined and not used: [defined, defined]. - v2: {{blockID: blockID, begin: 2, end: 2222}}, - v3: {{blockID: blockID, begin: 3, end: 3333}}, - }, - }, - { - name: "mix and match", - info: &blockInfo{ - liveIns: map[VReg]struct{}{v1: {}, v2: {}}, - liveOuts: map[VReg]struct{}{v1: {}, v5: {}}, - defs: map[VReg]programCounter{v3: 3, v4: 4, v5: 5}, - kills: map[VReg]programCounter{v2: 2222, v4: 4444}, - }, - exps: map[VReg][]liveRange{ - v1: {{blockID: blockID, begin: 0, end: math.MaxInt64}}, - v2: {{blockID: blockID, begin: 0, end: 2222}}, - v3: {{blockID: blockID, begin: 3, end: 3}}, - v4: {{blockID: blockID, begin: 4, end: 4444}}, - v5: {{blockID: blockID, begin: 5, end: math.MaxInt64}}, - }, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - a := NewAllocator(&RegisterInfo{}) - a.buildLiveRangesForNonReals(blockID, tc.info) - require.Equal(t, len(tc.exps), a.nodePool.Allocated()) - for v, exp := range tc.exps { - liveNodes := tc.info.liveNodes - initMapInInfo(tc.info) - tc.info.liveNodes = liveNodes - t.Run(v.String(), func(t *testing.T) { - n := a.vRegIDToNode[v.ID()] - require.Equal(t, exp, n.ranges) - var found bool - for _, ln := range liveNodes { - if ln.n == n { - found = true - break - } - } - require.True(t, found) - }) - } - }) - } -} - -func TestAllocator_buildLiveRangesForReals(t *testing.T) { - realReg, realReg2 := FromRealReg(50, RegTypeInt), FromRealReg(100, RegTypeInt) - const blockID = 10 - for _, tc := range []struct { - allocatableRealRegs map[RealReg]struct{} - name string - info *blockInfo - exps map[VReg][]liveRange - }{ - { - name: "ok", - allocatableRealRegs: map[RealReg]struct{}{realReg.RealReg(): {}, realReg2.RealReg(): {}}, - info: &blockInfo{ - realRegDefs: map[VReg][]programCounter{ - realReg: {0, 10, 100}, - realReg2: {5}, - }, - realRegUses: map[VReg][]programCounter{ - realReg: {1, 11, 101}, - realReg2: {10}, - }, - }, - exps: map[VReg][]liveRange{ - realReg: {{begin: 0, end: 1}, {begin: 10, end: 11}, {begin: 100, end: 101}}, - realReg2: {{begin: 5, end: 10}}, - }, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - initMapInInfo(tc.info) - a := NewAllocator(&RegisterInfo{}) - for r := range tc.allocatableRealRegs { - a.allocatableSet[r] = true - } - a.buildLiveRangesForReals(blockID, tc.info) - - actual := map[VReg][]liveRange{} - for _, n := range tc.info.liveNodes { - n := n.n - r := n.ranges[0] - actual[n.v] = append(actual[n.v], liveRange{begin: r.begin, end: r.end, blockID: r.blockID}) - sort.Slice(actual[n.v], func(i, j int) bool { - return actual[n.v][i].begin < actual[n.v][j].begin - }) - } - - require.Equal(t, len(tc.exps), len(actual)) - for v, exp := range tc.exps { - t.Run(v.String(), func(t *testing.T) { - actual := actual[v] - for i := range exp { - exp[i].blockID = blockID - } - require.Equal(t, exp, actual) - }) - } - }) - } -} - func TestAllocator_recordCopyRelation(t *testing.T) { t.Run("real/real", func(t *testing.T) { // Just ensure that it doesn't panic. @@ -686,6 +517,9 @@ func TestAllocator_recordCopyRelation(t *testing.T) { } func initMapInInfo(info *blockInfo) { + if info.intervalTree == nil { + info.intervalTree = newIntervalTree() + } if info.liveIns == nil { info.liveIns = make(map[VReg]struct{}) }