wazevo(regalloc): simplifies live range management (#1798)

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
This commit is contained in:
Takeshi Yoneda
2023-10-19 11:37:51 +09:00
committed by GitHub
parent 9264104c0b
commit 06136049e5
9 changed files with 471 additions and 603 deletions

View File

@@ -23,20 +23,21 @@ func (a *Allocator) assignRegistersPerBlock(f Function, blk Block, vRegIDToNode
blkID := blk.ID()
var pc programCounter
for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
tree := a.blockInfos[blkID].intervalTree
tree := a.blockInfos[blkID].intervalMng
a.assignRegistersPerInstr(f, pc, instr, vRegIDToNode, tree)
pc += pcStride
}
}
func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, tree *intervalTree) {
func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr Instr, vRegIDToNode []*node, intervalMng *intervalManager) {
if indirect := instr.IsIndirectCall(); instr.IsCall() || indirect {
// Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since
// the real VRegs are already placed in the right registers at this point.
tree.collectActiveNonRealVRegsAt(
intervalMng.collectActiveNodes(
// To find the all the live registers "after" call, we need to add pcDefOffset for search.
pc+pcDefOffset,
&a.nodes1,
// Only take care of non-real VRegs (e.g. VReg.IsRealReg() == false) since
// the real VRegs are already placed in the right registers at this point.
false,
)
for _, active := range a.nodes1 {
if r := active.r; a.regInfo.isCallerSaved(r) {
@@ -101,19 +102,19 @@ func (a *Allocator) assignRegistersPerInstr(f Function, pc programCounter, instr
panic("BUG: multiple def instructions must be special cased")
}
a.handleSpills(f, pc, instr, usesSpills, defSpill, tree)
a.handleSpills(f, pc, instr, usesSpills, defSpill, intervalMng)
a.vs = usesSpills[:0] // for reuse.
}
func (a *Allocator) handleSpills(
f Function, pc programCounter, instr Instr,
usesSpills []VReg, defSpill VReg, tree *intervalTree,
usesSpills []VReg, defSpill VReg, intervalMng *intervalManager,
) {
_usesSpills, _defSpill := len(usesSpills) > 0, defSpill.Valid()
switch {
case !_usesSpills && !_defSpill: // Nothing to do.
case !_usesSpills && _defSpill: // Only definition is spilled.
tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1)
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
a.spillHandler.init(a.nodes1, instr)
r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
@@ -129,7 +130,7 @@ func (a *Allocator) handleSpills(
f.StoreRegisterAfter(defSpill, instr)
case _usesSpills:
tree.collectActiveRealRegNodesAt(pc, &a.nodes1)
intervalMng.collectActiveNodes(pc, &a.nodes1, true)
a.spillHandler.init(a.nodes1, instr)
var evicted [3]*node
@@ -173,7 +174,7 @@ func (a *Allocator) handleSpills(
if !defSpill.IsRealReg() {
// This case, the destination register type is different from the source registers.
tree.collectActiveRealRegNodesAt(pc+pcDefOffset, &a.nodes1)
intervalMng.collectActiveNodes(pc+pcDefOffset, &a.nodes1, true)
a.spillHandler.init(a.nodes1, instr)
r, evictedNode := a.spillHandler.getUnusedOrEvictReg(defSpill.RegType(), a.regInfo)
if evictedNode != nil {

View File

@@ -11,7 +11,7 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
a := NewAllocator(&RegisterInfo{CallerSavedRegisters: [RealRegsNumMax]bool{1: true, 3: true}})
pc := programCounter(5)
tree := newIntervalTree()
manager := newIntervalManager()
liveNodes := []*node{
{r: 1, v: 0xa},
{r: RealRegInvalid, v: 0xb}, // Spill. not save target.
@@ -20,12 +20,13 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
}
for _, n := range liveNodes {
tree.insert(n, 5, 20)
manager.insert(n, 5, 20)
}
manager.build()
call := newMockInstr().asCall()
blk := newMockBlock(0, call).entry()
f := newMockFunction(blk)
a.assignRegistersPerInstr(f, pc, call, nil, tree)
a.assignRegistersPerInstr(f, pc, call, nil, manager)
require.Equal(t, 2, len(f.befores))
require.Equal(t, 2, len(f.afters))
@@ -36,8 +37,8 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
functionPtrVRegID := 0x0
functionPtrVReg := VReg(functionPtrVRegID).SetRegType(RegTypeInt)
functionPtrLiveNode := &node{r: 0xf, v: functionPtrVReg}
tree := newIntervalTree()
tree.insert(functionPtrLiveNode, 4, pc)
manager := newIntervalManager()
manager.insert(functionPtrLiveNode, 4, pc)
liveNodes := []*node{
{r: 1, v: 0xa},
{r: 2, v: FromRealReg(1, RegTypeInt)}, // Real reg-backed VReg. not target
@@ -45,12 +46,13 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
}
for _, n := range liveNodes {
tree.insert(n, 5, 20)
manager.insert(n, 5, 20)
}
manager.build()
callInd := newMockInstr().asIndirectCall().use(functionPtrVReg)
blk := newMockBlock(0, callInd).entry()
f := newMockFunction(blk)
a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode}, tree)
a.assignRegistersPerInstr(f, pc, callInd, []*node{0: functionPtrLiveNode}, manager)
require.Equal(t, 2, len(f.befores))
require.Equal(t, 2, len(f.afters))
@@ -71,17 +73,17 @@ func TestAllocator_assignRegistersPerInstr(t *testing.T) {
{r: 3, v: 0xc},
{r: 4, v: 0xd}, // real reg, but not caller saved. not save target
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, 5, 20)
manager.insert(n, 5, 20)
}
manager.build()
callInd := newMockInstr().asIndirectCall().use(functionPtrVReg)
blk := newMockBlock(0, callInd).entry()
f := newMockFunction(blk)
a.assignRegistersPerInstr(f, pc, callInd, []*node{
0: {r: RealRegInvalid},
}, tree)
}, manager)
require.Equal(t, 3, len(f.befores))
require.Equal(t, 2, len(f.afters))
@@ -132,10 +134,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xb), v: 0xa},
{r: RealReg(0xc), v: 0xc},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xa, 0xb, 0xc}}, // Only live nodes are allocatable.
})
@@ -144,7 +147,7 @@ func TestAllocator_handleSpills(t *testing.T) {
vr := VReg(100).SetRegType(RegTypeInt)
instr := newMockInstr().def(vr)
a.handleSpills(f, pc, instr, nil, vr, tree)
a.handleSpills(f, pc, instr, nil, vr, manager)
require.Equal(t, 1, len(instr.defs))
require.Equal(t, RealReg(0xa), instr.defs[0].RealReg())
@@ -162,11 +165,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xb), v: 0xa},
{r: RealReg(0xc), v: 0xc},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xb, 0xc}}, // Only live nodes are allocatable.
})
@@ -175,7 +178,7 @@ func TestAllocator_handleSpills(t *testing.T) {
vr := VReg(100).SetRegType(RegTypeInt)
instr := newMockInstr().def(vr)
a.handleSpills(f, pc, instr, nil, vr, tree)
a.handleSpills(f, pc, instr, nil, vr, manager)
require.Equal(t, 1, len(instr.defs))
require.Equal(t, RealReg(0xb), instr.defs[0].RealReg())
@@ -194,11 +197,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xc), v: 0xc},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{RegTypeInt: {0xb, 0xc, 0xf /* free */}},
})
@@ -207,7 +210,7 @@ func TestAllocator_handleSpills(t *testing.T) {
vr := VReg(100).SetRegType(RegTypeInt)
instr := newMockInstr().def(vr)
a.handleSpills(f, pc, instr, nil, vr, tree)
a.handleSpills(f, pc, instr, nil, vr, manager)
require.Equal(t, 1, len(instr.defs))
require.Equal(t, RealReg(0xf), instr.defs[0].RealReg())
@@ -223,11 +226,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xc), v: 0xc},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{
RegTypeInt: {0xb, 0xc, 0xa /* free */},
@@ -241,7 +244,7 @@ func TestAllocator_handleSpills(t *testing.T) {
VReg(102).SetRegType(RegTypeFloat)
d1 := VReg(104).SetRegType(RegTypeFloat)
instr := newMockInstr().use(u1, u2, u3).def(d1)
a.handleSpills(f, pc, instr, []VReg{u1, u3}, d1, tree)
a.handleSpills(f, pc, instr, []VReg{u1, u3}, d1, manager)
require.Equal(t, []VReg{u1.SetRealReg(0xa), u2, u3.SetRealReg(0xf)}, instr.uses)
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)
@@ -261,11 +264,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xf), v: 0xb},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{
RegTypeInt: {0xb, 0xa /* free */},
@@ -277,7 +280,7 @@ func TestAllocator_handleSpills(t *testing.T) {
u1 := VReg(100).SetRegType(RegTypeInt)
d1 := VReg(104).SetRegType(RegTypeFloat)
instr := newMockInstr().use(u1).def(d1)
a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree)
a.handleSpills(f, pc, instr, []VReg{u1}, d1, manager)
require.Equal(t, []VReg{u1.SetRealReg(0xa)}, instr.uses)
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)
@@ -297,11 +300,11 @@ func TestAllocator_handleSpills(t *testing.T) {
{r: RealReg(0xc), v: 0xa},
{r: RealReg(0xf), v: 0xb},
}
tree := newIntervalTree()
manager := newIntervalManager()
for _, n := range liveNodes {
tree.insert(n, pc, 20)
manager.insert(n, pc, 20)
}
manager.build()
a := NewAllocator(&RegisterInfo{
AllocatableRegisters: [3][]RealReg{
RegTypeInt: {0xb, 0xc},
@@ -313,7 +316,7 @@ func TestAllocator_handleSpills(t *testing.T) {
u1 := VReg(100).SetRegType(RegTypeInt)
d1 := VReg(104).SetRegType(RegTypeFloat)
instr := newMockInstr().use(u1).def(d1)
a.handleSpills(f, pc, instr, []VReg{u1}, d1, tree)
a.handleSpills(f, pc, instr, []VReg{u1}, d1, manager)
require.Equal(t, []VReg{u1.SetRealReg(0xb)}, instr.uses)
require.Equal(t, []VReg{d1.SetRealReg(0xf)}, instr.defs)

View File

@@ -259,7 +259,7 @@ func TestAllocator_buildNeighbors(t *testing.T) {
{n: newNode(0)},
{
n: &node{
ranges: []*intervalTreeNode{
ranges: []*interval{
{nodes: newNodes(1, 2, 3)},
{nodes: newNodes(4, 5, 1, 2, 3)},
},
@@ -268,7 +268,7 @@ func TestAllocator_buildNeighbors(t *testing.T) {
},
{
n: &node{
ranges: []*intervalTreeNode{
ranges: []*interval{
{nodes: newNodes(1, 2, 3)},
{nodes: newNodes(1, 2, 3)},
{nodes: newNodes(1, 2, 3)},
@@ -280,9 +280,9 @@ func TestAllocator_buildNeighbors(t *testing.T) {
},
{
n: &node{
ranges: []*intervalTreeNode{
ranges: []*interval{
{nodes: newNodes(1, 2, 3)},
{nodes: newNodes(4), neighbors: []*intervalTreeNode{
{nodes: newNodes(4), neighbors: []*interval{
{nodes: newNodes(5, 6)},
{nodes: newNodes(100, 200)},
}},

View File

@@ -1,157 +0,0 @@
package regalloc
import "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
type intervalTree struct {
root *intervalTreeNode
allocator intervalTreeNodeAllocator
intervals map[uint64]*intervalTreeNode
}
func intervalTreeNodeKey(begin, end programCounter) uint64 {
return uint64(begin) | uint64(end)<<32
}
func (t *intervalTree) insert(n *node, begin, end programCounter) *intervalTreeNode {
key := uint64(begin) | uint64(end)<<32
if i, ok := t.intervals[key]; ok {
i.nodes = append(i.nodes, n)
return i
}
t.root = t.root.insert(t, n, begin, end)
ret := t.intervals[key]
t.buildNeighbors(ret) // TODO: this can be done while inserting.
return ret
}
func (t *intervalTree) reset() {
t.root = nil
t.allocator.Reset()
t.intervals = make(map[uint64]*intervalTreeNode)
}
func newIntervalTree() *intervalTree {
return &intervalTree{
allocator: wazevoapi.NewPool[intervalTreeNode](resetIntervalTreeNode),
intervals: make(map[uint64]*intervalTreeNode),
}
}
type intervalTreeNodeAllocator = wazevoapi.Pool[intervalTreeNode]
type intervalTreeNode struct {
begin, end programCounter
nodes []*node
maxEnd programCounter
neighbors []*intervalTreeNode
left, right *intervalTreeNode
// TODO: color for red-black balancing.
}
func resetIntervalTreeNode(i *intervalTreeNode) {
i.begin = 0
i.end = 0
i.nodes = i.nodes[:0]
i.maxEnd = 0
i.neighbors = i.neighbors[:0]
i.left = nil
i.right = nil
}
func (i *intervalTreeNode) insert(t *intervalTree, n *node, begin, end programCounter) *intervalTreeNode {
if i == nil {
intervalNode := t.allocator.Allocate()
intervalNode.nodes = append(intervalNode.nodes, n)
intervalNode.maxEnd = end
intervalNode.begin = begin
intervalNode.end = end
key := intervalTreeNodeKey(begin, end)
t.intervals[key] = intervalNode
return intervalNode
}
if begin < i.begin {
i.left = i.left.insert(t, n, begin, end)
} else {
i.right = i.right.insert(t, n, begin, end)
}
if i.maxEnd < end {
i.maxEnd = end
}
// TODO: balancing logic so that collection functions are faster.
return i
}
func (t *intervalTree) buildNeighbors(from *intervalTreeNode) {
t.root.buildNeighbors(from)
}
func (i *intervalTreeNode) buildNeighbors(from *intervalTreeNode) {
if i == nil {
return
}
if i.intersects(from) {
from.neighbors = append(from.neighbors, i)
i.neighbors = append(i.neighbors, from)
}
if i.left != nil && i.left.maxEnd >= from.begin {
i.left.buildNeighbors(from)
}
if i.begin <= from.end {
i.right.buildNeighbors(from)
}
}
func (t *intervalTree) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) {
*overlaps = (*overlaps)[:0]
t.root.collectActiveNonRealVRegsAt(pc, overlaps)
}
func (i *intervalTreeNode) collectActiveNonRealVRegsAt(pc programCounter, overlaps *[]*node) {
if i == nil {
return
}
if i.begin <= pc && i.end >= pc {
for _, n := range i.nodes {
if n.spill() || n.v.IsRealReg() {
continue
}
*overlaps = append(*overlaps, n)
}
}
if i.left != nil && i.left.maxEnd >= pc {
i.left.collectActiveNonRealVRegsAt(pc, overlaps)
}
if i.begin <= pc {
i.right.collectActiveNonRealVRegsAt(pc, overlaps)
}
}
func (t *intervalTree) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) {
*overlaps = (*overlaps)[:0]
t.root.collectActiveRealRegNodesAt(pc, overlaps)
}
func (i *intervalTreeNode) collectActiveRealRegNodesAt(pc programCounter, overlaps *[]*node) {
if i == nil {
return
}
if i.begin <= pc && i.end >= pc {
for _, n := range i.nodes {
if n.assignedRealReg() != RealRegInvalid {
*overlaps = append(*overlaps, n)
}
}
}
if i.left != nil && i.left.maxEnd >= pc {
i.left.collectActiveRealRegNodesAt(pc, overlaps)
}
if i.begin <= pc {
i.right.collectActiveRealRegNodesAt(pc, overlaps)
}
}
func (i *intervalTreeNode) intersects(j *intervalTreeNode) bool {
return i.begin <= j.end && i.end >= j.begin || j.begin <= i.end && j.end >= i.begin
}

View File

@@ -1,385 +0,0 @@
package regalloc
import (
"fmt"
"sort"
"testing"
"github.com/tetratelabs/wazero/internal/testing/require"
)
func TestIntervalTree_reset(t *testing.T) {
tree := newIntervalTree()
n := tree.allocator.Allocate()
tree.root = n
tree.reset()
require.Nil(t, tree.root)
require.Equal(t, 0, tree.allocator.Allocated())
}
func TestIntervalTreeInsert(t *testing.T) {
n1 := &node{}
tree := newIntervalTree()
tree.insert(n1, 100, 200)
require.NotNil(t, tree.root)
require.NotNil(t, tree.root.nodes)
require.Equal(t, n1, tree.root.nodes[0])
n, ok := tree.intervals[intervalTreeNodeKey(100, 200)]
require.True(t, ok)
require.Equal(t, n1, n.nodes[0])
}
func TestIntervalTreeNodeInsert(t *testing.T) {
t.Run("nil", func(t *testing.T) {
tree := newIntervalTree()
var n *intervalTreeNode
allocated := n.insert(tree, &node{}, 0, 100)
require.Equal(t, 1, tree.allocator.Allocated())
require.NotNil(t, allocated)
require.Equal(t, allocated, tree.allocator.View(0))
require.Equal(t, programCounter(100), allocated.maxEnd)
require.Equal(t, programCounter(0), allocated.begin)
require.Equal(t, programCounter(100), allocated.end)
require.Equal(t, 1, len(allocated.nodes))
n, ok := tree.intervals[intervalTreeNodeKey(0, 100)]
require.True(t, ok)
require.Equal(t, allocated, n)
})
t.Run("left", func(t *testing.T) {
tree := newIntervalTree()
n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100}
n1 := &node{}
self := n.insert(tree, n1, 0, 200)
require.Equal(t, self, n)
require.Equal(t, 1, tree.allocator.Allocated())
left := tree.allocator.View(0)
require.Equal(t, n.left, left)
require.Nil(t, n.right)
require.Equal(t, programCounter(200), n.maxEnd)
require.Equal(t, left.nodes[0], n1)
})
t.Run("right", func(t *testing.T) {
tree := newIntervalTree()
n := &intervalTreeNode{begin: 50, end: 100, maxEnd: 100}
n1 := &node{}
self := n.insert(tree, n1, 150, 200)
require.Equal(t, self, n)
require.Equal(t, 1, tree.allocator.Allocated())
right := tree.allocator.View(0)
require.Equal(t, n.right, right)
require.Nil(t, n.left)
require.Equal(t, programCounter(200), n.maxEnd)
require.Equal(t, right.nodes[0], n1)
})
}
type (
interval struct {
begin, end programCounter
id int
}
queryCase struct {
query programCounter
exp []int
}
)
func newQueryCase(s programCounter, exp ...int) queryCase {
return queryCase{query: s, exp: exp}
}
func TestIntervalTree_collectActiveNodesAt(t *testing.T) {
for _, tc := range []struct {
name string
intervals []interval
queryCases []queryCase
}{
{
name: "single",
intervals: []interval{{begin: 0, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(0, 0),
newQueryCase(0, 0),
newQueryCase(1, 0),
newQueryCase(1, 0),
newQueryCase(101),
},
},
{
name: "single/2",
intervals: []interval{{begin: 50, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(50, 0),
newQueryCase(50, 0),
newQueryCase(51, 0),
newQueryCase(51, 0),
newQueryCase(101),
newQueryCase(48),
},
},
{
name: "same id for different intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(50, 0xa),
newQueryCase(101),
newQueryCase(150, 0xa),
newQueryCase(101),
},
},
{
name: "two disjoint intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(50, 0xa),
newQueryCase(0),
newQueryCase(51, 0xa),
newQueryCase(101),
newQueryCase(150, 0xb),
newQueryCase(200, 0xb),
newQueryCase(101),
newQueryCase(201),
},
},
{
name: "two intersecting intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(70, 0xa, 0xb),
newQueryCase(1),
newQueryCase(50, 0xa),
newQueryCase(51, 0xa, 0xb),
newQueryCase(100, 0xa, 0xb),
newQueryCase(101, 0xb),
newQueryCase(49),
newQueryCase(1001),
},
},
{
name: "two enclosing interval",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
queryCases: []queryCase{
newQueryCase(24),
newQueryCase(25, 0xb),
newQueryCase(39, 0xb),
newQueryCase(40, 0xb, 0xc),
newQueryCase(100, 0xa, 0xb, 0xc),
newQueryCase(99, 0xa, 0xb, 0xc),
newQueryCase(100, 0xa, 0xb, 0xc),
newQueryCase(50, 0xa, 0xb, 0xc),
newQueryCase(51, 0xa, 0xb, 0xc),
newQueryCase(101, 0xb, 0xc),
newQueryCase(49, 0xb, 0xc),
newQueryCase(200, 0xb, 0xc),
newQueryCase(201, 0xc),
newQueryCase(1000, 0xc),
newQueryCase(1001),
},
},
} {
t.Run(tc.name, func(t *testing.T) {
tree := newIntervalTree()
var maxID int
for _, inter := range tc.intervals {
n := &node{id: inter.id, r: RealReg(1)}
tree.insert(n, inter.begin, inter.end)
if maxID < inter.id {
maxID = inter.id
}
key := intervalTreeNodeKey(inter.begin, inter.end)
inserted := tree.intervals[key]
inserted.nodes = append(inserted.nodes, &node{v: VRegInvalid.SetRealReg(RealRegInvalid)}) // non-real reg should be ignored.
}
for _, qc := range tc.queryCases {
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
var collected []*node
tree.collectActiveRealRegNodesAt(qc.query, &collected)
require.Equal(t, len(qc.exp), len(collected))
var foundIDs []int
for _, n := range collected {
foundIDs = append(foundIDs, n.id)
}
sort.Slice(foundIDs, func(i, j int) bool {
return foundIDs[i] < foundIDs[j]
})
require.Equal(t, qc.exp, foundIDs)
})
}
})
}
}
func TestIntervalTree_collectActiveNonRealVRegsAt(t *testing.T) {
for _, tc := range []struct {
name string
intervals []interval
queryCases []queryCase
}{
{
name: "single",
intervals: []interval{{begin: 0, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(0, 0),
newQueryCase(0, 0),
newQueryCase(1, 0),
newQueryCase(1, 0),
newQueryCase(101),
},
},
{
name: "single/2",
intervals: []interval{{begin: 50, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(50, 0),
newQueryCase(50, 0),
newQueryCase(51, 0),
newQueryCase(51, 0),
newQueryCase(101),
newQueryCase(48),
},
},
{
name: "same id for different intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(50, 0xa),
newQueryCase(101),
newQueryCase(150, 0xa),
newQueryCase(101),
},
},
{
name: "two disjoint intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(50, 0xa),
newQueryCase(0),
newQueryCase(51, 0xa),
newQueryCase(101),
newQueryCase(150, 0xb),
newQueryCase(200, 0xb),
newQueryCase(101),
newQueryCase(201),
},
},
{
name: "two intersecting intervals",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(70, 0xa, 0xb),
newQueryCase(1),
newQueryCase(50, 0xa),
newQueryCase(51, 0xa, 0xb),
newQueryCase(100, 0xa, 0xb),
newQueryCase(101, 0xb),
newQueryCase(49),
newQueryCase(1001),
},
},
{
name: "two enclosing interval",
intervals: []interval{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
queryCases: []queryCase{
newQueryCase(24),
newQueryCase(25, 0xb),
newQueryCase(39, 0xb),
newQueryCase(40, 0xb, 0xc),
newQueryCase(100, 0xa, 0xb, 0xc),
newQueryCase(99, 0xa, 0xb, 0xc),
newQueryCase(100, 0xa, 0xb, 0xc),
newQueryCase(50, 0xa, 0xb, 0xc),
newQueryCase(51, 0xa, 0xb, 0xc),
newQueryCase(101, 0xb, 0xc),
newQueryCase(49, 0xb, 0xc),
newQueryCase(200, 0xb, 0xc),
newQueryCase(201, 0xc),
newQueryCase(1000, 0xc),
newQueryCase(1001),
},
},
} {
t.Run(tc.name, func(t *testing.T) {
tree := newIntervalTree()
var maxID int
for _, inter := range tc.intervals {
n := &node{id: inter.id, r: RealReg(1)}
tree.insert(n, inter.begin, inter.end)
if maxID < inter.id {
maxID = inter.id
}
key := intervalTreeNodeKey(inter.begin, inter.end)
inserted := tree.intervals[key]
// They are ignored.
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeInt)})
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeFloat)})
inserted.nodes = append(inserted.nodes, &node{v: VReg(1)})
}
for _, qc := range tc.queryCases {
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
var collected []*node
tree.collectActiveNonRealVRegsAt(qc.query, &collected)
require.Equal(t, len(qc.exp), len(collected))
var foundIDs []int
for _, n := range collected {
foundIDs = append(foundIDs, n.id)
}
sort.Slice(foundIDs, func(i, j int) bool {
return foundIDs[i] < foundIDs[j]
})
require.Equal(t, qc.exp, foundIDs)
})
}
})
}
}
func TestIntervalTreeNode_intersects(t *testing.T) {
for _, tc := range []struct {
rhs, lhs intervalTreeNode
exp bool
}{
{
rhs: intervalTreeNode{begin: 0, end: 100},
lhs: intervalTreeNode{begin: 0, end: 100},
exp: true,
},
{
rhs: intervalTreeNode{begin: 0, end: 100},
lhs: intervalTreeNode{begin: 0, end: 99},
exp: true,
},
{
rhs: intervalTreeNode{begin: 0, end: 100},
lhs: intervalTreeNode{begin: 1, end: 100},
exp: true,
},
{
rhs: intervalTreeNode{begin: 50, end: 100},
lhs: intervalTreeNode{begin: 1, end: 49},
exp: false,
},
{
rhs: intervalTreeNode{begin: 50, end: 100},
lhs: intervalTreeNode{begin: 1, end: 50},
exp: true,
},
{
rhs: intervalTreeNode{begin: 50, end: 100},
lhs: intervalTreeNode{begin: 1, end: 51},
exp: true,
},
{
rhs: intervalTreeNode{begin: 50, end: 100},
lhs: intervalTreeNode{begin: 99, end: 102},
exp: true,
},
} {
actual := tc.rhs.intersects(&tc.lhs)
require.Equal(t, tc.exp, actual)
}
}

View File

@@ -0,0 +1,147 @@
package regalloc
import (
"sort"
"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
)
type (
// intervalManager manages intervals for each block.
intervalManager struct {
allocator wazevoapi.Pool[interval]
intervals map[intervalKey]*interval
sortedIntervals []*interval
collectionCur int
}
// interval represents an interval in the block, which is a range of program counters.
// Each interval has a list of nodes which are live in the interval.
interval struct {
begin, end programCounter
// nodes are nodes which are alive in this interval.
nodes []*node
// neighbors are intervals which are adjacent to this interval.
neighbors []*interval
}
// intervalKey is a key for intervalManager.intervals which consists of begin and end.
intervalKey uint64
)
func newIntervalManager() *intervalManager {
return &intervalManager{
allocator: wazevoapi.NewPool[interval](resetIntervalTreeNode),
intervals: make(map[intervalKey]*interval),
}
}
func resetIntervalTreeNode(i *interval) {
i.begin = 0
i.end = 0
i.nodes = i.nodes[:0]
i.neighbors = i.neighbors[:0]
}
// intervalTreeNodeKey returns a key for intervalManager.intervals.
func intervalTreeNodeKey(begin, end programCounter) intervalKey {
return intervalKey(begin) | intervalKey(end)<<32
}
// insert inserts a node into the interval tree.
func (t *intervalManager) insert(n *node, begin, end programCounter) *interval {
key := intervalTreeNodeKey(begin, end)
if i, ok := t.intervals[key]; ok {
i.nodes = append(i.nodes, n)
return i
}
i := t.allocator.Allocate()
i.nodes = append(i.nodes, n)
i.begin = begin
i.end = end
t.intervals[key] = i
t.sortedIntervals = append(t.sortedIntervals, i) // Will be sorted later.
return i
}
func (t *intervalManager) reset() {
t.allocator.Reset()
t.sortedIntervals = t.sortedIntervals[:0]
t.intervals = make(map[intervalKey]*interval)
t.collectionCur = 0
}
// build is called after all the intervals are inserted. This sorts the intervals,
// and builds the neighbor intervals for each interval.
func (t *intervalManager) build() {
sort.Slice(t.sortedIntervals, func(i, j int) bool {
ii, ij := t.sortedIntervals[i], t.sortedIntervals[j]
if ii.begin == ij.begin {
return ii.end < ij.end
}
return ii.begin < ij.begin
})
var cur int
var existingEndMax programCounter = -1
for i, _interval := range t.sortedIntervals {
begin, end := _interval.begin, _interval.end
if begin > existingEndMax {
cur = i
existingEndMax = end
} else {
for j := cur; j < i; j++ {
existing := t.sortedIntervals[j]
if existing.end < begin {
continue
}
if existing.begin > end {
panic("BUG")
}
_interval.neighbors = append(_interval.neighbors, existing)
existing.neighbors = append(existing.neighbors, _interval)
}
if end > existingEndMax {
existingEndMax = end
}
}
}
}
// collectActiveNodes collects nodes which are alive at pc, and the result is stored in `collected`.
// If `real` is true, only nodes which are assigned to a real register are collected.
func (t *intervalManager) collectActiveNodes(pc programCounter, collected *[]*node, real bool) {
*collected = (*collected)[:0]
// Advance the collection cursor until the current interval's end is greater than pc.
l := len(t.sortedIntervals)
for cur := t.collectionCur; cur < l; cur++ {
curNode := t.sortedIntervals[cur]
if curNode.end < pc {
t.collectionCur++
continue
} else {
break
}
}
for cur := t.collectionCur; cur < l; cur++ {
curNode := t.sortedIntervals[cur]
if curNode.end < pc {
continue
} else if curNode.begin > pc {
break
}
for _, n := range curNode.nodes {
if real {
if n.assignedRealReg() == RealRegInvalid {
continue
}
} else {
if n.spill() || n.v.IsRealReg() {
continue
}
}
*collected = append(*collected, n)
}
}
}

View File

@@ -0,0 +1,258 @@
package regalloc
import (
"fmt"
"sort"
"testing"
"github.com/tetratelabs/wazero/internal/testing/require"
)
func TestIntervalsManager_build(t *testing.T) {
type (
intervalCase struct {
begin, end programCounter
}
expNeighborCase struct {
index int
neighbors []intervalCase
}
)
for _, tc := range []struct {
name string
intervals []intervalCase
expNeighbors []expNeighborCase
}{
{
name: "single",
intervals: []intervalCase{{begin: 0, end: 100}},
expNeighbors: []expNeighborCase{{index: 0}},
},
{
name: "disjoints",
intervals: []intervalCase{{begin: 50, end: 100}, {begin: 1, end: 2}},
expNeighbors: []expNeighborCase{{index: 0}, {index: 1}},
},
{
name: "disjoints duplicate",
intervals: []intervalCase{{begin: 50, end: 100}, {begin: 1, end: 2}, {begin: 50, end: 100}, {begin: 1, end: 2}},
expNeighbors: []expNeighborCase{{index: 0}, {index: 1}, {index: 2}, {index: 3}},
},
{
name: "two intersecting",
intervals: []intervalCase{
{begin: 70, end: 200},
{begin: 50, end: 100},
},
expNeighbors: []expNeighborCase{
{index: 0, neighbors: []intervalCase{{begin: 50, end: 100}}},
{index: 1, neighbors: []intervalCase{{begin: 70, end: 200}}},
},
},
{
name: "same beginnings",
intervals: []intervalCase{
{begin: 50, end: 200},
{begin: 50, end: 401},
{begin: 50, end: 201},
{begin: 50, end: 302},
},
expNeighbors: []expNeighborCase{
{index: 0, neighbors: []intervalCase{{begin: 50, end: 201}, {begin: 50, end: 302}, {begin: 50, end: 401}}},
{index: 1, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 201}, {begin: 50, end: 302}}},
{index: 2, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 302}, {begin: 50, end: 401}}},
{index: 3, neighbors: []intervalCase{{begin: 50, end: 200}, {begin: 50, end: 201}, {begin: 50, end: 401}}},
},
},
{
name: "three intersecting",
intervals: []intervalCase{
{begin: 70, end: 200},
{begin: 71, end: 150},
{begin: 50, end: 100},
},
expNeighbors: []expNeighborCase{
{index: 0, neighbors: []intervalCase{{begin: 50, end: 100}, {begin: 71, end: 150}}},
{index: 1, neighbors: []intervalCase{{begin: 50, end: 100}, {begin: 70, end: 200}}},
{index: 2, neighbors: []intervalCase{{begin: 70, end: 200}, {begin: 71, end: 150}}},
},
},
{
name: "two enclosing interval",
intervals: []intervalCase{
{begin: 50, end: 100},
{begin: 25, end: 200},
{begin: 40, end: 1000},
},
expNeighbors: []expNeighborCase{
{index: 0, neighbors: []intervalCase{{begin: 25, end: 200}, {begin: 40, end: 1000}}},
{index: 1, neighbors: []intervalCase{{begin: 40, end: 1000}, {begin: 50, end: 100}}},
{index: 2, neighbors: []intervalCase{{begin: 25, end: 200}, {begin: 50, end: 100}}},
},
},
} {
t.Run(tc.name, func(t *testing.T) {
manager := newIntervalManager()
for i, inter := range tc.intervals {
n := &node{id: i, r: RealReg(1)}
manager.insert(n, inter.begin, inter.end)
}
manager.build()
for i, exp := range tc.expNeighbors {
it := tc.intervals[exp.index]
key := intervalTreeNodeKey(it.begin, it.end)
var found []intervalCase
for _, n := range manager.intervals[key].neighbors {
found = append(found, intervalCase{begin: n.begin, end: n.end})
}
sort.Slice(found, func(i, j int) bool {
return found[i].begin < found[j].begin
})
require.Equal(t, exp.neighbors, found, fmt.Sprintf("case=%d", i))
}
})
}
}
func TestIntervalManager_collectActiveNodes(t *testing.T) {
type (
queryCase struct {
query programCounter
exp []int
}
intervalCase struct {
begin, end programCounter
id int
}
)
newQueryCase := func(s programCounter, exp ...int) queryCase {
return queryCase{query: s, exp: exp}
}
for _, tc := range []struct {
name string
intervals []intervalCase
queryCases []queryCase
}{
{
name: "single",
intervals: []intervalCase{{begin: 0, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(0, 0),
newQueryCase(1, 0),
newQueryCase(101),
},
},
{
name: "single/2",
intervals: []intervalCase{{begin: 50, end: 100, id: 0}},
queryCases: []queryCase{
newQueryCase(48),
newQueryCase(50, 0),
newQueryCase(51, 0),
newQueryCase(101),
},
},
{
name: "same id for different intervals",
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xa}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(50, 0xa),
newQueryCase(101),
newQueryCase(150, 0xa),
},
},
{
name: "two disjoint intervals",
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 150, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(50, 0xa),
newQueryCase(51, 0xa),
newQueryCase(101),
newQueryCase(150, 0xb),
newQueryCase(200, 0xb),
newQueryCase(201),
},
},
{
name: "two intersecting intervals",
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 51, end: 200, id: 0xb}},
queryCases: []queryCase{
newQueryCase(0),
newQueryCase(1),
newQueryCase(49),
newQueryCase(50, 0xa),
newQueryCase(51, 0xa, 0xb),
newQueryCase(70, 0xa, 0xb),
newQueryCase(100, 0xa, 0xb),
newQueryCase(101, 0xb),
newQueryCase(1001),
},
},
{
name: "two enclosing interval",
intervals: []intervalCase{{begin: 50, end: 100, id: 0xa}, {begin: 25, end: 200, id: 0xb}, {begin: 40, end: 1000, id: 0xc}},
queryCases: []queryCase{
newQueryCase(24),
newQueryCase(25, 0xb),
newQueryCase(39, 0xb),
newQueryCase(40, 0xb, 0xc),
newQueryCase(49, 0xb, 0xc),
newQueryCase(50, 0xa, 0xb, 0xc),
newQueryCase(51, 0xa, 0xb, 0xc),
newQueryCase(99, 0xa, 0xb, 0xc),
newQueryCase(100, 0xa, 0xb, 0xc),
newQueryCase(101, 0xb, 0xc),
newQueryCase(200, 0xb, 0xc),
newQueryCase(201, 0xc),
newQueryCase(1000, 0xc),
newQueryCase(1001),
},
},
} {
t.Run(tc.name, func(t *testing.T) {
for _, onlyReal := range []bool{false, true} {
t.Run(fmt.Sprintf("onlyReal=%t", onlyReal), func(t *testing.T) {
manager := newIntervalManager()
for _, inter := range tc.intervals {
n := &node{id: inter.id, r: RealReg(1)}
manager.insert(n, inter.begin, inter.end)
key := intervalTreeNodeKey(inter.begin, inter.end)
inserted := manager.intervals[key]
// They are ignored.
if onlyReal {
inserted.nodes = append(inserted.nodes, &node{v: VRegInvalid.SetRealReg(RealRegInvalid)}) // non-real reg should be ignored.
} else {
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeInt)})
inserted.nodes = append(inserted.nodes, &node{v: FromRealReg(1, RegTypeFloat)})
inserted.nodes = append(inserted.nodes, &node{v: VReg(1)})
}
}
manager.build()
for _, qc := range tc.queryCases {
t.Run(fmt.Sprintf("%d", qc.query), func(t *testing.T) {
var collected []*node
manager.collectActiveNodes(qc.query, &collected, onlyReal)
require.Equal(t, len(qc.exp), len(collected))
var foundIDs []int
for _, n := range collected {
foundIDs = append(foundIDs, n.id)
}
sort.Slice(foundIDs, func(i, j int) bool {
return foundIDs[i] < foundIDs[j]
})
require.Equal(t, qc.exp, foundIDs)
})
}
})
}
})
}
}

