wasi: improve select(2) impl on Windows (#1398)

Signed-off-by: Edoardo Vacchi <evacchi@users.noreply.github.com>
This commit is contained in:
Edoardo Vacchi
2023-04-28 01:36:48 +02:00
committed by GitHub
parent c82ad896f6
commit 34a639b770
2 changed files with 288 additions and 2 deletions

View File

@@ -1,21 +1,37 @@
package platform package platform
import ( import (
"context"
"syscall" "syscall"
"time" "time"
"unsafe"
) )
// wasiFdStdin is the constant value for stdin on Wasi. // wasiFdStdin is the constant value for stdin on Wasi.
// We need this constant because on Windows os.Stdin.Fd() != 0. // We need this constant because on Windows os.Stdin.Fd() != 0.
const wasiFdStdin = 0 const wasiFdStdin = 0
// pollInterval is the interval between each calls to peekNamedPipe in pollNamedPipe
const pollInterval = 100 * time.Millisecond
// procPeekNamedPipe is the syscall.LazyProc in kernel32 for PeekNamedPipe
var procPeekNamedPipe = kernel32.NewProc("PeekNamedPipe")
// syscall_select emulates the select syscall on Windows for two, well-known cases, returns syscall.ENOSYS for all others. // syscall_select emulates the select syscall on Windows for two, well-known cases, returns syscall.ENOSYS for all others.
// If r contains fd 0, then it immediately returns 1 (data ready on stdin) and r will have the fd 0 bit set. // If r contains fd 0, and it is a regular file, then it immediately returns 1 (data ready on stdin)
// and r will have the fd 0 bit set.
// If r contains fd 0, and it is a FILE_TYPE_CHAR, then it invokes PeekNamedPipe to check the buffer for input;
// if there is data ready, then it returns 1 and r will have fd 0 bit set.
// If n==0 it will wait for the given timeout duration, but it will return syscall.ENOSYS if timeout is nil, // If n==0 it will wait for the given timeout duration, but it will return syscall.ENOSYS if timeout is nil,
// i.e. it won't block indefinitely. // i.e. it won't block indefinitely.
//
// Note: idea taken from https://stackoverflow.com/questions/6839508/test-if-stdin-has-input-for-c-windows-and-or-linux
// PeekNamedPipe: https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe
// "GetFileType can assist in determining what device type the handle refers to. A console handle presents as FILE_TYPE_CHAR."
// https://learn.microsoft.com/en-us/windows/console/console-handles
func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) { func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error) {
if n == 0 { if n == 0 {
// don't block indefinitely // Don't block indefinitely.
if timeout == nil { if timeout == nil {
return -1, syscall.ENOSYS return -1, syscall.ENOSYS
} }
@@ -23,9 +39,85 @@ func syscall_select(n int, r, w, e *FdSet, timeout *time.Duration) (int, error)
return 0, nil return 0, nil
} }
if r.IsSet(wasiFdStdin) { if r.IsSet(wasiFdStdin) {
fileType, err := syscall.GetFileType(syscall.Stdin)
if err != nil {
return 0, err
}
if fileType&syscall.FILE_TYPE_CHAR != 0 {
res, err := pollNamedPipe(context.TODO(), syscall.Stdin, timeout)
if err != nil {
return -1, err
}
if !res {
r.Zero()
return 0, nil
}
}
r.Zero() r.Zero()
r.Set(wasiFdStdin) r.Set(wasiFdStdin)
return 1, nil return 1, nil
} }
return -1, syscall.ENOSYS return -1, syscall.ENOSYS
} }
// pollNamedPipe polls the given named pipe handle for the given duration.
//
// The implementation actually polls every 100 milliseconds until it reaches the given duration.
// The duration may be nil, in which case it will wait undefinely. The given ctx is
// used to allow for cancellation. Currently used only in tests.
func pollNamedPipe(ctx context.Context, pipeHandle syscall.Handle, duration *time.Duration) (bool, error) {
// Short circuit when the duration is zero.
if duration != nil && *duration == time.Duration(0) {
return peekNamedPipe(pipeHandle)
}
// Ticker that emits at every pollInterval.
tick := time.NewTicker(pollInterval)
tichCh := tick.C
defer tick.Stop()
// Timer that expires after the given duration.
// Initialize afterCh as nil: the select below will wait forever.
var afterCh <-chan time.Time
if duration != nil {
// If duration is not nil, instantiate the timer.
after := time.NewTimer(*duration)
defer after.Stop()
afterCh = after.C
}
for {
select {
case <-ctx.Done():
return false, nil
case <-afterCh:
return false, nil
case <-tichCh:
res, err := peekNamedPipe(pipeHandle)
if err != nil {
return false, err
}
if res {
return true, nil
}
}
}
}
// peekNamedPipe partially exposes PeekNamedPipe from the Win32 API
// see https://learn.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-peeknamedpipe
func peekNamedPipe(handle syscall.Handle) (bool, error) {
var totalBytesAvail uint32
totalBytesPtr := uintptr(unsafe.Pointer(&totalBytesAvail))
_, _, err := procPeekNamedPipe.Call(
uintptr(handle), // [in] HANDLE hNamedPipe,
0, // [out, optional] LPVOID lpBuffer,
0, // [in] DWORD nBufferSize,
0, // [out, optional] LPDWORD lpBytesRead
totalBytesPtr, // [out, optional] LPDWORD lpTotalBytesAvail,
0) // [out, optional] LPDWORD lpBytesLeftThisMessage
if err == syscall.Errno(0) {
return totalBytesAvail > 0, nil
}
return totalBytesAvail > 0, err
}

