377 lines
9.2 KiB
Go
377 lines
9.2 KiB
Go
package ws
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
// mockConn implements net.Conn for testing
|
|
type mockConn struct {
|
|
readData []byte
|
|
writeData []byte
|
|
readErr error
|
|
writeErr error
|
|
closed bool
|
|
}
|
|
|
|
func (m *mockConn) Read(b []byte) (n int, err error) {
|
|
if m.readErr != nil {
|
|
return 0, m.readErr
|
|
}
|
|
if len(m.readData) == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
n = copy(b, m.readData)
|
|
m.readData = m.readData[n:]
|
|
return n, nil
|
|
}
|
|
|
|
func (m *mockConn) Write(b []byte) (n int, err error) {
|
|
if m.writeErr != nil {
|
|
return 0, m.writeErr
|
|
}
|
|
m.writeData = append(m.writeData, b...)
|
|
return len(b), nil
|
|
}
|
|
|
|
func (m *mockConn) Close() error {
|
|
m.closed = true
|
|
return nil
|
|
}
|
|
|
|
func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
|
|
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
func TestNewConnection(t *testing.T) {
|
|
// Create a test websocket server
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Echo server - just read and write back
|
|
for {
|
|
var msg string
|
|
if err := websocket.Message.Receive(conn, &msg); err != nil {
|
|
break
|
|
}
|
|
websocket.Message.Send(conn, msg)
|
|
}
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
// Convert http:// to ws://
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if conn.conn == nil {
|
|
t.Error("Connection should have a valid conn")
|
|
}
|
|
}
|
|
|
|
func TestNewConnectionWithTLS(t *testing.T) {
|
|
// Create a test HTTPS server
|
|
server := httptest.NewTLSServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Echo server
|
|
for {
|
|
var msg string
|
|
if err := websocket.Message.Receive(conn, &msg); err != nil {
|
|
break
|
|
}
|
|
websocket.Message.Send(conn, msg)
|
|
}
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
// Convert https:// to wss://
|
|
wsURL := "wss" + strings.TrimPrefix(server.URL, "https")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Use insecure TLS config for testing
|
|
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, tlsConfig)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection with TLS failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if conn.conn == nil {
|
|
t.Error("Connection should have a valid conn")
|
|
}
|
|
}
|
|
|
|
func TestNewConnectionWithHeaders(t *testing.T) {
|
|
expectedOrigin := "https://example.com"
|
|
var receivedOrigin string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
receivedOrigin = r.Header.Get("Origin")
|
|
websocket.Handler(func(conn *websocket.Conn) {
|
|
// Simple echo
|
|
io.Copy(conn, conn)
|
|
}).ServeHTTP(w, r)
|
|
}))
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
headers := http.Header{
|
|
"Origin": []string{expectedOrigin},
|
|
}
|
|
|
|
conn, err := NewConnection(ctx, wsURL, headers, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection with headers failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if receivedOrigin != expectedOrigin {
|
|
t.Errorf("Expected origin %s, got %s", expectedOrigin, receivedOrigin)
|
|
}
|
|
}
|
|
|
|
func TestNewConnectionTimeout(t *testing.T) {
|
|
// Use a non-existent address to force timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := NewConnection(ctx, "ws://192.0.2.1:12345", nil, nil)
|
|
if err == nil {
|
|
t.Error("Expected connection to timeout")
|
|
}
|
|
}
|
|
|
|
func TestConnectionWriteMessage(t *testing.T) {
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Read the message
|
|
var msg string
|
|
websocket.Message.Receive(conn, &msg)
|
|
// Echo it back
|
|
websocket.Message.Send(conn, "received: "+msg)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
testMessage := []byte("test message")
|
|
err = conn.WriteMessage(ctx, testMessage)
|
|
if err != nil {
|
|
t.Fatalf("WriteMessage failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectionWriteMessageCanceled(t *testing.T) {
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Just keep the connection open
|
|
time.Sleep(10 * time.Second)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Cancel the context immediately
|
|
cancelCtx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err = conn.WriteMessage(cancelCtx, []byte("test"))
|
|
if err == nil {
|
|
t.Error("Expected WriteMessage to fail with canceled context")
|
|
}
|
|
if !strings.Contains(err.Error(), "context canceled") {
|
|
t.Errorf("Expected context canceled error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectionReadMessage(t *testing.T) {
|
|
testMessage := "hello world"
|
|
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Send a message immediately
|
|
websocket.Message.Send(conn, testMessage)
|
|
// Keep connection open for a bit
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
buf := &bytes.Buffer{}
|
|
err = conn.ReadMessage(ctx, buf)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage failed: %v", err)
|
|
}
|
|
|
|
received := buf.String()
|
|
if received != testMessage {
|
|
t.Errorf("Expected %s, got %s", testMessage, received)
|
|
}
|
|
}
|
|
|
|
func TestConnectionReadMessageCanceled(t *testing.T) {
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Don't send anything, just keep connection open
|
|
time.Sleep(10 * time.Second)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Cancel the context immediately
|
|
cancelCtx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
buf := &bytes.Buffer{}
|
|
err = conn.ReadMessage(cancelCtx, buf)
|
|
if err == nil {
|
|
t.Error("Expected ReadMessage to fail with canceled context")
|
|
}
|
|
if !strings.Contains(err.Error(), "context canceled") {
|
|
t.Errorf("Expected context canceled error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectionClose(t *testing.T) {
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
// Keep connection open
|
|
time.Sleep(1 * time.Second)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
|
|
err = conn.Close()
|
|
if err != nil {
|
|
t.Errorf("Close failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectionWithCompression(t *testing.T) {
|
|
// This test verifies that compression is handled properly
|
|
// The actual compression negotiation happens during the handshake
|
|
server := httptest.NewServer(&websocket.Server{
|
|
Handshake: anyOriginHandshake,
|
|
Handler: func(conn *websocket.Conn) {
|
|
var msg string
|
|
websocket.Message.Receive(conn, &msg)
|
|
websocket.Message.Send(conn, "compressed: "+msg)
|
|
},
|
|
})
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewConnection failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Test that we can write and read even if compression is enabled
|
|
testMessage := []byte("test compression message")
|
|
err = conn.WriteMessage(ctx, testMessage)
|
|
if err != nil {
|
|
t.Fatalf("WriteMessage failed: %v", err)
|
|
}
|
|
|
|
buf := &bytes.Buffer{}
|
|
err = conn.ReadMessage(ctx, buf)
|
|
if err != nil {
|
|
t.Fatalf("ReadMessage failed: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(buf.String(), "compressed:") {
|
|
t.Error("Expected response to contain 'compressed:'")
|
|
}
|
|
}
|