Files
next.orly.dev/app/listener.go
mleku 88ebf6eccc
Some checks failed
Go / build (push) Has been cancelled
Go / release (push) Has been cancelled
Update WebSocket implementation to use Gorilla WebSocket library
- Replaced the existing `github.com/coder/websocket` package with `github.com/gorilla/websocket` for improved functionality and compatibility.
- Adjusted WebSocket connection handling, including message reading and writing, to align with the new library's API.
- Enhanced error handling and logging for WebSocket operations.
- Bumped version to v0.20.0 to reflect the changes made.
2025-10-30 15:20:39 +00:00

158 lines
4.3 KiB
Go

package app
import (
"context"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"next.orly.dev/pkg/acl"
"next.orly.dev/pkg/database"
"next.orly.dev/pkg/encoders/event"
"next.orly.dev/pkg/encoders/filter"
"next.orly.dev/pkg/utils"
"next.orly.dev/pkg/utils/atomic"
)
type Listener struct {
*Server
conn *websocket.Conn
ctx context.Context
remote string
req *http.Request
challenge atomic.Bytes
authedPubkey atomic.Bytes
startTime time.Time
isBlacklisted bool // Marker to identify blacklisted IPs
blacklistTimeout time.Time // When to timeout blacklisted connections
// Diagnostics: per-connection counters
msgCount int
reqCount int
eventCount int
}
// Ctx returns the listener's context, but creates a new context for each operation
// to prevent cancellation from affecting subsequent operations
func (l *Listener) Ctx() context.Context {
return l.ctx
}
func (l *Listener) Write(p []byte) (n int, err error) {
start := time.Now()
msgLen := len(p)
// Log message attempt with content preview (first 200 chars for diagnostics)
preview := string(p)
if len(preview) > 200 {
preview = preview[:200] + "..."
}
log.T.F(
"ws->%s attempting write: len=%d preview=%q", l.remote, msgLen, preview,
)
// Use a separate context with timeout for writes to prevent race conditions
// where the main connection context gets cancelled while writing events
deadline := time.Now().Add(DefaultWriteTimeout)
l.conn.SetWriteDeadline(deadline)
// Attempt the write operation
writeStart := time.Now()
if err = l.conn.WriteMessage(websocket.TextMessage, p); err != nil {
writeDuration := time.Since(writeStart)
totalDuration := time.Since(start)
// Log detailed failure information
log.E.F(
"ws->%s WRITE FAILED: len=%d duration=%v write_duration=%v error=%v preview=%q",
l.remote, msgLen, totalDuration, writeDuration, err, preview,
)
// Check if this is a context timeout
if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F(
"ws->%s write timeout after %v (limit=%v)", l.remote,
writeDuration, DefaultWriteTimeout,
)
}
// Check connection state
if l.conn != nil {
log.T.F(
"ws->%s connection state during failure: remote_addr=%v",
l.remote, l.req.RemoteAddr,
)
}
chk.E(err) // Still call the original error handler
return
}
// Log successful write with timing
writeDuration := time.Since(writeStart)
totalDuration := time.Since(start)
n = msgLen
log.T.F(
"ws->%s WRITE SUCCESS: len=%d duration=%v write_duration=%v",
l.remote, n, totalDuration, writeDuration,
)
// Log slow writes for performance diagnostics
if writeDuration > time.Millisecond*100 {
log.T.F(
"ws->%s SLOW WRITE detected: %v (>100ms) len=%d", l.remote,
writeDuration, n,
)
}
return
}
// getManagedACL returns the managed ACL instance if available
func (l *Listener) getManagedACL() *database.ManagedACL {
// Get the managed ACL instance from the ACL registry
for _, aclInstance := range acl.Registry.ACL {
if aclInstance.Type() == "managed" {
if managed, ok := aclInstance.(*acl.Managed); ok {
return managed.GetManagedACL()
}
}
}
return nil
}
// QueryEvents queries events using the database QueryEvents method
func (l *Listener) QueryEvents(ctx context.Context, f *filter.F) (event.S, error) {
return l.D.QueryEvents(ctx, f)
}
// QueryAllVersions queries events using the database QueryAllVersions method
func (l *Listener) QueryAllVersions(ctx context.Context, f *filter.F) (event.S, error) {
return l.D.QueryAllVersions(ctx, f)
}
// canSeePrivateEvent checks if the authenticated user can see an event with a private tag
func (l *Listener) canSeePrivateEvent(authedPubkey, privatePubkey []byte) (canSee bool) {
// If no authenticated user, deny access
if len(authedPubkey) == 0 {
return false
}
// If the authenticated user matches the private tag pubkey, allow access
if len(privatePubkey) > 0 && utils.FastEqual(authedPubkey, privatePubkey) {
return true
}
// Check if user is an admin or owner (they can see all private events)
accessLevel := acl.Registry.GetAccessLevel(authedPubkey, l.remote)
if accessLevel == "admin" || accessLevel == "owner" {
return true
}
// Default deny
return false
}