Externalize compilation cache by compilers (#747)

This adds the experimental support of the file system compilation cache.
Notably, experimental.WithCompilationCacheDirName allows users to configure
where the compiler writes the cache into.

Versioning/validation of binary compatibility has been done via the release tag
(which will be created from the end of this month). More specifically, the cache
file starts with a header with the hardcoded wazero version.


Fixes #618

Signed-off-by: Takeshi Yoneda <takeshi@tetrate.io>
Co-authored-by: Crypt Keeper <64215+codefromthecrypt@users.noreply.github.com>
This commit is contained in:
Takeshi Yoneda
2022-08-18 19:37:11 +09:00
committed by GitHub
parent 076d3245e3
commit 3b32c2028b
20 changed files with 1015 additions and 89 deletions

View File

@@ -0,0 +1,31 @@
package experimental
import (
"context"
"github.com/tetratelabs/wazero/internal/compilationcache"
)
// WithCompilationCacheDirName configures the destination directory of the compilation cache.
// Regardless of the usage of this, the compiled functions are cached in memory, but its lifetime is
// bound to the lifetime of wazero.Runtime or wazero.CompiledModule.
//
// With the given non-empty directory, wazero persists the cache into the directory and that cache
// will be used as long as the running wazero version match the version of compilation wazero.
//
// A cache is only valid for use in one wazero.Runtime at a time. Concurrent use
// of a wazero.Runtime is supported, but multiple runtimes must not share the
// same directory.
//
// Note: The embedder must safeguard this directory from external changes.
//
// Usage:
//
// ctx := experimental.WithCompilationCacheDirName(context.Background(), "/home/me/.cache/wazero")
// r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler())
func WithCompilationCacheDirName(ctx context.Context, dirname string) context.Context {
if len(dirname) != 0 {
ctx = context.WithValue(ctx, compilationcache.FileCachePathKey{}, dirname)
}
return ctx
}

View File

@@ -0,0 +1,42 @@
package compilationcache
import (
"crypto/sha256"
"io"
)
// Cache allows the compiler engine to skip compilation of wasm to machine code
// where doing so is redundant for the same wasm binary and version of wazero.
//
// This augments the default in-memory cache of compiled functions, by
// decoupling it from a wazero.Runtime instance. Concretely, a runtime loses
// its cache once closed. This cache allows the runtime to rebuild its
// in-memory cache quicker, significantly reducing first-hit penalty on a hit.
//
// See NewFileCache for the example implementation.
type Cache interface {
// Get is called when the runtime is trying to get the cached compiled functions.
// Implementations are supposed to return compiled function in io.Reader with ok=true
// if the key exists on the cache. In the case of not-found, this should return
// ok=false with err=nil. content.Close() is automatically called by
// the caller of this Get.
//
// Note: the returned content won't go through the validation pass of Wasm binary
// which is applied when the binary is compiled from scratch without cache hit.
Get(key Key) (content io.ReadCloser, ok bool, err error)
//
// Add is called when the runtime is trying to add the new cache entry.
// The given `content` must be un-modified, and returned as-is in Get method.
//
// Note: the `content` is ensured to be safe through the validation phase applied on the Wasm binary.
Add(key Key, content io.Reader) (err error)
//
// Delete is called when the cache on the `key` returned by Get is no longer usable, and
// must be purged. Specifically, this is called happens when the wazero's version has been changed.
// For example, that is when there's a difference between the version of compiling wazero and the
// version of the currently used wazero.
Delete(key Key) (err error)
}
// Key represents the 256-bit unique identifier assigned to each cache entry.
type Key = [sha256.Size]byte

View File

@@ -0,0 +1,99 @@
package compilationcache
import (
"context"
"encoding/hex"
"errors"
"io"
"os"
"path"
"sync"
)
// FileCachePathKey is a context.Context Value key. Its value is a string
// representing the compilation cache directory.
type FileCachePathKey struct{}
// NewFileCache returns a new Cache implemented by fileCache.
func NewFileCache(ctx context.Context) Cache {
if fsValue := ctx.Value(FileCachePathKey{}); fsValue != nil {
return newFileCache(fsValue.(string))
}
return nil
}
func newFileCache(dir string) *fileCache {
return &fileCache{dirPath: dir}
}
// fileCache persists compiled functions into dirPath.
//
// Note: this can be expanded to do binary signing/verification, set TTL on each entry, etc.
type fileCache struct {
dirPath string
mux sync.RWMutex
}
type fileReadCloser struct {
*os.File
fc *fileCache
}
func (fc *fileCache) path(key Key) string {
return path.Join(fc.dirPath, hex.EncodeToString(key[:]))
}
func (fc *fileCache) Get(key Key) (content io.ReadCloser, ok bool, err error) {
// TODO: take lock per key for more efficiency vs the complexity of impl.
fc.mux.RLock()
unlock := fc.mux.RUnlock
defer func() {
if unlock != nil {
unlock()
}
}()
f, err := os.Open(fc.path(key))
if errors.Is(err, os.ErrNotExist) {
return nil, false, nil
} else if err != nil {
return nil, false, err
} else {
// Unlock is done inside the content.Close() at the call site.
unlock = nil
return &fileReadCloser{File: f, fc: fc}, true, nil
}
}
// Close wraps the os.File Close to release the read lock on fileCache.
func (f *fileReadCloser) Close() (err error) {
defer f.fc.mux.RUnlock()
err = f.File.Close()
return
}
func (fc *fileCache) Add(key Key, content io.Reader) (err error) {
// TODO: take lock per key for more efficiency vs the complexity of impl.
fc.mux.Lock()
defer fc.mux.Unlock()
file, err := os.Create(fc.path(key))
if err != nil {
return
}
defer file.Close()
_, err = io.Copy(file, content)
return
}
func (fc *fileCache) Delete(key Key) (err error) {
// TODO: take lock per key for more efficiency vs the complexity of impl.
fc.mux.Lock()
defer fc.mux.Unlock()
err = os.Remove(fc.path(key))
if errors.Is(err, os.ErrNotExist) {
err = nil
}
return
}

View File

@@ -0,0 +1,135 @@
package compilationcache
import (
"bytes"
"io"
"os"
"testing"
"github.com/tetratelabs/wazero/internal/testing/require"
)
func TestFileReadCloser_Close(t *testing.T) {
fc := newFileCache(t.TempDir())
key := Key{1, 2, 3}
err := fc.Add(key, bytes.NewReader([]byte{1, 2, 3, 4}))
require.NoError(t, err)
c, ok, err := fc.Get(key)
require.NoError(t, err)
require.True(t, ok)
// At this point, file is not closed, therefore TryLock should fail.
require.False(t, fc.mux.TryLock())
// Close, and then TryLock should succeed this time.
require.NoError(t, c.Close())
require.True(t, fc.mux.TryLock())
}
func TestFileCache_Add(t *testing.T) {
fc := newFileCache(t.TempDir())
t.Run("not exist", func(t *testing.T) {
content := []byte{1, 2, 3, 4, 5}
id := Key{1, 2, 3, 4, 5, 6, 7}
err := fc.Add(id, bytes.NewReader(content))
require.NoError(t, err)
// Ensures that file exists.
cached, err := os.ReadFile(fc.path(id))
require.NoError(t, err)
// Check if the saved content is the same as the given one.
require.Equal(t, content, cached)
})
t.Run("already exists", func(t *testing.T) {
content := []byte{1, 2, 3, 4, 5}
id := Key{1, 2, 3}
// Writes the pre-existing file for the same ID.
p := fc.path(id)
f, err := os.Create(p)
require.NoError(t, err)
_, err = f.Write(content)
require.NoError(t, err)
require.NoError(t, f.Close())
err = fc.Add(id, bytes.NewReader(content))
require.NoError(t, err)
// Ensures that file exists.
cached, err := os.ReadFile(fc.path(id))
require.NoError(t, err)
// Check if the saved content is the same as the given one.
require.Equal(t, content, cached)
})
}
func TestFileCache_Delete(t *testing.T) {
fc := newFileCache(t.TempDir())
t.Run("non-exist", func(t *testing.T) {
id := Key{0}
err := fc.Delete(id)
require.NoError(t, err)
})
t.Run("exist", func(t *testing.T) {
id := Key{1, 2, 3}
p := fc.path(id)
f, err := os.Create(p)
require.NoError(t, err)
require.NoError(t, f.Close())
// Ensures that file exists now.
f, err = os.Open(p)
require.NoError(t, err)
require.NoError(t, f.Close())
// Delete the cache.
err = fc.Delete(id)
require.NoError(t, err)
// Ensures that file no longer exists.
_, err = os.Open(p)
require.ErrorIs(t, err, os.ErrNotExist)
})
}
func TestFileCache_Get(t *testing.T) {
fc := newFileCache(t.TempDir())
t.Run("exist", func(t *testing.T) {
content := []byte{1, 2, 3, 4, 5}
id := Key{1, 2, 3}
// Writes the pre-existing file for the ID.
p := fc.path(id)
f, err := os.Create(p)
require.NoError(t, err)
_, err = f.Write(content)
require.NoError(t, err)
require.NoError(t, f.Close())
result, ok, err := fc.Get(id)
require.NoError(t, err)
require.True(t, ok)
defer func() {
require.NoError(t, result.Close())
}()
actual, err := io.ReadAll(result)
require.NoError(t, err)
require.Equal(t, content, actual)
})
t.Run("not exist", func(t *testing.T) {
_, ok, err := fc.Get(Key{0xf})
// Non-exist should not be error.
require.NoError(t, err)
require.False(t, ok)
})
}

View File

@@ -10,7 +10,9 @@ import (
"unsafe"
"github.com/tetratelabs/wazero/internal/buildoptions"
"github.com/tetratelabs/wazero/internal/compilationcache"
"github.com/tetratelabs/wazero/internal/platform"
"github.com/tetratelabs/wazero/internal/version"
"github.com/tetratelabs/wazero/internal/wasm"
"github.com/tetratelabs/wazero/internal/wasmdebug"
"github.com/tetratelabs/wazero/internal/wasmruntime"
@@ -22,9 +24,11 @@ type (
engine struct {
enabledFeatures wasm.Features
codes map[wasm.ModuleID][]*code // guarded by mutex.
Cache compilationcache.Cache
mux sync.RWMutex
// setFinalizer defaults to runtime.SetFinalizer, but overridable for tests.
setFinalizer func(obj interface{}, finalizer interface{})
wazeroVersion string
}
// moduleEngine implements wasm.ModuleEngine
@@ -411,8 +415,10 @@ func (e *engine) DeleteCompiledModule(module *wasm.Module) {
// CompileModule implements the same method as documented on wasm.Engine.
func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error {
if _, ok := e.getCodes(module); ok { // cache hit!
if _, ok, err := e.getCodes(module); ok { // cache hit!
return nil
} else if err != nil {
return err
}
funcs := make([]*code, 0, len(module.FunctionSection))
@@ -441,8 +447,7 @@ func (e *engine) CompileModule(ctx context.Context, module *wasm.Module) error {
funcs = append(funcs, compiled)
}
e.addCodes(module, funcs)
return nil
return e.addCodes(module, funcs)
}
// NewModuleEngine implements the same method as documented on wasm.Engine.
@@ -459,9 +464,11 @@ func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunct
me.functions = append(me.functions, cf)
}
codes, ok := e.getCodes(module)
codes, ok, err := e.getCodes(module)
if !ok {
return nil, fmt.Errorf("source module for %s must be compiled before instantiation", name)
} else if err != nil {
return nil, err
}
for i, c := range codes {
@@ -485,25 +492,6 @@ func (e *engine) NewModuleEngine(name string, module *wasm.Module, importedFunct
return me, nil
}
func (e *engine) deleteCodes(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()
delete(e.codes, module.ID)
}
func (e *engine) addCodes(module *wasm.Module, fs []*code) {
e.mux.Lock()
defer e.mux.Unlock()
e.codes[module.ID] = fs
}
func (e *engine) getCodes(module *wasm.Module) (fs []*code, ok bool) {
e.mux.RLock()
defer e.mux.RUnlock()
fs, ok = e.codes[module.ID]
return
}
// Name implements the same method as documented on wasm.ModuleEngine.
func (e *moduleEngine) Name() string {
return e.name
@@ -594,11 +582,17 @@ func NewEngine(ctx context.Context, enabledFeatures wasm.Features) wasm.Engine {
return newEngine(ctx, enabledFeatures)
}
func newEngine(_ context.Context, enabledFeatures wasm.Features) *engine {
func newEngine(ctx context.Context, enabledFeatures wasm.Features) *engine {
var wazeroVersion string
if v := ctx.Value(version.WazeroVersionKey{}); v != nil {
wazeroVersion = v.(string)
}
return &engine{
enabledFeatures: enabledFeatures,
codes: map[wasm.ModuleID][]*code{},
setFinalizer: runtime.SetFinalizer,
Cache: compilationcache.NewFileCache(ctx),
wazeroVersion: wazeroVersion,
}
}

View File

@@ -0,0 +1,188 @@
package compiler
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/tetratelabs/wazero/internal/platform"
"github.com/tetratelabs/wazero/internal/u32"
"github.com/tetratelabs/wazero/internal/u64"
"github.com/tetratelabs/wazero/internal/wasm"
)
func (e *engine) deleteCodes(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()
delete(e.codes, module.ID)
// Note: we do not call e.Cache.Delete, as the lifetime of
// the content is up to the implementation of extencache.Cache interface.
}
func (e *engine) addCodes(module *wasm.Module, codes []*code) (err error) {
e.addCodesToMemory(module, codes)
err = e.addCodesToCache(module, codes)
return
}
func (e *engine) getCodes(module *wasm.Module) (codes []*code, ok bool, err error) {
codes, ok = e.getCodesFromMemory(module)
if ok {
return
}
codes, ok, err = e.getCodesFromCache(module)
if ok {
e.addCodesToMemory(module, codes)
}
return
}
func (e *engine) addCodesToMemory(module *wasm.Module, codes []*code) {
e.mux.Lock()
defer e.mux.Unlock()
e.codes[module.ID] = codes
}
func (e *engine) getCodesFromMemory(module *wasm.Module) (codes []*code, ok bool) {
e.mux.RLock()
defer e.mux.RUnlock()
codes, ok = e.codes[module.ID]
return
}
func (e *engine) addCodesToCache(module *wasm.Module, codes []*code) (err error) {
if e.Cache == nil {
return
}
err = e.Cache.Add(module.ID, serializeCodes(e.wazeroVersion, codes))
return
}
func (e *engine) getCodesFromCache(module *wasm.Module) (codes []*code, hit bool, err error) {
if e.Cache == nil {
return
}
// Check if the entries exist in the external cache.
var cached io.ReadCloser
cached, hit, err = e.Cache.Get(module.ID)
if !hit || err != nil {
return
}
defer cached.Close()
// Otherwise, we hit the cache on external cache.
// We retrieve *code structures from `cached`.
var staleCache bool
codes, staleCache, err = deserializeCodes(e.wazeroVersion, cached)
if err != nil {
hit = false
return
} else if staleCache {
return nil, false, e.Cache.Delete(module.ID)
}
for i, c := range codes {
c.indexInModule = wasm.Index(i)
c.sourceModule = module
}
return
}
var (
wazeroMagic = "WAZERO"
// version must be synced with the tag of the wazero library.
)
func serializeCodes(wazeroVersion string, codes []*code) io.Reader {
buf := bytes.NewBuffer(nil)
// First 6 byte: WAZERO header.
buf.WriteString(wazeroMagic)
// Next 1 byte: length of version:
buf.WriteByte(byte(len(wazeroVersion)))
// Version of wazero.
buf.WriteString(wazeroVersion)
// Number of *code (== locally defined functions in the module): 4 bytes.
buf.Write(u32.LeBytes(uint32(len(codes))))
for _, c := range codes {
// The stack pointer ceil (8 bytes).
buf.Write(u64.LeBytes(c.stackPointerCeil))
// The length of code segment (8 bytes).
buf.Write(u64.LeBytes(uint64(len(c.codeSegment))))
// Append the native code.
buf.Write(c.codeSegment)
}
return bytes.NewReader(buf.Bytes())
}
func deserializeCodes(wazeroVersion string, reader io.Reader) (codes []*code, staleCache bool, err error) {
cacheHeaderSize := len(wazeroMagic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */
// Read the header before the native code.
header := make([]byte, cacheHeaderSize)
n, err := reader.Read(header)
if err != nil {
return nil, false, err
}
if n != cacheHeaderSize {
return nil, false, fmt.Errorf("invalid header length: %d", n)
}
// Check the version compatibility.
versionSize := int(header[len(wazeroMagic)])
cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize
if cachedVersionEnd >= len(header) {
staleCache = true
return
} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
staleCache = true
return
}
functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
codes = make([]*code, 0, functionsNum)
var eightBytes [8]byte
for i := uint32(0); i < functionsNum; i++ {
c := &code{}
// Read the stack pointer ceil.
_, err = reader.Read(eightBytes[:])
if err != nil {
err = fmt.Errorf("reading stack pointer ceil: %v", err)
break
}
c.stackPointerCeil = binary.LittleEndian.Uint64(eightBytes[:])
// Read (and mmap) the native code.
_, err = reader.Read(eightBytes[:])
if err != nil {
err = fmt.Errorf("reading native code size: %v", err)
break
}
c.codeSegment, err = platform.MmapCodeSegment(reader, int(binary.LittleEndian.Uint64(eightBytes[:])))
if err != nil {
err = fmt.Errorf("mmaping function: %v", err)
break
}
codes = append(codes, c)
}
if err != nil {
for _, c := range codes {
if errMunmap := platform.MunmapCodeSegment(c.codeSegment); errMunmap != nil {
// Munmap failure shouldn't happen.
panic(errMunmap)
}
}
codes = nil
}
return
}

View File

@@ -0,0 +1,370 @@
package compiler
import (
"bytes"
"fmt"
"io"
"testing"
"github.com/tetratelabs/wazero/internal/testing/require"
"github.com/tetratelabs/wazero/internal/u32"
"github.com/tetratelabs/wazero/internal/u64"
"github.com/tetratelabs/wazero/internal/wasm"
)
var testVersion string
func concat(ins ...[]byte) (ret []byte) {
for _, in := range ins {
ret = append(ret, in...)
}
return
}
func TestSerializeCodes(t *testing.T) {
tests := []struct {
in []*code
exp []byte
}{
{
in: []*code{{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}},
exp: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
),
},
{
in: []*code{
{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}},
},
exp: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(3), // length of code.
[]byte{1, 2, 3}, // code.
),
},
}
for i, tc := range tests {
actual, err := io.ReadAll(serializeCodes(testVersion, tc.in))
require.NoError(t, err, i)
require.Equal(t, tc.exp, actual, i)
}
}
func TestDeserializeCodes(t *testing.T) {
tests := []struct {
name string
in []byte
expCodes []*code
expStaleCache bool
expErr string
}{
{
name: "invalid header",
in: []byte{1},
expErr: "invalid header length: 1",
},
{
name: "version mismatch",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len("1233123.1.1"))},
[]byte("1233123.1.1"),
u32.LeBytes(1), // number of functions.
),
expStaleCache: true,
},
{
name: "version mismatch",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len("1"))},
[]byte("1"),
u32.LeBytes(1), // number of functions.
),
expStaleCache: true,
},
{
name: "one function",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
),
expCodes: []*code{
{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
},
expStaleCache: false,
expErr: "",
},
{
name: "two functions",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(3), // length of code.
[]byte{1, 2, 3}, // code.
),
expCodes: []*code{
{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}},
},
expStaleCache: false,
expErr: "",
},
{
name: "reading stack pointer",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
),
expErr: "reading stack pointer ceil: EOF",
},
{
name: "reading native code size",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
u64.LeBytes(12345), // stack pointer ceil.
),
expErr: "reading native code size: EOF",
},
{
name: "mmapping",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
// Lack of code here.
),
expErr: "mmaping function: EOF",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
codes, staleCache, err := deserializeCodes(testVersion, bytes.NewReader(tc.in))
if tc.expErr != "" {
require.EqualError(t, err, tc.expErr)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.expCodes, codes)
require.Equal(t, tc.expStaleCache, staleCache)
})
}
}
func TestEngine_getCodesFromCache(t *testing.T) {
tests := []struct {
name string
ext *testCache
key wasm.ModuleID
expCodes []*code
expHit bool
expErr string
expDeleted bool
}{
{name: "extern cache not given"},
{
name: "not hit",
ext: &testCache{caches: map[wasm.ModuleID][]byte{}},
},
{
name: "error in Cache.Get",
ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: {}}},
expErr: "some error from extern cache",
},
{
name: "error in deserialization",
ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: {1, 2, 3}}},
expErr: "invalid header length: 3",
},
{
name: "stale cache",
ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: concat(
[]byte(wazeroMagic),
[]byte{byte(len("1233123.1.1"))},
[]byte("1233123.1.1"),
u32.LeBytes(1), // number of functions.
)}},
expDeleted: true,
},
{
name: "hit",
ext: &testCache{caches: map[wasm.ModuleID][]byte{
{}: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(2), // number of functions.
// Function index = 0.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
// Function index = 1.
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(3), // length of code.
[]byte{1, 2, 3}, // code.
),
}},
expHit: true,
expCodes: []*code{
{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}, indexInModule: 0},
{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}, indexInModule: 1},
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
m := &wasm.Module{ID: tc.key}
for _, expC := range tc.expCodes {
expC.sourceModule = m
}
e := engine{}
if tc.ext != nil {
e.Cache = tc.ext
}
codes, hit, err := e.getCodesFromCache(m)
if tc.expErr != "" {
require.EqualError(t, err, tc.expErr)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.expHit, hit)
require.Equal(t, tc.expCodes, codes)
if tc.expDeleted {
require.Equal(t, tc.ext.deleted, tc.key)
}
})
}
}
func TestEngine_addCodesToCache(t *testing.T) {
t.Run("not defined", func(t *testing.T) {
e := engine{}
err := e.addCodesToCache(nil, nil)
require.NoError(t, err)
})
t.Run("add", func(t *testing.T) {
ext := &testCache{caches: map[wasm.ModuleID][]byte{}}
e := engine{Cache: ext}
m := &wasm.Module{}
codes := []*code{{stackPointerCeil: 123, codeSegment: []byte{1, 2, 3}}}
err := e.addCodesToCache(m, codes)
require.NoError(t, err)
content, ok := ext.caches[m.ID]
require.True(t, ok)
require.Equal(t, concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
u32.LeBytes(1), // number of functions.
u64.LeBytes(123), // stack pointer ceil.
u64.LeBytes(3), // length of code.
[]byte{1, 2, 3}, // code.
), content)
})
}
// testCache implements compilationcache.Cache
type testCache struct {
caches map[wasm.ModuleID][]byte
deleted wasm.ModuleID
}
// Get implements compilationcache.Cache Get
func (tc *testCache) Get(key wasm.ModuleID) (content io.ReadCloser, ok bool, err error) {
var raw []byte
raw, ok = tc.caches[key]
if !ok {
return
}
if len(raw) == 0 {
ok = false
err = fmt.Errorf("some error from extern cache")
return
}
content = io.NopCloser(bytes.NewReader(raw))
return
}
// Add implements compilationcache.Cache Add
func (tc *testCache) Add(key wasm.ModuleID, content io.Reader) (err error) {
raw, err := io.ReadAll(content)
if err != nil {
return err
}
tc.caches[key] = raw
return
}
// Delete implements compilationcache.Cache Delete
func (tc *testCache) Delete(key wasm.ModuleID) (err error) {
tc.deleted = key
return
}

