Fix issue with a delayed check of the context cancelation (#1156)

If a context is canceled after a gust func call has returned it could
cause the module to close in some cases. This change ensures that
a delayed check of the context cancelation is ignored.

I've also reduced the cost of context cancelation a bit.

Signed-off-by: Clifton Kaznocha <ckaznocha@users.noreply.github.com>
This commit is contained in:
Clifton Kaznocha
2023-02-23 17:42:13 -08:00
committed by GitHub
parent 9ebbb41adb
commit f0132ee346
2 changed files with 43 additions and 24 deletions

View File

@@ -73,29 +73,36 @@ func (m *CallContext) FailIfClosed() (err error) {
// //
// Callers of this function must invoke the returned context.CancelFunc to release the spawned Goroutine. // Callers of this function must invoke the returned context.CancelFunc to release the spawned Goroutine.
func (m *CallContext) CloseModuleOnCanceledOrTimeout(ctx context.Context) context.CancelFunc { func (m *CallContext) CloseModuleOnCanceledOrTimeout(ctx context.Context) context.CancelFunc {
goroutineDone, cancelFn := context.WithCancel(context.Background()) // Creating an empty channel in this case is a bit more efficient than
go m.closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone)() // creating a context.Context and canceling it with the same effect. We
return cancelFn // really just need to be notified when to stop listening to the users
// context. Closing the channel will unblock the select in the goroutine
// causing it to return an stop listening to ctx.Done().
cancelChan := make(chan struct{})
go m.closeModuleOnCanceledOrTimeout(ctx, cancelChan)
return func() { close(cancelChan) }
} }
// closeModuleOnCanceledOrTimeoutClosure is extracted from CloseModuleOnCanceledOrTimeout for testing. // closeModuleOnCanceledOrTimeout is extracted from CloseModuleOnCanceledOrTimeout for testing.
func (m *CallContext) closeModuleOnCanceledOrTimeoutClosure(ctx, goroutineDone context.Context) func() { func (m *CallContext) closeModuleOnCanceledOrTimeout(ctx context.Context, cancelChan <-chan struct{}) {
return func() { select {
for { case <-ctx.Done():
select { select {
case <-ctx.Done(): case <-cancelChan:
if errors.Is(ctx.Err(), context.Canceled) { // In some cases by the time this goroutine is scheduled, the caller
// TODO: figure out how to report error here. // has already closed both the context and the cancelChan. In this
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled) // case go will randomize which branch of the outer select to enter
} else if errors.Is(ctx.Err(), context.DeadlineExceeded) { // and we don't want to close the module.
// TODO: figure out how to report error here. default:
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded) if errors.Is(ctx.Err(), context.Canceled) {
} // TODO: figure out how to report error here.
return _ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
case <-goroutineDone.Done(): } else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return // TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
} }
} }
case <-cancelChan:
} }
} }

View File

@@ -361,19 +361,31 @@ func TestCallContext_CloseModuleOnCanceledOrTimeout(t *testing.T) {
t.Run("cancel works", func(t *testing.T) { t.Run("cancel works", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s} cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
goroutineDone, cancelFn := context.WithCancel(context.Background()) cancelChan := make(chan struct{})
fn := cc.closeModuleOnCanceledOrTimeoutClosure(context.Background(), goroutineDone)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
// Ensure that fn returned by closeModuleOnCanceledOrTimeoutClosure exists after cancelFn is called. // Ensure that fn returned by closeModuleOnCanceledOrTimeout exists after cancelFn is called.
go func() { go func() {
defer wg.Done() defer wg.Done()
fn() cc.closeModuleOnCanceledOrTimeout(context.Background(), cancelChan)
}() }()
cancelFn() close(cancelChan)
wg.Wait() wg.Wait()
}) })
t.Run("no close on all resources canceled", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
cancelChan := make(chan struct{})
close(cancelChan)
ctx, cancel := context.WithCancel(context.Background())
cancel()
cc.closeModuleOnCanceledOrTimeout(ctx, cancelChan)
err := cc.FailIfClosed()
require.Nil(t, err)
})
} }
type mockCloser struct{ called int } type mockCloser struct{ called int }