wazevo(frontend): simple bounds check elimination on mem access (#1883)

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
This commit is contained in:
Takeshi Yoneda
2023-12-20 07:51:19 -08:00
committed by GitHub
parent d26cbadd46
commit fa2b2fc090
3 changed files with 161 additions and 198 deletions

View File

@@ -50,9 +50,17 @@ type Compiler struct {
br *bytes.Reader
loweringState loweringState
knownSafeBounds []knownSafeBound
knownSafeBoundsSet []ssa.ValueID
execCtxPtrValue, moduleCtxPtrValue ssa.Value
}
type knownSafeBound struct {
bound uint64
absoluteAddr ssa.Value
}
// NewFrontendCompiler returns a frontend Compiler.
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler {
c := &Compiler{
@@ -354,3 +362,42 @@ func SignatureForListener(wasmSig *wasm.FunctionType) (*ssa.Signature, *ssa.Sign
}
return beforeSig, afterSig
}
// isBoundSafe returns true if the given value is known to be safe to access up to the given bound.
func (c *Compiler) getKnownSafeBound(v ssa.ValueID) *knownSafeBound {
if int(v) >= len(c.knownSafeBounds) {
return nil
}
return &c.knownSafeBounds[v]
}
// recordKnownSafeBound records the given safe bound for the given value.
func (c *Compiler) recordKnownSafeBound(v ssa.ValueID, safeBound uint64, absoluteAddr ssa.Value) {
if int(v) >= len(c.knownSafeBounds) {
c.knownSafeBounds = append(c.knownSafeBounds, make([]knownSafeBound, v+1)...)
}
if exiting := c.knownSafeBounds[v]; exiting.bound == 0 {
c.knownSafeBounds[v] = knownSafeBound{
bound: safeBound,
absoluteAddr: absoluteAddr,
}
c.knownSafeBoundsSet = append(c.knownSafeBoundsSet, v)
} else if safeBound > exiting.bound {
c.knownSafeBounds[v].bound = safeBound
}
}
// clearSafeBounds clears the known safe bounds. This must be called
// after the compilation of each block.
func (c *Compiler) clearSafeBounds() {
for _, v := range c.knownSafeBoundsSet {
ptr := &c.knownSafeBounds[v]
ptr.bound = 0
}
c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0]
}
func (k *knownSafeBound) valid() bool {
return k != nil && k.bound > 0
}

View File

@@ -1019,14 +1019,8 @@ blk0: (exec_ctx:i64, module_ctx:i64, v2:i32, v3:i32)
v9:i64 = Load module_ctx, 0x8
v10:i64 = Iadd v9, v5
Store v3, v10, 0x0
v11:i64 = Iconst_64 0x4
v12:i64 = UExtend v2, 32->64
v13:i64 = Iadd v12, v11
v14:i32 = Icmp lt_u, v6, v13
ExitIfTrue v14, exec_ctx, memory_out_of_bounds
v15:i64 = Iadd v9, v12
v16:i32 = Load v15, 0x0
Jump blk_ret, v16
v11:i32 = Load v10, 0x0
Jump blk_ret, v11
`,
},
{
@@ -1142,191 +1136,44 @@ blk0: (exec_ctx:i64, module_ctx:i64, v2:i32)
v13:i64 = Iadd v12, v11
v14:i32 = Icmp lt_u, v5, v13
ExitIfTrue v14, exec_ctx, memory_out_of_bounds
v15:i64 = Iadd v8, v12
v16:i64 = Load v15, 0x0
v17:i64 = Iconst_64 0x4
v18:i64 = UExtend v2, 32->64
v19:i64 = Iadd v18, v17
v20:i32 = Icmp lt_u, v5, v19
ExitIfTrue v20, exec_ctx, memory_out_of_bounds
v21:i64 = Iadd v8, v18
v22:f32 = Load v21, 0x0
v23:i64 = Iconst_64 0x8
v15:i64 = Load v9, 0x0
v16:f32 = Load v9, 0x0
v17:f64 = Load v9, 0x0
v18:i64 = Iconst_64 0x13
v19:i64 = UExtend v2, 32->64
v20:i64 = Iadd v19, v18
v21:i32 = Icmp lt_u, v5, v20
ExitIfTrue v21, exec_ctx, memory_out_of_bounds
v22:i32 = Load v9, 0xf
v23:i64 = Iconst_64 0x17
v24:i64 = UExtend v2, 32->64
v25:i64 = Iadd v24, v23
v26:i32 = Icmp lt_u, v5, v25
ExitIfTrue v26, exec_ctx, memory_out_of_bounds
v27:i64 = Iadd v8, v24
v28:f64 = Load v27, 0x0
v29:i64 = Iconst_64 0x13
v30:i64 = UExtend v2, 32->64
v31:i64 = Iadd v30, v29
v32:i32 = Icmp lt_u, v5, v31
ExitIfTrue v32, exec_ctx, memory_out_of_bounds
v33:i64 = Iadd v8, v30
v34:i32 = Load v33, 0xf
v35:i64 = Iconst_64 0x17
v36:i64 = UExtend v2, 32->64
v37:i64 = Iadd v36, v35
v38:i32 = Icmp lt_u, v5, v37
ExitIfTrue v38, exec_ctx, memory_out_of_bounds
v39:i64 = Iadd v8, v36
v40:i64 = Load v39, 0xf
v41:i64 = Iconst_64 0x13
v42:i64 = UExtend v2, 32->64
v43:i64 = Iadd v42, v41
v44:i32 = Icmp lt_u, v5, v43
ExitIfTrue v44, exec_ctx, memory_out_of_bounds
v45:i64 = Iadd v8, v42
v46:f32 = Load v45, 0xf
v47:i64 = Iconst_64 0x17
v48:i64 = UExtend v2, 32->64
v49:i64 = Iadd v48, v47
v50:i32 = Icmp lt_u, v5, v49
ExitIfTrue v50, exec_ctx, memory_out_of_bounds
v51:i64 = Iadd v8, v48
v52:f64 = Load v51, 0xf
v53:i64 = Iconst_64 0x1
v54:i64 = UExtend v2, 32->64
v55:i64 = Iadd v54, v53
v56:i32 = Icmp lt_u, v5, v55
ExitIfTrue v56, exec_ctx, memory_out_of_bounds
v57:i64 = Iadd v8, v54
v58:i32 = Sload8 v57, 0x0
v59:i64 = Iconst_64 0x10
v60:i64 = UExtend v2, 32->64
v61:i64 = Iadd v60, v59
v62:i32 = Icmp lt_u, v5, v61
ExitIfTrue v62, exec_ctx, memory_out_of_bounds
v63:i64 = Iadd v8, v60
v64:i32 = Sload8 v63, 0xf
v65:i64 = Iconst_64 0x1
v66:i64 = UExtend v2, 32->64
v67:i64 = Iadd v66, v65
v68:i32 = Icmp lt_u, v5, v67
ExitIfTrue v68, exec_ctx, memory_out_of_bounds
v69:i64 = Iadd v8, v66
v70:i32 = Uload8 v69, 0x0
v71:i64 = Iconst_64 0x10
v72:i64 = UExtend v2, 32->64
v73:i64 = Iadd v72, v71
v74:i32 = Icmp lt_u, v5, v73
ExitIfTrue v74, exec_ctx, memory_out_of_bounds
v75:i64 = Iadd v8, v72
v76:i32 = Uload8 v75, 0xf
v77:i64 = Iconst_64 0x2
v78:i64 = UExtend v2, 32->64
v79:i64 = Iadd v78, v77
v80:i32 = Icmp lt_u, v5, v79
ExitIfTrue v80, exec_ctx, memory_out_of_bounds
v81:i64 = Iadd v8, v78
v82:i32 = Sload16 v81, 0x0
v83:i64 = Iconst_64 0x11
v84:i64 = UExtend v2, 32->64
v85:i64 = Iadd v84, v83
v86:i32 = Icmp lt_u, v5, v85
ExitIfTrue v86, exec_ctx, memory_out_of_bounds
v87:i64 = Iadd v8, v84
v88:i32 = Sload16 v87, 0xf
v89:i64 = Iconst_64 0x2
v90:i64 = UExtend v2, 32->64
v91:i64 = Iadd v90, v89
v92:i32 = Icmp lt_u, v5, v91
ExitIfTrue v92, exec_ctx, memory_out_of_bounds
v93:i64 = Iadd v8, v90
v94:i32 = Uload16 v93, 0x0
v95:i64 = Iconst_64 0x11
v96:i64 = UExtend v2, 32->64
v97:i64 = Iadd v96, v95
v98:i32 = Icmp lt_u, v5, v97
ExitIfTrue v98, exec_ctx, memory_out_of_bounds
v99:i64 = Iadd v8, v96
v100:i32 = Uload16 v99, 0xf
v101:i64 = Iconst_64 0x1
v102:i64 = UExtend v2, 32->64
v103:i64 = Iadd v102, v101
v104:i32 = Icmp lt_u, v5, v103
ExitIfTrue v104, exec_ctx, memory_out_of_bounds
v105:i64 = Iadd v8, v102
v106:i64 = Sload8 v105, 0x0
v107:i64 = Iconst_64 0x10
v108:i64 = UExtend v2, 32->64
v109:i64 = Iadd v108, v107
v110:i32 = Icmp lt_u, v5, v109
ExitIfTrue v110, exec_ctx, memory_out_of_bounds
v111:i64 = Iadd v8, v108
v112:i64 = Sload8 v111, 0xf
v113:i64 = Iconst_64 0x1
v114:i64 = UExtend v2, 32->64
v115:i64 = Iadd v114, v113
v116:i32 = Icmp lt_u, v5, v115
ExitIfTrue v116, exec_ctx, memory_out_of_bounds
v117:i64 = Iadd v8, v114
v118:i64 = Uload8 v117, 0x0
v119:i64 = Iconst_64 0x10
v120:i64 = UExtend v2, 32->64
v121:i64 = Iadd v120, v119
v122:i32 = Icmp lt_u, v5, v121
ExitIfTrue v122, exec_ctx, memory_out_of_bounds
v123:i64 = Iadd v8, v120
v124:i64 = Uload8 v123, 0xf
v125:i64 = Iconst_64 0x2
v126:i64 = UExtend v2, 32->64
v127:i64 = Iadd v126, v125
v128:i32 = Icmp lt_u, v5, v127
ExitIfTrue v128, exec_ctx, memory_out_of_bounds
v129:i64 = Iadd v8, v126
v130:i64 = Sload16 v129, 0x0
v131:i64 = Iconst_64 0x11
v132:i64 = UExtend v2, 32->64
v133:i64 = Iadd v132, v131
v134:i32 = Icmp lt_u, v5, v133
ExitIfTrue v134, exec_ctx, memory_out_of_bounds
v135:i64 = Iadd v8, v132
v136:i64 = Sload16 v135, 0xf
v137:i64 = Iconst_64 0x2
v138:i64 = UExtend v2, 32->64
v139:i64 = Iadd v138, v137
v140:i32 = Icmp lt_u, v5, v139
ExitIfTrue v140, exec_ctx, memory_out_of_bounds
v141:i64 = Iadd v8, v138
v142:i64 = Uload16 v141, 0x0
v143:i64 = Iconst_64 0x11
v144:i64 = UExtend v2, 32->64
v145:i64 = Iadd v144, v143
v146:i32 = Icmp lt_u, v5, v145
ExitIfTrue v146, exec_ctx, memory_out_of_bounds
v147:i64 = Iadd v8, v144
v148:i64 = Uload16 v147, 0xf
v149:i64 = Iconst_64 0x4
v150:i64 = UExtend v2, 32->64
v151:i64 = Iadd v150, v149
v152:i32 = Icmp lt_u, v5, v151
ExitIfTrue v152, exec_ctx, memory_out_of_bounds
v153:i64 = Iadd v8, v150
v154:i64 = Sload32 v153, 0x0
v155:i64 = Iconst_64 0x13
v156:i64 = UExtend v2, 32->64
v157:i64 = Iadd v156, v155
v158:i32 = Icmp lt_u, v5, v157
ExitIfTrue v158, exec_ctx, memory_out_of_bounds
v159:i64 = Iadd v8, v156
v160:i64 = Sload32 v159, 0xf
v161:i64 = Iconst_64 0x4
v162:i64 = UExtend v2, 32->64
v163:i64 = Iadd v162, v161
v164:i32 = Icmp lt_u, v5, v163
ExitIfTrue v164, exec_ctx, memory_out_of_bounds
v165:i64 = Iadd v8, v162
v166:i64 = Uload32 v165, 0x0
v167:i64 = Iconst_64 0x13
v168:i64 = UExtend v2, 32->64
v169:i64 = Iadd v168, v167
v170:i32 = Icmp lt_u, v5, v169
ExitIfTrue v170, exec_ctx, memory_out_of_bounds
v171:i64 = Iadd v8, v168
v172:i64 = Uload32 v171, 0xf
Jump blk_ret, v10, v16, v22, v28, v34, v40, v46, v52, v58, v64, v70, v76, v82, v88, v94, v100, v106, v112, v118, v124, v130, v136, v142, v148, v154, v160, v166, v172
v27:i64 = Load v9, 0xf
v28:f32 = Load v9, 0xf
v29:f64 = Load v9, 0xf
v30:i32 = Sload8 v9, 0x0
v31:i32 = Sload8 v9, 0xf
v32:i32 = Uload8 v9, 0x0
v33:i32 = Uload8 v9, 0xf
v34:i32 = Sload16 v9, 0x0
v35:i32 = Sload16 v9, 0xf
v36:i32 = Uload16 v9, 0x0
v37:i32 = Uload16 v9, 0xf
v38:i64 = Sload8 v9, 0x0
v39:i64 = Sload8 v9, 0xf
v40:i64 = Uload8 v9, 0x0
v41:i64 = Uload8 v9, 0xf
v42:i64 = Sload16 v9, 0x0
v43:i64 = Sload16 v9, 0xf
v44:i64 = Uload16 v9, 0x0
v45:i64 = Uload16 v9, 0xf
v46:i64 = Sload32 v9, 0x0
v47:i64 = Sload32 v9, 0xf
v48:i64 = Uload32 v9, 0x0
v49:i64 = Uload32 v9, 0xf
Jump blk_ret, v10, v15, v16, v17, v22, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49
`,
},
{
@@ -1934,3 +1781,49 @@ func TestCompiler_declareSignatures(t *testing.T) {
}
})
}
func TestCompiler_recordKnownSafeBound(t *testing.T) {
c := &Compiler{}
c.recordKnownSafeBound(1, 99, 9999)
require.Equal(t, 1, len(c.knownSafeBoundsSet))
require.True(t, c.getKnownSafeBound(1).valid())
require.Equal(t, uint64(99), c.getKnownSafeBound(1).bound)
require.Equal(t, ssa.Value(9999), c.getKnownSafeBound(1).absoluteAddr)
c.recordKnownSafeBound(1, 150, 9999)
require.Equal(t, 1, len(c.knownSafeBoundsSet))
require.Equal(t, uint64(150), c.getKnownSafeBound(1).bound)
c.recordKnownSafeBound(5, 666, 54321)
require.Equal(t, 2, len(c.knownSafeBoundsSet))
require.Equal(t, uint64(666), c.getKnownSafeBound(5).bound)
require.Equal(t, ssa.Value(54321), c.getKnownSafeBound(5).absoluteAddr)
}
func TestCompiler_getKnownSafeBound(t *testing.T) {
c := &Compiler{
knownSafeBounds: []knownSafeBound{
{}, {bound: 2134},
},
}
require.Nil(t, c.getKnownSafeBound(5))
require.Nil(t, c.getKnownSafeBound(12345))
require.False(t, c.getKnownSafeBound(0).valid())
require.True(t, c.getKnownSafeBound(1).valid())
}
func TestCompiler_clearSafeBounds(t *testing.T) {
c := &Compiler{}
c.knownSafeBounds = []knownSafeBound{{bound: 1}, {}, {bound: 2}, {}, {}, {bound: 3}}
c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5}
c.clearSafeBounds()
require.Equal(t, 0, len(c.knownSafeBoundsSet))
require.Equal(t, []knownSafeBound{{}, {}, {}, {}, {}, {}}, c.knownSafeBounds)
}
func TestKnownSafeBound_valid(t *testing.T) {
k := &knownSafeBound{bound: 10, absoluteAddr: 12345}
require.True(t, k.valid())
k.bound = 0
require.False(t, k.valid())
}

