diff --git a/internal/wasm/host.go b/internal/wasm/host.go index bb9c40f5..bca686d1 100644 --- a/internal/wasm/host.go +++ b/internal/wasm/host.go @@ -72,7 +72,7 @@ func NewHostModule( // compilation of host modules is not costly as it's merely small trampolines vs the real-world native Wasm binary. // TODO: refactor engines so that we can properly cache compiled machine codes for host modules. m.AssignModuleID([]byte(fmt.Sprintf("@@@@@@@@%p", m)), // @@@@@@@@ = any 8 bytes different from Wasm header. - false, false) + nil, false) return } diff --git a/internal/wasm/module.go b/internal/wasm/module.go index ba2c032a..7c363b2f 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -3,6 +3,7 @@ package wasm import ( "bytes" "crypto/sha256" + "encoding/binary" "errors" "fmt" "io" @@ -11,6 +12,7 @@ import ( "sync" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/ieee754" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/wasmdebug" @@ -198,13 +200,20 @@ const ( // AssignModuleID calculates a sha256 checksum on `wasm` and other args, and set Module.ID to the result. // See the doc on Module.ID on what it's used for. -func (m *Module) AssignModuleID(wasm []byte, withListener, withEnsureTermination bool) { +func (m *Module) AssignModuleID(wasm []byte, listeners []experimental.FunctionListener, withEnsureTermination bool) { h := sha256.New() h.Write(wasm) - // Use the pre-allocated space on m.ID to append the booleans to sha256 hash. - m.ID[0] = boolToByte(withListener) - m.ID[1] = boolToByte(withEnsureTermination) - h.Write(m.ID[:2]) + // Use the pre-allocated space backed by m.ID below. + + // Write the existence of listeners to the checksum per function. + for i, l := range listeners { + binary.LittleEndian.PutUint32(m.ID[:], uint32(i)) + m.ID[4] = boolToByte(l != nil) + h.Write(m.ID[:5]) + } + // Write the flag of ensureTermination to the checksum. + m.ID[0] = boolToByte(withEnsureTermination) + h.Write(m.ID[:1]) // Get checksum by passing the slice underlying m.ID. h.Sum(m.ID[:0]) } diff --git a/internal/wasm/module_test.go b/internal/wasm/module_test.go index 77eb830d..a3f5dab4 100644 --- a/internal/wasm/module_test.go +++ b/internal/wasm/module_test.go @@ -1,11 +1,13 @@ package wasm import ( + "context" "fmt" "math" "testing" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" "github.com/tetratelabs/wazero/internal/leb128" "github.com/tetratelabs/wazero/internal/testing/require" "github.com/tetratelabs/wazero/internal/u64" @@ -1010,30 +1012,88 @@ func TestModule_declaredFunctionIndexes(t *testing.T) { } func TestModule_AssignModuleID(t *testing.T) { - getID := func(bin []byte, withListener, withEnsureTermination bool) ModuleID { + getID := func(bin []byte, lsns []experimental.FunctionListener, withEnsureTermination bool) ModuleID { m := Module{} - m.AssignModuleID(bin, withListener, withEnsureTermination) + m.AssignModuleID(bin, lsns, withEnsureTermination) return m.ID } + ml := &mockListener{} + // Ensures that different args always produce the different IDs. exists := map[ModuleID]struct{}{} - for _, tc := range []struct { - bin []byte - withListener, withEnsureTermination bool + for i, tc := range []struct { + bin []byte + withEnsureTermination bool + listeners []experimental.FunctionListener }{ - {bin: []byte{1, 2, 3}, withListener: false, withEnsureTermination: false}, - {bin: []byte{1, 2, 3}, withListener: false, withEnsureTermination: true}, - {bin: []byte{1, 2, 3}, withListener: true, withEnsureTermination: false}, - {bin: []byte{1, 2, 3}, withListener: true, withEnsureTermination: true}, - {bin: []byte{1, 2, 3, 4}, withListener: false, withEnsureTermination: false}, - {bin: []byte{1, 2, 3, 4}, withListener: false, withEnsureTermination: true}, - {bin: []byte{1, 2, 3, 4}, withListener: true, withEnsureTermination: false}, - {bin: []byte{1, 2, 3, 4}, withListener: true, withEnsureTermination: true}, + {bin: []byte{1, 2, 3}, withEnsureTermination: false}, + {bin: []byte{1, 2, 3}, withEnsureTermination: true}, + { + bin: []byte{1, 2, 3}, + listeners: []experimental.FunctionListener{ml}, + withEnsureTermination: false, + }, + { + bin: []byte{1, 2, 3}, + listeners: []experimental.FunctionListener{ml}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3}, + listeners: []experimental.FunctionListener{nil, ml}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3}, + listeners: []experimental.FunctionListener{ml, ml}, + withEnsureTermination: true, + }, + {bin: []byte{1, 2, 3, 4}, withEnsureTermination: false}, + {bin: []byte{1, 2, 3, 4}, withEnsureTermination: true}, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{ml}, + withEnsureTermination: false, + }, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{ml}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{nil}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{nil, ml}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{ml, ml}, + withEnsureTermination: true, + }, + { + bin: []byte{1, 2, 3, 4}, + listeners: []experimental.FunctionListener{ml, ml}, + withEnsureTermination: false, + }, } { - id := getID(tc.bin, tc.withListener, tc.withEnsureTermination) + id := getID(tc.bin, tc.listeners, tc.withEnsureTermination) _, exist := exists[id] - require.False(t, exist) + require.False(t, exist, i) exists[id] = struct{}{} } } + +type mockListener struct{} + +func (m mockListener) Before(context.Context, api.Module, api.FunctionDefinition, []uint64, experimental.StackIterator) { +} + +func (m mockListener) After(context.Context, api.Module, api.FunctionDefinition, []uint64) {} + +func (m mockListener) Abort(context.Context, api.Module, api.FunctionDefinition, error) {} diff --git a/runtime.go b/runtime.go index c416be71..220324ed 100644 --- a/runtime.go +++ b/runtime.go @@ -233,7 +233,7 @@ func (r *runtime) CompileModule(ctx context.Context, binary []byte) (CompiledMod if err != nil { return nil, err } - internal.AssignModuleID(binary, len(listeners) > 0, r.ensureTermination) + internal.AssignModuleID(binary, listeners, r.ensureTermination) if err = r.store.Engine.CompileModule(ctx, internal, listeners, r.ensureTermination); err != nil { return nil, err }