Add experimental support for snapshot/restore (#1808)

Signed-off-by: Anuraag Agrawal <anuraaga@gmail.com>
This commit is contained in:
Anuraag (Rag) Agrawal
2024-01-24 09:46:58 +09:00
committed by GitHub
parent ebb5ea2eb7
commit 4185e533bb
8 changed files with 438 additions and 10 deletions

View File

@@ -0,0 +1,25 @@
package experimental
// Snapshot holds the execution state at the time of a Snapshotter.Snapshot call.
type Snapshot interface {
// Restore sets the Wasm execution state to the capture. Because a host function
// calling this is resetting the pointer to the executation stack, the host function
// will not be able to return values in the normal way. ret is a slice of values the
// host function intends to return from the restored function.
Restore(ret []uint64)
}
// Snapshotter allows host functions to snapshot the WebAssembly execution environment.
type Snapshotter interface {
// Snapshot captures the current execution state.
Snapshot() Snapshot
}
// EnableSnapshotterKey is a context key to indicate that snapshotting should be enabled.
// The context.Context passed to a exported function invocation should have this key set
// to a non-nil value, and host functions will be able to retrieve it using SnapshotterKey.
type EnableSnapshotterKey struct{}
// SnapshotterKey is a context key to access a Snapshotter from a host function.
// It is only present if EnableSnapshotter was set in the function invocation context.
type SnapshotterKey struct{}

View File

@@ -0,0 +1,100 @@
package experimental_test
import (
"context"
_ "embed"
"fmt"
"log"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
)
// snapshotWasm was generated by the following:
//
// cd testdata; wat2wasm snapshot.wat
//
//go:embed testdata/snapshot.wasm
var snapshotWasm []byte
type snapshotsKey struct{}
func Example_enableSnapshotterKey() {
ctx := context.Background()
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx) // This closes everything this Runtime created.
// Enable experimental snapshotting functionality by setting it to context. We use this
// context when invoking functions, indicating to wazero to enable it.
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
// Also place a mutable holder of snapshots to be referenced during restore.
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
// Register host functions using snapshot and restore. Generally snapshot is saved
// into a mutable location in context to be referenced during restore.
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
// Because we set EnableSnapshotterKey to context, this is non-nil.
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
// Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot
// and restore once but real programs will often call them at multiple layers within a call
// stack with various e.g., try/catch statements.
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
// Write a value to be passed back to restore. This is meant to be opaque to the guest
// and used to re-reference the snapshot.
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
if !ok {
log.Panicln("failed to write snapshot index")
}
return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
// Read the value written by snapshot to re-reference the snapshot.
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
if !ok {
log.Panicln("failed to read snapshot index")
}
// Get the snapshot
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]
// Restore! The invocation of this function will end as soon as we invoke
// Restore, so we also pass in our return value. The guest function run
// will finish with this return value.
snapshot.Restore([]uint64{5})
}).
Export("restore").
Instantiate(ctx)
if err != nil {
log.Panicln(err)
}
mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code
if err != nil {
log.Panicln(err)
}
// Call the guest entrypoint.
res, err := mod.ExportedFunction("run").Call(ctx)
if err != nil {
log.Panicln(err)
}
// We restored and returned the restore value, so it's our result. If restore
// was instead a no-op, we would have returned 10 from normal code flow.
fmt.Println(res[0])
// Output:
// 5
}

View File

