Update WebSocket implementation to use Gorilla WebSocket library
- Replaced the existing `github.com/coder/websocket` package with `github.com/gorilla/websocket` for improved functionality and compatibility. - Adjusted WebSocket connection handling, including message reading and writing, to align with the new library's API. - Enhanced error handling and logging for WebSocket operations. - Bumped version to v0.20.0 to reflect the changes made.
This commit is contained in:
@@ -3,21 +3,19 @@ package ws
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"lol.mleku.dev/errorf"
|
||||
"next.orly.dev/pkg/utils/units"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// Connection represents a websocket connection to a Nostr relay.
|
||||
type Connection struct {
|
||||
conn *ws.Conn
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
// NewConnection creates a new websocket connection to a Nostr relay.
|
||||
@@ -25,10 +23,23 @@ func NewConnection(
|
||||
ctx context.Context, url string, reqHeader http.Header,
|
||||
tlsConfig *tls.Config,
|
||||
) (c *Connection, err error) {
|
||||
var conn *ws.Conn
|
||||
if conn, _, err = ws.Dial(
|
||||
ctx, url, getConnectionOptions(reqHeader, tlsConfig),
|
||||
); err != nil {
|
||||
var conn *websocket.Conn
|
||||
var resp *http.Response
|
||||
dialer := getConnectionOptions(reqHeader, tlsConfig)
|
||||
|
||||
// Prepare headers with default User-Agent if not present
|
||||
headers := reqHeader
|
||||
if headers == nil {
|
||||
headers = make(http.Header)
|
||||
}
|
||||
if headers.Get("User-Agent") == "" {
|
||||
headers.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
|
||||
}
|
||||
|
||||
if conn, resp, err = dialer.DialContext(ctx, url, headers); err != nil {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
conn.SetReadLimit(33 * units.Mb)
|
||||
@@ -41,7 +52,14 @@ func NewConnection(
|
||||
func (c *Connection) WriteMessage(
|
||||
ctx context.Context, data []byte,
|
||||
) (err error) {
|
||||
if err = c.conn.Write(ctx, ws.MessageText, data); err != nil {
|
||||
deadline := time.Now().Add(10 * time.Second)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
err = errorf.E("failed to write message: %w", err)
|
||||
return
|
||||
}
|
||||
@@ -52,11 +70,22 @@ func (c *Connection) WriteMessage(
|
||||
func (c *Connection) ReadMessage(
|
||||
ctx context.Context, buf io.Writer,
|
||||
) (err error) {
|
||||
var reader io.Reader
|
||||
if _, reader, err = c.conn.Reader(ctx); err != nil {
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetReadDeadline(deadline)
|
||||
messageType, reader, err := c.conn.NextReader()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get reader: %w", err)
|
||||
return
|
||||
}
|
||||
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
|
||||
err = fmt.Errorf("unexpected message type: %d", messageType)
|
||||
return
|
||||
}
|
||||
if _, err = io.Copy(buf, reader); err != nil {
|
||||
err = fmt.Errorf("failed to read message: %w", err)
|
||||
return
|
||||
@@ -66,14 +95,18 @@ func (c *Connection) ReadMessage(
|
||||
|
||||
// Close closes the websocket connection.
|
||||
func (c *Connection) Close() error {
|
||||
return c.conn.Close(ws.StatusNormalClosure, "")
|
||||
c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping sends a ping message to the websocket connection.
|
||||
func (c *Connection) Ping(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeoutCause(
|
||||
ctx, time.Millisecond*800, errors.New("ping took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
return c.conn.Ping(ctx)
|
||||
deadline := time.Now().Add(800 * time.Millisecond)
|
||||
if ctx != nil {
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
}
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
|
||||
}
|
||||
|
||||
@@ -5,32 +5,21 @@ package ws
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"time"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var defaultConnectionOptions = &ws.DialOptions{
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
|
||||
},
|
||||
}
|
||||
|
||||
func getConnectionOptions(
|
||||
requestHeader http.Header, tlsConfig *tls.Config,
|
||||
) *ws.DialOptions {
|
||||
if requestHeader == nil && tlsConfig == nil {
|
||||
return defaultConnectionOptions
|
||||
}
|
||||
|
||||
return &ws.DialOptions{
|
||||
HTTPHeader: requestHeader,
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
},
|
||||
) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
TLSClientConfig: tlsConfig,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
// Headers are passed directly to DialContext, not set on Dialer
|
||||
// The User-Agent header will be set when calling DialContext if not present
|
||||
return dialer
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user