Close and return immediately if the context is already canceled (#1158)
Signed-off-by: Clifton Kaznocha <ckaznocha@users.noreply.github.com> Co-authored-by: Clifton Kaznocha <ckaznocha@users.noreply.github.com>
This commit is contained in:
@@ -636,6 +636,17 @@ func functionFromUintptr(ptr uintptr) *function {
|
||||
|
||||
// Call implements the same method as documented on wasm.ModuleEngine.
|
||||
func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, params []uint64) (results []uint64, err error) {
|
||||
if ce.fn.parent.withEnsureTermination {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// If the provided context is already done, close the call context
|
||||
// and return the error.
|
||||
callCtx.CloseWithCtxErr(ctx)
|
||||
return nil, callCtx.FailIfClosed()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
tp := ce.initialFn.source.Type
|
||||
|
||||
paramCount := len(params)
|
||||
|
||||
@@ -777,6 +777,17 @@ func (ce *callEngine) Call(ctx context.Context, m *wasm.CallContext, params []ui
|
||||
}
|
||||
|
||||
func (ce *callEngine) call(ctx context.Context, callCtx *wasm.CallContext, tf *function, params []uint64) (results []uint64, err error) {
|
||||
if ce.compiled.parent.ensureTermination {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// If the provided context is already done, close the call context
|
||||
// and return the error.
|
||||
callCtx.CloseWithCtxErr(ctx)
|
||||
return nil, callCtx.FailIfClosed()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
ft := tf.source.Type
|
||||
paramSignature := ft.ParamNumInUint64
|
||||
paramCount := len(params)
|
||||
|
||||
@@ -117,6 +117,15 @@ func testEnsureTerminationOnClose(t *testing.T, r wazero.Runtime) {
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("module \"%s\" closed with context canceled", t.Name()))
|
||||
})
|
||||
|
||||
t.Run("context cancel in advance", func(t *testing.T) {
|
||||
_, infinite := newInfiniteLoopFn(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = infinite.Call(ctx)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("module \"%s\" closed with context canceled", t.Name()))
|
||||
})
|
||||
|
||||
t.Run("context timeout", func(t *testing.T) {
|
||||
_, infinite := newInfiniteLoopFn(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
|
||||
@@ -94,18 +94,27 @@ func (m *CallContext) closeModuleOnCanceledOrTimeout(ctx context.Context, cancel
|
||||
// case go will randomize which branch of the outer select to enter
|
||||
// and we don't want to close the module.
|
||||
default:
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
// TODO: figure out how to report error here.
|
||||
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
|
||||
} else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
// TODO: figure out how to report error here.
|
||||
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
|
||||
}
|
||||
m.CloseWithCtxErr(ctx)
|
||||
}
|
||||
case <-cancelChan:
|
||||
}
|
||||
}
|
||||
|
||||
// CloseWithExitCode closes the module with an exit code based on the type of
|
||||
// error reported by the context.
|
||||
//
|
||||
// If the context's error is unknown or nil, the module does not close.
|
||||
func (m *CallContext) CloseWithCtxErr(ctx context.Context) {
|
||||
switch {
|
||||
case errors.Is(ctx.Err(), context.Canceled):
|
||||
// TODO: figure out how to report error here.
|
||||
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
|
||||
case errors.Is(ctx.Err(), context.DeadlineExceeded):
|
||||
// TODO: figure out how to report error here.
|
||||
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements the same method as documented on api.Module
|
||||
func (m *CallContext) Name() string {
|
||||
return m.module.Name
|
||||
|
||||
@@ -388,6 +388,44 @@ func TestCallContext_CloseModuleOnCanceledOrTimeout(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCallContext_CloseWithCtxErr(t *testing.T) {
|
||||
s := newStore()
|
||||
|
||||
t.Run("context canceled", func(t *testing.T) {
|
||||
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
cc.CloseWithCtxErr(ctx)
|
||||
|
||||
err := cc.FailIfClosed()
|
||||
require.EqualError(t, err, "module \"test\" closed with context canceled")
|
||||
})
|
||||
|
||||
t.Run("context timeout", func(t *testing.T) {
|
||||
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
|
||||
duration := time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
|
||||
time.Sleep(duration * 2)
|
||||
|
||||
cc.CloseWithCtxErr(ctx)
|
||||
|
||||
err := cc.FailIfClosed()
|
||||
require.EqualError(t, err, "module \"test\" closed with context deadline exceeded")
|
||||
})
|
||||
|
||||
t.Run("no error", func(t *testing.T) {
|
||||
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
|
||||
|
||||
cc.CloseWithCtxErr(context.Background())
|
||||
|
||||
err := cc.FailIfClosed()
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type mockCloser struct{ called int }
|
||||
|
||||
func (m *mockCloser) Close(context.Context) error {
|
||||
|
||||
Reference in New Issue
Block a user