diff --git a/sys/error.go b/sys/error.go index 9aea47b1..c3efbad9 100644 --- a/sys/error.go +++ b/sys/error.go @@ -73,5 +73,11 @@ func (e *ExitError) Is(err error) bool { if target, ok := err.(*ExitError); ok { return e.exitCode == target.exitCode } + if e.exitCode == ExitCodeContextCanceled && err == context.Canceled { + return true + } + if e.exitCode == ExitCodeDeadlineExceeded && err == context.DeadlineExceeded { + return true + } return false } diff --git a/sys/error_test.go b/sys/error_test.go index b1b278c6..bbf60755 100644 --- a/sys/error_test.go +++ b/sys/error_test.go @@ -1,6 +1,7 @@ package sys import ( + "context" "errors" "testing" @@ -55,11 +56,13 @@ func TestExitError_Error(t *testing.T) { err := NewExitError(ExitCodeDeadlineExceeded) require.Equal(t, ExitCodeDeadlineExceeded, err.ExitCode()) require.EqualError(t, err, "module closed with context deadline exceeded") + require.ErrorIs(t, err, context.DeadlineExceeded, "exit code context deadline exceeded should work") }) t.Run("cancel", func(t *testing.T) { err := NewExitError(ExitCodeContextCanceled) require.Equal(t, ExitCodeContextCanceled, err.ExitCode()) require.EqualError(t, err, "module closed with context canceled") + require.ErrorIs(t, err, context.Canceled, "exit code context canceled should work") }) t.Run("normal", func(t *testing.T) { err := NewExitError(123)