Fixes races around {api.Module, wazero.Runtime}.CloseWithExitCode (#1119)

Signed-off-by: Takeshi Yoneda <takeshi@tetrate.io>
This commit is contained in:
Takeshi Yoneda
2023-02-13 22:17:57 -08:00
committed by GitHub
parent 6287ee48d1
commit f84a68b576
6 changed files with 107 additions and 29 deletions

View File

@@ -672,7 +672,7 @@ func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, param
ce.initializeStack(tp, params)
if ce.fn.parent.withEnsureTermination {
done := callCtx.SetExitCodeOnCanceledOrTimeout(ctx)
done := callCtx.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}

View File

@@ -816,7 +816,7 @@ func (ce *callEngine) call(ctx context.Context, callCtx *wasm.CallContext, tf *f
}
if ce.compiled.parent.ensureTermination {
done := callCtx.SetExitCodeOnCanceledOrTimeout(ctx)
done := callCtx.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}

View File

@@ -67,27 +67,29 @@ func (m *CallContext) FailIfClosed() (err error) {
return nil
}
// SetExitCodeOnCanceledOrTimeout take a context `ctx`, which might be a Cancel or Timeout context,
// CloseModuleOnCanceledOrTimeout take a context `ctx`, which might be a Cancel or Timeout context,
// and spawns the Goroutine to check the context is canceled ot deadline exceeded. If it reaches
// one of the conditions, it sets the appropriate exit code.
//
// Callers of this function must invoke the returned context.CancelFunc to release the spawned Goroutine.
func (m *CallContext) SetExitCodeOnCanceledOrTimeout(ctx context.Context) context.CancelFunc {
func (m *CallContext) CloseModuleOnCanceledOrTimeout(ctx context.Context) context.CancelFunc {
goroutineDone, cancelFn := context.WithCancel(context.Background())
go m.setExitCodeOnCanceledOrTimeoutClosure(ctx, goroutineDone)()
go m.closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone)()
return cancelFn
}
// setExitCodeOnCanceledOrTimeoutClosure is extracted from SetExitCodeOnCanceledOrTimeout for testing.
func (m *CallContext) setExitCodeOnCanceledOrTimeoutClosure(ctx, goroutineDone context.Context) func() {
// closeModuleOnCanceledOrTimeoutClosure is extracted from CloseModuleOnCanceledOrTimeout for testing.
func (m *CallContext) closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone context.Context) func() {
return func() {
for {
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
m.setExitCode(sys.ExitCodeContextCanceled)
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
} else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
m.setExitCode(sys.ExitCodeDeadlineExceeded)
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
}
return
case <-goroutineDone.Done():
@@ -122,14 +124,24 @@ func (m *CallContext) Close(ctx context.Context) (err error) {
// CloseWithExitCode implements the same method as documented on api.Module.
func (m *CallContext) CloseWithExitCode(ctx context.Context, exitCode uint32) (err error) {
m.setExitCode(exitCode)
if !m.setExitCode(exitCode) {
return nil // not an error to have already closed
}
_ = m.s.deleteModule(m.Name())
return m.ensureResourcesClosed(ctx)
}
func (m *CallContext) setExitCode(exitCode uint32) {
// closeWithExitCode is the same as CloseWithExitCode besides this doesn't delete it from Store.moduleList.
func (m *CallContext) closeWithExitCode(ctx context.Context, exitCode uint32) (err error) {
if !m.setExitCode(exitCode) {
return nil // not an error to have already closed
}
return m.ensureResourcesClosed(ctx)
}
func (m *CallContext) setExitCode(exitCode uint32) bool {
closed := uint64(1) + uint64(exitCode)<<32 // Store exitCode as high-order bits.
atomic.CompareAndSwapUint64(m.Closed, 0, closed)
return atomic.CompareAndSwapUint64(m.Closed, 0, closed)
}
// ensureResourcesClosed ensures that resources assigned to CallContext is released.

View File

@@ -12,6 +12,7 @@ import (
internalsys "github.com/tetratelabs/wazero/internal/sys"
"github.com/tetratelabs/wazero/internal/sysfs"
testfs "github.com/tetratelabs/wazero/internal/testing/fs"
"github.com/tetratelabs/wazero/internal/testing/hammer"
"github.com/tetratelabs/wazero/internal/testing/require"
)
@@ -162,8 +163,15 @@ func TestCallContext_Close(t *testing.T) {
_, ok := fsCtx.LookupFile(3)
require.True(t, ok, "sysCtx.openedFiles was empty")
// Closing should not err.
require.NoError(t, m.Close(testCtx))
// Closing should not err even when concurrently closed.
hammer.NewHammer(t, 100, 10).Run(func(name string) {
require.NoError(t, m.Close(testCtx))
// closeWithExitCode is the one called during Store.CloseWithExitCode.
require.NoError(t, m.closeWithExitCode(testCtx, 0))
}, nil)
if t.Failed() {
return // At least one test failed, so return now.
}
// Verify our intended side-effect
_, ok = fsCtx.LookupFile(3)
@@ -291,13 +299,14 @@ func TestCallContext_CallDynamic(t *testing.T) {
})
}
func TestCallContext_SetExitCodeOnCanceledOrTimeout(t *testing.T) {
func TestCallContext_CloseModuleOnCanceledOrTimeout(t *testing.T) {
s := newStore()
t.Run("timeout", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}}
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
const duration = time.Second
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
done := cc.SetExitCodeOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
done := cc.CloseModuleOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
time.Sleep(duration * 2)
defer done()
@@ -306,9 +315,9 @@ func TestCallContext_SetExitCodeOnCanceledOrTimeout(t *testing.T) {
})
t.Run("cancel", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}}
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
ctx, cancel := context.WithCancel(context.Background())
done := cc.SetExitCodeOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
done := cc.CloseModuleOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
cancel()
// Make sure nothing panics or otherwise gets weird with redundant call to cancel().
cancel()
@@ -321,27 +330,27 @@ func TestCallContext_SetExitCodeOnCanceledOrTimeout(t *testing.T) {
})
t.Run("timeout over cancel", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}}
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
const duration = time.Second
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Wrap the cancel context by timeout.
ctx, cancel = context.WithTimeout(ctx, duration)
defer cancel()
done := cc.SetExitCodeOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
done := cc.CloseModuleOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
time.Sleep(duration * 2)
defer done()
})
t.Run("cancel over timeout", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}}
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
ctx, cancel := context.WithCancel(context.Background())
// Wrap the timeout context by cancel context.
var timeoutDone context.CancelFunc
ctx, timeoutDone = context.WithTimeout(ctx, time.Second*1000)
defer timeoutDone()
done := cc.SetExitCodeOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
done := cc.CloseModuleOnCanceledOrTimeout(context.WithValue(ctx, struct{}{}, 1)) // Wrapping arbitrary context.
cancel()
defer done()
@@ -351,13 +360,13 @@ func TestCallContext_SetExitCodeOnCanceledOrTimeout(t *testing.T) {
})
t.Run("cancel works", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}}
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
goroutineDone, cancelFn := context.WithCancel(context.Background())
fn := cc.setExitCodeOnCanceledOrTimeoutClosure(context.Background(), goroutineDone)
fn := cc.closeModuleOnCanceledOrTimeoutClosure(context.Background(), goroutineDone)
var wg sync.WaitGroup
wg.Add(1)
// Ensure that fn returned by setExitCodeOnCanceledOrTimeoutClosure exists after cancelFn is called.
// Ensure that fn returned by closeModuleOnCanceledOrTimeoutClosure exists after cancelFn is called.
go func() {
defer wg.Done()
fn()

View File

@@ -631,12 +631,11 @@ func (s *Store) CloseWithExitCode(ctx context.Context, exitCode uint32) (err err
s.mux.Lock()
defer s.mux.Unlock()
// Close modules in reverse initialization order.
// Close modules in reverse initialization order.
for node := s.moduleList; node != nil; node = node.next {
// If closing this module errs, proceed anyway to close the others.
if m := node.module; m != nil {
m.CallCtx.setExitCode(exitCode)
if e := m.CallCtx.ensureResourcesClosed(ctx); e != nil && err == nil {
if e := m.CallCtx.closeWithExitCode(ctx, exitCode); e != nil && err == nil {
// TODO: use multiple errors handling in Go 1.20.
err = e // first error
}
}

View File

@@ -234,6 +234,64 @@ func TestStore_hammer(t *testing.T) {
require.Nil(t, s.moduleList)
}
func TestStore_hammer_close(t *testing.T) {
const importedModuleName = "imported"
m, err := NewHostModule(importedModuleName, map[string]interface{}{"fn": func() {}}, map[string]*HostFuncNames{"fn": {}}, api.CoreFeaturesV1)
require.NoError(t, err)
s := newStore()
imported, err := s.Instantiate(testCtx, m, importedModuleName, nil)
require.NoError(t, err)
_, ok := s.nameToNode[imported.Name()]
require.True(t, ok)
importingModule := &Module{
TypeSection: []*FunctionType{v_v},
FunctionSection: []uint32{0},
CodeSection: []*Code{{Body: []byte{OpcodeEnd}}},
MemorySection: &Memory{Min: 1, Cap: 1},
MemoryDefinitionSection: []*MemoryDefinition{{}},
GlobalSection: []*Global{{
Type: &GlobalType{ValType: ValueTypeI32},
Init: &ConstantExpression{Opcode: OpcodeI32Const, Data: leb128.EncodeInt32(1)},
}},
TableSection: []*Table{{Min: 10}},
ImportSection: []*Import{
{Type: ExternTypeFunc, Module: importedModuleName, Name: "fn", DescFunc: 0},
},
}
importingModule.BuildFunctionDefinitions()
const instCount = 10000
instances := make([]api.Module, instCount)
for i := 0; i < instCount; i++ {
mod, instantiateErr := s.Instantiate(testCtx, importingModule, strconv.Itoa(i), sys.DefaultContext(nil))
require.NoError(t, instantiateErr)
instances[i] = mod
}
hammer.NewHammer(t, 100, 2).Run(func(name string) {
for i := 0; i < instCount; i++ {
if i == instCount/2 {
// Close store concurrently as well.
err := s.CloseWithExitCode(testCtx, 0)
require.NoError(t, err)
}
err := instances[i].CloseWithExitCode(testCtx, 0)
require.NoError(t, err)
}
require.NoError(t, err)
}, nil)
if t.Failed() {
return // At least one test failed, so return now.
}
// All instances are freed.
require.Nil(t, s.moduleList)
}
func TestStore_Instantiate_Errors(t *testing.T) {
const importedModuleName = "imported"
const importingModuleName = "test"