Files
realy/ws/connection_test.go

377 lines
9.2 KiB
Go

package ws
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/net/websocket"
)
// mockConn implements net.Conn for testing
type mockConn struct {
readData []byte
writeData []byte
readErr error
writeErr error
closed bool
}
func (m *mockConn) Read(b []byte) (n int, err error) {
if m.readErr != nil {
return 0, m.readErr
}
if len(m.readData) == 0 {
return 0, io.EOF
}
n = copy(b, m.readData)
m.readData = m.readData[n:]
return n, nil
}
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.writeErr != nil {
return 0, m.writeErr
}
m.writeData = append(m.writeData, b...)
return len(b), nil
}
func (m *mockConn) Close() error {
m.closed = true
return nil
}
func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestNewConnection(t *testing.T) {
// Create a test websocket server
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Echo server - just read and write back
for {
var msg string
if err := websocket.Message.Receive(conn, &msg); err != nil {
break
}
websocket.Message.Send(conn, msg)
}
},
})
defer server.Close()
// Convert http:// to ws://
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
if conn.conn == nil {
t.Error("Connection should have a valid conn")
}
}
func TestNewConnectionWithTLS(t *testing.T) {
// Create a test HTTPS server
server := httptest.NewTLSServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Echo server
for {
var msg string
if err := websocket.Message.Receive(conn, &msg); err != nil {
break
}
websocket.Message.Send(conn, msg)
}
},
})
defer server.Close()
// Convert https:// to wss://
wsURL := "wss" + strings.TrimPrefix(server.URL, "https")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Use insecure TLS config for testing
tlsConfig := &tls.Config{InsecureSkipVerify: true}
conn, err := NewConnection(ctx, wsURL, nil, tlsConfig)
if err != nil {
t.Fatalf("NewConnection with TLS failed: %v", err)
}
defer conn.Close()
if conn.conn == nil {
t.Error("Connection should have a valid conn")
}
}
func TestNewConnectionWithHeaders(t *testing.T) {
expectedOrigin := "https://example.com"
var receivedOrigin string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedOrigin = r.Header.Get("Origin")
websocket.Handler(func(conn *websocket.Conn) {
// Simple echo
io.Copy(conn, conn)
}).ServeHTTP(w, r)
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
headers := http.Header{
"Origin": []string{expectedOrigin},
}
conn, err := NewConnection(ctx, wsURL, headers, nil)
if err != nil {
t.Fatalf("NewConnection with headers failed: %v", err)
}
defer conn.Close()
if receivedOrigin != expectedOrigin {
t.Errorf("Expected origin %s, got %s", expectedOrigin, receivedOrigin)
}
}
func TestNewConnectionTimeout(t *testing.T) {
// Use a non-existent address to force timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := NewConnection(ctx, "ws://192.0.2.1:12345", nil, nil)
if err == nil {
t.Error("Expected connection to timeout")
}
}
func TestConnectionWriteMessage(t *testing.T) {
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Read the message
var msg string
websocket.Message.Receive(conn, &msg)
// Echo it back
websocket.Message.Send(conn, "received: "+msg)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
testMessage := []byte("test message")
err = conn.WriteMessage(ctx, testMessage)
if err != nil {
t.Fatalf("WriteMessage failed: %v", err)
}
}
func TestConnectionWriteMessageCanceled(t *testing.T) {
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Just keep the connection open
time.Sleep(10 * time.Second)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
// Cancel the context immediately
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
err = conn.WriteMessage(cancelCtx, []byte("test"))
if err == nil {
t.Error("Expected WriteMessage to fail with canceled context")
}
if !strings.Contains(err.Error(), "context canceled") {
t.Errorf("Expected context canceled error, got: %v", err)
}
}
func TestConnectionReadMessage(t *testing.T) {
testMessage := "hello world"
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Send a message immediately
websocket.Message.Send(conn, testMessage)
// Keep connection open for a bit
time.Sleep(100 * time.Millisecond)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
buf := &bytes.Buffer{}
err = conn.ReadMessage(ctx, buf)
if err != nil {
t.Fatalf("ReadMessage failed: %v", err)
}
received := buf.String()
if received != testMessage {
t.Errorf("Expected %s, got %s", testMessage, received)
}
}
func TestConnectionReadMessageCanceled(t *testing.T) {
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Don't send anything, just keep connection open
time.Sleep(10 * time.Second)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
// Cancel the context immediately
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
buf := &bytes.Buffer{}
err = conn.ReadMessage(cancelCtx, buf)
if err == nil {
t.Error("Expected ReadMessage to fail with canceled context")
}
if !strings.Contains(err.Error(), "context canceled") {
t.Errorf("Expected context canceled error, got: %v", err)
}
}
func TestConnectionClose(t *testing.T) {
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
// Keep connection open
time.Sleep(1 * time.Second)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
err = conn.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
}
func TestConnectionWithCompression(t *testing.T) {
// This test verifies that compression is handled properly
// The actual compression negotiation happens during the handshake
server := httptest.NewServer(&websocket.Server{
Handshake: anyOriginHandshake,
Handler: func(conn *websocket.Conn) {
var msg string
websocket.Message.Receive(conn, &msg)
websocket.Message.Send(conn, "compressed: "+msg)
},
})
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := NewConnection(ctx, wsURL, nil, nil)
if err != nil {
t.Fatalf("NewConnection failed: %v", err)
}
defer conn.Close()
// Test that we can write and read even if compression is enabled
testMessage := []byte("test compression message")
err = conn.WriteMessage(ctx, testMessage)
if err != nil {
t.Fatalf("WriteMessage failed: %v", err)
}
buf := &bytes.Buffer{}
err = conn.ReadMessage(ctx, buf)
if err != nil {
t.Fatalf("ReadMessage failed: %v", err)
}
if !strings.Contains(buf.String(), "compressed:") {
t.Error("Expected response to contain 'compressed:'")
}
}