Files
wazero/internal/engine/wazevo/ssa/pass.go
Takeshi Yoneda 65650d399d ssa: reuses slices for basicBlock.params (#2247)
This replaces the basicBlock.params field with the reusable
VarLength[Value] type. As a result, the compilation starts
using less memory and allocations.

```
goos: darwin
goarch: arm64
pkg: github.com/tetratelabs/wazero
                      │  old.txt   │             new.txt              │
                      │   sec/op   │   sec/op    vs base              │
Compilation/wazero-10   2.004 ± 2%   2.001 ± 0%       ~ (p=0.620 n=7)
Compilation/zig-10      4.164 ± 1%   4.174 ± 3%       ~ (p=0.097 n=7)
geomean                 2.888        2.890       +0.06%

                      │   old.txt    │              new.txt               │
                      │     B/op     │     B/op      vs base              │
Compilation/wazero-10   297.7Mi ± 0%   297.5Mi ± 0%  -0.06% (p=0.001 n=7)
Compilation/zig-10      594.0Mi ± 0%   593.9Mi ± 0%  -0.01% (p=0.001 n=7)
geomean                 420.5Mi        420.3Mi       -0.03%

                      │   old.txt   │              new.txt              │
                      │  allocs/op  │  allocs/op   vs base              │
Compilation/wazero-10   472.5k ± 0%   457.1k ± 0%  -3.25% (p=0.001 n=7)
Compilation/zig-10      277.2k ± 0%   275.7k ± 0%  -0.53% (p=0.001 n=7)
geomean                 361.9k        355.0k       -1.90%
```

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
2024-06-13 13:01:58 -07:00

412 lines
14 KiB
Go

package ssa
import (
"fmt"
"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
)
// RunPasses implements Builder.RunPasses.
//
// The order here matters; some pass depends on the previous ones.
//
// Note that passes suffixed with "Opt" are the optimization passes, meaning that they edit the instructions and blocks
// while the other passes are not, like passEstimateBranchProbabilities does not edit them, but only calculates the additional information.
func (b *builder) RunPasses() {
b.runPreBlockLayoutPasses()
b.runBlockLayoutPass()
b.runPostBlockLayoutPasses()
b.runFinalizingPasses()
}
func (b *builder) runPreBlockLayoutPasses() {
passSortSuccessors(b)
passDeadBlockEliminationOpt(b)
// The result of passCalculateImmediateDominators will be used by various passes below.
passCalculateImmediateDominators(b)
passRedundantPhiEliminationOpt(b)
passNopInstElimination(b)
// TODO: implement either conversion of irreducible CFG into reducible one, or irreducible CFG detection where we panic.
// WebAssembly program shouldn't result in irreducible CFG, but we should handle it properly in just in case.
// See FixIrreducible pass in LLVM: https://llvm.org/doxygen/FixIrreducible_8cpp_source.html
// TODO: implement more optimization passes like:
// block coalescing.
// Copy-propagation.
// Constant folding.
// Common subexpression elimination.
// Arithmetic simplifications.
// and more!
// passDeadCodeEliminationOpt could be more accurate if we do this after other optimizations.
passDeadCodeEliminationOpt(b)
b.donePreBlockLayoutPasses = true
}
func (b *builder) runBlockLayoutPass() {
if !b.donePreBlockLayoutPasses {
panic("runBlockLayoutPass must be called after all pre passes are done")
}
passLayoutBlocks(b)
b.doneBlockLayout = true
}
// runPostBlockLayoutPasses runs the post block layout passes. After this point, CFG is somewhat stable,
// but still can be modified before finalizing passes. At this point, critical edges are split by passLayoutBlocks.
func (b *builder) runPostBlockLayoutPasses() {
if !b.doneBlockLayout {
panic("runPostBlockLayoutPasses must be called after block layout pass is done")
}
// TODO: Do more. e.g. tail duplication, loop unrolling, etc.
b.donePostBlockLayoutPasses = true
}
// runFinalizingPasses runs the finalizing passes. After this point, CFG should not be modified.
func (b *builder) runFinalizingPasses() {
if !b.donePostBlockLayoutPasses {
panic("runFinalizingPasses must be called after post block layout passes are done")
}
// Critical edges are split, so we fix the loop nesting forest.
passBuildLoopNestingForest(b)
passBuildDominatorTree(b)
// Now that we know the final placement of the blocks, we can explicitly mark the fallthrough jumps.
b.markFallthroughJumps()
}
// passDeadBlockEliminationOpt searches the unreachable blocks, and sets the basicBlock.invalid flag true if so.
func passDeadBlockEliminationOpt(b *builder) {
entryBlk := b.entryBlk()
b.blkStack = append(b.blkStack, entryBlk)
for len(b.blkStack) > 0 {
reachableBlk := b.blkStack[len(b.blkStack)-1]
b.blkStack = b.blkStack[:len(b.blkStack)-1]
reachableBlk.visited = 1
if !reachableBlk.sealed && !reachableBlk.ReturnBlock() {
panic(fmt.Sprintf("%s is not sealed", reachableBlk))
}
if wazevoapi.SSAValidationEnabled {
reachableBlk.validate(b)
}
for _, succ := range reachableBlk.success {
if succ.visited == 1 {
continue
}
b.blkStack = append(b.blkStack, succ)
}
}
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
if blk.visited != 1 {
blk.invalid = true
}
blk.visited = 0
}
}
// passRedundantPhiEliminationOpt eliminates the redundant PHIs (in our terminology, parameters of a block).
// This requires the reverse post-order traversal to be calculated before calling this function,
// hence passCalculateImmediateDominators must be called before this.
func passRedundantPhiEliminationOpt(b *builder) {
redundantParameterIndexes := b.ints[:0] // reuse the slice from previous iterations.
// TODO: this might be costly for large programs, but at least, as far as I did the experiment, it's almost the
// same as the single iteration version in terms of the overall compilation time. That *might be* mostly thanks to the fact
// that removing many PHIs results in the reduction of the total instructions, not because of this indefinite iteration is
// relatively small. For example, sqlite speedtest binary results in the large number of redundant PHIs,
// the maximum number of iteration was 22, which seems to be acceptable but not that small either since the
// complexity here is O(BlockNum * Iterations) at the worst case where BlockNum might be the order of thousands.
// -- Note --
// Currently, each iteration can run in any order of blocks, but it empirically converges quickly in practice when
// running on the reverse post-order. It might be possible to optimize this further by using the dominator tree.
for {
changed := false
_ = b.blockIteratorReversePostOrderBegin() // skip entry block!
// Below, we intentionally use the named iteration variable name, as this comes with inevitable nested for loops!
for blk := b.blockIteratorReversePostOrderNext(); blk != nil; blk = b.blockIteratorReversePostOrderNext() {
params := blk.params.View()
paramNum := len(params)
for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
phiValue := params[paramIndex]
redundant := true
nonSelfReferencingValue := ValueInvalid
for predIndex := range blk.preds {
br := blk.preds[predIndex].branch
// Resolve the alias in the arguments so that we could use the previous iteration's result.
b.resolveArgumentAlias(br)
pred := br.vs.View()[paramIndex]
if pred == phiValue {
// This is self-referencing: PHI from the same PHI.
continue
}
if !nonSelfReferencingValue.Valid() {
nonSelfReferencingValue = pred
continue
}
if nonSelfReferencingValue != pred {
redundant = false
break
}
}
if !nonSelfReferencingValue.Valid() {
// This shouldn't happen, and must be a bug in builder.go.
panic("BUG: params added but only self-referencing")
}
if redundant {
b.redundantParameterIndexToValue[paramIndex] = nonSelfReferencingValue
redundantParameterIndexes = append(redundantParameterIndexes, paramIndex)
}
}
if len(b.redundantParameterIndexToValue) == 0 {
continue
}
changed = true
// Remove the redundant PHIs from the argument list of branching instructions.
for predIndex := range blk.preds {
var cur int
predBlk := blk.preds[predIndex]
branchInst := predBlk.branch
view := branchInst.vs.View()
for argIndex, value := range view {
if _, ok := b.redundantParameterIndexToValue[argIndex]; !ok {
view[cur] = value
cur++
}
}
branchInst.vs.Cut(cur)
}
// Still need to have the definition of the value of the PHI (previously as the parameter).
for _, redundantParamIndex := range redundantParameterIndexes {
phiValue := params[redundantParamIndex]
onlyValue := b.redundantParameterIndexToValue[redundantParamIndex]
// Create an alias in this block from the only phi argument to the phi value.
b.alias(phiValue, onlyValue)
}
// Finally, Remove the param from the blk.
var cur int
for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
param := params[paramIndex]
if _, ok := b.redundantParameterIndexToValue[paramIndex]; !ok {
params[cur] = param
cur++
}
}
blk.params.Cut(cur)
// Clears the map for the next iteration.
for _, paramIndex := range redundantParameterIndexes {
delete(b.redundantParameterIndexToValue, paramIndex)
}
redundantParameterIndexes = redundantParameterIndexes[:0]
}
if !changed {
break
}
}
// Reuse the slice for the future passes.
b.ints = redundantParameterIndexes
}
// passDeadCodeEliminationOpt traverses all the instructions, and calculates the reference count of each Value, and
// eliminates all the unnecessary instructions whose ref count is zero.
// The results are stored at builder.valueRefCounts. This also assigns a InstructionGroupID to each Instruction
// during the process. This is the last SSA-level optimization pass and after this,
// the SSA function is ready to be used by backends.
//
// TODO: the algorithm here might not be efficient. Get back to this later.
func passDeadCodeEliminationOpt(b *builder) {
nvid := int(b.nextValueID)
if nvid >= len(b.valueRefCounts) {
b.valueRefCounts = append(b.valueRefCounts, make([]int, nvid-len(b.valueRefCounts)+1)...)
}
if nvid >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, nvid-len(b.valueIDToInstruction)+1)...)
}
// First, we gather all the instructions with side effects.
liveInstructions := b.instStack[:0]
// During the process, we will assign InstructionGroupID to each instruction, which is not
// relevant to dead code elimination, but we need in the backend.
var gid InstructionGroupID
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
cur.gid = gid
switch cur.sideEffect() {
case sideEffectTraps:
// The trappable should always be alive.
liveInstructions = append(liveInstructions, cur)
case sideEffectStrict:
liveInstructions = append(liveInstructions, cur)
// The strict side effect should create different instruction groups.
gid++
}
r1, rs := cur.Returns()
if r1.Valid() {
b.valueIDToInstruction[r1.ID()] = cur
}
for _, r := range rs {
b.valueIDToInstruction[r.ID()] = cur
}
}
}
// Find all the instructions referenced by live instructions transitively.
for len(liveInstructions) > 0 {
tail := len(liveInstructions) - 1
live := liveInstructions[tail]
liveInstructions = liveInstructions[:tail]
if live.live {
// If it's already marked alive, this is referenced multiple times,
// so we can skip it.
continue
}
live.live = true
// Before we walk, we need to resolve the alias first.
b.resolveArgumentAlias(live)
v1, v2, v3, vs := live.Args()
if v1.Valid() {
producingInst := b.valueIDToInstruction[v1.ID()]
if producingInst != nil {
liveInstructions = append(liveInstructions, producingInst)
}
}
if v2.Valid() {
producingInst := b.valueIDToInstruction[v2.ID()]
if producingInst != nil {
liveInstructions = append(liveInstructions, producingInst)
}
}
if v3.Valid() {
producingInst := b.valueIDToInstruction[v3.ID()]
if producingInst != nil {
liveInstructions = append(liveInstructions, producingInst)
}
}
for _, v := range vs {
producingInst := b.valueIDToInstruction[v.ID()]
if producingInst != nil {
liveInstructions = append(liveInstructions, producingInst)
}
}
}
// Now that all the live instructions are flagged as live=true, we eliminate all dead instructions.
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
if !cur.live {
// Remove the instruction from the list.
if prev := cur.prev; prev != nil {
prev.next = cur.next
} else {
blk.rootInstr = cur.next
}
if next := cur.next; next != nil {
next.prev = cur.prev
}
continue
}
// If the value alive, we can be sure that arguments are used definitely.
// Hence, we can increment the value reference counts.
v1, v2, v3, vs := cur.Args()
if v1.Valid() {
b.incRefCount(v1.ID(), cur)
}
if v2.Valid() {
b.incRefCount(v2.ID(), cur)
}
if v3.Valid() {
b.incRefCount(v3.ID(), cur)
}
for _, v := range vs {
b.incRefCount(v.ID(), cur)
}
}
}
b.instStack = liveInstructions // we reuse the stack for the next iteration.
}
func (b *builder) incRefCount(id ValueID, from *Instruction) {
if wazevoapi.SSALoggingEnabled {
fmt.Printf("v%d referenced from %v\n", id, from.Format(b))
}
b.valueRefCounts[id]++
}
// passNopInstElimination eliminates the instructions which is essentially a no-op.
func passNopInstElimination(b *builder) {
if int(b.nextValueID) >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, int(b.nextValueID)-len(b.valueIDToInstruction)+1)...)
}
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
r1, rs := cur.Returns()
if r1.Valid() {
b.valueIDToInstruction[r1.ID()] = cur
}
for _, r := range rs {
b.valueIDToInstruction[r.ID()] = cur
}
}
}
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
switch cur.Opcode() {
// TODO: add more logics here.
case OpcodeIshl, OpcodeSshr, OpcodeUshr:
x, amount := cur.Arg2()
definingInst := b.valueIDToInstruction[amount.ID()]
if definingInst == nil {
// If there's no defining instruction, that means the amount is coming from the parameter.
continue
}
if definingInst.Constant() {
v := definingInst.ConstantVal()
if x.Type().Bits() == 64 {
v = v % 64
} else {
v = v % 32
}
if v == 0 {
b.alias(cur.Return(), x)
}
}
}
}
}
}
// passSortSuccessors sorts the successors of each block in the natural program order.
func passSortSuccessors(b *builder) {
for i := 0; i < b.basicBlocksPool.Allocated(); i++ {
blk := b.basicBlocksPool.View(i)
sortBlocks(blk.success)
}
}