diff --git a/app/handle-websocket.go b/app/handle-websocket.go index 96f677a..0ff45c8 100644 --- a/app/handle-websocket.go +++ b/app/handle-websocket.go @@ -83,6 +83,16 @@ whitelist: remote: remote, req: r, startTime: time.Now(), + writeChan: make(chan WriteRequest, 100), // Buffered channel for writes + writeDone: make(chan struct{}), + } + + // Start write worker goroutine + go listener.writeWorker() + + // Register write channel with publisher + if socketPub := listener.publishers.GetSocketPublisher(); socketPub != nil { + socketPub.SetWriteChan(conn, listener.writeChan) } // Check for blacklisted IPs @@ -110,12 +120,14 @@ whitelist: return nil }) // Set ping handler - extends read deadline when pings are received - conn.SetPingHandler(func(string) error { + // Send pong through write channel + conn.SetPingHandler(func(msg string) error { conn.SetReadDeadline(time.Now().Add(DefaultPongWait)) - return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(DefaultWriteTimeout)) + deadline := time.Now().Add(DefaultWriteTimeout) + return listener.WriteControl(websocket.PongMessage, []byte{}, deadline) }) // Don't pass cancel to Pinger - it should not be able to cancel the connection context - go s.Pinger(ctx, conn, ticker) + go s.Pinger(ctx, listener, ticker) defer func() { log.D.F("closing websocket connection from %s", remote) @@ -123,6 +135,11 @@ whitelist: cancel() ticker.Stop() + // Close write channel to signal worker to exit + close(listener.writeChan) + // Wait for write worker to finish + <-listener.writeDone + // Cancel all subscriptions for this connection log.D.F("cancelling subscriptions for %s", remote) listener.publishers.Receive(&W{ @@ -222,11 +239,10 @@ whitelist: } if typ == websocket.PingMessage { log.D.F("received PING from %s, sending PONG", remote) - // Create a write context with timeout for pong response + // Send pong through write channel deadline := time.Now().Add(DefaultWriteTimeout) - conn.SetWriteDeadline(deadline) pongStart := time.Now() - if err = conn.WriteControl(websocket.PongMessage, msg, deadline); err != nil { + if err = listener.WriteControl(websocket.PongMessage, msg, deadline); err != nil { pongDuration := time.Since(pongStart) // Check if this is a timeout vs a connection error @@ -279,7 +295,7 @@ whitelist: } func (s *Server) Pinger( - ctx context.Context, conn *websocket.Conn, ticker *time.Ticker, + ctx context.Context, listener *Listener, ticker *time.Ticker, ) { defer func() { log.D.F("pinger shutting down") @@ -295,12 +311,11 @@ func (s *Server) Pinger( pingCount++ log.D.F("sending PING #%d", pingCount) - // Set write deadline for ping operation + // Send ping through write channel deadline := time.Now().Add(DefaultWriteTimeout) - conn.SetWriteDeadline(deadline) pingStart := time.Now() - if err = conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil { + if err = listener.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil { pingDuration := time.Since(pingStart) // Check if this is a timeout vs a connection error diff --git a/app/listener.go b/app/listener.go index 7a71e65..6b451e6 100644 --- a/app/listener.go +++ b/app/listener.go @@ -7,16 +7,20 @@ import ( "time" "github.com/gorilla/websocket" - "lol.mleku.dev/chk" + "lol.mleku.dev/errorf" "lol.mleku.dev/log" "next.orly.dev/pkg/acl" "next.orly.dev/pkg/database" "next.orly.dev/pkg/encoders/event" "next.orly.dev/pkg/encoders/filter" + "next.orly.dev/pkg/protocol/publish" "next.orly.dev/pkg/utils" "next.orly.dev/pkg/utils/atomic" ) +// WriteRequest represents a write operation to be performed by the write worker +type WriteRequest = publish.WriteRequest + type Listener struct { *Server conn *websocket.Conn @@ -28,6 +32,8 @@ type Listener struct { startTime time.Time isBlacklisted bool // Marker to identify blacklisted IPs blacklistTimeout time.Time // When to timeout blacklisted connections + writeChan chan WriteRequest // Channel for write requests + writeDone chan struct{} // Closed when write worker exits // Diagnostics: per-connection counters msgCount int reqCount int @@ -40,75 +46,80 @@ func (l *Listener) Ctx() context.Context { return l.ctx } +// writeWorker is the single goroutine that handles all writes to the websocket connection. +// This serializes all writes to prevent concurrent write panics. +func (l *Listener) writeWorker() { + defer close(l.writeDone) + for { + select { + case <-l.ctx.Done(): + return + case req, ok := <-l.writeChan: + if !ok { + return + } + deadline := req.Deadline + if deadline.IsZero() { + deadline = time.Now().Add(DefaultWriteTimeout) + } + l.conn.SetWriteDeadline(deadline) + writeStart := time.Now() + var err error + if req.IsControl { + err = l.conn.WriteControl(req.MsgType, req.Data, deadline) + } else { + err = l.conn.WriteMessage(req.MsgType, req.Data) + } + if err != nil { + writeDuration := time.Since(writeStart) + log.E.F("ws->%s write worker FAILED: len=%d duration=%v error=%v", + l.remote, len(req.Data), writeDuration, err) + // Check for connection errors - if so, stop the worker + isConnectionError := strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "broken pipe") || + strings.Contains(err.Error(), "connection reset") || + websocket.IsCloseError(err, websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) + if isConnectionError { + return + } + // Continue for other errors (timeouts, etc.) + } else { + writeDuration := time.Since(writeStart) + if writeDuration > time.Millisecond*100 { + log.D.F("ws->%s write worker SLOW: len=%d duration=%v", + l.remote, len(req.Data), writeDuration) + } + } + } + } +} + func (l *Listener) Write(p []byte) (n int, err error) { - start := time.Now() - msgLen := len(p) - - // Log message attempt with content preview (first 200 chars for diagnostics) - preview := string(p) - if len(preview) > 200 { - preview = preview[:200] + "..." + // Send write request to channel - non-blocking with timeout + select { + case <-l.ctx.Done(): + return 0, l.ctx.Err() + case l.writeChan <- WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}: + return len(p), nil + case <-time.After(DefaultWriteTimeout): + log.E.F("ws->%s write channel timeout", l.remote) + return 0, errorf.E("write channel timeout") } - log.T.F( - "ws->%s attempting write: len=%d preview=%q", l.remote, msgLen, preview, - ) +} - // Use a separate context with timeout for writes to prevent race conditions - // where the main connection context gets cancelled while writing events - deadline := time.Now().Add(DefaultWriteTimeout) - l.conn.SetWriteDeadline(deadline) - - // Attempt the write operation - writeStart := time.Now() - if err = l.conn.WriteMessage(websocket.TextMessage, p); err != nil { - writeDuration := time.Since(writeStart) - totalDuration := time.Since(start) - - // Log detailed failure information - log.E.F( - "ws->%s WRITE FAILED: len=%d duration=%v write_duration=%v error=%v preview=%q", - l.remote, msgLen, totalDuration, writeDuration, err, preview, - ) - - // Check if this is a context timeout - if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") { - log.E.F( - "ws->%s write timeout after %v (limit=%v)", l.remote, - writeDuration, DefaultWriteTimeout, - ) - } - - // Check connection state - if l.conn != nil { - log.T.F( - "ws->%s connection state during failure: remote_addr=%v", - l.remote, l.req.RemoteAddr, - ) - } - - chk.E(err) // Still call the original error handler - return +// WriteControl sends a control message through the write channel +func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time) (err error) { + select { + case <-l.ctx.Done(): + return l.ctx.Err() + case l.writeChan <- WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}: + return nil + case <-time.After(DefaultWriteTimeout): + log.E.F("ws->%s writeControl channel timeout", l.remote) + return errorf.E("writeControl channel timeout") } - - // Log successful write with timing - writeDuration := time.Since(writeStart) - totalDuration := time.Since(start) - n = msgLen - - log.T.F( - "ws->%s WRITE SUCCESS: len=%d duration=%v write_duration=%v", - l.remote, n, totalDuration, writeDuration, - ) - - // Log slow writes for performance diagnostics - if writeDuration > time.Millisecond*100 { - log.T.F( - "ws->%s SLOW WRITE detected: %v (>100ms) len=%d", l.remote, - writeDuration, n, - ) - } - - return } // getManagedACL returns the managed ACL instance if available diff --git a/app/publisher.go b/app/publisher.go index c8ea945..a406f19 100644 --- a/app/publisher.go +++ b/app/publisher.go @@ -3,7 +3,6 @@ package app import ( "context" "fmt" - "strings" "sync" "time" @@ -18,6 +17,7 @@ import ( "next.orly.dev/pkg/encoders/kind" "next.orly.dev/pkg/interfaces/publisher" "next.orly.dev/pkg/interfaces/typer" + "next.orly.dev/pkg/protocol/publish" "next.orly.dev/pkg/utils" ) @@ -33,6 +33,9 @@ type Subscription struct { // connections. type Map map[*websocket.Conn]map[string]Subscription +// WriteChanMap maps websocket connections to their write channels +type WriteChanMap map[*websocket.Conn]chan<- publish.WriteRequest + type W struct { *websocket.Conn @@ -69,19 +72,37 @@ type P struct { Mx sync.RWMutex // Map is the map of subscribers and subscriptions from the websocket api. Map + // WriteChans maps websocket connections to their write channels + WriteChans WriteChanMap } var _ publisher.I = &P{} func NewPublisher(c context.Context) (publisher *P) { return &P{ - c: c, - Map: make(Map), + c: c, + Map: make(Map), + WriteChans: make(WriteChanMap, 100), } } func (p *P) Type() (typeName string) { return Type } +// SetWriteChan stores the write channel for a websocket connection +func (p *P) SetWriteChan(conn *websocket.Conn, writeChan chan<- publish.WriteRequest) { + p.Mx.Lock() + defer p.Mx.Unlock() + p.WriteChans[conn] = writeChan +} + +// GetWriteChan returns the write channel for a websocket connection +func (p *P) GetWriteChan(conn *websocket.Conn) (chan<- publish.WriteRequest, bool) { + p.Mx.RLock() + defer p.Mx.RUnlock() + ch, ok := p.WriteChans[conn] + return ch, ok +} + // Receive handles incoming messages to manage websocket listener subscriptions // and associated filters. // @@ -269,61 +290,40 @@ func (p *P) Deliver(ev *event.E) { log.D.F("attempting delivery of event %s (kind=%d, len=%d) to subscription %s @ %s", hex.Enc(ev.ID), ev.Kind, len(msgData), d.id, d.sub.remote) - // Use a separate context with timeout for writes to prevent race conditions - // where the publisher context gets cancelled while writing events - deadline := time.Now().Add(DefaultWriteTimeout) - d.w.SetWriteDeadline(deadline) + // Get write channel for this connection + p.Mx.RLock() + writeChan, hasChan := p.GetWriteChan(d.w) + stillSubscribed := p.Map[d.w] != nil + p.Mx.RUnlock() - deliveryStart := time.Now() - if err = d.w.WriteMessage(websocket.TextMessage, msgData); err != nil { - deliveryDuration := time.Since(deliveryStart) - - // Log detailed failure information - log.E.F("subscription delivery FAILED: event=%s to=%s sub=%s duration=%v error=%v", - hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, err) - - // Check for timeout specifically - isTimeout := strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded") - if isTimeout { - log.E.F("subscription delivery TIMEOUT: event=%s to=%s after %v (limit=%v)", - hex.Enc(ev.ID), d.sub.remote, deliveryDuration, DefaultWriteTimeout) - } - - // Only close connection on permanent errors, not transient timeouts - // WebSocket write errors typically indicate connection issues, but we should - // distinguish between timeouts (client might be slow) and connection errors - isConnectionError := strings.Contains(err.Error(), "use of closed network connection") || - strings.Contains(err.Error(), "broken pipe") || - strings.Contains(err.Error(), "connection reset") || - websocket.IsCloseError(err, websocket.CloseAbnormalClosure, - websocket.CloseGoingAway, - websocket.CloseNoStatusReceived) - - if isConnectionError { - log.D.F("removing failed subscriber connection due to connection error: %s", d.sub.remote) - p.removeSubscriber(d.w) - _ = d.w.Close() - } else if isTimeout { - // For timeouts, log but don't immediately close - give it another chance - // The read deadline will catch dead connections eventually - log.W.F("subscription delivery timeout for %s (client may be slow), skipping event but keeping connection", d.sub.remote) - } else { - // Unknown error - be conservative and close - log.D.F("removing failed subscriber connection due to unknown error: %s", d.sub.remote) - p.removeSubscriber(d.w) - _ = d.w.Close() - } + if !stillSubscribed { + log.D.F("skipping delivery to %s - connection no longer subscribed", d.sub.remote) continue } - deliveryDuration := time.Since(deliveryStart) - log.D.F("subscription delivery SUCCESS: event=%s to=%s sub=%s duration=%v len=%d", - hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, len(msgData)) + if !hasChan { + log.D.F("skipping delivery to %s - no write channel available", d.sub.remote) + continue + } - // Log slow deliveries for performance monitoring - if deliveryDuration > time.Millisecond*50 { - log.D.F("SLOW subscription delivery: event=%s to=%s duration=%v (>50ms)", - hex.Enc(ev.ID), d.sub.remote, deliveryDuration) + // Send to write channel - non-blocking with timeout + select { + case <-p.c.Done(): + continue + case writeChan <- publish.WriteRequest{Data: msgData, MsgType: websocket.TextMessage, IsControl: false}: + log.D.F("subscription delivery QUEUED: event=%s to=%s sub=%s len=%d", + hex.Enc(ev.ID), d.sub.remote, d.id, len(msgData)) + case <-time.After(DefaultWriteTimeout): + log.E.F("subscription delivery TIMEOUT: event=%s to=%s sub=%s (write channel full)", + hex.Enc(ev.ID), d.sub.remote, d.id) + // Check if connection is still valid + p.Mx.RLock() + stillSubscribed = p.Map[d.w] != nil + p.Mx.RUnlock() + if !stillSubscribed { + log.D.F("removing failed subscriber connection due to channel timeout: %s", d.sub.remote) + p.removeSubscriber(d.w) + } } } } @@ -340,6 +340,7 @@ func (p *P) removeSubscriberId(ws *websocket.Conn, id string) { // Check the actual map after deletion, not the original reference if len(p.Map[ws]) == 0 { delete(p.Map, ws) + delete(p.WriteChans, ws) } } } @@ -350,6 +351,7 @@ func (p *P) removeSubscriber(ws *websocket.Conn) { defer p.Mx.Unlock() clear(p.Map[ws]) delete(p.Map, ws) + delete(p.WriteChans, ws) } // canSeePrivateEvent checks if the authenticated user can see an event with a private tag diff --git a/pkg/protocol/publish/publisher.go b/pkg/protocol/publish/publisher.go index 5e7906b..df5f338 100644 --- a/pkg/protocol/publish/publisher.go +++ b/pkg/protocol/publish/publisher.go @@ -1,11 +1,28 @@ package publish import ( + "time" + + "github.com/gorilla/websocket" "next.orly.dev/pkg/encoders/event" "next.orly.dev/pkg/interfaces/publisher" "next.orly.dev/pkg/interfaces/typer" ) +// WriteRequest represents a write operation to be performed by the write worker +type WriteRequest struct { + Data []byte + MsgType int + IsControl bool + Deadline time.Time +} + +// WriteChanSetter defines the interface for setting write channels +type WriteChanSetter interface { + SetWriteChan(*websocket.Conn, chan<- WriteRequest) + GetWriteChan(*websocket.Conn) (chan<- WriteRequest, bool) +} + // S is the control structure for the subscription management scheme. type S struct { publisher.Publishers @@ -36,3 +53,15 @@ func (s *S) Receive(msg typer.T) { } } } + +// GetSocketPublisher returns the socketapi publisher instance +func (s *S) GetSocketPublisher() WriteChanSetter { + for _, p := range s.Publishers { + if p.Type() == "socketapi" { + if socketPub, ok := p.(WriteChanSetter); ok { + return socketPub + } + } + } + return nil +} diff --git a/pkg/version/version b/pkg/version/version index a1849d2..964a04d 100644 --- a/pkg/version/version +++ b/pkg/version/version @@ -1 +1 @@ -v0.21.4 \ No newline at end of file +v0.23.0 \ No newline at end of file