Update WebSocket implementation to use Gorilla WebSocket library
- Replaced the existing `github.com/coder/websocket` package with `github.com/gorilla/websocket` for improved functionality and compatibility. - Adjusted WebSocket connection handling, including message reading and writing, to align with the new library's API. - Enhanced error handling and logging for WebSocket operations. - Bumped version to v0.20.0 to reflect the changes made.
This commit is contained in:
@@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/gorilla/websocket"
|
||||
"lol.mleku.dev/chk"
|
||||
"lol.mleku.dev/errorf"
|
||||
"lol.mleku.dev/log"
|
||||
@@ -396,12 +396,15 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
|
||||
headers.Set("Origin", "https://orly.dev")
|
||||
|
||||
// Use proper WebSocket dial options
|
||||
dialOptions := &websocket.DialOptions{
|
||||
HTTPHeader: headers,
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
c, _, err := websocket.Dial(connCtx, u, dialOptions)
|
||||
c, resp, err := dialer.DialContext(connCtx, u, headers)
|
||||
cancel()
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.W.F("follows syncer: dial %s failed: %v", u, err)
|
||||
|
||||
@@ -480,13 +483,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
|
||||
req := reqenvelope.NewFrom([]byte(subID), ff)
|
||||
reqBytes := req.Marshal(nil)
|
||||
log.T.F("follows syncer: outbound REQ to %s: %s", u, string(reqBytes))
|
||||
if err = c.Write(
|
||||
ctx, websocket.MessageText, reqBytes,
|
||||
); chk.E(err) {
|
||||
c.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
|
||||
log.W.F(
|
||||
"follows syncer: failed to send event REQ to %s: %v", u, err,
|
||||
)
|
||||
_ = c.Close(websocket.StatusInternalError, "write failed")
|
||||
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "write failed"), time.Now().Add(time.Second))
|
||||
continue
|
||||
}
|
||||
log.T.F(
|
||||
@@ -501,11 +503,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = c.Close(websocket.StatusNormalClosure, "ctx done")
|
||||
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "ctx done"), time.Now().Add(time.Second))
|
||||
return
|
||||
case <-keepaliveTicker.C:
|
||||
// Send ping to keep connection alive
|
||||
if err := c.Ping(ctx); err != nil {
|
||||
c.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
|
||||
log.T.F("follows syncer: ping failed for %s: %v", u, err)
|
||||
break readLoop
|
||||
}
|
||||
@@ -513,11 +516,10 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
|
||||
continue
|
||||
default:
|
||||
// Set a read timeout to avoid hanging
|
||||
readCtx, readCancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
_, data, err := c.Read(readCtx)
|
||||
readCancel()
|
||||
c.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
_, data, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
_ = c.Close(websocket.StatusNormalClosure, "read err")
|
||||
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "read err"), time.Now().Add(time.Second))
|
||||
break readLoop
|
||||
}
|
||||
label, rem, err := envelopes.Identify(data)
|
||||
@@ -714,16 +716,19 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
|
||||
headers.Set("Origin", "https://orly.dev")
|
||||
|
||||
// Use proper WebSocket dial options
|
||||
dialOptions := &websocket.DialOptions{
|
||||
HTTPHeader: headers,
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
c, _, err := websocket.Dial(ctx, relayURL, dialOptions)
|
||||
c, resp, err := dialer.DialContext(ctx, relayURL, headers)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.W.F("follows syncer: failed to connect to %s for follow list fetch: %v", relayURL, err)
|
||||
return
|
||||
}
|
||||
defer c.Close(websocket.StatusNormalClosure, "follow list fetch complete")
|
||||
defer c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "follow list fetch complete"), time.Now().Add(time.Second))
|
||||
|
||||
log.I.F("follows syncer: fetching follow lists from relay %s", relayURL)
|
||||
|
||||
@@ -746,7 +751,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
|
||||
req := reqenvelope.NewFrom([]byte(subID), ff)
|
||||
reqBytes := req.Marshal(nil)
|
||||
log.T.F("follows syncer: outbound REQ to %s: %s", relayURL, string(reqBytes))
|
||||
if err = c.Write(ctx, websocket.MessageText, reqBytes); chk.E(err) {
|
||||
c.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
|
||||
log.W.F("follows syncer: failed to send follow list REQ to %s: %v", relayURL, err)
|
||||
return
|
||||
}
|
||||
@@ -769,7 +775,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
|
||||
default:
|
||||
}
|
||||
|
||||
_, data, err := c.Read(ctx)
|
||||
c.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
_, data, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.T.F("follows syncer: error reading events from %s: %v", relayURL, err)
|
||||
goto processEvents
|
||||
|
||||
@@ -3,21 +3,19 @@ package ws
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"lol.mleku.dev/errorf"
|
||||
"next.orly.dev/pkg/utils/units"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// Connection represents a websocket connection to a Nostr relay.
|
||||
type Connection struct {
|
||||
conn *ws.Conn
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
// NewConnection creates a new websocket connection to a Nostr relay.
|
||||
@@ -25,10 +23,23 @@ func NewConnection(
|
||||
ctx context.Context, url string, reqHeader http.Header,
|
||||
tlsConfig *tls.Config,
|
||||
) (c *Connection, err error) {
|
||||
var conn *ws.Conn
|
||||
if conn, _, err = ws.Dial(
|
||||
ctx, url, getConnectionOptions(reqHeader, tlsConfig),
|
||||
); err != nil {
|
||||
var conn *websocket.Conn
|
||||
var resp *http.Response
|
||||
dialer := getConnectionOptions(reqHeader, tlsConfig)
|
||||
|
||||
// Prepare headers with default User-Agent if not present
|
||||
headers := reqHeader
|
||||
if headers == nil {
|
||||
headers = make(http.Header)
|
||||
}
|
||||
if headers.Get("User-Agent") == "" {
|
||||
headers.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
|
||||
}
|
||||
|
||||
if conn, resp, err = dialer.DialContext(ctx, url, headers); err != nil {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
conn.SetReadLimit(33 * units.Mb)
|
||||
@@ -41,7 +52,14 @@ func NewConnection(
|
||||
func (c *Connection) WriteMessage(
|
||||
ctx context.Context, data []byte,
|
||||
) (err error) {
|
||||
if err = c.conn.Write(ctx, ws.MessageText, data); err != nil {
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
err = errorf.E("failed to write message: %w", err)
|
||||
return
|
||||
}
|
||||
@@ -52,11 +70,22 @@ func (c *Connection) WriteMessage(
|
||||
func (c *Connection) ReadMessage(
|
||||
ctx context.Context, buf io.Writer,
|
||||
) (err error) {
|
||||
var reader io.Reader
|
||||
if _, reader, err = c.conn.Reader(ctx); err != nil {
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetReadDeadline(deadline)
|
||||
messageType, reader, err := c.conn.NextReader()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get reader: %w", err)
|
||||
return
|
||||
}
|
||||
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
|
||||
err = fmt.Errorf("unexpected message type: %d", messageType)
|
||||
return
|
||||
}
|
||||
if _, err = io.Copy(buf, reader); err != nil {
|
||||
err = fmt.Errorf("failed to read message: %w", err)
|
||||
return
|
||||
@@ -66,14 +95,18 @@ func (c *Connection) ReadMessage(
|
||||
|
||||
// Close closes the websocket connection.
|
||||
func (c *Connection) Close() error {
|
||||
return c.conn.Close(ws.StatusNormalClosure, "")
|
||||
c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping sends a ping message to the websocket connection.
|
||||
func (c *Connection) Ping(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeoutCause(
|
||||
ctx, time.Millisecond*800, errors.New("ping took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
return c.conn.Ping(ctx)
|
||||
deadline := time.Now().Add(800 * time.Millisecond)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
|
||||
}
|
||||
|
||||
@@ -5,32 +5,21 @@ package ws
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"time"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var defaultConnectionOptions = &ws.DialOptions{
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
|
||||
},
|
||||
}
|
||||
|
||||
func getConnectionOptions(
|
||||
requestHeader http.Header, tlsConfig *tls.Config,
|
||||
) *ws.DialOptions {
|
||||
if requestHeader == nil && tlsConfig == nil {
|
||||
return defaultConnectionOptions
|
||||
}
|
||||
|
||||
return &ws.DialOptions{
|
||||
HTTPHeader: requestHeader,
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
},
|
||||
) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
TLSClientConfig: tlsConfig,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
// Headers are passed directly to DialContext, not set on Dialer
|
||||
// The User-Agent header will be set when calling DialContext if not present
|
||||
return dialer
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
v0.19.9
|
||||
v0.20.0
|
||||
Reference in New Issue
Block a user