Files
wazero/internal/engine/compiler/compiler_test.go
Takeshi Yoneda ad968dc3fe Pass correct api.Module to host functions (#1213)
Fixes #1211

Previously, host functions are getting api.Module for the "originating" module, 
which is the module for api.Function currently invoked, except that the api.Module 
is modified by withMemory with the caller's memory instance, therefore there 
haven't been no problem for most cases. The only issues were the methods 
besides Memory() of api.Module, and this commit fixes them.

Signed-off-by: Takeshi Yoneda <takeshi@tetrate.io>
2023-03-08 15:05:48 +09:00

296 lines
12 KiB
Go

package compiler
import (
"fmt"
"math"
"os"
"testing"
"unsafe"
"github.com/tetratelabs/wazero/internal/platform"
"github.com/tetratelabs/wazero/internal/testing/require"
"github.com/tetratelabs/wazero/internal/wasm"
"github.com/tetratelabs/wazero/internal/wazeroir"
)
func TestMain(m *testing.M) {
if !platform.CompilerSupported() {
os.Exit(0)
}
os.Exit(m.Run())
}
// Ensures that the offset consts do not drift when we manipulate the target
// structs.
//
// Note: This is a package initializer as many tests could fail if these
// constants are misaligned, hiding the root cause.
func init() {
var me moduleEngine
requireEqual := func(expected, actual int, name string) {
if expected != actual {
panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual))
}
}
requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset")
var ce callEngine
// Offsets for callEngine.moduleContext.
requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset")
requireEqual(int(unsafe.Offsetof(ce.moduleInstanceAddress)), callEngineModuleContextModuleInstanceAddressOffset, "callEngineModuleContextModuleInstanceAddressOffset")
requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset")
requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset")
requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset")
// Offsets for callEngine.stackContext
requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset")
requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset")
requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset")
requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset")
// Offsets for callEngine.exitContext.
requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset")
requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset")
requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset")
requireEqual(int(unsafe.Offsetof(ce.callerFunctionInstance)), callEngineExitContextCallerFunctionInstanceOffset, "callEngineExitContextCallerFunctionInstanceOffset")
// Size and offsets for callFrame.
var frame callFrame
requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize")
// Offsets for code.
var compiledFunc function
requireEqual(int(unsafe.Offsetof(compiledFunc.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset")
requireEqual(int(unsafe.Offsetof(compiledFunc.source)), functionSourceOffset, "functionSourceOffset")
requireEqual(int(unsafe.Offsetof(compiledFunc.moduleInstanceAddress)), functionModuleInstanceAddressOffset, "functionModuleInstanceAddressOffset")
requireEqual(int(unsafe.Sizeof(compiledFunc)), functionSize, "functionModuleInstanceAddressOffset")
// Offsets for wasm.ModuleInstance.
var moduleInstance wasm.ModuleInstance
requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.Memory)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset")
requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset")
var functionInstance wasm.FunctionInstance
requireEqual(int(unsafe.Offsetof(functionInstance.TypeID)), functionInstanceTypeIDOffset, "functionInstanceTypeIDOffset")
// Offsets for wasm.Table.
var tableInstance wasm.TableInstance
requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset")
// We add "+8" to get the length of Tables[0].Table
// since the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset")
// Offsets for wasm.Memory
var memoryInstance wasm.MemoryInstance
requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset")
// "+8" because the slice header is laid out as {Data uintptr, Len int64, Cap int64} on memory.
requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset")
// Offsets for wasm.GlobalInstance
var globalInstance wasm.GlobalInstance
requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset")
var dataInstance wasm.DataInstance
requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize")
var elementInstance wasm.ElementInstance
requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize")
var pointer uintptr
requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2")
}
type compilerEnv struct {
me *moduleEngine
ce *callEngine
moduleInstance *wasm.ModuleInstance
}
func (j *compilerEnv) stackTopAsUint32() uint32 {
return uint32(j.stack()[j.ce.stackContext.stackPointer-1])
}
func (j *compilerEnv) stackTopAsInt32() int32 {
return int32(j.stack()[j.ce.stackContext.stackPointer-1])
}
func (j *compilerEnv) stackTopAsUint64() uint64 {
return j.stack()[j.ce.stackContext.stackPointer-1]
}
func (j *compilerEnv) stackTopAsInt64() int64 {
return int64(j.stack()[j.ce.stackContext.stackPointer-1])
}
func (j *compilerEnv) stackTopAsFloat32() float32 {
return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1]))
}
func (j *compilerEnv) stackTopAsFloat64() float64 {
return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1])
}
func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) {
st := j.stack()
return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1]
}
func (j *compilerEnv) memory() []byte {
return j.moduleInstance.Memory.Buffer
}
func (j *compilerEnv) stack() []uint64 {
return j.ce.stack
}
func (j *compilerEnv) compilerStatus() nativeCallStatusCode {
return j.ce.exitContext.statusCode
}
func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index {
return j.ce.exitContext.builtinFunctionCallIndex
}
// stackPointer returns the stack pointer minus the call frame.
func (j *compilerEnv) stackPointer() uint64 {
return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64
}
func (j *compilerEnv) stackBasePointer() uint64 {
return j.ce.stackContext.stackBasePointerInBytes >> 3
}
func (j *compilerEnv) setStackPointer(sp uint64) {
j.ce.stackContext.stackPointer = sp
}
func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) {
j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...)
}
func (j *compilerEnv) globals() []*wasm.GlobalInstance {
return j.moduleInstance.Globals
}
func (j *compilerEnv) addTable(table *wasm.TableInstance) {
j.moduleInstance.Tables = append(j.moduleInstance.Tables, table)
}
func (j *compilerEnv) setStackBasePointer(sp uint64) {
j.ce.stackContext.stackBasePointerInBytes = sp << 3
}
func (j *compilerEnv) module() *wasm.ModuleInstance {
return j.moduleInstance
}
func (j *compilerEnv) moduleEngine() *moduleEngine {
return j.me
}
func (j *compilerEnv) callEngine() *callEngine {
return j.ce
}
func (j *compilerEnv) newFunction(codeSegment []byte) *function {
return &function{
parent: &code{codeSegment: codeSegment},
codeInitialAddress: uintptr(unsafe.Pointer(&codeSegment[0])),
moduleInstanceAddress: uintptr(unsafe.Pointer(j.moduleInstance)),
source: &wasm.FunctionInstance{
Type: &wasm.FunctionType{},
Module: j.moduleInstance,
},
}
}
func (j *compilerEnv) exec(codeSegment []byte) {
f := j.newFunction(codeSegment)
j.ce.initialFn = f
j.ce.fn = f
nativecall(
uintptr(unsafe.Pointer(&codeSegment[0])),
uintptr(unsafe.Pointer(j.ce)),
uintptr(unsafe.Pointer(j.moduleInstance)),
)
}
func (j *compilerEnv) requireNewCompiler(t *testing.T, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl {
requireSupportedOSArch(t)
if ir == nil {
ir = &wazeroir.CompilationResult{
LabelCallers: map[wazeroir.LabelID]uint32{},
Signature: &wasm.FunctionType{},
}
}
c := fn()
c.Init(ir, false)
ret, ok := c.(compilerImpl)
require.True(t, ok)
return ret
}
// CompilerImpl is the interface used for architecture-independent unit tests in this pkg.
// This is currently implemented by amd64 and arm64.
type compilerImpl interface {
compiler
compileExitFromNativeCode(nativeCallStatusCode)
compileMaybeGrowStack() error
compileReturnFunction() error
getOnStackPointerCeilDeterminedCallBack() func(uint64)
setStackPointerCeil(uint64)
compileReleaseRegisterToStack(loc *runtimeValueLocation)
setRuntimeValueLocationStack(runtimeValueLocationStack)
compileEnsureOnRegister(loc *runtimeValueLocation) error
compileModuleContextInitialization() error
}
const defaultMemoryPageNumInTest = 1
func newCompilerEnvironment() *compilerEnv {
me := &moduleEngine{}
return &compilerEnv{
me: me,
moduleInstance: &wasm.ModuleInstance{
Memory: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)},
Tables: []*wasm.TableInstance{},
Globals: []*wasm.GlobalInstance{},
Engine: me,
},
ce: me.newCallEngine(initialStackSize, nil),
}
}
// requireRuntimeLocationStackPointerEqual ensures that the compiler's runtimeValueLocationStack has
// the expected stack pointer value relative to the call frame.
func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) {
require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64)
}
// TestCompileI32WrapFromI64 is the regression test for https://github.com/tetratelabs/wazero/issues/1008
func TestCompileI32WrapFromI64(t *testing.T) {
c := newCompiler()
// Push the original i64 value.
loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
loc.valueType = runtimeValueTypeI64
// Wrap it as the i32, and this should result in having runtimeValueTypeI32 on top of the stack.
err := c.compileI32WrapFromI64()
require.NoError(t, err)
require.Equal(t, runtimeValueTypeI32, loc.valueType)
}