Add WebSocket workaround test and enhance connection handling
Some checks failed
Go / build (push) Has been cancelled
Go / release (push) Has been cancelled

- Introduced a new test file `workaround_test.go` to validate the behavior of a "dumb" WebSocket client that does not handle ping/pong messages correctly, ensuring the connection remains alive through server-side workarounds.
- Updated the `handle-websocket.go` file to improve message size handling and refactor ping/pong logic, allowing for direct message sending and better error management.
- Enhanced the `listener.go` file to support a more robust write channel mechanism, allowing pings to interrupt writes and improving overall connection management.
- Bumped version to v0.23.4 to reflect these changes.
This commit is contained in:
2025-11-03 13:49:14 +00:00
parent 2614b51068
commit ed412dcb7e
6 changed files with 269 additions and 264 deletions

View File

@@ -12,6 +12,7 @@ import (
"lol.mleku.dev/log"
"next.orly.dev/pkg/encoders/envelopes/authenvelope"
"next.orly.dev/pkg/encoders/hex"
"next.orly.dev/pkg/protocol/publish"
"next.orly.dev/pkg/utils/units"
)
@@ -20,7 +21,7 @@ const (
DefaultPongWait = 60 * time.Second
DefaultPingWait = DefaultPongWait / 2
DefaultWriteTimeout = 3 * time.Second
DefaultMaxMessageSize = 100 * units.Mb
DefaultMaxMessageSize = 512000 // Match khatru's MaxMessageSize
// ClientMessageSizeLimit is the maximum message size that clients can handle
// This is set to 100MB to allow large messages
ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB
@@ -83,7 +84,7 @@ whitelist:
remote: remote,
req: r,
startTime: time.Now(),
writeChan: make(chan WriteRequest, 100), // Buffered channel for writes
writeChan: make(chan publish.WriteRequest, 100), // Buffered channel for writes
writeDone: make(chan struct{}),
}
@@ -119,13 +120,6 @@ whitelist:
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return nil
})
// Set ping handler - extends read deadline when pings are received
// Send pong through write channel
conn.SetPingHandler(func(msg string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
deadline := time.Now().Add(DefaultWriteTimeout)
return listener.WriteControl(websocket.PongMessage, []byte{}, deadline)
})
// Don't pass cancel to Pinger - it should not be able to cancel the connection context
go s.Pinger(ctx, listener, ticker)
defer func() {
@@ -135,11 +129,6 @@ whitelist:
cancel()
ticker.Stop()
// Close write channel to signal worker to exit
close(listener.writeChan)
// Wait for write worker to finish
<-listener.writeDone
// Cancel all subscriptions for this connection
log.D.F("cancelling subscriptions for %s", remote)
listener.publishers.Receive(&W{
@@ -162,6 +151,11 @@ whitelist:
} else {
log.D.F("ws connection %s was not authenticated", remote)
}
// Close write channel to signal worker to exit
close(listener.writeChan)
// Wait for write worker to finish
<-listener.writeDone
}()
for {
select {
@@ -191,97 +185,25 @@ whitelist:
typ, msg, err = conn.ReadMessage()
if err != nil {
// Check if the error is due to context cancellation
if err == context.Canceled || strings.Contains(err.Error(), "context canceled") {
log.T.F("connection from %s cancelled (context done): %v", remote, err)
return
}
if strings.Contains(
err.Error(), "use of closed network connection",
if websocket.IsUnexpectedCloseError(
err,
websocket.CloseNormalClosure, // 1000
websocket.CloseGoingAway, // 1001
websocket.CloseNoStatusReceived, // 1005
websocket.CloseAbnormalClosure, // 1006
4537, // some client seems to send many of these
) {
return
}
// Handle EOF errors gracefully - these occur when client closes connection
// or sends incomplete/malformed WebSocket frames
if strings.Contains(err.Error(), "EOF") ||
strings.Contains(err.Error(), "failed to read frame header") {
log.T.F("connection from %s closed: %v", remote, err)
return
}
// Handle timeout errors specifically - these can occur on idle connections
// but pongs should extend the deadline, so a timeout usually means dead connection
if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded") {
log.T.F("connection from %s read timeout (likely dead connection): %v", remote, err)
return
}
// Handle message too big errors specifically
if strings.Contains(err.Error(), "message too large") ||
strings.Contains(err.Error(), "read limited at") {
log.D.F("client %s hit message size limit: %v", remote, err)
// Don't log this as an error since it's a client-side limit
// Just close the connection gracefully
return
}
// Check for websocket close errors
if websocket.IsCloseError(err, websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
websocket.CloseUnsupportedData,
websocket.CloseInvalidFramePayloadData) {
log.T.F("connection from %s closed: %v", remote, err)
} else if websocket.IsCloseError(err, websocket.CloseMessageTooBig) {
log.D.F("client %s sent message too big: %v", remote, err)
} else {
log.E.F("unexpected close error from %s: %v", remote, err)
log.I.F("websocket connection closed from %s: %v", remote, err)
}
cancel() // Cancel context like khatru does
return
}
if typ == websocket.PingMessage {
log.D.F("received PING from %s, sending PONG", remote)
// Send pong through write channel
deadline := time.Now().Add(DefaultWriteTimeout)
pongStart := time.Now()
if err = listener.WriteControl(websocket.PongMessage, msg, deadline); err != nil {
pongDuration := time.Since(pongStart)
// Check if this is a timeout vs a connection error
isTimeout := strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded")
isConnectionError := strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "broken pipe") ||
strings.Contains(err.Error(), "connection reset") ||
websocket.IsCloseError(err, websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived)
if isConnectionError {
log.E.F(
"failed to send PONG to %s after %v (connection error): %v", remote,
pongDuration, err,
)
return
} else if isTimeout {
// Timeout on pong - log but don't close immediately
// The read deadline will catch dead connections
log.W.F(
"failed to send PONG to %s after %v (timeout, but connection may still be alive): %v", remote,
pongDuration, err,
)
// Continue - don't close connection on pong timeout
} else {
// Unknown error - log and continue
log.E.F(
"failed to send PONG to %s after %v (unknown error): %v", remote,
pongDuration, err,
)
// Continue - don't close on unknown errors
}
continue
}
pongDuration := time.Since(pongStart)
log.D.F("sent PONG to %s successfully in %v", remote, pongDuration)
if pongDuration > time.Millisecond*50 {
log.D.F("SLOW PONG to %s: %v (>50ms)", remote, pongDuration)
// Send pong directly (like khatru does)
if err = conn.WriteMessage(websocket.PongMessage, nil); err != nil {
log.E.F("failed to send PONG to %s: %v", remote, err)
return
}
continue
}
@@ -300,68 +222,25 @@ func (s *Server) Pinger(
defer func() {
log.D.F("pinger shutting down")
ticker.Stop()
// DO NOT call cancel here - the pinger should not be able to cancel the connection context
// The connection handler will cancel the context when the connection is actually closing
}()
var err error
pingCount := 0
for {
select {
case <-ticker.C:
pingCount++
log.D.F("sending PING #%d", pingCount)
// Send ping through write channel
deadline := time.Now().Add(DefaultWriteTimeout)
pingStart := time.Now()
if err = listener.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
pingDuration := time.Since(pingStart)
// Check if this is a timeout vs a connection error
isTimeout := strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded")
isConnectionError := strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "broken pipe") ||
strings.Contains(err.Error(), "connection reset") ||
websocket.IsCloseError(err, websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived)
if isConnectionError {
log.E.F(
"PING #%d FAILED after %v (connection error): %v", pingCount, pingDuration,
err,
)
chk.E(err)
return
} else if isTimeout {
// Timeout on ping - log but don't stop pinger immediately
// The read deadline will catch dead connections
log.W.F(
"PING #%d timeout after %v (connection may still be alive): %v", pingCount, pingDuration,
err,
)
// Continue - don't stop pinger on timeout
} else {
// Unknown error - log and continue
log.E.F(
"PING #%d FAILED after %v (unknown error): %v", pingCount, pingDuration,
err,
)
// Continue - don't stop pinger on unknown errors
}
continue
}
pingDuration := time.Since(pingStart)
log.D.F("PING #%d sent successfully in %v", pingCount, pingDuration)
if pingDuration > time.Millisecond*100 {
log.D.F("SLOW PING #%d: %v (>100ms)", pingCount, pingDuration)
}
case <-ctx.Done():
log.T.F("pinger context cancelled after %d pings", pingCount)
return
case <-ticker.C:
pingCount++
// Send ping request through write channel - this allows pings to interrupt other writes
select {
case <-ctx.Done():
return
case listener.writeChan <- publish.WriteRequest{IsPing: true, MsgType: pingCount}:
// Ping request queued successfully
case <-time.After(DefaultWriteTimeout):
log.E.F("ping #%d channel timeout - connection may be overloaded", pingCount)
return
}
}
}
}

View File

@@ -18,9 +18,6 @@ import (
"next.orly.dev/pkg/utils/atomic"
)
// WriteRequest represents a write operation to be performed by the write worker
type WriteRequest = publish.WriteRequest
type Listener struct {
*Server
conn *websocket.Conn
@@ -32,7 +29,7 @@ type Listener struct {
startTime time.Time
isBlacklisted bool // Marker to identify blacklisted IPs
blacklistTimeout time.Time // When to timeout blacklisted connections
writeChan chan WriteRequest // Channel for write requests
writeChan chan publish.WriteRequest // Channel for write requests (back to queued approach)
writeDone chan struct{} // Closed when write worker exits
// Diagnostics: per-connection counters
msgCount int
@@ -46,92 +43,13 @@ func (l *Listener) Ctx() context.Context {
return l.ctx
}
// writeWorker is the single goroutine that handles all writes to the websocket connection.
// This serializes all writes to prevent concurrent write panics.
func (l *Listener) writeWorker() {
var channelClosed bool
defer func() {
// Only unregister write channel if connection is actually dead/closing
// Unregister if:
// 1. Context is cancelled (connection closing)
// 2. Channel was closed (connection closing)
// 3. Connection error occurred (already handled inline)
if l.ctx.Err() != nil || channelClosed {
// Connection is closing - safe to unregister
if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil {
log.D.F("ws->%s write worker: unregistering write channel (connection closing)", l.remote)
socketPub.SetWriteChan(l.conn, nil)
}
} else {
// Exiting for other reasons (timeout, etc.) but connection may still be alive
// Don't unregister - let the connection cleanup handle it
log.D.F("ws->%s write worker: exiting but connection may still be alive, keeping write channel registered", l.remote)
}
close(l.writeDone)
}()
for {
select {
case <-l.ctx.Done():
// Context cancelled - connection is closing
log.D.F("ws->%s write worker: context cancelled, exiting", l.remote)
return
case req, ok := <-l.writeChan:
if !ok {
// Channel closed - connection is closing
channelClosed = true
log.D.F("ws->%s write worker: write channel closed, exiting", l.remote)
return
}
deadline := req.Deadline
if deadline.IsZero() {
deadline = time.Now().Add(DefaultWriteTimeout)
}
l.conn.SetWriteDeadline(deadline)
writeStart := time.Now()
var err error
if req.IsControl {
err = l.conn.WriteControl(req.MsgType, req.Data, deadline)
} else {
err = l.conn.WriteMessage(req.MsgType, req.Data)
}
if err != nil {
writeDuration := time.Since(writeStart)
log.E.F("ws->%s write worker FAILED: len=%d duration=%v error=%v",
l.remote, len(req.Data), writeDuration, err)
// Check for connection errors - if so, stop the worker
isConnectionError := strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "broken pipe") ||
strings.Contains(err.Error(), "connection reset") ||
websocket.IsCloseError(err, websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived)
if isConnectionError {
// Connection is dead - unregister channel immediately
log.D.F("ws->%s write worker: connection error detected, unregistering write channel", l.remote)
if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil {
socketPub.SetWriteChan(l.conn, nil)
}
return
}
// Continue for other errors (timeouts, etc.) - connection may still be alive
log.D.F("ws->%s write worker: non-fatal error (timeout?), continuing", l.remote)
} else {
writeDuration := time.Since(writeStart)
if writeDuration > time.Millisecond*100 {
log.D.F("ws->%s write worker SLOW: len=%d duration=%v",
l.remote, len(req.Data), writeDuration)
}
}
}
}
}
func (l *Listener) Write(p []byte) (n int, err error) {
// Send write request to channel - non-blocking with timeout
select {
case <-l.ctx.Done():
return 0, l.ctx.Err()
case l.writeChan <- WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}:
case l.writeChan <- publish.WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}:
return len(p), nil
case <-time.After(DefaultWriteTimeout):
log.E.F("ws->%s write channel timeout", l.remote)
@@ -144,7 +62,7 @@ func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time
select {
case <-l.ctx.Done():
return l.ctx.Err()
case l.writeChan <- WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}:
case l.writeChan <- publish.WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}:
return nil
case <-time.After(DefaultWriteTimeout):
log.E.F("ws->%s writeControl channel timeout", l.remote)
@@ -152,6 +70,72 @@ func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time
}
}
// writeWorker is the single goroutine that handles all writes to the websocket connection.
// This serializes all writes to prevent concurrent write panics and allows pings to interrupt writes.
func (l *Listener) writeWorker() {
defer func() {
// Only unregister write channel if connection is actually dead/closing
// Unregister if:
// 1. Context is cancelled (connection closing)
// 2. Channel was closed (connection closing)
// 3. Connection error occurred (already handled inline)
if l.ctx.Err() != nil {
// Connection is closing - safe to unregister
if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil {
log.D.F("ws->%s write worker: unregistering write channel (connection closing)", l.remote)
socketPub.SetWriteChan(l.conn, nil)
}
} else {
// Exiting for other reasons (timeout, etc.) but connection may still be valid
log.D.F("ws->%s write worker exiting unexpectedly", l.remote)
}
close(l.writeDone)
}()
for {
select {
case <-l.ctx.Done():
log.D.F("ws->%s write worker context cancelled", l.remote)
return
case req, ok := <-l.writeChan:
if !ok {
log.D.F("ws->%s write channel closed", l.remote)
return
}
// Handle the write request
var err error
if req.IsPing {
// Special handling for ping messages
log.D.F("sending PING #%d", req.MsgType)
deadline := time.Now().Add(DefaultWriteTimeout)
err = l.conn.WriteControl(websocket.PingMessage, nil, deadline)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
log.E.F("error writing ping: %v; closing websocket", err)
}
return
}
} else if req.IsControl {
// Control message
err = l.conn.WriteControl(req.MsgType, req.Data, req.Deadline)
if err != nil {
log.E.F("ws->%s control write failed: %v", l.remote, err)
return
}
} else {
// Regular message
l.conn.SetWriteDeadline(time.Now().Add(DefaultWriteTimeout))
err = l.conn.WriteMessage(req.MsgType, req.Data)
if err != nil {
log.E.F("ws->%s write failed: %v", l.remote, err)
return
}
}
}
}
}
// getManagedACL returns the managed ACL instance if available
func (l *Listener) getManagedACL() *database.ManagedACL {
// Get the managed ACL instance from the ACL registry

View File

@@ -23,6 +23,9 @@ import (
const Type = "socketapi"
// WriteChanMap maps websocket connections to their write channels
type WriteChanMap map[*websocket.Conn]chan publish.WriteRequest
type Subscription struct {
remote string
AuthedPubkey []byte
@@ -33,9 +36,6 @@ type Subscription struct {
// connections.
type Map map[*websocket.Conn]map[string]Subscription
// WriteChanMap maps websocket connections to their write channels
type WriteChanMap map[*websocket.Conn]chan<- publish.WriteRequest
type W struct {
*websocket.Conn
@@ -88,25 +88,6 @@ func NewPublisher(c context.Context) (publisher *P) {
func (p *P) Type() (typeName string) { return Type }
// SetWriteChan stores the write channel for a websocket connection
// If writeChan is nil, the entry is removed from the map
func (p *P) SetWriteChan(conn *websocket.Conn, writeChan chan<- publish.WriteRequest) {
p.Mx.Lock()
defer p.Mx.Unlock()
if writeChan == nil {
delete(p.WriteChans, conn)
} else {
p.WriteChans[conn] = writeChan
}
}
// GetWriteChan returns the write channel for a websocket connection
func (p *P) GetWriteChan(conn *websocket.Conn) (chan<- publish.WriteRequest, bool) {
p.Mx.RLock()
defer p.Mx.RUnlock()
ch, ok := p.WriteChans[conn]
return ch, ok
}
// Receive handles incoming messages to manage websocket listener subscriptions
// and associated filters.
@@ -319,14 +300,14 @@ func (p *P) Deliver(ev *event.E) {
log.D.F("subscription delivery QUEUED: event=%s to=%s sub=%s len=%d",
hex.Enc(ev.ID), d.sub.remote, d.id, len(msgData))
case <-time.After(DefaultWriteTimeout):
log.E.F("subscription delivery TIMEOUT: event=%s to=%s sub=%s (write channel full)",
log.E.F("subscription delivery TIMEOUT: event=%s to=%s sub=%s",
hex.Enc(ev.ID), d.sub.remote, d.id)
// Check if connection is still valid
p.Mx.RLock()
stillSubscribed = p.Map[d.w] != nil
p.Mx.RUnlock()
if !stillSubscribed {
log.D.F("removing failed subscriber connection due to channel timeout: %s", d.sub.remote)
log.D.F("removing failed subscriber connection: %s", d.sub.remote)
p.removeSubscriber(d.w)
}
}
@@ -352,6 +333,26 @@ func (p *P) removeSubscriberId(ws *websocket.Conn, id string) {
}
}
// SetWriteChan stores the write channel for a websocket connection
// If writeChan is nil, the entry is removed from the map
func (p *P) SetWriteChan(conn *websocket.Conn, writeChan chan publish.WriteRequest) {
p.Mx.Lock()
defer p.Mx.Unlock()
if writeChan == nil {
delete(p.WriteChans, conn)
} else {
p.WriteChans[conn] = writeChan
}
}
// GetWriteChan returns the write channel for a websocket connection
func (p *P) GetWriteChan(conn *websocket.Conn) (chan publish.WriteRequest, bool) {
p.Mx.RLock()
defer p.Mx.RUnlock()
ch, ok := p.WriteChans[conn]
return ch, ok
}
// removeSubscriber removes a websocket from the P collection.
func (p *P) removeSubscriber(ws *websocket.Conn) {
p.Mx.Lock()