diff --git a/internal/wasm/call_context.go b/internal/wasm/call_context.go index c70c2e88..5443caf8 100644 --- a/internal/wasm/call_context.go +++ b/internal/wasm/call_context.go @@ -73,29 +73,36 @@ func (m *CallContext) FailIfClosed() (err error) { // // Callers of this function must invoke the returned context.CancelFunc to release the spawned Goroutine. func (m *CallContext) CloseModuleOnCanceledOrTimeout(ctx context.Context) context.CancelFunc { - goroutineDone, cancelFn := context.WithCancel(context.Background()) - go m.closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone)() - return cancelFn + // Creating an empty channel in this case is a bit more efficient than + // creating a context.Context and canceling it with the same effect. We + // really just need to be notified when to stop listening to the users + // context. Closing the channel will unblock the select in the goroutine + // causing it to return an stop listening to ctx.Done(). + cancelChan := make(chan struct{}) + go m.closeModuleOnCanceledOrTimeout(ctx, cancelChan) + return func() { close(cancelChan) } } -// closeModuleOnCanceledOrTimeoutClosure is extracted from CloseModuleOnCanceledOrTimeout for testing. -func (m *CallContext) closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone context.Context) func() { - return func() { - for { - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.Canceled) { - // TODO: figure out how to report error here. - _ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled) - } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { - // TODO: figure out how to report error here. - _ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded) - } - return - case <-goroutineDone.Done(): - return +// closeModuleOnCanceledOrTimeout is extracted from CloseModuleOnCanceledOrTimeout for testing. +func (m *CallContext) closeModuleOnCanceledOrTimeout(ctx context.Context, cancelChan <-chan struct{}) { + select { + case <-ctx.Done(): + select { + case <-cancelChan: + // In some cases by the time this goroutine is scheduled, the caller + // has already closed both the context and the cancelChan. In this + // case go will randomize which branch of the outer select to enter + // and we don't want to close the module. + default: + if errors.Is(ctx.Err(), context.Canceled) { + // TODO: figure out how to report error here. + _ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled) + } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { + // TODO: figure out how to report error here. + _ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded) } } + case <-cancelChan: } } diff --git a/internal/wasm/call_context_test.go b/internal/wasm/call_context_test.go index ab758a86..2f69c535 100644 --- a/internal/wasm/call_context_test.go +++ b/internal/wasm/call_context_test.go @@ -361,19 +361,31 @@ func TestCallContext_CloseModuleOnCanceledOrTimeout(t *testing.T) { t.Run("cancel works", func(t *testing.T) { cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s} - goroutineDone, cancelFn := context.WithCancel(context.Background()) - fn := cc.closeModuleOnCanceledOrTimeoutClosure(context.Background(), goroutineDone) + cancelChan := make(chan struct{}) var wg sync.WaitGroup wg.Add(1) - // Ensure that fn returned by closeModuleOnCanceledOrTimeoutClosure exists after cancelFn is called. + // Ensure that fn returned by closeModuleOnCanceledOrTimeout exists after cancelFn is called. go func() { defer wg.Done() - fn() + cc.closeModuleOnCanceledOrTimeout(context.Background(), cancelChan) }() - cancelFn() + close(cancelChan) wg.Wait() }) + + t.Run("no close on all resources canceled", func(t *testing.T) { + cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s} + cancelChan := make(chan struct{}) + close(cancelChan) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + cc.closeModuleOnCanceledOrTimeout(ctx, cancelChan) + + err := cc.FailIfClosed() + require.Nil(t, err) + }) } type mockCloser struct{ called int }