Files
realy/ws/pool_test.go

299 lines
6.6 KiB
Go

package ws
import (
"context"
"sync"
"testing"
"time"
"realy.lol/event"
"realy.lol/filter"
"realy.lol/filters"
"realy.lol/kind"
"realy.lol/kinds"
"realy.lol/signer"
"realy.lol/timestamp"
)
// mockSigner implements signer.I for testing
type mockSigner struct {
pubkey []byte
}
func (m *mockSigner) Pub() []byte { return m.pubkey }
func (m *mockSigner) Sign([]byte) ([]byte, error) { return []byte("mock-signature"), nil }
func (m *mockSigner) Generate() error { return nil }
func (m *mockSigner) InitSec([]byte) error { return nil }
func (m *mockSigner) InitPub([]byte) error { return nil }
func (m *mockSigner) Sec() []byte { return []byte("mock-secret") }
func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil }
func (m *mockSigner) Zero() {}
func (m *mockSigner) ECDH([]byte) ([]byte, error) { return []byte("mock-shared-secret"), nil }
func TestNewPool(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
if pool == nil {
t.Fatal("NewPool returned nil")
}
if pool.Relays == nil {
t.Error("Pool should have initialized Relays map")
}
if pool.Context == nil {
t.Error("Pool should have a context")
}
}
func TestPoolWithAuthHandler(t *testing.T) {
ctx := context.Background()
authHandler := WithAuthHandler(func() signer.I {
return &mockSigner{pubkey: []byte("test-pubkey")}
})
pool := NewPool(ctx, authHandler)
if pool.authHandler == nil {
t.Error("Pool should have auth handler set")
}
// Test that auth handler returns the expected signer
signer := pool.authHandler()
if string(signer.Pub()) != "test-pubkey" {
t.Errorf("Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()))
}
}
func TestPoolWithEventMiddleware(t *testing.T) {
ctx := context.Background()
var middlewareCalled bool
middleware := WithEventMiddleware(func(ie IncomingEvent) {
middlewareCalled = true
})
pool := NewPool(ctx, middleware)
if len(pool.eventMiddleware) != 1 {
t.Errorf("Expected 1 middleware, got %d", len(pool.eventMiddleware))
}
// Test that middleware is called
testEvent := &event.T{
Kind: kind.TextNote,
Content: []byte("test"),
CreatedAt: timestamp.Now(),
}
ie := IncomingEvent{Event: testEvent, Client: nil}
pool.eventMiddleware[0](ie)
if !middlewareCalled {
t.Error("Expected middleware to be called")
}
}
func TestIncomingEventString(t *testing.T) {
testEvent := &event.T{
Kind: kind.TextNote,
Content: []byte("test content"),
CreatedAt: timestamp.Now(),
}
client := &Client{URL: "wss://test.relay"}
ie := IncomingEvent{Event: testEvent, Client: client}
str := ie.String()
if !contains(str, "wss://test.relay") {
t.Errorf("Expected string to contain relay URL, got: %s", str)
}
if !contains(str, "test content") {
t.Errorf("Expected string to contain event content, got: %s", str)
}
}
func TestNamedLock(t *testing.T) {
// Test that named locks work correctly
var wg sync.WaitGroup
var counter int
var mu sync.Mutex
lockName := "test-lock"
// Start multiple goroutines that try to increment counter
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
unlock := namedLock(lockName)
defer unlock()
// Critical section
mu.Lock()
temp := counter
time.Sleep(1 * time.Millisecond) // Simulate work
counter = temp + 1
mu.Unlock()
}()
}
wg.Wait()
if counter != 10 {
t.Errorf("Expected counter to be 10, got %d", counter)
}
}
func TestDirectedFilters(t *testing.T) {
f := &filter.T{
Kinds: kinds.New(kind.TextNote),
Limit: uintPtr(10),
}
df := DirectedFilters{
Filters: filters.New(f),
Client: "wss://test.relay",
}
if df.Client != "wss://test.relay" {
t.Errorf("Expected client to be 'wss://test.relay', got '%s'", df.Client)
}
if df.Filters == nil {
t.Error("Expected filters to be set")
}
}
func TestPoolEnsureRelayInvalidURL(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with invalid URL
_, err := pool.EnsureRelay("invalid-url")
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestPoolQuerySingle(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with empty URLs slice
result := pool.QuerySingle(ctx, []string{}, &filter.T{})
if result != nil {
t.Error("Expected nil result for empty URLs")
}
}
// Helper functions
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr)))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func uintPtr(u uint) *uint {
return &u
}
// Test pool context cancellation
func TestPoolContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := NewPool(ctx)
// Cancel the context
cancel()
// Check that pool context is cancelled
select {
case <-pool.Context.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Expected pool context to be cancelled")
}
}
// Test multiple middleware
func TestPoolMultipleMiddleware(t *testing.T) {
ctx := context.Background()
var middleware1Called, middleware2Called bool
middleware1 := WithEventMiddleware(func(ie IncomingEvent) {
middleware1Called = true
})
middleware2 := WithEventMiddleware(func(ie IncomingEvent) {
middleware2Called = true
})
pool := NewPool(ctx, middleware1, middleware2)
if len(pool.eventMiddleware) != 2 {
t.Errorf("Expected 2 middleware, got %d", len(pool.eventMiddleware))
}
// Test that both middleware are called
testEvent := &event.T{
Kind: kind.TextNote,
Content: []byte("test"),
CreatedAt: timestamp.Now(),
}
ie := IncomingEvent{Event: testEvent, Client: nil}
for _, mw := range pool.eventMiddleware {
mw(ie)
}
if !middleware1Called {
t.Error("Expected middleware1 to be called")
}
if !middleware2Called {
t.Error("Expected middleware2 to be called")
}
}
// Test pool with signature checker
func TestPoolWithSignatureChecker(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Set a custom signature checker
pool.SignatureChecker = func(e *event.T) bool {
return string(e.Content) == "valid"
}
if pool.SignatureChecker == nil {
t.Error("Expected signature checker to be set")
}
// Test the signature checker
validEvent := &event.T{Content: []byte("valid")}
invalidEvent := &event.T{Content: []byte("invalid")}
if !pool.SignatureChecker(validEvent) {
t.Error("Expected valid event to pass signature check")
}
if pool.SignatureChecker(invalidEvent) {
t.Error("Expected invalid event to fail signature check")
}
}