diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index 83eb62b..eea55b4 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -38,41 +38,107 @@ type BenchmarkResults struct { func main() { var ( - relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark") - eventCount = flag.Int("events", 10000, "Number of events to publish") - eventSize = flag.Int("size", 1024, "Average size of event content in bytes") - concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers") - queryCount = flag.Int("queries", 100, "Number of queries to execute") - queryLimit = flag.Int("query-limit", 100, "Limit for each query") - skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase") - skipQuery = flag.Bool("skip-query", false, "Skip query phase") - verbose = flag.Bool("v", false, "Verbose output") - multiRelay = flag.Bool("multi-relay", false, "Use multi-relay harness") - relayBinPath = flag.String("relay-bin", "", "Path to relay binary (for multi-relay mode)") - profileQueries = flag.Bool("profile", false, "Run query performance profiling") - profileSubs = flag.Bool("profile-subs", false, "Profile subscription performance") - subCount = flag.Int("sub-count", 100, "Number of concurrent subscriptions for profiling") - subDuration = flag.Duration("sub-duration", 30*time.Second, "Duration for subscription profiling") - installRelays = flag.Bool("install", false, "Install relay dependencies and binaries") - installSecp = flag.Bool("install-secp", false, "Install only secp256k1 library") - workDir = flag.String("work-dir", "/tmp/relay-build", "Working directory for builds") - installDir = flag.String("install-dir", "/usr/local/bin", "Installation directory for binaries") - generateReport = flag.Bool("report", false, "Generate comparative report") - reportFormat = flag.String("report-format", "markdown", "Report format: markdown, json, csv") - reportFile = flag.String("report-file", "benchmark_report", "Report output filename (without extension)") - reportTitle = flag.String("report-title", "Relay Benchmark Comparison", "Report title") - timingMode = flag.Bool("timing", false, "Run end-to-end timing instrumentation") - timingEvents = flag.Int("timing-events", 100, "Number of events for timing instrumentation") - timingSubs = flag.Bool("timing-subs", false, "Test subscription timing") - timingDuration = flag.Duration("timing-duration", 10*time.Second, "Duration for subscription timing test") - loadTest = flag.Bool("load", false, "Run load pattern simulation") - loadPattern = flag.String("load-pattern", "constant", "Load pattern: constant, spike, burst, sine, ramp") - loadDuration = flag.Duration("load-duration", 60*time.Second, "Duration for load test") - loadBase = flag.Int("load-base", 50, "Base load (events/sec)") - loadPeak = flag.Int("load-peak", 200, "Peak load (events/sec)") - loadPool = flag.Int("load-pool", 10, "Connection pool size for load testing") - loadSuite = flag.Bool("load-suite", false, "Run comprehensive load test suite") - loadConstraints = flag.Bool("load-constraints", false, "Test under resource constraints") + relayURL = flag.String( + "relay", "ws://localhost:7447", "Client URL to benchmark", + ) + eventCount = flag.Int( + "events", 10000, "Number of events to publish", + ) + eventSize = flag.Int( + "size", 1024, "Average size of event content in bytes", + ) + concurrency = flag.Int( + "concurrency", 10, "Number of concurrent publishers", + ) + queryCount = flag.Int( + "queries", 100, "Number of queries to execute", + ) + queryLimit = flag.Int("query-limit", 100, "Limit for each query") + skipPublish = flag.Bool( + "skip-publish", false, "Skip publishing phase", + ) + skipQuery = flag.Bool("skip-query", false, "Skip query phase") + verbose = flag.Bool("v", false, "Verbose output") + multiRelay = flag.Bool( + "multi-relay", false, "Use multi-relay harness", + ) + relayBinPath = flag.String( + "relay-bin", "", "Path to relay binary (for multi-relay mode)", + ) + profileQueries = flag.Bool( + "profile", false, "Run query performance profiling", + ) + profileSubs = flag.Bool( + "profile-subs", false, "Profile subscription performance", + ) + subCount = flag.Int( + "sub-count", 100, + "Number of concurrent subscriptions for profiling", + ) + subDuration = flag.Duration( + "sub-duration", 30*time.Second, + "Duration for subscription profiling", + ) + installRelays = flag.Bool( + "install", false, "Install relay dependencies and binaries", + ) + installSecp = flag.Bool( + "install-secp", false, "Install only secp256k1 library", + ) + workDir = flag.String( + "work-dir", "/tmp/relay-build", "Working directory for builds", + ) + installDir = flag.String( + "install-dir", "/usr/local/bin", + "Installation directory for binaries", + ) + generateReport = flag.Bool( + "report", false, "Generate comparative report", + ) + reportFormat = flag.String( + "report-format", "markdown", "Report format: markdown, json, csv", + ) + reportFile = flag.String( + "report-file", "benchmark_report", + "Report output filename (without extension)", + ) + reportTitle = flag.String( + "report-title", "Client Benchmark Comparison", "Report title", + ) + timingMode = flag.Bool( + "timing", false, "Run end-to-end timing instrumentation", + ) + timingEvents = flag.Int( + "timing-events", 100, "Number of events for timing instrumentation", + ) + timingSubs = flag.Bool( + "timing-subs", false, "Test subscription timing", + ) + timingDuration = flag.Duration( + "timing-duration", 10*time.Second, + "Duration for subscription timing test", + ) + loadTest = flag.Bool( + "load", false, "Run load pattern simulation", + ) + loadPattern = flag.String( + "load-pattern", "constant", + "Load pattern: constant, spike, burst, sine, ramp", + ) + loadDuration = flag.Duration( + "load-duration", 60*time.Second, "Duration for load test", + ) + loadBase = flag.Int("load-base", 50, "Base load (events/sec)") + loadPeak = flag.Int("load-peak", 200, "Peak load (events/sec)") + loadPool = flag.Int( + "load-pool", 10, "Connection pool size for load testing", + ) + loadSuite = flag.Bool( + "load-suite", false, "Run comprehensive load test suite", + ) + loadConstraints = flag.Bool( + "load-constraints", false, "Test under resource constraints", + ) ) flag.Parse() @@ -89,25 +155,46 @@ func main() { } else if *generateReport { runReportGeneration(*reportTitle, *reportFormat, *reportFile) } else if *loadTest || *loadSuite || *loadConstraints { - runLoadSimulation(c, *relayURL, *loadPattern, *loadDuration, *loadBase, *loadPeak, *loadPool, *eventSize, *loadSuite, *loadConstraints) + runLoadSimulation( + c, *relayURL, *loadPattern, *loadDuration, *loadBase, *loadPeak, + *loadPool, *eventSize, *loadSuite, *loadConstraints, + ) } else if *timingMode || *timingSubs { - runTimingInstrumentation(c, *relayURL, *timingEvents, *eventSize, *timingSubs, *timingDuration) + runTimingInstrumentation( + c, *relayURL, *timingEvents, *eventSize, *timingSubs, + *timingDuration, + ) } else if *profileQueries || *profileSubs { - runQueryProfiler(c, *relayURL, *queryCount, *concurrency, *profileSubs, *subCount, *subDuration) + runQueryProfiler( + c, *relayURL, *queryCount, *concurrency, *profileSubs, *subCount, + *subDuration, + ) } else if *multiRelay { - runMultiRelayBenchmark(c, *relayBinPath, *eventCount, *eventSize, *concurrency, *queryCount, *queryLimit, *skipPublish, *skipQuery) + runMultiRelayBenchmark( + c, *relayBinPath, *eventCount, *eventSize, *concurrency, + *queryCount, *queryLimit, *skipPublish, *skipQuery, + ) } else { - runSingleRelayBenchmark(c, *relayURL, *eventCount, *eventSize, *concurrency, *queryCount, *queryLimit, *skipPublish, *skipQuery) + runSingleRelayBenchmark( + c, *relayURL, *eventCount, *eventSize, *concurrency, *queryCount, + *queryLimit, *skipPublish, *skipQuery, + ) } } -func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize, concurrency, queryCount, queryLimit int, skipPublish, skipQuery bool) { +func runSingleRelayBenchmark( + c context.T, relayURL string, + eventCount, eventSize, concurrency, queryCount, queryLimit int, + skipPublish, skipQuery bool, +) { results := &BenchmarkResults{} // Phase 1: Publish events if !skipPublish { fmt.Printf("Publishing %d events to %s...\n", eventCount, relayURL) - if err := benchmarkPublish(c, relayURL, eventCount, eventSize, concurrency, results); chk.E(err) { + if err := benchmarkPublish( + c, relayURL, eventCount, eventSize, concurrency, results, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Error during publish benchmark: %v\n", err) os.Exit(1) } @@ -116,7 +203,9 @@ func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize // Phase 2: Query events if !skipQuery { fmt.Printf("\nQuerying events from %s...\n", relayURL) - if err := benchmarkQuery(c, relayURL, queryCount, queryLimit, results); chk.E(err) { + if err := benchmarkQuery( + c, relayURL, queryCount, queryLimit, results, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Error during query benchmark: %v\n", err) os.Exit(1) } @@ -126,7 +215,11 @@ func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize printResults(results) } -func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventSize, concurrency, queryCount, queryLimit int, skipPublish, skipQuery bool) { +func runMultiRelayBenchmark( + c context.T, relayBinPath string, + eventCount, eventSize, concurrency, queryCount, queryLimit int, + skipPublish, skipQuery bool, +) { harness := NewMultiRelayHarness() generator := NewReportGenerator() @@ -165,16 +258,26 @@ func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventS if !skipPublish { fmt.Printf("Publishing %d events to %s...\n", eventCount, relayURL) - if err := benchmarkPublish(c, relayURL, eventCount, eventSize, concurrency, results); chk.E(err) { - fmt.Fprintf(os.Stderr, "Error during publish benchmark for %s: %v\n", relayType, err) + if err := benchmarkPublish( + c, relayURL, eventCount, eventSize, concurrency, results, + ); chk.E(err) { + fmt.Fprintf( + os.Stderr, "Error during publish benchmark for %s: %v\n", + relayType, err, + ) continue } } if !skipQuery { fmt.Printf("\nQuerying events from %s...\n", relayURL) - if err := benchmarkQuery(c, relayURL, queryCount, queryLimit, results); chk.E(err) { - fmt.Fprintf(os.Stderr, "Error during query benchmark for %s: %v\n", relayType, err) + if err := benchmarkQuery( + c, relayURL, queryCount, queryLimit, results, + ); chk.E(err) { + fmt.Fprintf( + os.Stderr, "Error during query benchmark for %s: %v\n", + relayType, err, + ) continue } } @@ -190,16 +293,21 @@ func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventS generator.AddRelayData(relayType.String(), results, metrics, nil) } - generator.GenerateReport("Multi-Relay Benchmark Results") + generator.GenerateReport("Multi-Client Benchmark Results") - if err := SaveReportToFile("BENCHMARK_RESULTS.md", "markdown", generator); chk.E(err) { + if err := SaveReportToFile( + "BENCHMARK_RESULTS.md", "markdown", generator, + ); chk.E(err) { fmt.Printf("Warning: Failed to save benchmark results: %v\n", err) } else { fmt.Printf("\nBenchmark results saved to: BENCHMARK_RESULTS.md\n") } } -func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concurrency int, results *BenchmarkResults) error { +func benchmarkPublish( + c context.T, relayURL string, eventCount, eventSize, concurrency int, + results *BenchmarkResults, +) error { // Generate signers for each concurrent publisher signers := make([]*testSigner, concurrency) for i := range signers { @@ -244,7 +352,7 @@ func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concu for j := 0; j < eventsToPublish; j++ { ev := generateEvent(signer, eventSize, time.Duration(0), 0) - if err := relay.Publish(c, ev); err != nil { + if err = relay.Publish(c, ev); err != nil { log.E.F( "Publisher %d failed to publish event: %v", publisherID, err, @@ -372,7 +480,9 @@ func benchmarkQuery( return nil } -func generateEvent(signer *testSigner, contentSize int, rateLimit time.Duration, burstSize int) *event.E { +func generateEvent( + signer *testSigner, contentSize int, rateLimit time.Duration, burstSize int, +) *event.E { return generateSimpleEvent(signer, contentSize) } @@ -450,18 +560,31 @@ func printHarnessMetrics(relayType RelayType, metrics *HarnessMetrics) { } } -func runQueryProfiler(c context.T, relayURL string, queryCount, concurrency int, profileSubs bool, subCount int, subDuration time.Duration) { +func runQueryProfiler( + c context.T, relayURL string, queryCount, concurrency int, profileSubs bool, + subCount int, subDuration time.Duration, +) { profiler := NewQueryProfiler(relayURL) if profileSubs { - fmt.Printf("Profiling %d concurrent subscriptions for %v...\n", subCount, subDuration) - if err := profiler.TestSubscriptionPerformance(c, subDuration, subCount); chk.E(err) { + fmt.Printf( + "Profiling %d concurrent subscriptions for %v...\n", subCount, + subDuration, + ) + if err := profiler.TestSubscriptionPerformance( + c, subDuration, subCount, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Subscription profiling failed: %v\n", err) os.Exit(1) } } else { - fmt.Printf("Profiling %d queries with %d concurrent workers...\n", queryCount, concurrency) - if err := profiler.ExecuteProfile(c, queryCount, concurrency); chk.E(err) { + fmt.Printf( + "Profiling %d queries with %d concurrent workers...\n", queryCount, + concurrency, + ) + if err := profiler.ExecuteProfile( + c, queryCount, concurrency, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Query profiling failed: %v\n", err) os.Exit(1) } @@ -488,7 +611,10 @@ func runSecp256k1Installer(workDir, installDir string) { } } -func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.Duration, baseLoad, peakLoad, poolSize, eventSize int, runSuite, runConstraints bool) { +func runLoadSimulation( + c context.T, relayURL, patternStr string, duration time.Duration, + baseLoad, peakLoad, poolSize, eventSize int, runSuite, runConstraints bool, +) { if runSuite { suite := NewLoadTestSuite(relayURL, poolSize, eventSize) if err := suite.RunAllPatterns(c); chk.E(err) { @@ -515,7 +641,9 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D os.Exit(1) } - simulator := NewLoadSimulator(relayURL, pattern, duration, baseLoad, peakLoad, poolSize, eventSize) + simulator := NewLoadSimulator( + relayURL, pattern, duration, baseLoad, peakLoad, poolSize, eventSize, + ) if err := simulator.Run(c); chk.E(err) { fmt.Fprintf(os.Stderr, "Load simulation failed: %v\n", err) @@ -524,8 +652,12 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D if runConstraints { fmt.Printf("\n") - if err := simulator.SimulateResourceConstraints(c, 512, 80); chk.E(err) { - fmt.Fprintf(os.Stderr, "Resource constraint simulation failed: %v\n", err) + if err := simulator.SimulateResourceConstraints( + c, 512, 80, + ); chk.E(err) { + fmt.Fprintf( + os.Stderr, "Resource constraint simulation failed: %v\n", err, + ) } } @@ -540,7 +672,10 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D fmt.Printf("Peak latency: %vms\n", metrics["peak_latency_ms"]) } -func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSize int, testSubs bool, duration time.Duration) { +func runTimingInstrumentation( + c context.T, relayURL string, eventCount, eventSize int, testSubs bool, + duration time.Duration, +) { instrumentation := NewTimingInstrumentation(relayURL) fmt.Printf("Connecting to relay at %s...\n", relayURL) @@ -552,13 +687,17 @@ func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSiz if testSubs { fmt.Printf("\n=== Subscription Timing Test ===\n") - if err := instrumentation.TestSubscriptionTiming(c, duration); chk.E(err) { + if err := instrumentation.TestSubscriptionTiming( + c, duration, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Subscription timing test failed: %v\n", err) os.Exit(1) } } else { fmt.Printf("\n=== Full Event Lifecycle Instrumentation ===\n") - if err := instrumentation.RunFullInstrumentation(c, eventCount, eventSize); chk.E(err) { + if err := instrumentation.RunFullInstrumentation( + c, eventCount, eventSize, + ); chk.E(err) { fmt.Fprintf(os.Stderr, "Timing instrumentation failed: %v\n", err) os.Exit(1) } @@ -574,12 +713,14 @@ func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSiz if bottlenecks, ok := metrics["bottlenecks"].(map[string]map[string]interface{}); ok { fmt.Printf("\n=== Pipeline Stage Analysis ===\n") for stage, data := range bottlenecks { - fmt.Printf("%s: avg=%vms, p95=%vms, p99=%vms, throughput=%.2f ops/s\n", + fmt.Printf( + "%s: avg=%vms, p95=%vms, p99=%vms, throughput=%.2f ops/s\n", stage, data["avg_latency_ms"], data["p95_latency_ms"], data["p99_latency_ms"], - data["throughput_ops_sec"]) + data["throughput_ops_sec"], + ) } } } diff --git a/cmd/benchmark/report_generator.go b/cmd/benchmark/report_generator.go index 34e4e4a..48bea49 100644 --- a/cmd/benchmark/report_generator.go +++ b/cmd/benchmark/report_generator.go @@ -60,7 +60,10 @@ func NewReportGenerator() *ReportGenerator { } } -func (rg *ReportGenerator) AddRelayData(relayType string, results *BenchmarkResults, metrics *HarnessMetrics, profilerMetrics *QueryMetrics) { +func (rg *ReportGenerator) AddRelayData( + relayType string, results *BenchmarkResults, metrics *HarnessMetrics, + profilerMetrics *QueryMetrics, +) { data := RelayBenchmarkData{ RelayType: relayType, EventsPublished: results.EventsPublished, @@ -148,19 +151,25 @@ func (rg *ReportGenerator) detectAnomalies() { for _, data := range rg.data { if math.Abs(data.PublishRate-publishMean) > 2*publishStdDev { - anomaly := fmt.Sprintf("%s publish rate (%.2f) deviates significantly from average (%.2f)", - data.RelayType, data.PublishRate, publishMean) + anomaly := fmt.Sprintf( + "%s publish rate (%.2f) deviates significantly from average (%.2f)", + data.RelayType, data.PublishRate, publishMean, + ) rg.report.Anomalies = append(rg.report.Anomalies, anomaly) } if math.Abs(data.QueryRate-queryMean) > 2*queryStdDev { - anomaly := fmt.Sprintf("%s query rate (%.2f) deviates significantly from average (%.2f)", - data.RelayType, data.QueryRate, queryMean) + anomaly := fmt.Sprintf( + "%s query rate (%.2f) deviates significantly from average (%.2f)", + data.RelayType, data.QueryRate, queryMean, + ) rg.report.Anomalies = append(rg.report.Anomalies, anomaly) } if data.Errors > 0 { - anomaly := fmt.Sprintf("%s had %d errors during benchmark", data.RelayType, data.Errors) + anomaly := fmt.Sprintf( + "%s had %d errors during benchmark", data.RelayType, data.Errors, + ) rg.report.Anomalies = append(rg.report.Anomalies, anomaly) } } @@ -171,9 +180,11 @@ func (rg *ReportGenerator) generateRecommendations() { return } - sort.Slice(rg.data, func(i, j int) bool { - return rg.data[i].PublishRate > rg.data[j].PublishRate - }) + sort.Slice( + rg.data, func(i, j int) bool { + return rg.data[i].PublishRate > rg.data[j].PublishRate + }, + ) if len(rg.data) > 1 { best := rg.data[0] @@ -181,16 +192,20 @@ func (rg *ReportGenerator) generateRecommendations() { improvement := (best.PublishRate - worst.PublishRate) / worst.PublishRate * 100 if improvement > 20 { - rec := fmt.Sprintf("Consider using %s for high-throughput scenarios (%.1f%% faster than %s)", - best.RelayType, improvement, worst.RelayType) + rec := fmt.Sprintf( + "Consider using %s for high-throughput scenarios (%.1f%% faster than %s)", + best.RelayType, improvement, worst.RelayType, + ) rg.report.Recommendations = append(rg.report.Recommendations, rec) } } for _, data := range rg.data { if data.MemoryUsageMB > 500 { - rec := fmt.Sprintf("%s shows high memory usage (%.1f MB) - monitor for memory leaks", - data.RelayType, data.MemoryUsageMB) + rec := fmt.Sprintf( + "%s shows high memory usage (%.1f MB) - monitor for memory leaks", + data.RelayType, data.MemoryUsageMB, + ) rg.report.Recommendations = append(rg.report.Recommendations, rec) } } @@ -198,25 +213,39 @@ func (rg *ReportGenerator) generateRecommendations() { func (rg *ReportGenerator) OutputMarkdown(writer io.Writer) error { fmt.Fprintf(writer, "# %s\n\n", rg.report.Title) - fmt.Fprintf(writer, "Generated: %s\n\n", rg.report.GeneratedAt.Format(time.RFC3339)) + fmt.Fprintf( + writer, "Generated: %s\n\n", rg.report.GeneratedAt.Format(time.RFC3339), + ) fmt.Fprintf(writer, "## Performance Summary\n\n") - fmt.Fprintf(writer, "| Relay | Publish Rate | Publish BW | Query Rate | Avg Events/Query | Memory (MB) |\n") - fmt.Fprintf(writer, "|-------|--------------|------------|------------|------------------|-------------|\n") + fmt.Fprintf( + writer, + "| Client | Publish Rate | Publish BW | Query Rate | Avg Events/Query | Memory (MB) |\n", + ) + fmt.Fprintf( + writer, + "|-------|--------------|------------|------------|------------------|-------------|\n", + ) for _, data := range rg.data { - fmt.Fprintf(writer, "| %s | %.2f/s | %.2f MB/s | %.2f/s | %.2f | %.1f |\n", + fmt.Fprintf( + writer, "| %s | %.2f/s | %.2f MB/s | %.2f/s | %.2f | %.1f |\n", data.RelayType, data.PublishRate, data.PublishBandwidth, - data.QueryRate, data.AvgEventsPerQuery, data.MemoryUsageMB) + data.QueryRate, data.AvgEventsPerQuery, data.MemoryUsageMB, + ) } if rg.report.WinnerPublish != "" || rg.report.WinnerQuery != "" { fmt.Fprintf(writer, "\n## Winners\n\n") if rg.report.WinnerPublish != "" { - fmt.Fprintf(writer, "- **Best Publisher**: %s\n", rg.report.WinnerPublish) + fmt.Fprintf( + writer, "- **Best Publisher**: %s\n", rg.report.WinnerPublish, + ) } if rg.report.WinnerQuery != "" { - fmt.Fprintf(writer, "- **Best Query Engine**: %s\n", rg.report.WinnerQuery) + fmt.Fprintf( + writer, "- **Best Query Engine**: %s\n", rg.report.WinnerQuery, + ) } } @@ -237,12 +266,18 @@ func (rg *ReportGenerator) OutputMarkdown(writer io.Writer) error { fmt.Fprintf(writer, "\n## Detailed Results\n\n") for _, data := range rg.data { fmt.Fprintf(writer, "### %s\n\n", data.RelayType) - fmt.Fprintf(writer, "- Events Published: %d (%.2f MB)\n", data.EventsPublished, data.EventsPublishedMB) + fmt.Fprintf( + writer, "- Events Published: %d (%.2f MB)\n", data.EventsPublished, + data.EventsPublishedMB, + ) fmt.Fprintf(writer, "- Publish Duration: %s\n", data.PublishDuration) fmt.Fprintf(writer, "- Queries Executed: %d\n", data.QueriesExecuted) fmt.Fprintf(writer, "- Query Duration: %s\n", data.QueryDuration) if data.P50Latency != "" { - fmt.Fprintf(writer, "- Latency P50/P95/P99: %s/%s/%s\n", data.P50Latency, data.P95Latency, data.P99Latency) + fmt.Fprintf( + writer, "- Latency P50/P95/P99: %s/%s/%s\n", data.P50Latency, + data.P95Latency, data.P99Latency, + ) } if data.StartupTime != "" { fmt.Fprintf(writer, "- Startup Time: %s\n", data.StartupTime) @@ -264,9 +299,12 @@ func (rg *ReportGenerator) OutputCSV(writer io.Writer) error { defer w.Flush() header := []string{ - "relay_type", "events_published", "events_published_mb", "publish_duration", - "publish_rate", "publish_bandwidth", "queries_executed", "events_returned", - "query_duration", "query_rate", "avg_events_per_query", "memory_usage_mb", + "relay_type", "events_published", "events_published_mb", + "publish_duration", + "publish_rate", "publish_bandwidth", "queries_executed", + "events_returned", + "query_duration", "query_rate", "avg_events_per_query", + "memory_usage_mb", "p50_latency", "p95_latency", "p99_latency", "startup_time", "errors", } @@ -315,9 +353,11 @@ func (rg *ReportGenerator) GenerateThroughputCurve() []ThroughputPoint { points = append(points, point) } - sort.Slice(points, func(i, j int) bool { - return points[i].Throughput < points[j].Throughput - }) + sort.Slice( + points, func(i, j int) bool { + return points[i].Throughput < points[j].Throughput + }, + ) return points } @@ -370,7 +410,9 @@ func stdDev(values []float64, mean float64) float64 { return math.Sqrt(variance) } -func SaveReportToFile(filename, format string, generator *ReportGenerator) error { +func SaveReportToFile( + filename, format string, generator *ReportGenerator, +) error { file, err := os.Create(filename) if err != nil { return err diff --git a/cmd/benchmark/run.sh b/cmd/benchmark/run.sh index f1f641a..236351c 100755 --- a/cmd/benchmark/run.sh +++ b/cmd/benchmark/run.sh @@ -1 +1,31 @@ #!/usr/bin/env bash + +# khatru + +khatru & +KHATRU_PID=$! +printf "khatru started pid: %s\n" $KHATRU_PID +sleep 2s +LOG_LEVEL=info relay-benchmark -relay ws://localhost:3334 -events 10000 -queries 100 +kill $KHATRU_PID +printf "khatru stopped\n" +sleep 1s + +# ORLY + +LOG_LEVEL=off \ +ORLY_LOG_LEVEL=off \ +ORLY_DB_LOG_LEVEL=off \ +ORLY_SPIDER_TYPE=none \ +ORLY_LISTEN=localhost \ +ORLY_PORT=7447 \ +ORLY_AUTH_REQUIRED=false \ +ORLY_PRIVATE=true \ +orly & +ORLY_PID=$! +printf "ORLY started pid: %s\n" $ORLY_PID +sleep 2s +LOG_LEVEL=info relay-benchmark -relay ws://localhost:7447 -events 100 -queries 100 +kill $ORLY_PID +printf "ORLY stopped\n" +sleep 1s diff --git a/pkg/encoders/envelopes/okenvelope/okenvelope.go b/pkg/encoders/envelopes/okenvelope/okenvelope.go index 16d5125..6b10969 100644 --- a/pkg/encoders/envelopes/okenvelope/okenvelope.go +++ b/pkg/encoders/envelopes/okenvelope/okenvelope.go @@ -90,17 +90,19 @@ func (en *T) Marshal(dst []byte) (b []byte) { // subscription.Id strings are correctly unescaped by NIP-01 escaping rules. func (en *T) Unmarshal(b []byte) (r []byte, err error) { r = b - var idHex []byte - if idHex, r, err = text2.UnmarshalHex(r); chk.E(err) { + var idBytes []byte + // Parse event id as quoted hex (NIP-20 compliant) + if idBytes, r, err = text2.UnmarshalHex(r); err != nil { return } - if len(idHex) != sha256.Size { + if len(idBytes) != sha256.Size { err = errorf.E( "invalid size for ID, require %d got %d", - len(idHex), sha256.Size, + sha256.Size, len(idBytes), ) + return } - en.EventID = eventid.NewWith(idHex) + en.EventID = eventid.NewWith(idBytes) if r, err = text2.Comma(r); chk.E(err) { return } diff --git a/pkg/encoders/filter/filter.go b/pkg/encoders/filter/filter.go index 5d89cf6..bef9657 100644 --- a/pkg/encoders/filter/filter.go +++ b/pkg/encoders/filter/filter.go @@ -23,7 +23,6 @@ import ( "orly.dev/pkg/encoders/timestamp" "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/errorf" - "orly.dev/pkg/utils/log" "orly.dev/pkg/utils/pointers" "lukechampine.com/frand" @@ -446,21 +445,15 @@ invalid: // determines if the event matches the filter, ignoring timestamp constraints.. func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool { if ev == nil { - log.I.F("nil event") return false } if f.Ids.Len() > 0 && !f.Ids.Contains(ev.ID) { - log.I.F("no ids in filter match event") return false } if f.Kinds.Len() > 0 && !f.Kinds.Contains(ev.Kind) { - log.I.F( - "no matching kinds in filter", - ) return false } if f.Authors.Len() > 0 && !f.Authors.Contains(ev.Pubkey) { - log.I.F("no matching authors in filter") return false } // if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) { @@ -470,7 +463,6 @@ func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool { for _, v := range f.Tags.ToSliceOfTags() { tvs := v.ToSliceOfBytes() if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) { - log.I.F("no matching tags in filter") return false } } @@ -485,11 +477,9 @@ func (f *F) Matches(ev *event.E) (match bool) { return } if f.Since.Int() != 0 && ev.CreatedAt.I64() < f.Since.I64() { - log.I.F("event is older than since") return } if f.Until.Int() != 0 && ev.CreatedAt.I64() > f.Until.I64() { - log.I.F("event is newer than until") return } return true diff --git a/pkg/protocol/relayinfo/types.go b/pkg/protocol/relayinfo/types.go index 6287880..35a148b 100644 --- a/pkg/protocol/relayinfo/types.go +++ b/pkg/protocol/relayinfo/types.go @@ -57,7 +57,7 @@ var ( NIP8 = HandlingMentions EventDeletion = NIP{"Event Deletion", 9} NIP9 = EventDeletion - RelayInformationDocument = NIP{"Relay Information Document", 11} + RelayInformationDocument = NIP{"Client Information Document", 11} NIP11 = RelayInformationDocument GenericTagQueries = NIP{"Generic Tag Queries", 12} NIP12 = GenericTagQueries @@ -133,7 +133,7 @@ var ( NIP57 = LightningZaps Badges = NIP{"Badges", 58} NIP58 = Badges - RelayListMetadata = NIP{"Relay List Metadata", 65} + RelayListMetadata = NIP{"Client List Metadata", 65} NIP65 = RelayListMetadata ProtectedEvents = NIP{"Protected Events", 70} NIP70 = ProtectedEvents diff --git a/pkg/protocol/ws/client.go b/pkg/protocol/ws/client.go index 15b8f61..d4ce834 100644 --- a/pkg/protocol/ws/client.go +++ b/pkg/protocol/ws/client.go @@ -2,17 +2,11 @@ package ws import ( "bytes" + "context" "crypto/tls" "errors" "fmt" "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/puzpuzpuz/xsync/v3" "orly.dev/pkg/encoders/envelopes" "orly.dev/pkg/encoders/envelopes/authenvelope" "orly.dev/pkg/encoders/envelopes/closedenvelope" @@ -23,21 +17,28 @@ import ( "orly.dev/pkg/encoders/event" "orly.dev/pkg/encoders/filter" "orly.dev/pkg/encoders/filters" - "orly.dev/pkg/encoders/hex" "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/subscription" "orly.dev/pkg/encoders/tag" "orly.dev/pkg/encoders/tags" "orly.dev/pkg/encoders/timestamp" + "orly.dev/pkg/interfaces/codec" "orly.dev/pkg/interfaces/signer" "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/context" "orly.dev/pkg/utils/log" "orly.dev/pkg/utils/normalize" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/puzpuzpuz/xsync/v3" ) var subscriptionIDCounter atomic.Int64 -// Relay represents a connection to a Nostr relay. +// Client represents a connection to a Nostr relay. type Client struct { closeMutex sync.Mutex @@ -45,14 +46,14 @@ type Client struct { requestHeader http.Header // e.g. for origin header Connection *Connection - Subscriptions *xsync.MapOf[int64, *Subscription] + Subscriptions *xsync.MapOf[string, *Subscription] ConnectionError error - connectionContext context.T // will be canceled when the connection closes - connectionContextCancel context.C + connectionContext context.Context // will be canceled when the connection closes + connectionContextCancel context.CancelCauseFunc challenge []byte // NIP-42 challenge, we only keep the last - noticeHandler func(string) // NIP-01 NOTICEs + notices chan []byte // NIP-01 NOTICEs customHandler func(string) // nonstandard unparseable messages okCallbacks *xsync.MapOf[string, func(bool, string)] writeQueue chan writeRequest @@ -69,13 +70,13 @@ type writeRequest struct { } // NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection. -func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client { - ctx, cancel := context.Cause(ctx) +func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Client { + ctx, cancel := context.WithCancelCause(ctx) r := &Client{ URL: string(normalize.URL(url)), connectionContext: ctx, connectionContextCancel: cancel, - Subscriptions: xsync.NewMapOf[int64, *Subscription](), + Subscriptions: xsync.NewMapOf[string, *Subscription](), okCallbacks: xsync.NewMapOf[string, func( bool, string, )](), @@ -97,10 +98,10 @@ func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client { // // The ongoing relay connection uses a background context. To close the connection, call r.Close(). // If you need fine grained long-term connection contexts, use NewRelay() instead. -func RelayConnect(ctx context.T, url string, opts ...RelayOption) ( +func RelayConnect(ctx context.Context, url string, opts ...RelayOption) ( *Client, error, ) { - r := NewRelay(context.Bg(), url, opts...) + r := NewRelay(context.Background(), url, opts...) err := r.Connect(ctx) return r, err } @@ -111,19 +112,10 @@ type RelayOption interface { } var ( - _ RelayOption = (WithNoticeHandler)(nil) _ RelayOption = (WithCustomHandler)(nil) _ RelayOption = (WithRequestHeader)(nil) ) -// WithNoticeHandler just takes notices and is expected to do something with them. -// when not given, defaults to logging the notices. -type WithNoticeHandler func(notice string) - -func (nh WithNoticeHandler) ApplyRelayOption(r *Client) { - r.noticeHandler = nh -} - // WithCustomHandler must be a function that handles any relay message that couldn't be // parsed as a standard envelope. type WithCustomHandler func(data string) @@ -146,7 +138,7 @@ func (r *Client) String() string { // Context retrieves the context that is associated with this relay connection. // It will be closed when the relay is disconnected. -func (r *Client) Context() context.T { return r.connectionContext } +func (r *Client) Context() context.Context { return r.connectionContext } // IsConnected returns true if the connection to this relay seems to be active. func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil } @@ -158,10 +150,32 @@ func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil } // // The given context here is only used during the connection phase. The long-living // relay connection will be based on the context given to NewRelay(). -func (r *Client) Connect(ctx context.T) error { +func (r *Client) Connect(ctx context.Context) error { return r.ConnectWithTLS(ctx, nil) } +func extractSubID(jsonStr string) string { + // look for "EVENT" pattern + start := strings.Index(jsonStr, `"EVENT"`) + if start == -1 { + return "" + } + + // move to the next quote + offset := strings.Index(jsonStr[start+7:], `"`) + if offset == -1 { + return "" + } + + start += 7 + offset + 1 + + // find the ending quote + end := strings.Index(jsonStr[start:], `"`) + + // get the contents + return jsonStr[start : start+end] +} + func subIdToSerial(subId string) int64 { n := strings.Index(subId, ":") if n < 0 || n > len(subId) { @@ -173,8 +187,8 @@ func subIdToSerial(subId string) int64 { // ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that. func (r *Client) ConnectWithTLS( - ctx context.T, tlsConfig *tls.Config, -) (err error) { + ctx context.Context, tlsConfig *tls.Config, +) error { if r.connectionContext == nil || r.Subscriptions == nil { return fmt.Errorf("relay must be initialized with a call to NewRelay()") } @@ -185,209 +199,182 @@ func (r *Client) ConnectWithTLS( if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds - var cancel context.F - ctx, cancel = context.TimeoutCause( + var cancel context.CancelFunc + ctx, cancel = context.WithTimeoutCause( ctx, 7*time.Second, errors.New("connection took too long"), ) defer cancel() } - var conn *Connection - if conn, err = NewConnection( - ctx, r.URL, r.requestHeader, tlsConfig, - ); chk.E(err) { - err = fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) - return + + conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig) + if err != nil { + return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) } r.Connection = conn + // ping every 29 seconds ticker := time.NewTicker(29 * time.Second) + // queue all write operations here so we don't do mutex spaghetti go func() { + var err error for { select { case <-r.connectionContext.Done(): ticker.Stop() r.Connection = nil + for _, sub := range r.Subscriptions.Range { sub.unsub( fmt.Errorf( "relay connection closed: %w / %w", - context.GetCause(r.connectionContext), + context.Cause(r.connectionContext), r.ConnectionError, ), ) } return + case <-ticker.C: - err := r.Connection.Ping(r.connectionContext) + err = r.Connection.Ping(r.connectionContext) if err != nil && !strings.Contains( err.Error(), "failed to wait for pong", ) { - log.T.C( - func() string { - return fmt.Sprintf( - "{%s} error writing ping: %v; closing websocket", - r.URL, - err, - ) - }, + log.I.F( + "{%s} error writing ping: %v; closing websocket", r.URL, + err, ) r.Close() // this should trigger a context cancelation return } - case writeRequest := <-r.writeQueue: + + case wr := <-r.writeQueue: // all write requests will go through this to prevent races - log.T.C( - func() string { - return fmt.Sprintf( - "{%s} sending %v\n", r.URL, - string(writeRequest.msg), - ) - }, - ) - if err := r.Connection.WriteMessage( - r.connectionContext, writeRequest.msg, + log.D.F("{%s} sending %v\n", r.URL, string(wr.msg)) + if err = r.Connection.WriteMessage( + r.connectionContext, wr.msg, ); err != nil { - writeRequest.answer <- err + wr.answer <- err } - close(writeRequest.answer) + close(wr.answer) } } }() + // general message reader loop go func() { for { buf := new(bytes.Buffer) - if err := conn.ReadMessage(r.connectionContext, buf); err != nil { - r.ConnectionError = err - r.close(err) - break - } - var err error - var t string - var rem []byte - if t, rem, err = envelopes.Identify(buf.Bytes()); chk.E(err) { - continue - } - switch t { - case noticeenvelope.L: - env := noticeenvelope.NewFrom(rem) - // see WithNoticeHandler - if r.noticeHandler != nil { - r.noticeHandler(string(env.Message)) - } else { - log.D.F( - "NOTICE from %s: '%s'\n", r.URL, string(env.Message), - ) + for { + buf.Reset() + if err := conn.ReadMessage( + r.connectionContext, buf, + ); err != nil { + r.ConnectionError = err + r.Close() + break } - case authenvelope.L: - env := authenvelope.NewChallengeWith(rem) - if env.Challenge == nil { + message := buf.Bytes() + log.D.F("{%s} %v\n", r.URL, message) + + var t string + if t, message, err = envelopes.Identify(message); chk.E(err) { continue } - r.challenge = env.Challenge - case eventenvelope.L: - // log.I.F("%s", rem) - var env *eventenvelope.Result - env = eventenvelope.NewResult() - if _, err = env.Unmarshal(rem); err != nil { - continue - } - subid := env.Subscription.String() - sub, ok := r.Subscriptions.Load(subIdToSerial(subid)) - if !ok { - log.W.F( - "unknown subscription with id '%s'\n", - subid, - ) - continue - } - if !sub.Filters.Match(env.Event) { - log.T.C( - func() string { - return fmt.Sprintf( - "{%s} filter does not match: %v ~ %s\n", r.URL, - sub.Filters, env.Event.Marshal(nil), - ) - }, - ) - continue - } - if !r.AssumeValid { - if ok, err = env.Event.Verify(); !ok || chk.E(err) { - log.T.C( - func() string { - return fmt.Sprintf( - "{%s} bad signature on %s\n", r.URL, - env.Event.ID, - ) - }, - ) + switch t { + case noticeenvelope.L: + env := noticeenvelope.New() + if env, message, err = noticeenvelope.Parse(message); chk.E(err) { continue } - } - sub.dispatchEvent(env.Event) - case eoseenvelope.L: - var env *eoseenvelope.T - if env, rem, err = eoseenvelope.Parse(rem); chk.E(err) { - continue - } - if len(rem) != 0 { - log.W.F( - "{%s} unexpected data after EOSE: %s\n", r.URL, - string(rem), - ) - } - sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String())) - if !ok { - log.W.F( - "unknown subscription with id '%s'\n", - env.Subscription.String(), - ) - continue - } - sub.dispatchEose() - case closedenvelope.L: - var env *closedenvelope.T - if env, rem, err = closedenvelope.Parse(rem); chk.E(err) { - continue - } - sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String())) - if !ok { - log.W.F( - "unknown subscription with id '%s'\n", - env.Subscription.String(), - ) - continue - } - sub.handleClosed(env.ReasonString()) - case okenvelope.L: - var env *okenvelope.T - if env, rem, err = okenvelope.Parse(rem); chk.E(err) { - continue - } - eventIDStr := env.EventID.String() - if okCallback, exist := r.okCallbacks.Load(eventIDStr); exist { - okCallback(env.OK, string(env.Reason)) - } else { - log.T.C( - func() string { - return fmt.Sprintf( - "{%s} got an unexpected OK message for event %s", - r.URL, - eventIDStr, + // see WithNoticeHandler + if r.notices != nil { + r.notices <- env.Message + } else { + log.E.F("NOTICE from %s: '%s'\n", r.URL, env.Message) + } + case authenvelope.L: + env := authenvelope.NewChallenge() + if env, message, err = authenvelope.ParseChallenge(message); chk.E(err) { + continue + } + if len(env.Challenge) == 0 { + continue + } + r.challenge = env.Challenge + case eventenvelope.L: + env := eventenvelope.NewResult() + if env, message, err = eventenvelope.ParseResult(message); chk.E(err) { + continue + } + if len(env.Subscription.T) == 0 { + continue + } + if sub, ok := r.Subscriptions.Load(env.Subscription.String()); !ok { + log.D.F( + "{%s} no subscription with id '%s'\n", r.URL, + env.Subscription, + ) + continue + } else { + // check if the event matches the desired filter, ignore otherwise + if !sub.Filters.Match(env.Event) { + log.D.F( + "{%s} filter does not match: %v ~ %v\n", r.URL, + sub.Filters, env.Event, ) - }, - ) + continue + } + // check signature, ignore invalid, except from trusted (AssumeValid) relays + if !r.AssumeValid { + if ok, err = env.Event.Verify(); !ok { + log.E.F( + "{%s} bad signature on %s\n", r.URL, + env.Event.IdString(), + ) + continue + } + } + // dispatch this to the internal .events channel of the subscription + sub.dispatchEvent(env.Event) + } + case eoseenvelope.L: + env := eoseenvelope.New() + if env, message, err = eoseenvelope.Parse(message); chk.E(err) { + continue + } + if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok { + subscription.dispatchEose() + } + case closedenvelope.L: + env := closedenvelope.New() + if env, message, err = closedenvelope.Parse(message); chk.E(err) { + continue + } + if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok { + subscription.handleClosed(env.ReasonString()) + } + case okenvelope.L: + env := okenvelope.New() + if env, message, err = okenvelope.Parse(message); chk.E(err) { + continue + } + if okCallback, exist := r.okCallbacks.Load(env.EventID.String()); exist { + okCallback(env.OK, env.ReasonString()) + } else { + log.I.F( + "{%s} got an unexpected OK message for event %s", + r.URL, + env.EventID, + ) + } } - default: - log.W.F("unknown envelope type %s\n%s", t, rem) - continue } } }() - return + + return nil } // Write queues an arbitrary message to be sent to the relay. @@ -401,55 +388,60 @@ func (r *Client) Write(msg []byte) <-chan error { return ch } -// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an -// OK response. -func (r *Client) Publish(ctx context.T, ev *event.E) error { - return r.publish(ctx, ev.ID, ev) +// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response. +func (r *Client) Publish(ctx context.Context, event *event.E) error { + return r.publish( + ctx, event.IdString(), eventenvelope.NewSubmissionWith(event), + ) } -// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK -// response. +// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK response. +// +// You don't have to build the AUTH event yourself, this function takes a function to which the +// event that must be signed will be passed, so it's only necessary to sign that. func (r *Client) Auth( - ctx context.T, sign signer.I, + ctx context.Context, sign signer.I, ) (err error) { authEvent := &event.E{ CreatedAt: timestamp.Now(), Kind: kind.ClientAuthentication, Tags: tags.New( tag.New("relay", r.URL), - tag.New([]byte("challenge"), r.challenge), + tag.New("challenge", string(r.challenge)), ), + Content: nil, } - if err = authEvent.Sign(sign); chk.E(err) { - err = fmt.Errorf("error signing auth event: %w", err) - return + if err = authEvent.Sign(sign); err != nil { + return fmt.Errorf("error signing auth event: %w", err) } - return r.publish(ctx, authEvent.ID, authEvent) + + return r.publish( + ctx, authEvent.IdString(), authenvelope.NewResponseWith(authEvent), + ) } func (r *Client) publish( - ctx context.T, id []byte, ev *event.E, + ctx context.Context, id string, env codec.Envelope, ) error { var err error - var cancel context.F + var cancel context.CancelFunc + if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds - ctx, cancel = context.TimeoutCause( + ctx, cancel = context.WithTimeoutCause( ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"), ) defer cancel() } else { - // otherwise make the context cancellable so we can stop everything upon - // receiving an "OK" - ctx, cancel = context.Cancel(ctx) + // otherwise make the context cancellable so we can stop everything upon receiving an "OK" + ctx, cancel = context.WithCancel(ctx) defer cancel() } // listen for an OK callback gotOk := false - ids := hex.Enc(id) r.okCallbacks.Store( - ids, func(ok bool, reason string) { + id, func(ok bool, reason string) { gotOk = true if !ok { err = fmt.Errorf("msg: %s", reason) @@ -457,18 +449,18 @@ func (r *Client) publish( cancel() }, ) - defer r.okCallbacks.Delete(ids) + defer r.okCallbacks.Delete(id) + // publish event - envb := eventenvelope.NewSubmissionWith(ev).Marshal(nil) - // envb := ev.Marshal(nil) + envb := env.Marshal(nil) if err = <-r.Write(envb); err != nil { return err } + for { select { case <-ctx.Done(): - // this will be called when we get an OK or when the context has - // been canceled + // this will be called when we get an OK or when the context has been canceled if gotOk { return err } @@ -480,40 +472,41 @@ func (r *Client) publish( } } -// Subscribe sends a "REQ" command to the relay r as in NIP-01. Events are -// returned through the channel sub.Events. The subscription is closed when -// context ctx is cancelled ("CLOSE" in NIP-01). +// Subscribe sends a "REQ" command to the relay r as in NIP-01. +// Events are returned through the channel sub.Events. +// The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01). // -// Remember to cancel subscriptions, either by calling `.Unsub()` on them or -// ensuring their `context.T` will be canceled at some point. Failure to -// do that will result in a huge number of halted goroutines being created. +// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point. +// Failure to do that will result in a huge number of halted goroutines being created. func (r *Client) Subscribe( - ctx context.T, ff *filters.T, opts ...SubscriptionOption, -) (sub *Subscription, err error) { - sub = r.PrepareSubscription(ctx, ff, opts...) + ctx context.Context, ff *filters.T, opts ...SubscriptionOption, +) (*Subscription, error) { + sub := r.PrepareSubscription(ctx, ff, opts...) + if r.Connection == nil { return nil, fmt.Errorf("not connected to %s", r.URL) } - if err = sub.Fire(); err != nil { - err = fmt.Errorf( - "couldn't subscribe to %v at %s: %w", ff.Marshal(nil), r.URL, err, + + if err := sub.Fire(); err != nil { + return nil, fmt.Errorf( + "couldn't subscribe to %v at %s: %w", ff, r.URL, err, ) - return } - return + + return sub, nil } // PrepareSubscription creates a subscription, but doesn't fire it. // -// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.T` will be canceled at some point. +// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point. // Failure to do that will result in a huge number of halted goroutines being created. func (r *Client) PrepareSubscription( - ctx context.T, ff *filters.T, opts ...SubscriptionOption, -) *Subscription { + ctx context.Context, ff *filters.T, opts ...SubscriptionOption, +) (sub *Subscription) { current := subscriptionIDCounter.Add(1) - ctx, cancel := context.Cause(ctx) - sub := &Subscription{ - Relay: r, + ctx, cancel := context.WithCancelCause(ctx) + sub = &Subscription{ + Client: r, Context: ctx, cancel: cancel, counter: current, @@ -528,10 +521,10 @@ func (r *Client) PrepareSubscription( switch o := opt.(type) { case WithLabel: label = string(o) - // case WithCheckDuplicate: - // sub.checkDuplicate = o - // case WithCheckDuplicateReplaceable: - // sub.checkDuplicateReplaceable = o + case WithCheckDuplicate: + sub.checkDuplicate = o + case WithCheckDuplicateReplaceable: + sub.checkDuplicateReplaceable = o } } // subscription id computation @@ -540,9 +533,8 @@ func (r *Client) PrepareSubscription( buf = append(buf, ':') buf = append(buf, label...) defer subIdPool.Put(buf) - sub.id = string(buf) - // we track subscriptions only by their counter, no need for the full id - r.Subscriptions.Store(int64(sub.counter), sub) + sub.id = &subscription.Id{T: buf} + r.Subscriptions.Store(string(buf), sub) // start handling events, eose, unsub etc: go sub.start() return sub @@ -550,8 +542,8 @@ func (r *Client) PrepareSubscription( // QueryEvents subscribes to events matching the given filter and returns a channel of events. // -// In most cases it's better to use Pool instead of this method. -func (r *Client) QueryEvents(ctx context.T, f *filter.F) ( +// In most cases it's better to use SimplePool instead of this method. +func (r *Client) QueryEvents(ctx context.Context, f *filter.F) ( evc event.C, err error, ) { var sub *Subscription @@ -570,44 +562,31 @@ func (r *Client) QueryEvents(ctx context.T, f *filter.F) ( return } }() - return sub.Events, nil + evc = sub.Events + return } -// QuerySync subscribes to events matching the given filter and returns a slice -// of events. This method blocks until all events are received or the context is -// canceled. +// QuerySync subscribes to events matching the given filter and returns a slice of events. +// This method blocks until all events are received or the context is canceled. // -// If the filter causes a subscription to open, it will stay open until the -// limit is exceeded. So this method will return an error if the limit is nil. -// If the query blocks, the caller needs to cancel the context to prevent the -// thread stalling. -func (r *Client) QuerySync(ctx context.T, f *filter.F) ( - evs event.S, err error, +// In most cases it's better to use SimplePool instead of this method. +func (r *Client) QuerySync(ctx context.Context, ff *filter.F) ( + []*event.E, error, ) { - if f.Limit == nil { - err = errors.New("limit must be set for a sync query to prevent blocking") - return - } - var sub *Subscription - if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) { - return - } - defer sub.unsub(errors.New("QuerySync() ended")) - evs = make(event.S, 0, *f.Limit) if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds - var cancel context.F - ctx, cancel = context.TimeoutCause( + var cancel context.CancelFunc + ctx, cancel = context.WithTimeoutCause( ctx, 7*time.Second, errors.New("QuerySync() took too long"), ) defer cancel() } - lim := 250 - if f.Limit != nil { - lim = int(*f.Limit) + var lim int + if ff.Limit != nil { + lim = int(*ff.Limit) } - events := make(event.S, 0, max(lim, 250)) - ch, err := r.QueryEvents(ctx, f) + events := make([]*event.E, 0, max(lim, 250)) + ch, err := r.QueryEvents(ctx, ff) if err != nil { return nil, err } @@ -619,50 +598,9 @@ func (r *Client) QuerySync(ctx context.T, f *filter.F) ( return events, nil } -// // Count sends a "COUNT" command to the relay and returns the count of events matching the filters. -// func (r *Relay) Count( -// ctx context.T, -// filters Filters, -// opts ...SubscriptionOption, -// ) (int64, []byte, error) { -// v, err := r.countInternal(ctx, filters, opts...) -// if err != nil { -// return 0, nil, err -// } -// -// return *v.Count, v.HyperLogLog, nil -// } -// -// func (r *Relay) countInternal(ctx context.T, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) { -// sub := r.PrepareSubscription(ctx, filters, opts...) -// sub.countResult = make(chan CountEnvelope) -// -// if err := sub.Fire(); err != nil { -// return CountEnvelope{}, err -// } -// -// defer sub.unsub(errors.New("countInternal() ended")) -// -// if _, ok := ctx.Deadline(); !ok { -// // if no timeout is set, force it to 7 seconds -// var cancel context.CancelFunc -// ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long")) -// defer cancel() -// } -// -// for { -// select { -// case count := <-sub.countResult: -// return count, nil -// case <-ctx.Done(): -// return CountEnvelope{}, ctx.Err() -// } -// } -// } - // Close closes the relay connection. func (r *Client) Close() error { - return r.close(errors.New("relay connection closed")) + return r.close(errors.New("Close() called")) } func (r *Client) close(reason error) error { diff --git a/pkg/protocol/ws/client_test.go b/pkg/protocol/ws/client_test.go index d6a3c09..c3876c2 100644 --- a/pkg/protocol/ws/client_test.go +++ b/pkg/protocol/ws/client_test.go @@ -1,176 +1,118 @@ +//go:build !js + package ws import ( + "bytes" "context" - "errors" + "encoding/json" "io" "net/http" "net/http/httptest" + "orly.dev/pkg/crypto/p256k" + "orly.dev/pkg/encoders/event" + "orly.dev/pkg/encoders/filter" + "orly.dev/pkg/encoders/filters" + "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/tag" + "orly.dev/pkg/encoders/tags" + "orly.dev/pkg/encoders/timestamp" + "orly.dev/pkg/utils/chk" + "orly.dev/pkg/utils/normalize" "sync" "testing" "time" - "orly.dev/pkg/crypto/p256k" - "orly.dev/pkg/encoders/event" - "orly.dev/pkg/encoders/kind" - "orly.dev/pkg/encoders/timestamp" - "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/normalize" - + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" ) -// func TestPublish(t *testing.T) { -// // test note to be sent over websocket -// var err error -// signer := &p256k.Signer{} -// if err = signer.Generate(); chk.E(err) { -// t.Fatal(err) -// } -// textNote := &event.E{ -// Kind: kind.TextNote, -// Content: []byte("hello"), -// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp -// Pubkey: signer.Pub(), -// } -// if err = textNote.Sign(signer); chk.E(err) { -// t.Fatalf("textNote.Sign: %v", err) -// } -// // fake relay server -// var published bool -// ws := newWebsocketServer( -// func(conn *websocket.Conn) { -// // receive message -// var raw []json.RawMessage -// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { -// t.Errorf("websocket.JSON.Receive: %v", err) -// } -// // check that it's an EVENT message -// if len(raw) < 2 { -// t.Errorf("message too short: %v", raw) -// } -// var msgType string -// if err := json.Unmarshal(raw[0], &msgType); chk.T(err) { -// t.Errorf("json.Unmarshal: %v", err) -// } -// if msgType != "EVENT" { -// t.Errorf("expected EVENT message, got %q", msgType) -// } -// // check that the event is the one we sent -// var ev event.E -// if err := json.Unmarshal(raw[1], &ev); chk.T(err) { -// t.Errorf("json.Unmarshal: %v", err) -// } -// published = true -// if !bytes.Equal(ev.ID, textNote.ID) { -// t.Errorf( -// "event ID mismatch: got %x, want %x", -// ev.ID, textNote.ID, -// ) -// } -// if !bytes.Equal(ev.Pubkey, textNote.Pubkey) { -// t.Errorf( -// "event pubkey mismatch: got %x, want %x", -// ev.Pubkey, textNote.Pubkey, -// ) -// } -// if !bytes.Equal(ev.Content, textNote.Content) { -// t.Errorf( -// "event content mismatch: got %q, want %q", -// ev.Content, textNote.Content, -// ) -// } -// fmt.Printf( -// "received event: %s\n", -// textNote.Serialize(), -// ) -// // send back an ok nip-20 command result -// var res []byte -// if res = okenvelope.NewFrom( -// textNote.ID, true, nil, -// ).Marshal(res); chk.E(err) { -// t.Fatal(err) -// } -// if err := websocket.Message.Send(conn, res); chk.T(err) { -// t.Errorf("websocket.Message.Send: %v", err) -// } -// }, -// ) -// defer ws.Close() -// // connect a client and send the text note -// rl := mustRelayConnect(ws.URL) -// err = rl.Publish(context.Background(), textNote) -// if err != nil { -// t.Errorf("publish should have succeeded") -// } -// if !published { -// t.Errorf("fake relay server saw no event") -// } -// } -// -// func TestPublishBlocked(t *testing.T) { -// // test note to be sent over websocket -// var err error -// signer := &p256k.Signer{} -// if err = signer.Generate(); chk.E(err) { -// t.Fatal(err) -// } -// textNote := &event.E{ -// Kind: kind.TextNote, -// Content: []byte("hello"), -// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp -// Pubkey: signer.Pub(), -// } -// if err = textNote.Sign(signer); chk.E(err) { -// t.Fatalf("textNote.Sign: %v", err) -// } -// // fake relay server -// ws := newWebsocketServer( -// func(conn *websocket.Conn) { -// // discard received message; not interested -// var raw []json.RawMessage -// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { -// t.Errorf("websocket.JSON.Receive: %v", err) -// } -// // send back a not ok nip-20 command result -// var res []byte -// if res = okenvelope.NewFrom( -// textNote.ID, false, -// normalize.Msg(normalize.Blocked, "no reason"), -// ).Marshal(res); chk.E(err) { -// t.Fatal(err) -// } -// if err := websocket.Message.Send(conn, res); chk.T(err) { -// t.Errorf("websocket.Message.Send: %v", err) -// } -// // res := []any{"OK", textNote.ID, false, "blocked"} -// }, -// ) -// defer ws.Close() -// -// // connect a client and send a text note -// rl := mustRelayConnect(ws.URL) -// if err = rl.Publish(context.Background(), textNote); !chk.E(err) { -// t.Errorf("should have failed to publish") -// } -// } - -func TestPublishWriteFailed(t *testing.T) { +func TestPublish(t *testing.T) { // test note to be sent over websocket - var err error - signer := &p256k.Signer{} - if err = signer.Generate(); chk.E(err) { - t.Fatal(err) - } + priv, pub := makeKeyPair(t) textNote := &event.E{ Kind: kind.TextNote, Content: []byte("hello"), - CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp - Pubkey: signer.Pub(), + CreatedAt: timestamp.New(1672068534), // random fixed timestamp + Tags: tags.New(tag.New("foo", "bar")), + Pubkey: pub, } - if err = textNote.Sign(signer); chk.E(err) { - t.Fatalf("textNote.Sign: %v", err) + sign := &p256k.Signer{} + var err error + if err = sign.InitSec(priv); chk.E(err) { } + err = textNote.Sign(sign) + assert.NoError(t, err) + + // fake relay server + var mu sync.Mutex // guards published to satisfy go test -race + var published bool + ws := newWebsocketServer( + func(conn *websocket.Conn) { + mu.Lock() + published = true + mu.Unlock() + // verify the client sent exactly the textNote + var raw []json.RawMessage + err := websocket.JSON.Receive(conn, &raw) + assert.NoError(t, err) + + event := parseEventMessage(t, raw) + assert.True(t, bytes.Equal(event.Serialize(), textNote.Serialize())) + + // send back an ok nip-20 command result + res := []any{"OK", textNote.IdString(), true, ""} + err = websocket.JSON.Send(conn, res) + assert.NoError(t, err) + }, + ) + defer ws.Close() + + // connect a client and send the text note + rl := mustRelayConnect(t, ws.URL) + err = rl.Publish(context.Background(), textNote) + assert.NoError(t, err) + + assert.True(t, published, "fake relay server saw no event") +} + +func TestPublishBlocked(t *testing.T) { + // test note to be sent over websocket + textNote := &event.E{ + Kind: kind.TextNote, Content: []byte("hello"), + CreatedAt: timestamp.Now(), + } + textNote.ID = textNote.GetIDBytes() + + // fake relay server + ws := newWebsocketServer( + func(conn *websocket.Conn) { + // discard received message; not interested + var raw []json.RawMessage + err := websocket.JSON.Receive(conn, &raw) + assert.NoError(t, err) + + // send back a not ok nip-20 command result + res := []any{"OK", textNote.IdString(), false, "blocked"} + websocket.JSON.Send(conn, res) + }, + ) + defer ws.Close() + + // connect a client and send a text note + rl := mustRelayConnect(t, ws.URL) + err := rl.Publish(context.Background(), textNote) + assert.Error(t, err) +} + +func TestPublishWriteFailed(t *testing.T) { + // test note to be sent over websocket + textNote := &event.E{ + Kind: kind.TextNote, Content: []byte("hello"), + CreatedAt: timestamp.Now(), + } + textNote.ID = textNote.GetIDBytes() // fake relay server ws := newWebsocketServer( func(conn *websocket.Conn) { @@ -179,15 +121,12 @@ func TestPublishWriteFailed(t *testing.T) { }, ) defer ws.Close() - // connect a client and send a text note - rl := mustRelayConnect(ws.URL) + rl := mustRelayConnect(t, ws.URL) // Force brief period of time so that publish always fails on closed socket. time.Sleep(1 * time.Millisecond) - err = rl.Publish(context.Background(), textNote) - if err == nil { - t.Errorf("should have failed to publish") - } + err := rl.Publish(context.Background(), textNote) + assert.Error(t, err) } func TestConnectContext(t *testing.T) { @@ -208,16 +147,13 @@ func TestConnectContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() r, err := RelayConnect(ctx, ws.URL) - if err != nil { - t.Fatalf("RelayConnectContext: %v", err) - } + assert.NoError(t, err) + defer r.Close() mu.Lock() defer mu.Unlock() - if !connected { - t.Error("fake relay server saw no client connect") - } + assert.True(t, connected, "fake relay server saw no client connect") } func TestConnectContextCanceled(t *testing.T) { @@ -229,11 +165,7 @@ func TestConnectContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // make ctx expired _, err := RelayConnect(ctx, ws.URL) - if !errors.Is(err, context.Canceled) { - t.Errorf( - "RelayConnectContext returned %v error; want context.Canceled", err, - ) - } + assert.ErrorIs(t, err, context.Canceled) } func TestConnectWithOrigin(t *testing.T) { @@ -243,21 +175,21 @@ func TestConnectWithOrigin(t *testing.T) { defer ws.Close() // relay client - r := NewRelay(context.Background(), string(normalize.URL(ws.URL))) - r.requestHeader = http.Header{"origin": {"https://example.com"}} + r := NewRelay( + context.Background(), string(normalize.URL(ws.URL)), + WithRequestHeader(http.Header{"origin": {"https://example.com"}}), + ) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := r.Connect(ctx) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + assert.NoError(t, err) } func discardingHandler(conn *websocket.Conn) { io.ReadAll(conn) // discard all input } -func newWebsocketServer(handler func(*websocket.Conn)) (server *httptest.Server) { +func newWebsocketServer(handler func(*websocket.Conn)) *httptest.Server { return httptest.NewServer( &websocket.Server{ Handshake: anyOriginHandshake, @@ -269,16 +201,76 @@ func newWebsocketServer(handler func(*websocket.Conn)) (server *httptest.Server) // anyOriginHandshake is an alternative to default in golang.org/x/net/websocket // which checks for origin. nostr client sends no origin and it makes no difference // for the tests here anyway. -var anyOriginHandshake = func( - conf *websocket.Config, r *http.Request, -) (err error) { +var anyOriginHandshake = func(conf *websocket.Config, r *http.Request) error { return nil } -func mustRelayConnect(url string) (client *Client) { - rl, err := RelayConnect(context.Background(), url) - if err != nil { - panic(err.Error()) +func makeKeyPair(t *testing.T) (sec, pub []byte) { + t.Helper() + sign := &p256k.Signer{} + var err error + if err = sign.Generate(); chk.E(err) { + return } + sec = sign.Sec() + pub = sign.Pub() + assert.NoError(t, err) + + return +} + +func mustRelayConnect(t *testing.T, url string) *Client { + t.Helper() + + rl, err := RelayConnect(context.Background(), url) + require.NoError(t, err) + return rl } + +func parseEventMessage(t *testing.T, raw []json.RawMessage) *event.E { + t.Helper() + + assert.Condition( + t, func() (success bool) { + return len(raw) >= 2 + }, + ) + + var typ string + err := json.Unmarshal(raw[0], &typ) + assert.NoError(t, err) + assert.Equal(t, "EVENT", typ) + + event := &event.E{} + _, err = event.Unmarshal(raw[1]) + require.NoError(t, err) + + return event +} + +func parseSubscriptionMessage( + t *testing.T, raw []json.RawMessage, +) (subid string, ff *filters.T) { + t.Helper() + + assert.Greater(t, len(raw), 3) + + var typ string + err := json.Unmarshal(raw[0], &typ) + + assert.NoError(t, err) + assert.Equal(t, "REQ", typ) + + var id string + err = json.Unmarshal(raw[1], &id) + assert.NoError(t, err) + ff = &filters.T{} + for _, b := range raw[2:] { + var f *filter.F + err = json.Unmarshal(b, &f) + assert.NoError(t, err) + ff.F = append(ff.F, f) + } + return id, ff +} diff --git a/pkg/protocol/ws/connection.go b/pkg/protocol/ws/connection.go index c9eeee0..86410bd 100644 --- a/pkg/protocol/ws/connection.go +++ b/pkg/protocol/ws/connection.go @@ -1,43 +1,18 @@ package ws import ( - "context" "crypto/tls" "errors" "fmt" "io" "net/http" - "net/textproto" + "orly.dev/pkg/utils/context" + "orly.dev/pkg/utils/units" "time" ws "github.com/coder/websocket" ) -var defaultConnectionOptions = &ws.DialOptions{ - CompressionMode: ws.CompressionContextTakeover, - HTTPHeader: http.Header{ - textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"}, - }, -} - -func getConnectionOptions( - requestHeader http.Header, tlsConfig *tls.Config, -) *ws.DialOptions { - if requestHeader == nil && tlsConfig == nil { - return defaultConnectionOptions - } - - return &ws.DialOptions{ - HTTPHeader: requestHeader, - CompressionMode: ws.CompressionContextTakeover, - HTTPClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, - }, - } -} - // Connection represents a websocket connection to a Nostr relay. type Connection struct { conn *ws.Conn @@ -45,42 +20,46 @@ type Connection struct { // NewConnection creates a new websocket connection to a Nostr relay. func NewConnection( - ctx context.Context, url string, requestHeader http.Header, + ctx context.T, url string, reqHeader http.Header, tlsConfig *tls.Config, -) (*Connection, error) { - c, _, err := ws.Dial( - ctx, url, getConnectionOptions(requestHeader, tlsConfig), - ) - if err != nil { - return nil, err +) (c *Connection, err error) { + var conn *ws.Conn + if conn, _, err = ws.Dial( + ctx, url, getConnectionOptions(reqHeader, tlsConfig), + ); err != nil { + return } - - c.SetReadLimit(2 << 24) // 33MB - + conn.SetReadLimit(33 * units.Mb) return &Connection{ - conn: c, + conn: conn, }, nil } // WriteMessage writes arbitrary bytes to the websocket connection. -func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { - if err := c.conn.Write(ctx, ws.MessageText, data); err != nil { - return fmt.Errorf("failed to write message: %w", err) +func (c *Connection) WriteMessage( + ctx context.T, data []byte, +) (err error) { + if err = c.conn.Write(ctx, ws.MessageText, data); err != nil { + err = fmt.Errorf("failed to write message: %w", err) + return } - return nil } // ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer. -func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { - _, reader, err := c.conn.Reader(ctx) - if err != nil { - return fmt.Errorf("failed to get reader: %w", err) +func (c *Connection) ReadMessage( + ctx context.T, buf io.Writer, +) (err error) { + var reader io.Reader + if _, reader, err = c.conn.Reader(ctx); err != nil { + err = fmt.Errorf("failed to get reader: %w", err) + return } - if _, err := io.Copy(buf, reader); err != nil { - return fmt.Errorf("failed to read message: %w", err) + if _, err = io.Copy(buf, reader); err != nil { + err = fmt.Errorf("failed to read message: %w", err) + return } - return nil + return } // Close closes the websocket connection. @@ -89,8 +68,8 @@ func (c *Connection) Close() error { } // Ping sends a ping message to the websocket connection. -func (c *Connection) Ping(ctx context.Context) error { - ctx, cancel := context.WithTimeoutCause( +func (c *Connection) Ping(ctx context.T) error { + ctx, cancel := context.TimeoutCause( ctx, time.Millisecond*800, errors.New("ping took too long"), ) defer cancel() diff --git a/pkg/protocol/ws/connection_options.go b/pkg/protocol/ws/connection_options.go new file mode 100644 index 0000000..ae187fe --- /dev/null +++ b/pkg/protocol/ws/connection_options.go @@ -0,0 +1,36 @@ +//go:build !js + +package ws + +import ( + "crypto/tls" + "net/http" + "net/textproto" + + ws "github.com/coder/websocket" +) + +var defaultConnectionOptions = &ws.DialOptions{ + CompressionMode: ws.CompressionContextTakeover, + HTTPHeader: http.Header{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"}, + }, +} + +func getConnectionOptions( + requestHeader http.Header, tlsConfig *tls.Config, +) *ws.DialOptions { + if requestHeader == nil && tlsConfig == nil { + return defaultConnectionOptions + } + + return &ws.DialOptions{ + HTTPHeader: requestHeader, + CompressionMode: ws.CompressionContextTakeover, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }, + } +} diff --git a/pkg/protocol/ws/pool.go b/pkg/protocol/ws/pool.go deleted file mode 100644 index 5830c3f..0000000 --- a/pkg/protocol/ws/pool.go +++ /dev/null @@ -1,905 +0,0 @@ -package ws - -import ( - "errors" - "fmt" - "math" - "net/http" - "slices" - "strings" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/puzpuzpuz/xsync/v3" - "orly.dev/pkg/encoders/event" - "orly.dev/pkg/encoders/filter" - "orly.dev/pkg/encoders/filters" - "orly.dev/pkg/encoders/hex" - "orly.dev/pkg/encoders/timestamp" - "orly.dev/pkg/interfaces/signer" - "orly.dev/pkg/utils/context" - "orly.dev/pkg/utils/log" - "orly.dev/pkg/utils/normalize" -) - -const ( - seenAlreadyDropTick = time.Minute -) - -// Pool manages connections to multiple relays, ensures they are reopened when necessary and not duplicated. -type Pool struct { - Relays *xsync.MapOf[string, *Client] - Context context.T - - authHandler func() signer.I - cancel context.C - - eventMiddleware func(RelayEvent) - duplicateMiddleware func(relay, id string) - queryMiddleware func(relay, pubkey string, kind uint16) - - // custom things not often used - penaltyBoxMu sync.Mutex - penaltyBox map[string][2]float64 - relayOptions []RelayOption -} - -// DirectedFilter combines a Filter with a specific relay URL. -type DirectedFilter struct { - *filter.F - Relay string -} - -// RelayEvent represents an event received from a specific relay. -type RelayEvent struct { - *event.E - Relay *Client -} - -func (ie RelayEvent) String() string { - return fmt.Sprintf( - "[%s] >> %s", ie.Relay.URL, ie.E.Marshal(nil), - ) -} - -// PoolOption is an interface for options that can be applied to a Pool. -type PoolOption interface { - ApplyPoolOption(*Pool) -} - -// NewPool creates a new Pool with the given context and options. -func NewPool(c context.T, opts ...PoolOption) (pool *Pool) { - ctx, cancel := context.Cause(c) - pool = &Pool{ - Relays: xsync.NewMapOf[string, *Client](), - Context: ctx, - cancel: cancel, - } - for _, opt := range opts { - opt.ApplyPoolOption(pool) - } - return pool -} - -// WithRelayOptions sets options that will be used on every relay instance created by this pool. -func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt { - return ropts -} - -type withRelayOptionsOpt []RelayOption - -func (h withRelayOptionsOpt) ApplyPoolOption(pool *Pool) { - pool.relayOptions = h -} - -// WithAuthHandler must be a function that signs the auth event when called. -// it will be called whenever any relay in the pool returns a `CLOSED` message -// with the "auth-required:" prefix, only once for each relay -type WithAuthHandler func() signer.I - -func (h WithAuthHandler) ApplyPoolOption(pool *Pool) { - pool.authHandler = h -} - -// WithPenaltyBox just sets the penalty box mechanism so relays that fail to connect -// or that disconnect will be ignored for a while and we won't attempt to connect again. -func WithPenaltyBox() withPenaltyBoxOpt { return withPenaltyBoxOpt{} } - -type withPenaltyBoxOpt struct{} - -func (h withPenaltyBoxOpt) ApplyPoolOption(pool *Pool) { - pool.penaltyBox = make(map[string][2]float64) - go func() { - sleep := 30.0 - for { - time.Sleep(time.Duration(sleep) * time.Second) - - pool.penaltyBoxMu.Lock() - nextSleep := 300.0 - for url, v := range pool.penaltyBox { - remainingSeconds := v[1] - remainingSeconds -= sleep - if remainingSeconds <= 0 { - pool.penaltyBox[url] = [2]float64{v[0], 0} - continue - } else { - pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds} - } - - if remainingSeconds < nextSleep { - nextSleep = remainingSeconds - } - } - - sleep = nextSleep - pool.penaltyBoxMu.Unlock() - } - }() -} - -// WithEventMiddleware is a function that will be called with all events received. -type WithEventMiddleware func(RelayEvent) - -func (h WithEventMiddleware) ApplyPoolOption(pool *Pool) { - pool.eventMiddleware = h -} - -// WithDuplicateMiddleware is a function that will be called with all duplicate ids received. -type WithDuplicateMiddleware func(relay string, id string) - -func (h WithDuplicateMiddleware) ApplyPoolOption(pool *Pool) { - pool.duplicateMiddleware = h -} - -// WithAuthorKindQueryMiddleware is a function that will be called with every combination of relay+pubkey+kind queried -// in a .SubMany*() call -- when applicable (i.e. when the query contains a pubkey and a kind). -type WithAuthorKindQueryMiddleware func(relay, pubkey string, kind uint16) - -func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *Pool) { - pool.queryMiddleware = h -} - -var ( - _ PoolOption = (WithAuthHandler)(nil) - _ PoolOption = (WithEventMiddleware)(nil) - _ PoolOption = WithPenaltyBox() - _ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{})) -) - -const MAX_LOCKS = 50 - -var namedMutexPool = make([]sync.Mutex, MAX_LOCKS) - -//go:noescape -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, h, s uintptr) uintptr - -func namedLock[V ~[]byte | ~string](name V) (unlock func()) { - sptr := unsafe.StringData(string(name)) - idx := uint64( - memhash( - unsafe.Pointer(sptr), 0, uintptr(len(name)), - ), - ) % MAX_LOCKS - namedMutexPool[idx].Lock() - return namedMutexPool[idx].Unlock -} - -// EnsureRelay ensures that a relay connection exists and is active. -// If the relay is not connected, it attempts to connect. -func (p *Pool) EnsureRelay(url string) (*Client, error) { - nm := string(normalize.URL(url)) - defer namedLock(nm)() - relay, ok := p.Relays.Load(nm) - if ok && relay == nil { - if p.penaltyBox != nil { - p.penaltyBoxMu.Lock() - defer p.penaltyBoxMu.Unlock() - v, _ := p.penaltyBox[nm] - if v[1] > 0 { - return nil, fmt.Errorf("in penalty box, %fs remaining", v[1]) - } - } - } else if ok && relay.IsConnected() { - // already connected, unlock and return - return relay, nil - } - // try to connect - // we use this ctx here so when the p dies everything dies - ctx, cancel := context.TimeoutCause( - p.Context, - time.Second*15, - errors.New("connecting to the relay took too long"), - ) - defer cancel() - relay = NewRelay(context.Bg(), url, p.relayOptions...) - if err := relay.Connect(ctx); err != nil { - if p.penaltyBox != nil { - // putting relay in penalty box - p.penaltyBoxMu.Lock() - defer p.penaltyBoxMu.Unlock() - v, _ := p.penaltyBox[nm] - p.penaltyBox[nm] = [2]float64{ - v[0] + 1, 30.0 + math.Pow(2, v[0]+1), - } - } - return nil, fmt.Errorf("failed to connect: %w", err) - } - - p.Relays.Store(nm, relay) - return relay, nil -} - -// PublishResult represents the result of publishing an event to a relay. -type PublishResult struct { - Error error - RelayURL string - Relay *Client -} - -// todo: this didn't used to be in this package... probably don't want to add it -// either. -// -// PublishMany publishes an event to multiple relays and returns a -// channel of results emitted as they're received. - -// func (pool *Pool) PublishMany( -// ctx context.T, urls []string, evt *event.E, -// ) chan PublishResult { -// ch := make(chan PublishResult, len(urls)) -// wg := sync.WaitGroup{} -// wg.Add(len(urls)) -// go func() { -// for _, url := range urls { -// go func() { -// defer wg.Done() -// relay, err := pool.EnsureRelay(url) -// if err != nil { -// ch <- PublishResult{err, url, nil} -// return -// } -// if err = relay.Publish(ctx, evt); err == nil { -// // success with no auth required -// ch <- PublishResult{nil, url, relay} -// } else if strings.HasPrefix( -// err.Error(), "msg: auth-required:", -// ) && pool.authHandler != nil { -// // try to authenticate if we can -// if authErr := relay.Auth( -// ctx, pool.authHandler(), -// ); authErr == nil { -// if err := relay.Publish(ctx, evt); err == nil { -// // success after auth -// ch <- PublishResult{nil, url, relay} -// } else { -// // failure after auth -// ch <- PublishResult{err, url, relay} -// } -// } else { -// // failure to auth -// ch <- PublishResult{ -// fmt.Errorf( -// "failed to auth: %w", authErr, -// ), url, relay, -// } -// } -// } else { -// // direct failure -// ch <- PublishResult{err, url, relay} -// } -// }() -// } -// -// wg.Wait() -// close(ch) -// }() -// -// return ch -// } - -// SubscribeMany opens a subscription with the given filter to multiple relays -// the subscriptions ends when the context is canceled or when all relays return a CLOSED. -func (p *Pool) SubscribeMany( - ctx context.T, - urls []string, - filter *filter.F, - opts ...SubscriptionOption, -) chan RelayEvent { - return p.subMany(ctx, urls, filters.New(filter), nil, opts...) -} - -// FetchMany opens a subscription, much like SubscribeMany, but it ends as soon as all Relays -// return an EOSE message. -func (p *Pool) FetchMany( - ctx context.T, - urls []string, - filter *filter.F, - opts ...SubscriptionOption, -) chan RelayEvent { - return p.SubManyEose(ctx, urls, filters.New(filter), opts...) -} - -// Deprecated: SubMany is deprecated: use SubscribeMany instead. -func (p *Pool) SubMany( - ctx context.T, - urls []string, - filters *filters.T, - opts ...SubscriptionOption, -) chan RelayEvent { - return p.subMany(ctx, urls, filters, nil, opts...) -} - -// SubscribeManyNotifyEOSE is like SubscribeMany, but takes a channel that is closed when -// all subscriptions have received an EOSE -func (p *Pool) SubscribeManyNotifyEOSE( - ctx context.T, - urls []string, - filter *filter.F, - eoseChan chan struct{}, - opts ...SubscriptionOption, -) chan RelayEvent { - return p.subMany(ctx, urls, filters.New(filter), eoseChan, opts...) -} - -type ReplaceableKey struct { - PubKey string - D string -} - -// FetchManyReplaceable is like FetchMany, but deduplicates replaceable and addressable events and returns -// only the latest for each "d" tag. -func (p *Pool) FetchManyReplaceable( - ctx context.T, - urls []string, - f *filter.F, - opts ...SubscriptionOption, -) *xsync.MapOf[ReplaceableKey, *event.E] { - ctx, cancel := context.Cause(ctx) - results := xsync.NewMapOf[ReplaceableKey, *event.E]() - wg := sync.WaitGroup{} - wg.Add(len(urls)) - // todo: this is a hack for compensating for retarded relays that don't - // filter replaceable events because it streams them back over a channel. - // this is out of spec anyway so should not be handled. replaceable events - // are supposed to delete old versions. the end. this is for the incorrect - // behaviour of fiatjaf's database code, which he obviously thinks is clever - // for using channels, and not sorting results before dispatching them - // before EOSE. - _ = 0 - // seenAlreadyLatest := xsync.NewMapOf[ReplaceableKey, - // *timestamp.T]() opts = append( - // opts, WithCheckDuplicateReplaceable( - // func(rk ReplaceableKey, ts Timestamp) bool { - // updated := false - // seenAlreadyLatest.Compute( - // rk, func(latest Timestamp, _ bool) ( - // newValue Timestamp, delete bool, - // ) { - // if ts > latest { - // updated = true // we are updating the most recent - // return ts, false - // } - // return latest, false // the one we had was already more recent - // }, - // ) - // return updated - // }, - // ), - // ) - - for _, url := range urls { - go func(nm string) { - defer wg.Done() - - if mh := p.queryMiddleware; mh != nil { - if f.Kinds != nil && f.Authors != nil { - for _, kind := range f.Kinds.K { - for _, author := range f.Authors.ToStringSlice() { - mh(nm, author, kind.K) - } - } - } - } - - relay, err := p.EnsureRelay(nm) - if err != nil { - log.D.C( - func() string { - return fmt.Sprintf( - "error connecting to %s with %v: %s", nm, f, err, - ) - }, - ) - return - } - - hasAuthed := false - - subscribe: - sub, err := relay.Subscribe(ctx, filters.New(f), opts...) - if err != nil { - log.D.C( - func() string { - return fmt.Sprintf( - "error subscribing to %s with %v: %s", relay, f, - err, - ) - }, - ) - return - } - - for { - select { - case <-ctx.Done(): - return - case <-sub.EndOfStoredEvents: - return - case reason := <-sub.ClosedReason: - if strings.HasPrefix( - reason, "auth-required:", - ) && p.authHandler != nil && !hasAuthed { - // relay is requesting auth. if we can we will perform auth and try again - err = relay.Auth( - ctx, p.authHandler(), - ) - if err == nil { - hasAuthed = true // so we don't keep doing AUTH again and again - goto subscribe - } - } - log.D.F("CLOSED from %s: '%s'\n", nm, reason) - return - case evt, more := <-sub.Events: - if !more { - return - } - - ie := RelayEvent{E: evt, Relay: relay} - if mh := p.eventMiddleware; mh != nil { - mh(ie) - } - - results.Store( - ReplaceableKey{hex.Enc(evt.Pubkey), evt.Tags.GetD()}, - evt, - ) - } - } - }(string(normalize.URL(url))) - } - - // this will happen when all subscriptions get an eose (or when they die) - wg.Wait() - cancel(errors.New("all subscriptions ended")) - - return results -} - -func (p *Pool) subMany( - ctx context.T, - urls []string, - ff *filters.T, - eoseChan chan struct{}, - opts ...SubscriptionOption, -) chan RelayEvent { - ctx, cancel := context.Cause(ctx) - _ = cancel // do this so `go vet` will stop complaining - events := make(chan RelayEvent) - seenAlready := xsync.NewMapOf[string, *timestamp.T]() - ticker := time.NewTicker(seenAlreadyDropTick) - - eoseWg := sync.WaitGroup{} - eoseWg.Add(len(urls)) - if eoseChan != nil { - go func() { - eoseWg.Wait() - close(eoseChan) - }() - } - - pending := xsync.NewCounter() - pending.Add(int64(len(urls))) - for i, url := range urls { - url = string(normalize.URL(url)) - urls[i] = url - if idx := slices.Index(urls, url); idx != i { - // skip duplicate relays in the list - eoseWg.Done() - continue - } - - eosed := atomic.Bool{} - firstConnection := true - - go func(nm string) { - defer func() { - pending.Dec() - if pending.Value() == 0 { - close(events) - cancel(fmt.Errorf("aborted: %w", context.GetCause(ctx))) - } - if eosed.CompareAndSwap(false, true) { - eoseWg.Done() - } - }() - - hasAuthed := false - interval := 3 * time.Second - for { - select { - case <-ctx.Done(): - return - default: - } - - var sub *Subscription - - if mh := p.queryMiddleware; mh != nil { - for _, f := range ff.F { - if f.Kinds != nil && f.Authors != nil { - for _, k := range f.Kinds.K { - for _, author := range f.Authors.ToSliceOfBytes() { - mh(nm, hex.Enc(author), k.K) - } - } - } - } - } - - relay, err := p.EnsureRelay(nm) - if err != nil { - // if we never connected to this just fail - if firstConnection { - return - } - - // otherwise (if we were connected and got disconnected) keep trying to reconnect - log.D.F("%s reconnecting because connection failed", nm) - goto reconnect - } - firstConnection = false - hasAuthed = false - - subscribe: - sub, err = relay.Subscribe( - ctx, ff, - // append( - opts..., - // WithCheckDuplicate( - // func(id, relay string) bool { - // _, exists := seenAlready.Load(id) - // if exists && p.duplicateMiddleware != nil { - // p.duplicateMiddleware(relay, id) - // } - // return exists - // }, - // ), - // )..., - ) - if err != nil { - log.D.F("%s reconnecting because subscription died", nm) - goto reconnect - } - - go func() { - <-sub.EndOfStoredEvents - - // guard here otherwise a resubscription will trigger a duplicate call to eoseWg.Done() - if eosed.CompareAndSwap(false, true) { - eoseWg.Done() - } - }() - - // reset interval when we get a good subscription - interval = 3 * time.Second - - for { - select { - case evt, more := <-sub.Events: - if !more { - // this means the connection was closed for weird reasons, like the server shut down - // so we will update the filters here to include only events seem from now on - // and try to reconnect until we succeed - now := timestamp.Now() - for i := range ff.F { - ff.F[i].Since = now - } - log.D.F( - "%s reconnecting because sub.Events is broken", - nm, - ) - goto reconnect - } - - ie := RelayEvent{E: evt, Relay: relay} - if mh := p.eventMiddleware; mh != nil { - mh(ie) - } - - select { - case events <- ie: - case <-ctx.Done(): - return - } - case <-ticker.C: - if eosed.Load() { - old := timestamp.New(time.Now().Add(-seenAlreadyDropTick).Unix()) - for id, value := range seenAlready.Range { - if value.I64() < old.I64() { - seenAlready.Delete(id) - } - } - } - case reason := <-sub.ClosedReason: - if strings.HasPrefix( - reason, "auth-required:", - ) && p.authHandler != nil && !hasAuthed { - // relay is requesting auth. if we can we will perform auth and try again - err = relay.Auth( - ctx, p.authHandler(), - ) - if err == nil { - hasAuthed = true // so we don't keep doing AUTH again and again - goto subscribe - } - } else { - log.D.F("CLOSED from %s: '%s'\n", nm, reason) - } - - return - case <-ctx.Done(): - return - } - } - - reconnect: - // we will go back to the beginning of the loop and try to connect again and again - // until the context is canceled - time.Sleep(interval) - interval = interval * 17 / 10 // the next time we try we will wait longer - } - }(url) - } - - return events -} - -// Deprecated: SubManyEose is deprecated: use FetchMany instead. -func (p *Pool) SubManyEose( - ctx context.T, - urls []string, - filters *filters.T, - opts ...SubscriptionOption, -) chan RelayEvent { - // seenAlready := xsync.NewMapOf[string, struct{}]() - return p.subManyEoseNonOverwriteCheckDuplicate( - ctx, urls, filters, - // WithCheckDuplicate( - // func(id, relay string) bool { - // _, exists := seenAlready.LoadOrStore(id, struct{}{}) - // if exists && p.duplicateMiddleware != nil { - // p.duplicateMiddleware(relay, id) - // } - // return exists - // }, - // ), - opts..., - ) -} - -func (p *Pool) subManyEoseNonOverwriteCheckDuplicate( - ctx context.T, - urls []string, - filters *filters.T, - // wcd WithCheckDuplicate, - opts ...SubscriptionOption, -) chan RelayEvent { - ctx, cancel := context.Cause(ctx) - - events := make(chan RelayEvent) - wg := sync.WaitGroup{} - wg.Add(len(urls)) - - // opts = append(opts, wcd) - - go func() { - // this will happen when all subscriptions get an eose (or when they die) - wg.Wait() - cancel(errors.New("all subscriptions ended")) - close(events) - }() - - for _, url := range urls { - go func(nm string) { - defer wg.Done() - - if mh := p.queryMiddleware; mh != nil { - for _, filter := range filters.F { - if filter.Kinds != nil && filter.Authors != nil { - for _, k := range filter.Kinds.K { - for _, author := range filter.Authors.ToSliceOfBytes() { - mh(nm, hex.Enc(author), k.K) - } - } - } - } - } - - relay, err := p.EnsureRelay(nm) - if err != nil { - log.D.C( - func() string { - return fmt.Sprintf( - "error connecting to %s with %v: %s", nm, filters, - err, - ) - }, - ) - return - } - - hasAuthed := false - - subscribe: - sub, err := relay.Subscribe(ctx, filters, opts...) - if err != nil { - log.D.C( - func() string { - return fmt.Sprintf( - "error subscribing to %s with %v: %s", relay, - filters, - err, - ) - }, - ) - return - } - - for { - select { - case <-ctx.Done(): - return - case <-sub.EndOfStoredEvents: - return - case reason := <-sub.ClosedReason: - if strings.HasPrefix( - reason, "auth-required:", - ) && p.authHandler != nil && !hasAuthed { - // relay is requesting auth. if we can we will perform auth and try again - err = relay.Auth( - ctx, p.authHandler(), - ) - if err == nil { - hasAuthed = true // so we don't keep doing AUTH again and again - goto subscribe - } - } - log.D.F("CLOSED from %s: '%s'\n", nm, reason) - return - case evt, more := <-sub.Events: - if !more { - return - } - - ie := RelayEvent{E: evt, Relay: relay} - if mh := p.eventMiddleware; mh != nil { - mh(ie) - } - - select { - case events <- ie: - case <-ctx.Done(): - return - } - } - } - }(string(normalize.URL(url))) - } - - return events -} - -// // CountMany aggregates count results from multiple relays using NIP-45 HyperLogLog -// func (pool *Pool) CountMany( -// ctx context.T, -// urls []string, -// filter *filter.F, -// opts []SubscriptionOption, -// ) int { -// hll := hyperloglog.New(0) // offset is irrelevant here -// -// wg := sync.WaitGroup{} -// wg.Add(len(urls)) -// for _, url := range urls { -// go func(nm string) { -// defer wg.Done() -// relay, err := pool.EnsureRelay(url) -// if err != nil { -// return -// } -// ce, err := relay.countInternal(ctx, Filters{filter}, opts...) -// if err != nil { -// return -// } -// if len(ce.HyperLogLog) != 256 { -// return -// } -// hll.MergeRegisters(ce.HyperLogLog) -// }(NormalizeURL(url)) -// } -// -// wg.Wait() -// return int(hll.Count()) -// } - -// QuerySingle returns the first event returned by the first relay, cancels everything else. -func (p *Pool) QuerySingle( - ctx context.T, - urls []string, - filter *filter.F, - opts ...SubscriptionOption, -) *RelayEvent { - ctx, cancel := context.Cause(ctx) - for ievt := range p.SubManyEose( - ctx, urls, filters.New(filter), opts..., - ) { - cancel(errors.New("got the first event and ended successfully")) - return &ievt - } - cancel(errors.New("SubManyEose() didn't get yield events")) - return nil -} - -// BatchedSubManyEose performs batched subscriptions to multiple relays with different filters. -func (p *Pool) BatchedSubManyEose( - ctx context.T, - dfs []DirectedFilter, - opts ...SubscriptionOption, -) chan RelayEvent { - res := make(chan RelayEvent) - wg := sync.WaitGroup{} - wg.Add(len(dfs)) - // seenAlready := xsync.NewMapOf[string, struct{}]() - for _, df := range dfs { - go func(df DirectedFilter) { - for ie := range p.subManyEoseNonOverwriteCheckDuplicate( - ctx, - []string{df.Relay}, - filters.New(df.F), - // WithCheckDuplicate( - // func(id, relay string) bool { - // _, exists := seenAlready.LoadOrStore(id, struct{}{}) - // if exists && p.duplicateMiddleware != nil { - // p.duplicateMiddleware(relay, id) - // } - // return exists - // }, - // ), - opts..., - ) { - select { - case res <- ie: - case <-ctx.Done(): - wg.Done() - return - } - } - wg.Done() - }(df) - } - - go func() { - wg.Wait() - close(res) - }() - - return res -} - -// Close closes the pool with the given reason. -func (p *Pool) Close(reason string) { - p.cancel(fmt.Errorf("pool closed with reason: '%s'", reason)) -} diff --git a/pkg/protocol/ws/pool_test.go b/pkg/protocol/ws/pool_test.go deleted file mode 100644 index b7a3ac9..0000000 --- a/pkg/protocol/ws/pool_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package ws - -import ( - "context" - "sync" - "testing" - "time" - - "orly.dev/pkg/encoders/event" - "orly.dev/pkg/encoders/filter" - "orly.dev/pkg/encoders/kind" - "orly.dev/pkg/encoders/timestamp" - "orly.dev/pkg/interfaces/signer" -) - -// mockSigner implements signer.I for testing -type mockSigner struct { - pubkey []byte -} - -func (m *mockSigner) Pub() []byte { return m.pubkey } -func (m *mockSigner) Sign([]byte) ( - []byte, error, -) { - return []byte("mock-signature"), nil -} -func (m *mockSigner) Generate() error { return nil } -func (m *mockSigner) InitSec([]byte) error { return nil } -func (m *mockSigner) InitPub([]byte) error { return nil } -func (m *mockSigner) Sec() []byte { return []byte("mock-secret") } -func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil } -func (m *mockSigner) Zero() {} -func (m *mockSigner) ECDH([]byte) ( - []byte, error, -) { - return []byte("mock-shared-secret"), nil -} - -func TestNewPool(t *testing.T) { - ctx := context.Background() - pool := NewPool(ctx) - - if pool == nil { - t.Fatal("NewPool returned nil") - } - - if pool.Relays == nil { - t.Error("Pool should have initialized Relays map") - } - - if pool.Context == nil { - t.Error("Pool should have a context") - } -} - -func TestPoolWithAuthHandler(t *testing.T) { - ctx := context.Background() - - authHandler := WithAuthHandler( - func() signer.I { - return &mockSigner{pubkey: []byte("test-pubkey")} - }, - ) - - pool := NewPool(ctx, authHandler) - - if pool.authHandler == nil { - t.Error("Pool should have auth handler set") - } - - // Test that auth handler returns the expected signer - signer := pool.authHandler() - if string(signer.Pub()) != "test-pubkey" { - t.Errorf( - "Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()), - ) - } -} - -func TestPoolWithEventMiddleware(t *testing.T) { - ctx := context.Background() - - var middlewareCalled bool - middleware := WithEventMiddleware( - func(ie RelayEvent) { - middlewareCalled = true - }, - ) - - pool := NewPool(ctx, middleware) - - // Test that middleware is called - testEvent := &event.E{ - Kind: kind.TextNote, - Content: []byte("test"), - CreatedAt: timestamp.Now(), - } - - ie := RelayEvent{E: testEvent, Relay: nil} - pool.eventMiddleware(ie) - - if !middlewareCalled { - t.Error("Expected middleware to be called") - } -} - -func TestRelayEventString(t *testing.T) { - testEvent := &event.E{ - Kind: kind.TextNote, - Content: []byte("test content"), - CreatedAt: timestamp.Now(), - } - - client := &Client{URL: "wss://test.relay"} - ie := RelayEvent{E: testEvent, Relay: client} - - str := ie.String() - if !contains(str, "wss://test.relay") { - t.Errorf("Expected string to contain relay URL, got: %s", str) - } - - if !contains(str, "test content") { - t.Errorf("Expected string to contain event content, got: %s", str) - } -} - -func TestNamedLock(t *testing.T) { - // Test that named locks work correctly - var wg sync.WaitGroup - var counter int - var mu sync.Mutex - - lockName := "test-lock" - - // Start multiple goroutines that try to increment counter - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - unlock := namedLock(lockName) - defer unlock() - - // Critical section - mu.Lock() - temp := counter - time.Sleep(1 * time.Millisecond) // Simulate work - counter = temp + 1 - mu.Unlock() - }() - } - - wg.Wait() - - if counter != 10 { - t.Errorf("Expected counter to be 10, got %d", counter) - } -} - -func TestPoolEnsureRelayInvalidURL(t *testing.T) { - ctx := context.Background() - pool := NewPool(ctx) - - // Test with invalid URL - _, err := pool.EnsureRelay("invalid-url") - if err == nil { - t.Error("Expected error for invalid URL") - } -} - -func TestPoolQuerySingle(t *testing.T) { - ctx := context.Background() - pool := NewPool(ctx) - - // Test with empty URLs slice - result := pool.QuerySingle(ctx, []string{}, &filter.F{}) - if result != nil { - t.Error("Expected nil result for empty URLs") - } -} - -// Helper functions -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && - (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || - containsSubstring(s, substr))) -} - -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - -func uintPtr(u uint) *uint { - return &u -} - -// // Test pool context cancellation -// func TestPoolContextCancellation(t *testing.T) { -// ctx, cancel := context.WithCancel(context.Background()) -// pool := NewPool(ctx) -// -// // Cancel the context -// cancel() -// -// // Check that pool context is cancelled -// select { -// case <-pool.Context.Done(): -// // Expected -// case <-time.After(100 * time.Millisecond): -// t.Error("Expected pool context to be cancelled") -// } -// } diff --git a/pkg/protocol/ws/subscription.go b/pkg/protocol/ws/subscription.go index eab786f..589c282 100644 --- a/pkg/protocol/ws/subscription.go +++ b/pkg/protocol/ws/subscription.go @@ -4,31 +4,32 @@ import ( "context" "errors" "fmt" - "sync" - "sync/atomic" - "orly.dev/pkg/encoders/envelopes/closeenvelope" "orly.dev/pkg/encoders/envelopes/reqenvelope" "orly.dev/pkg/encoders/event" "orly.dev/pkg/encoders/filters" "orly.dev/pkg/encoders/subscription" - "orly.dev/pkg/utils/chk" + "orly.dev/pkg/encoders/timestamp" + "sync" + "sync/atomic" ) +type ReplaceableKey struct { + PubKey string + D string +} + // Subscription represents a subscription to a relay. type Subscription struct { counter int64 - id string + id *subscription.Id - Relay *Client + Client *Client Filters *filters.T - // // for this to be treated as a COUNT and not a REQ this must be set - // countResult chan CountEnvelope - // the Events channel emits all EVENTs that come in a Subscription // will be closed when the subscription ends - Events event.C + Events chan *event.E mu sync.Mutex // the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription @@ -40,13 +41,13 @@ type Subscription struct { // Context will be .Done() when the subscription ends Context context.Context - // // if it is not nil, checkDuplicate will be called for every event received - // // if it returns true that event will not be processed further. - // checkDuplicate func(id string, relay string) bool - // - // // if it is not nil, checkDuplicateReplaceable will be called for every event received - // // if it returns true that event will not be processed further. - // checkDuplicateReplaceable func(rk ReplaceableKey, ts Timestamp) bool + // if it is not nil, checkDuplicate will be called for every event received + // if it returns true that event will not be processed further. + checkDuplicate func(id string, relay string) bool + + // if it is not nil, checkDuplicateReplaceable will be called for every event received + // if it returns true that event will not be processed further. + checkDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool match func(*event.E) bool // this will be either Filters.Match or Filters.MatchIgnoringTimestampConstraints live atomic.Bool @@ -69,20 +70,20 @@ type WithLabel string func (_ WithLabel) IsSubscriptionOption() {} -// // WithCheckDuplicate sets checkDuplicate on the subscription -// type WithCheckDuplicate func(id, relay string) bool -// -// func (_ WithCheckDuplicate) IsSubscriptionOption() {} -// -// // WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription -// type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool -// -// func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {} +// WithCheckDuplicate sets checkDuplicate on the subscription +type WithCheckDuplicate func(id, relay string) bool + +func (_ WithCheckDuplicate) IsSubscriptionOption() {} + +// WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription +type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool + +func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {} var ( _ SubscriptionOption = (WithLabel)("") - // _ SubscriptionOption = (WithCheckDuplicate)(nil) - // _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil) + _ SubscriptionOption = (WithCheckDuplicate)(nil) + _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil) ) func (sub *Subscription) start() { @@ -98,7 +99,7 @@ func (sub *Subscription) start() { } // GetID returns the subscription ID. -func (sub *Subscription) GetID() string { return sub.id } +func (sub *Subscription) GetID() string { return sub.id.String() } func (sub *Subscription) dispatchEvent(evt *event.E) { added := false @@ -117,6 +118,7 @@ func (sub *Subscription) dispatchEvent(evt *event.E) { case <-sub.Context.Done(): } } + if added { sub.storedwg.Done() } @@ -159,19 +161,15 @@ func (sub *Subscription) unsub(err error) { } // remove subscription from our map - sub.Relay.Subscriptions.Delete(sub.counter) + sub.Client.Subscriptions.Delete(sub.id.String()) } // Close just sends a CLOSE message. You probably want Unsub() instead. func (sub *Subscription) Close() { - if sub.Relay.IsConnected() { - id, err := subscription.NewId(sub.id) - if err != nil { - return - } - closeMsg := closeenvelope.NewFrom(id) + if sub.Client.IsConnected() { + closeMsg := closeenvelope.NewFrom(sub.id) closeb := closeMsg.Marshal(nil) - <-sub.Relay.Write(closeb) + <-sub.Client.Write(closeb) } } @@ -184,21 +182,14 @@ func (sub *Subscription) Sub(_ context.Context, ff *filters.T) { // Fire sends the "REQ" command to the relay. func (sub *Subscription) Fire() (err error) { - // if sub.countResult == nil { - req := reqenvelope.NewWithIdString(sub.id, sub.Filters) - if req == nil { - return fmt.Errorf("invalid ID or filters") - } - reqb := req.Marshal(nil) - // } else - // if len(sub.Filters) == 1 { - // reqb, _ = CountEnvelope{sub.id, sub.Filters[0], nil, nil}.MarshalJSON() - // } else { - // return fmt.Errorf("unexpected sub configuration") + var reqb []byte + reqb = reqenvelope.NewFrom(sub.id, sub.Filters).Marshal(nil) sub.live.Store(true) - if err = <-sub.Relay.Write(reqb); chk.E(err) { + if err = <-sub.Client.Write(reqb); err != nil { err = fmt.Errorf("failed to write: %w", err) sub.cancel(err) + return } + return } diff --git a/pkg/protocol/ws/subscription_test.go b/pkg/protocol/ws/subscription_test.go index 402a795..31b1e76 100644 --- a/pkg/protocol/ws/subscription_test.go +++ b/pkg/protocol/ws/subscription_test.go @@ -1,117 +1,120 @@ package ws +import ( + "context" + "orly.dev/pkg/encoders/filter" + "orly.dev/pkg/encoders/filters" + "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/kinds" + "orly.dev/pkg/encoders/tag" + "orly.dev/pkg/encoders/tags" + "orly.dev/pkg/utils/values" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + const RELAY = "wss://nos.lol" -// // test if we can fetch a couple of random events -// func TestSubscribeBasic(t *testing.F) { -// rl := mustRelayConnect(RELAY) -// defer rl.Close() -// var lim uint = 2 -// sub, err := rl.Subscribe(context.Bg(), -// filters.New(&filter.F{Kinds: kinds.New(kind.TextNote), Limit: &lim})) -// if err != nil { -// t.Fatalf("subscription failed: %v", err) -// return -// } -// timeout := time.After(5 * time.Second) -// n := 0 -// for { -// select { -// case event := <-sub.Events: -// if event == nil { -// t.Fatalf("event is nil: %v", event) -// } -// n++ -// case <-sub.EndOfStoredEvents: -// goto end -// case <-rl.Context().Done(): -// t.Errorf("connection closed: %v", rl.Context().Err()) -// goto end -// case <-timeout: -// t.Errorf("timeout") -// goto end -// } -// } -// end: -// if n != 2 { -// t.Fatalf("expected 2 events, got %d", n) -// } -// } +// test if we can fetch a couple of random events +func TestSubscribeBasic(t *testing.T) { + rl := mustRelayConnect(t, RELAY) + defer rl.Close() -// // test if we can do multiple nested subscriptions -// func TestNestedSubscriptions(t *testing.T) { -// rl := mustRelayConnect(RELAY) -// defer rl.Close() -// -// n := atomic.Uint32{} -// _ = n -// // fetch 2 replies to a note -// var lim3 uint = 3 -// sub, err := rl.Subscribe( -// context.Bg(), -// filters.New( -// &filter.F{ -// Kinds: kinds.New(kind.TextNote), -// Tags: tags.New( -// tag.New( -// "e", -// "0e34a74f8547e3b95d52a2543719b109fd0312aba144e2ef95cba043f42fe8c5", -// ), -// ), -// Limit: &lim3, -// }, -// ), -// ) -// if err != nil { -// t.Fatalf("subscription 1 failed: %v", err) -// return -// } -// -// for { -// select { -// case event := <-sub.Events: -// // now fetch the author of this -// var lim uint = 1 -// sub, err := rl.Subscribe( -// context.Bg(), -// filters.New( -// &filter.F{ -// Kinds: kinds.New(kind.ProfileMetadata), -// Authors: tag.New(event.Pubkey), Limit: &lim, -// }, -// ), -// ) -// if err != nil { -// t.Fatalf("subscription 2 failed: %v", err) -// return -// } -// -// for { -// select { -// case <-sub.Events: -// // do another subscription here in "sync" mode, just so -// // we're sure things are not blocking -// rl.QuerySync(context.Bg(), &filter.F{Limit: &lim}) -// -// n.Add(1) -// if n.Load() == 3 { -// // if we get here it means the test passed -// return -// } -// case <-sub.Context.Done(): -// goto end -// case <-sub.EndOfStoredEvents: -// sub.Unsub() -// } -// } -// end: -// fmt.Println("") -// case <-sub.EndOfStoredEvents: -// sub.Unsub() -// return -// case <-sub.Context.Done(): -// t.Fatalf("connection closed: %v", rl.Context().Err()) -// return -// } -// } -// } + sub, err := rl.Subscribe( + context.Background(), filters.New( + &filter.F{ + Kinds: &kinds.T{K: []*kind.T{kind.TextNote}}, + Limit: values.ToUintPointer(2), + }, + ), + ) + assert.NoError(t, err) + timeout := time.After(5 * time.Second) + n := 0 + for { + select { + case event := <-sub.Events: + assert.NotNil(t, event) + n++ + case <-sub.EndOfStoredEvents: + assert.Equal(t, 2, n) + sub.Unsub() + return + case <-rl.Context().Done(): + t.Fatalf("connection closed: %v", rl.Context().Err()) + case <-timeout: + t.Fatalf("timeout") + } + } +} + +// test if we can do multiple nested subscriptions +func TestNestedSubscriptions(t *testing.T) { + rl := mustRelayConnect(t, RELAY) + defer rl.Close() + + n := atomic.Uint32{} + + // fetch 2 replies to a note + sub, err := rl.Subscribe( + context.Background(), filters.New( + &filter.F{ + Kinds: kinds.New(kind.TextNote), + Tags: tags.New( + tag.New( + "e", + "0e34a74f8547e3b95d52a2543719b109fd0312aba144e2ef95cba043f42fe8c5", + ), + ), + Limit: values.ToUintPointer(3), + }, + ), + ) + assert.NoError(t, err) + + for { + select { + case ev := <-sub.Events: + // now fetch author of this + sub, err := rl.Subscribe( + context.Background(), filters.New( + &filter.F{ + Kinds: kinds.New(kind.ProfileMetadata), + Authors: tag.New(ev.PubKeyString()), + Limit: values.ToUintPointer(1), + }, + ), + ) + assert.NoError(t, err) + + for { + select { + case <-sub.Events: + // do another subscription here in "sync" mode, just so + // we're sure things aren't blocking + rl.QuerySync( + context.Background(), + &filter.F{Limit: values.ToUintPointer(1)}, + ) + n.Add(1) + if n.Load() == 3 { + // if we get here, it means the test passed + return + } + case <-sub.Context.Done(): + case <-sub.EndOfStoredEvents: + sub.Unsub() + } + } + case <-sub.EndOfStoredEvents: + sub.Unsub() + return + case <-sub.Context.Done(): + t.Fatalf("connection closed: %v", rl.Context().Err()) + return + } + } +}