fixed websocket client bugs

This commit is contained in:
2025-08-17 09:48:01 +01:00
parent 0ad371b06a
commit 0187114918
14 changed files with 957 additions and 1934 deletions

View File

@@ -38,41 +38,107 @@ type BenchmarkResults struct {
func main() {
var (
relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark")
eventCount = flag.Int("events", 10000, "Number of events to publish")
eventSize = flag.Int("size", 1024, "Average size of event content in bytes")
concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers")
queryCount = flag.Int("queries", 100, "Number of queries to execute")
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase")
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
verbose = flag.Bool("v", false, "Verbose output")
multiRelay = flag.Bool("multi-relay", false, "Use multi-relay harness")
relayBinPath = flag.String("relay-bin", "", "Path to relay binary (for multi-relay mode)")
profileQueries = flag.Bool("profile", false, "Run query performance profiling")
profileSubs = flag.Bool("profile-subs", false, "Profile subscription performance")
subCount = flag.Int("sub-count", 100, "Number of concurrent subscriptions for profiling")
subDuration = flag.Duration("sub-duration", 30*time.Second, "Duration for subscription profiling")
installRelays = flag.Bool("install", false, "Install relay dependencies and binaries")
installSecp = flag.Bool("install-secp", false, "Install only secp256k1 library")
workDir = flag.String("work-dir", "/tmp/relay-build", "Working directory for builds")
installDir = flag.String("install-dir", "/usr/local/bin", "Installation directory for binaries")
generateReport = flag.Bool("report", false, "Generate comparative report")
reportFormat = flag.String("report-format", "markdown", "Report format: markdown, json, csv")
reportFile = flag.String("report-file", "benchmark_report", "Report output filename (without extension)")
reportTitle = flag.String("report-title", "Relay Benchmark Comparison", "Report title")
timingMode = flag.Bool("timing", false, "Run end-to-end timing instrumentation")
timingEvents = flag.Int("timing-events", 100, "Number of events for timing instrumentation")
timingSubs = flag.Bool("timing-subs", false, "Test subscription timing")
timingDuration = flag.Duration("timing-duration", 10*time.Second, "Duration for subscription timing test")
loadTest = flag.Bool("load", false, "Run load pattern simulation")
loadPattern = flag.String("load-pattern", "constant", "Load pattern: constant, spike, burst, sine, ramp")
loadDuration = flag.Duration("load-duration", 60*time.Second, "Duration for load test")
loadBase = flag.Int("load-base", 50, "Base load (events/sec)")
loadPeak = flag.Int("load-peak", 200, "Peak load (events/sec)")
loadPool = flag.Int("load-pool", 10, "Connection pool size for load testing")
loadSuite = flag.Bool("load-suite", false, "Run comprehensive load test suite")
loadConstraints = flag.Bool("load-constraints", false, "Test under resource constraints")
relayURL = flag.String(
"relay", "ws://localhost:7447", "Client URL to benchmark",
)
eventCount = flag.Int(
"events", 10000, "Number of events to publish",
)
eventSize = flag.Int(
"size", 1024, "Average size of event content in bytes",
)
concurrency = flag.Int(
"concurrency", 10, "Number of concurrent publishers",
)
queryCount = flag.Int(
"queries", 100, "Number of queries to execute",
)
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
skipPublish = flag.Bool(
"skip-publish", false, "Skip publishing phase",
)
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
verbose = flag.Bool("v", false, "Verbose output")
multiRelay = flag.Bool(
"multi-relay", false, "Use multi-relay harness",
)
relayBinPath = flag.String(
"relay-bin", "", "Path to relay binary (for multi-relay mode)",
)
profileQueries = flag.Bool(
"profile", false, "Run query performance profiling",
)
profileSubs = flag.Bool(
"profile-subs", false, "Profile subscription performance",
)
subCount = flag.Int(
"sub-count", 100,
"Number of concurrent subscriptions for profiling",
)
subDuration = flag.Duration(
"sub-duration", 30*time.Second,
"Duration for subscription profiling",
)
installRelays = flag.Bool(
"install", false, "Install relay dependencies and binaries",
)
installSecp = flag.Bool(
"install-secp", false, "Install only secp256k1 library",
)
workDir = flag.String(
"work-dir", "/tmp/relay-build", "Working directory for builds",
)
installDir = flag.String(
"install-dir", "/usr/local/bin",
"Installation directory for binaries",
)
generateReport = flag.Bool(
"report", false, "Generate comparative report",
)
reportFormat = flag.String(
"report-format", "markdown", "Report format: markdown, json, csv",
)
reportFile = flag.String(
"report-file", "benchmark_report",
"Report output filename (without extension)",
)
reportTitle = flag.String(
"report-title", "Client Benchmark Comparison", "Report title",
)
timingMode = flag.Bool(
"timing", false, "Run end-to-end timing instrumentation",
)
timingEvents = flag.Int(
"timing-events", 100, "Number of events for timing instrumentation",
)
timingSubs = flag.Bool(
"timing-subs", false, "Test subscription timing",
)
timingDuration = flag.Duration(
"timing-duration", 10*time.Second,
"Duration for subscription timing test",
)
loadTest = flag.Bool(
"load", false, "Run load pattern simulation",
)
loadPattern = flag.String(
"load-pattern", "constant",
"Load pattern: constant, spike, burst, sine, ramp",
)
loadDuration = flag.Duration(
"load-duration", 60*time.Second, "Duration for load test",
)
loadBase = flag.Int("load-base", 50, "Base load (events/sec)")
loadPeak = flag.Int("load-peak", 200, "Peak load (events/sec)")
loadPool = flag.Int(
"load-pool", 10, "Connection pool size for load testing",
)
loadSuite = flag.Bool(
"load-suite", false, "Run comprehensive load test suite",
)
loadConstraints = flag.Bool(
"load-constraints", false, "Test under resource constraints",
)
)
flag.Parse()
@@ -89,25 +155,46 @@ func main() {
} else if *generateReport {
runReportGeneration(*reportTitle, *reportFormat, *reportFile)
} else if *loadTest || *loadSuite || *loadConstraints {
runLoadSimulation(c, *relayURL, *loadPattern, *loadDuration, *loadBase, *loadPeak, *loadPool, *eventSize, *loadSuite, *loadConstraints)
runLoadSimulation(
c, *relayURL, *loadPattern, *loadDuration, *loadBase, *loadPeak,
*loadPool, *eventSize, *loadSuite, *loadConstraints,
)
} else if *timingMode || *timingSubs {
runTimingInstrumentation(c, *relayURL, *timingEvents, *eventSize, *timingSubs, *timingDuration)
runTimingInstrumentation(
c, *relayURL, *timingEvents, *eventSize, *timingSubs,
*timingDuration,
)
} else if *profileQueries || *profileSubs {
runQueryProfiler(c, *relayURL, *queryCount, *concurrency, *profileSubs, *subCount, *subDuration)
runQueryProfiler(
c, *relayURL, *queryCount, *concurrency, *profileSubs, *subCount,
*subDuration,
)
} else if *multiRelay {
runMultiRelayBenchmark(c, *relayBinPath, *eventCount, *eventSize, *concurrency, *queryCount, *queryLimit, *skipPublish, *skipQuery)
runMultiRelayBenchmark(
c, *relayBinPath, *eventCount, *eventSize, *concurrency,
*queryCount, *queryLimit, *skipPublish, *skipQuery,
)
} else {
runSingleRelayBenchmark(c, *relayURL, *eventCount, *eventSize, *concurrency, *queryCount, *queryLimit, *skipPublish, *skipQuery)
runSingleRelayBenchmark(
c, *relayURL, *eventCount, *eventSize, *concurrency, *queryCount,
*queryLimit, *skipPublish, *skipQuery,
)
}
}
func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize, concurrency, queryCount, queryLimit int, skipPublish, skipQuery bool) {
func runSingleRelayBenchmark(
c context.T, relayURL string,
eventCount, eventSize, concurrency, queryCount, queryLimit int,
skipPublish, skipQuery bool,
) {
results := &BenchmarkResults{}
// Phase 1: Publish events
if !skipPublish {
fmt.Printf("Publishing %d events to %s...\n", eventCount, relayURL)
if err := benchmarkPublish(c, relayURL, eventCount, eventSize, concurrency, results); chk.E(err) {
if err := benchmarkPublish(
c, relayURL, eventCount, eventSize, concurrency, results,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during publish benchmark: %v\n", err)
os.Exit(1)
}
@@ -116,7 +203,9 @@ func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize
// Phase 2: Query events
if !skipQuery {
fmt.Printf("\nQuerying events from %s...\n", relayURL)
if err := benchmarkQuery(c, relayURL, queryCount, queryLimit, results); chk.E(err) {
if err := benchmarkQuery(
c, relayURL, queryCount, queryLimit, results,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during query benchmark: %v\n", err)
os.Exit(1)
}
@@ -126,7 +215,11 @@ func runSingleRelayBenchmark(c context.T, relayURL string, eventCount, eventSize
printResults(results)
}
func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventSize, concurrency, queryCount, queryLimit int, skipPublish, skipQuery bool) {
func runMultiRelayBenchmark(
c context.T, relayBinPath string,
eventCount, eventSize, concurrency, queryCount, queryLimit int,
skipPublish, skipQuery bool,
) {
harness := NewMultiRelayHarness()
generator := NewReportGenerator()
@@ -165,16 +258,26 @@ func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventS
if !skipPublish {
fmt.Printf("Publishing %d events to %s...\n", eventCount, relayURL)
if err := benchmarkPublish(c, relayURL, eventCount, eventSize, concurrency, results); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during publish benchmark for %s: %v\n", relayType, err)
if err := benchmarkPublish(
c, relayURL, eventCount, eventSize, concurrency, results,
); chk.E(err) {
fmt.Fprintf(
os.Stderr, "Error during publish benchmark for %s: %v\n",
relayType, err,
)
continue
}
}
if !skipQuery {
fmt.Printf("\nQuerying events from %s...\n", relayURL)
if err := benchmarkQuery(c, relayURL, queryCount, queryLimit, results); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during query benchmark for %s: %v\n", relayType, err)
if err := benchmarkQuery(
c, relayURL, queryCount, queryLimit, results,
); chk.E(err) {
fmt.Fprintf(
os.Stderr, "Error during query benchmark for %s: %v\n",
relayType, err,
)
continue
}
}
@@ -190,16 +293,21 @@ func runMultiRelayBenchmark(c context.T, relayBinPath string, eventCount, eventS
generator.AddRelayData(relayType.String(), results, metrics, nil)
}
generator.GenerateReport("Multi-Relay Benchmark Results")
generator.GenerateReport("Multi-Client Benchmark Results")
if err := SaveReportToFile("BENCHMARK_RESULTS.md", "markdown", generator); chk.E(err) {
if err := SaveReportToFile(
"BENCHMARK_RESULTS.md", "markdown", generator,
); chk.E(err) {
fmt.Printf("Warning: Failed to save benchmark results: %v\n", err)
} else {
fmt.Printf("\nBenchmark results saved to: BENCHMARK_RESULTS.md\n")
}
}
func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concurrency int, results *BenchmarkResults) error {
func benchmarkPublish(
c context.T, relayURL string, eventCount, eventSize, concurrency int,
results *BenchmarkResults,
) error {
// Generate signers for each concurrent publisher
signers := make([]*testSigner, concurrency)
for i := range signers {
@@ -244,7 +352,7 @@ func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concu
for j := 0; j < eventsToPublish; j++ {
ev := generateEvent(signer, eventSize, time.Duration(0), 0)
if err := relay.Publish(c, ev); err != nil {
if err = relay.Publish(c, ev); err != nil {
log.E.F(
"Publisher %d failed to publish event: %v", publisherID,
err,
@@ -372,7 +480,9 @@ func benchmarkQuery(
return nil
}
func generateEvent(signer *testSigner, contentSize int, rateLimit time.Duration, burstSize int) *event.E {
func generateEvent(
signer *testSigner, contentSize int, rateLimit time.Duration, burstSize int,
) *event.E {
return generateSimpleEvent(signer, contentSize)
}
@@ -450,18 +560,31 @@ func printHarnessMetrics(relayType RelayType, metrics *HarnessMetrics) {
}
}
func runQueryProfiler(c context.T, relayURL string, queryCount, concurrency int, profileSubs bool, subCount int, subDuration time.Duration) {
func runQueryProfiler(
c context.T, relayURL string, queryCount, concurrency int, profileSubs bool,
subCount int, subDuration time.Duration,
) {
profiler := NewQueryProfiler(relayURL)
if profileSubs {
fmt.Printf("Profiling %d concurrent subscriptions for %v...\n", subCount, subDuration)
if err := profiler.TestSubscriptionPerformance(c, subDuration, subCount); chk.E(err) {
fmt.Printf(
"Profiling %d concurrent subscriptions for %v...\n", subCount,
subDuration,
)
if err := profiler.TestSubscriptionPerformance(
c, subDuration, subCount,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Subscription profiling failed: %v\n", err)
os.Exit(1)
}
} else {
fmt.Printf("Profiling %d queries with %d concurrent workers...\n", queryCount, concurrency)
if err := profiler.ExecuteProfile(c, queryCount, concurrency); chk.E(err) {
fmt.Printf(
"Profiling %d queries with %d concurrent workers...\n", queryCount,
concurrency,
)
if err := profiler.ExecuteProfile(
c, queryCount, concurrency,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Query profiling failed: %v\n", err)
os.Exit(1)
}
@@ -488,7 +611,10 @@ func runSecp256k1Installer(workDir, installDir string) {
}
}
func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.Duration, baseLoad, peakLoad, poolSize, eventSize int, runSuite, runConstraints bool) {
func runLoadSimulation(
c context.T, relayURL, patternStr string, duration time.Duration,
baseLoad, peakLoad, poolSize, eventSize int, runSuite, runConstraints bool,
) {
if runSuite {
suite := NewLoadTestSuite(relayURL, poolSize, eventSize)
if err := suite.RunAllPatterns(c); chk.E(err) {
@@ -515,7 +641,9 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D
os.Exit(1)
}
simulator := NewLoadSimulator(relayURL, pattern, duration, baseLoad, peakLoad, poolSize, eventSize)
simulator := NewLoadSimulator(
relayURL, pattern, duration, baseLoad, peakLoad, poolSize, eventSize,
)
if err := simulator.Run(c); chk.E(err) {
fmt.Fprintf(os.Stderr, "Load simulation failed: %v\n", err)
@@ -524,8 +652,12 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D
if runConstraints {
fmt.Printf("\n")
if err := simulator.SimulateResourceConstraints(c, 512, 80); chk.E(err) {
fmt.Fprintf(os.Stderr, "Resource constraint simulation failed: %v\n", err)
if err := simulator.SimulateResourceConstraints(
c, 512, 80,
); chk.E(err) {
fmt.Fprintf(
os.Stderr, "Resource constraint simulation failed: %v\n", err,
)
}
}
@@ -540,7 +672,10 @@ func runLoadSimulation(c context.T, relayURL, patternStr string, duration time.D
fmt.Printf("Peak latency: %vms\n", metrics["peak_latency_ms"])
}
func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSize int, testSubs bool, duration time.Duration) {
func runTimingInstrumentation(
c context.T, relayURL string, eventCount, eventSize int, testSubs bool,
duration time.Duration,
) {
instrumentation := NewTimingInstrumentation(relayURL)
fmt.Printf("Connecting to relay at %s...\n", relayURL)
@@ -552,13 +687,17 @@ func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSiz
if testSubs {
fmt.Printf("\n=== Subscription Timing Test ===\n")
if err := instrumentation.TestSubscriptionTiming(c, duration); chk.E(err) {
if err := instrumentation.TestSubscriptionTiming(
c, duration,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Subscription timing test failed: %v\n", err)
os.Exit(1)
}
} else {
fmt.Printf("\n=== Full Event Lifecycle Instrumentation ===\n")
if err := instrumentation.RunFullInstrumentation(c, eventCount, eventSize); chk.E(err) {
if err := instrumentation.RunFullInstrumentation(
c, eventCount, eventSize,
); chk.E(err) {
fmt.Fprintf(os.Stderr, "Timing instrumentation failed: %v\n", err)
os.Exit(1)
}
@@ -574,12 +713,14 @@ func runTimingInstrumentation(c context.T, relayURL string, eventCount, eventSiz
if bottlenecks, ok := metrics["bottlenecks"].(map[string]map[string]interface{}); ok {
fmt.Printf("\n=== Pipeline Stage Analysis ===\n")
for stage, data := range bottlenecks {
fmt.Printf("%s: avg=%vms, p95=%vms, p99=%vms, throughput=%.2f ops/s\n",
fmt.Printf(
"%s: avg=%vms, p95=%vms, p99=%vms, throughput=%.2f ops/s\n",
stage,
data["avg_latency_ms"],
data["p95_latency_ms"],
data["p99_latency_ms"],
data["throughput_ops_sec"])
data["throughput_ops_sec"],
)
}
}
}

View File

@@ -60,7 +60,10 @@ func NewReportGenerator() *ReportGenerator {
}
}
func (rg *ReportGenerator) AddRelayData(relayType string, results *BenchmarkResults, metrics *HarnessMetrics, profilerMetrics *QueryMetrics) {
func (rg *ReportGenerator) AddRelayData(
relayType string, results *BenchmarkResults, metrics *HarnessMetrics,
profilerMetrics *QueryMetrics,
) {
data := RelayBenchmarkData{
RelayType: relayType,
EventsPublished: results.EventsPublished,
@@ -148,19 +151,25 @@ func (rg *ReportGenerator) detectAnomalies() {
for _, data := range rg.data {
if math.Abs(data.PublishRate-publishMean) > 2*publishStdDev {
anomaly := fmt.Sprintf("%s publish rate (%.2f) deviates significantly from average (%.2f)",
data.RelayType, data.PublishRate, publishMean)
anomaly := fmt.Sprintf(
"%s publish rate (%.2f) deviates significantly from average (%.2f)",
data.RelayType, data.PublishRate, publishMean,
)
rg.report.Anomalies = append(rg.report.Anomalies, anomaly)
}
if math.Abs(data.QueryRate-queryMean) > 2*queryStdDev {
anomaly := fmt.Sprintf("%s query rate (%.2f) deviates significantly from average (%.2f)",
data.RelayType, data.QueryRate, queryMean)
anomaly := fmt.Sprintf(
"%s query rate (%.2f) deviates significantly from average (%.2f)",
data.RelayType, data.QueryRate, queryMean,
)
rg.report.Anomalies = append(rg.report.Anomalies, anomaly)
}
if data.Errors > 0 {
anomaly := fmt.Sprintf("%s had %d errors during benchmark", data.RelayType, data.Errors)
anomaly := fmt.Sprintf(
"%s had %d errors during benchmark", data.RelayType, data.Errors,
)
rg.report.Anomalies = append(rg.report.Anomalies, anomaly)
}
}
@@ -171,9 +180,11 @@ func (rg *ReportGenerator) generateRecommendations() {
return
}
sort.Slice(rg.data, func(i, j int) bool {
return rg.data[i].PublishRate > rg.data[j].PublishRate
})
sort.Slice(
rg.data, func(i, j int) bool {
return rg.data[i].PublishRate > rg.data[j].PublishRate
},
)
if len(rg.data) > 1 {
best := rg.data[0]
@@ -181,16 +192,20 @@ func (rg *ReportGenerator) generateRecommendations() {
improvement := (best.PublishRate - worst.PublishRate) / worst.PublishRate * 100
if improvement > 20 {
rec := fmt.Sprintf("Consider using %s for high-throughput scenarios (%.1f%% faster than %s)",
best.RelayType, improvement, worst.RelayType)
rec := fmt.Sprintf(
"Consider using %s for high-throughput scenarios (%.1f%% faster than %s)",
best.RelayType, improvement, worst.RelayType,
)
rg.report.Recommendations = append(rg.report.Recommendations, rec)
}
}
for _, data := range rg.data {
if data.MemoryUsageMB > 500 {
rec := fmt.Sprintf("%s shows high memory usage (%.1f MB) - monitor for memory leaks",
data.RelayType, data.MemoryUsageMB)
rec := fmt.Sprintf(
"%s shows high memory usage (%.1f MB) - monitor for memory leaks",
data.RelayType, data.MemoryUsageMB,
)
rg.report.Recommendations = append(rg.report.Recommendations, rec)
}
}
@@ -198,25 +213,39 @@ func (rg *ReportGenerator) generateRecommendations() {
func (rg *ReportGenerator) OutputMarkdown(writer io.Writer) error {
fmt.Fprintf(writer, "# %s\n\n", rg.report.Title)
fmt.Fprintf(writer, "Generated: %s\n\n", rg.report.GeneratedAt.Format(time.RFC3339))
fmt.Fprintf(
writer, "Generated: %s\n\n", rg.report.GeneratedAt.Format(time.RFC3339),
)
fmt.Fprintf(writer, "## Performance Summary\n\n")
fmt.Fprintf(writer, "| Relay | Publish Rate | Publish BW | Query Rate | Avg Events/Query | Memory (MB) |\n")
fmt.Fprintf(writer, "|-------|--------------|------------|------------|------------------|-------------|\n")
fmt.Fprintf(
writer,
"| Client | Publish Rate | Publish BW | Query Rate | Avg Events/Query | Memory (MB) |\n",
)
fmt.Fprintf(
writer,
"|-------|--------------|------------|------------|------------------|-------------|\n",
)
for _, data := range rg.data {
fmt.Fprintf(writer, "| %s | %.2f/s | %.2f MB/s | %.2f/s | %.2f | %.1f |\n",
fmt.Fprintf(
writer, "| %s | %.2f/s | %.2f MB/s | %.2f/s | %.2f | %.1f |\n",
data.RelayType, data.PublishRate, data.PublishBandwidth,
data.QueryRate, data.AvgEventsPerQuery, data.MemoryUsageMB)
data.QueryRate, data.AvgEventsPerQuery, data.MemoryUsageMB,
)
}
if rg.report.WinnerPublish != "" || rg.report.WinnerQuery != "" {
fmt.Fprintf(writer, "\n## Winners\n\n")
if rg.report.WinnerPublish != "" {
fmt.Fprintf(writer, "- **Best Publisher**: %s\n", rg.report.WinnerPublish)
fmt.Fprintf(
writer, "- **Best Publisher**: %s\n", rg.report.WinnerPublish,
)
}
if rg.report.WinnerQuery != "" {
fmt.Fprintf(writer, "- **Best Query Engine**: %s\n", rg.report.WinnerQuery)
fmt.Fprintf(
writer, "- **Best Query Engine**: %s\n", rg.report.WinnerQuery,
)
}
}
@@ -237,12 +266,18 @@ func (rg *ReportGenerator) OutputMarkdown(writer io.Writer) error {
fmt.Fprintf(writer, "\n## Detailed Results\n\n")
for _, data := range rg.data {
fmt.Fprintf(writer, "### %s\n\n", data.RelayType)
fmt.Fprintf(writer, "- Events Published: %d (%.2f MB)\n", data.EventsPublished, data.EventsPublishedMB)
fmt.Fprintf(
writer, "- Events Published: %d (%.2f MB)\n", data.EventsPublished,
data.EventsPublishedMB,
)
fmt.Fprintf(writer, "- Publish Duration: %s\n", data.PublishDuration)
fmt.Fprintf(writer, "- Queries Executed: %d\n", data.QueriesExecuted)
fmt.Fprintf(writer, "- Query Duration: %s\n", data.QueryDuration)
if data.P50Latency != "" {
fmt.Fprintf(writer, "- Latency P50/P95/P99: %s/%s/%s\n", data.P50Latency, data.P95Latency, data.P99Latency)
fmt.Fprintf(
writer, "- Latency P50/P95/P99: %s/%s/%s\n", data.P50Latency,
data.P95Latency, data.P99Latency,
)
}
if data.StartupTime != "" {
fmt.Fprintf(writer, "- Startup Time: %s\n", data.StartupTime)
@@ -264,9 +299,12 @@ func (rg *ReportGenerator) OutputCSV(writer io.Writer) error {
defer w.Flush()
header := []string{
"relay_type", "events_published", "events_published_mb", "publish_duration",
"publish_rate", "publish_bandwidth", "queries_executed", "events_returned",
"query_duration", "query_rate", "avg_events_per_query", "memory_usage_mb",
"relay_type", "events_published", "events_published_mb",
"publish_duration",
"publish_rate", "publish_bandwidth", "queries_executed",
"events_returned",
"query_duration", "query_rate", "avg_events_per_query",
"memory_usage_mb",
"p50_latency", "p95_latency", "p99_latency", "startup_time", "errors",
}
@@ -315,9 +353,11 @@ func (rg *ReportGenerator) GenerateThroughputCurve() []ThroughputPoint {
points = append(points, point)
}
sort.Slice(points, func(i, j int) bool {
return points[i].Throughput < points[j].Throughput
})
sort.Slice(
points, func(i, j int) bool {
return points[i].Throughput < points[j].Throughput
},
)
return points
}
@@ -370,7 +410,9 @@ func stdDev(values []float64, mean float64) float64 {
return math.Sqrt(variance)
}
func SaveReportToFile(filename, format string, generator *ReportGenerator) error {
func SaveReportToFile(
filename, format string, generator *ReportGenerator,
) error {
file, err := os.Create(filename)
if err != nil {
return err

View File

@@ -1 +1,31 @@
#!/usr/bin/env bash
# khatru
khatru &
KHATRU_PID=$!
printf "khatru started pid: %s\n" $KHATRU_PID
sleep 2s
LOG_LEVEL=info relay-benchmark -relay ws://localhost:3334 -events 10000 -queries 100
kill $KHATRU_PID
printf "khatru stopped\n"
sleep 1s
# ORLY
LOG_LEVEL=off \
ORLY_LOG_LEVEL=off \
ORLY_DB_LOG_LEVEL=off \
ORLY_SPIDER_TYPE=none \
ORLY_LISTEN=localhost \
ORLY_PORT=7447 \
ORLY_AUTH_REQUIRED=false \
ORLY_PRIVATE=true \
orly &
ORLY_PID=$!
printf "ORLY started pid: %s\n" $ORLY_PID
sleep 2s
LOG_LEVEL=info relay-benchmark -relay ws://localhost:7447 -events 100 -queries 100
kill $ORLY_PID
printf "ORLY stopped\n"
sleep 1s

View File

@@ -90,17 +90,19 @@ func (en *T) Marshal(dst []byte) (b []byte) {
// subscription.Id strings are correctly unescaped by NIP-01 escaping rules.
func (en *T) Unmarshal(b []byte) (r []byte, err error) {
r = b
var idHex []byte
if idHex, r, err = text2.UnmarshalHex(r); chk.E(err) {
var idBytes []byte
// Parse event id as quoted hex (NIP-20 compliant)
if idBytes, r, err = text2.UnmarshalHex(r); err != nil {
return
}
if len(idHex) != sha256.Size {
if len(idBytes) != sha256.Size {
err = errorf.E(
"invalid size for ID, require %d got %d",
len(idHex), sha256.Size,
sha256.Size, len(idBytes),
)
return
}
en.EventID = eventid.NewWith(idHex)
en.EventID = eventid.NewWith(idBytes)
if r, err = text2.Comma(r); chk.E(err) {
return
}

View File

@@ -23,7 +23,6 @@ import (
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/pointers"
"lukechampine.com/frand"
@@ -446,21 +445,15 @@ invalid:
// determines if the event matches the filter, ignoring timestamp constraints..
func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool {
if ev == nil {
log.I.F("nil event")
return false
}
if f.Ids.Len() > 0 && !f.Ids.Contains(ev.ID) {
log.I.F("no ids in filter match event")
return false
}
if f.Kinds.Len() > 0 && !f.Kinds.Contains(ev.Kind) {
log.I.F(
"no matching kinds in filter",
)
return false
}
if f.Authors.Len() > 0 && !f.Authors.Contains(ev.Pubkey) {
log.I.F("no matching authors in filter")
return false
}
// if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) {
@@ -470,7 +463,6 @@ func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool {
for _, v := range f.Tags.ToSliceOfTags() {
tvs := v.ToSliceOfBytes()
if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) {
log.I.F("no matching tags in filter")
return false
}
}
@@ -485,11 +477,9 @@ func (f *F) Matches(ev *event.E) (match bool) {
return
}
if f.Since.Int() != 0 && ev.CreatedAt.I64() < f.Since.I64() {
log.I.F("event is older than since")
return
}
if f.Until.Int() != 0 && ev.CreatedAt.I64() > f.Until.I64() {
log.I.F("event is newer than until")
return
}
return true

View File

@@ -57,7 +57,7 @@ var (
NIP8 = HandlingMentions
EventDeletion = NIP{"Event Deletion", 9}
NIP9 = EventDeletion
RelayInformationDocument = NIP{"Relay Information Document", 11}
RelayInformationDocument = NIP{"Client Information Document", 11}
NIP11 = RelayInformationDocument
GenericTagQueries = NIP{"Generic Tag Queries", 12}
NIP12 = GenericTagQueries
@@ -133,7 +133,7 @@ var (
NIP57 = LightningZaps
Badges = NIP{"Badges", 58}
NIP58 = Badges
RelayListMetadata = NIP{"Relay List Metadata", 65}
RelayListMetadata = NIP{"Client List Metadata", 65}
NIP65 = RelayListMetadata
ProtectedEvents = NIP{"Protected Events", 70}
NIP70 = ProtectedEvents

View File

@@ -2,17 +2,11 @@ package ws
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/puzpuzpuz/xsync/v3"
"orly.dev/pkg/encoders/envelopes"
"orly.dev/pkg/encoders/envelopes/authenvelope"
"orly.dev/pkg/encoders/envelopes/closedenvelope"
@@ -23,21 +17,28 @@ import (
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/subscription"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/codec"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/normalize"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/puzpuzpuz/xsync/v3"
)
var subscriptionIDCounter atomic.Int64
// Relay represents a connection to a Nostr relay.
// Client represents a connection to a Nostr relay.
type Client struct {
closeMutex sync.Mutex
@@ -45,14 +46,14 @@ type Client struct {
requestHeader http.Header // e.g. for origin header
Connection *Connection
Subscriptions *xsync.MapOf[int64, *Subscription]
Subscriptions *xsync.MapOf[string, *Subscription]
ConnectionError error
connectionContext context.T // will be canceled when the connection closes
connectionContextCancel context.C
connectionContext context.Context // will be canceled when the connection closes
connectionContextCancel context.CancelCauseFunc
challenge []byte // NIP-42 challenge, we only keep the last
noticeHandler func(string) // NIP-01 NOTICEs
notices chan []byte // NIP-01 NOTICEs
customHandler func(string) // nonstandard unparseable messages
okCallbacks *xsync.MapOf[string, func(bool, string)]
writeQueue chan writeRequest
@@ -69,13 +70,13 @@ type writeRequest struct {
}
// NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection.
func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client {
ctx, cancel := context.Cause(ctx)
func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Client {
ctx, cancel := context.WithCancelCause(ctx)
r := &Client{
URL: string(normalize.URL(url)),
connectionContext: ctx,
connectionContextCancel: cancel,
Subscriptions: xsync.NewMapOf[int64, *Subscription](),
Subscriptions: xsync.NewMapOf[string, *Subscription](),
okCallbacks: xsync.NewMapOf[string, func(
bool, string,
)](),
@@ -97,10 +98,10 @@ func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client {
//
// The ongoing relay connection uses a background context. To close the connection, call r.Close().
// If you need fine grained long-term connection contexts, use NewRelay() instead.
func RelayConnect(ctx context.T, url string, opts ...RelayOption) (
func RelayConnect(ctx context.Context, url string, opts ...RelayOption) (
*Client, error,
) {
r := NewRelay(context.Bg(), url, opts...)
r := NewRelay(context.Background(), url, opts...)
err := r.Connect(ctx)
return r, err
}
@@ -111,19 +112,10 @@ type RelayOption interface {
}
var (
_ RelayOption = (WithNoticeHandler)(nil)
_ RelayOption = (WithCustomHandler)(nil)
_ RelayOption = (WithRequestHeader)(nil)
)
// WithNoticeHandler just takes notices and is expected to do something with them.
// when not given, defaults to logging the notices.
type WithNoticeHandler func(notice string)
func (nh WithNoticeHandler) ApplyRelayOption(r *Client) {
r.noticeHandler = nh
}
// WithCustomHandler must be a function that handles any relay message that couldn't be
// parsed as a standard envelope.
type WithCustomHandler func(data string)
@@ -146,7 +138,7 @@ func (r *Client) String() string {
// Context retrieves the context that is associated with this relay connection.
// It will be closed when the relay is disconnected.
func (r *Client) Context() context.T { return r.connectionContext }
func (r *Client) Context() context.Context { return r.connectionContext }
// IsConnected returns true if the connection to this relay seems to be active.
func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil }
@@ -158,10 +150,32 @@ func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil }
//
// The given context here is only used during the connection phase. The long-living
// relay connection will be based on the context given to NewRelay().
func (r *Client) Connect(ctx context.T) error {
func (r *Client) Connect(ctx context.Context) error {
return r.ConnectWithTLS(ctx, nil)
}
func extractSubID(jsonStr string) string {
// look for "EVENT" pattern
start := strings.Index(jsonStr, `"EVENT"`)
if start == -1 {
return ""
}
// move to the next quote
offset := strings.Index(jsonStr[start+7:], `"`)
if offset == -1 {
return ""
}
start += 7 + offset + 1
// find the ending quote
end := strings.Index(jsonStr[start:], `"`)
// get the contents
return jsonStr[start : start+end]
}
func subIdToSerial(subId string) int64 {
n := strings.Index(subId, ":")
if n < 0 || n > len(subId) {
@@ -173,8 +187,8 @@ func subIdToSerial(subId string) int64 {
// ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that.
func (r *Client) ConnectWithTLS(
ctx context.T, tlsConfig *tls.Config,
) (err error) {
ctx context.Context, tlsConfig *tls.Config,
) error {
if r.connectionContext == nil || r.Subscriptions == nil {
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
}
@@ -185,209 +199,182 @@ func (r *Client) ConnectWithTLS(
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.F
ctx, cancel = context.TimeoutCause(
var cancel context.CancelFunc
ctx, cancel = context.WithTimeoutCause(
ctx, 7*time.Second, errors.New("connection took too long"),
)
defer cancel()
}
var conn *Connection
if conn, err = NewConnection(
ctx, r.URL, r.requestHeader, tlsConfig,
); chk.E(err) {
err = fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
return
conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
}
r.Connection = conn
// ping every 29 seconds
ticker := time.NewTicker(29 * time.Second)
// queue all write operations here so we don't do mutex spaghetti
go func() {
var err error
for {
select {
case <-r.connectionContext.Done():
ticker.Stop()
r.Connection = nil
for _, sub := range r.Subscriptions.Range {
sub.unsub(
fmt.Errorf(
"relay connection closed: %w / %w",
context.GetCause(r.connectionContext),
context.Cause(r.connectionContext),
r.ConnectionError,
),
)
}
return
case <-ticker.C:
err := r.Connection.Ping(r.connectionContext)
err = r.Connection.Ping(r.connectionContext)
if err != nil && !strings.Contains(
err.Error(), "failed to wait for pong",
) {
log.T.C(
func() string {
return fmt.Sprintf(
"{%s} error writing ping: %v; closing websocket",
r.URL,
err,
)
},
log.I.F(
"{%s} error writing ping: %v; closing websocket", r.URL,
err,
)
r.Close() // this should trigger a context cancelation
return
}
case writeRequest := <-r.writeQueue:
case wr := <-r.writeQueue:
// all write requests will go through this to prevent races
log.T.C(
func() string {
return fmt.Sprintf(
"{%s} sending %v\n", r.URL,
string(writeRequest.msg),
)
},
)
if err := r.Connection.WriteMessage(
r.connectionContext, writeRequest.msg,
log.D.F("{%s} sending %v\n", r.URL, string(wr.msg))
if err = r.Connection.WriteMessage(
r.connectionContext, wr.msg,
); err != nil {
writeRequest.answer <- err
wr.answer <- err
}
close(writeRequest.answer)
close(wr.answer)
}
}
}()
// general message reader loop
go func() {
for {
buf := new(bytes.Buffer)
if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
r.ConnectionError = err
r.close(err)
break
}
var err error
var t string
var rem []byte
if t, rem, err = envelopes.Identify(buf.Bytes()); chk.E(err) {
continue
}
switch t {
case noticeenvelope.L:
env := noticeenvelope.NewFrom(rem)
// see WithNoticeHandler
if r.noticeHandler != nil {
r.noticeHandler(string(env.Message))
} else {
log.D.F(
"NOTICE from %s: '%s'\n", r.URL, string(env.Message),
)
for {
buf.Reset()
if err := conn.ReadMessage(
r.connectionContext, buf,
); err != nil {
r.ConnectionError = err
r.Close()
break
}
case authenvelope.L:
env := authenvelope.NewChallengeWith(rem)
if env.Challenge == nil {
message := buf.Bytes()
log.D.F("{%s} %v\n", r.URL, message)
var t string
if t, message, err = envelopes.Identify(message); chk.E(err) {
continue
}
r.challenge = env.Challenge
case eventenvelope.L:
// log.I.F("%s", rem)
var env *eventenvelope.Result
env = eventenvelope.NewResult()
if _, err = env.Unmarshal(rem); err != nil {
continue
}
subid := env.Subscription.String()
sub, ok := r.Subscriptions.Load(subIdToSerial(subid))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
subid,
)
continue
}
if !sub.Filters.Match(env.Event) {
log.T.C(
func() string {
return fmt.Sprintf(
"{%s} filter does not match: %v ~ %s\n", r.URL,
sub.Filters, env.Event.Marshal(nil),
)
},
)
continue
}
if !r.AssumeValid {
if ok, err = env.Event.Verify(); !ok || chk.E(err) {
log.T.C(
func() string {
return fmt.Sprintf(
"{%s} bad signature on %s\n", r.URL,
env.Event.ID,
)
},
)
switch t {
case noticeenvelope.L:
env := noticeenvelope.New()
if env, message, err = noticeenvelope.Parse(message); chk.E(err) {
continue
}
}
sub.dispatchEvent(env.Event)
case eoseenvelope.L:
var env *eoseenvelope.T
if env, rem, err = eoseenvelope.Parse(rem); chk.E(err) {
continue
}
if len(rem) != 0 {
log.W.F(
"{%s} unexpected data after EOSE: %s\n", r.URL,
string(rem),
)
}
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
env.Subscription.String(),
)
continue
}
sub.dispatchEose()
case closedenvelope.L:
var env *closedenvelope.T
if env, rem, err = closedenvelope.Parse(rem); chk.E(err) {
continue
}
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
env.Subscription.String(),
)
continue
}
sub.handleClosed(env.ReasonString())
case okenvelope.L:
var env *okenvelope.T
if env, rem, err = okenvelope.Parse(rem); chk.E(err) {
continue
}
eventIDStr := env.EventID.String()
if okCallback, exist := r.okCallbacks.Load(eventIDStr); exist {
okCallback(env.OK, string(env.Reason))
} else {
log.T.C(
func() string {
return fmt.Sprintf(
"{%s} got an unexpected OK message for event %s",
r.URL,
eventIDStr,
// see WithNoticeHandler
if r.notices != nil {
r.notices <- env.Message
} else {
log.E.F("NOTICE from %s: '%s'\n", r.URL, env.Message)
}
case authenvelope.L:
env := authenvelope.NewChallenge()
if env, message, err = authenvelope.ParseChallenge(message); chk.E(err) {
continue
}
if len(env.Challenge) == 0 {
continue
}
r.challenge = env.Challenge
case eventenvelope.L:
env := eventenvelope.NewResult()
if env, message, err = eventenvelope.ParseResult(message); chk.E(err) {
continue
}
if len(env.Subscription.T) == 0 {
continue
}
if sub, ok := r.Subscriptions.Load(env.Subscription.String()); !ok {
log.D.F(
"{%s} no subscription with id '%s'\n", r.URL,
env.Subscription,
)
continue
} else {
// check if the event matches the desired filter, ignore otherwise
if !sub.Filters.Match(env.Event) {
log.D.F(
"{%s} filter does not match: %v ~ %v\n", r.URL,
sub.Filters, env.Event,
)
},
)
continue
}
// check signature, ignore invalid, except from trusted (AssumeValid) relays
if !r.AssumeValid {
if ok, err = env.Event.Verify(); !ok {
log.E.F(
"{%s} bad signature on %s\n", r.URL,
env.Event.IdString(),
)
continue
}
}
// dispatch this to the internal .events channel of the subscription
sub.dispatchEvent(env.Event)
}
case eoseenvelope.L:
env := eoseenvelope.New()
if env, message, err = eoseenvelope.Parse(message); chk.E(err) {
continue
}
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
subscription.dispatchEose()
}
case closedenvelope.L:
env := closedenvelope.New()
if env, message, err = closedenvelope.Parse(message); chk.E(err) {
continue
}
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
subscription.handleClosed(env.ReasonString())
}
case okenvelope.L:
env := okenvelope.New()
if env, message, err = okenvelope.Parse(message); chk.E(err) {
continue
}
if okCallback, exist := r.okCallbacks.Load(env.EventID.String()); exist {
okCallback(env.OK, env.ReasonString())
} else {
log.I.F(
"{%s} got an unexpected OK message for event %s",
r.URL,
env.EventID,
)
}
}
default:
log.W.F("unknown envelope type %s\n%s", t, rem)
continue
}
}
}()
return
return nil
}
// Write queues an arbitrary message to be sent to the relay.
@@ -401,55 +388,60 @@ func (r *Client) Write(msg []byte) <-chan error {
return ch
}
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an
// OK response.
func (r *Client) Publish(ctx context.T, ev *event.E) error {
return r.publish(ctx, ev.ID, ev)
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
func (r *Client) Publish(ctx context.Context, event *event.E) error {
return r.publish(
ctx, event.IdString(), eventenvelope.NewSubmissionWith(event),
)
}
// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK
// response.
// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK response.
//
// You don't have to build the AUTH event yourself, this function takes a function to which the
// event that must be signed will be passed, so it's only necessary to sign that.
func (r *Client) Auth(
ctx context.T, sign signer.I,
ctx context.Context, sign signer.I,
) (err error) {
authEvent := &event.E{
CreatedAt: timestamp.Now(),
Kind: kind.ClientAuthentication,
Tags: tags.New(
tag.New("relay", r.URL),
tag.New([]byte("challenge"), r.challenge),
tag.New("challenge", string(r.challenge)),
),
Content: nil,
}
if err = authEvent.Sign(sign); chk.E(err) {
err = fmt.Errorf("error signing auth event: %w", err)
return
if err = authEvent.Sign(sign); err != nil {
return fmt.Errorf("error signing auth event: %w", err)
}
return r.publish(ctx, authEvent.ID, authEvent)
return r.publish(
ctx, authEvent.IdString(), authenvelope.NewResponseWith(authEvent),
)
}
func (r *Client) publish(
ctx context.T, id []byte, ev *event.E,
ctx context.Context, id string, env codec.Envelope,
) error {
var err error
var cancel context.F
var cancel context.CancelFunc
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
ctx, cancel = context.TimeoutCause(
ctx, cancel = context.WithTimeoutCause(
ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"),
)
defer cancel()
} else {
// otherwise make the context cancellable so we can stop everything upon
// receiving an "OK"
ctx, cancel = context.Cancel(ctx)
// otherwise make the context cancellable so we can stop everything upon receiving an "OK"
ctx, cancel = context.WithCancel(ctx)
defer cancel()
}
// listen for an OK callback
gotOk := false
ids := hex.Enc(id)
r.okCallbacks.Store(
ids, func(ok bool, reason string) {
id, func(ok bool, reason string) {
gotOk = true
if !ok {
err = fmt.Errorf("msg: %s", reason)
@@ -457,18 +449,18 @@ func (r *Client) publish(
cancel()
},
)
defer r.okCallbacks.Delete(ids)
defer r.okCallbacks.Delete(id)
// publish event
envb := eventenvelope.NewSubmissionWith(ev).Marshal(nil)
// envb := ev.Marshal(nil)
envb := env.Marshal(nil)
if err = <-r.Write(envb); err != nil {
return err
}
for {
select {
case <-ctx.Done():
// this will be called when we get an OK or when the context has
// been canceled
// this will be called when we get an OK or when the context has been canceled
if gotOk {
return err
}
@@ -480,40 +472,41 @@ func (r *Client) publish(
}
}
// Subscribe sends a "REQ" command to the relay r as in NIP-01. Events are
// returned through the channel sub.Events. The subscription is closed when
// context ctx is cancelled ("CLOSE" in NIP-01).
// Subscribe sends a "REQ" command to the relay r as in NIP-01.
// Events are returned through the channel sub.Events.
// The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01).
//
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or
// ensuring their `context.T` will be canceled at some point. Failure to
// do that will result in a huge number of halted goroutines being created.
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
// Failure to do that will result in a huge number of halted goroutines being created.
func (r *Client) Subscribe(
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
) (sub *Subscription, err error) {
sub = r.PrepareSubscription(ctx, ff, opts...)
ctx context.Context, ff *filters.T, opts ...SubscriptionOption,
) (*Subscription, error) {
sub := r.PrepareSubscription(ctx, ff, opts...)
if r.Connection == nil {
return nil, fmt.Errorf("not connected to %s", r.URL)
}
if err = sub.Fire(); err != nil {
err = fmt.Errorf(
"couldn't subscribe to %v at %s: %w", ff.Marshal(nil), r.URL, err,
if err := sub.Fire(); err != nil {
return nil, fmt.Errorf(
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
)
return
}
return
return sub, nil
}
// PrepareSubscription creates a subscription, but doesn't fire it.
//
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.T` will be canceled at some point.
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
// Failure to do that will result in a huge number of halted goroutines being created.
func (r *Client) PrepareSubscription(
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
) *Subscription {
ctx context.Context, ff *filters.T, opts ...SubscriptionOption,
) (sub *Subscription) {
current := subscriptionIDCounter.Add(1)
ctx, cancel := context.Cause(ctx)
sub := &Subscription{
Relay: r,
ctx, cancel := context.WithCancelCause(ctx)
sub = &Subscription{
Client: r,
Context: ctx,
cancel: cancel,
counter: current,
@@ -528,10 +521,10 @@ func (r *Client) PrepareSubscription(
switch o := opt.(type) {
case WithLabel:
label = string(o)
// case WithCheckDuplicate:
// sub.checkDuplicate = o
// case WithCheckDuplicateReplaceable:
// sub.checkDuplicateReplaceable = o
case WithCheckDuplicate:
sub.checkDuplicate = o
case WithCheckDuplicateReplaceable:
sub.checkDuplicateReplaceable = o
}
}
// subscription id computation
@@ -540,9 +533,8 @@ func (r *Client) PrepareSubscription(
buf = append(buf, ':')
buf = append(buf, label...)
defer subIdPool.Put(buf)
sub.id = string(buf)
// we track subscriptions only by their counter, no need for the full id
r.Subscriptions.Store(int64(sub.counter), sub)
sub.id = &subscription.Id{T: buf}
r.Subscriptions.Store(string(buf), sub)
// start handling events, eose, unsub etc:
go sub.start()
return sub
@@ -550,8 +542,8 @@ func (r *Client) PrepareSubscription(
// QueryEvents subscribes to events matching the given filter and returns a channel of events.
//
// In most cases it's better to use Pool instead of this method.
func (r *Client) QueryEvents(ctx context.T, f *filter.F) (
// In most cases it's better to use SimplePool instead of this method.
func (r *Client) QueryEvents(ctx context.Context, f *filter.F) (
evc event.C, err error,
) {
var sub *Subscription
@@ -570,44 +562,31 @@ func (r *Client) QueryEvents(ctx context.T, f *filter.F) (
return
}
}()
return sub.Events, nil
evc = sub.Events
return
}
// QuerySync subscribes to events matching the given filter and returns a slice
// of events. This method blocks until all events are received or the context is
// canceled.
// QuerySync subscribes to events matching the given filter and returns a slice of events.
// This method blocks until all events are received or the context is canceled.
//
// If the filter causes a subscription to open, it will stay open until the
// limit is exceeded. So this method will return an error if the limit is nil.
// If the query blocks, the caller needs to cancel the context to prevent the
// thread stalling.
func (r *Client) QuerySync(ctx context.T, f *filter.F) (
evs event.S, err error,
// In most cases it's better to use SimplePool instead of this method.
func (r *Client) QuerySync(ctx context.Context, ff *filter.F) (
[]*event.E, error,
) {
if f.Limit == nil {
err = errors.New("limit must be set for a sync query to prevent blocking")
return
}
var sub *Subscription
if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) {
return
}
defer sub.unsub(errors.New("QuerySync() ended"))
evs = make(event.S, 0, *f.Limit)
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.F
ctx, cancel = context.TimeoutCause(
var cancel context.CancelFunc
ctx, cancel = context.WithTimeoutCause(
ctx, 7*time.Second, errors.New("QuerySync() took too long"),
)
defer cancel()
}
lim := 250
if f.Limit != nil {
lim = int(*f.Limit)
var lim int
if ff.Limit != nil {
lim = int(*ff.Limit)
}
events := make(event.S, 0, max(lim, 250))
ch, err := r.QueryEvents(ctx, f)
events := make([]*event.E, 0, max(lim, 250))
ch, err := r.QueryEvents(ctx, ff)
if err != nil {
return nil, err
}
@@ -619,50 +598,9 @@ func (r *Client) QuerySync(ctx context.T, f *filter.F) (
return events, nil
}
// // Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
// func (r *Relay) Count(
// ctx context.T,
// filters Filters,
// opts ...SubscriptionOption,
// ) (int64, []byte, error) {
// v, err := r.countInternal(ctx, filters, opts...)
// if err != nil {
// return 0, nil, err
// }
//
// return *v.Count, v.HyperLogLog, nil
// }
//
// func (r *Relay) countInternal(ctx context.T, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
// sub := r.PrepareSubscription(ctx, filters, opts...)
// sub.countResult = make(chan CountEnvelope)
//
// if err := sub.Fire(); err != nil {
// return CountEnvelope{}, err
// }
//
// defer sub.unsub(errors.New("countInternal() ended"))
//
// if _, ok := ctx.Deadline(); !ok {
// // if no timeout is set, force it to 7 seconds
// var cancel context.CancelFunc
// ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long"))
// defer cancel()
// }
//
// for {
// select {
// case count := <-sub.countResult:
// return count, nil
// case <-ctx.Done():
// return CountEnvelope{}, ctx.Err()
// }
// }
// }
// Close closes the relay connection.
func (r *Client) Close() error {
return r.close(errors.New("relay connection closed"))
return r.close(errors.New("Close() called"))
}
func (r *Client) close(reason error) error {

View File

@@ -1,176 +1,118 @@
//go:build !js
package ws
import (
"bytes"
"context"
"errors"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/normalize"
"sync"
"testing"
"time"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/normalize"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
// func TestPublish(t *testing.T) {
// // test note to be sent over websocket
// var err error
// signer := &p256k.Signer{}
// if err = signer.Generate(); chk.E(err) {
// t.Fatal(err)
// }
// textNote := &event.E{
// Kind: kind.TextNote,
// Content: []byte("hello"),
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
// Pubkey: signer.Pub(),
// }
// if err = textNote.Sign(signer); chk.E(err) {
// t.Fatalf("textNote.Sign: %v", err)
// }
// // fake relay server
// var published bool
// ws := newWebsocketServer(
// func(conn *websocket.Conn) {
// // receive message
// var raw []json.RawMessage
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
// t.Errorf("websocket.JSON.Receive: %v", err)
// }
// // check that it's an EVENT message
// if len(raw) < 2 {
// t.Errorf("message too short: %v", raw)
// }
// var msgType string
// if err := json.Unmarshal(raw[0], &msgType); chk.T(err) {
// t.Errorf("json.Unmarshal: %v", err)
// }
// if msgType != "EVENT" {
// t.Errorf("expected EVENT message, got %q", msgType)
// }
// // check that the event is the one we sent
// var ev event.E
// if err := json.Unmarshal(raw[1], &ev); chk.T(err) {
// t.Errorf("json.Unmarshal: %v", err)
// }
// published = true
// if !bytes.Equal(ev.ID, textNote.ID) {
// t.Errorf(
// "event ID mismatch: got %x, want %x",
// ev.ID, textNote.ID,
// )
// }
// if !bytes.Equal(ev.Pubkey, textNote.Pubkey) {
// t.Errorf(
// "event pubkey mismatch: got %x, want %x",
// ev.Pubkey, textNote.Pubkey,
// )
// }
// if !bytes.Equal(ev.Content, textNote.Content) {
// t.Errorf(
// "event content mismatch: got %q, want %q",
// ev.Content, textNote.Content,
// )
// }
// fmt.Printf(
// "received event: %s\n",
// textNote.Serialize(),
// )
// // send back an ok nip-20 command result
// var res []byte
// if res = okenvelope.NewFrom(
// textNote.ID, true, nil,
// ).Marshal(res); chk.E(err) {
// t.Fatal(err)
// }
// if err := websocket.Message.Send(conn, res); chk.T(err) {
// t.Errorf("websocket.Message.Send: %v", err)
// }
// },
// )
// defer ws.Close()
// // connect a client and send the text note
// rl := mustRelayConnect(ws.URL)
// err = rl.Publish(context.Background(), textNote)
// if err != nil {
// t.Errorf("publish should have succeeded")
// }
// if !published {
// t.Errorf("fake relay server saw no event")
// }
// }
//
// func TestPublishBlocked(t *testing.T) {
// // test note to be sent over websocket
// var err error
// signer := &p256k.Signer{}
// if err = signer.Generate(); chk.E(err) {
// t.Fatal(err)
// }
// textNote := &event.E{
// Kind: kind.TextNote,
// Content: []byte("hello"),
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
// Pubkey: signer.Pub(),
// }
// if err = textNote.Sign(signer); chk.E(err) {
// t.Fatalf("textNote.Sign: %v", err)
// }
// // fake relay server
// ws := newWebsocketServer(
// func(conn *websocket.Conn) {
// // discard received message; not interested
// var raw []json.RawMessage
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
// t.Errorf("websocket.JSON.Receive: %v", err)
// }
// // send back a not ok nip-20 command result
// var res []byte
// if res = okenvelope.NewFrom(
// textNote.ID, false,
// normalize.Msg(normalize.Blocked, "no reason"),
// ).Marshal(res); chk.E(err) {
// t.Fatal(err)
// }
// if err := websocket.Message.Send(conn, res); chk.T(err) {
// t.Errorf("websocket.Message.Send: %v", err)
// }
// // res := []any{"OK", textNote.ID, false, "blocked"}
// },
// )
// defer ws.Close()
//
// // connect a client and send a text note
// rl := mustRelayConnect(ws.URL)
// if err = rl.Publish(context.Background(), textNote); !chk.E(err) {
// t.Errorf("should have failed to publish")
// }
// }
func TestPublishWriteFailed(t *testing.T) {
func TestPublish(t *testing.T) {
// test note to be sent over websocket
var err error
signer := &p256k.Signer{}
if err = signer.Generate(); chk.E(err) {
t.Fatal(err)
}
priv, pub := makeKeyPair(t)
textNote := &event.E{
Kind: kind.TextNote,
Content: []byte("hello"),
CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
Pubkey: signer.Pub(),
CreatedAt: timestamp.New(1672068534), // random fixed timestamp
Tags: tags.New(tag.New("foo", "bar")),
Pubkey: pub,
}
if err = textNote.Sign(signer); chk.E(err) {
t.Fatalf("textNote.Sign: %v", err)
sign := &p256k.Signer{}
var err error
if err = sign.InitSec(priv); chk.E(err) {
}
err = textNote.Sign(sign)
assert.NoError(t, err)
// fake relay server
var mu sync.Mutex // guards published to satisfy go test -race
var published bool
ws := newWebsocketServer(
func(conn *websocket.Conn) {
mu.Lock()
published = true
mu.Unlock()
// verify the client sent exactly the textNote
var raw []json.RawMessage
err := websocket.JSON.Receive(conn, &raw)
assert.NoError(t, err)
event := parseEventMessage(t, raw)
assert.True(t, bytes.Equal(event.Serialize(), textNote.Serialize()))
// send back an ok nip-20 command result
res := []any{"OK", textNote.IdString(), true, ""}
err = websocket.JSON.Send(conn, res)
assert.NoError(t, err)
},
)
defer ws.Close()
// connect a client and send the text note
rl := mustRelayConnect(t, ws.URL)
err = rl.Publish(context.Background(), textNote)
assert.NoError(t, err)
assert.True(t, published, "fake relay server saw no event")
}
func TestPublishBlocked(t *testing.T) {
// test note to be sent over websocket
textNote := &event.E{
Kind: kind.TextNote, Content: []byte("hello"),
CreatedAt: timestamp.Now(),
}
textNote.ID = textNote.GetIDBytes()
// fake relay server
ws := newWebsocketServer(
func(conn *websocket.Conn) {
// discard received message; not interested
var raw []json.RawMessage
err := websocket.JSON.Receive(conn, &raw)
assert.NoError(t, err)
// send back a not ok nip-20 command result
res := []any{"OK", textNote.IdString(), false, "blocked"}
websocket.JSON.Send(conn, res)
},
)
defer ws.Close()
// connect a client and send a text note
rl := mustRelayConnect(t, ws.URL)
err := rl.Publish(context.Background(), textNote)
assert.Error(t, err)
}
func TestPublishWriteFailed(t *testing.T) {
// test note to be sent over websocket
textNote := &event.E{
Kind: kind.TextNote, Content: []byte("hello"),
CreatedAt: timestamp.Now(),
}
textNote.ID = textNote.GetIDBytes()
// fake relay server
ws := newWebsocketServer(
func(conn *websocket.Conn) {
@@ -179,15 +121,12 @@ func TestPublishWriteFailed(t *testing.T) {
},
)
defer ws.Close()
// connect a client and send a text note
rl := mustRelayConnect(ws.URL)
rl := mustRelayConnect(t, ws.URL)
// Force brief period of time so that publish always fails on closed socket.
time.Sleep(1 * time.Millisecond)
err = rl.Publish(context.Background(), textNote)
if err == nil {
t.Errorf("should have failed to publish")
}
err := rl.Publish(context.Background(), textNote)
assert.Error(t, err)
}
func TestConnectContext(t *testing.T) {
@@ -208,16 +147,13 @@ func TestConnectContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
r, err := RelayConnect(ctx, ws.URL)
if err != nil {
t.Fatalf("RelayConnectContext: %v", err)
}
assert.NoError(t, err)
defer r.Close()
mu.Lock()
defer mu.Unlock()
if !connected {
t.Error("fake relay server saw no client connect")
}
assert.True(t, connected, "fake relay server saw no client connect")
}
func TestConnectContextCanceled(t *testing.T) {
@@ -229,11 +165,7 @@ func TestConnectContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // make ctx expired
_, err := RelayConnect(ctx, ws.URL)
if !errors.Is(err, context.Canceled) {
t.Errorf(
"RelayConnectContext returned %v error; want context.Canceled", err,
)
}
assert.ErrorIs(t, err, context.Canceled)
}
func TestConnectWithOrigin(t *testing.T) {
@@ -243,21 +175,21 @@ func TestConnectWithOrigin(t *testing.T) {
defer ws.Close()
// relay client
r := NewRelay(context.Background(), string(normalize.URL(ws.URL)))
r.requestHeader = http.Header{"origin": {"https://example.com"}}
r := NewRelay(
context.Background(), string(normalize.URL(ws.URL)),
WithRequestHeader(http.Header{"origin": {"https://example.com"}}),
)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := r.Connect(ctx)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
assert.NoError(t, err)
}
func discardingHandler(conn *websocket.Conn) {
io.ReadAll(conn) // discard all input
}
func newWebsocketServer(handler func(*websocket.Conn)) (server *httptest.Server) {
func newWebsocketServer(handler func(*websocket.Conn)) *httptest.Server {
return httptest.NewServer(
&websocket.Server{
Handshake: anyOriginHandshake,
@@ -269,16 +201,76 @@ func newWebsocketServer(handler func(*websocket.Conn)) (server *httptest.Server)
// anyOriginHandshake is an alternative to default in golang.org/x/net/websocket
// which checks for origin. nostr client sends no origin and it makes no difference
// for the tests here anyway.
var anyOriginHandshake = func(
conf *websocket.Config, r *http.Request,
) (err error) {
var anyOriginHandshake = func(conf *websocket.Config, r *http.Request) error {
return nil
}
func mustRelayConnect(url string) (client *Client) {
rl, err := RelayConnect(context.Background(), url)
if err != nil {
panic(err.Error())
func makeKeyPair(t *testing.T) (sec, pub []byte) {
t.Helper()
sign := &p256k.Signer{}
var err error
if err = sign.Generate(); chk.E(err) {
return
}
sec = sign.Sec()
pub = sign.Pub()
assert.NoError(t, err)
return
}
func mustRelayConnect(t *testing.T, url string) *Client {
t.Helper()
rl, err := RelayConnect(context.Background(), url)
require.NoError(t, err)
return rl
}
func parseEventMessage(t *testing.T, raw []json.RawMessage) *event.E {
t.Helper()
assert.Condition(
t, func() (success bool) {
return len(raw) >= 2
},
)
var typ string
err := json.Unmarshal(raw[0], &typ)
assert.NoError(t, err)
assert.Equal(t, "EVENT", typ)
event := &event.E{}
_, err = event.Unmarshal(raw[1])
require.NoError(t, err)
return event
}
func parseSubscriptionMessage(
t *testing.T, raw []json.RawMessage,
) (subid string, ff *filters.T) {
t.Helper()
assert.Greater(t, len(raw), 3)
var typ string
err := json.Unmarshal(raw[0], &typ)
assert.NoError(t, err)
assert.Equal(t, "REQ", typ)
var id string
err = json.Unmarshal(raw[1], &id)
assert.NoError(t, err)
ff = &filters.T{}
for _, b := range raw[2:] {
var f *filter.F
err = json.Unmarshal(b, &f)
assert.NoError(t, err)
ff.F = append(ff.F, f)
}
return id, ff
}

View File

@@ -1,43 +1,18 @@
package ws
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/units"
"time"
ws "github.com/coder/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions(
requestHeader http.Header, tlsConfig *tls.Config,
) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}
// Connection represents a websocket connection to a Nostr relay.
type Connection struct {
conn *ws.Conn
@@ -45,42 +20,46 @@ type Connection struct {
// NewConnection creates a new websocket connection to a Nostr relay.
func NewConnection(
ctx context.Context, url string, requestHeader http.Header,
ctx context.T, url string, reqHeader http.Header,
tlsConfig *tls.Config,
) (*Connection, error) {
c, _, err := ws.Dial(
ctx, url, getConnectionOptions(requestHeader, tlsConfig),
)
if err != nil {
return nil, err
) (c *Connection, err error) {
var conn *ws.Conn
if conn, _, err = ws.Dial(
ctx, url, getConnectionOptions(reqHeader, tlsConfig),
); err != nil {
return
}
c.SetReadLimit(2 << 24) // 33MB
conn.SetReadLimit(33 * units.Mb)
return &Connection{
conn: c,
conn: conn,
}, nil
}
// WriteMessage writes arbitrary bytes to the websocket connection.
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
func (c *Connection) WriteMessage(
ctx context.T, data []byte,
) (err error) {
if err = c.conn.Write(ctx, ws.MessageText, data); err != nil {
err = fmt.Errorf("failed to write message: %w", err)
return
}
return nil
}
// ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer.
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
func (c *Connection) ReadMessage(
ctx context.T, buf io.Writer,
) (err error) {
var reader io.Reader
if _, reader, err = c.conn.Reader(ctx); err != nil {
err = fmt.Errorf("failed to get reader: %w", err)
return
}
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
if _, err = io.Copy(buf, reader); err != nil {
err = fmt.Errorf("failed to read message: %w", err)
return
}
return nil
return
}
// Close closes the websocket connection.
@@ -89,8 +68,8 @@ func (c *Connection) Close() error {
}
// Ping sends a ping message to the websocket connection.
func (c *Connection) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeoutCause(
func (c *Connection) Ping(ctx context.T) error {
ctx, cancel := context.TimeoutCause(
ctx, time.Millisecond*800, errors.New("ping took too long"),
)
defer cancel()

View File

@@ -0,0 +1,36 @@
//go:build !js
package ws
import (
"crypto/tls"
"net/http"
"net/textproto"
ws "github.com/coder/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions(
requestHeader http.Header, tlsConfig *tls.Config,
) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}

View File

@@ -1,905 +0,0 @@
package ws
import (
"errors"
"fmt"
"math"
"net/http"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/puzpuzpuz/xsync/v3"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/normalize"
)
const (
seenAlreadyDropTick = time.Minute
)
// Pool manages connections to multiple relays, ensures they are reopened when necessary and not duplicated.
type Pool struct {
Relays *xsync.MapOf[string, *Client]
Context context.T
authHandler func() signer.I
cancel context.C
eventMiddleware func(RelayEvent)
duplicateMiddleware func(relay, id string)
queryMiddleware func(relay, pubkey string, kind uint16)
// custom things not often used
penaltyBoxMu sync.Mutex
penaltyBox map[string][2]float64
relayOptions []RelayOption
}
// DirectedFilter combines a Filter with a specific relay URL.
type DirectedFilter struct {
*filter.F
Relay string
}
// RelayEvent represents an event received from a specific relay.
type RelayEvent struct {
*event.E
Relay *Client
}
func (ie RelayEvent) String() string {
return fmt.Sprintf(
"[%s] >> %s", ie.Relay.URL, ie.E.Marshal(nil),
)
}
// PoolOption is an interface for options that can be applied to a Pool.
type PoolOption interface {
ApplyPoolOption(*Pool)
}
// NewPool creates a new Pool with the given context and options.
func NewPool(c context.T, opts ...PoolOption) (pool *Pool) {
ctx, cancel := context.Cause(c)
pool = &Pool{
Relays: xsync.NewMapOf[string, *Client](),
Context: ctx,
cancel: cancel,
}
for _, opt := range opts {
opt.ApplyPoolOption(pool)
}
return pool
}
// WithRelayOptions sets options that will be used on every relay instance created by this pool.
func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt {
return ropts
}
type withRelayOptionsOpt []RelayOption
func (h withRelayOptionsOpt) ApplyPoolOption(pool *Pool) {
pool.relayOptions = h
}
// WithAuthHandler must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message
// with the "auth-required:" prefix, only once for each relay
type WithAuthHandler func() signer.I
func (h WithAuthHandler) ApplyPoolOption(pool *Pool) {
pool.authHandler = h
}
// WithPenaltyBox just sets the penalty box mechanism so relays that fail to connect
// or that disconnect will be ignored for a while and we won't attempt to connect again.
func WithPenaltyBox() withPenaltyBoxOpt { return withPenaltyBoxOpt{} }
type withPenaltyBoxOpt struct{}
func (h withPenaltyBoxOpt) ApplyPoolOption(pool *Pool) {
pool.penaltyBox = make(map[string][2]float64)
go func() {
sleep := 30.0
for {
time.Sleep(time.Duration(sleep) * time.Second)
pool.penaltyBoxMu.Lock()
nextSleep := 300.0
for url, v := range pool.penaltyBox {
remainingSeconds := v[1]
remainingSeconds -= sleep
if remainingSeconds <= 0 {
pool.penaltyBox[url] = [2]float64{v[0], 0}
continue
} else {
pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds}
}
if remainingSeconds < nextSleep {
nextSleep = remainingSeconds
}
}
sleep = nextSleep
pool.penaltyBoxMu.Unlock()
}
}()
}
// WithEventMiddleware is a function that will be called with all events received.
type WithEventMiddleware func(RelayEvent)
func (h WithEventMiddleware) ApplyPoolOption(pool *Pool) {
pool.eventMiddleware = h
}
// WithDuplicateMiddleware is a function that will be called with all duplicate ids received.
type WithDuplicateMiddleware func(relay string, id string)
func (h WithDuplicateMiddleware) ApplyPoolOption(pool *Pool) {
pool.duplicateMiddleware = h
}
// WithAuthorKindQueryMiddleware is a function that will be called with every combination of relay+pubkey+kind queried
// in a .SubMany*() call -- when applicable (i.e. when the query contains a pubkey and a kind).
type WithAuthorKindQueryMiddleware func(relay, pubkey string, kind uint16)
func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *Pool) {
pool.queryMiddleware = h
}
var (
_ PoolOption = (WithAuthHandler)(nil)
_ PoolOption = (WithEventMiddleware)(nil)
_ PoolOption = WithPenaltyBox()
_ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{}))
)
const MAX_LOCKS = 50
var namedMutexPool = make([]sync.Mutex, MAX_LOCKS)
//go:noescape
//go:linkname memhash runtime.memhash
func memhash(p unsafe.Pointer, h, s uintptr) uintptr
func namedLock[V ~[]byte | ~string](name V) (unlock func()) {
sptr := unsafe.StringData(string(name))
idx := uint64(
memhash(
unsafe.Pointer(sptr), 0, uintptr(len(name)),
),
) % MAX_LOCKS
namedMutexPool[idx].Lock()
return namedMutexPool[idx].Unlock
}
// EnsureRelay ensures that a relay connection exists and is active.
// If the relay is not connected, it attempts to connect.
func (p *Pool) EnsureRelay(url string) (*Client, error) {
nm := string(normalize.URL(url))
defer namedLock(nm)()
relay, ok := p.Relays.Load(nm)
if ok && relay == nil {
if p.penaltyBox != nil {
p.penaltyBoxMu.Lock()
defer p.penaltyBoxMu.Unlock()
v, _ := p.penaltyBox[nm]
if v[1] > 0 {
return nil, fmt.Errorf("in penalty box, %fs remaining", v[1])
}
}
} else if ok && relay.IsConnected() {
// already connected, unlock and return
return relay, nil
}
// try to connect
// we use this ctx here so when the p dies everything dies
ctx, cancel := context.TimeoutCause(
p.Context,
time.Second*15,
errors.New("connecting to the relay took too long"),
)
defer cancel()
relay = NewRelay(context.Bg(), url, p.relayOptions...)
if err := relay.Connect(ctx); err != nil {
if p.penaltyBox != nil {
// putting relay in penalty box
p.penaltyBoxMu.Lock()
defer p.penaltyBoxMu.Unlock()
v, _ := p.penaltyBox[nm]
p.penaltyBox[nm] = [2]float64{
v[0] + 1, 30.0 + math.Pow(2, v[0]+1),
}
}
return nil, fmt.Errorf("failed to connect: %w", err)
}
p.Relays.Store(nm, relay)
return relay, nil
}
// PublishResult represents the result of publishing an event to a relay.
type PublishResult struct {
Error error
RelayURL string
Relay *Client
}
// todo: this didn't used to be in this package... probably don't want to add it
// either.
//
// PublishMany publishes an event to multiple relays and returns a
// channel of results emitted as they're received.
// func (pool *Pool) PublishMany(
// ctx context.T, urls []string, evt *event.E,
// ) chan PublishResult {
// ch := make(chan PublishResult, len(urls))
// wg := sync.WaitGroup{}
// wg.Add(len(urls))
// go func() {
// for _, url := range urls {
// go func() {
// defer wg.Done()
// relay, err := pool.EnsureRelay(url)
// if err != nil {
// ch <- PublishResult{err, url, nil}
// return
// }
// if err = relay.Publish(ctx, evt); err == nil {
// // success with no auth required
// ch <- PublishResult{nil, url, relay}
// } else if strings.HasPrefix(
// err.Error(), "msg: auth-required:",
// ) && pool.authHandler != nil {
// // try to authenticate if we can
// if authErr := relay.Auth(
// ctx, pool.authHandler(),
// ); authErr == nil {
// if err := relay.Publish(ctx, evt); err == nil {
// // success after auth
// ch <- PublishResult{nil, url, relay}
// } else {
// // failure after auth
// ch <- PublishResult{err, url, relay}
// }
// } else {
// // failure to auth
// ch <- PublishResult{
// fmt.Errorf(
// "failed to auth: %w", authErr,
// ), url, relay,
// }
// }
// } else {
// // direct failure
// ch <- PublishResult{err, url, relay}
// }
// }()
// }
//
// wg.Wait()
// close(ch)
// }()
//
// return ch
// }
// SubscribeMany opens a subscription with the given filter to multiple relays
// the subscriptions ends when the context is canceled or when all relays return a CLOSED.
func (p *Pool) SubscribeMany(
ctx context.T,
urls []string,
filter *filter.F,
opts ...SubscriptionOption,
) chan RelayEvent {
return p.subMany(ctx, urls, filters.New(filter), nil, opts...)
}
// FetchMany opens a subscription, much like SubscribeMany, but it ends as soon as all Relays
// return an EOSE message.
func (p *Pool) FetchMany(
ctx context.T,
urls []string,
filter *filter.F,
opts ...SubscriptionOption,
) chan RelayEvent {
return p.SubManyEose(ctx, urls, filters.New(filter), opts...)
}
// Deprecated: SubMany is deprecated: use SubscribeMany instead.
func (p *Pool) SubMany(
ctx context.T,
urls []string,
filters *filters.T,
opts ...SubscriptionOption,
) chan RelayEvent {
return p.subMany(ctx, urls, filters, nil, opts...)
}
// SubscribeManyNotifyEOSE is like SubscribeMany, but takes a channel that is closed when
// all subscriptions have received an EOSE
func (p *Pool) SubscribeManyNotifyEOSE(
ctx context.T,
urls []string,
filter *filter.F,
eoseChan chan struct{},
opts ...SubscriptionOption,
) chan RelayEvent {
return p.subMany(ctx, urls, filters.New(filter), eoseChan, opts...)
}
type ReplaceableKey struct {
PubKey string
D string
}
// FetchManyReplaceable is like FetchMany, but deduplicates replaceable and addressable events and returns
// only the latest for each "d" tag.
func (p *Pool) FetchManyReplaceable(
ctx context.T,
urls []string,
f *filter.F,
opts ...SubscriptionOption,
) *xsync.MapOf[ReplaceableKey, *event.E] {
ctx, cancel := context.Cause(ctx)
results := xsync.NewMapOf[ReplaceableKey, *event.E]()
wg := sync.WaitGroup{}
wg.Add(len(urls))
// todo: this is a hack for compensating for retarded relays that don't
// filter replaceable events because it streams them back over a channel.
// this is out of spec anyway so should not be handled. replaceable events
// are supposed to delete old versions. the end. this is for the incorrect
// behaviour of fiatjaf's database code, which he obviously thinks is clever
// for using channels, and not sorting results before dispatching them
// before EOSE.
_ = 0
// seenAlreadyLatest := xsync.NewMapOf[ReplaceableKey,
// *timestamp.T]() opts = append(
// opts, WithCheckDuplicateReplaceable(
// func(rk ReplaceableKey, ts Timestamp) bool {
// updated := false
// seenAlreadyLatest.Compute(
// rk, func(latest Timestamp, _ bool) (
// newValue Timestamp, delete bool,
// ) {
// if ts > latest {
// updated = true // we are updating the most recent
// return ts, false
// }
// return latest, false // the one we had was already more recent
// },
// )
// return updated
// },
// ),
// )
for _, url := range urls {
go func(nm string) {
defer wg.Done()
if mh := p.queryMiddleware; mh != nil {
if f.Kinds != nil && f.Authors != nil {
for _, kind := range f.Kinds.K {
for _, author := range f.Authors.ToStringSlice() {
mh(nm, author, kind.K)
}
}
}
}
relay, err := p.EnsureRelay(nm)
if err != nil {
log.D.C(
func() string {
return fmt.Sprintf(
"error connecting to %s with %v: %s", nm, f, err,
)
},
)
return
}
hasAuthed := false
subscribe:
sub, err := relay.Subscribe(ctx, filters.New(f), opts...)
if err != nil {
log.D.C(
func() string {
return fmt.Sprintf(
"error subscribing to %s with %v: %s", relay, f,
err,
)
},
)
return
}
for {
select {
case <-ctx.Done():
return
case <-sub.EndOfStoredEvents:
return
case reason := <-sub.ClosedReason:
if strings.HasPrefix(
reason, "auth-required:",
) && p.authHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again
err = relay.Auth(
ctx, p.authHandler(),
)
if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe
}
}
log.D.F("CLOSED from %s: '%s'\n", nm, reason)
return
case evt, more := <-sub.Events:
if !more {
return
}
ie := RelayEvent{E: evt, Relay: relay}
if mh := p.eventMiddleware; mh != nil {
mh(ie)
}
results.Store(
ReplaceableKey{hex.Enc(evt.Pubkey), evt.Tags.GetD()},
evt,
)
}
}
}(string(normalize.URL(url)))
}
// this will happen when all subscriptions get an eose (or when they die)
wg.Wait()
cancel(errors.New("all subscriptions ended"))
return results
}
func (p *Pool) subMany(
ctx context.T,
urls []string,
ff *filters.T,
eoseChan chan struct{},
opts ...SubscriptionOption,
) chan RelayEvent {
ctx, cancel := context.Cause(ctx)
_ = cancel // do this so `go vet` will stop complaining
events := make(chan RelayEvent)
seenAlready := xsync.NewMapOf[string, *timestamp.T]()
ticker := time.NewTicker(seenAlreadyDropTick)
eoseWg := sync.WaitGroup{}
eoseWg.Add(len(urls))
if eoseChan != nil {
go func() {
eoseWg.Wait()
close(eoseChan)
}()
}
pending := xsync.NewCounter()
pending.Add(int64(len(urls)))
for i, url := range urls {
url = string(normalize.URL(url))
urls[i] = url
if idx := slices.Index(urls, url); idx != i {
// skip duplicate relays in the list
eoseWg.Done()
continue
}
eosed := atomic.Bool{}
firstConnection := true
go func(nm string) {
defer func() {
pending.Dec()
if pending.Value() == 0 {
close(events)
cancel(fmt.Errorf("aborted: %w", context.GetCause(ctx)))
}
if eosed.CompareAndSwap(false, true) {
eoseWg.Done()
}
}()
hasAuthed := false
interval := 3 * time.Second
for {
select {
case <-ctx.Done():
return
default:
}
var sub *Subscription
if mh := p.queryMiddleware; mh != nil {
for _, f := range ff.F {
if f.Kinds != nil && f.Authors != nil {
for _, k := range f.Kinds.K {
for _, author := range f.Authors.ToSliceOfBytes() {
mh(nm, hex.Enc(author), k.K)
}
}
}
}
}
relay, err := p.EnsureRelay(nm)
if err != nil {
// if we never connected to this just fail
if firstConnection {
return
}
// otherwise (if we were connected and got disconnected) keep trying to reconnect
log.D.F("%s reconnecting because connection failed", nm)
goto reconnect
}
firstConnection = false
hasAuthed = false
subscribe:
sub, err = relay.Subscribe(
ctx, ff,
// append(
opts...,
// WithCheckDuplicate(
// func(id, relay string) bool {
// _, exists := seenAlready.Load(id)
// if exists && p.duplicateMiddleware != nil {
// p.duplicateMiddleware(relay, id)
// }
// return exists
// },
// ),
// )...,
)
if err != nil {
log.D.F("%s reconnecting because subscription died", nm)
goto reconnect
}
go func() {
<-sub.EndOfStoredEvents
// guard here otherwise a resubscription will trigger a duplicate call to eoseWg.Done()
if eosed.CompareAndSwap(false, true) {
eoseWg.Done()
}
}()
// reset interval when we get a good subscription
interval = 3 * time.Second
for {
select {
case evt, more := <-sub.Events:
if !more {
// this means the connection was closed for weird reasons, like the server shut down
// so we will update the filters here to include only events seem from now on
// and try to reconnect until we succeed
now := timestamp.Now()
for i := range ff.F {
ff.F[i].Since = now
}
log.D.F(
"%s reconnecting because sub.Events is broken",
nm,
)
goto reconnect
}
ie := RelayEvent{E: evt, Relay: relay}
if mh := p.eventMiddleware; mh != nil {
mh(ie)
}
select {
case events <- ie:
case <-ctx.Done():
return
}
case <-ticker.C:
if eosed.Load() {
old := timestamp.New(time.Now().Add(-seenAlreadyDropTick).Unix())
for id, value := range seenAlready.Range {
if value.I64() < old.I64() {
seenAlready.Delete(id)
}
}
}
case reason := <-sub.ClosedReason:
if strings.HasPrefix(
reason, "auth-required:",
) && p.authHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again
err = relay.Auth(
ctx, p.authHandler(),
)
if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe
}
} else {
log.D.F("CLOSED from %s: '%s'\n", nm, reason)
}
return
case <-ctx.Done():
return
}
}
reconnect:
// we will go back to the beginning of the loop and try to connect again and again
// until the context is canceled
time.Sleep(interval)
interval = interval * 17 / 10 // the next time we try we will wait longer
}
}(url)
}
return events
}
// Deprecated: SubManyEose is deprecated: use FetchMany instead.
func (p *Pool) SubManyEose(
ctx context.T,
urls []string,
filters *filters.T,
opts ...SubscriptionOption,
) chan RelayEvent {
// seenAlready := xsync.NewMapOf[string, struct{}]()
return p.subManyEoseNonOverwriteCheckDuplicate(
ctx, urls, filters,
// WithCheckDuplicate(
// func(id, relay string) bool {
// _, exists := seenAlready.LoadOrStore(id, struct{}{})
// if exists && p.duplicateMiddleware != nil {
// p.duplicateMiddleware(relay, id)
// }
// return exists
// },
// ),
opts...,
)
}
func (p *Pool) subManyEoseNonOverwriteCheckDuplicate(
ctx context.T,
urls []string,
filters *filters.T,
// wcd WithCheckDuplicate,
opts ...SubscriptionOption,
) chan RelayEvent {
ctx, cancel := context.Cause(ctx)
events := make(chan RelayEvent)
wg := sync.WaitGroup{}
wg.Add(len(urls))
// opts = append(opts, wcd)
go func() {
// this will happen when all subscriptions get an eose (or when they die)
wg.Wait()
cancel(errors.New("all subscriptions ended"))
close(events)
}()
for _, url := range urls {
go func(nm string) {
defer wg.Done()
if mh := p.queryMiddleware; mh != nil {
for _, filter := range filters.F {
if filter.Kinds != nil && filter.Authors != nil {
for _, k := range filter.Kinds.K {
for _, author := range filter.Authors.ToSliceOfBytes() {
mh(nm, hex.Enc(author), k.K)
}
}
}
}
}
relay, err := p.EnsureRelay(nm)
if err != nil {
log.D.C(
func() string {
return fmt.Sprintf(
"error connecting to %s with %v: %s", nm, filters,
err,
)
},
)
return
}
hasAuthed := false
subscribe:
sub, err := relay.Subscribe(ctx, filters, opts...)
if err != nil {
log.D.C(
func() string {
return fmt.Sprintf(
"error subscribing to %s with %v: %s", relay,
filters,
err,
)
},
)
return
}
for {
select {
case <-ctx.Done():
return
case <-sub.EndOfStoredEvents:
return
case reason := <-sub.ClosedReason:
if strings.HasPrefix(
reason, "auth-required:",
) && p.authHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again
err = relay.Auth(
ctx, p.authHandler(),
)
if err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe
}
}
log.D.F("CLOSED from %s: '%s'\n", nm, reason)
return
case evt, more := <-sub.Events:
if !more {
return
}
ie := RelayEvent{E: evt, Relay: relay}
if mh := p.eventMiddleware; mh != nil {
mh(ie)
}
select {
case events <- ie:
case <-ctx.Done():
return
}
}
}
}(string(normalize.URL(url)))
}
return events
}
// // CountMany aggregates count results from multiple relays using NIP-45 HyperLogLog
// func (pool *Pool) CountMany(
// ctx context.T,
// urls []string,
// filter *filter.F,
// opts []SubscriptionOption,
// ) int {
// hll := hyperloglog.New(0) // offset is irrelevant here
//
// wg := sync.WaitGroup{}
// wg.Add(len(urls))
// for _, url := range urls {
// go func(nm string) {
// defer wg.Done()
// relay, err := pool.EnsureRelay(url)
// if err != nil {
// return
// }
// ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
// if err != nil {
// return
// }
// if len(ce.HyperLogLog) != 256 {
// return
// }
// hll.MergeRegisters(ce.HyperLogLog)
// }(NormalizeURL(url))
// }
//
// wg.Wait()
// return int(hll.Count())
// }
// QuerySingle returns the first event returned by the first relay, cancels everything else.
func (p *Pool) QuerySingle(
ctx context.T,
urls []string,
filter *filter.F,
opts ...SubscriptionOption,
) *RelayEvent {
ctx, cancel := context.Cause(ctx)
for ievt := range p.SubManyEose(
ctx, urls, filters.New(filter), opts...,
) {
cancel(errors.New("got the first event and ended successfully"))
return &ievt
}
cancel(errors.New("SubManyEose() didn't get yield events"))
return nil
}
// BatchedSubManyEose performs batched subscriptions to multiple relays with different filters.
func (p *Pool) BatchedSubManyEose(
ctx context.T,
dfs []DirectedFilter,
opts ...SubscriptionOption,
) chan RelayEvent {
res := make(chan RelayEvent)
wg := sync.WaitGroup{}
wg.Add(len(dfs))
// seenAlready := xsync.NewMapOf[string, struct{}]()
for _, df := range dfs {
go func(df DirectedFilter) {
for ie := range p.subManyEoseNonOverwriteCheckDuplicate(
ctx,
[]string{df.Relay},
filters.New(df.F),
// WithCheckDuplicate(
// func(id, relay string) bool {
// _, exists := seenAlready.LoadOrStore(id, struct{}{})
// if exists && p.duplicateMiddleware != nil {
// p.duplicateMiddleware(relay, id)
// }
// return exists
// },
// ),
opts...,
) {
select {
case res <- ie:
case <-ctx.Done():
wg.Done()
return
}
}
wg.Done()
}(df)
}
go func() {
wg.Wait()
close(res)
}()
return res
}
// Close closes the pool with the given reason.
func (p *Pool) Close(reason string) {
p.cancel(fmt.Errorf("pool closed with reason: '%s'", reason))
}

View File

@@ -1,216 +0,0 @@
package ws
import (
"context"
"sync"
"testing"
"time"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/signer"
)
// mockSigner implements signer.I for testing
type mockSigner struct {
pubkey []byte
}
func (m *mockSigner) Pub() []byte { return m.pubkey }
func (m *mockSigner) Sign([]byte) (
[]byte, error,
) {
return []byte("mock-signature"), nil
}
func (m *mockSigner) Generate() error { return nil }
func (m *mockSigner) InitSec([]byte) error { return nil }
func (m *mockSigner) InitPub([]byte) error { return nil }
func (m *mockSigner) Sec() []byte { return []byte("mock-secret") }
func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil }
func (m *mockSigner) Zero() {}
func (m *mockSigner) ECDH([]byte) (
[]byte, error,
) {
return []byte("mock-shared-secret"), nil
}
func TestNewPool(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
if pool == nil {
t.Fatal("NewPool returned nil")
}
if pool.Relays == nil {
t.Error("Pool should have initialized Relays map")
}
if pool.Context == nil {
t.Error("Pool should have a context")
}
}
func TestPoolWithAuthHandler(t *testing.T) {
ctx := context.Background()
authHandler := WithAuthHandler(
func() signer.I {
return &mockSigner{pubkey: []byte("test-pubkey")}
},
)
pool := NewPool(ctx, authHandler)
if pool.authHandler == nil {
t.Error("Pool should have auth handler set")
}
// Test that auth handler returns the expected signer
signer := pool.authHandler()
if string(signer.Pub()) != "test-pubkey" {
t.Errorf(
"Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()),
)
}
}
func TestPoolWithEventMiddleware(t *testing.T) {
ctx := context.Background()
var middlewareCalled bool
middleware := WithEventMiddleware(
func(ie RelayEvent) {
middlewareCalled = true
},
)
pool := NewPool(ctx, middleware)
// Test that middleware is called
testEvent := &event.E{
Kind: kind.TextNote,
Content: []byte("test"),
CreatedAt: timestamp.Now(),
}
ie := RelayEvent{E: testEvent, Relay: nil}
pool.eventMiddleware(ie)
if !middlewareCalled {
t.Error("Expected middleware to be called")
}
}
func TestRelayEventString(t *testing.T) {
testEvent := &event.E{
Kind: kind.TextNote,
Content: []byte("test content"),
CreatedAt: timestamp.Now(),
}
client := &Client{URL: "wss://test.relay"}
ie := RelayEvent{E: testEvent, Relay: client}
str := ie.String()
if !contains(str, "wss://test.relay") {
t.Errorf("Expected string to contain relay URL, got: %s", str)
}
if !contains(str, "test content") {
t.Errorf("Expected string to contain event content, got: %s", str)
}
}
func TestNamedLock(t *testing.T) {
// Test that named locks work correctly
var wg sync.WaitGroup
var counter int
var mu sync.Mutex
lockName := "test-lock"
// Start multiple goroutines that try to increment counter
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
unlock := namedLock(lockName)
defer unlock()
// Critical section
mu.Lock()
temp := counter
time.Sleep(1 * time.Millisecond) // Simulate work
counter = temp + 1
mu.Unlock()
}()
}
wg.Wait()
if counter != 10 {
t.Errorf("Expected counter to be 10, got %d", counter)
}
}
func TestPoolEnsureRelayInvalidURL(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with invalid URL
_, err := pool.EnsureRelay("invalid-url")
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestPoolQuerySingle(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with empty URLs slice
result := pool.QuerySingle(ctx, []string{}, &filter.F{})
if result != nil {
t.Error("Expected nil result for empty URLs")
}
}
// Helper functions
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr)))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func uintPtr(u uint) *uint {
return &u
}
// // Test pool context cancellation
// func TestPoolContextCancellation(t *testing.T) {
// ctx, cancel := context.WithCancel(context.Background())
// pool := NewPool(ctx)
//
// // Cancel the context
// cancel()
//
// // Check that pool context is cancelled
// select {
// case <-pool.Context.Done():
// // Expected
// case <-time.After(100 * time.Millisecond):
// t.Error("Expected pool context to be cancelled")
// }
// }

View File

@@ -4,31 +4,32 @@ import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"orly.dev/pkg/encoders/envelopes/closeenvelope"
"orly.dev/pkg/encoders/envelopes/reqenvelope"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/subscription"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/encoders/timestamp"
"sync"
"sync/atomic"
)
type ReplaceableKey struct {
PubKey string
D string
}
// Subscription represents a subscription to a relay.
type Subscription struct {
counter int64
id string
id *subscription.Id
Relay *Client
Client *Client
Filters *filters.T
// // for this to be treated as a COUNT and not a REQ this must be set
// countResult chan CountEnvelope
// the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
Events event.C
Events chan *event.E
mu sync.Mutex
// the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription
@@ -40,13 +41,13 @@ type Subscription struct {
// Context will be .Done() when the subscription ends
Context context.Context
// // if it is not nil, checkDuplicate will be called for every event received
// // if it returns true that event will not be processed further.
// checkDuplicate func(id string, relay string) bool
//
// // if it is not nil, checkDuplicateReplaceable will be called for every event received
// // if it returns true that event will not be processed further.
// checkDuplicateReplaceable func(rk ReplaceableKey, ts Timestamp) bool
// if it is not nil, checkDuplicate will be called for every event received
// if it returns true that event will not be processed further.
checkDuplicate func(id string, relay string) bool
// if it is not nil, checkDuplicateReplaceable will be called for every event received
// if it returns true that event will not be processed further.
checkDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool
match func(*event.E) bool // this will be either Filters.Match or Filters.MatchIgnoringTimestampConstraints
live atomic.Bool
@@ -69,20 +70,20 @@ type WithLabel string
func (_ WithLabel) IsSubscriptionOption() {}
// // WithCheckDuplicate sets checkDuplicate on the subscription
// type WithCheckDuplicate func(id, relay string) bool
//
// func (_ WithCheckDuplicate) IsSubscriptionOption() {}
//
// // WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription
// type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool
//
// func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {}
// WithCheckDuplicate sets checkDuplicate on the subscription
type WithCheckDuplicate func(id, relay string) bool
func (_ WithCheckDuplicate) IsSubscriptionOption() {}
// WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription
type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool
func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {}
var (
_ SubscriptionOption = (WithLabel)("")
// _ SubscriptionOption = (WithCheckDuplicate)(nil)
// _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil)
_ SubscriptionOption = (WithCheckDuplicate)(nil)
_ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil)
)
func (sub *Subscription) start() {
@@ -98,7 +99,7 @@ func (sub *Subscription) start() {
}
// GetID returns the subscription ID.
func (sub *Subscription) GetID() string { return sub.id }
func (sub *Subscription) GetID() string { return sub.id.String() }
func (sub *Subscription) dispatchEvent(evt *event.E) {
added := false
@@ -117,6 +118,7 @@ func (sub *Subscription) dispatchEvent(evt *event.E) {
case <-sub.Context.Done():
}
}
if added {
sub.storedwg.Done()
}
@@ -159,19 +161,15 @@ func (sub *Subscription) unsub(err error) {
}
// remove subscription from our map
sub.Relay.Subscriptions.Delete(sub.counter)
sub.Client.Subscriptions.Delete(sub.id.String())
}
// Close just sends a CLOSE message. You probably want Unsub() instead.
func (sub *Subscription) Close() {
if sub.Relay.IsConnected() {
id, err := subscription.NewId(sub.id)
if err != nil {
return
}
closeMsg := closeenvelope.NewFrom(id)
if sub.Client.IsConnected() {
closeMsg := closeenvelope.NewFrom(sub.id)
closeb := closeMsg.Marshal(nil)
<-sub.Relay.Write(closeb)
<-sub.Client.Write(closeb)
}
}
@@ -184,21 +182,14 @@ func (sub *Subscription) Sub(_ context.Context, ff *filters.T) {
// Fire sends the "REQ" command to the relay.
func (sub *Subscription) Fire() (err error) {
// if sub.countResult == nil {
req := reqenvelope.NewWithIdString(sub.id, sub.Filters)
if req == nil {
return fmt.Errorf("invalid ID or filters")
}
reqb := req.Marshal(nil)
// } else
// if len(sub.Filters) == 1 {
// reqb, _ = CountEnvelope{sub.id, sub.Filters[0], nil, nil}.MarshalJSON()
// } else {
// return fmt.Errorf("unexpected sub configuration")
var reqb []byte
reqb = reqenvelope.NewFrom(sub.id, sub.Filters).Marshal(nil)
sub.live.Store(true)
if err = <-sub.Relay.Write(reqb); chk.E(err) {
if err = <-sub.Client.Write(reqb); err != nil {
err = fmt.Errorf("failed to write: %w", err)
sub.cancel(err)
return
}
return
}

View File

@@ -1,117 +1,120 @@
package ws
import (
"context"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/utils/values"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const RELAY = "wss://nos.lol"
// // test if we can fetch a couple of random events
// func TestSubscribeBasic(t *testing.F) {
// rl := mustRelayConnect(RELAY)
// defer rl.Close()
// var lim uint = 2
// sub, err := rl.Subscribe(context.Bg(),
// filters.New(&filter.F{Kinds: kinds.New(kind.TextNote), Limit: &lim}))
// if err != nil {
// t.Fatalf("subscription failed: %v", err)
// return
// }
// timeout := time.After(5 * time.Second)
// n := 0
// for {
// select {
// case event := <-sub.Events:
// if event == nil {
// t.Fatalf("event is nil: %v", event)
// }
// n++
// case <-sub.EndOfStoredEvents:
// goto end
// case <-rl.Context().Done():
// t.Errorf("connection closed: %v", rl.Context().Err())
// goto end
// case <-timeout:
// t.Errorf("timeout")
// goto end
// }
// }
// end:
// if n != 2 {
// t.Fatalf("expected 2 events, got %d", n)
// }
// }
// test if we can fetch a couple of random events
func TestSubscribeBasic(t *testing.T) {
rl := mustRelayConnect(t, RELAY)
defer rl.Close()
// // test if we can do multiple nested subscriptions
// func TestNestedSubscriptions(t *testing.T) {
// rl := mustRelayConnect(RELAY)
// defer rl.Close()
//
// n := atomic.Uint32{}
// _ = n
// // fetch 2 replies to a note
// var lim3 uint = 3
// sub, err := rl.Subscribe(
// context.Bg(),
// filters.New(
// &filter.F{
// Kinds: kinds.New(kind.TextNote),
// Tags: tags.New(
// tag.New(
// "e",
// "0e34a74f8547e3b95d52a2543719b109fd0312aba144e2ef95cba043f42fe8c5",
// ),
// ),
// Limit: &lim3,
// },
// ),
// )
// if err != nil {
// t.Fatalf("subscription 1 failed: %v", err)
// return
// }
//
// for {
// select {
// case event := <-sub.Events:
// // now fetch the author of this
// var lim uint = 1
// sub, err := rl.Subscribe(
// context.Bg(),
// filters.New(
// &filter.F{
// Kinds: kinds.New(kind.ProfileMetadata),
// Authors: tag.New(event.Pubkey), Limit: &lim,
// },
// ),
// )
// if err != nil {
// t.Fatalf("subscription 2 failed: %v", err)
// return
// }
//
// for {
// select {
// case <-sub.Events:
// // do another subscription here in "sync" mode, just so
// // we're sure things are not blocking
// rl.QuerySync(context.Bg(), &filter.F{Limit: &lim})
//
// n.Add(1)
// if n.Load() == 3 {
// // if we get here it means the test passed
// return
// }
// case <-sub.Context.Done():
// goto end
// case <-sub.EndOfStoredEvents:
// sub.Unsub()
// }
// }
// end:
// fmt.Println("")
// case <-sub.EndOfStoredEvents:
// sub.Unsub()
// return
// case <-sub.Context.Done():
// t.Fatalf("connection closed: %v", rl.Context().Err())
// return
// }
// }
// }
sub, err := rl.Subscribe(
context.Background(), filters.New(
&filter.F{
Kinds: &kinds.T{K: []*kind.T{kind.TextNote}},
Limit: values.ToUintPointer(2),
},
),
)
assert.NoError(t, err)
timeout := time.After(5 * time.Second)
n := 0
for {
select {
case event := <-sub.Events:
assert.NotNil(t, event)
n++
case <-sub.EndOfStoredEvents:
assert.Equal(t, 2, n)
sub.Unsub()
return
case <-rl.Context().Done():
t.Fatalf("connection closed: %v", rl.Context().Err())
case <-timeout:
t.Fatalf("timeout")
}
}
}
// test if we can do multiple nested subscriptions
func TestNestedSubscriptions(t *testing.T) {
rl := mustRelayConnect(t, RELAY)
defer rl.Close()
n := atomic.Uint32{}
// fetch 2 replies to a note
sub, err := rl.Subscribe(
context.Background(), filters.New(
&filter.F{
Kinds: kinds.New(kind.TextNote),
Tags: tags.New(
tag.New(
"e",
"0e34a74f8547e3b95d52a2543719b109fd0312aba144e2ef95cba043f42fe8c5",
),
),
Limit: values.ToUintPointer(3),
},
),
)
assert.NoError(t, err)
for {
select {
case ev := <-sub.Events:
// now fetch author of this
sub, err := rl.Subscribe(
context.Background(), filters.New(
&filter.F{
Kinds: kinds.New(kind.ProfileMetadata),
Authors: tag.New(ev.PubKeyString()),
Limit: values.ToUintPointer(1),
},
),
)
assert.NoError(t, err)
for {
select {
case <-sub.Events:
// do another subscription here in "sync" mode, just so
// we're sure things aren't blocking
rl.QuerySync(
context.Background(),
&filter.F{Limit: values.ToUintPointer(1)},
)
n.Add(1)
if n.Load() == 3 {
// if we get here, it means the test passed
return
}
case <-sub.Context.Done():
case <-sub.EndOfStoredEvents:
sub.Unsub()
}
}
case <-sub.EndOfStoredEvents:
sub.Unsub()
return
case <-sub.Context.Done():
t.Fatalf("connection closed: %v", rl.Context().Err())
return
}
}
}