View File

@@ -284,29 +284,6 @@ func TestCompiler_SliceAllocatedOnHeap(t *testing.T) {
}
}
// TODO: move most of this logic to enginetest.go so that there is less drift between interpreter and compiler
func TestEngine_Cachedcodes(t *testing.T) {
e := newEngine(context.Background(), wasm.Features20191205)
exp := []*code{
{codeSegment: []byte{0x0}},
{codeSegment: []byte{0x0}},
}
m := &wasm.Module{}
e.addCodes(m, exp)
actual, ok := e.getCodes(m)
require.True(t, ok)
require.Equal(t, len(exp), len(actual))
for i := range actual {
require.Equal(t, exp[i], actual[i])
}
e.deleteCodes(m)
_, ok = e.getCodes(m)
require.False(t, ok)
}
func TestCallEngine_builtinFunctionTableGrow(t *testing.T) {
ce := &callEngine{
valueStack: []uint64{

View File

@@ -7,6 +7,7 @@ package compiler
// e.g. MOVQ will be given as amd64.MOVQ.
import (
"bytes"
"fmt"
"math"
"runtime"
@@ -191,10 +192,7 @@ func (c *amd64Compiler) compile() (code []byte, stackPointerCeil uint64, err err
return
}
code, err = platform.MmapCodeSegment(code)
if err != nil {
return
}
code, err = platform.MmapCodeSegment(bytes.NewReader(code), len(code))
return
}

View File

@@ -4,6 +4,7 @@
package compiler
import (
"bytes"
"errors"
"fmt"
"math"
@@ -112,10 +113,7 @@ func (c *arm64Compiler) compile() (code []byte, stackPointerCeil uint64, err err
return
}
code, err = platform.MmapCodeSegment(original)
if err != nil {
return
}
code, err = platform.MmapCodeSegment(bytes.NewReader(original), len(original))
return
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/platform"
"github.com/tetratelabs/wazero/wasi_snapshot_preview1"
)
@@ -50,11 +52,39 @@ func BenchmarkInitialization(b *testing.B) {
}
}
func runInitializationBench(b *testing.B, r wazero.Runtime) {
func BenchmarkCompilation(b *testing.B) {
if !platform.CompilerSupported() {
b.Skip()
}
// Note: recreate runtime each time in the loop to ensure that
// recompilation happens if the extern cache is not used.
b.Run("with extern cache", func(b *testing.B) {
ctx := experimental.WithCompilationCacheDirName(context.Background(), b.TempDir())
for i := 0; i < b.N; i++ {
r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler())
runCompilation(b, r)
}
})
b.Run("without extern cache", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
r := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigCompiler())
runCompilation(b, r)
}
})
}
func runCompilation(b *testing.B, r wazero.Runtime) wazero.CompiledModule {
compiled, err := r.CompileModule(testCtx, caseWasm, wazero.NewCompileConfig())
if err != nil {
b.Fatal(err)
}
return compiled
}
func runInitializationBench(b *testing.B, r wazero.Runtime) {
compiled := runCompilation(b, r)
defer compiled.Close(testCtx)
b.ResetTimer()
for i := 0; i < b.N; i++ {

View File

@@ -0,0 +1,20 @@
package platform
// bufWriter implements io.Writer.
//
// This is implemented because bytes.Buffer cannot write from the beginning of the underlying buffer
// without changing the memory location. In this case, the underlying buffer is memory-mapped region,
// and we have to write into that region via io.Copy since sometimes the original native code exists
// as a file for external-cached cases.
type bufWriter struct {
underlying []byte
pos int
}
// Write implements io.Writer Write.
func (b *bufWriter) Write(p []byte) (n int, err error) {
copy(b.underlying[b.pos:], p)
n = len(p)
b.pos += n
return
}

View File

@@ -4,6 +4,7 @@
package platform
import (
"io"
"syscall"
"unsafe"
)
@@ -14,11 +15,11 @@ func munmapCodeSegment(code []byte) error {
// mmapCodeSegmentAMD64 gives all read-write-exec permission to the mmap region
// to enter the function. Otherwise, segmentation fault exception is raised.
func mmapCodeSegmentAMD64(code []byte) ([]byte, error) {
func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
mmapFunc, err := syscall.Mmap(
-1,
0,
len(code),
size,
// The region must be RWX: RW for writing native codes, X for executing the region.
syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC,
// Anonymous as this is not an actual file, but a memory,
@@ -28,19 +29,21 @@ func mmapCodeSegmentAMD64(code []byte) ([]byte, error) {
if err != nil {
return nil, err
}
copy(mmapFunc, code)
return mmapFunc, nil
w := &bufWriter{underlying: mmapFunc}
_, err = io.CopyN(w, code, int64(size))
return mmapFunc, err
}
// mmapCodeSegmentARM64 cannot give all read-write-exec permission to the mmap region.
// Otherwise, the mmap systemcall would raise an error. Here we give read-write
// to the region at first, write the native code and then change the perm to
// read-exec so we can execute the native code.
func mmapCodeSegmentARM64(code []byte) ([]byte, error) {
// read-exec, so we can execute the native code.
func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
mmapFunc, err := syscall.Mmap(
-1,
0,
len(code),
size,
// The region must be RW: RW for writing native codes.
syscall.PROT_READ|syscall.PROT_WRITE,
// Anonymous as this is not an actual file, but a memory,
@@ -51,7 +54,11 @@ func mmapCodeSegmentARM64(code []byte) ([]byte, error) {
return nil, err
}
copy(mmapFunc, code)
w := &bufWriter{underlying: mmapFunc}
_, err = io.CopyN(w, code, int64(size))
if err != nil {
return nil, err
}
// Then we're done with writing code, change the permission to RX.
err = mprotect(mmapFunc, syscall.PROT_READ|syscall.PROT_EXEC)

View File

@@ -1,6 +1,7 @@
package platform
import (
"bytes"
"crypto/rand"
"io"
"testing"
@@ -8,22 +9,23 @@ import (
"github.com/tetratelabs/wazero/internal/testing/require"
)
var testCode, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024))
var testCodeBuf, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024))
func Test_MmapCodeSegment(t *testing.T) {
if !CompilerSupported() {
t.Skip()
}
newCode, err := MmapCodeSegment(testCode)
testCodeReader := bytes.NewReader(testCodeBuf)
newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len())
require.NoError(t, err)
// Verify that the mmap is the same as the original.
require.Equal(t, testCode, newCode)
require.Equal(t, testCodeBuf, newCode)
// TODO: test newCode can executed.
t.Run("panic on zero length", func(t *testing.T) {
captured := require.CapturePanic(func() {
_, _ = MmapCodeSegment(make([]byte, 0))
_, _ = MmapCodeSegment(bytes.NewBuffer(make([]byte, 0)), 0)
})
require.EqualError(t, captured, "BUG: MmapCodeSegment with zero length")
})
@@ -35,9 +37,10 @@ func Test_MunmapCodeSegment(t *testing.T) {
}
// Errors if never mapped
require.Error(t, MunmapCodeSegment(testCode))
require.Error(t, MunmapCodeSegment(testCodeBuf))
newCode, err := MmapCodeSegment(testCode)
testCodeReader := bytes.NewReader(testCodeBuf)
newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len())
require.NoError(t, err)
// First munmap should succeed.
require.NoError(t, MunmapCodeSegment(newCode))

View File

@@ -4,6 +4,7 @@ package platform
import (
"fmt"
"io"
"reflect"
"syscall"
"unsafe"
@@ -30,9 +31,8 @@ func munmapCodeSegment(code []byte) error {
// allocateMemory commits the memory region via the "VirtualAlloc" function.
// See https://docs.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-virtualalloc
func allocateMemory(code []byte, protect uintptr) (uintptr, error) {
func allocateMemory(size uintptr, protect uintptr) (uintptr, error) {
address := uintptr(0) // TODO: document why zero
size := uintptr(len(code))
alloctype := windows_MEM_COMMIT
if r, _, err := procVirtualAlloc.Call(address, size, alloctype, protect); r == 0 {
return 0, fmt.Errorf("compiler: VirtualAlloc error: %w", ensureErr(err))
@@ -60,8 +60,8 @@ func virtualProtect(address, size, newprotect uintptr, oldprotect *uint32) error
return nil
}
func mmapCodeSegmentAMD64(code []byte) ([]byte, error) {
p, err := allocateMemory(code, windows_PAGE_EXECUTE_READWRITE)
func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
p, err := allocateMemory(uintptr(size), windows_PAGE_EXECUTE_READWRITE)
if err != nil {
return nil, err
}
@@ -69,14 +69,16 @@ func mmapCodeSegmentAMD64(code []byte) ([]byte, error) {
var mem []byte
sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
sh.Data = p
sh.Len = len(code)
sh.Cap = len(code)
copy(mem, code)
return mem, nil
sh.Len = size
sh.Cap = size
w := &bufWriter{underlying: mem}
_, err = io.CopyN(w, code, int64(size))
return mem, err
}
func mmapCodeSegmentARM64(code []byte) ([]byte, error) {
p, err := allocateMemory(code, windows_PAGE_READWRITE)
func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
p, err := allocateMemory(uintptr(size), windows_PAGE_READWRITE)
if err != nil {
return nil, err
}
@@ -84,12 +86,16 @@ func mmapCodeSegmentARM64(code []byte) ([]byte, error) {
var mem []byte
sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
sh.Data = p
sh.Len = len(code)
sh.Cap = len(code)
copy(mem, code)
sh.Len = size
sh.Cap = size
w := &bufWriter{underlying: mem}
_, err = io.CopyN(w, code, int64(size))
if err != nil {
return nil, err
}
old := uint32(windows_PAGE_READWRITE)
err = virtualProtect(p, uintptr(len(code)), windows_PAGE_EXECUTE_READ, &old)
err = virtualProtect(p, uintptr(size), windows_PAGE_EXECUTE_READ, &old)
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ package platform
import (
"errors"
"io"
"runtime"
)
@@ -29,14 +30,14 @@ func CompilerSupported() bool {
// MmapCodeSegment copies the code into the executable region and returns the byte slice of the region.
//
// See https://man7.org/linux/man-pages/man2/mmap.2.html for mmap API and flags.
func MmapCodeSegment(code []byte) ([]byte, error) {
if len(code) == 0 {
func MmapCodeSegment(code io.Reader, size int) ([]byte, error) {
if size == 0 {
panic(errors.New("BUG: MmapCodeSegment with zero length"))
}
if runtime.GOARCH == "amd64" {
return mmapCodeSegmentAMD64(code)
return mmapCodeSegmentAMD64(code, size)
} else {
return mmapCodeSegmentARM64(code)
return mmapCodeSegmentARM64(code, size)
}
}

View File

@@ -0,0 +1,4 @@
package version
// WazeroVersionKey is the key for holding wazero's version in context.Context.
type WazeroVersionKey struct{}

View File

@@ -7,6 +7,7 @@ import (
"github.com/tetratelabs/wazero/api"
experimentalapi "github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/version"
"github.com/tetratelabs/wazero/internal/wasm"
binaryformat "github.com/tetratelabs/wazero/internal/wasm/binary"
)
@@ -126,6 +127,9 @@ func NewRuntime(ctx context.Context) Runtime {
// NewRuntimeWithConfig returns a runtime with the given configuration.
func NewRuntimeWithConfig(ctx context.Context, rConfig RuntimeConfig) Runtime {
if v := ctx.Value(version.WazeroVersionKey{}); v == nil {
ctx = context.WithValue(ctx, version.WazeroVersionKey{}, wazeroVersion)
}
config := rConfig.(*runtimeConfig)
store, ns := wasm.NewStore(config.enabledFeatures, config.newEngine(ctx, config.enabledFeatures))
return &runtime{

View File

@@ -9,6 +9,7 @@ import (
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/internal/leb128"
"github.com/tetratelabs/wazero/internal/testing/require"
"github.com/tetratelabs/wazero/internal/version"
"github.com/tetratelabs/wazero/internal/wasm"
binaryformat "github.com/tetratelabs/wazero/internal/wasm/binary"
"github.com/tetratelabs/wazero/sys"
@@ -20,6 +21,19 @@ var (
testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary")
)
func TestNewRuntimeWithConfig_version(t *testing.T) {
cfg := NewRuntimeConfig().(*runtimeConfig)
oldNewEngine := cfg.newEngine
cfg.newEngine = func(ctx context.Context, features wasm.Features) wasm.Engine {
// Ensures that wazeroVersion is propagated to the engine.
v := ctx.Value(version.WazeroVersionKey{})
require.NotNil(t, v)
require.Equal(t, wazeroVersion, v.(string))
return oldNewEngine(ctx, features)
}
_ = NewRuntimeWithConfig(testCtx, cfg)
}
func TestRuntime_CompileModule(t *testing.T) {
tests := []struct {
name string

5
version.go Normal file
View File

@@ -0,0 +1,5 @@
package wazero
// wazeroVersion holds the current version of wazero.
// TODO: use debug.ReadBuildInfo automatically set wazeroVersion to the release tag.
var wazeroVersion = "dev"