View File

@@ -0,0 +1,194 @@
package platform
import (
"context"
"os"
"syscall"
"testing"
"time"
"github.com/tetratelabs/wazero/internal/testing/require"
)
func TestSelect_Windows(t *testing.T) {
type result struct {
hasData bool
err error
}
testCtx, cancel := context.WithCancel(context.Background())
defer cancel()
pollToChannel := func(readHandle syscall.Handle, duration *time.Duration, ch chan result) {
r := result{}
r.hasData, r.err = pollNamedPipe(testCtx, readHandle, duration)
ch <- r
close(ch)
}
t.Run("peekNamedPipe should report the correct state of incoming data in the pipe", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
wh := syscall.Handle(w.Fd())
// Ensure the pipe has data.
hasData, err := peekNamedPipe(rh)
require.NoError(t, err)
require.False(t, hasData)
// Write to the channel.
msg, err := syscall.ByteSliceFromString("test\n")
require.NoError(t, err)
_, err = syscall.Write(wh, msg)
require.NoError(t, err)
// Ensure the pipe has data.
hasData, err = peekNamedPipe(rh)
require.NoError(t, err)
require.True(t, hasData)
})
t.Run("pollNamedPipe should return immediately when duration is nil (no data)", func(t *testing.T) {
r, _, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
d := time.Duration(0)
hasData, err := pollNamedPipe(testCtx, rh, &d)
require.NoError(t, err)
require.False(t, hasData)
})
t.Run("pollNamedPipe should return immediately when duration is nil (data)", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
wh := syscall.Handle(w.Fd())
// Write to the channel immediately.
msg, err := syscall.ByteSliceFromString("test\n")
require.NoError(t, err)
_, err = syscall.Write(wh, msg)
require.NoError(t, err)
// Verify that the write is reported.
d := time.Duration(0)
hasData, err := pollNamedPipe(testCtx, rh, &d)
require.NoError(t, err)
require.True(t, hasData)
})
t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) {
r, _, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
ch := make(chan result, 1)
go pollToChannel(rh, nil, ch)
// Wait a little, then ensure no writes occurred.
<-time.After(500 * time.Millisecond)
require.Equal(t, 0, len(ch))
})
t.Run("pollNamedPipe should wait forever when duration is nil", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
wh := syscall.Handle(w.Fd())
ch := make(chan result, 1)
go pollToChannel(rh, nil, ch)
// Wait a little, then ensure no writes occurred.
<-time.After(100 * time.Millisecond)
require.Equal(t, 0, len(ch))
// Write a message to the pipe.
msg, err := syscall.ByteSliceFromString("test\n")
require.NoError(t, err)
_, err = syscall.Write(wh, msg)
require.NoError(t, err)
// Ensure that the write occurs (panic after an arbitrary timeout).
select {
case <-time.After(500 * time.Millisecond):
panic("unreachable!")
case r := <-ch:
require.NoError(t, r.err)
require.True(t, r.hasData)
}
})
t.Run("pollNamedPipe should wait for the given duration", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
wh := syscall.Handle(w.Fd())
d := 500 * time.Millisecond
ch := make(chan result, 1)
go pollToChannel(rh, &d, ch)
// Wait a little, then ensure no writes occurred.
<-time.After(100 * time.Millisecond)
require.Equal(t, 0, len(ch))
// Write a message to the pipe.
msg, err := syscall.ByteSliceFromString("test\n")
require.NoError(t, err)
_, err = syscall.Write(wh, msg)
require.NoError(t, err)
// Ensure that the write occurs before the timer expires.
select {
case <-time.After(500 * time.Millisecond):
panic("no data!")
case r := <-ch:
require.NoError(t, r.err)
require.True(t, r.hasData)
}
})
t.Run("pollNamedPipe should timeout after the given duration", func(t *testing.T) {
r, _, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
d := 200 * time.Millisecond
ch := make(chan result, 1)
go pollToChannel(rh, &d, ch)
// Wait a little, then ensure a message has been written to the channel.
<-time.After(300 * time.Millisecond)
require.Equal(t, 1, len(ch))
// Ensure that the timer has expired.
res := <-ch
require.NoError(t, res.err)
require.False(t, res.hasData)
})
t.Run("pollNamedPipe should return when a write occurs before the given duration", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
rh := syscall.Handle(r.Fd())
wh := syscall.Handle(w.Fd())
d := 600 * time.Millisecond
ch := make(chan result, 1)
go pollToChannel(rh, &d, ch)
<-time.After(300 * time.Millisecond)
require.Equal(t, 0, len(ch))
msg, err := syscall.ByteSliceFromString("test\n")
require.NoError(t, err)
_, err = syscall.Write(wh, msg)
require.NoError(t, err)
res := <-ch
require.NoError(t, res.err)
require.True(t, res.hasData)
})
}