wazevo(regalloc): simplifies live range management (#1798)
Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
This commit is contained in:
@@ -23,20 +23,21 @@ func (a *Allocator) assignRegistersPerBlock(f Function, blk Block, vRegIDToNode
|
||||
blkID := blk.ID()
|
||||
var pc programCounter
|
||||
for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
|
||||
tree := a.blockInfos[blkID].intervalTree
|
||||
tree := a.blockInfos[blkID].intervalMng
|
||||
a.assignRegistersPerInstr(f, pc, instr, vRegIDToNode, tree)
|
||||
pc += pcStride
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, tree *intervalTree) {
|
||||
func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, intervalMng *intervalManager) {
|
||||
if indirect := instr.IsIndirectCall(); instr.IsCall() || indirect {
|
||||
// Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since
|
||||
// the real VRegs are already placed in the right registers at this point.
|
||||
tree.collectActiveNonRealVRegsAt(
|
||||
intervalMng.collectActiveNodes(
|
||||
// To find the all the live registers "after" call, we need to add pcDefOffset for search.
|
||||
pc+pcDefOffset,
|
||||
&a.nodes1,
|
||||
// Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since
|
||||
// the real VRegs are already placed in the right registers at this point.
|
||||
false,
|
||||
)
|
||||
for _, active := range a.nodes1 {
|
||||
if r := active.r; a.regInfo.isCallerSaved(r) {
|
||||
@@ -101,19 +102,19 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
|
||||
panic("BUG: multiple def instructions must be special cased")
|
||||
}
|
||||
|
||||
a.handleSpills(f, pc, instr, usesSpills, defSpill, tree)
|
||||
a.handleSpills(f, pc, instr, usesSpills, defSpill, intervalMng)
|
||||
a.vs = usesSpills[:0] // for reuse.
|
||||
}
|
||||
|
||||
func (a *Allocator) handleSpills(
|
||||
f Function, pc programCounter, instr Instr,
|
||||
usesSpills []VReg, defSpill VReg, tree *intervalTree,
|
||||
usesSpills []VReg, defSpill VReg, intervalMng *intervalManager,
|
||||
) {
|
||||
_usesSpills, _defSpill := len(usesSpills) > 0, defSpill.Valid()
|
||||
switch {
|
||||
case !_usesSpills && !_defSpill: // Nothing to do.
|
||||
case !_usesSpills && _defSpill: // Only definition is spilled.
|
||||
tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1)
|
||||
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
|
||||
a.spillHandler.init(a.nodes1, instr)
|
||||
|
||||
r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
|
||||
@@ -129,7 +130,7 @@ func (a *Allocator) handleSpills(
|
||||
f.StoreRegisterAfter(defSpill, instr)
|
||||
|
||||
case _usesSpills:
|
||||
tree.collectActiveRealRegNodesAt(pc, &a.nodes1)
|
||||
intervalMng.collectActiveNodes(pc, &a.nodes1, true)
|
||||
a.spillHandler.init(a.nodes1, instr)
|
||||
|
||||
var evicted [3]*node
|
||||
@@ -173,7 +174,7 @@ func (a *Allocator) handleSpills(
|
||||
|
||||
if !defSpill.IsRealReg() {
|
||||
// This case, the destination register type is different from the source registers.
|
||||
tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1)
|
||||
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
|
||||
a.spillHandler.init(a.nodes1, instr)
|
||||
r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
|
||||
if evictedNode != nil {
|
||||
|
||||
@@ -11,7 +11,7 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
|
||||
a := NewAllocator(&RegisterInfo{CallerSavedRegisters: [RealRegsNumMax]bool{1: true, 3: true}})
|
||||
pc := programCounter(5)
|
||||
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
liveNodes := []*node{
|
||||
{r: 1, v: 0xa},
|
||||
{r: RealRegInvalid, v: 0xb}, // Spill. not save target.
|
||||
@@ -20,12 +20,13 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
|
||||
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
|
||||
}
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, 5, 20)
|
||||
manager.insert(n, 5, 20)
|
||||
}
|
||||
manager.build()
|
||||
call := newMockInstr().asCall()
|
||||
blk := newMockBlock(0, call).entry()
|
||||
f := newMockFunction(blk)
|
||||
a.assignRegistersPerInstr(f, pc, call, nil, tree)
|
||||
a.assignRegistersPerInstr(f, pc, call, nil, manager)
|
||||
|
||||
require.Equal(t, 2, len(f.befores))
|
||||
require.Equal(t, 2, len(f.afters))
|
||||
@@ -36,8 +37,8 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
|
||||
functionPtrVRegID := 0x0
|
||||
functionPtrVReg := VReg(functionPtrVRegID).SetRegType(RegTypeInt)
|
||||
functionPtrLiveNode := &node{r: 0xf, v: functionPtrVReg}
|
||||
tree := newIntervalTree()
|
||||
tree.insert(functionPtrLiveNode, 4, pc)
|
||||
manager := newIntervalManager()
|
||||
manager.insert(functionPtrLiveNode, 4, pc)
|
||||
liveNodes := []*node{
|
||||
{r: 1, v: 0xa},
|
||||
{r: 2, v: FromRealReg(1, RegTypeInt)}, // Real reg-backed VReg. not target
|
||||
@@ -45,12 +46,13 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
|
||||
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
|
||||
}
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, 5, 20)
|
||||
manager.insert(n, 5, 20)
|
||||
}
|
||||
manager.build()
|
||||
callInd := newMockInstr().asIndirectCall().use(functionPtrVReg)
|
||||
blk := newMockBlock(0, callInd).entry()
|
||||
f := newMockFunction(blk)
|
||||
a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode}, tree)
|
||||
a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode}, manager)
|
||||
|
||||
require.Equal(t, 2, len(f.befores))
|
||||
require.Equal(t, 2, len(f.afters))
|
||||
@@ -71,17 +73,17 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
|
||||
{r: 3, v: 0xc},
|
||||
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
|
||||
}
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, 5, 20)
|
||||
manager.insert(n, 5, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
callInd := newMockInstr().asIndirectCall().use(functionPtrVReg)
|
||||
blk := newMockBlock(0, callInd).entry()
|
||||
f := newMockFunction(blk)
|
||||
a.assignRegistersPerInstr(f, pc, callInd, []*node{
|
||||
0: {r: RealRegInvalid},
|
||||
}, tree)
|
||||
}, manager)
|
||||
|
||||
require.Equal(t, 3, len(f.befores))
|
||||
require.Equal(t, 2, len(f.afters))
|
||||
@@ -132,10 +134,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xb), v: 0xa},
|
||||
{r: RealReg(0xc), v: 0xc},
|
||||
}
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xa, 0xb, 0xc}}, // Only live nodes are allocatable.
|
||||
})
|
||||
@@ -144,7 +147,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
|
||||
vr := VReg(100).SetRegType(RegTypeInt)
|
||||
instr := newMockInstr().def(vr)
|
||||
a.handleSpills(f, pc, instr, nil, vr, tree)
|
||||
a.handleSpills(f, pc, instr, nil, vr, manager)
|
||||
require.Equal(t, 1, len(instr.defs))
|
||||
require.Equal(t, RealReg(0xa), instr.defs[0].RealReg())
|
||||
|
||||
@@ -162,11 +165,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xb), v: 0xa},
|
||||
{r: RealReg(0xc), v: 0xc},
|
||||
}
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xb, 0xc}}, // Only live nodes are allocatable.
|
||||
})
|
||||
@@ -175,7 +178,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
|
||||
vr := VReg(100).SetRegType(RegTypeInt)
|
||||
instr := newMockInstr().def(vr)
|
||||
a.handleSpills(f, pc, instr, nil, vr, tree)
|
||||
a.handleSpills(f, pc, instr, nil, vr, manager)
|
||||
require.Equal(t, 1, len(instr.defs))
|
||||
require.Equal(t, RealReg(0xb), instr.defs[0].RealReg())
|
||||
|
||||
@@ -194,11 +197,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xc), v: 0xc},
|
||||
}
|
||||
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xb, 0xc, 0xf /* free */}},
|
||||
})
|
||||
@@ -207,7 +210,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
|
||||
vr := VReg(100).SetRegType(RegTypeInt)
|
||||
instr := newMockInstr().def(vr)
|
||||
a.handleSpills(f, pc, instr, nil, vr, tree)
|
||||
a.handleSpills(f, pc, instr, nil, vr, manager)
|
||||
require.Equal(t, 1, len(instr.defs))
|
||||
require.Equal(t, RealReg(0xf), instr.defs[0].RealReg())
|
||||
|
||||
@@ -223,11 +226,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xc), v: 0xc},
|
||||
}
|
||||
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{
|
||||
RegTypeInt: {0xb, 0xc, 0xa /* free */},
|
||||
@@ -241,7 +244,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
VReg(102).SetRegType(RegTypeFloat)
|
||||
d1 := VReg(104).SetRegType(RegTypeFloat)
|
||||
instr := newMockInstr().use(u1, u2, u3).def(d1)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1, u3}, d1, tree)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1, u3}, d1, manager)
|
||||
require.Equal(t, []VReg{u1.SetRealReg(0xa), u2, u3.SetRealReg(0xf)}, instr.uses)
|
||||
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)
|
||||
|
||||
@@ -261,11 +264,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xf), v: 0xb},
|
||||
}
|
||||
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{
|
||||
RegTypeInt: {0xb, 0xa /* free */},
|
||||
@@ -277,7 +280,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
u1 := VReg(100).SetRegType(RegTypeInt)
|
||||
d1 := VReg(104).SetRegType(RegTypeFloat)
|
||||
instr := newMockInstr().use(u1).def(d1)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1}, d1, manager)
|
||||
require.Equal(t, []VReg{u1.SetRealReg(0xa)}, instr.uses)
|
||||
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)
|
||||
|
||||
@@ -297,11 +300,11 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
{r: RealReg(0xc), v: 0xa},
|
||||
{r: RealReg(0xf), v: 0xb},
|
||||
}
|
||||
tree := newIntervalTree()
|
||||
manager := newIntervalManager()
|
||||
for _, n := range liveNodes {
|
||||
tree.insert(n, pc, 20)
|
||||
manager.insert(n, pc, 20)
|
||||
}
|
||||
|
||||
manager.build()
|
||||
a := NewAllocator(&RegisterInfo{
|
||||
AllocatableRegisters: [3][]RealReg{
|
||||
RegTypeInt: {0xb, 0xc},
|
||||
@@ -313,7 +316,7 @@ func TestAllocator_handleSpills(t *testing.T) {
|
||||
u1 := VReg(100).SetRegType(RegTypeInt)
|
||||
d1 := VReg(104).SetRegType(RegTypeFloat)
|
||||
instr := newMockInstr().use(u1).def(d1)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree)
|
||||
a.handleSpills(f, pc, instr, []VReg{u1}, d1, manager)
|
||||
require.Equal(t, []VReg{u1.SetRealReg(0xb)}, instr.uses)
|
||||
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ func TestAllocator_buildNeighbors(t *testing.T) {
|
||||
{n: newNode(0)},
|
||||
{
|
||||
n: &node{
|
||||
ranges: []*intervalTreeNode{
|
||||
ranges: []*interval{
|
||||
{nodes: newNodes(1, 2, 3)},
|
||||
{nodes: newNodes(4, 5, 1, 2, 3)},
|
||||
},
|
||||
@@ -268,7 +268,7 @@ func TestAllocator_buildNeighbors(t *testing.T) {
|
||||
},
|
||||
{
|
||||
n: &node{
|
||||
ranges: []*intervalTreeNode{
|
||||
ranges: []*interval{
|
||||
{nodes: newNodes(1, 2, 3)},
|
||||
{nodes: newNodes(1, 2, 3)},
|
||||
{nodes: newNodes(1, 2, 3)},
|
||||
@@ -280,9 +280,9 @@ func TestAllocator_buildNeighbors(t *testing.T) {
|
||||
},
|
||||
{
|
||||
n: &node{
|
||||
ranges: []*intervalTreeNode{
|
||||
ranges: []*interval{
|
||||
{nodes: newNodes(1, 2, 3)},
|
||||
{nodes: newNodes(4), neighbors: []*intervalTreeNode{
|
||||
{nodes: newNodes(4), neighbors: []*interval{
|
||||
{nodes: newNodes(5, 6)},
|
||||
{nodes: newNodes(100, 200)},
|
||||
}},
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package regalloc
|
||||
|
||||
import "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
|
||||
|
||||
type intervalTree struct {
|
||||
root *intervalTreeNode
|
||||
allocator intervalTreeNodeAllocator
|
||||
intervals map[uint64]*intervalTreeNode
|
||||
}
|
||||
|
||||
func intervalTreeNodeKey(begin, end programCounter) uint64 {
|
||||
return uint64(begin) | uint64(end)<<32
|
||||
}
|
||||
|
||||
func (t *intervalTree) insert(n *node, begin, end programCounter) *intervalTreeNode {
|
||||
key := uint64(begin) | uint64(end)<<32
|
||||
if i, ok := t.intervals[key]; ok {
|
||||
i.nodes = append(i.nodes, n)
|
||||
return i
|
||||
}
|
||||
t.root = t.root.insert(t, n, begin, end)
|
||||
ret := t.intervals[key]
|
||||
t.buildNeighbors(ret) // TODO: this can be done while inserting.
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t *intervalTree) reset() {
|
||||
t.root = nil
|
||||
t.allocator.Reset()
|
||||
t.intervals = make(map[uint64]*intervalTreeNode)
|
||||
}
|
||||
|
||||
func newIntervalTree() *intervalTree {
|
||||
return &intervalTree{
|
||||
allocator: wazevoapi.NewPool[intervalTreeNode](resetIntervalTreeNode),
|
||||
intervals: make(map[uint64]*intervalTreeNode),
|
||||
}
|
||||
}
|
||||
|
||||
type intervalTreeNodeAllocator = wazevoapi.Pool[intervalTreeNode]
|
||||
|
||||
type intervalTreeNode struct {
|
||||
begin, end programCounter
|
||||
nodes []*node
|
||||
maxEnd programCounter
|
||||
neighbors []*intervalTreeNode
|
||||
left, right *intervalTreeNode
|
||||
// TODO: color for red-black balancing.
|
||||
}
|
||||
|
||||
func resetIntervalTreeNode(i *intervalTreeNode) {
|
||||
i.begin = 0
|
||||
i.end = 0
|
||||
i.nodes = i.nodes[:0]
|
||||
i.maxEnd = 0
|
||||
i.neighbors = i.neighbors[:0]
|
||||
i.left = nil
|
||||
i.right = nil
|
||||
}
|
||||
|
||||
func (i *intervalTreeNode) insert(t *intervalTree, n *node, begin, end programCounter) *intervalTreeNode {
|
||||
if i == nil {
|
||||
intervalNode := t.allocator.Allocate()
|
||||
intervalNode.nodes = append(intervalNode.nodes, n)
|
||||
intervalNode.maxEnd = end
|
||||
intervalNode.begin = begin
|
||||
intervalNode.end = end
|
||||
key := intervalTreeNodeKey(begin, end)
|
||||
t.intervals[key] = intervalNode
|
||||
return intervalNode
|
||||
}
|
||||
if begin < i.begin {
|
||||
i.left = i.left.insert(t, n, begin, end)
|
||||
} else {
|
||||
i.right = i.right.insert(t, n, begin, end)
|
||||
}
|
||||
if i.maxEnd < end {
|
||||
i.maxEnd = end
|
||||
}
|
||||
|
||||
// TODO: balancing logic so that collection functions are faster.
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
func (t *intervalTree) buildNeighbors(from *intervalTreeNode) {
|
||||
t.root.buildNeighbors(from)
|
||||
}
|
||||
|
||||
func (i *intervalTreeNode) buildNeighbors(from *intervalTreeNode) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
if i.intersects(from) {
|
||||
from.neighbors = append(from.neighbors, i)
|
||||
i.neighbors = append(i.neighbors, from)
|
||||
}
|
||||
if i.left != nil && i.left.maxEnd >= from.begin {
|
||||
i.left.buildNeighbors(from)
|
||||
}
|
||||
if i.begin <= from.end {
|
||||
i.right.buildNeighbors(from)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *intervalTree) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) {
|
||||
*overlaps = (*overlaps)[:0]
|
||||
t.root.collectActiveNonRealVRegsAt(pc, overlaps)
|
||||
}
|
||||
|
||||
func (i *intervalTreeNode) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
if i.begin <= pc && i.end >= pc {
|
||||
for _, n := range i.nodes {
|
||||
if n.spill() || n.v.IsRealReg() {
|
||||
continue
|
||||
}
|
||||
*overlaps = append(*overlaps, n)
|
||||
}
|
||||
}
|
||||
if i.left != nil && i.left.maxEnd >= pc {
|
||||
i.left.collectActiveNonRealVRegsAt(pc, overlaps)
|
||||
}
|
||||
if i.begin <= pc {
|
||||
i.right.collectActiveNonRealVRegsAt(pc, overlaps)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *intervalTree) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) {
|
||||
*overlaps = (*overlaps)[:0]
|
||||
t.root.collectActiveRealRegNodesAt(pc, overlaps)
|
||||
}
|
||||
|
||||
func (i *intervalTreeNode) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
if i.begin <= pc && i.end >= pc {
|
||||
for _, n := range i.nodes {
|
||||
if n.assignedRealReg() != RealRegInvalid {
|
||||
*overlaps = append(*overlaps, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
if i.left != nil && i.left.maxEnd >= pc {
|
||||
i.left.collectActiveRealRegNodesAt(pc, overlaps)
|
||||
}
|
||||
if i.begin <= pc {
|
||||
i.right.collectActiveRealRegNodesAt(pc, overlaps)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *intervalTreeNode) intersects(j *intervalTreeNode) bool {
|
||||
return i.begin <= j.end && i.end >= j.begin || j.begin <= i.end && j.end >= i.begin
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
package regalloc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/tetratelabs/wazero/internal/testing/require"
|
||||
)
|
||||
|
||||
func TestIntervalTree_reset(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
n := tree.allocator.Allocate()
|
||||
tree.root = n
|
||||
tree.reset()
|
||||
|
||||
require.Nil(t, tree.root)
|
||||
require.Equal(t, 0, tree.allocator.Allocated())
|
||||
}
|
||||
|
||||
func TestIntervalTreeInsert(t *testing.T) {
|
||||
n1 := &node{}
|
||||
tree := newIntervalTree()
|
||||
tree.insert(n1, 100, 200)
|
||||
require.NotNil(t, tree.root)
|
||||
require.NotNil(t, tree.root.nodes)
|
||||
require.Equal(t, n1, tree.root.nodes[0])
|
||||
n, ok := tree.intervals[intervalTreeNodeKey(100, 200)]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, n1, n.nodes[0])
|
||||
}
|
||||
|
||||
func TestIntervalTreeNodeInsert(t *testing.T) {
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
var n *intervalTreeNode
|
||||
allocated := n.insert(tree, &node{}, 0, 100)
|
||||
require.Equal(t, 1, tree.allocator.Allocated())
|
||||
require.NotNil(t, allocated)
|
||||
require.Equal(t, allocated, tree.allocator.View(0))
|
||||
require.Equal(t, programCounter(100), allocated.maxEnd)
|
||||
require.Equal(t, programCounter(0), allocated.begin)
|
||||
require.Equal(t, programCounter(100), allocated.end)
|
||||
require.Equal(t, 1, len(allocated.nodes))
|
||||
n, ok := tree.intervals[intervalTreeNodeKey(0, 100)]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, allocated, n)
|
||||
})
|
||||
t.Run("left", func(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100}
|
||||
n1 := &node{}
|
||||
self := n.insert(tree, n1, 0, 200)
|
||||
require.Equal(t, self, n)
|
||||
require.Equal(t, 1, tree.allocator.Allocated())
|
||||
left := tree.allocator.View(0)
|
||||
require.Equal(t, n.left, left)
|
||||
require.Nil(t, n.right)
|
||||
require.Equal(t, programCounter(200), n.maxEnd)
|
||||
require.Equal(t, left.nodes[0], n1)
|
||||
})
|
||||
t.Run("right", func(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100}
|
||||
n1 := &node{}
|
||||
self := n.insert(tree, n1, 150, 200)
|
||||
require.Equal(t, self, n)
|
||||
require.Equal(t, 1, tree.allocator.Allocated())
|
||||
right := tree.allocator.View(0)
|
||||
require.Equal(t, n.right, right)
|
||||
require.Nil(t, n.left)
|
||||
require.Equal(t, programCounter(200), n.maxEnd)
|
||||
require.Equal(t, right.nodes[0], n1)
|
||||
})
|
||||
}
|
||||
|
||||
type (
|
||||
interval struct {
|
||||
begin, end programCounter
|
||||
id int
|
||||
}
|
||||
queryCase struct {
|
||||
query programCounter
|
||||
exp []int
|
||||
}
|
||||
)
|
||||
|
||||
func newQueryCase(s programCounter, exp ...int) queryCase {
|
||||
return queryCase{query: s, exp: exp}
|
||||
}
|
||||
|
||||
func TestIntervalTree_collectActiveNodesAt(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
intervals []interval
|
||||
queryCases []queryCase
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
intervals: []interval{{begin: 0, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0, 0),
|
||||
newQueryCase(0, 0),
|
||||
newQueryCase(1, 0),
|
||||
newQueryCase(1, 0),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single/2",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(50, 0),
|
||||
newQueryCase(50, 0),
|
||||
newQueryCase(51, 0),
|
||||
newQueryCase(51, 0),
|
||||
newQueryCase(101),
|
||||
newQueryCase(48),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "same id for different intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xa),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two disjoint intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(0),
|
||||
newQueryCase(51, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xb),
|
||||
newQueryCase(200, 0xb),
|
||||
newQueryCase(101),
|
||||
newQueryCase(201),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two intersecting intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(70, 0xa, 0xb),
|
||||
newQueryCase(1),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(51, 0xa, 0xb),
|
||||
newQueryCase(100, 0xa, 0xb),
|
||||
newQueryCase(101, 0xb),
|
||||
newQueryCase(49),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two enclosing interval",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(24),
|
||||
newQueryCase(25, 0xb),
|
||||
newQueryCase(39, 0xb),
|
||||
newQueryCase(40, 0xb, 0xc),
|
||||
newQueryCase(100, 0xa, 0xb, 0xc),
|
||||
newQueryCase(99, 0xa, 0xb, 0xc),
|
||||
newQueryCase(100, 0xa, 0xb, 0xc),
|
||||
newQueryCase(50, 0xa, 0xb, 0xc),
|
||||
newQueryCase(51, 0xa, 0xb, 0xc),
|
||||
newQueryCase(101, 0xb, 0xc),
|
||||
newQueryCase(49, 0xb, 0xc),
|
||||
newQueryCase(200, 0xb, 0xc),
|
||||
newQueryCase(201, 0xc),
|
||||
newQueryCase(1000, 0xc),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
var maxID int
|
||||
for _, inter := range tc.intervals {
|
||||
n := &node{id: inter.id, r: RealReg(1)}
|
||||
tree.insert(n, inter.begin, inter.end)
|
||||
if maxID < inter.id {
|
||||
maxID = inter.id
|
||||
}
|
||||
key := intervalTreeNodeKey(inter.begin, inter.end)
|
||||
inserted := tree.intervals[key]
|
||||
inserted.nodes = append(inserted.nodes, &node{v: VRegInvalid.SetRealReg(RealRegInvalid)}) // non-real reg should be ignored.
|
||||
}
|
||||
for _, qc := range tc.queryCases {
|
||||
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
|
||||
var collected []*node
|
||||
tree.collectActiveRealRegNodesAt(qc.query, &collected)
|
||||
require.Equal(t, len(qc.exp), len(collected))
|
||||
var foundIDs []int
|
||||
for _, n := range collected {
|
||||
foundIDs = append(foundIDs, n.id)
|
||||
}
|
||||
sort.Slice(foundIDs, func(i, j int) bool {
|
||||
return foundIDs[i] < foundIDs[j]
|
||||
})
|
||||
require.Equal(t, qc.exp, foundIDs)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalTree_collectActiveNonRealVRegsAt(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
intervals []interval
|
||||
queryCases []queryCase
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
intervals: []interval{{begin: 0, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0, 0),
|
||||
newQueryCase(0, 0),
|
||||
newQueryCase(1, 0),
|
||||
newQueryCase(1, 0),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single/2",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(50, 0),
|
||||
newQueryCase(50, 0),
|
||||
newQueryCase(51, 0),
|
||||
newQueryCase(51, 0),
|
||||
newQueryCase(101),
|
||||
newQueryCase(48),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "same id for different intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xa),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two disjoint intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(0),
|
||||
newQueryCase(51, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xb),
|
||||
newQueryCase(200, 0xb),
|
||||
newQueryCase(101),
|
||||
newQueryCase(201),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two intersecting intervals",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(70, 0xa, 0xb),
|
||||
newQueryCase(1),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(51, 0xa, 0xb),
|
||||
newQueryCase(100, 0xa, 0xb),
|
||||
newQueryCase(101, 0xb),
|
||||
newQueryCase(49),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two enclosing interval",
|
||||
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(24),
|
||||
newQueryCase(25, 0xb),
|
||||
newQueryCase(39, 0xb),
|
||||
newQueryCase(40, 0xb, 0xc),
|
||||
newQueryCase(100, 0xa, 0xb, 0xc),
|
||||
newQueryCase(99, 0xa, 0xb, 0xc),
|
||||
newQueryCase(100, 0xa, 0xb, 0xc),
|
||||
newQueryCase(50, 0xa, 0xb, 0xc),
|
||||
newQueryCase(51, 0xa, 0xb, 0xc),
|
||||
newQueryCase(101, 0xb, 0xc),
|
||||
newQueryCase(49, 0xb, 0xc),
|
||||
newQueryCase(200, 0xb, 0xc),
|
||||
newQueryCase(201, 0xc),
|
||||
newQueryCase(1000, 0xc),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tree := newIntervalTree()
|
||||
var maxID int
|
||||
for _, inter := range tc.intervals {
|
||||
n := &node{id: inter.id, r: RealReg(1)}
|
||||
tree.insert(n, inter.begin, inter.end)
|
||||
if maxID < inter.id {
|
||||
maxID = inter.id
|
||||
}
|
||||
key := intervalTreeNodeKey(inter.begin, inter.end)
|
||||
inserted := tree.intervals[key]
|
||||
// They are ignored.
|
||||
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeInt)})
|
||||
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeFloat)})
|
||||
inserted.nodes = append(inserted.nodes, &node{v: VReg(1)})
|
||||
}
|
||||
for _, qc := range tc.queryCases {
|
||||
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
|
||||
var collected []*node
|
||||
tree.collectActiveNonRealVRegsAt(qc.query, &collected)
|
||||
require.Equal(t, len(qc.exp), len(collected))
|
||||
var foundIDs []int
|
||||
for _, n := range collected {
|
||||
foundIDs = append(foundIDs, n.id)
|
||||
}
|
||||
sort.Slice(foundIDs, func(i, j int) bool {
|
||||
return foundIDs[i] < foundIDs[j]
|
||||
})
|
||||
require.Equal(t, qc.exp, foundIDs)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalTreeNode_intersects(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
rhs, lhs intervalTreeNode
|
||||
exp bool
|
||||
}{
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 0, end: 100},
|
||||
lhs: intervalTreeNode{begin: 0, end: 100},
|
||||
exp: true,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 0, end: 100},
|
||||
lhs: intervalTreeNode{begin: 0, end: 99},
|
||||
exp: true,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 0, end: 100},
|
||||
lhs: intervalTreeNode{begin: 1, end: 100},
|
||||
exp: true,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 50, end: 100},
|
||||
lhs: intervalTreeNode{begin: 1, end: 49},
|
||||
exp: false,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 50, end: 100},
|
||||
lhs: intervalTreeNode{begin: 1, end: 50},
|
||||
exp: true,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 50, end: 100},
|
||||
lhs: intervalTreeNode{begin: 1, end: 51},
|
||||
exp: true,
|
||||
},
|
||||
{
|
||||
rhs: intervalTreeNode{begin: 50, end: 100},
|
||||
lhs: intervalTreeNode{begin: 99, end: 102},
|
||||
exp: true,
|
||||
},
|
||||
} {
|
||||
actual := tc.rhs.intersects(&tc.lhs)
|
||||
require.Equal(t, tc.exp, actual)
|
||||
}
|
||||
}
|
||||
147
internal/engine/wazevo/backend/regalloc/intervals.go
Normal file
147
internal/engine/wazevo/backend/regalloc/intervals.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package regalloc
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
|
||||
)
|
||||
|
||||
type (
|
||||
// intervalManager manages intervals for each block.
|
||||
intervalManager struct {
|
||||
allocator wazevoapi.Pool[interval]
|
||||
intervals map[intervalKey]*interval
|
||||
sortedIntervals []*interval
|
||||
collectionCur int
|
||||
}
|
||||
// interval represents an interval in the block, which is a range of program counters.
|
||||
// Each interval has a list of nodes which are live in the interval.
|
||||
interval struct {
|
||||
begin, end programCounter
|
||||
// nodes are nodes which are alive in this interval.
|
||||
nodes []*node
|
||||
// neighbors are intervals which are adjacent to this interval.
|
||||
neighbors []*interval
|
||||
}
|
||||
// intervalKey is a key for intervalManager.intervals which consists of begin and end.
|
||||
intervalKey uint64
|
||||
)
|
||||
|
||||
func newIntervalManager() *intervalManager {
|
||||
return &intervalManager{
|
||||
allocator: wazevoapi.NewPool[interval](resetIntervalTreeNode),
|
||||
intervals: make(map[intervalKey]*interval),
|
||||
}
|
||||
}
|
||||
|
||||
func resetIntervalTreeNode(i *interval) {
|
||||
i.begin = 0
|
||||
i.end = 0
|
||||
i.nodes = i.nodes[:0]
|
||||
i.neighbors = i.neighbors[:0]
|
||||
}
|
||||
|
||||
// intervalTreeNodeKey returns a key for intervalManager.intervals.
|
||||
func intervalTreeNodeKey(begin, end programCounter) intervalKey {
|
||||
return intervalKey(begin) | intervalKey(end)<<32
|
||||
}
|
||||
|
||||
// insert inserts a node into the interval tree.
|
||||
func (t *intervalManager) insert(n *node, begin, end programCounter) *interval {
|
||||
key := intervalTreeNodeKey(begin, end)
|
||||
if i, ok := t.intervals[key]; ok {
|
||||
i.nodes = append(i.nodes, n)
|
||||
return i
|
||||
}
|
||||
i := t.allocator.Allocate()
|
||||
i.nodes = append(i.nodes, n)
|
||||
i.begin = begin
|
||||
i.end = end
|
||||
t.intervals[key] = i
|
||||
t.sortedIntervals = append(t.sortedIntervals, i) // Will be sorted later.
|
||||
return i
|
||||
}
|
||||
|
||||
func (t *intervalManager) reset() {
|
||||
t.allocator.Reset()
|
||||
t.sortedIntervals = t.sortedIntervals[:0]
|
||||
t.intervals = make(map[intervalKey]*interval)
|
||||
t.collectionCur = 0
|
||||
}
|
||||
|
||||
// build is called after all the intervals are inserted. This sorts the intervals,
|
||||
// and builds the neighbor intervals for each interval.
|
||||
func (t *intervalManager) build() {
|
||||
sort.Slice(t.sortedIntervals, func(i, j int) bool {
|
||||
ii, ij := t.sortedIntervals[i], t.sortedIntervals[j]
|
||||
if ii.begin == ij.begin {
|
||||
return ii.end < ij.end
|
||||
}
|
||||
return ii.begin < ij.begin
|
||||
})
|
||||
|
||||
var cur int
|
||||
var existingEndMax programCounter = -1
|
||||
for i, _interval := range t.sortedIntervals {
|
||||
begin, end := _interval.begin, _interval.end
|
||||
if begin > existingEndMax {
|
||||
cur = i
|
||||
existingEndMax = end
|
||||
} else {
|
||||
for j := cur; j < i; j++ {
|
||||
existing := t.sortedIntervals[j]
|
||||
if existing.end < begin {
|
||||
continue
|
||||
}
|
||||
if existing.begin > end {
|
||||
panic("BUG")
|
||||
}
|
||||
_interval.neighbors = append(_interval.neighbors, existing)
|
||||
existing.neighbors = append(existing.neighbors, _interval)
|
||||
}
|
||||
if end > existingEndMax {
|
||||
existingEndMax = end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectActiveNodes collects nodes which are alive at pc, and the result is stored in `collected`.
|
||||
// If `real` is true, only nodes which are assigned to a real register are collected.
|
||||
func (t *intervalManager) collectActiveNodes(pc programCounter, collected *[]*node, real bool) {
|
||||
*collected = (*collected)[:0]
|
||||
|
||||
// Advance the collection cursor until the current interval's end is greater than pc.
|
||||
l := len(t.sortedIntervals)
|
||||
for cur := t.collectionCur; cur < l; cur++ {
|
||||
curNode := t.sortedIntervals[cur]
|
||||
if curNode.end < pc {
|
||||
t.collectionCur++
|
||||
continue
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for cur := t.collectionCur; cur < l; cur++ {
|
||||
curNode := t.sortedIntervals[cur]
|
||||
if curNode.end < pc {
|
||||
continue
|
||||
} else if curNode.begin > pc {
|
||||
break
|
||||
}
|
||||
|
||||
for _, n := range curNode.nodes {
|
||||
if real {
|
||||
if n.assignedRealReg() == RealRegInvalid {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if n.spill() || n.v.IsRealReg() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
*collected = append(*collected, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
258
internal/engine/wazevo/backend/regalloc/intervals_test.go
Normal file
258
internal/engine/wazevo/backend/regalloc/intervals_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package regalloc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/tetratelabs/wazero/internal/testing/require"
|
||||
)
|
||||
|
||||
func TestIntervalsManager_build(t *testing.T) {
|
||||
type (
|
||||
intervalCase struct {
|
||||
begin, end programCounter
|
||||
}
|
||||
expNeighborCase struct {
|
||||
index int
|
||||
neighbors []intervalCase
|
||||
}
|
||||
)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
intervals []intervalCase
|
||||
expNeighbors []expNeighborCase
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
intervals: []intervalCase{{begin: 0, end: 100}},
|
||||
expNeighbors: []expNeighborCase{{index: 0}},
|
||||
},
|
||||
{
|
||||
name: "disjoints",
|
||||
intervals: []intervalCase{{begin: 50, end: 100}, {begin: 1, end: 2}},
|
||||
expNeighbors: []expNeighborCase{{index: 0}, {index: 1}},
|
||||
},
|
||||
{
|
||||
name: "disjoints duplicate",
|
||||
intervals: []intervalCase{{begin: 50, end: 100}, {begin: 1, end: 2}, {begin: 50, end: 100}, {begin: 1, end: 2}},
|
||||
expNeighbors: []expNeighborCase{{index: 0}, {index: 1}, {index: 2}, {index: 3}},
|
||||
},
|
||||
{
|
||||
name: "two intersecting",
|
||||
intervals: []intervalCase{
|
||||
{begin: 70, end: 200},
|
||||
{begin: 50, end: 100},
|
||||
},
|
||||
expNeighbors: []expNeighborCase{
|
||||
{index: 0, neighbors: []intervalCase{{begin: 50, end: 100}}},
|
||||
{index: 1, neighbors: []intervalCase{{begin: 70, end: 200}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "same beginnings",
|
||||
intervals: []intervalCase{
|
||||
{begin: 50, end: 200},
|
||||
{begin: 50, end: 401},
|
||||
{begin: 50, end: 201},
|
||||
{begin: 50, end: 302},
|
||||
},
|
||||
expNeighbors: []expNeighborCase{
|
||||
{index: 0, neighbors: []intervalCase{{begin: 50, end: 201}, {begin: 50, end: 302}, {begin: 50, end: 401}}},
|
||||
{index: 1, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 201}, {begin: 50, end: 302}}},
|
||||
{index: 2, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 302}, {begin: 50, end: 401}}},
|
||||
{index: 3, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 201}, {begin: 50, end: 401}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three intersecting",
|
||||
intervals: []intervalCase{
|
||||
{begin: 70, end: 200},
|
||||
{begin: 71, end: 150},
|
||||
{begin: 50, end: 100},
|
||||
},
|
||||
expNeighbors: []expNeighborCase{
|
||||
{index: 0, neighbors: []intervalCase{{begin: 50, end: 100}, {begin: 71, end: 150}}},
|
||||
{index: 1, neighbors: []intervalCase{{begin: 50, end: 100}, {begin: 70, end: 200}}},
|
||||
{index: 2, neighbors: []intervalCase{{begin: 70, end: 200}, {begin: 71, end: 150}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two enclosing interval",
|
||||
intervals: []intervalCase{
|
||||
{begin: 50, end: 100},
|
||||
{begin: 25, end: 200},
|
||||
{begin: 40, end: 1000},
|
||||
},
|
||||
expNeighbors: []expNeighborCase{
|
||||
{index: 0, neighbors: []intervalCase{{begin: 25, end: 200}, {begin: 40, end: 1000}}},
|
||||
{index: 1, neighbors: []intervalCase{{begin: 40, end: 1000}, {begin: 50, end: 100}}},
|
||||
{index: 2, neighbors: []intervalCase{{begin: 25, end: 200}, {begin: 50, end: 100}}},
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager := newIntervalManager()
|
||||
for i, inter := range tc.intervals {
|
||||
n := &node{id: i, r: RealReg(1)}
|
||||
manager.insert(n, inter.begin, inter.end)
|
||||
}
|
||||
manager.build()
|
||||
|
||||
for i, exp := range tc.expNeighbors {
|
||||
it := tc.intervals[exp.index]
|
||||
key := intervalTreeNodeKey(it.begin, it.end)
|
||||
|
||||
var found []intervalCase
|
||||
for _, n := range manager.intervals[key].neighbors {
|
||||
found = append(found, intervalCase{begin: n.begin, end: n.end})
|
||||
}
|
||||
sort.Slice(found, func(i, j int) bool {
|
||||
return found[i].begin < found[j].begin
|
||||
})
|
||||
require.Equal(t, exp.neighbors, found, fmt.Sprintf("case=%d", i))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalManager_collectActiveNodes(t *testing.T) {
|
||||
type (
|
||||
queryCase struct {
|
||||
query programCounter
|
||||
exp []int
|
||||
}
|
||||
intervalCase struct {
|
||||
begin, end programCounter
|
||||
id int
|
||||
}
|
||||
)
|
||||
|
||||
newQueryCase := func(s programCounter, exp ...int) queryCase {
|
||||
return queryCase{query: s, exp: exp}
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
intervals []intervalCase
|
||||
queryCases []queryCase
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
intervals: []intervalCase{{begin: 0, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0, 0),
|
||||
newQueryCase(1, 0),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single/2",
|
||||
intervals: []intervalCase{{begin: 50, end: 100, id: 0}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(48),
|
||||
newQueryCase(50, 0),
|
||||
newQueryCase(51, 0),
|
||||
newQueryCase(101),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "same id for different intervals",
|
||||
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xa),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two disjoint intervals",
|
||||
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(51, 0xa),
|
||||
newQueryCase(101),
|
||||
newQueryCase(150, 0xb),
|
||||
newQueryCase(200, 0xb),
|
||||
newQueryCase(201),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two intersecting intervals",
|
||||
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(0),
|
||||
newQueryCase(1),
|
||||
newQueryCase(49),
|
||||
newQueryCase(50, 0xa),
|
||||
newQueryCase(51, 0xa, 0xb),
|
||||
newQueryCase(70, 0xa, 0xb),
|
||||
newQueryCase(100, 0xa, 0xb),
|
||||
newQueryCase(101, 0xb),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two enclosing interval",
|
||||
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
|
||||
queryCases: []queryCase{
|
||||
newQueryCase(24),
|
||||
newQueryCase(25, 0xb),
|
||||
newQueryCase(39, 0xb),
|
||||
newQueryCase(40, 0xb, 0xc),
|
||||
newQueryCase(49, 0xb, 0xc),
|
||||
newQueryCase(50, 0xa, 0xb, 0xc),
|
||||
newQueryCase(51, 0xa, 0xb, 0xc),
|
||||
newQueryCase(99, 0xa, 0xb, 0xc),
|
||||
newQueryCase(100, 0xa, 0xb, 0xc),
|
||||
newQueryCase(101, 0xb, 0xc),
|
||||
newQueryCase(200, 0xb, 0xc),
|
||||
newQueryCase(201, 0xc),
|
||||
newQueryCase(1000, 0xc),
|
||||
newQueryCase(1001),
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, onlyReal := range []bool{false, true} {
|
||||
t.Run(fmt.Sprintf("onlyReal=%t", onlyReal), func(t *testing.T) {
|
||||
manager := newIntervalManager()
|
||||
for _, inter := range tc.intervals {
|
||||
n := &node{id: inter.id, r: RealReg(1)}
|
||||
manager.insert(n, inter.begin, inter.end)
|
||||
key := intervalTreeNodeKey(inter.begin, inter.end)
|
||||
inserted := manager.intervals[key]
|
||||
|
||||
// They are ignored.
|
||||
if onlyReal {
|
||||
inserted.nodes = append(inserted.nodes, &node{v: VRegInvalid.SetRealReg(RealRegInvalid)}) // non-real reg should be ignored.
|
||||
} else {
|
||||
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeInt)})
|
||||
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeFloat)})
|
||||
inserted.nodes = append(inserted.nodes, &node{v: VReg(1)})
|
||||
}
|
||||
}
|
||||
manager.build()
|
||||
for _, qc := range tc.queryCases {
|
||||
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
|
||||
var collected []*node
|
||||
manager.collectActiveNodes(qc.query, &collected, onlyReal)
|
||||
require.Equal(t, len(qc.exp), len(collected))
|
||||
var foundIDs []int
|
||||
for _, n := range collected {
|
||||
foundIDs = append(foundIDs, n.id)
|
||||
}
|
||||
sort.Slice(foundIDs, func(i, j int) bool {
|
||||
return foundIDs[i] < foundIDs[j]
|
||||
})
|
||||
require.Equal(t, qc.exp, foundIDs)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,16 +80,16 @@ type (
|
||||
lastUses map[VReg]programCounter
|
||||
kills map[VReg]programCounter
|
||||
// Pre-colored real registers can have multiple live ranges in one block.
|
||||
realRegUses map[VReg][]programCounter
|
||||
realRegDefs map[VReg][]programCounter
|
||||
intervalTree *intervalTree
|
||||
realRegUses map[VReg][]programCounter
|
||||
realRegDefs map[VReg][]programCounter
|
||||
intervalMng *intervalManager
|
||||
}
|
||||
|
||||
// node represents a VReg.
|
||||
node struct {
|
||||
id int
|
||||
v VReg
|
||||
ranges []*intervalTreeNode
|
||||
ranges []*interval
|
||||
// r is the real register assigned to this node. It is either a pre-colored register or a register assigned during allocation.
|
||||
r RealReg
|
||||
// neighbors are the nodes that this node interferes with. Such neighbors have the same RegType as this node.
|
||||
@@ -300,6 +300,7 @@ func (a *Allocator) buildLiveRanges(f Function) {
|
||||
info := a.blockInfoAt(blkID)
|
||||
a.buildLiveRangesForNonReals(info)
|
||||
a.buildLiveRangesForReals(info)
|
||||
info.intervalMng.build()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +335,7 @@ func (a *Allocator) buildLiveRangesForNonReals(info *blockInfo) {
|
||||
begin, end = 0, killPos
|
||||
}
|
||||
n := a.getOrAllocateNode(v)
|
||||
intervalNode := info.intervalTree.insert(n, begin, end)
|
||||
intervalNode := info.intervalMng.insert(n, begin, end)
|
||||
n.ranges = append(n.ranges, intervalNode)
|
||||
}
|
||||
|
||||
@@ -375,7 +376,7 @@ func (a *Allocator) buildLiveRangesForNonReals(info *blockInfo) {
|
||||
}
|
||||
}
|
||||
n := a.getOrAllocateNode(v)
|
||||
intervalNode := info.intervalTree.insert(n, defPos, end)
|
||||
intervalNode := info.intervalMng.insert(n, defPos, end)
|
||||
n.ranges = append(n.ranges, intervalNode)
|
||||
}
|
||||
|
||||
@@ -431,7 +432,7 @@ func (a *Allocator) buildLiveRangesForReals(info *blockInfo) {
|
||||
n.r = v.RealReg()
|
||||
n.v = v
|
||||
defined, used := defs[i], uses[i]
|
||||
intervalNode := info.intervalTree.insert(n, defined, used)
|
||||
intervalNode := info.intervalMng.insert(n, defined, used)
|
||||
n.ranges = append(n.ranges, intervalNode)
|
||||
}
|
||||
}
|
||||
@@ -517,10 +518,10 @@ func resetMap[T any](a *Allocator, m map[VReg]T) {
|
||||
}
|
||||
|
||||
func (a *Allocator) initBlockInfo(i *blockInfo) {
|
||||
if i.intervalTree == nil {
|
||||
i.intervalTree = newIntervalTree()
|
||||
if i.intervalMng == nil {
|
||||
i.intervalMng = newIntervalManager()
|
||||
} else {
|
||||
i.intervalTree.reset()
|
||||
i.intervalMng.reset()
|
||||
}
|
||||
if i.liveOuts == nil {
|
||||
i.liveOuts = make(map[VReg]struct{})
|
||||
|
||||
@@ -445,10 +445,10 @@ func TestAllocator_livenessAnalysis(t *testing.T) {
|
||||
actual := &a.blockInfos[blockID]
|
||||
exp := tc.exp[blockID]
|
||||
initMapInInfo(exp)
|
||||
saved := actual.intervalTree
|
||||
actual.intervalTree = nil // Don't compare intervalTree.
|
||||
saved := actual.intervalMng
|
||||
actual.intervalMng = nil // Don't compare intervalManager.
|
||||
require.Equal(t, exp, actual, "\n[exp for block[%d]]\n%s\n[actual for block[%d]]\n%s", blockID, exp, blockID, actual)
|
||||
actual.intervalTree = saved
|
||||
actual.intervalMng = saved
|
||||
}
|
||||
|
||||
// Sanity check: buildLiveRanges should not panic.
|
||||
|
||||
Reference in New Issue
Block a user