@@ -0,0 +1,121 @@
package experimental_test
import (
"context"
"testing"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/testing/require"
)
func TestSnapshotNestedWasmInvocation(t *testing.T) {
ctx := context.Background()
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)
sidechannel := 0
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
defer func() {
sidechannel = 10
}()
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)
_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
require.NoError(t, err)
return 2
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]
snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)
mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// return value from restore
require.Equal(t, uint64(12), res[0])
// Host function defers within the call stack work fine
require.Equal(t, 10, sidechannel)
}
func TestSnapshotMultipleWasmInvocations(t *testing.T) {
ctx := context.Background()
rt := wazero.NewRuntime(ctx)
defer rt.Close(ctx)
_, err := rt.NewHostModuleBuilder("example").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
idx := len(*snapshots)
*snapshots = append(*snapshots, snapshot)
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
require.True(t, ok)
return 0
}).
Export("snapshot").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
require.True(t, ok)
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
snapshot := (*snapshots)[idx]
snapshot.Restore([]uint64{12})
}).
Export("restore").
Instantiate(ctx)
require.NoError(t, err)
mod, err := rt.Instantiate(ctx, snapshotWasm)
require.NoError(t, err)
var snapshots []experimental.Snapshot
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
snapshotPtr := uint64(0)
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
require.NoError(t, err)
// snapshot returned zero
require.Equal(t, uint64(0), res[0])
// Fails, snapshot and restore are called from different wasm invocations. Currently, this
// results in a panic.
err = require.CapturePanic(func() {
_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
})
require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
"exported function invocation than snapshot")
}

View File

