diff --git a/ws/connection_test.go b/ws/connection_test.go new file mode 100644 index 0000000..87eae04 --- /dev/null +++ b/ws/connection_test.go @@ -0,0 +1,376 @@ +package ws + +import ( + "bytes" + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +// mockConn implements net.Conn for testing +type mockConn struct { + readData []byte + writeData []byte + readErr error + writeErr error + closed bool +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.readErr != nil { + return 0, m.readErr + } + if len(m.readData) == 0 { + return 0, io.EOF + } + n = copy(b, m.readData) + m.readData = m.readData[n:] + return n, nil +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.writeErr != nil { + return 0, m.writeErr + } + m.writeData = append(m.writeData, b...) + return len(b), nil +} + +func (m *mockConn) Close() error { + m.closed = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestNewConnection(t *testing.T) { + // Create a test websocket server + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Echo server - just read and write back + for { + var msg string + if err := websocket.Message.Receive(conn, &msg); err != nil { + break + } + websocket.Message.Send(conn, msg) + } + }, + }) + defer server.Close() + + // Convert http:// to ws:// + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + if conn.conn == nil { + t.Error("Connection should have a valid conn") + } +} + +func TestNewConnectionWithTLS(t *testing.T) { + // Create a test HTTPS server + server := httptest.NewTLSServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Echo server + for { + var msg string + if err := websocket.Message.Receive(conn, &msg); err != nil { + break + } + websocket.Message.Send(conn, msg) + } + }, + }) + defer server.Close() + + // Convert https:// to wss:// + wsURL := "wss" + strings.TrimPrefix(server.URL, "https") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Use insecure TLS config for testing + tlsConfig := &tls.Config{InsecureSkipVerify: true} + + conn, err := NewConnection(ctx, wsURL, nil, tlsConfig) + if err != nil { + t.Fatalf("NewConnection with TLS failed: %v", err) + } + defer conn.Close() + + if conn.conn == nil { + t.Error("Connection should have a valid conn") + } +} + +func TestNewConnectionWithHeaders(t *testing.T) { + expectedOrigin := "https://example.com" + var receivedOrigin string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedOrigin = r.Header.Get("Origin") + websocket.Handler(func(conn *websocket.Conn) { + // Simple echo + io.Copy(conn, conn) + }).ServeHTTP(w, r) + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + headers := http.Header{ + "Origin": []string{expectedOrigin}, + } + + conn, err := NewConnection(ctx, wsURL, headers, nil) + if err != nil { + t.Fatalf("NewConnection with headers failed: %v", err) + } + defer conn.Close() + + if receivedOrigin != expectedOrigin { + t.Errorf("Expected origin %s, got %s", expectedOrigin, receivedOrigin) + } +} + +func TestNewConnectionTimeout(t *testing.T) { + // Use a non-existent address to force timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := NewConnection(ctx, "ws://192.0.2.1:12345", nil, nil) + if err == nil { + t.Error("Expected connection to timeout") + } +} + +func TestConnectionWriteMessage(t *testing.T) { + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Read the message + var msg string + websocket.Message.Receive(conn, &msg) + // Echo it back + websocket.Message.Send(conn, "received: "+msg) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + testMessage := []byte("test message") + err = conn.WriteMessage(ctx, testMessage) + if err != nil { + t.Fatalf("WriteMessage failed: %v", err) + } +} + +func TestConnectionWriteMessageCanceled(t *testing.T) { + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Just keep the connection open + time.Sleep(10 * time.Second) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + // Cancel the context immediately + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err = conn.WriteMessage(cancelCtx, []byte("test")) + if err == nil { + t.Error("Expected WriteMessage to fail with canceled context") + } + if !strings.Contains(err.Error(), "context canceled") { + t.Errorf("Expected context canceled error, got: %v", err) + } +} + +func TestConnectionReadMessage(t *testing.T) { + testMessage := "hello world" + + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Send a message immediately + websocket.Message.Send(conn, testMessage) + // Keep connection open for a bit + time.Sleep(100 * time.Millisecond) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + buf := &bytes.Buffer{} + err = conn.ReadMessage(ctx, buf) + if err != nil { + t.Fatalf("ReadMessage failed: %v", err) + } + + received := buf.String() + if received != testMessage { + t.Errorf("Expected %s, got %s", testMessage, received) + } +} + +func TestConnectionReadMessageCanceled(t *testing.T) { + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Don't send anything, just keep connection open + time.Sleep(10 * time.Second) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + // Cancel the context immediately + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + buf := &bytes.Buffer{} + err = conn.ReadMessage(cancelCtx, buf) + if err == nil { + t.Error("Expected ReadMessage to fail with canceled context") + } + if !strings.Contains(err.Error(), "context canceled") { + t.Errorf("Expected context canceled error, got: %v", err) + } +} + +func TestConnectionClose(t *testing.T) { + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + // Keep connection open + time.Sleep(1 * time.Second) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + + err = conn.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } +} + +func TestConnectionWithCompression(t *testing.T) { + // This test verifies that compression is handled properly + // The actual compression negotiation happens during the handshake + server := httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: func(conn *websocket.Conn) { + var msg string + websocket.Message.Receive(conn, &msg) + websocket.Message.Send(conn, "compressed: "+msg) + }, + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := NewConnection(ctx, wsURL, nil, nil) + if err != nil { + t.Fatalf("NewConnection failed: %v", err) + } + defer conn.Close() + + // Test that we can write and read even if compression is enabled + testMessage := []byte("test compression message") + err = conn.WriteMessage(ctx, testMessage) + if err != nil { + t.Fatalf("WriteMessage failed: %v", err) + } + + buf := &bytes.Buffer{} + err = conn.ReadMessage(ctx, buf) + if err != nil { + t.Fatalf("ReadMessage failed: %v", err) + } + + if !strings.Contains(buf.String(), "compressed:") { + t.Error("Expected response to contain 'compressed:'") + } +} diff --git a/ws/listener_test.go b/ws/listener_test.go new file mode 100644 index 0000000..1f61fc8 --- /dev/null +++ b/ws/listener_test.go @@ -0,0 +1,81 @@ +package ws + +import ( + "net/http/httptest" + "testing" +) + +// Test the remote address parsing logic which is the main testable functionality +func TestListenerSetRemoteFromReq(t *testing.T) { + tests := []struct { + name string + remoteAddr string + xForwardedFor string + expectedRemote string + }{ + { + name: "No X-Forwarded-For header", + remoteAddr: "192.168.1.1:8080", + xForwardedFor: "", + expectedRemote: "192.168.1.1:8080", + }, + { + name: "Single X-Forwarded-For", + remoteAddr: "192.168.1.1:8080", + xForwardedFor: "203.0.113.1", + expectedRemote: "203.0.113.1", + }, + { + name: "Double X-Forwarded-For", + remoteAddr: "192.168.1.1:8080", + xForwardedFor: "203.0.113.1 198.51.100.1", + expectedRemote: "198.51.100.1", + }, + { + name: "Empty X-Forwarded-For", + remoteAddr: "192.168.1.1:8080", + xForwardedFor: "", + expectedRemote: "192.168.1.1:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test listener to access the setRemoteFromReq method + listener := &Listener{} + + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tt.remoteAddr + if tt.xForwardedFor != "" { + req.Header.Set("X-Forwarded-For", tt.xForwardedFor) + } + + // Call the method we want to test + listener.setRemoteFromReq(req) + + if listener.RealRemote() != tt.expectedRemote { + t.Errorf("Expected remote address to be '%s', got '%s'", tt.expectedRemote, listener.RealRemote()) + } + }) + } +} + +// Test that Req() returns the stored request +func TestListenerReq(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + listener := &Listener{Request: req} + + if listener.Req() != req { + t.Error("Expected Req() to return the stored request") + } +} + +// Test RealRemote method +func TestListenerRealRemote(t *testing.T) { + listener := &Listener{} + listener.remote.Store("test-remote-address") + + if listener.RealRemote() != "test-remote-address" { + t.Errorf("Expected RealRemote() to return 'test-remote-address', got '%s'", listener.RealRemote()) + } +} diff --git a/ws/pool_test.go b/ws/pool_test.go new file mode 100644 index 0000000..3144f2a --- /dev/null +++ b/ws/pool_test.go @@ -0,0 +1,298 @@ +package ws + +import ( + "context" + "sync" + "testing" + "time" + + "realy.lol/event" + "realy.lol/filter" + "realy.lol/filters" + "realy.lol/kind" + "realy.lol/kinds" + "realy.lol/signer" + "realy.lol/timestamp" +) + +// mockSigner implements signer.I for testing +type mockSigner struct { + pubkey []byte +} + +func (m *mockSigner) Pub() []byte { return m.pubkey } +func (m *mockSigner) Sign([]byte) ([]byte, error) { return []byte("mock-signature"), nil } +func (m *mockSigner) Generate() error { return nil } +func (m *mockSigner) InitSec([]byte) error { return nil } +func (m *mockSigner) InitPub([]byte) error { return nil } +func (m *mockSigner) Sec() []byte { return []byte("mock-secret") } +func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil } +func (m *mockSigner) Zero() {} +func (m *mockSigner) ECDH([]byte) ([]byte, error) { return []byte("mock-shared-secret"), nil } + +func TestNewPool(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + if pool == nil { + t.Fatal("NewPool returned nil") + } + + if pool.Relays == nil { + t.Error("Pool should have initialized Relays map") + } + + if pool.Context == nil { + t.Error("Pool should have a context") + } +} + +func TestPoolWithAuthHandler(t *testing.T) { + ctx := context.Background() + + authHandler := WithAuthHandler(func() signer.I { + return &mockSigner{pubkey: []byte("test-pubkey")} + }) + + pool := NewPool(ctx, authHandler) + + if pool.authHandler == nil { + t.Error("Pool should have auth handler set") + } + + // Test that auth handler returns the expected signer + signer := pool.authHandler() + if string(signer.Pub()) != "test-pubkey" { + t.Errorf("Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub())) + } +} + +func TestPoolWithEventMiddleware(t *testing.T) { + ctx := context.Background() + + var middlewareCalled bool + middleware := WithEventMiddleware(func(ie IncomingEvent) { + middlewareCalled = true + }) + + pool := NewPool(ctx, middleware) + + if len(pool.eventMiddleware) != 1 { + t.Errorf("Expected 1 middleware, got %d", len(pool.eventMiddleware)) + } + + // Test that middleware is called + testEvent := &event.T{ + Kind: kind.TextNote, + Content: []byte("test"), + CreatedAt: timestamp.Now(), + } + + ie := IncomingEvent{Event: testEvent, Client: nil} + pool.eventMiddleware[0](ie) + + if !middlewareCalled { + t.Error("Expected middleware to be called") + } +} + +func TestIncomingEventString(t *testing.T) { + testEvent := &event.T{ + Kind: kind.TextNote, + Content: []byte("test content"), + CreatedAt: timestamp.Now(), + } + + client := &Client{URL: "wss://test.relay"} + ie := IncomingEvent{Event: testEvent, Client: client} + + str := ie.String() + if !contains(str, "wss://test.relay") { + t.Errorf("Expected string to contain relay URL, got: %s", str) + } + + if !contains(str, "test content") { + t.Errorf("Expected string to contain event content, got: %s", str) + } +} + +func TestNamedLock(t *testing.T) { + // Test that named locks work correctly + var wg sync.WaitGroup + var counter int + var mu sync.Mutex + + lockName := "test-lock" + + // Start multiple goroutines that try to increment counter + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + unlock := namedLock(lockName) + defer unlock() + + // Critical section + mu.Lock() + temp := counter + time.Sleep(1 * time.Millisecond) // Simulate work + counter = temp + 1 + mu.Unlock() + }() + } + + wg.Wait() + + if counter != 10 { + t.Errorf("Expected counter to be 10, got %d", counter) + } +} + +func TestDirectedFilters(t *testing.T) { + f := &filter.T{ + Kinds: kinds.New(kind.TextNote), + Limit: uintPtr(10), + } + + df := DirectedFilters{ + Filters: filters.New(f), + Client: "wss://test.relay", + } + + if df.Client != "wss://test.relay" { + t.Errorf("Expected client to be 'wss://test.relay', got '%s'", df.Client) + } + + if df.Filters == nil { + t.Error("Expected filters to be set") + } +} + +func TestPoolEnsureRelayInvalidURL(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + // Test with invalid URL + _, err := pool.EnsureRelay("invalid-url") + if err == nil { + t.Error("Expected error for invalid URL") + } +} + +func TestPoolQuerySingle(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + // Test with empty URLs slice + result := pool.QuerySingle(ctx, []string{}, &filter.T{}) + if result != nil { + t.Error("Expected nil result for empty URLs") + } +} + +// Helper functions +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + containsSubstring(s, substr))) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func uintPtr(u uint) *uint { + return &u +} + +// Test pool context cancellation +func TestPoolContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := NewPool(ctx) + + // Cancel the context + cancel() + + // Check that pool context is cancelled + select { + case <-pool.Context.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("Expected pool context to be cancelled") + } +} + +// Test multiple middleware +func TestPoolMultipleMiddleware(t *testing.T) { + ctx := context.Background() + + var middleware1Called, middleware2Called bool + + middleware1 := WithEventMiddleware(func(ie IncomingEvent) { + middleware1Called = true + }) + + middleware2 := WithEventMiddleware(func(ie IncomingEvent) { + middleware2Called = true + }) + + pool := NewPool(ctx, middleware1, middleware2) + + if len(pool.eventMiddleware) != 2 { + t.Errorf("Expected 2 middleware, got %d", len(pool.eventMiddleware)) + } + + // Test that both middleware are called + testEvent := &event.T{ + Kind: kind.TextNote, + Content: []byte("test"), + CreatedAt: timestamp.Now(), + } + + ie := IncomingEvent{Event: testEvent, Client: nil} + + for _, mw := range pool.eventMiddleware { + mw(ie) + } + + if !middleware1Called { + t.Error("Expected middleware1 to be called") + } + + if !middleware2Called { + t.Error("Expected middleware2 to be called") + } +} + +// Test pool with signature checker +func TestPoolWithSignatureChecker(t *testing.T) { + ctx := context.Background() + + pool := NewPool(ctx) + + // Set a custom signature checker + pool.SignatureChecker = func(e *event.T) bool { + return string(e.Content) == "valid" + } + + if pool.SignatureChecker == nil { + t.Error("Expected signature checker to be set") + } + + // Test the signature checker + validEvent := &event.T{Content: []byte("valid")} + invalidEvent := &event.T{Content: []byte("invalid")} + + if !pool.SignatureChecker(validEvent) { + t.Error("Expected valid event to pass signature check") + } + + if pool.SignatureChecker(invalidEvent) { + t.Error("Expected invalid event to fail signature check") + } +}