Enhance WebSocket write handling and connection management
Some checks failed
Go / build (push) Has been cancelled
Go / release (push) Has been cancelled

- Introduced a buffered write channel and a dedicated write worker goroutine to serialize write operations, preventing concurrent write panics.
- Updated the Write and WriteControl methods to send messages through the write channel, improving error handling and connection stability.
- Refactored ping and pong handlers to utilize the new write channel for sending control messages.
- Enhanced publisher logic to manage write channels for WebSocket connections, ensuring efficient message delivery and error handling.
- Bumped version to v0.23.0 to reflect these changes.
This commit is contained in:
2025-11-02 17:02:28 +00:00
parent 0123c2d6f5
commit 354a2f1cda
5 changed files with 187 additions and 130 deletions

View File

@@ -83,6 +83,16 @@ whitelist:
remote: remote, remote: remote,
req: r, req: r,
startTime: time.Now(), startTime: time.Now(),
writeChan: make(chan WriteRequest, 100), // Buffered channel for writes
writeDone: make(chan struct{}),
}
// Start write worker goroutine
go listener.writeWorker()
// Register write channel with publisher
if socketPub := listener.publishers.GetSocketPublisher(); socketPub != nil {
socketPub.SetWriteChan(conn, listener.writeChan)
} }
// Check for blacklisted IPs // Check for blacklisted IPs
@@ -110,12 +120,14 @@ whitelist:
return nil return nil
}) })
// Set ping handler - extends read deadline when pings are received // Set ping handler - extends read deadline when pings are received
conn.SetPingHandler(func(string) error { // Send pong through write channel
conn.SetPingHandler(func(msg string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait)) conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(DefaultWriteTimeout)) deadline := time.Now().Add(DefaultWriteTimeout)
return listener.WriteControl(websocket.PongMessage, []byte{}, deadline)
}) })
// Don't pass cancel to Pinger - it should not be able to cancel the connection context // Don't pass cancel to Pinger - it should not be able to cancel the connection context
go s.Pinger(ctx, conn, ticker) go s.Pinger(ctx, listener, ticker)
defer func() { defer func() {
log.D.F("closing websocket connection from %s", remote) log.D.F("closing websocket connection from %s", remote)
@@ -123,6 +135,11 @@ whitelist:
cancel() cancel()
ticker.Stop() ticker.Stop()
// Close write channel to signal worker to exit
close(listener.writeChan)
// Wait for write worker to finish
<-listener.writeDone
// Cancel all subscriptions for this connection // Cancel all subscriptions for this connection
log.D.F("cancelling subscriptions for %s", remote) log.D.F("cancelling subscriptions for %s", remote)
listener.publishers.Receive(&W{ listener.publishers.Receive(&W{
@@ -222,11 +239,10 @@ whitelist:
} }
if typ == websocket.PingMessage { if typ == websocket.PingMessage {
log.D.F("received PING from %s, sending PONG", remote) log.D.F("received PING from %s, sending PONG", remote)
// Create a write context with timeout for pong response // Send pong through write channel
deadline := time.Now().Add(DefaultWriteTimeout) deadline := time.Now().Add(DefaultWriteTimeout)
conn.SetWriteDeadline(deadline)
pongStart := time.Now() pongStart := time.Now()
if err = conn.WriteControl(websocket.PongMessage, msg, deadline); err != nil { if err = listener.WriteControl(websocket.PongMessage, msg, deadline); err != nil {
pongDuration := time.Since(pongStart) pongDuration := time.Since(pongStart)
// Check if this is a timeout vs a connection error // Check if this is a timeout vs a connection error
@@ -279,7 +295,7 @@ whitelist:
} }
func (s *Server) Pinger( func (s *Server) Pinger(
ctx context.Context, conn *websocket.Conn, ticker *time.Ticker, ctx context.Context, listener *Listener, ticker *time.Ticker,
) { ) {
defer func() { defer func() {
log.D.F("pinger shutting down") log.D.F("pinger shutting down")
@@ -295,12 +311,11 @@ func (s *Server) Pinger(
pingCount++ pingCount++
log.D.F("sending PING #%d", pingCount) log.D.F("sending PING #%d", pingCount)
// Set write deadline for ping operation // Send ping through write channel
deadline := time.Now().Add(DefaultWriteTimeout) deadline := time.Now().Add(DefaultWriteTimeout)
conn.SetWriteDeadline(deadline)
pingStart := time.Now() pingStart := time.Now()
if err = conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil { if err = listener.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
pingDuration := time.Since(pingStart) pingDuration := time.Since(pingStart)
// Check if this is a timeout vs a connection error // Check if this is a timeout vs a connection error

View File

@@ -7,16 +7,20 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"lol.mleku.dev/chk" "lol.mleku.dev/errorf"
"lol.mleku.dev/log" "lol.mleku.dev/log"
"next.orly.dev/pkg/acl" "next.orly.dev/pkg/acl"
"next.orly.dev/pkg/database" "next.orly.dev/pkg/database"
"next.orly.dev/pkg/encoders/event" "next.orly.dev/pkg/encoders/event"
"next.orly.dev/pkg/encoders/filter" "next.orly.dev/pkg/encoders/filter"
"next.orly.dev/pkg/protocol/publish"
"next.orly.dev/pkg/utils" "next.orly.dev/pkg/utils"
"next.orly.dev/pkg/utils/atomic" "next.orly.dev/pkg/utils/atomic"
) )
// WriteRequest represents a write operation to be performed by the write worker
type WriteRequest = publish.WriteRequest
type Listener struct { type Listener struct {
*Server *Server
conn *websocket.Conn conn *websocket.Conn
@@ -28,6 +32,8 @@ type Listener struct {
startTime time.Time startTime time.Time
isBlacklisted bool // Marker to identify blacklisted IPs isBlacklisted bool // Marker to identify blacklisted IPs
blacklistTimeout time.Time // When to timeout blacklisted connections blacklistTimeout time.Time // When to timeout blacklisted connections
writeChan chan WriteRequest // Channel for write requests
writeDone chan struct{} // Closed when write worker exits
// Diagnostics: per-connection counters // Diagnostics: per-connection counters
msgCount int msgCount int
reqCount int reqCount int
@@ -40,75 +46,80 @@ func (l *Listener) Ctx() context.Context {
return l.ctx return l.ctx
} }
// writeWorker is the single goroutine that handles all writes to the websocket connection.
// This serializes all writes to prevent concurrent write panics.
func (l *Listener) writeWorker() {
defer close(l.writeDone)
for {
select {
case <-l.ctx.Done():
return
case req, ok := <-l.writeChan:
if !ok {
return
}
deadline := req.Deadline
if deadline.IsZero() {
deadline = time.Now().Add(DefaultWriteTimeout)
}
l.conn.SetWriteDeadline(deadline)
writeStart := time.Now()
var err error
if req.IsControl {
err = l.conn.WriteControl(req.MsgType, req.Data, deadline)
} else {
err = l.conn.WriteMessage(req.MsgType, req.Data)
}
if err != nil {
writeDuration := time.Since(writeStart)
log.E.F("ws->%s write worker FAILED: len=%d duration=%v error=%v",
l.remote, len(req.Data), writeDuration, err)
// Check for connection errors - if so, stop the worker
isConnectionError := strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "broken pipe") ||
strings.Contains(err.Error(), "connection reset") ||
websocket.IsCloseError(err, websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived)
if isConnectionError {
return
}
// Continue for other errors (timeouts, etc.)
} else {
writeDuration := time.Since(writeStart)
if writeDuration > time.Millisecond*100 {
log.D.F("ws->%s write worker SLOW: len=%d duration=%v",
l.remote, len(req.Data), writeDuration)
}
}
}
}
}
func (l *Listener) Write(p []byte) (n int, err error) { func (l *Listener) Write(p []byte) (n int, err error) {
start := time.Now() // Send write request to channel - non-blocking with timeout
msgLen := len(p) select {
case <-l.ctx.Done():
// Log message attempt with content preview (first 200 chars for diagnostics) return 0, l.ctx.Err()
preview := string(p) case l.writeChan <- WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}:
if len(preview) > 200 { return len(p), nil
preview = preview[:200] + "..." case <-time.After(DefaultWriteTimeout):
log.E.F("ws->%s write channel timeout", l.remote)
return 0, errorf.E("write channel timeout")
} }
log.T.F( }
"ws->%s attempting write: len=%d preview=%q", l.remote, msgLen, preview,
)
// Use a separate context with timeout for writes to prevent race conditions // WriteControl sends a control message through the write channel
// where the main connection context gets cancelled while writing events func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time) (err error) {
deadline := time.Now().Add(DefaultWriteTimeout) select {
l.conn.SetWriteDeadline(deadline) case <-l.ctx.Done():
return l.ctx.Err()
// Attempt the write operation case l.writeChan <- WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}:
writeStart := time.Now() return nil
if err = l.conn.WriteMessage(websocket.TextMessage, p); err != nil { case <-time.After(DefaultWriteTimeout):
writeDuration := time.Since(writeStart) log.E.F("ws->%s writeControl channel timeout", l.remote)
totalDuration := time.Since(start) return errorf.E("writeControl channel timeout")
// Log detailed failure information
log.E.F(
"ws->%s WRITE FAILED: len=%d duration=%v write_duration=%v error=%v preview=%q",
l.remote, msgLen, totalDuration, writeDuration, err, preview,
)
// Check if this is a context timeout
if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F(
"ws->%s write timeout after %v (limit=%v)", l.remote,
writeDuration, DefaultWriteTimeout,
)
}
// Check connection state
if l.conn != nil {
log.T.F(
"ws->%s connection state during failure: remote_addr=%v",
l.remote, l.req.RemoteAddr,
)
}
chk.E(err) // Still call the original error handler
return
} }
// Log successful write with timing
writeDuration := time.Since(writeStart)
totalDuration := time.Since(start)
n = msgLen
log.T.F(
"ws->%s WRITE SUCCESS: len=%d duration=%v write_duration=%v",
l.remote, n, totalDuration, writeDuration,
)
// Log slow writes for performance diagnostics
if writeDuration > time.Millisecond*100 {
log.T.F(
"ws->%s SLOW WRITE detected: %v (>100ms) len=%d", l.remote,
writeDuration, n,
)
}
return
} }
// getManagedACL returns the managed ACL instance if available // getManagedACL returns the managed ACL instance if available

View File

@@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
@@ -18,6 +17,7 @@ import (
"next.orly.dev/pkg/encoders/kind" "next.orly.dev/pkg/encoders/kind"
"next.orly.dev/pkg/interfaces/publisher" "next.orly.dev/pkg/interfaces/publisher"
"next.orly.dev/pkg/interfaces/typer" "next.orly.dev/pkg/interfaces/typer"
"next.orly.dev/pkg/protocol/publish"
"next.orly.dev/pkg/utils" "next.orly.dev/pkg/utils"
) )
@@ -33,6 +33,9 @@ type Subscription struct {
// connections. // connections.
type Map map[*websocket.Conn]map[string]Subscription type Map map[*websocket.Conn]map[string]Subscription
// WriteChanMap maps websocket connections to their write channels
type WriteChanMap map[*websocket.Conn]chan<- publish.WriteRequest
type W struct { type W struct {
*websocket.Conn *websocket.Conn
@@ -69,19 +72,37 @@ type P struct {
Mx sync.RWMutex Mx sync.RWMutex
// Map is the map of subscribers and subscriptions from the websocket api. // Map is the map of subscribers and subscriptions from the websocket api.
Map Map
// WriteChans maps websocket connections to their write channels
WriteChans WriteChanMap
} }
var _ publisher.I = &P{} var _ publisher.I = &P{}
func NewPublisher(c context.Context) (publisher *P) { func NewPublisher(c context.Context) (publisher *P) {
return &P{ return &P{
c: c, c: c,
Map: make(Map), Map: make(Map),
WriteChans: make(WriteChanMap, 100),
} }
} }
func (p *P) Type() (typeName string) { return Type } func (p *P) Type() (typeName string) { return Type }
// SetWriteChan stores the write channel for a websocket connection
func (p *P) SetWriteChan(conn *websocket.Conn, writeChan chan<- publish.WriteRequest) {
p.Mx.Lock()
defer p.Mx.Unlock()
p.WriteChans[conn] = writeChan
}
// GetWriteChan returns the write channel for a websocket connection
func (p *P) GetWriteChan(conn *websocket.Conn) (chan<- publish.WriteRequest, bool) {
p.Mx.RLock()
defer p.Mx.RUnlock()
ch, ok := p.WriteChans[conn]
return ch, ok
}
// Receive handles incoming messages to manage websocket listener subscriptions // Receive handles incoming messages to manage websocket listener subscriptions
// and associated filters. // and associated filters.
// //
@@ -269,61 +290,40 @@ func (p *P) Deliver(ev *event.E) {
log.D.F("attempting delivery of event %s (kind=%d, len=%d) to subscription %s @ %s", log.D.F("attempting delivery of event %s (kind=%d, len=%d) to subscription %s @ %s",
hex.Enc(ev.ID), ev.Kind, len(msgData), d.id, d.sub.remote) hex.Enc(ev.ID), ev.Kind, len(msgData), d.id, d.sub.remote)
// Use a separate context with timeout for writes to prevent race conditions // Get write channel for this connection
// where the publisher context gets cancelled while writing events p.Mx.RLock()
deadline := time.Now().Add(DefaultWriteTimeout) writeChan, hasChan := p.GetWriteChan(d.w)
d.w.SetWriteDeadline(deadline) stillSubscribed := p.Map[d.w] != nil
p.Mx.RUnlock()
deliveryStart := time.Now() if !stillSubscribed {
if err = d.w.WriteMessage(websocket.TextMessage, msgData); err != nil { log.D.F("skipping delivery to %s - connection no longer subscribed", d.sub.remote)
deliveryDuration := time.Since(deliveryStart)
// Log detailed failure information
log.E.F("subscription delivery FAILED: event=%s to=%s sub=%s duration=%v error=%v",
hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, err)
// Check for timeout specifically
isTimeout := strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded")
if isTimeout {
log.E.F("subscription delivery TIMEOUT: event=%s to=%s after %v (limit=%v)",
hex.Enc(ev.ID), d.sub.remote, deliveryDuration, DefaultWriteTimeout)
}
// Only close connection on permanent errors, not transient timeouts
// WebSocket write errors typically indicate connection issues, but we should
// distinguish between timeouts (client might be slow) and connection errors
isConnectionError := strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "broken pipe") ||
strings.Contains(err.Error(), "connection reset") ||
websocket.IsCloseError(err, websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived)
if isConnectionError {
log.D.F("removing failed subscriber connection due to connection error: %s", d.sub.remote)
p.removeSubscriber(d.w)
_ = d.w.Close()
} else if isTimeout {
// For timeouts, log but don't immediately close - give it another chance
// The read deadline will catch dead connections eventually
log.W.F("subscription delivery timeout for %s (client may be slow), skipping event but keeping connection", d.sub.remote)
} else {
// Unknown error - be conservative and close
log.D.F("removing failed subscriber connection due to unknown error: %s", d.sub.remote)
p.removeSubscriber(d.w)
_ = d.w.Close()
}
continue continue
} }
deliveryDuration := time.Since(deliveryStart) if !hasChan {
log.D.F("subscription delivery SUCCESS: event=%s to=%s sub=%s duration=%v len=%d", log.D.F("skipping delivery to %s - no write channel available", d.sub.remote)
hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, len(msgData)) continue
}
// Log slow deliveries for performance monitoring // Send to write channel - non-blocking with timeout
if deliveryDuration > time.Millisecond*50 { select {
log.D.F("SLOW subscription delivery: event=%s to=%s duration=%v (>50ms)", case <-p.c.Done():
hex.Enc(ev.ID), d.sub.remote, deliveryDuration) continue
case writeChan <- publish.WriteRequest{Data: msgData, MsgType: websocket.TextMessage, IsControl: false}:
log.D.F("subscription delivery QUEUED: event=%s to=%s sub=%s len=%d",
hex.Enc(ev.ID), d.sub.remote, d.id, len(msgData))
case <-time.After(DefaultWriteTimeout):
log.E.F("subscription delivery TIMEOUT: event=%s to=%s sub=%s (write channel full)",
hex.Enc(ev.ID), d.sub.remote, d.id)
// Check if connection is still valid
p.Mx.RLock()
stillSubscribed = p.Map[d.w] != nil
p.Mx.RUnlock()
if !stillSubscribed {
log.D.F("removing failed subscriber connection due to channel timeout: %s", d.sub.remote)
p.removeSubscriber(d.w)
}
} }
} }
} }
@@ -340,6 +340,7 @@ func (p *P) removeSubscriberId(ws *websocket.Conn, id string) {
// Check the actual map after deletion, not the original reference // Check the actual map after deletion, not the original reference
if len(p.Map[ws]) == 0 { if len(p.Map[ws]) == 0 {
delete(p.Map, ws) delete(p.Map, ws)
delete(p.WriteChans, ws)
} }
} }
} }
@@ -350,6 +351,7 @@ func (p *P) removeSubscriber(ws *websocket.Conn) {
defer p.Mx.Unlock() defer p.Mx.Unlock()
clear(p.Map[ws]) clear(p.Map[ws])
delete(p.Map, ws) delete(p.Map, ws)
delete(p.WriteChans, ws)
} }
// canSeePrivateEvent checks if the authenticated user can see an event with a private tag // canSeePrivateEvent checks if the authenticated user can see an event with a private tag

View File

@@ -1,11 +1,28 @@
package publish package publish
import ( import (
"time"
"github.com/gorilla/websocket"
"next.orly.dev/pkg/encoders/event" "next.orly.dev/pkg/encoders/event"
"next.orly.dev/pkg/interfaces/publisher" "next.orly.dev/pkg/interfaces/publisher"
"next.orly.dev/pkg/interfaces/typer" "next.orly.dev/pkg/interfaces/typer"
) )
// WriteRequest represents a write operation to be performed by the write worker
type WriteRequest struct {
Data []byte
MsgType int
IsControl bool
Deadline time.Time
}
// WriteChanSetter defines the interface for setting write channels
type WriteChanSetter interface {
SetWriteChan(*websocket.Conn, chan<- WriteRequest)
GetWriteChan(*websocket.Conn) (chan<- WriteRequest, bool)
}
// S is the control structure for the subscription management scheme. // S is the control structure for the subscription management scheme.
type S struct { type S struct {
publisher.Publishers publisher.Publishers
@@ -36,3 +53,15 @@ func (s *S) Receive(msg typer.T) {
} }
} }
} }
// GetSocketPublisher returns the socketapi publisher instance
func (s *S) GetSocketPublisher() WriteChanSetter {
for _, p := range s.Publishers {
if p.Type() == "socketapi" {
if socketPub, ok := p.(WriteChanSetter); ok {
return socketPub
}
}
}
return nil
}

View File

@@ -1 +1 @@
v0.21.4 v0.23.0