Add experimental support for snapshot/restore (#1808)
Signed-off-by: Anuraag Agrawal <anuraaga@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ebb5ea2eb7
commit
4185e533bb
25
experimental/checkpoint.go
Normal file
25
experimental/checkpoint.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package experimental
|
||||
|
||||
// Snapshot holds the execution state at the time of a Snapshotter.Snapshot call.
|
||||
type Snapshot interface {
|
||||
// Restore sets the Wasm execution state to the capture. Because a host function
|
||||
// calling this is resetting the pointer to the executation stack, the host function
|
||||
// will not be able to return values in the normal way. ret is a slice of values the
|
||||
// host function intends to return from the restored function.
|
||||
Restore(ret []uint64)
|
||||
}
|
||||
|
||||
// Snapshotter allows host functions to snapshot the WebAssembly execution environment.
|
||||
type Snapshotter interface {
|
||||
// Snapshot captures the current execution state.
|
||||
Snapshot() Snapshot
|
||||
}
|
||||
|
||||
// EnableSnapshotterKey is a context key to indicate that snapshotting should be enabled.
|
||||
// The context.Context passed to a exported function invocation should have this key set
|
||||
// to a non-nil value, and host functions will be able to retrieve it using SnapshotterKey.
|
||||
type EnableSnapshotterKey struct{}
|
||||
|
||||
// SnapshotterKey is a context key to access a Snapshotter from a host function.
|
||||
// It is only present if EnableSnapshotter was set in the function invocation context.
|
||||
type SnapshotterKey struct{}
|
||||
100
experimental/checkpoint_example_test.go
Normal file
100
experimental/checkpoint_example_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package experimental_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
)
|
||||
|
||||
// snapshotWasm was generated by the following:
|
||||
//
|
||||
// cd testdata; wat2wasm snapshot.wat
|
||||
//
|
||||
//go:embed testdata/snapshot.wasm
|
||||
var snapshotWasm []byte
|
||||
|
||||
type snapshotsKey struct{}
|
||||
|
||||
func Example_enableSnapshotterKey() {
|
||||
ctx := context.Background()
|
||||
|
||||
rt := wazero.NewRuntime(ctx)
|
||||
defer rt.Close(ctx) // This closes everything this Runtime created.
|
||||
|
||||
// Enable experimental snapshotting functionality by setting it to context. We use this
|
||||
// context when invoking functions, indicating to wazero to enable it.
|
||||
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
|
||||
|
||||
// Also place a mutable holder of snapshots to be referenced during restore.
|
||||
var snapshots []experimental.Snapshot
|
||||
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
|
||||
|
||||
// Register host functions using snapshot and restore. Generally snapshot is saved
|
||||
// into a mutable location in context to be referenced during restore.
|
||||
_, err := rt.NewHostModuleBuilder("example").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
|
||||
// Because we set EnableSnapshotterKey to context, this is non-nil.
|
||||
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
|
||||
|
||||
// Get our mutable snapshots holder to be able to add to it. Our example only calls snapshot
|
||||
// and restore once but real programs will often call them at multiple layers within a call
|
||||
// stack with various e.g., try/catch statements.
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
idx := len(*snapshots)
|
||||
*snapshots = append(*snapshots, snapshot)
|
||||
|
||||
// Write a value to be passed back to restore. This is meant to be opaque to the guest
|
||||
// and used to re-reference the snapshot.
|
||||
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
|
||||
if !ok {
|
||||
log.Panicln("failed to write snapshot index")
|
||||
}
|
||||
|
||||
return 0
|
||||
}).
|
||||
Export("snapshot").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
|
||||
// Read the value written by snapshot to re-reference the snapshot.
|
||||
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
|
||||
if !ok {
|
||||
log.Panicln("failed to read snapshot index")
|
||||
}
|
||||
|
||||
// Get the snapshot
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
snapshot := (*snapshots)[idx]
|
||||
|
||||
// Restore! The invocation of this function will end as soon as we invoke
|
||||
// Restore, so we also pass in our return value. The guest function run
|
||||
// will finish with this return value.
|
||||
snapshot.Restore([]uint64{5})
|
||||
}).
|
||||
Export("restore").
|
||||
Instantiate(ctx)
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
mod, err := rt.Instantiate(ctx, snapshotWasm) // Instantiate the actual code
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
// Call the guest entrypoint.
|
||||
res, err := mod.ExportedFunction("run").Call(ctx)
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
// We restored and returned the restore value, so it's our result. If restore
|
||||
// was instead a no-op, we would have returned 10 from normal code flow.
|
||||
fmt.Println(res[0])
|
||||
// Output:
|
||||
// 5
|
||||
}
|
||||
121
experimental/checkpoint_test.go
Normal file
121
experimental/checkpoint_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package experimental_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"github.com/tetratelabs/wazero/experimental"
|
||||
"github.com/tetratelabs/wazero/internal/testing/require"
|
||||
)
|
||||
|
||||
func TestSnapshotNestedWasmInvocation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
rt := wazero.NewRuntime(ctx)
|
||||
defer rt.Close(ctx)
|
||||
|
||||
sidechannel := 0
|
||||
|
||||
_, err := rt.NewHostModuleBuilder("example").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
|
||||
defer func() {
|
||||
sidechannel = 10
|
||||
}()
|
||||
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
idx := len(*snapshots)
|
||||
*snapshots = append(*snapshots, snapshot)
|
||||
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
|
||||
require.True(t, ok)
|
||||
|
||||
_, err := mod.ExportedFunction("restore").Call(ctx, uint64(snapshotPtr))
|
||||
require.NoError(t, err)
|
||||
|
||||
return 2
|
||||
}).
|
||||
Export("snapshot").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
|
||||
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
|
||||
require.True(t, ok)
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
snapshot := (*snapshots)[idx]
|
||||
|
||||
snapshot.Restore([]uint64{12})
|
||||
}).
|
||||
Export("restore").
|
||||
Instantiate(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
mod, err := rt.Instantiate(ctx, snapshotWasm)
|
||||
require.NoError(t, err)
|
||||
|
||||
var snapshots []experimental.Snapshot
|
||||
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
|
||||
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
|
||||
|
||||
snapshotPtr := uint64(0)
|
||||
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
|
||||
require.NoError(t, err)
|
||||
// return value from restore
|
||||
require.Equal(t, uint64(12), res[0])
|
||||
// Host function defers within the call stack work fine
|
||||
require.Equal(t, 10, sidechannel)
|
||||
}
|
||||
|
||||
func TestSnapshotMultipleWasmInvocations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
rt := wazero.NewRuntime(ctx)
|
||||
defer rt.Close(ctx)
|
||||
|
||||
_, err := rt.NewHostModuleBuilder("example").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) int32 {
|
||||
snapshot := ctx.Value(experimental.SnapshotterKey{}).(experimental.Snapshotter).Snapshot()
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
idx := len(*snapshots)
|
||||
*snapshots = append(*snapshots, snapshot)
|
||||
ok := mod.Memory().WriteUint32Le(snapshotPtr, uint32(idx))
|
||||
require.True(t, ok)
|
||||
|
||||
return 0
|
||||
}).
|
||||
Export("snapshot").
|
||||
NewFunctionBuilder().
|
||||
WithFunc(func(ctx context.Context, mod api.Module, snapshotPtr uint32) {
|
||||
idx, ok := mod.Memory().ReadUint32Le(snapshotPtr)
|
||||
require.True(t, ok)
|
||||
snapshots := ctx.Value(snapshotsKey{}).(*[]experimental.Snapshot)
|
||||
snapshot := (*snapshots)[idx]
|
||||
|
||||
snapshot.Restore([]uint64{12})
|
||||
}).
|
||||
Export("restore").
|
||||
Instantiate(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
mod, err := rt.Instantiate(ctx, snapshotWasm)
|
||||
require.NoError(t, err)
|
||||
|
||||
var snapshots []experimental.Snapshot
|
||||
ctx = context.WithValue(ctx, snapshotsKey{}, &snapshots)
|
||||
ctx = context.WithValue(ctx, experimental.EnableSnapshotterKey{}, struct{}{})
|
||||
|
||||
snapshotPtr := uint64(0)
|
||||
res, err := mod.ExportedFunction("snapshot").Call(ctx, snapshotPtr)
|
||||
require.NoError(t, err)
|
||||
// snapshot returned zero
|
||||
require.Equal(t, uint64(0), res[0])
|
||||
|
||||
// Fails, snapshot and restore are called from different wasm invocations. Currently, this
|
||||
// results in a panic.
|
||||
err = require.CapturePanic(func() {
|
||||
_, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr)
|
||||
})
|
||||
require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+
|
||||
"exported function invocation than snapshot")
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
// listenerWasm was generated by the following:
|
||||
//
|
||||
// cd testdata; wat2wasm --debug-names listener.wat
|
||||
// cd logging/testdata; wat2wasm --debug-names listener.wat
|
||||
//
|
||||
//go:embed logging/testdata/listener.wasm
|
||||
var listenerWasm []byte
|
||||
|
||||
BIN
experimental/testdata/snapshot.wasm
vendored
Normal file
BIN
experimental/testdata/snapshot.wasm
vendored
Normal file
Binary file not shown.
34
experimental/testdata/snapshot.wat
vendored
Normal file
34
experimental/testdata/snapshot.wat
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
(module
|
||||
(import "example" "snapshot" (func $snapshot (param i32) (result i32)))
|
||||
(import "example" "restore" (func $restore (param i32)))
|
||||
|
||||
(func $helper (result i32)
|
||||
(call $restore (i32.const 0))
|
||||
;; Not executed
|
||||
i32.const 10
|
||||
)
|
||||
|
||||
(func (export "run") (result i32) (local i32)
|
||||
(call $snapshot (i32.const 0))
|
||||
local.set 0
|
||||
local.get 0
|
||||
(if (result i32)
|
||||
(then ;; restore return, finish with the value returned by it
|
||||
local.get 0
|
||||
)
|
||||
(else ;; snapshot return, call heloer
|
||||
(call $helper)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
(func (export "snapshot") (param i32) (result i32)
|
||||
(call $snapshot (local.get 0))
|
||||
)
|
||||
|
||||
(func (export "restore") (param i32)
|
||||
(call $restore (local.get 0))
|
||||
)
|
||||
|
||||
(memory (export "memory") 1 1)
|
||||
)
|
||||
@@ -787,7 +787,12 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
|
||||
defer done()
|
||||
}
|
||||
|
||||
ce.execWasmFunction(ctx, m)
|
||||
snapshotEnabled := ctx.Value(experimental.EnableSnapshotterKey{}) != nil
|
||||
if snapshotEnabled {
|
||||
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
|
||||
}
|
||||
|
||||
ce.execWasmFunction(ctx, m, snapshotEnabled)
|
||||
|
||||
// This returns a safe copy of the results, instead of a slice view. If we
|
||||
// returned a re-slice, the caller could accidentally or purposefully
|
||||
@@ -853,6 +858,11 @@ func callFrameOffset(funcType *wasm.FunctionType) (ret int) {
|
||||
//
|
||||
// This is defined for testability.
|
||||
func (ce *callEngine) deferredOnCall(ctx context.Context, m *wasm.ModuleInstance, recovered interface{}) (err error) {
|
||||
if s, ok := recovered.(*snapshot); ok {
|
||||
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
|
||||
// let it propagate up to be handled by the caller.
|
||||
panic(s)
|
||||
}
|
||||
if recovered != nil {
|
||||
builder := wasmdebug.NewErrorBuilder()
|
||||
|
||||
@@ -1039,7 +1049,7 @@ const (
|
||||
builtinFunctionMemoryNotify
|
||||
)
|
||||
|
||||
func (ce *callEngine) execWasmFunction(ctx context.Context, m *wasm.ModuleInstance) {
|
||||
func (ce *callEngine) execWasmFunction(ctx context.Context, m *wasm.ModuleInstance, snapshotEnabled bool) {
|
||||
codeAddr := ce.initialFn.codeInitialAddress
|
||||
modAddr := ce.initialFn.moduleInstance
|
||||
|
||||
@@ -1065,12 +1075,25 @@ entry:
|
||||
stack := ce.stack[base : base+stackLen]
|
||||
|
||||
fn := calleeHostFunction.parent.goFunc
|
||||
switch fn := fn.(type) {
|
||||
case api.GoModuleFunction:
|
||||
fn.Call(ctx, ce.callerModuleInstance, stack)
|
||||
case api.GoFunction:
|
||||
fn.Call(ctx, stack)
|
||||
}
|
||||
func() {
|
||||
if snapshotEnabled {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(*snapshot); ok && s.ce == ce {
|
||||
s.doRestore()
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
switch fn := fn.(type) {
|
||||
case api.GoModuleFunction:
|
||||
fn.Call(ctx, ce.callerModuleInstance, stack)
|
||||
case api.GoFunction:
|
||||
fn.Call(ctx, stack)
|
||||
}
|
||||
}()
|
||||
|
||||
codeAddr, modAddr = ce.returnAddress, ce.moduleInstance
|
||||
goto entry
|
||||
@@ -1217,6 +1240,58 @@ func (ce *callEngine) builtinFunctionMemoryNotify(mem *wasm.MemoryInstance) {
|
||||
ce.pushValue(uint64(mem.Notify(offset, uint32(count))))
|
||||
}
|
||||
|
||||
// snapshot implements experimental.Snapshot
|
||||
type snapshot struct {
|
||||
stackPointer uint64
|
||||
stackBasePointerInBytes uint64
|
||||
returnAddress uint64
|
||||
hostBase int
|
||||
stack []uint64
|
||||
|
||||
ret []uint64
|
||||
|
||||
ce *callEngine
|
||||
}
|
||||
|
||||
// Snapshot implements the same method as documented on experimental.Snapshotter.
|
||||
func (ce *callEngine) Snapshot() experimental.Snapshot {
|
||||
hostBase := int(ce.stackBasePointerInBytes >> 3)
|
||||
|
||||
stackTop := int(ce.stackTopIndex())
|
||||
stack := make([]uint64, stackTop)
|
||||
copy(stack, ce.stack[:stackTop])
|
||||
|
||||
return &snapshot{
|
||||
stackPointer: ce.stackContext.stackPointer,
|
||||
stackBasePointerInBytes: ce.stackBasePointerInBytes,
|
||||
returnAddress: uint64(ce.returnAddress),
|
||||
hostBase: hostBase,
|
||||
stack: stack,
|
||||
ce: ce,
|
||||
}
|
||||
}
|
||||
|
||||
// Restore implements the same method as documented on experimental.Snapshot.
|
||||
func (s *snapshot) Restore(ret []uint64) {
|
||||
s.ret = ret
|
||||
panic(s)
|
||||
}
|
||||
|
||||
func (s *snapshot) doRestore() {
|
||||
ce := s.ce
|
||||
ce.stackContext.stackPointer = s.stackPointer
|
||||
ce.stackContext.stackBasePointerInBytes = s.stackBasePointerInBytes
|
||||
copy(ce.stack, s.stack)
|
||||
ce.returnAddress = uintptr(s.returnAddress)
|
||||
copy(ce.stack[s.hostBase:], s.ret)
|
||||
}
|
||||
|
||||
// Error implements the same method on error.
|
||||
func (s *snapshot) Error() string {
|
||||
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
|
||||
"exported function invocation than snapshot"
|
||||
}
|
||||
|
||||
// stackIterator implements experimental.StackIterator.
|
||||
type stackIterator struct {
|
||||
stack []uint64
|
||||
|
||||
@@ -224,6 +224,53 @@ func functionFromUintptr(ptr uintptr) *function {
|
||||
return *(**function)(unsafe.Pointer(wrapped))
|
||||
}
|
||||
|
||||
type snapshot struct {
|
||||
stack []uint64
|
||||
frames []*callFrame
|
||||
pc uint64
|
||||
|
||||
ret []uint64
|
||||
|
||||
ce *callEngine
|
||||
}
|
||||
|
||||
// Snapshot implements the same method as documented on experimental.Snapshotter.
|
||||
func (ce *callEngine) Snapshot() experimental.Snapshot {
|
||||
stack := make([]uint64, len(ce.stack))
|
||||
copy(stack, ce.stack)
|
||||
|
||||
frames := make([]*callFrame, len(ce.frames))
|
||||
copy(frames, ce.frames)
|
||||
|
||||
return &snapshot{
|
||||
stack: stack,
|
||||
frames: frames,
|
||||
ce: ce,
|
||||
}
|
||||
}
|
||||
|
||||
// Restore implements the same method as documented on experimental.Snapshot.
|
||||
func (s *snapshot) Restore(ret []uint64) {
|
||||
s.ret = ret
|
||||
panic(s)
|
||||
}
|
||||
|
||||
func (s *snapshot) doRestore() {
|
||||
ce := s.ce
|
||||
|
||||
ce.stack = s.stack
|
||||
ce.frames = s.frames
|
||||
ce.frames[len(ce.frames)-1].pc = s.pc
|
||||
|
||||
copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret)
|
||||
}
|
||||
|
||||
// Error implements the same method on error.
|
||||
func (s *snapshot) Error() string {
|
||||
return "unhandled snapshot restore, this generally indicates restore was called from a different " +
|
||||
"exported function invocation than snapshot"
|
||||
}
|
||||
|
||||
// stackIterator implements experimental.StackIterator.
|
||||
type stackIterator struct {
|
||||
stack []uint64
|
||||
@@ -520,6 +567,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
|
||||
ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
|
||||
if err == nil {
|
||||
@@ -563,6 +614,12 @@ type functionListenerInvocation struct {
|
||||
// with the call frame stack traces. Also, reset the state of callEngine
|
||||
// so that it can be used for the subsequent calls.
|
||||
func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) {
|
||||
if s, ok := v.(*snapshot); ok {
|
||||
// A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation,
|
||||
// let it propagate up to be handled by the caller.
|
||||
panic(s)
|
||||
}
|
||||
|
||||
builder := wasmdebug.NewErrorBuilder()
|
||||
frameCount := len(ce.frames)
|
||||
functionListeners := make([]functionListenerInvocation, 0, 16)
|
||||
@@ -677,7 +734,23 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance
|
||||
ce.drop(op.Us[v+1])
|
||||
frame.pc = op.Us[v]
|
||||
case wazeroir.OperationKindCall:
|
||||
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
|
||||
func() {
|
||||
if ctx.Value(experimental.EnableSnapshotterKey{}) != nil {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(*snapshot); ok && s.ce == ce {
|
||||
s.doRestore()
|
||||
frame = ce.frames[len(ce.frames)-1]
|
||||
body = frame.f.parent.body
|
||||
bodyLen = uint64(len(body))
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
ce.callFunction(ctx, f.moduleInstance, &functions[op.U1])
|
||||
}()
|
||||
frame.pc++
|
||||
case wazeroir.OperationKindCallIndirect:
|
||||
offset := ce.popValue()
|
||||
|
||||
Reference in New Issue
Block a user