Files
wazero/internal/integration_test/fuzz/wazerolib/nodiff.go
Nuno Cruces 6503e82d3c Follow unsafe rules. (#1449)
Signed-off-by: Nuno Cruces <ncruces@users.noreply.github.com>
2023-05-09 15:34:08 +01:00

367 lines
13 KiB
Go

package main
import "C"
import (
"bytes"
"context"
"errors"
"fmt"
"reflect"
"strings"
"unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/internal/testing/binaryencoding"
"github.com/tetratelabs/wazero/internal/wasm"
)
// require_no_diff ensures that the behavior is the same between the compiler and the interpreter for any given binary.
// And if there's diff, this also saves the problematic binary and wat into testdata directory.
//
//export require_no_diff
func require_no_diff(binaryPtr uintptr, binarySize int, watPtr uintptr, watSize int, checkMemory bool) {
// TODO: use unsafe.Slice after flooring Go 1.20.
var wasmBin []byte
wasmHdr := (*reflect.SliceHeader)(unsafe.Pointer(&wasmBin))
wasmHdr.Data = binaryPtr
wasmHdr.Len = binarySize
wasmHdr.Cap = binarySize
// TODO: use unsafe.String after flooring Go 1.20.
var wat string
watHdr := (*reflect.StringHeader)(unsafe.Pointer(&wat))
watHdr.Data = watPtr
watHdr.Len = watSize
failed := true
defer func() {
if failed {
// If the test fails, we save the binary and wat into testdata directory.
saveFailedBinary(wasmBin, wat, "TestReRunFailedRequireNoDiffCase")
}
}()
requireNoDiff(wasmBin, checkMemory, func(err error) {
if err != nil {
panic(err)
}
})
failed = false
}
// We haven't had public APIs for referencing all the imported entries from wazero.CompiledModule,
// so we use the unsafe.Pointer and the internal memory layout to get the internal *wasm.Module
// from wazero.CompiledFunction. This must be synced with the struct definition of wazero.compiledModule (internal one).
func extractInternalWasmModuleFromCompiledModule(c wazero.CompiledModule) (*wasm.Module, error) {
// This is the internal representation of interface in Go.
// https://research.swtch.com/interfaces
type iface struct {
tp *byte
data unsafe.Pointer
}
// This corresponds to the unexported wazero.compiledModule to get *wasm.Module from wazero.CompiledModule interface.
type compiledModule struct {
module *wasm.Module
}
ciface := (*iface)(unsafe.Pointer(&c))
if ciface == nil {
return nil, errors.New("invalid pointer")
}
cm := (*compiledModule)(ciface.data)
return cm.module, nil
}
// requireNoDiff ensures that the behavior is the same between the compiler and the interpreter for any given binary.
func requireNoDiff(wasmBin []byte, checkMemory bool, requireNoError func(err error)) {
// Choose the context to use for function calls.
ctx := context.Background()
compiler := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler())
interpreter := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigInterpreter())
defer compiler.Close(ctx)
defer interpreter.Close(ctx)
compilerCompiled, err := compiler.CompileModule(ctx, wasmBin)
if err != nil && strings.Contains(err.Error(), "has an empty module name") {
// This is the limitation wazero imposes to allow special-casing of anonymous modules.
return
}
requireNoError(err)
interpreterCompiled, err := interpreter.CompileModule(ctx, wasmBin)
requireNoError(err)
internalMod, err := extractInternalWasmModuleFromCompiledModule(compilerCompiled)
requireNoError(err)
if skip := ensureDummyImports(compiler, internalMod, requireNoError); skip {
return
}
ensureDummyImports(interpreter, internalMod, requireNoError)
// Instantiate module.
compilerMod, compilerInstErr := compiler.InstantiateModule(ctx, compilerCompiled,
wazero.NewModuleConfig().WithName(string(internalMod.ID[:])))
interpreterMod, interpreterInstErr := interpreter.InstantiateModule(ctx, interpreterCompiled,
wazero.NewModuleConfig().WithName(string(internalMod.ID[:])))
okToInvoke, err := ensureInstantiationError(compilerInstErr, interpreterInstErr)
requireNoError(err)
if okToInvoke {
err = ensureInvocationResultMatch(compilerMod, interpreterMod, interpreterCompiled.ExportedFunctions())
requireNoError(err)
compilerMem, _ := compilerMod.Memory().(*wasm.MemoryInstance)
interpreterMem, _ := interpreterMod.Memory().(*wasm.MemoryInstance)
if checkMemory && compilerMem != nil && interpreterMem != nil {
if !bytes.Equal(compilerMem.Buffer, interpreterMem.Buffer) {
requireNoError(fmt.Errorf("memory state mimsmatch\ncompiler: %v\ninterpreter: %v",
compilerMem.Buffer, interpreterMem.Buffer))
}
}
}
}
// ensureDummyImports instantiates the modules which are required imports by `origin` *wasm.Module.
func ensureDummyImports(r wazero.Runtime, origin *wasm.Module, requireNoError func(err error)) (skip bool) {
impMods := make(map[string][]wasm.Import)
for _, imp := range origin.ImportSection {
if imp.Module == "" {
// Importing empty modules are forbidden as future work will allow multiple anonymous modules.
skip = true
return
}
impMods[imp.Module] = append(impMods[imp.Module], imp)
}
for mName, impMod := range impMods {
usedName := make(map[string]struct{}, len(impMod))
m := &wasm.Module{NameSection: &wasm.NameSection{ModuleName: mName}}
for _, imp := range impMod {
_, ok := usedName[imp.Name]
if ok {
// Import segment can have duplicated "{module_name}.{name}" pair while it is prohibited for exports.
// Decision on allowing modules with these "ill" imports or not is up to embedder, and wazero chooses
// not to allow. Hence, we skip the entire case.
// See "Note" at https://www.w3.org/TR/wasm-core-2/syntax/modules.html#imports
return true
} else {
usedName[imp.Name] = struct{}{}
}
var index uint32
switch imp.Type {
case wasm.ExternTypeFunc:
tp := origin.TypeSection[imp.DescFunc]
typeIdx := uint32(len(m.TypeSection))
index = uint32(len(m.FunctionSection))
m.FunctionSection = append(m.FunctionSection, typeIdx)
m.TypeSection = append(m.TypeSection, tp)
body := bytes.NewBuffer(nil)
for _, vt := range tp.Results {
switch vt {
case wasm.ValueTypeI32:
body.WriteByte(wasm.OpcodeI32Const)
body.WriteByte(0)
case wasm.ValueTypeI64:
body.WriteByte(wasm.OpcodeI64Const)
body.WriteByte(0)
case wasm.ValueTypeF32:
body.Write([]byte{wasm.OpcodeF32Const, 0, 0, 0, 0})
case wasm.ValueTypeF64:
body.Write([]byte{wasm.OpcodeF64Const, 0, 0, 0, 0, 0, 0, 0, 0})
case wasm.ValueTypeV128:
body.Write([]byte{
wasm.OpcodeVecPrefix, wasm.OpcodeVecV128Const,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
})
case wasm.ValueTypeExternref:
body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeExternref})
case wasm.ValueTypeFuncref:
body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeFuncref})
}
}
body.WriteByte(wasm.OpcodeEnd)
m.CodeSection = append(m.CodeSection, wasm.Code{Body: body.Bytes()})
case wasm.ExternTypeGlobal:
index = uint32(len(m.GlobalSection))
var data []byte
var opcode byte
switch imp.DescGlobal.ValType {
case wasm.ValueTypeI32:
opcode = wasm.OpcodeI32Const
data = []byte{0}
case wasm.ValueTypeI64:
opcode = wasm.OpcodeI64Const
data = []byte{0}
case wasm.ValueTypeF32:
opcode = wasm.OpcodeF32Const
data = []byte{0, 0, 0, 0}
case wasm.ValueTypeF64:
opcode = wasm.OpcodeF64Const
data = []byte{0, 0, 0, 0, 0, 0, 0, 0}
case wasm.ValueTypeV128:
opcode = wasm.OpcodeVecPrefix
data = []byte{wasm.OpcodeVecV128Const, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
case wasm.ValueTypeExternref:
opcode = wasm.OpcodeRefNull
data = []byte{wasm.RefTypeExternref}
case wasm.ValueTypeFuncref:
opcode = wasm.OpcodeRefNull
data = []byte{wasm.RefTypeFuncref}
}
m.GlobalSection = append(m.GlobalSection, wasm.Global{
Type: imp.DescGlobal, Init: wasm.ConstantExpression{Opcode: opcode, Data: data},
})
case wasm.ExternTypeMemory:
m.MemorySection = imp.DescMem
index = 0
case wasm.ExternTypeTable:
index = uint32(len(m.TableSection))
m.TableSection = append(m.TableSection, imp.DescTable)
}
m.ExportSection = append(m.ExportSection, wasm.Export{Type: imp.Type, Name: imp.Name, Index: index})
}
_, err := r.Instantiate(context.Background(), binaryencoding.EncodeModule(m))
requireNoError(err)
}
return
}
const valueTypeVector = 0x7b
// ensureInvocationResultMatch invokes all the exported functions from the module, and compare all the results between compiler vs interpreter.
func ensureInvocationResultMatch(compiledMod, interpreterMod api.Module, exportedFunctions map[string]api.FunctionDefinition) (err error) {
ctx := context.Background()
outer:
for name, def := range exportedFunctions {
resultTypes := def.ResultTypes()
for _, rt := range resultTypes {
switch rt {
case api.ValueTypeI32, api.ValueTypeI64, api.ValueTypeF32, api.ValueTypeF64, valueTypeVector:
default:
// For the sake of simplicity in the assertion, we only invoke the function with the basic types.
continue outer
}
}
cmpF := compiledMod.ExportedFunction(name)
intF := interpreterMod.ExportedFunction(name)
params := getDummyValues(def.ParamTypes())
cmpRes, cmpErr := cmpF.Call(ctx, params...)
intRes, intErr := intF.Call(ctx, params...)
if errMismatch := ensureInvocationError(cmpErr, intErr); errMismatch != nil {
panic(fmt.Sprintf("error mismatch on invoking %s: %v", name, errMismatch))
}
matched := true
var typesIndex int
for i := 0; i < len(cmpRes); i++ {
switch resultTypes[typesIndex] {
case api.ValueTypeI32, api.ValueTypeF32:
matched = matched && uint32(cmpRes[i]) == uint32(intRes[i])
case api.ValueTypeI64, api.ValueTypeF64:
matched = matched && cmpRes[i] == intRes[i]
case valueTypeVector:
matched = matched && cmpRes[i] == intRes[i] && cmpRes[i+1] == intRes[i+1]
i++ // We need to advance twice (lower and higher 64bits)
}
typesIndex++
}
if !matched {
err = fmt.Errorf("result mismatch on invoking '%s':\n\tinterpreter got: %v\n\tcompiler got: %v", name, intRes, cmpRes)
}
}
return
}
// getDummyValues returns a dummy input values for function invocations.
func getDummyValues(valueTypes []api.ValueType) (ret []uint64) {
for _, vt := range valueTypes {
if vt != 0x7b { // v128
ret = append(ret, 0)
} else {
ret = append(ret, 0, 0)
}
}
return
}
// ensureInvocationError ensures that function invocation errors returned by interpreter and compiler match each other's.
func ensureInvocationError(compilerErr, interpErr error) error {
if compilerErr == nil && interpErr == nil {
return nil
} else if compilerErr == nil && interpErr != nil {
return fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr)
} else if compilerErr != nil && interpErr == nil {
return fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr)
}
compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error()
if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 {
compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")]
}
if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 {
interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")]
}
if compilerErrMsg != interpErrMsg {
return fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr)
}
return nil
}
// ensureInstantiationError ensures that instantiation errors returned by interpreter and compiler match each other's.
func ensureInstantiationError(compilerErr, interpErr error) (okToInvoke bool, err error) {
if compilerErr == nil && interpErr == nil {
return true, nil
} else if compilerErr == nil && interpErr != nil {
return false, fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr)
} else if compilerErr != nil && interpErr == nil {
return false, fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr)
}
compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error()
if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 {
compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")]
}
if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 {
interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")]
}
if !allowedErrorDuringInstantiation(compilerErrMsg) {
return false, fmt.Errorf("invalid error occur with compiler: %v", compilerErr)
} else if !allowedErrorDuringInstantiation(interpErrMsg) {
return false, fmt.Errorf("invalid error occur with interpreter: %v", interpErrMsg)
}
if compilerErrMsg != interpErrMsg {
return false, fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr)
}
return false, nil
}
// allowedErrorDuringInstantiation checks if the error message is considered sane.
func allowedErrorDuringInstantiation(errMsg string) bool {
// This happens when data segment causes out of bound, but it is considered as runtime-error in WebAssembly 2.0
// which is fine.
if strings.HasPrefix(errMsg, "data[") && strings.HasSuffix(errMsg, "]: out of bounds memory access") {
return true
}
// Start function failure is neither instantiation nor compilation error, but rather a runtime error, so that is fine.
if strings.HasPrefix(errMsg, "start function[") && strings.Contains(errMsg, "failed: wasm error:") {
return true
}
return false
}