View File

@@ -1416,6 +1416,8 @@ func (c *Compiler) lowerCurrentOpcode() {
builder.Seal(thenBlk)
builder.Seal(elseBlk)
case wasm.OpcodeElse:
c.clearSafeBounds() // Reset the safe bounds since we are entering the Else block.
ifctrl := state.ctrlPeekAt(0)
if unreachable := state.unreachable; unreachable && state.unreachableDepth > 0 {
// If it is currently in unreachable and is a nested if,
@@ -1443,6 +1445,8 @@ func (c *Compiler) lowerCurrentOpcode() {
builder.SetCurrentBlock(elseBlk)
case wasm.OpcodeEnd:
c.clearSafeBounds() // Reset the safe bounds since we are exiting the block.
if state.unreachableDepth > 0 {
state.unreachableDepth--
break
@@ -3368,24 +3372,35 @@ func (c *Compiler) lowerCallIndirect(typeIndex, tableIndex uint32) {
// memOpSetup inserts the bounds check and calculates the address of the memory operation (loads/stores).
func (c *Compiler) memOpSetup(baseAddr ssa.Value, constOffset, operationSizeInBytes uint64) (address ssa.Value) {
address = ssa.ValueInvalid
builder := c.ssaBuilder
baseAddrID := baseAddr.ID()
ceil := constOffset + operationSizeInBytes
if known := c.getKnownSafeBound(baseAddrID); known.valid() {
// We reuse the calculated absolute address even if the bound is not known to be safe.
address = known.absoluteAddr
if ceil <= known.bound {
return
}
}
ceilConst := builder.AllocateInstruction()
ceilConst.AsIconst64(ceil)
builder.InsertInstruction(ceilConst)
// We calculate the offset in 64-bit space.
extBaseAddr := builder.AllocateInstruction()
extBaseAddr.AsUExtend(baseAddr, 32, 64)
builder.InsertInstruction(extBaseAddr)
extBaseAddr := builder.AllocateInstruction().
AsUExtend(baseAddr, 32, 64).
Insert(builder).
Return()
// Note: memLen is already zero extended to 64-bit space at the load time.
memLen := c.getMemoryLenValue(false)
// baseAddrPlusCeil = baseAddr + ceil
baseAddrPlusCeil := builder.AllocateInstruction()
baseAddrPlusCeil.AsIadd(extBaseAddr.Return(), ceilConst.Return())
baseAddrPlusCeil.AsIadd(extBaseAddr, ceilConst.Return())
builder.InsertInstruction(baseAddrPlusCeil)
// Check for out of bounds memory access: `memLen >= baseAddrPlusCeil`.
@@ -3397,11 +3412,15 @@ func (c *Compiler) memOpSetup(baseAddr ssa.Value, constOffset, operationSizeInBy
builder.InsertInstruction(exitIfNZ)
// Load the value from memBase + extBaseAddr.
memBase := c.getMemoryBaseValue(false)
addrCalc := builder.AllocateInstruction()
addrCalc.AsIadd(memBase, extBaseAddr.Return())
builder.InsertInstruction(addrCalc)
return addrCalc.Return()
if address == ssa.ValueInvalid { // Reuse the value if the memBase is already calculated at this point.
memBase := c.getMemoryBaseValue(false)
address = builder.AllocateInstruction().
AsIadd(memBase, extBaseAddr).Insert(builder).Return()
}
// Record the bound ceil for this baseAddr is known to be safe for the subsequent memory access in the same block.
c.recordKnownSafeBound(baseAddrID, ceil, address)
return
}
func (c *Compiler) callMemmove(dst, src, size ssa.Value) {
@@ -3434,6 +3453,10 @@ func (c *Compiler) reloadAfterCall() {
func (c *Compiler) reloadMemoryBaseLen() {
_ = c.getMemoryBaseValue(true)
_ = c.getMemoryLenValue(true)
// This function being called means that the memory base might have changed.
// Therefore, we need to clear the known safe bounds because we cache the absolute address of the memory access per each base offset.
c.clearSafeBounds()
}
// globalInstanceValueOffset is the offsetOf .Value field of wasm.GlobalInstance.