@@ -16,7 +16,7 @@ import (
// listenerWasm was generated by the following:
//
// cd testdata; wat2wasm --debug-names listener.wat
// cd logging/testdata; wat2wasm --debug-names listener.wat
//
//go:embed logging/testdata/listener.wasm
var listenerWasm []byte

BIN
experimental/testdata/snapshot.wasm vendored Normal file

Binary file not shown.

34
experimental/testdata/snapshot.wat vendored Normal file
View File

@@ -0,0 +1,34 @@
(module
(import "example" "snapshot" (func $snapshot (param i32) (result i32)))
(import "example" "restore" (func $restore (param i32)))
(func $helper (result i32)
(call $restore (i32.const 0))
;; Not executed
i32.const 10
)
(func (export "run") (result i32) (local i32)
(call $snapshot (i32.const 0))
local.set 0
local.get 0
(if (result i32)
(then ;; restore return, finish with the value returned by it
local.get 0
)
(else ;; snapshot return, call heloer
(call $helper)
)
)
)
(func (export "snapshot") (param i32) (result i32)
(call $snapshot (local.get 0))
)
(func (export "restore") (param i32)
(call $restore (local.get 0))
)
(memory (export "memory") 1 1)
)

View File

@@ -787,7 +787,12 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
defer done()
}
ce.execWasmFunction(ctx, m)
snapshotEnabled := ctx.Value(experimental.EnableSnapshotterKey{}) != nil
if snapshotEnabled {
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
}
ce.execWasmFunction(ctx, m, snapshotEnabled)
// This returns a safe copy of the results, instead of a slice view. If we
// returned a re-slice, the caller could accidentally or purposefully
@@ -853,6 +858,11 @@ func callFrameOffset(funcType *wasm.FunctionType) (ret int) {
//
// This is defined for testability.
func (ce *callEngine) deferredOnCall(ctx context.Context, m *wasm.ModuleInstance, recovered interface{}) (err error) {
if s, ok := recovered.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}
if recovered != nil {
builder := wasmdebug.NewErrorBuilder()
@@ -1039,7 +1049,7 @@ const (
builtinFunctionMemoryNotify
)
func (ce *callEngine) execWasmFunction(ctx context.Context, m *wasm.ModuleInstance) {
func (ce *callEngine) execWasmFunction(ctx context.Context, m *wasm.ModuleInstance, snapshotEnabled bool) {
codeAddr := ce.initialFn.codeInitialAddress
modAddr := ce.initialFn.moduleInstance
@@ -1065,12 +1075,25 @@ entry:
stack := ce.stack[base : base+stackLen]
fn := calleeHostFunction.parent.goFunc
func() {
if snapshotEnabled {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(*snapshot); ok && s.ce == ce {
s.doRestore()
} else {
panic(r)
}
}
}()
}
switch fn := fn.(type) {
case api.GoModuleFunction:
fn.Call(ctx, ce.callerModuleInstance, stack)
case api.GoFunction:
fn.Call(ctx, stack)
}
}()
codeAddr, modAddr = ce.returnAddress, ce.moduleInstance
goto entry
@@ -1217,6 +1240,58 @@ func (ce *callEngine) builtinFunctionMemoryNotify(mem *wasm.MemoryInstance) {
ce.pushValue(uint64(mem.Notify(offset, uint32(count))))
}
// snapshot implements experimental.Snapshot
type snapshot struct {
stackPointer uint64
stackBasePointerInBytes uint64
returnAddress uint64
hostBase int
stack []uint64
ret []uint64
ce *callEngine
}
// Snapshot implements the same method as documented on experimental.Snapshotter.
func (ce *callEngine) Snapshot() experimental.Snapshot {
hostBase := int(ce.stackBasePointerInBytes >> 3)
stackTop := int(ce.stackTopIndex())
stack := make([]uint64, stackTop)
copy(stack, ce.stack[:stackTop])
return &snapshot{
stackPointer: ce.stackContext.stackPointer,
stackBasePointerInBytes: ce.stackBasePointerInBytes,
returnAddress: uint64(ce.returnAddress),
hostBase: hostBase,
stack: stack,
ce: ce,
}
}
// Restore implements the same method as documented on experimental.Snapshot.
func (s *snapshot) Restore(ret []uint64) {
s.ret = ret
panic(s)
}
func (s *snapshot) doRestore() {
ce := s.ce
ce.stackContext.stackPointer = s.stackPointer
ce.stackContext.stackBasePointerInBytes = s.stackBasePointerInBytes
copy(ce.stack, s.stack)
ce.returnAddress = uintptr(s.returnAddress)
copy(ce.stack[s.hostBase:], s.ret)
}
// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}
// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64

View File

@@ -224,6 +224,53 @@ func functionFromUintptr(ptr uintptr) *function {
return *(**function)(unsafe.Pointer(wrapped))
}
type snapshot struct {
stack []uint64
frames []*callFrame
pc uint64
ret []uint64
ce *callEngine
}
// Snapshot implements the same method as documented on experimental.Snapshotter.
func (ce *callEngine) Snapshot() experimental.Snapshot {
stack := make([]uint64, len(ce.stack))
copy(stack, ce.stack)
frames := make([]*callFrame, len(ce.frames))
copy(frames, ce.frames)
return &snapshot{
stack: stack,
frames: frames,
ce: ce,
}
}
// Restore implements the same method as documented on experimental.Snapshot.
func (s *snapshot) Restore(ret []uint64) {
s.ret = ret
panic(s)
}
func (s *snapshot) doRestore() {
ce := s.ce
ce.stack = s.stack
ce.frames = s.frames
ce.frames[len(ce.frames)-1].pc = s.pc
copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret)
}
// Error implements the same method on error.
func (s *snapshot) Error() string {
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
"exported function invocation than snapshot"
}
// stackIterator implements experimental.StackIterator.
type stackIterator struct {
stack []uint64
@@ -520,6 +567,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
}
}
if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
}
defer func() {
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
if err == nil {
@@ -563,6 +614,12 @@ type functionListenerInvocation struct {
// with the call frame stack traces. Also, reset the state of callEngine
// so that it can be used for the subsequent calls.
func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) {
if s, ok := v.(*snapshot); ok {
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
// let it propagate up to be handled by the caller.
panic(s)
}
builder := wasmdebug.NewErrorBuilder()
frameCount := len(ce.frames)
functionListeners := make([]functionListenerInvocation, 0, 16)
@@ -677,7 +734,23 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance
ce.drop(op.Us[v+1])
frame.pc = op.Us[v]
case wazeroir.OperationKindCall:
func() {
if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(*snapshot); ok && s.ce == ce {
s.doRestore()
frame = ce.frames[len(ce.frames)-1]
body = frame.f.parent.body
bodyLen = uint64(len(body))
} else {
panic(r)
}
}
}()
}
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
}()
frame.pc++
case wazeroir.OperationKindCallIndirect:
offset := ce.popValue()