wasi: improve select(2) impl on Windows (#1398)
Signed-off-by: Edoardo Vacchi <evacchi@users.noreply.github.com>
This commit is contained in:
@@ -1,21 +1,37 @@
|
|||||||
package platform
|
package platform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// wasiFdStdin is the constant value for stdin on Wasi.
|
// wasiFdStdin is the constant value for stdin on Wasi.
|
||||||
// We need this constant because on Windows os.Stdin.Fd() != 0.
|
// We need this constant because on Windows os.Stdin.Fd() != 0.
|
||||||
const wasiFdStdin = 0
|
const wasiFdStdin = 0
|
||||||
|
|
||||||
|
// pollInterval is the interval between each calls to peekNamedPipe in pollNamedPipe
|
||||||
|
const pollInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
|
// procPeekNamedPipe is the syscall.LazyProc in kernel32 for PeekNamedPipe
|
||||||
|
var procPeekNamedPipe = kernel32.NewProc("PeekNamedPipe")
|
||||||
|
|
||||||
// syscall_select emulates the select syscall on Windows for two, well-known cases, returns syscall.ENOSYS for all others.
|
// syscall_select emulates the select syscall on Windows for two, well-known cases, returns syscall.ENOSYS for all others.
|
||||||
// If r contains fd 0, then it immediately returns 1 (data ready on stdin) and r will have the fd 0 bit set.
|
// If r contains fd 0, and it is a regular file, then it immediately returns 1 (data ready on stdin)
|
||||||
|
// and r will have the fd 0 bit set.
|
||||||
|
// If r contains fd 0, and it is a FILE_TYPE_CHAR, then it invokes PeekNamedPipe to check the buffer for input;
|
||||||
|
// if there is data ready, then it returns 1 and r will have fd 0 bit set.
|
||||||
// If n==0 it will wait for the given timeout duration, but it will return syscall.ENOSYS if timeout is nil,
|
// If n==0 it will wait for the given timeout duration, but it will return syscall.ENOSYS if timeout is nil,
|
||||||
// i.e. it won't block indefinitely.
|
// i.e. it won't block indefinitely.
|
||||||
|
//
|
||||||
|
// Note: idea taken from https://stackoverflow.com/questions/6839508/test-if-stdin-has-input-for-c-windows-and-or-linux
|
||||||
|
// PeekNamedPipe: https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe
|
||||||
|
// "GetFileType can assist in determining what device type the handle refers to. A console handle presents as FILE_TYPE_CHAR."
|
||||||
|
// https://learn.microsoft.com/en-us/windows/console/console-handles
|
||||||
func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) {
|
func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
// don't block indefinitely
|
// Don't block indefinitely.
|
||||||
if timeout == nil {
|
if timeout == nil {
|
||||||
return -1, syscall.ENOSYS
|
return -1, syscall.ENOSYS
|
||||||
}
|
}
|
||||||
@@ -23,9 +39,85 @@ func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error)
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
if r.IsSet(wasiFdStdin) {
|
if r.IsSet(wasiFdStdin) {
|
||||||
|
fileType, err := syscall.GetFileType(syscall.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if fileType&syscall.FILE_TYPE_CHAR != 0 {
|
||||||
|
res, err := pollNamedPipe(context.TODO(), syscall.Stdin, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
if !res {
|
||||||
|
r.Zero()
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
r.Zero()
|
r.Zero()
|
||||||
r.Set(wasiFdStdin)
|
r.Set(wasiFdStdin)
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
return -1, syscall.ENOSYS
|
return -1, syscall.ENOSYS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pollNamedPipe polls the given named pipe handle for the given duration.
|
||||||
|
//
|
||||||
|
// The implementation actually polls every 100 milliseconds until it reaches the given duration.
|
||||||
|
// The duration may be nil, in which case it will wait undefinely. The given ctx is
|
||||||
|
// used to allow for cancellation. Currently used only in tests.
|
||||||
|
func pollNamedPipe(ctx context.Context, pipeHandle syscall.Handle, duration *time.Duration) (bool, error) {
|
||||||
|
// Short circuit when the duration is zero.
|
||||||
|
if duration != nil && *duration == time.Duration(0) {
|
||||||
|
return peekNamedPipe(pipeHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ticker that emits at every pollInterval.
|
||||||
|
tick := time.NewTicker(pollInterval)
|
||||||
|
tichCh := tick.C
|
||||||
|
defer tick.Stop()
|
||||||
|
|
||||||
|
// Timer that expires after the given duration.
|
||||||
|
// Initialize afterCh as nil: the select below will wait forever.
|
||||||
|
var afterCh <-chan time.Time
|
||||||
|
if duration != nil {
|
||||||
|
// If duration is not nil, instantiate the timer.
|
||||||
|
after := time.NewTimer(*duration)
|
||||||
|
defer after.Stop()
|
||||||
|
afterCh = after.C
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, nil
|
||||||
|
case <-afterCh:
|
||||||
|
return false, nil
|
||||||
|
case <-tichCh:
|
||||||
|
res, err := peekNamedPipe(pipeHandle)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if res {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// peekNamedPipe partially exposes PeekNamedPipe from the Win32 API
|
||||||
|
// see https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe
|
||||||
|
func peekNamedPipe(handle syscall.Handle) (bool, error) {
|
||||||
|
var totalBytesAvail uint32
|
||||||
|
totalBytesPtr := uintptr(unsafe.Pointer(&totalBytesAvail))
|
||||||
|
_, _, err := procPeekNamedPipe.Call(
|
||||||
|
uintptr(handle), // [in] HANDLE hNamedPipe,
|
||||||
|
0, // [out, optional] LPVOID lpBuffer,
|
||||||
|
0, // [in] DWORD nBufferSize,
|
||||||
|
0, // [out, optional] LPDWORD lpBytesRead
|
||||||
|
totalBytesPtr, // [out, optional] LPDWORD lpTotalBytesAvail,
|
||||||
|
0) // [out, optional] LPDWORD lpBytesLeftThisMessage
|
||||||
|
if err == syscall.Errno(0) {
|
||||||
|
return totalBytesAvail > 0, nil
|
||||||
|
}
|
||||||
|
return totalBytesAvail > 0, err
|
||||||
|
}
|
||||||
|
|||||||
194
internal/platform/select_windows_test.go
Normal file
194
internal/platform/select_windows_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package platform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tetratelabs/wazero/internal/testing/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSelect_Windows(t *testing.T) {
|
||||||
|
type result struct {
|
||||||
|
hasData bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
testCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pollToChannel := func(readHandle syscall.Handle, duration *time.Duration, ch chan result) {
|
||||||
|
r := result{}
|
||||||
|
r.hasData, r.err = pollNamedPipe(testCtx, readHandle, duration)
|
||||||
|
ch <- r
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("peekNamedPipe should report the correct state of incoming data in the pipe", func(t *testing.T) {
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
wh := syscall.Handle(w.Fd())
|
||||||
|
|
||||||
|
// Ensure the pipe has data.
|
||||||
|
hasData, err := peekNamedPipe(rh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, hasData)
|
||||||
|
|
||||||
|
// Write to the channel.
|
||||||
|
msg, err := syscall.ByteSliceFromString("test\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = syscall.Write(wh, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure the pipe has data.
|
||||||
|
hasData, err = peekNamedPipe(rh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hasData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should return immediately when duration is nil (no data)", func(t *testing.T) {
|
||||||
|
r, _, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
d := time.Duration(0)
|
||||||
|
hasData, err := pollNamedPipe(testCtx, rh, &d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, hasData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should return immediately when duration is nil (data)", func(t *testing.T) {
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
wh := syscall.Handle(w.Fd())
|
||||||
|
|
||||||
|
// Write to the channel immediately.
|
||||||
|
msg, err := syscall.ByteSliceFromString("test\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = syscall.Write(wh, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that the write is reported.
|
||||||
|
d := time.Duration(0)
|
||||||
|
hasData, err := pollNamedPipe(testCtx, rh, &d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hasData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) {
|
||||||
|
r, _, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go pollToChannel(rh, nil, ch)
|
||||||
|
|
||||||
|
// Wait a little, then ensure no writes occurred.
|
||||||
|
<-time.After(500 * time.Millisecond)
|
||||||
|
require.Equal(t, 0, len(ch))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) {
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
wh := syscall.Handle(w.Fd())
|
||||||
|
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go pollToChannel(rh, nil, ch)
|
||||||
|
|
||||||
|
// Wait a little, then ensure no writes occurred.
|
||||||
|
<-time.After(100 * time.Millisecond)
|
||||||
|
require.Equal(t, 0, len(ch))
|
||||||
|
|
||||||
|
// Write a message to the pipe.
|
||||||
|
msg, err := syscall.ByteSliceFromString("test\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = syscall.Write(wh, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure that the write occurs (panic after an arbitrary timeout).
|
||||||
|
select {
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
panic("unreachable!")
|
||||||
|
case r := <-ch:
|
||||||
|
require.NoError(t, r.err)
|
||||||
|
require.True(t, r.hasData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should wait for the given duration", func(t *testing.T) {
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
wh := syscall.Handle(w.Fd())
|
||||||
|
|
||||||
|
d := 500 * time.Millisecond
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go pollToChannel(rh, &d, ch)
|
||||||
|
|
||||||
|
// Wait a little, then ensure no writes occurred.
|
||||||
|
<-time.After(100 * time.Millisecond)
|
||||||
|
require.Equal(t, 0, len(ch))
|
||||||
|
|
||||||
|
// Write a message to the pipe.
|
||||||
|
msg, err := syscall.ByteSliceFromString("test\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = syscall.Write(wh, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure that the write occurs before the timer expires.
|
||||||
|
select {
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
panic("no data!")
|
||||||
|
case r := <-ch:
|
||||||
|
require.NoError(t, r.err)
|
||||||
|
require.True(t, r.hasData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should timeout after the given duration", func(t *testing.T) {
|
||||||
|
r, _, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
|
||||||
|
d := 200 * time.Millisecond
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go pollToChannel(rh, &d, ch)
|
||||||
|
|
||||||
|
// Wait a little, then ensure a message has been written to the channel.
|
||||||
|
<-time.After(300 * time.Millisecond)
|
||||||
|
require.Equal(t, 1, len(ch))
|
||||||
|
|
||||||
|
// Ensure that the timer has expired.
|
||||||
|
res := <-ch
|
||||||
|
require.NoError(t, res.err)
|
||||||
|
require.False(t, res.hasData)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pollNamedPipe should return when a write occurs before the given duration", func(t *testing.T) {
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rh := syscall.Handle(r.Fd())
|
||||||
|
wh := syscall.Handle(w.Fd())
|
||||||
|
|
||||||
|
d := 600 * time.Millisecond
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go pollToChannel(rh, &d, ch)
|
||||||
|
|
||||||
|
<-time.After(300 * time.Millisecond)
|
||||||
|
require.Equal(t, 0, len(ch))
|
||||||
|
|
||||||
|
msg, err := syscall.ByteSliceFromString("test\n")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = syscall.Write(wh, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
res := <-ch
|
||||||
|
require.NoError(t, res.err)
|
||||||
|
require.True(t, res.hasData)
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user