Close and return immediately if the context is already canceled (#1158)

Signed-off-by: Clifton Kaznocha <ckaznocha@users.noreply.github.com>
Co-authored-by: Clifton Kaznocha <ckaznocha@users.noreply.github.com>
This commit is contained in:
Clifton Kaznocha
2023-02-23 17:58:07 -08:00
committed by GitHub
parent f0132ee346
commit ecb5b1ad03
5 changed files with 85 additions and 7 deletions

View File

@@ -636,6 +636,17 @@ func functionFromUintptr(ptr uintptr) *function {
// Call implements the same method as documented on wasm.ModuleEngine.
func (ce *callEngine) Call(ctx context.Context, callCtx *wasm.CallContext, params []uint64) (results []uint64, err error) {
if ce.fn.parent.withEnsureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the call context
// and return the error.
callCtx.CloseWithCtxErr(ctx)
return nil, callCtx.FailIfClosed()
default:
}
}
tp := ce.initialFn.source.Type
paramCount := len(params)

View File

@@ -777,6 +777,17 @@ func (ce *callEngine) Call(ctx context.Context, m *wasm.CallContext, params []ui
}
func (ce *callEngine) call(ctx context.Context, callCtx *wasm.CallContext, tf *function, params []uint64) (results []uint64, err error) {
if ce.compiled.parent.ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the call context
// and return the error.
callCtx.CloseWithCtxErr(ctx)
return nil, callCtx.FailIfClosed()
default:
}
}
ft := tf.source.Type
paramSignature := ft.ParamNumInUint64
paramCount := len(params)

View File

@@ -117,6 +117,15 @@ func testEnsureTerminationOnClose(t *testing.T, r wazero.Runtime) {
require.Contains(t, err.Error(), fmt.Sprintf("module \"%s\" closed with context canceled", t.Name()))
})
t.Run("context cancel in advance", func(t *testing.T) {
_, infinite := newInfiniteLoopFn(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = infinite.Call(ctx)
require.Error(t, err)
require.Contains(t, err.Error(), fmt.Sprintf("module \"%s\" closed with context canceled", t.Name()))
})
t.Run("context timeout", func(t *testing.T) {
_, infinite := newInfiniteLoopFn(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View File

@@ -94,18 +94,27 @@ func (m *CallContext) closeModuleOnCanceledOrTimeout(ctx context.Context, cancel
// case go will randomize which branch of the outer select to enter
// and we don't want to close the module.
default:
if errors.Is(ctx.Err(), context.Canceled) {
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
} else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
}
m.CloseWithCtxErr(ctx)
}
case <-cancelChan:
}
}
// CloseWithExitCode closes the module with an exit code based on the type of
// error reported by the context.
//
// If the context's error is unknown or nil, the module does not close.
func (m *CallContext) CloseWithCtxErr(ctx context.Context) {
switch {
case errors.Is(ctx.Err(), context.Canceled):
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
case errors.Is(ctx.Err(), context.DeadlineExceeded):
// TODO: figure out how to report error here.
_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
}
}
// Name implements the same method as documented on api.Module
func (m *CallContext) Name() string {
return m.module.Name

View File

@@ -388,6 +388,44 @@ func TestCallContext_CloseModuleOnCanceledOrTimeout(t *testing.T) {
})
}
func TestCallContext_CloseWithCtxErr(t *testing.T) {
s := newStore()
t.Run("context canceled", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
ctx, cancel := context.WithCancel(context.Background())
cancel()
cc.CloseWithCtxErr(ctx)
err := cc.FailIfClosed()
require.EqualError(t, err, "module \"test\" closed with context canceled")
})
t.Run("context timeout", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
duration := time.Second
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
time.Sleep(duration * 2)
cc.CloseWithCtxErr(ctx)
err := cc.FailIfClosed()
require.EqualError(t, err, "module \"test\" closed with context deadline exceeded")
})
t.Run("no error", func(t *testing.T) {
cc := &CallContext{Closed: new(uint64), module: &ModuleInstance{Name: "test"}, s: s}
cc.CloseWithCtxErr(context.Background())
err := cc.FailIfClosed()
require.Nil(t, err)
})
}
type mockCloser struct{ called int }
func (m *mockCloser) Close(context.Context) error {