diff --git a/internal/gojs/state.go b/internal/gojs/state.go index cbbcdf87..4a4fcc30 100644 --- a/internal/gojs/state.go +++ b/internal/gojs/state.go @@ -7,11 +7,12 @@ import ( "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/internal/gojs/goos" + "github.com/tetratelabs/wazero/internal/gojs/values" ) func NewState(ctx context.Context) *State { return &State{ - values: &values{ids: map[interface{}]uint32{}}, + values: values.NewValues(), valueGlobal: newJsGlobal(getRoundTripper(ctx)), cwd: "/", _nextCallbackTimeoutID: 1, @@ -105,7 +106,7 @@ func LoadValue(ctx context.Context, ref goos.Ref) interface{} { //nolint if f, ok := ref.ParseFloat(); ok { // numbers are passed through as a Ref return f } - return getState(ctx).values.get(uint32(ref)) + return getState(ctx).values.Get(uint32(ref)) } } @@ -131,16 +132,16 @@ func storeRef(ctx context.Context, v interface{}) goos.Ref { //nolint } else if c, ok := v.(*jsVal); ok { return c.ref // already stored } else if _, ok := v.(*event); ok { - id := getState(ctx).values.increment(v) + id := getState(ctx).values.Increment(v) return goos.ValueRef(id, goos.TypeFlagFunction) } else if _, ok := v.(funcWrapper); ok { - id := getState(ctx).values.increment(v) + id := getState(ctx).values.Increment(v) return goos.ValueRef(id, goos.TypeFlagFunction) } else if _, ok := v.(jsFn); ok { - id := getState(ctx).values.increment(v) + id := getState(ctx).values.Increment(v) return goos.ValueRef(id, goos.TypeFlagFunction) } else if _, ok := v.(string); ok { - id := getState(ctx).values.increment(v) + id := getState(ctx).values.Increment(v) return goos.ValueRef(id, goos.TypeFlagString) } else if i32, ok := v.(int32); ok { return toFloatRef(float64(i32)) @@ -153,7 +154,7 @@ func storeRef(ctx context.Context, v interface{}) goos.Ref { //nolint } else if f64, ok := v.(float64); ok { return toFloatRef(f64) } - id := getState(ctx).values.increment(v) + id := getState(ctx).values.Increment(v) return goos.ValueRef(id, goos.TypeFlagObject) } @@ -165,56 +166,10 @@ func toFloatRef(f float64) goos.Ref { return goos.Ref(api.EncodeF64(f)) } -type values struct { - // Below is needed to avoid exhausting the ID namespace finalizeRef reclaims - // See https://go-review.googlesource.com/c/go/+/203600 - - values []interface{} // values indexed by ID, nil - goRefCounts []uint32 // recount pair-indexed with values - ids map[interface{}]uint32 // live values - idPool []uint32 // reclaimed IDs (values[i] = nil, goRefCounts[i] nil -} - -func (j *values) get(id uint32) interface{} { - index := id - goos.NextID - if index >= uint32(len(j.values)) { - panic(fmt.Errorf("id %d is out of range %d", id, len(j.values))) - } - return j.values[index] -} - -func (j *values) increment(v interface{}) uint32 { - id, ok := j.ids[v] - if !ok { - if len(j.idPool) == 0 { - id, j.values, j.goRefCounts = uint32(len(j.values)), append(j.values, v), append(j.goRefCounts, 0) - } else { - id, j.idPool = j.idPool[len(j.idPool)-1], j.idPool[:len(j.idPool)-1] - j.values[id], j.goRefCounts[id] = v, 0 - } - j.ids[v] = id - } - j.goRefCounts[id]++ - return id + goos.NextID -} - -func (j *values) decrement(id uint32) { - // Special IDs are not goos.Refcounted. - if id < goos.NextID { - return - } - id -= goos.NextID - j.goRefCounts[id]-- - if j.goRefCounts[id] == 0 { - j.values[id] = nil - j.idPool = append(j.idPool, id) - } -} - // State holds state used by the "go" imports used by gojs. // Note: This is module-scoped. type State struct { - values *values + values *values.Values _pendingEvent *event // _lastEvent was the last _pendingEvent value _lastEvent *event @@ -257,10 +212,7 @@ func (s *State) close() { // Reset all state recursively to their initial values. This allows our // unit tests to check we closed everything. s._scheduledTimeouts = map[uint32]chan bool{} - s.values.values = nil - s.values.goRefCounts = nil - s.values.ids = map[interface{}]uint32{} - s.values.idPool = nil + s.values.Reset() s._pendingEvent = nil s._lastEvent = nil s._nextCallbackTimeoutID = 1 diff --git a/internal/gojs/syscall.go b/internal/gojs/syscall.go index 9df39949..36e9ff06 100644 --- a/internal/gojs/syscall.go +++ b/internal/gojs/syscall.go @@ -26,7 +26,7 @@ func finalizeRef(ctx context.Context, _ api.Module, stack goos.Stack) { id := uint32(r) // 32-bits of the ref are the ID - getState(ctx).values.decrement(id) + getState(ctx).values.Decrement(id) } // StringVal implements js.stringVal, which is used to load the string for diff --git a/internal/gojs/values/values.go b/internal/gojs/values/values.go new file mode 100644 index 00000000..53dce81a --- /dev/null +++ b/internal/gojs/values/values.go @@ -0,0 +1,73 @@ +package values + +import ( + "fmt" + + "github.com/tetratelabs/wazero/internal/gojs/goos" +) + +func NewValues() *Values { + ret := &Values{} + ret.Reset() + return ret +} + +type Values struct { + // Below is needed to avoid exhausting the ID namespace finalizeRef reclaims + // See https://go-review.googlesource.com/c/go/+/203600 + + values []interface{} // values indexed by ID, nil + goRefCounts []uint32 // recount pair-indexed with values + ids map[interface{}]uint32 // live values + idPool []uint32 // reclaimed IDs (values[i] = nil, goRefCounts[i] nil +} + +func (j *Values) Get(id uint32) interface{} { + index := id - goos.NextID + if index >= uint32(len(j.values)) { + panic(fmt.Errorf("id %d is out of range %d", id, len(j.values))) + } + if v := j.values[index]; v == nil { + panic(fmt.Errorf("value for %d was nil", id)) + } else { + return v + } +} + +func (j *Values) Increment(v interface{}) uint32 { + id, ok := j.ids[v] + if !ok { + if len(j.idPool) == 0 { + id, j.values, j.goRefCounts = uint32(len(j.values)), append(j.values, v), append(j.goRefCounts, 0) + } else { + id, j.idPool = j.idPool[len(j.idPool)-1], j.idPool[:len(j.idPool)-1] + j.values[id], j.goRefCounts[id] = v, 0 + } + j.ids[v] = id + } + j.goRefCounts[id]++ + + return id + goos.NextID +} + +func (j *Values) Decrement(id uint32) { + // Special IDs are not goos.Refcounted. + if id < goos.NextID { + return + } + id -= goos.NextID + j.goRefCounts[id]-- + if j.goRefCounts[id] == 0 { + v := j.values[id] + j.values[id] = nil + delete(j.ids, v) + j.idPool = append(j.idPool, id) + } +} + +func (j *Values) Reset() { + j.values = nil + j.goRefCounts = nil + j.ids = map[interface{}]uint32{} + j.idPool = nil +} diff --git a/internal/gojs/values/values_test.go b/internal/gojs/values/values_test.go new file mode 100644 index 00000000..552683e2 --- /dev/null +++ b/internal/gojs/values/values_test.go @@ -0,0 +1,51 @@ +package values + +import ( + "testing" + + "github.com/tetratelabs/wazero/internal/gojs/goos" + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func Test_Values(t *testing.T) { + t.Parallel() + + vs := NewValues() + + err := require.CapturePanic(func() { + _ = vs.Get(goos.NextID) + }) + require.EqualError(t, err, "id 18 is out of range 0") + + v1 := "foo" + id1 := vs.Increment(v1) + v2 := "bar" + id2 := vs.Increment(v2) + + require.Equal(t, goos.NextID, id1) + require.Equal(t, v1, vs.Get(id1)) + + // Second value should be at a sequential position + require.Equal(t, id1+1, id2) + require.Equal(t, v2, vs.Get(id2)) + + // Incrementing the ref count should return the same ID + require.Equal(t, id1, vs.Increment(v1)) + require.Equal(t, v1, vs.Get(id1)) + + // Decrement and we should still get the value + vs.Decrement(id1) + require.Equal(t, v1, vs.Get(id1)) + + // Decrement again, and we should panic, as go should never attempt to + // get a value it already decremented to zero. + vs.Decrement(id1) + err = require.CapturePanic(func() { + _ = vs.Get(id1) + }) + require.EqualError(t, err, "value for 18 was nil") + + // Since the ID is no longer in use, we should be able to revive it. + require.Equal(t, id1, vs.Increment(v1)) + require.Equal(t, v1, vs.Get(id1)) +}