View File

@@ -80,16 +80,16 @@ type (
lastUses map[VReg]programCounter
kills map[VReg]programCounter
// Pre-colored real registers can have multiple live ranges in one block.
realRegUses map[VReg][]programCounter
realRegDefs map[VReg][]programCounter
intervalTree *intervalTree
realRegUses map[VReg][]programCounter
realRegDefs map[VReg][]programCounter
intervalMng *intervalManager
}
// node represents a VReg.
node struct {
id int
v VReg
ranges []*intervalTreeNode
ranges []*interval
// r is the real register assigned to this node. It is either a pre-colored register or a register assigned during allocation.
r RealReg
// neighbors are the nodes that this node interferes with. Such neighbors have the same RegType as this node.
@@ -300,6 +300,7 @@ func (a *Allocator) buildLiveRanges(f Function) {
info := a.blockInfoAt(blkID)
a.buildLiveRangesForNonReals(info)
a.buildLiveRangesForReals(info)
info.intervalMng.build()
}
}
@@ -334,7 +335,7 @@ func (a *Allocator) buildLiveRangesForNonReals(info *blockInfo) {
begin, end = 0, killPos
}
n := a.getOrAllocateNode(v)
intervalNode := info.intervalTree.insert(n, begin, end)
intervalNode := info.intervalMng.insert(n, begin, end)
n.ranges = append(n.ranges, intervalNode)
}
@@ -375,7 +376,7 @@ func (a *Allocator) buildLiveRangesForNonReals(info *blockInfo) {
}
}
n := a.getOrAllocateNode(v)
intervalNode := info.intervalTree.insert(n, defPos, end)
intervalNode := info.intervalMng.insert(n, defPos, end)
n.ranges = append(n.ranges, intervalNode)
}
@@ -431,7 +432,7 @@ func (a *Allocator) buildLiveRangesForReals(info *blockInfo) {
n.r = v.RealReg()
n.v = v
defined, used := defs[i], uses[i]
intervalNode := info.intervalTree.insert(n, defined, used)
intervalNode := info.intervalMng.insert(n, defined, used)
n.ranges = append(n.ranges, intervalNode)
}
}
@@ -517,10 +518,10 @@ func resetMap[T any](a *Allocator, m map[VReg]T) {
}
func (a *Allocator) initBlockInfo(i *blockInfo) {
if i.intervalTree == nil {
i.intervalTree = newIntervalTree()
if i.intervalMng == nil {
i.intervalMng = newIntervalManager()
} else {
i.intervalTree.reset()
i.intervalMng.reset()
}
if i.liveOuts == nil {
i.liveOuts = make(map[VReg]struct{})

View File

@@ -445,10 +445,10 @@ func TestAllocator_livenessAnalysis(t *testing.T) {
actual := &a.blockInfos[blockID]
exp := tc.exp[blockID]
initMapInInfo(exp)
saved := actual.intervalTree
actual.intervalTree = nil // Don't compare intervalTree.
saved := actual.intervalMng
actual.intervalMng = nil // Don't compare intervalManager.
require.Equal(t, exp, actual, "\n[exp for block[%d]]\n%s\n[actual for block[%d]]\n%s", blockID, exp, blockID, actual)
actual.intervalTree = saved
actual.intervalMng = saved
}
// Sanity check: buildLiveRanges should not panic.