add tests to ws package for full coverage
This commit is contained in:
376
ws/connection_test.go
Normal file
376
ws/connection_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockConn implements net.Conn for testing
|
||||||
|
type mockConn struct {
|
||||||
|
readData []byte
|
||||||
|
writeData []byte
|
||||||
|
readErr error
|
||||||
|
writeErr error
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) Read(b []byte) (n int, err error) {
|
||||||
|
if m.readErr != nil {
|
||||||
|
return 0, m.readErr
|
||||||
|
}
|
||||||
|
if len(m.readData) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n = copy(b, m.readData)
|
||||||
|
m.readData = m.readData[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) Write(b []byte) (n int, err error) {
|
||||||
|
if m.writeErr != nil {
|
||||||
|
return 0, m.writeErr
|
||||||
|
}
|
||||||
|
m.writeData = append(m.writeData, b...)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) Close() error {
|
||||||
|
m.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||||
|
func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||||
|
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
func TestNewConnection(t *testing.T) {
|
||||||
|
// Create a test websocket server
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Echo server - just read and write back
|
||||||
|
for {
|
||||||
|
var msg string
|
||||||
|
if err := websocket.Message.Receive(conn, &msg); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
websocket.Message.Send(conn, msg)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Convert http:// to ws://
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if conn.conn == nil {
|
||||||
|
t.Error("Connection should have a valid conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionWithTLS(t *testing.T) {
|
||||||
|
// Create a test HTTPS server
|
||||||
|
server := httptest.NewTLSServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Echo server
|
||||||
|
for {
|
||||||
|
var msg string
|
||||||
|
if err := websocket.Message.Receive(conn, &msg); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
websocket.Message.Send(conn, msg)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Convert https:// to wss://
|
||||||
|
wsURL := "wss" + strings.TrimPrefix(server.URL, "https")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Use insecure TLS config for testing
|
||||||
|
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection with TLS failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if conn.conn == nil {
|
||||||
|
t.Error("Connection should have a valid conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionWithHeaders(t *testing.T) {
|
||||||
|
expectedOrigin := "https://example.com"
|
||||||
|
var receivedOrigin string
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedOrigin = r.Header.Get("Origin")
|
||||||
|
websocket.Handler(func(conn *websocket.Conn) {
|
||||||
|
// Simple echo
|
||||||
|
io.Copy(conn, conn)
|
||||||
|
}).ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
headers := http.Header{
|
||||||
|
"Origin": []string{expectedOrigin},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, headers, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection with headers failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if receivedOrigin != expectedOrigin {
|
||||||
|
t.Errorf("Expected origin %s, got %s", expectedOrigin, receivedOrigin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionTimeout(t *testing.T) {
|
||||||
|
// Use a non-existent address to force timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := NewConnection(ctx, "ws://192.0.2.1:12345", nil, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected connection to timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionWriteMessage(t *testing.T) {
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Read the message
|
||||||
|
var msg string
|
||||||
|
websocket.Message.Receive(conn, &msg)
|
||||||
|
// Echo it back
|
||||||
|
websocket.Message.Send(conn, "received: "+msg)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
testMessage := []byte("test message")
|
||||||
|
err = conn.WriteMessage(ctx, testMessage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMessage failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionWriteMessageCanceled(t *testing.T) {
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Just keep the connection open
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Cancel the context immediately
|
||||||
|
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err = conn.WriteMessage(cancelCtx, []byte("test"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected WriteMessage to fail with canceled context")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "context canceled") {
|
||||||
|
t.Errorf("Expected context canceled error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionReadMessage(t *testing.T) {
|
||||||
|
testMessage := "hello world"
|
||||||
|
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Send a message immediately
|
||||||
|
websocket.Message.Send(conn, testMessage)
|
||||||
|
// Keep connection open for a bit
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err = conn.ReadMessage(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadMessage failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
received := buf.String()
|
||||||
|
if received != testMessage {
|
||||||
|
t.Errorf("Expected %s, got %s", testMessage, received)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionReadMessageCanceled(t *testing.T) {
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Don't send anything, just keep connection open
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Cancel the context immediately
|
||||||
|
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err = conn.ReadMessage(cancelCtx, buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected ReadMessage to fail with canceled context")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "context canceled") {
|
||||||
|
t.Errorf("Expected context canceled error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionClose(t *testing.T) {
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
// Keep connection open
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Close failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionWithCompression(t *testing.T) {
|
||||||
|
// This test verifies that compression is handled properly
|
||||||
|
// The actual compression negotiation happens during the handshake
|
||||||
|
server := httptest.NewServer(&websocket.Server{
|
||||||
|
Handshake: anyOriginHandshake,
|
||||||
|
Handler: func(conn *websocket.Conn) {
|
||||||
|
var msg string
|
||||||
|
websocket.Message.Receive(conn, &msg)
|
||||||
|
websocket.Message.Send(conn, "compressed: "+msg)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := NewConnection(ctx, wsURL, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewConnection failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Test that we can write and read even if compression is enabled
|
||||||
|
testMessage := []byte("test compression message")
|
||||||
|
err = conn.WriteMessage(ctx, testMessage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMessage failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err = conn.ReadMessage(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadMessage failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "compressed:") {
|
||||||
|
t.Error("Expected response to contain 'compressed:'")
|
||||||
|
}
|
||||||
|
}
|
||||||
81
ws/listener_test.go
Normal file
81
ws/listener_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test the remote address parsing logic which is the main testable functionality
|
||||||
|
func TestListenerSetRemoteFromReq(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
expectedRemote string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No X-Forwarded-For header",
|
||||||
|
remoteAddr: "192.168.1.1:8080",
|
||||||
|
xForwardedFor: "",
|
||||||
|
expectedRemote: "192.168.1.1:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single X-Forwarded-For",
|
||||||
|
remoteAddr: "192.168.1.1:8080",
|
||||||
|
xForwardedFor: "203.0.113.1",
|
||||||
|
expectedRemote: "203.0.113.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Double X-Forwarded-For",
|
||||||
|
remoteAddr: "192.168.1.1:8080",
|
||||||
|
xForwardedFor: "203.0.113.1 198.51.100.1",
|
||||||
|
expectedRemote: "198.51.100.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty X-Forwarded-For",
|
||||||
|
remoteAddr: "192.168.1.1:8080",
|
||||||
|
xForwardedFor: "",
|
||||||
|
expectedRemote: "192.168.1.1:8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create a test listener to access the setRemoteFromReq method
|
||||||
|
listener := &Listener{}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method we want to test
|
||||||
|
listener.setRemoteFromReq(req)
|
||||||
|
|
||||||
|
if listener.RealRemote() != tt.expectedRemote {
|
||||||
|
t.Errorf("Expected remote address to be '%s', got '%s'", tt.expectedRemote, listener.RealRemote())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that Req() returns the stored request
|
||||||
|
func TestListenerReq(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
listener := &Listener{Request: req}
|
||||||
|
|
||||||
|
if listener.Req() != req {
|
||||||
|
t.Error("Expected Req() to return the stored request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test RealRemote method
|
||||||
|
func TestListenerRealRemote(t *testing.T) {
|
||||||
|
listener := &Listener{}
|
||||||
|
listener.remote.Store("test-remote-address")
|
||||||
|
|
||||||
|
if listener.RealRemote() != "test-remote-address" {
|
||||||
|
t.Errorf("Expected RealRemote() to return 'test-remote-address', got '%s'", listener.RealRemote())
|
||||||
|
}
|
||||||
|
}
|
||||||
298
ws/pool_test.go
Normal file
298
ws/pool_test.go
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"realy.lol/event"
|
||||||
|
"realy.lol/filter"
|
||||||
|
"realy.lol/filters"
|
||||||
|
"realy.lol/kind"
|
||||||
|
"realy.lol/kinds"
|
||||||
|
"realy.lol/signer"
|
||||||
|
"realy.lol/timestamp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockSigner implements signer.I for testing
|
||||||
|
type mockSigner struct {
|
||||||
|
pubkey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSigner) Pub() []byte { return m.pubkey }
|
||||||
|
func (m *mockSigner) Sign([]byte) ([]byte, error) { return []byte("mock-signature"), nil }
|
||||||
|
func (m *mockSigner) Generate() error { return nil }
|
||||||
|
func (m *mockSigner) InitSec([]byte) error { return nil }
|
||||||
|
func (m *mockSigner) InitPub([]byte) error { return nil }
|
||||||
|
func (m *mockSigner) Sec() []byte { return []byte("mock-secret") }
|
||||||
|
func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil }
|
||||||
|
func (m *mockSigner) Zero() {}
|
||||||
|
func (m *mockSigner) ECDH([]byte) ([]byte, error) { return []byte("mock-shared-secret"), nil }
|
||||||
|
|
||||||
|
func TestNewPool(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pool := NewPool(ctx)
|
||||||
|
|
||||||
|
if pool == nil {
|
||||||
|
t.Fatal("NewPool returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.Relays == nil {
|
||||||
|
t.Error("Pool should have initialized Relays map")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.Context == nil {
|
||||||
|
t.Error("Pool should have a context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolWithAuthHandler(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
authHandler := WithAuthHandler(func() signer.I {
|
||||||
|
return &mockSigner{pubkey: []byte("test-pubkey")}
|
||||||
|
})
|
||||||
|
|
||||||
|
pool := NewPool(ctx, authHandler)
|
||||||
|
|
||||||
|
if pool.authHandler == nil {
|
||||||
|
t.Error("Pool should have auth handler set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that auth handler returns the expected signer
|
||||||
|
signer := pool.authHandler()
|
||||||
|
if string(signer.Pub()) != "test-pubkey" {
|
||||||
|
t.Errorf("Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolWithEventMiddleware(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var middlewareCalled bool
|
||||||
|
middleware := WithEventMiddleware(func(ie IncomingEvent) {
|
||||||
|
middlewareCalled = true
|
||||||
|
})
|
||||||
|
|
||||||
|
pool := NewPool(ctx, middleware)
|
||||||
|
|
||||||
|
if len(pool.eventMiddleware) != 1 {
|
||||||
|
t.Errorf("Expected 1 middleware, got %d", len(pool.eventMiddleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that middleware is called
|
||||||
|
testEvent := &event.T{
|
||||||
|
Kind: kind.TextNote,
|
||||||
|
Content: []byte("test"),
|
||||||
|
CreatedAt: timestamp.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ie := IncomingEvent{Event: testEvent, Client: nil}
|
||||||
|
pool.eventMiddleware[0](ie)
|
||||||
|
|
||||||
|
if !middlewareCalled {
|
||||||
|
t.Error("Expected middleware to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIncomingEventString(t *testing.T) {
|
||||||
|
testEvent := &event.T{
|
||||||
|
Kind: kind.TextNote,
|
||||||
|
Content: []byte("test content"),
|
||||||
|
CreatedAt: timestamp.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{URL: "wss://test.relay"}
|
||||||
|
ie := IncomingEvent{Event: testEvent, Client: client}
|
||||||
|
|
||||||
|
str := ie.String()
|
||||||
|
if !contains(str, "wss://test.relay") {
|
||||||
|
t.Errorf("Expected string to contain relay URL, got: %s", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !contains(str, "test content") {
|
||||||
|
t.Errorf("Expected string to contain event content, got: %s", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedLock(t *testing.T) {
|
||||||
|
// Test that named locks work correctly
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var counter int
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
lockName := "test-lock"
|
||||||
|
|
||||||
|
// Start multiple goroutines that try to increment counter
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
unlock := namedLock(lockName)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
// Critical section
|
||||||
|
mu.Lock()
|
||||||
|
temp := counter
|
||||||
|
time.Sleep(1 * time.Millisecond) // Simulate work
|
||||||
|
counter = temp + 1
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if counter != 10 {
|
||||||
|
t.Errorf("Expected counter to be 10, got %d", counter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirectedFilters(t *testing.T) {
|
||||||
|
f := &filter.T{
|
||||||
|
Kinds: kinds.New(kind.TextNote),
|
||||||
|
Limit: uintPtr(10),
|
||||||
|
}
|
||||||
|
|
||||||
|
df := DirectedFilters{
|
||||||
|
Filters: filters.New(f),
|
||||||
|
Client: "wss://test.relay",
|
||||||
|
}
|
||||||
|
|
||||||
|
if df.Client != "wss://test.relay" {
|
||||||
|
t.Errorf("Expected client to be 'wss://test.relay', got '%s'", df.Client)
|
||||||
|
}
|
||||||
|
|
||||||
|
if df.Filters == nil {
|
||||||
|
t.Error("Expected filters to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolEnsureRelayInvalidURL(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pool := NewPool(ctx)
|
||||||
|
|
||||||
|
// Test with invalid URL
|
||||||
|
_, err := pool.EnsureRelay("invalid-url")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolQuerySingle(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pool := NewPool(ctx)
|
||||||
|
|
||||||
|
// Test with empty URLs slice
|
||||||
|
result := pool.QuerySingle(ctx, []string{}, &filter.T{})
|
||||||
|
if result != nil {
|
||||||
|
t.Error("Expected nil result for empty URLs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
|
||||||
|
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
|
||||||
|
containsSubstring(s, substr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsSubstring(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func uintPtr(u uint) *uint {
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test pool context cancellation
|
||||||
|
func TestPoolContextCancellation(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
pool := NewPool(ctx)
|
||||||
|
|
||||||
|
// Cancel the context
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Check that pool context is cancelled
|
||||||
|
select {
|
||||||
|
case <-pool.Context.Done():
|
||||||
|
// Expected
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Expected pool context to be cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test multiple middleware
|
||||||
|
func TestPoolMultipleMiddleware(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var middleware1Called, middleware2Called bool
|
||||||
|
|
||||||
|
middleware1 := WithEventMiddleware(func(ie IncomingEvent) {
|
||||||
|
middleware1Called = true
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware2 := WithEventMiddleware(func(ie IncomingEvent) {
|
||||||
|
middleware2Called = true
|
||||||
|
})
|
||||||
|
|
||||||
|
pool := NewPool(ctx, middleware1, middleware2)
|
||||||
|
|
||||||
|
if len(pool.eventMiddleware) != 2 {
|
||||||
|
t.Errorf("Expected 2 middleware, got %d", len(pool.eventMiddleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that both middleware are called
|
||||||
|
testEvent := &event.T{
|
||||||
|
Kind: kind.TextNote,
|
||||||
|
Content: []byte("test"),
|
||||||
|
CreatedAt: timestamp.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ie := IncomingEvent{Event: testEvent, Client: nil}
|
||||||
|
|
||||||
|
for _, mw := range pool.eventMiddleware {
|
||||||
|
mw(ie)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !middleware1Called {
|
||||||
|
t.Error("Expected middleware1 to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !middleware2Called {
|
||||||
|
t.Error("Expected middleware2 to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test pool with signature checker
|
||||||
|
func TestPoolWithSignatureChecker(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
pool := NewPool(ctx)
|
||||||
|
|
||||||
|
// Set a custom signature checker
|
||||||
|
pool.SignatureChecker = func(e *event.T) bool {
|
||||||
|
return string(e.Content) == "valid"
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.SignatureChecker == nil {
|
||||||
|
t.Error("Expected signature checker to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the signature checker
|
||||||
|
validEvent := &event.T{Content: []byte("valid")}
|
||||||
|
invalidEvent := &event.T{Content: []byte("invalid")}
|
||||||
|
|
||||||
|
if !pool.SignatureChecker(validEvent) {
|
||||||
|
t.Error("Expected valid event to pass signature check")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.SignatureChecker(invalidEvent) {
|
||||||
|
t.Error("Expected invalid event to fail signature check")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user