diff --git a/internal/platform/select_windows.go b/internal/platform/select_windows.go index 164ed17f..55f9751d 100644 --- a/internal/platform/select_windows.go +++ b/internal/platform/select_windows.go @@ -1,21 +1,37 @@ package platform import ( + "context" "syscall" "time" + "unsafe" ) // wasiFdStdin is the constant value for stdin on Wasi. // We need this constant because on Windows os.Stdin.Fd() != 0. const wasiFdStdin = 0 +// pollInterval is the interval between each calls to peekNamedPipe in pollNamedPipe +const pollInterval = 100 * time.Millisecond + +// procPeekNamedPipe is the syscall.LazyProc in kernel32 for PeekNamedPipe +var procPeekNamedPipe = kernel32.NewProc("PeekNamedPipe") + // syscall_select emulates the select syscall on Windows for two, well-known cases, returns syscall.ENOSYS for all others. -// If r contains fd 0, then it immediately returns 1 (data ready on stdin) and r will have the fd 0 bit set. +// If r contains fd 0, and it is a regular file, then it immediately returns 1 (data ready on stdin) +// and r will have the fd 0 bit set. +// If r contains fd 0, and it is a FILE_TYPE_CHAR, then it invokes PeekNamedPipe to check the buffer for input; +// if there is data ready, then it returns 1 and r will have fd 0 bit set. // If n==0 it will wait for the given timeout duration, but it will return syscall.ENOSYS if timeout is nil, // i.e. it won't block indefinitely. +// +// Note: idea taken from https://stackoverflow.com/questions/6839508/test-if-stdin-has-input-for-c-windows-and-or-linux +// PeekNamedPipe: https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe +// "GetFileType can assist in determining what device type the handle refers to. A console handle presents as FILE_TYPE_CHAR." +// https://learn.microsoft.com/en-us/windows/console/console-handles func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) { if n == 0 { - // don't block indefinitely + // Don't block indefinitely. if timeout == nil { return -1, syscall.ENOSYS } @@ -23,9 +39,85 @@ func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) return 0, nil } if r.IsSet(wasiFdStdin) { + fileType, err := syscall.GetFileType(syscall.Stdin) + if err != nil { + return 0, err + } + if fileType&syscall.FILE_TYPE_CHAR != 0 { + res, err := pollNamedPipe(context.TODO(), syscall.Stdin, timeout) + if err != nil { + return -1, err + } + if !res { + r.Zero() + return 0, nil + } + } r.Zero() r.Set(wasiFdStdin) return 1, nil } return -1, syscall.ENOSYS } + +// pollNamedPipe polls the given named pipe handle for the given duration. +// +// The implementation actually polls every 100 milliseconds until it reaches the given duration. +// The duration may be nil, in which case it will wait undefinely. The given ctx is +// used to allow for cancellation. Currently used only in tests. +func pollNamedPipe(ctx context.Context, pipeHandle syscall.Handle, duration *time.Duration) (bool, error) { + // Short circuit when the duration is zero. + if duration != nil && *duration == time.Duration(0) { + return peekNamedPipe(pipeHandle) + } + + // Ticker that emits at every pollInterval. + tick := time.NewTicker(pollInterval) + tichCh := tick.C + defer tick.Stop() + + // Timer that expires after the given duration. + // Initialize afterCh as nil: the select below will wait forever. + var afterCh <-chan time.Time + if duration != nil { + // If duration is not nil, instantiate the timer. + after := time.NewTimer(*duration) + defer after.Stop() + afterCh = after.C + } + + for { + select { + case <-ctx.Done(): + return false, nil + case <-afterCh: + return false, nil + case <-tichCh: + res, err := peekNamedPipe(pipeHandle) + if err != nil { + return false, err + } + if res { + return true, nil + } + } + } +} + +// peekNamedPipe partially exposes PeekNamedPipe from the Win32 API +// see https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe +func peekNamedPipe(handle syscall.Handle) (bool, error) { + var totalBytesAvail uint32 + totalBytesPtr := uintptr(unsafe.Pointer(&totalBytesAvail)) + _, _, err := procPeekNamedPipe.Call( + uintptr(handle), // [in] HANDLE hNamedPipe, + 0, // [out, optional] LPVOID lpBuffer, + 0, // [in] DWORD nBufferSize, + 0, // [out, optional] LPDWORD lpBytesRead + totalBytesPtr, // [out, optional] LPDWORD lpTotalBytesAvail, + 0) // [out, optional] LPDWORD lpBytesLeftThisMessage + if err == syscall.Errno(0) { + return totalBytesAvail > 0, nil + } + return totalBytesAvail > 0, err +} diff --git a/internal/platform/select_windows_test.go b/internal/platform/select_windows_test.go new file mode 100644 index 00000000..561eb9cf --- /dev/null +++ b/internal/platform/select_windows_test.go @@ -0,0 +1,194 @@ +package platform + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func TestSelect_Windows(t *testing.T) { + type result struct { + hasData bool + err error + } + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pollToChannel := func(readHandle syscall.Handle, duration *time.Duration, ch chan result) { + r := result{} + r.hasData, r.err = pollNamedPipe(testCtx, readHandle, duration) + ch <- r + close(ch) + } + + t.Run("peekNamedPipe should report the correct state of incoming data in the pipe", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + // Ensure the pipe has data. + hasData, err := peekNamedPipe(rh) + require.NoError(t, err) + require.False(t, hasData) + + // Write to the channel. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure the pipe has data. + hasData, err = peekNamedPipe(rh) + require.NoError(t, err) + require.True(t, hasData) + }) + + t.Run("pollNamedPipe should return immediately when duration is nil (no data)", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + d := time.Duration(0) + hasData, err := pollNamedPipe(testCtx, rh, &d) + require.NoError(t, err) + require.False(t, hasData) + }) + + t.Run("pollNamedPipe should return immediately when duration is nil (data)", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + // Write to the channel immediately. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Verify that the write is reported. + d := time.Duration(0) + hasData, err := pollNamedPipe(testCtx, rh, &d) + require.NoError(t, err) + require.True(t, hasData) + }) + + t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + + ch := make(chan result, 1) + go pollToChannel(rh, nil, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(500 * time.Millisecond) + require.Equal(t, 0, len(ch)) + }) + + t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + ch := make(chan result, 1) + go pollToChannel(rh, nil, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(100 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + // Write a message to the pipe. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure that the write occurs (panic after an arbitrary timeout). + select { + case <-time.After(500 * time.Millisecond): + panic("unreachable!") + case r := <-ch: + require.NoError(t, r.err) + require.True(t, r.hasData) + } + }) + + t.Run("pollNamedPipe should wait for the given duration", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + d := 500 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(100 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + // Write a message to the pipe. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure that the write occurs before the timer expires. + select { + case <-time.After(500 * time.Millisecond): + panic("no data!") + case r := <-ch: + require.NoError(t, r.err) + require.True(t, r.hasData) + } + }) + + t.Run("pollNamedPipe should timeout after the given duration", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + + d := 200 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + // Wait a little, then ensure a message has been written to the channel. + <-time.After(300 * time.Millisecond) + require.Equal(t, 1, len(ch)) + + // Ensure that the timer has expired. + res := <-ch + require.NoError(t, res.err) + require.False(t, res.hasData) + }) + + t.Run("pollNamedPipe should return when a write occurs before the given duration", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + d := 600 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + <-time.After(300 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + res := <-ch + require.NoError(t, res.err) + require.True(t, res.hasData) + }) +}