Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
7ec8698b62
|
|||
|
2514f875e6
|
|||
|
a6350c8e80
|
|||
|
6c3d22cb38
|
|||
|
8adb129fbe
|
|||
|
fd698af1ca
|
|||
|
ac4fd506e5
|
|||
|
8898b20d4b
|
|||
|
b351d0fb78
|
|||
|
9c8ff2976d
|
|||
|
a7dd958585
|
|||
|
8eb5b839b0
|
|||
|
e57169eeae
|
|||
|
109326dfa3
|
|||
| 52911354a7 | |||
|
b74f4757e7
|
|||
| 2d0ebfe032 | |||
| fff61ceca1 | |||
| b7b7dc7353 | |||
| 996fb3aeb7 | |||
| b9a713d81d | |||
|
1e6ce84e26
|
|||
|
0361f3843a
|
|||
|
4317e8ba4a
|
|||
|
9094f36d6e
|
|||
|
9314467f55
|
|||
|
19e6520587
|
|||
|
9e59a6c315
|
|||
|
9449435c65
|
|||
|
df8e66d9a7
|
|||
|
96eab2270d
|
60
.github/workflows/test-and-release.yml
vendored
Normal file
60
.github/workflows/test-and-release.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Test and Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*' # Triggers on tags like v1.2.3
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.22
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
- name: Run tests
|
||||
run: go test -v ./...
|
||||
|
||||
release:
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.22
|
||||
- name: Build binaries
|
||||
run: |
|
||||
mkdir -p dist
|
||||
GOOS=linux GOARCH=amd64 go build -o dist/app-linux-amd64
|
||||
GOOS=darwin GOARCH=amd64 go build -o dist/app-darwin-amd64
|
||||
GOOS=windows GOARCH=amd64 go build -o dist/app-windows-amd64.exe
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ github.ref_name }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Upload Release Assets
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -105,3 +105,4 @@ pkg/database/testrealy
|
||||
/.idea/orly.iml
|
||||
/.idea/go.imports.xml
|
||||
/.idea/inspectionProfiles/Project_Default.xml
|
||||
/.idea/.name
|
||||
|
||||
173
cmd/benchmark/BENCHMARK_RESULTS.md
Normal file
173
cmd/benchmark/BENCHMARK_RESULTS.md
Normal file
@@ -0,0 +1,173 @@
|
||||
# Orly Relay Benchmark Results
|
||||
|
||||
## Test Environment
|
||||
|
||||
- **Date**: August 5, 2025
|
||||
- **Relay**: Orly v0.4.14
|
||||
- **Port**: 3334 (WebSocket)
|
||||
- **System**: Linux 5.15.0-151-generic
|
||||
- **Storage**: BadgerDB v4
|
||||
|
||||
## Benchmark Test Results
|
||||
|
||||
### Test 1: Basic Performance (1,000 events, 1KB each)
|
||||
|
||||
**Parameters:**
|
||||
- Events: 1,000
|
||||
- Event size: 1,024 bytes
|
||||
- Concurrent publishers: 5
|
||||
- Queries: 50
|
||||
|
||||
**Results:**
|
||||
```
|
||||
Publish Performance:
|
||||
Events Published: 1,000
|
||||
Total Data: 4.01 MB
|
||||
Duration: 1.769s
|
||||
Rate: 565.42 events/second
|
||||
Bandwidth: 2.26 MB/second
|
||||
|
||||
Query Performance:
|
||||
Queries Executed: 50
|
||||
Events Returned: 2,000
|
||||
Duration: 3.058s
|
||||
Rate: 16.35 queries/second
|
||||
Avg Events/Query: 40.00
|
||||
```
|
||||
|
||||
### Test 2: Medium Load (10,000 events, 2KB each)
|
||||
|
||||
**Parameters:**
|
||||
- Events: 10,000
|
||||
- Event size: 2,048 bytes
|
||||
- Concurrent publishers: 10
|
||||
- Queries: 100
|
||||
|
||||
**Results:**
|
||||
```
|
||||
Publish Performance:
|
||||
Events Published: 10,000
|
||||
Total Data: 76.81 MB
|
||||
Duration: 598.301ms
|
||||
Rate: 16,714.00 events/second
|
||||
Bandwidth: 128.38 MB/second
|
||||
|
||||
Query Performance:
|
||||
Queries Executed: 100
|
||||
Events Returned: 4,000
|
||||
Duration: 8.923s
|
||||
Rate: 11.21 queries/second
|
||||
Avg Events/Query: 40.00
|
||||
```
|
||||
|
||||
### Test 3: High Concurrency (50,000 events, 512 bytes each)
|
||||
|
||||
**Parameters:**
|
||||
- Events: 50,000
|
||||
- Event size: 512 bytes
|
||||
- Concurrent publishers: 50
|
||||
- Queries: 200
|
||||
|
||||
**Results:**
|
||||
```
|
||||
Publish Performance:
|
||||
Events Published: 50,000
|
||||
Total Data: 108.63 MB
|
||||
Duration: 2.368s
|
||||
Rate: 21,118.66 events/second
|
||||
Bandwidth: 45.88 MB/second
|
||||
|
||||
Query Performance:
|
||||
Queries Executed: 200
|
||||
Events Returned: 8,000
|
||||
Duration: 36.146s
|
||||
Rate: 5.53 queries/second
|
||||
Avg Events/Query: 40.00
|
||||
```
|
||||
|
||||
### Test 4: Large Events (5,000 events, 10KB each)
|
||||
|
||||
**Parameters:**
|
||||
- Events: 5,000
|
||||
- Event size: 10,240 bytes
|
||||
- Concurrent publishers: 10
|
||||
- Queries: 50
|
||||
|
||||
**Results:**
|
||||
```
|
||||
Publish Performance:
|
||||
Events Published: 5,000
|
||||
Total Data: 185.26 MB
|
||||
Duration: 934.328ms
|
||||
Rate: 5,351.44 events/second
|
||||
Bandwidth: 198.28 MB/second
|
||||
|
||||
Query Performance:
|
||||
Queries Executed: 50
|
||||
Events Returned: 2,000
|
||||
Duration: 9.982s
|
||||
Rate: 5.01 queries/second
|
||||
Avg Events/Query: 40.00
|
||||
```
|
||||
|
||||
### Test 5: Query-Only Performance (500 queries)
|
||||
|
||||
**Parameters:**
|
||||
- Skip publishing phase
|
||||
- Queries: 500
|
||||
- Query limit: 100
|
||||
|
||||
**Results:**
|
||||
```
|
||||
Query Performance:
|
||||
Queries Executed: 500
|
||||
Events Returned: 20,000
|
||||
Duration: 1m14.384s
|
||||
Rate: 6.72 queries/second
|
||||
Avg Events/Query: 40.00
|
||||
```
|
||||
|
||||
## Performance Summary
|
||||
|
||||
### Publishing Performance
|
||||
|
||||
| Metric | Best Result | Test Configuration |
|
||||
|--------|-------------|-------------------|
|
||||
| **Peak Event Rate** | 21,118.66 events/sec | 50 concurrent publishers, 512-byte events |
|
||||
| **Peak Bandwidth** | 198.28 MB/sec | 10 concurrent publishers, 10KB events |
|
||||
| **Optimal Balance** | 16,714.00 events/sec @ 128.38 MB/sec | 10 concurrent publishers, 2KB events |
|
||||
|
||||
### Query Performance
|
||||
|
||||
| Query Type | Avg Rate | Notes |
|
||||
|------------|----------|--------|
|
||||
| **Light Load** | 16.35 queries/sec | 50 queries after 1K events |
|
||||
| **Medium Load** | 11.21 queries/sec | 100 queries after 10K events |
|
||||
| **Heavy Load** | 5.53 queries/sec | 200 queries after 50K events |
|
||||
| **Sustained** | 6.72 queries/sec | 500 continuous queries |
|
||||
|
||||
## Key Findings
|
||||
|
||||
1. **Optimal Concurrency**: The relay performs best with 10-50 concurrent publishers, achieving rates of 16,000-21,000 events/second.
|
||||
|
||||
2. **Event Size Impact**:
|
||||
- Smaller events (512B-2KB) achieve higher event rates
|
||||
- Larger events (10KB) achieve higher bandwidth utilization but lower event rates
|
||||
|
||||
3. **Query Performance**: Query performance varies with database size:
|
||||
- Fresh database: ~16 queries/second
|
||||
- After 50K events: ~6 queries/second
|
||||
|
||||
4. **Scalability**: The relay maintains consistent performance up to 50 concurrent connections and can sustain 21,000+ events/second under optimal conditions.
|
||||
|
||||
## Query Filter Distribution
|
||||
|
||||
The benchmark tested 5 different query patterns in rotation:
|
||||
1. Query by kind (20%)
|
||||
2. Query by time range (20%)
|
||||
3. Query by tag (20%)
|
||||
4. Query by author (20%)
|
||||
5. Complex queries with multiple conditions (20%)
|
||||
|
||||
All query types showed similar performance characteristics, indicating well-balanced indexing.
|
||||
|
||||
112
cmd/benchmark/README.md
Normal file
112
cmd/benchmark/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# Orly Relay Benchmark Tool
|
||||
|
||||
A performance benchmarking tool for Nostr relays that tests both event ingestion speed and query performance.
|
||||
|
||||
## Quick Start (Simple Version)
|
||||
|
||||
The repository includes a simple standalone benchmark tool that doesn't require the full Orly dependencies:
|
||||
|
||||
```bash
|
||||
# Build the simple benchmark
|
||||
go build -o benchmark-simple ./benchmark_simple.go
|
||||
|
||||
# Run with default settings
|
||||
./benchmark-simple
|
||||
|
||||
# Or use the convenience script
|
||||
chmod +x run_benchmark.sh
|
||||
./run_benchmark.sh --relay ws://localhost:7447 --events 10000
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Event Publishing Benchmark**: Tests how fast a relay can accept and store events
|
||||
- **Query Performance Benchmark**: Tests various filter types and query speeds
|
||||
- **Concurrent Publishing**: Supports multiple concurrent publishers to stress test the relay
|
||||
- **Detailed Metrics**: Reports events/second, bandwidth usage, and query performance
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Build the tool
|
||||
go build -o benchmark ./cmd/benchmark
|
||||
|
||||
# Run a full benchmark (publish and query)
|
||||
./benchmark -relay ws://localhost:7447 -events 10000 -queries 100
|
||||
|
||||
# Benchmark only publishing
|
||||
./benchmark -relay ws://localhost:7447 -events 50000 -concurrency 20 -skip-query
|
||||
|
||||
# Benchmark only querying
|
||||
./benchmark -relay ws://localhost:7447 -queries 500 -skip-publish
|
||||
|
||||
# Use custom event sizes
|
||||
./benchmark -relay ws://localhost:7447 -events 10000 -size 2048
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `-relay`: Relay URL to benchmark (default: ws://localhost:7447)
|
||||
- `-events`: Number of events to publish (default: 10000)
|
||||
- `-size`: Average size of event content in bytes (default: 1024)
|
||||
- `-concurrency`: Number of concurrent publishers (default: 10)
|
||||
- `-queries`: Number of queries to execute (default: 100)
|
||||
- `-query-limit`: Limit for each query (default: 100)
|
||||
- `-skip-publish`: Skip the publishing phase
|
||||
- `-skip-query`: Skip the query phase
|
||||
- `-v`: Enable verbose output
|
||||
|
||||
## Query Types Tested
|
||||
|
||||
The benchmark tests various query patterns:
|
||||
1. Query by kind
|
||||
2. Query by time range (last hour)
|
||||
3. Query by tag (p tags)
|
||||
4. Query by author
|
||||
5. Complex queries with multiple conditions
|
||||
|
||||
## Output
|
||||
|
||||
The tool provides detailed metrics including:
|
||||
|
||||
**Publish Performance:**
|
||||
- Total events published
|
||||
- Total data transferred
|
||||
- Publishing rate (events/second)
|
||||
- Bandwidth usage (MB/second)
|
||||
|
||||
**Query Performance:**
|
||||
- Total queries executed
|
||||
- Total events returned
|
||||
- Query rate (queries/second)
|
||||
- Average events per query
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
Publishing 10000 events to ws://localhost:7447...
|
||||
Published 1000 events...
|
||||
Published 2000 events...
|
||||
...
|
||||
|
||||
Querying events from ws://localhost:7447...
|
||||
Executed 20 queries...
|
||||
Executed 40 queries...
|
||||
...
|
||||
|
||||
=== Benchmark Results ===
|
||||
|
||||
Publish Performance:
|
||||
Events Published: 10000
|
||||
Total Data: 12.34 MB
|
||||
Duration: 5.2s
|
||||
Rate: 1923.08 events/second
|
||||
Bandwidth: 2.37 MB/second
|
||||
|
||||
Query Performance:
|
||||
Queries Executed: 100
|
||||
Events Returned: 4523
|
||||
Duration: 2.1s
|
||||
Rate: 47.62 queries/second
|
||||
Avg Events/Query: 45.23
|
||||
```
|
||||
304
cmd/benchmark/benchmark_simple.go
Normal file
304
cmd/benchmark/benchmark_simple.go
Normal file
@@ -0,0 +1,304 @@
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
// Simple event structure for benchmarking
|
||||
type Event struct {
|
||||
ID string `json:"id"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Kind int `json:"kind"`
|
||||
Tags [][]string `json:"tags"`
|
||||
Content string `json:"content"`
|
||||
Sig string `json:"sig"`
|
||||
}
|
||||
|
||||
// Generate a test event
|
||||
func generateTestEvent(size int) *Event {
|
||||
content := make([]byte, size)
|
||||
rand.Read(content)
|
||||
|
||||
// Generate random pubkey and sig
|
||||
pubkey := make([]byte, 32)
|
||||
sig := make([]byte, 64)
|
||||
rand.Read(pubkey)
|
||||
rand.Read(sig)
|
||||
|
||||
ev := &Event{
|
||||
Pubkey: hex.EncodeToString(pubkey),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
Kind: 1,
|
||||
Tags: [][]string{},
|
||||
Content: string(content),
|
||||
Sig: hex.EncodeToString(sig),
|
||||
}
|
||||
|
||||
// Generate ID (simplified)
|
||||
serialized, _ := json.Marshal([]interface{}{
|
||||
0,
|
||||
ev.Pubkey,
|
||||
ev.CreatedAt,
|
||||
ev.Kind,
|
||||
ev.Tags,
|
||||
ev.Content,
|
||||
})
|
||||
hash := sha256.Sum256(serialized)
|
||||
ev.ID = hex.EncodeToString(hash[:])
|
||||
|
||||
return ev
|
||||
}
|
||||
|
||||
func publishEvents(relayURL string, count int, size int, concurrency int) (int64, int64, time.Duration, error) {
|
||||
u, err := url.Parse(relayURL)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
|
||||
var publishedEvents atomic.Int64
|
||||
var publishedBytes atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
eventsPerWorker := count / concurrency
|
||||
extraEvents := count % concurrency
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
eventsToPublish := eventsPerWorker
|
||||
if i < extraEvents {
|
||||
eventsToPublish++
|
||||
}
|
||||
|
||||
go func(workerID int, eventCount int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Connect to relay
|
||||
ctx := context.Background()
|
||||
conn, _, _, err := ws.Dial(ctx, u.String())
|
||||
if err != nil {
|
||||
log.Printf("Worker %d: connection error: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Publish events
|
||||
for j := 0; j < eventCount; j++ {
|
||||
ev := generateTestEvent(size)
|
||||
|
||||
// Create EVENT message
|
||||
msg, _ := json.Marshal([]interface{}{"EVENT", ev})
|
||||
|
||||
err := wsutil.WriteClientMessage(conn, ws.OpText, msg)
|
||||
if err != nil {
|
||||
log.Printf("Worker %d: write error: %v", workerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
publishedEvents.Add(1)
|
||||
publishedBytes.Add(int64(len(msg)))
|
||||
|
||||
// Read response (OK or error)
|
||||
_, _, err = wsutil.ReadServerData(conn)
|
||||
if err != nil {
|
||||
log.Printf("Worker %d: read error: %v", workerID, err)
|
||||
}
|
||||
}
|
||||
}(i, eventsToPublish)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
duration := time.Since(start)
|
||||
|
||||
return publishedEvents.Load(), publishedBytes.Load(), duration, nil
|
||||
}
|
||||
|
||||
func queryEvents(relayURL string, queries int, limit int) (int64, int64, time.Duration, error) {
|
||||
u, err := url.Parse(relayURL)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
conn, _, _, err := ws.Dial(ctx, u.String())
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var totalQueries int64
|
||||
var totalEvents int64
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < queries; i++ {
|
||||
// Generate various filter types
|
||||
var filter map[string]interface{}
|
||||
|
||||
switch i % 5 {
|
||||
case 0:
|
||||
// Query by kind
|
||||
filter = map[string]interface{}{
|
||||
"kinds": []int{1},
|
||||
"limit": limit,
|
||||
}
|
||||
case 1:
|
||||
// Query by time range
|
||||
now := time.Now().Unix()
|
||||
filter = map[string]interface{}{
|
||||
"since": now - 3600,
|
||||
"until": now,
|
||||
"limit": limit,
|
||||
}
|
||||
case 2:
|
||||
// Query by tag
|
||||
filter = map[string]interface{}{
|
||||
"#p": []string{hex.EncodeToString(randBytes(32))},
|
||||
"limit": limit,
|
||||
}
|
||||
case 3:
|
||||
// Query by author
|
||||
filter = map[string]interface{}{
|
||||
"authors": []string{hex.EncodeToString(randBytes(32))},
|
||||
"limit": limit,
|
||||
}
|
||||
case 4:
|
||||
// Complex query
|
||||
now := time.Now().Unix()
|
||||
filter = map[string]interface{}{
|
||||
"kinds": []int{1, 6},
|
||||
"authors": []string{hex.EncodeToString(randBytes(32))},
|
||||
"since": now - 7200,
|
||||
"limit": limit,
|
||||
}
|
||||
}
|
||||
|
||||
// Send REQ
|
||||
subID := fmt.Sprintf("bench-%d", i)
|
||||
msg, _ := json.Marshal([]interface{}{"REQ", subID, filter})
|
||||
|
||||
err := wsutil.WriteClientMessage(conn, ws.OpText, msg)
|
||||
if err != nil {
|
||||
log.Printf("Query %d: write error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Read events until EOSE
|
||||
eventCount := 0
|
||||
for {
|
||||
data, err := wsutil.ReadServerText(conn)
|
||||
if err != nil {
|
||||
log.Printf("Query %d: read error: %v", i, err)
|
||||
break
|
||||
}
|
||||
|
||||
var msg []interface{}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(msg) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, ok := msg[0].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case "EVENT":
|
||||
eventCount++
|
||||
case "EOSE":
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
|
||||
// Send CLOSE
|
||||
closeMsg, _ := json.Marshal([]interface{}{"CLOSE", subID})
|
||||
wsutil.WriteClientMessage(conn, ws.OpText, closeMsg)
|
||||
|
||||
totalQueries++
|
||||
totalEvents += int64(eventCount)
|
||||
|
||||
if totalQueries%20 == 0 {
|
||||
fmt.Printf(" Executed %d queries...\n", totalQueries)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
return totalQueries, totalEvents, duration, nil
|
||||
}
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
rand.Read(b)
|
||||
return b
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark")
|
||||
eventCount = flag.Int("events", 10000, "Number of events to publish")
|
||||
eventSize = flag.Int("size", 1024, "Average size of event content in bytes")
|
||||
concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers")
|
||||
queryCount = flag.Int("queries", 100, "Number of queries to execute")
|
||||
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
|
||||
skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase")
|
||||
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
fmt.Printf("=== Nostr Relay Benchmark ===\n\n")
|
||||
|
||||
// Phase 1: Publish events
|
||||
if !*skipPublish {
|
||||
fmt.Printf("Publishing %d events to %s...\n", *eventCount, *relayURL)
|
||||
published, bytes, duration, err := publishEvents(*relayURL, *eventCount, *eventSize, *concurrency)
|
||||
if err != nil {
|
||||
log.Fatalf("Publishing failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nPublish Performance:\n")
|
||||
fmt.Printf(" Events Published: %d\n", published)
|
||||
fmt.Printf(" Total Data: %.2f MB\n", float64(bytes)/1024/1024)
|
||||
fmt.Printf(" Duration: %s\n", duration)
|
||||
fmt.Printf(" Rate: %.2f events/second\n", float64(published)/duration.Seconds())
|
||||
fmt.Printf(" Bandwidth: %.2f MB/second\n", float64(bytes)/duration.Seconds()/1024/1024)
|
||||
}
|
||||
|
||||
// Phase 2: Query events
|
||||
if !*skipQuery {
|
||||
fmt.Printf("\nQuerying events from %s...\n", *relayURL)
|
||||
queries, events, duration, err := queryEvents(*relayURL, *queryCount, *queryLimit)
|
||||
if err != nil {
|
||||
log.Fatalf("Querying failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nQuery Performance:\n")
|
||||
fmt.Printf(" Queries Executed: %d\n", queries)
|
||||
fmt.Printf(" Events Returned: %d\n", events)
|
||||
fmt.Printf(" Duration: %s\n", duration)
|
||||
fmt.Printf(" Rate: %.2f queries/second\n", float64(queries)/duration.Seconds())
|
||||
fmt.Printf(" Avg Events/Query: %.2f\n", float64(events)/float64(queries))
|
||||
}
|
||||
}
|
||||
320
cmd/benchmark/main.go
Normal file
320
cmd/benchmark/main.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"lukechampine.com/frand"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/filter"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/kinds"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
"orly.dev/pkg/encoders/text"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/protocol/ws"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/lol"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BenchmarkResults struct {
|
||||
EventsPublished int64
|
||||
EventsPublishedBytes int64
|
||||
PublishDuration time.Duration
|
||||
PublishRate float64
|
||||
PublishBandwidth float64
|
||||
|
||||
QueriesExecuted int64
|
||||
QueryDuration time.Duration
|
||||
QueryRate float64
|
||||
EventsReturned int64
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark")
|
||||
eventCount = flag.Int("events", 10000, "Number of events to publish")
|
||||
eventSize = flag.Int("size", 1024, "Average size of event content in bytes")
|
||||
concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers")
|
||||
queryCount = flag.Int("queries", 100, "Number of queries to execute")
|
||||
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
|
||||
skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase")
|
||||
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
|
||||
verbose = flag.Bool("v", false, "Verbose output")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *verbose {
|
||||
lol.SetLogLevel("trace")
|
||||
}
|
||||
|
||||
c := context.Bg()
|
||||
results := &BenchmarkResults{}
|
||||
|
||||
// Phase 1: Publish events
|
||||
if !*skipPublish {
|
||||
fmt.Printf("Publishing %d events to %s...\n", *eventCount, *relayURL)
|
||||
if err := benchmarkPublish(c, *relayURL, *eventCount, *eventSize, *concurrency, results); chk.E(err) {
|
||||
fmt.Fprintf(os.Stderr, "Error during publish benchmark: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Query events
|
||||
if !*skipQuery {
|
||||
fmt.Printf("\nQuerying events from %s...\n", *relayURL)
|
||||
if err := benchmarkQuery(c, *relayURL, *queryCount, *queryLimit, results); chk.E(err) {
|
||||
fmt.Fprintf(os.Stderr, "Error during query benchmark: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Print results
|
||||
printResults(results)
|
||||
}
|
||||
|
||||
func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concurrency int, results *BenchmarkResults) error {
|
||||
// Generate signers for each concurrent publisher
|
||||
signers := make([]*testSigner, concurrency)
|
||||
for i := range signers {
|
||||
signers[i] = newTestSigner()
|
||||
}
|
||||
|
||||
// Track published events
|
||||
var publishedEvents atomic.Int64
|
||||
var publishedBytes atomic.Int64
|
||||
var errors atomic.Int64
|
||||
|
||||
// Create wait group for concurrent publishers
|
||||
var wg sync.WaitGroup
|
||||
eventsPerPublisher := eventCount / concurrency
|
||||
extraEvents := eventCount % concurrency
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(publisherID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Connect to relay
|
||||
relay, err := ws.RelayConnect(c, relayURL)
|
||||
if err != nil {
|
||||
log.E.F("Publisher %d failed to connect: %v", publisherID, err)
|
||||
errors.Add(1)
|
||||
return
|
||||
}
|
||||
defer relay.Close()
|
||||
|
||||
// Calculate events for this publisher
|
||||
eventsToPublish := eventsPerPublisher
|
||||
if publisherID < extraEvents {
|
||||
eventsToPublish++
|
||||
}
|
||||
|
||||
signer := signers[publisherID]
|
||||
|
||||
// Publish events
|
||||
for j := 0; j < eventsToPublish; j++ {
|
||||
ev := generateEvent(signer, eventSize)
|
||||
|
||||
if err := relay.Publish(c, ev); err != nil {
|
||||
log.E.F("Publisher %d failed to publish event: %v", publisherID, err)
|
||||
errors.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
evBytes := ev.Marshal(nil)
|
||||
publishedEvents.Add(1)
|
||||
publishedBytes.Add(int64(len(evBytes)))
|
||||
|
||||
if publishedEvents.Load()%1000 == 0 {
|
||||
fmt.Printf(" Published %d events...\n", publishedEvents.Load())
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
results.EventsPublished = publishedEvents.Load()
|
||||
results.EventsPublishedBytes = publishedBytes.Load()
|
||||
results.PublishDuration = duration
|
||||
results.PublishRate = float64(results.EventsPublished) / duration.Seconds()
|
||||
results.PublishBandwidth = float64(results.EventsPublishedBytes) / duration.Seconds() / 1024 / 1024 // MB/s
|
||||
|
||||
if errors.Load() > 0 {
|
||||
fmt.Printf(" Warning: %d errors occurred during publishing\n", errors.Load())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func benchmarkQuery(c context.T, relayURL string, queryCount, queryLimit int, results *BenchmarkResults) error {
|
||||
relay, err := ws.RelayConnect(c, relayURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to relay: %w", err)
|
||||
}
|
||||
defer relay.Close()
|
||||
|
||||
var totalEvents atomic.Int64
|
||||
var totalQueries atomic.Int64
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for i := 0; i < queryCount; i++ {
|
||||
// Generate various filter types
|
||||
var f *filter.F
|
||||
switch i % 5 {
|
||||
case 0:
|
||||
// Query by kind
|
||||
limit := uint(queryLimit)
|
||||
f = &filter.F{
|
||||
Kinds: kinds.New(kind.TextNote),
|
||||
Limit: &limit,
|
||||
}
|
||||
case 1:
|
||||
// Query by time range
|
||||
now := timestamp.Now()
|
||||
since := timestamp.New(now.I64() - 3600) // last hour
|
||||
limit := uint(queryLimit)
|
||||
f = &filter.F{
|
||||
Since: since,
|
||||
Until: now,
|
||||
Limit: &limit,
|
||||
}
|
||||
case 2:
|
||||
// Query by tag
|
||||
limit := uint(queryLimit)
|
||||
f = &filter.F{
|
||||
Tags: tags.New(tag.New([]byte("p"), generateRandomPubkey())),
|
||||
Limit: &limit,
|
||||
}
|
||||
case 3:
|
||||
// Query by author
|
||||
limit := uint(queryLimit)
|
||||
f = &filter.F{
|
||||
Authors: tag.New(generateRandomPubkey()),
|
||||
Limit: &limit,
|
||||
}
|
||||
case 4:
|
||||
// Complex query with multiple conditions
|
||||
now := timestamp.Now()
|
||||
since := timestamp.New(now.I64() - 7200)
|
||||
limit := uint(queryLimit)
|
||||
f = &filter.F{
|
||||
Kinds: kinds.New(kind.TextNote, kind.Repost),
|
||||
Authors: tag.New(generateRandomPubkey()),
|
||||
Since: since,
|
||||
Limit: &limit,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
events, err := relay.QuerySync(c, f, ws.WithLabel("benchmark"))
|
||||
if err != nil {
|
||||
log.E.F("Query %d failed: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
totalEvents.Add(int64(len(events)))
|
||||
totalQueries.Add(1)
|
||||
|
||||
if totalQueries.Load()%20 == 0 {
|
||||
fmt.Printf(" Executed %d queries...\n", totalQueries.Load())
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
results.QueriesExecuted = totalQueries.Load()
|
||||
results.QueryDuration = duration
|
||||
results.QueryRate = float64(results.QueriesExecuted) / duration.Seconds()
|
||||
results.EventsReturned = totalEvents.Load()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateEvent(signer *testSigner, contentSize int) *event.E {
|
||||
// Generate content with some variation
|
||||
size := contentSize + frand.Intn(contentSize/2) - contentSize/4
|
||||
if size < 10 {
|
||||
size = 10
|
||||
}
|
||||
|
||||
content := text.NostrEscape(nil, frand.Bytes(size))
|
||||
|
||||
ev := &event.E{
|
||||
Pubkey: signer.Pub(),
|
||||
Kind: kind.TextNote,
|
||||
CreatedAt: timestamp.Now(),
|
||||
Content: content,
|
||||
Tags: generateRandomTags(),
|
||||
}
|
||||
|
||||
if err := ev.Sign(signer); chk.E(err) {
|
||||
panic(fmt.Sprintf("failed to sign event: %v", err))
|
||||
}
|
||||
|
||||
return ev
|
||||
}
|
||||
|
||||
func generateRandomTags() *tags.T {
|
||||
t := tags.New()
|
||||
|
||||
// Add some random tags
|
||||
numTags := frand.Intn(5)
|
||||
for i := 0; i < numTags; i++ {
|
||||
switch frand.Intn(3) {
|
||||
case 0:
|
||||
// p tag
|
||||
t.AppendUnique(tag.New([]byte("p"), generateRandomPubkey()))
|
||||
case 1:
|
||||
// e tag
|
||||
t.AppendUnique(tag.New([]byte("e"), generateRandomEventID()))
|
||||
case 2:
|
||||
// t tag
|
||||
t.AppendUnique(tag.New([]byte("t"), []byte(fmt.Sprintf("topic%d", frand.Intn(100)))))
|
||||
}
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func generateRandomPubkey() []byte {
|
||||
return frand.Bytes(32)
|
||||
}
|
||||
|
||||
func generateRandomEventID() []byte {
|
||||
return frand.Bytes(32)
|
||||
}
|
||||
|
||||
func printResults(results *BenchmarkResults) {
|
||||
fmt.Println("\n=== Benchmark Results ===")
|
||||
|
||||
if results.EventsPublished > 0 {
|
||||
fmt.Println("\nPublish Performance:")
|
||||
fmt.Printf(" Events Published: %d\n", results.EventsPublished)
|
||||
fmt.Printf(" Total Data: %.2f MB\n", float64(results.EventsPublishedBytes)/1024/1024)
|
||||
fmt.Printf(" Duration: %s\n", results.PublishDuration)
|
||||
fmt.Printf(" Rate: %.2f events/second\n", results.PublishRate)
|
||||
fmt.Printf(" Bandwidth: %.2f MB/second\n", results.PublishBandwidth)
|
||||
}
|
||||
|
||||
if results.QueriesExecuted > 0 {
|
||||
fmt.Println("\nQuery Performance:")
|
||||
fmt.Printf(" Queries Executed: %d\n", results.QueriesExecuted)
|
||||
fmt.Printf(" Events Returned: %d\n", results.EventsReturned)
|
||||
fmt.Printf(" Duration: %s\n", results.QueryDuration)
|
||||
fmt.Printf(" Rate: %.2f queries/second\n", results.QueryRate)
|
||||
avgEventsPerQuery := float64(results.EventsReturned) / float64(results.QueriesExecuted)
|
||||
fmt.Printf(" Avg Events/Query: %.2f\n", avgEventsPerQuery)
|
||||
}
|
||||
}
|
||||
82
cmd/benchmark/run_benchmark.sh
Executable file
82
cmd/benchmark/run_benchmark.sh
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Simple Nostr Relay Benchmark Script
|
||||
|
||||
# Default values
|
||||
RELAY_URL="ws://localhost:7447"
|
||||
EVENTS=10000
|
||||
SIZE=1024
|
||||
CONCURRENCY=10
|
||||
QUERIES=100
|
||||
QUERY_LIMIT=100
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--relay)
|
||||
RELAY_URL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--events)
|
||||
EVENTS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--size)
|
||||
SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--concurrency)
|
||||
CONCURRENCY="$2"
|
||||
shift 2
|
||||
;;
|
||||
--queries)
|
||||
QUERIES="$2"
|
||||
shift 2
|
||||
;;
|
||||
--query-limit)
|
||||
QUERY_LIMIT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--skip-publish)
|
||||
SKIP_PUBLISH="-skip-publish"
|
||||
shift
|
||||
;;
|
||||
--skip-query)
|
||||
SKIP_QUERY="-skip-query"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Usage: $0 [--relay URL] [--events N] [--size N] [--concurrency N] [--queries N] [--query-limit N] [--skip-publish] [--skip-query]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Build the benchmark tool if it doesn't exist
|
||||
if [ ! -f benchmark-simple ]; then
|
||||
echo "Building benchmark tool..."
|
||||
go build -o benchmark-simple ./benchmark_simple.go
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed to build benchmark tool"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Run the benchmark
|
||||
echo "Running Nostr relay benchmark..."
|
||||
echo "Relay: $RELAY_URL"
|
||||
echo "Events: $EVENTS (size: $SIZE bytes)"
|
||||
echo "Concurrency: $CONCURRENCY"
|
||||
echo "Queries: $QUERIES (limit: $QUERY_LIMIT)"
|
||||
echo ""
|
||||
|
||||
./benchmark-simple \
|
||||
-relay "$RELAY_URL" \
|
||||
-events $EVENTS \
|
||||
-size $SIZE \
|
||||
-concurrency $CONCURRENCY \
|
||||
-queries $QUERIES \
|
||||
-query-limit $QUERY_LIMIT \
|
||||
$SKIP_PUBLISH \
|
||||
$SKIP_QUERY
|
||||
63
cmd/benchmark/test_signer.go
Normal file
63
cmd/benchmark/test_signer.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"lukechampine.com/frand"
|
||||
"orly.dev/pkg/interfaces/signer"
|
||||
)
|
||||
|
||||
// testSigner is a simple signer implementation for benchmarking
|
||||
type testSigner struct {
|
||||
pub []byte
|
||||
sec []byte
|
||||
}
|
||||
|
||||
func newTestSigner() *testSigner {
|
||||
return &testSigner{
|
||||
pub: frand.Bytes(32),
|
||||
sec: frand.Bytes(32),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testSigner) Pub() []byte {
|
||||
return s.pub
|
||||
}
|
||||
|
||||
func (s *testSigner) Sec() []byte {
|
||||
return s.sec
|
||||
}
|
||||
|
||||
func (s *testSigner) Sign(msg []byte) ([]byte, error) {
|
||||
return frand.Bytes(64), nil
|
||||
}
|
||||
|
||||
func (s *testSigner) Verify(msg, sig []byte) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *testSigner) InitSec(sec []byte) error {
|
||||
s.sec = sec
|
||||
s.pub = frand.Bytes(32)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testSigner) InitPub(pub []byte) error {
|
||||
s.pub = pub
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testSigner) Zero() {
|
||||
for i := range s.sec {
|
||||
s.sec[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (s *testSigner) ECDH(pubkey []byte) ([]byte, error) {
|
||||
return frand.Bytes(32), nil
|
||||
}
|
||||
|
||||
func (s *testSigner) Generate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ signer.I = (*testSigner)(nil)
|
||||
@@ -1,15 +0,0 @@
|
||||
package app
|
||||
|
||||
import "sync"
|
||||
|
||||
var bufferPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
type Pool struct{}
|
||||
|
||||
func (bp Pool) Get() []byte { return *(bufferPool.Get().(*[]byte)) }
|
||||
func (bp Pool) Put(b []byte) { bufferPool.Put(&b) }
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/crypto/sha256"
|
||||
"orly.dev/pkg/encoders/bech32encoding"
|
||||
@@ -18,7 +20,6 @@ import (
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
realy_lol "orly.dev/pkg/version"
|
||||
"os"
|
||||
)
|
||||
|
||||
const secEnv = "NOSTR_SECRET_KEY"
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"orly.dev/pkg/protocol/nwc"
|
||||
)
|
||||
|
||||
func printUsage() {
|
||||
fmt.Println("Usage: nwcclient \"<connection URL>\" <method> [parameters...]")
|
||||
fmt.Println("\nSupported methods:")
|
||||
fmt.Println(" get_info - Get wallet information")
|
||||
fmt.Println(" get_balance - Get wallet balance")
|
||||
fmt.Println(" get_budget - Get wallet budget")
|
||||
fmt.Println(" make_invoice - Create an invoice (amount, description, [description_hash], [expiry])")
|
||||
fmt.Println(" pay_invoice - Pay an invoice (invoice, [amount])")
|
||||
fmt.Println(" pay_keysend - Send a keysend payment (amount, pubkey, [preimage])")
|
||||
fmt.Println(" lookup_invoice - Look up an invoice (payment_hash or invoice)")
|
||||
fmt.Println(" list_transactions - List transactions ([from], [until], [limit], [offset], [unpaid], [type])")
|
||||
fmt.Println(" sign_message - Sign a message (message)")
|
||||
fmt.Println("\nUnsupported methods (due to limitations in the nwc package):")
|
||||
fmt.Println(" create_connection - Create a connection")
|
||||
fmt.Println(" make_hold_invoice - Create a hold invoice")
|
||||
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
|
||||
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
|
||||
fmt.Println(" multi_pay_invoice - Pay multiple invoices")
|
||||
fmt.Println(" multi_pay_keysend - Send multiple keysend payments")
|
||||
fmt.Println("\nParameters format:")
|
||||
fmt.Println(" - Positional parameters are used for required fields")
|
||||
fmt.Println(" - For list_transactions, named parameters are used: 'from', 'until', 'limit', 'offset', 'unpaid', 'type'")
|
||||
fmt.Println(" Example: nwcclient <url> list_transactions limit 10 type incoming")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Check if we have enough arguments
|
||||
if len(os.Args) < 3 {
|
||||
printUsage()
|
||||
}
|
||||
|
||||
// Parse connection URL and method
|
||||
connectionURL := os.Args[1]
|
||||
methodStr := os.Args[2]
|
||||
method := nwc.Method(methodStr)
|
||||
|
||||
// Parse the wallet connect URL
|
||||
opts, err := nwc.ParseWalletConnectURL(connectionURL)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing connection URL: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create a new NWC client
|
||||
client, err := nwc.NewNWCClient(opts)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating NWC client: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Execute the requested method
|
||||
var result interface{}
|
||||
|
||||
switch method {
|
||||
case nwc.GetInfo:
|
||||
result, err = client.GetInfo()
|
||||
|
||||
case nwc.GetBalance:
|
||||
result, err = client.GetBalance()
|
||||
|
||||
case nwc.GetBudget:
|
||||
result, err = client.GetBudget()
|
||||
|
||||
case nwc.MakeInvoice:
|
||||
if len(os.Args) < 5 {
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"Error: make_invoice requires at least amount and description\n",
|
||||
)
|
||||
printUsage()
|
||||
}
|
||||
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
description := os.Args[4]
|
||||
|
||||
req := &nwc.MakeInvoiceRequest{
|
||||
Amount: amount,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
// Optional parameters
|
||||
if len(os.Args) > 5 {
|
||||
req.DescriptionHash = os.Args[5]
|
||||
}
|
||||
if len(os.Args) > 6 {
|
||||
expiry, err := strconv.ParseInt(os.Args[6], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing expiry: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Expiry = &expiry
|
||||
}
|
||||
|
||||
result, err = client.MakeInvoice(req)
|
||||
|
||||
case nwc.PayInvoice:
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Fprintf(os.Stderr, "Error: pay_invoice requires an invoice\n")
|
||||
printUsage()
|
||||
}
|
||||
|
||||
req := &nwc.PayInvoiceRequest{
|
||||
Invoice: os.Args[3],
|
||||
}
|
||||
|
||||
// Optional amount parameter
|
||||
if len(os.Args) > 4 {
|
||||
amount, err := strconv.ParseInt(os.Args[4], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Amount = &amount
|
||||
}
|
||||
|
||||
result, err = client.PayInvoice(req)
|
||||
|
||||
case nwc.PayKeysend:
|
||||
if len(os.Args) < 5 {
|
||||
fmt.Fprintf(
|
||||
os.Stderr, "Error: pay_keysend requires amount and pubkey\n",
|
||||
)
|
||||
printUsage()
|
||||
}
|
||||
|
||||
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req := &nwc.PayKeysendRequest{
|
||||
Amount: amount,
|
||||
Pubkey: os.Args[4],
|
||||
}
|
||||
|
||||
// Optional preimage
|
||||
if len(os.Args) > 5 {
|
||||
req.Preimage = os.Args[5]
|
||||
}
|
||||
|
||||
result, err = client.PayKeysend(req)
|
||||
|
||||
case nwc.LookupInvoice:
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"Error: lookup_invoice requires a payment_hash or invoice\n",
|
||||
)
|
||||
printUsage()
|
||||
}
|
||||
|
||||
param := os.Args[3]
|
||||
req := &nwc.LookupInvoiceRequest{}
|
||||
|
||||
// Determine if the parameter is a payment hash or an invoice
|
||||
if strings.HasPrefix(param, "ln") {
|
||||
req.Invoice = param
|
||||
} else {
|
||||
req.PaymentHash = param
|
||||
}
|
||||
|
||||
result, err = client.LookupInvoice(req)
|
||||
|
||||
case nwc.ListTransactions:
|
||||
req := &nwc.ListTransactionsRequest{}
|
||||
|
||||
// Parse optional parameters
|
||||
paramIndex := 3
|
||||
for paramIndex < len(os.Args) {
|
||||
if paramIndex+1 >= len(os.Args) {
|
||||
break
|
||||
}
|
||||
|
||||
paramName := os.Args[paramIndex]
|
||||
paramValue := os.Args[paramIndex+1]
|
||||
|
||||
switch paramName {
|
||||
case "from":
|
||||
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing from: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.From = &val
|
||||
case "until":
|
||||
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing until: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Until = &val
|
||||
case "limit":
|
||||
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing limit: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Limit = &val
|
||||
case "offset":
|
||||
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing offset: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Offset = &val
|
||||
case "unpaid":
|
||||
val := paramValue == "true"
|
||||
req.Unpaid = &val
|
||||
case "type":
|
||||
req.Type = ¶mValue
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
paramIndex += 2
|
||||
}
|
||||
|
||||
result, err = client.ListTransactions(req)
|
||||
|
||||
case nwc.SignMessage:
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Fprintf(os.Stderr, "Error: sign_message requires a message\n")
|
||||
printUsage()
|
||||
}
|
||||
|
||||
req := &nwc.SignMessageRequest{
|
||||
Message: os.Args[3],
|
||||
}
|
||||
|
||||
result, err = client.SignMessage(req)
|
||||
|
||||
case nwc.CreateConnection, nwc.MakeHoldInvoice, nwc.SettleHoldInvoice, nwc.CancelHoldInvoice, nwc.MultiPayInvoice, nwc.MultiPayKeysend:
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"Error: Method %s is not directly supported by the CLI tool.\n",
|
||||
methodStr,
|
||||
)
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"This is because these methods don't have exported client methods in the nwc package.\n",
|
||||
)
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"Only the following methods are currently supported: get_info, get_balance, get_budget, make_invoice, pay_invoice, pay_keysend, lookup_invoice, list_transactions, sign_message\n",
|
||||
)
|
||||
os.Exit(1)
|
||||
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Error: Unsupported method: %s\n", methodStr)
|
||||
printUsage()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing method: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print the result as JSON
|
||||
jsonData, err := json.MarshalIndent(result, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marshaling result to JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println(string(jsonData))
|
||||
}
|
||||
417
cmd/walletcli/main.go
Normal file
417
cmd/walletcli/main.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"orly.dev/pkg/protocol/nwc"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
)
|
||||
|
||||
func printUsage() {
|
||||
fmt.Println("Usage: walletcli \"<NWC connection URL>\" <method> [<args...>]")
|
||||
fmt.Println("\nAvailable methods:")
|
||||
fmt.Println(" get_wallet_service_info - Get wallet service information")
|
||||
fmt.Println(" get_info - Get wallet information")
|
||||
fmt.Println(" get_balance - Get wallet balance")
|
||||
fmt.Println(" get_budget - Get wallet budget")
|
||||
fmt.Println(" make_invoice - Create an invoice")
|
||||
fmt.Println(" Args: <amount> [<description>] [<description_hash>] [<expiry>]")
|
||||
fmt.Println(" pay_invoice - Pay an invoice")
|
||||
fmt.Println(" Args: <invoice> [<amount>] [<comment>]")
|
||||
fmt.Println(" pay_keysend - Pay to a node using keysend")
|
||||
fmt.Println(" Args: <pubkey> <amount> [<preimage>] [<tlv_type> <tlv_value>...]")
|
||||
fmt.Println(" lookup_invoice - Look up an invoice")
|
||||
fmt.Println(" Args: <payment_hash or invoice>")
|
||||
fmt.Println(" list_transactions - List transactions")
|
||||
fmt.Println(" Args: [<limit>] [<offset>] [<from>] [<until>]")
|
||||
fmt.Println(" make_hold_invoice - Create a hold invoice")
|
||||
fmt.Println(" Args: <amount> <payment_hash> [<description>] [<description_hash>] [<expiry>]")
|
||||
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
|
||||
fmt.Println(" Args: <preimage>")
|
||||
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
|
||||
fmt.Println(" Args: <payment_hash>")
|
||||
fmt.Println(" sign_message - Sign a message")
|
||||
fmt.Println(" Args: <message>")
|
||||
fmt.Println(" create_connection - Create a connection")
|
||||
fmt.Println(" Args: <pubkey> <name> <methods> [<notification_types>] [<max_amount>] [<budget_renewal>] [<expires_at>]")
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 3 {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
connectionURL := os.Args[1]
|
||||
method := os.Args[2]
|
||||
args := os.Args[3:]
|
||||
// Create context
|
||||
// ctx, cancel := context.Cancel(context.Bg())
|
||||
ctx := context.Bg()
|
||||
// defer cancel()
|
||||
// Create NWC client
|
||||
client, err := nwc.NewClient(ctx, connectionURL)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating client: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Execute the requested method
|
||||
switch method {
|
||||
case "get_wallet_service_info":
|
||||
handleGetWalletServiceInfo(ctx, client)
|
||||
case "get_info":
|
||||
handleGetInfo(ctx, client)
|
||||
case "get_balance":
|
||||
handleGetBalance(ctx, client)
|
||||
case "get_budget":
|
||||
handleGetBudget(ctx, client)
|
||||
case "make_invoice":
|
||||
handleMakeInvoice(ctx, client, args)
|
||||
case "pay_invoice":
|
||||
handlePayInvoice(ctx, client, args)
|
||||
case "pay_keysend":
|
||||
handlePayKeysend(ctx, client, args)
|
||||
case "lookup_invoice":
|
||||
handleLookupInvoice(ctx, client, args)
|
||||
case "list_transactions":
|
||||
handleListTransactions(ctx, client, args)
|
||||
case "make_hold_invoice":
|
||||
handleMakeHoldInvoice(ctx, client, args)
|
||||
case "settle_hold_invoice":
|
||||
handleSettleHoldInvoice(ctx, client, args)
|
||||
case "cancel_hold_invoice":
|
||||
handleCancelHoldInvoice(ctx, client, args)
|
||||
case "sign_message":
|
||||
handleSignMessage(ctx, client, args)
|
||||
case "create_connection":
|
||||
handleCreateConnection(ctx, client, args)
|
||||
default:
|
||||
fmt.Printf("Unknown method: %s\n", method)
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetWalletServiceInfo(ctx context.T, client *nwc.Client) {
|
||||
if _, raw, err := client.GetWalletServiceInfo(ctx, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetInfo(ctx context.T, client *nwc.Client) {
|
||||
if _, raw, err := client.GetInfo(ctx, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetBalance(ctx context.T, client *nwc.Client) {
|
||||
if _, raw, err := client.GetBalance(ctx, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetBudget(ctx context.T, client *nwc.Client) {
|
||||
if _, raw, err := client.GetBudget(ctx, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleMakeInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> make_invoice <amount> [<description>] [<description_hash>] [<expiry>]")
|
||||
return
|
||||
}
|
||||
amount, err := strconv.ParseUint(args[0], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing amount: %v\n", err)
|
||||
return
|
||||
}
|
||||
params := &nwc.MakeInvoiceParams{
|
||||
Amount: amount,
|
||||
}
|
||||
if len(args) > 1 {
|
||||
params.Description = args[1]
|
||||
}
|
||||
if len(args) > 2 {
|
||||
params.DescriptionHash = args[2]
|
||||
}
|
||||
if len(args) > 3 {
|
||||
expiry, err := strconv.ParseInt(args[3], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing expiry: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.Expiry = &expiry
|
||||
}
|
||||
var raw []byte
|
||||
if _, raw, err = client.MakeInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handlePayInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> pay_invoice <invoice> [<amount>] [<comment>]")
|
||||
return
|
||||
}
|
||||
params := &nwc.PayInvoiceParams{
|
||||
Invoice: args[0],
|
||||
}
|
||||
if len(args) > 1 {
|
||||
amount, err := strconv.ParseUint(args[1], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing amount: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.Amount = &amount
|
||||
}
|
||||
if len(args) > 2 {
|
||||
comment := args[2]
|
||||
params.Metadata = &nwc.PayInvoiceMetadata{
|
||||
Comment: &comment,
|
||||
}
|
||||
}
|
||||
if _, raw, err := client.PayInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleLookupInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> lookup_invoice <payment_hash or invoice>")
|
||||
return
|
||||
}
|
||||
params := &nwc.LookupInvoiceParams{}
|
||||
// Determine if the argument is a payment hash or an invoice
|
||||
if strings.HasPrefix(args[0], "ln") {
|
||||
invoice := args[0]
|
||||
params.Invoice = &invoice
|
||||
} else {
|
||||
paymentHash := args[0]
|
||||
params.PaymentHash = &paymentHash
|
||||
}
|
||||
var err error
|
||||
var raw []byte
|
||||
if _, raw, err = client.LookupInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleListTransactions(ctx context.T, client *nwc.Client, args []string) {
|
||||
params := &nwc.ListTransactionsParams{}
|
||||
if len(args) > 0 {
|
||||
limit, err := strconv.ParseUint(args[0], 10, 16)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing limit: %v\n", err)
|
||||
return
|
||||
}
|
||||
limitUint16 := uint16(limit)
|
||||
params.Limit = &limitUint16
|
||||
}
|
||||
if len(args) > 1 {
|
||||
offset, err := strconv.ParseUint(args[1], 10, 32)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing offset: %v\n", err)
|
||||
return
|
||||
}
|
||||
offsetUint32 := uint32(offset)
|
||||
params.Offset = &offsetUint32
|
||||
}
|
||||
if len(args) > 2 {
|
||||
from, err := strconv.ParseInt(args[2], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing from: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.From = &from
|
||||
}
|
||||
if len(args) > 3 {
|
||||
until, err := strconv.ParseInt(args[3], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing until: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.Until = &until
|
||||
}
|
||||
var raw []byte
|
||||
var err error
|
||||
if _, raw, err = client.ListTransactions(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleMakeHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 2 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> make_hold_invoice <amount> <payment_hash> [<description>] [<description_hash>] [<expiry>]")
|
||||
return
|
||||
}
|
||||
amount, err := strconv.ParseUint(args[0], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing amount: %v\n", err)
|
||||
return
|
||||
}
|
||||
params := &nwc.MakeHoldInvoiceParams{
|
||||
Amount: amount,
|
||||
PaymentHash: args[1],
|
||||
}
|
||||
if len(args) > 2 {
|
||||
params.Description = args[2]
|
||||
}
|
||||
if len(args) > 3 {
|
||||
params.DescriptionHash = args[3]
|
||||
}
|
||||
if len(args) > 4 {
|
||||
expiry, err := strconv.ParseInt(args[4], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing expiry: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.Expiry = &expiry
|
||||
}
|
||||
var raw []byte
|
||||
if _, raw, err = client.MakeHoldInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleSettleHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> settle_hold_invoice <preimage>")
|
||||
return
|
||||
}
|
||||
params := &nwc.SettleHoldInvoiceParams{
|
||||
Preimage: args[0],
|
||||
}
|
||||
var raw []byte
|
||||
var err error
|
||||
if raw, err = client.SettleHoldInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleCancelHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> cancel_hold_invoice <payment_hash>")
|
||||
return
|
||||
}
|
||||
|
||||
params := &nwc.CancelHoldInvoiceParams{
|
||||
PaymentHash: args[0],
|
||||
}
|
||||
var err error
|
||||
var raw []byte
|
||||
if raw, err = client.CancelHoldInvoice(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleSignMessage(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 1 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> sign_message <message>")
|
||||
return
|
||||
}
|
||||
|
||||
params := &nwc.SignMessageParams{
|
||||
Message: args[0],
|
||||
}
|
||||
var raw []byte
|
||||
var err error
|
||||
if _, raw, err = client.SignMessage(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 2 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> pay_keysend <pubkey> <amount> [<preimage>] [<tlv_type> <tlv_value>...]")
|
||||
return
|
||||
}
|
||||
pubkey := args[0]
|
||||
amount, err := strconv.ParseUint(args[1], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing amount: %v\n", err)
|
||||
return
|
||||
}
|
||||
params := &nwc.PayKeysendParams{
|
||||
Pubkey: pubkey,
|
||||
Amount: amount,
|
||||
}
|
||||
// Optional preimage
|
||||
if len(args) > 2 {
|
||||
preimage := args[2]
|
||||
params.Preimage = &preimage
|
||||
}
|
||||
// Optional TLV records (must come in pairs)
|
||||
if len(args) > 3 {
|
||||
// Start from index 3 and process pairs of arguments
|
||||
for i := 3; i < len(args)-1; i += 2 {
|
||||
tlvType, err := strconv.ParseUint(args[i], 10, 32)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing TLV type: %v\n", err)
|
||||
return
|
||||
}
|
||||
tlvValue := args[i+1]
|
||||
params.TLVRecords = append(
|
||||
params.TLVRecords, nwc.PayKeysendTLVRecord{
|
||||
Type: uint32(tlvType),
|
||||
Value: tlvValue,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
var raw []byte
|
||||
if _, raw, err = client.PayKeysend(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) {
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Error: Missing required arguments")
|
||||
fmt.Println("Usage: walletcli <NWC connection URL> create_connection <pubkey> <name> <methods> [<notification_types>] [<max_amount>] [<budget_renewal>] [<expires_at>]")
|
||||
return
|
||||
}
|
||||
params := &nwc.CreateConnectionParams{
|
||||
Pubkey: args[0],
|
||||
Name: args[1],
|
||||
RequestMethods: strings.Split(args[2], ","),
|
||||
}
|
||||
if len(args) > 3 {
|
||||
params.NotificationTypes = strings.Split(args[3], ",")
|
||||
}
|
||||
if len(args) > 4 {
|
||||
maxAmount, err := strconv.ParseUint(args[4], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing max_amount: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.MaxAmount = &maxAmount
|
||||
}
|
||||
if len(args) > 5 {
|
||||
params.BudgetRenewal = &args[5]
|
||||
}
|
||||
if len(args) > 6 {
|
||||
expiresAt, err := strconv.ParseInt(args[6], 10, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing expires_at: %v\n", err)
|
||||
return
|
||||
}
|
||||
params.ExpiresAt = &expiresAt
|
||||
}
|
||||
var raw []byte
|
||||
var err error
|
||||
if raw, err = client.CreateConnection(ctx, params, true); !chk.E(err) {
|
||||
fmt.Println(string(raw))
|
||||
}
|
||||
}
|
||||
4
go.mod
4
go.mod
@@ -5,13 +5,12 @@ go 1.24.2
|
||||
require (
|
||||
github.com/adrg/xdg v0.5.3
|
||||
github.com/alexflint/go-arg v1.6.0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/danielgtaylor/huma/v2 v2.34.1
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/dgraph-io/badger/v4 v4.7.0
|
||||
github.com/fasthttp/websocket v1.5.12
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/gobwas/httphead v0.1.0
|
||||
github.com/gobwas/ws v1.4.0
|
||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
|
||||
github.com/klauspost/cpuid/v2 v2.2.11
|
||||
github.com/minio/sha256-simd v1.0.1
|
||||
@@ -41,7 +40,6 @@ require (
|
||||
github.com/felixge/fgprof v0.9.5 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/google/flatbuffers v25.2.10+incompatible // indirect
|
||||
github.com/google/pprof v0.0.0-20250630185457-6e76a2b096b5 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -19,6 +19,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
|
||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/danielgtaylor/huma/v2 v2.34.1 h1:EmOJAbzEGfy0wAq/QMQ1YKfEMBEfE94xdBRLPBP0gwQ=
|
||||
github.com/danielgtaylor/huma/v2 v2.34.1/go.mod h1:ynwJgLk8iGVgoaipi5tgwIQ5yoFNmiu+QdhU7CEEmhk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -44,13 +46,9 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
|
||||
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
|
||||
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
|
||||
github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
|
||||
github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
|
||||
@@ -5,12 +5,6 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"orly.dev/pkg/utils/apputil"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
env2 "orly.dev/pkg/utils/env"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/lol"
|
||||
"orly.dev/pkg/version"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
@@ -18,6 +12,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/utils/apputil"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
env2 "orly.dev/pkg/utils/env"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/lol"
|
||||
"orly.dev/pkg/version"
|
||||
|
||||
"github.com/adrg/xdg"
|
||||
"go-simpler.org/env"
|
||||
)
|
||||
@@ -26,24 +27,27 @@ import (
|
||||
// and default values. It defines parameters for app behaviour, storage
|
||||
// locations, logging, and network settings used across the relay service.
|
||||
type C struct {
|
||||
AppName string `env:"ORLY_APP_NAME" default:"orly"`
|
||||
Config string `env:"ORLY_CONFIG_DIR" usage:"location for configuration file, which has the name '.env' to make it harder to delete, and is a standard environment KEY=value<newline>... style" default:"~/.config/orly"`
|
||||
State string `env:"ORLY_STATE_DATA_DIR" usage:"storage location for state data affected by dynamic interactive interfaces" default:"~/.local/state/orly"`
|
||||
DataDir string `env:"ORLY_DATA_DIR" usage:"storage location for the event store" default:"~/.local/cache/orly"`
|
||||
Listen string `env:"ORLY_LISTEN" default:"0.0.0.0" usage:"network listen address"`
|
||||
Port int `env:"ORLY_PORT" default:"3334" usage:"port to listen on"`
|
||||
LogLevel string `env:"ORLY_LOG_LEVEL" default:"info" usage:"debug level: fatal error warn info debug trace"`
|
||||
DbLogLevel string `env:"ORLY_DB_LOG_LEVEL" default:"info" usage:"debug level: fatal error warn info debug trace"`
|
||||
Pprof string `env:"ORLY_PPROF" usage:"enable pprof on 127.0.0.1:6060" enum:"cpu,memory,allocation"`
|
||||
AuthRequired bool `env:"ORLY_AUTH_REQUIRED" default:"false" usage:"require authentication for all requests"`
|
||||
PublicReadable bool `env:"ORLY_PUBLIC_READABLE" default:"true" usage:"allow public read access to regardless of whether the client is authed"`
|
||||
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://profiles.nostr1.com/,wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/"`
|
||||
SpiderType string `env:"ORLY_SPIDER_TYPE" usage:"whether to spider, and what degree of spidering: none, directory, follows (follows means to the second degree of the follow graph)" default:"directory"`
|
||||
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
|
||||
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
|
||||
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`
|
||||
RelaySecret string `env:"ORLY_SECRET_KEY" usage:"secret key for relay cluster replication authentication"`
|
||||
PeerRelays []string `env:"ORLY_PEER_RELAYS" usage:"list of peer relays URLs that new events are pushed to in format <pubkey>|<url>"`
|
||||
AppName string `env:"ORLY_APP_NAME" default:"ORLY"`
|
||||
Config string `env:"ORLY_CONFIG_DIR" usage:"location for configuration file, which has the name '.env' to make it harder to delete, and is a standard environment KEY=value<newline>... style" default:"~/.config/orly"`
|
||||
State string `env:"ORLY_STATE_DATA_DIR" usage:"storage location for state data affected by dynamic interactive interfaces" default:"~/.local/state/orly"`
|
||||
DataDir string `env:"ORLY_DATA_DIR" usage:"storage location for the event store" default:"~/.local/cache/orly"`
|
||||
Listen string `env:"ORLY_LISTEN" default:"0.0.0.0" usage:"network listen address"`
|
||||
Port int `env:"ORLY_PORT" default:"3334" usage:"port to listen on"`
|
||||
LogLevel string `env:"ORLY_LOG_LEVEL" default:"info" usage:"debug level: fatal error warn info debug trace"`
|
||||
DbLogLevel string `env:"ORLY_DB_LOG_LEVEL" default:"info" usage:"debug level: fatal error warn info debug trace"`
|
||||
Pprof string `env:"ORLY_PPROF" usage:"enable pprof on 127.0.0.1:6060" enum:"cpu,memory,allocation"`
|
||||
AuthRequired bool `env:"ORLY_AUTH_REQUIRED" default:"false" usage:"require authentication for all requests"`
|
||||
PublicReadable bool `env:"ORLY_PUBLIC_READABLE" default:"true" usage:"allow public read access to regardless of whether the client is authed"`
|
||||
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://profiles.nostr1.com/,wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/,wss://profiles.nostr1.com/"`
|
||||
SpiderType string `env:"ORLY_SPIDER_TYPE" usage:"whether to spider, and what degree of spidering: none, directory, follows (follows means to the second degree of the follow graph)" default:"directory"`
|
||||
SpiderTime time.Duration `env:"ORLY_SPIDER_FREQUENCY" usage:"how often to run the spider, uses notation 0h0m0s" default:"1h"`
|
||||
SpiderSecondDegree bool `env:"ORLY_SPIDER_SECOND_DEGREE" default:"true" usage:"whether to enable spidering the second degree of follows for non-directory events if ORLY_SPIDER_TYPE is set to 'follows'"`
|
||||
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
|
||||
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
|
||||
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`
|
||||
Blacklist []string `env:"ORLY_BLACKLIST" usage:"list of pubkeys to block when auth is not required (comma separated)"`
|
||||
RelaySecret string `env:"ORLY_SECRET_KEY" usage:"secret key for relay cluster replication authentication"`
|
||||
PeerRelays []string `env:"ORLY_PEER_RELAYS" usage:"list of peer relays URLs that new events are pushed to in format <pubkey>|<url>"`
|
||||
}
|
||||
|
||||
// New creates and initializes a new configuration object for the relay
|
||||
|
||||
@@ -42,30 +42,35 @@ func (s *Server) AcceptEvent(
|
||||
remote string,
|
||||
) (accept bool, notice string, afterSave func()) {
|
||||
if !s.AuthRequired() {
|
||||
accept = true
|
||||
// Check blacklist for public relay mode
|
||||
if len(s.blacklistPubkeys) > 0 {
|
||||
for _, blockedPubkey := range s.blacklistPubkeys {
|
||||
if bytes.Equal(blockedPubkey, ev.Pubkey) {
|
||||
notice = "event author is blacklisted"
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// if auth is required and the user is not authed, reject
|
||||
if s.AuthRequired() && len(authedPubkey) == 0 {
|
||||
if len(authedPubkey) == 0 {
|
||||
notice = "client isn't authed"
|
||||
return
|
||||
}
|
||||
// check if the authed user is on the lists
|
||||
list := append(s.OwnersFollowed(), s.FollowedFollows()...)
|
||||
for _, u := range list {
|
||||
if bytes.Equal(u, authedPubkey) {
|
||||
accept = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !accept {
|
||||
return
|
||||
}
|
||||
for _, u := range s.OwnersMuted() {
|
||||
if bytes.Equal(u, authedPubkey) {
|
||||
notice = "event author is banned from this relay"
|
||||
return
|
||||
}
|
||||
}
|
||||
// check if the authed user is on the lists
|
||||
list := append(s.OwnersFollowed(), s.FollowedFollows()...)
|
||||
for _, u := range list {
|
||||
if bytes.Equal(u, authedPubkey) {
|
||||
accept = true
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
|
||||
// mockServerForEvent is a simple mock implementation of the Server struct for testing AcceptEvent
|
||||
type mockServerForEvent struct {
|
||||
authRequired bool
|
||||
ownersFollowed [][]byte
|
||||
authRequired bool
|
||||
ownersFollowed [][]byte
|
||||
followedFollows [][]byte
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ func TestAcceptEventWithRealServer(t *testing.T) {
|
||||
if accept {
|
||||
t.Error("AcceptEvent() accept = true, want false")
|
||||
}
|
||||
if notice != "" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
|
||||
if notice != "client isn't authed" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want 'client isn't authed'", notice)
|
||||
}
|
||||
if afterSave != nil {
|
||||
t.Error("AcceptEvent() afterSave is not nil, but should be nil")
|
||||
@@ -234,4 +234,81 @@ func TestAcceptEventWithRealServer(t *testing.T) {
|
||||
if !accept {
|
||||
t.Error("AcceptEvent() accept = false, want true")
|
||||
}
|
||||
|
||||
// Test with muted user
|
||||
s.SetOwnersMuted([][]byte{[]byte("test-pubkey")})
|
||||
accept, notice, afterSave = s.AcceptEvent(ctx, testEvent, req, []byte("test-pubkey"), "127.0.0.1")
|
||||
if accept {
|
||||
t.Error("AcceptEvent() accept = true, want false")
|
||||
}
|
||||
if notice != "event author is banned from this relay" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want 'event author is banned from this relay'", notice)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcceptEventWithBlacklist tests the blacklist functionality when auth is not required
|
||||
func TestAcceptEventWithBlacklist(t *testing.T) {
|
||||
// Create a context and HTTP request for testing
|
||||
ctx := context.Bg()
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
// Test pubkey bytes
|
||||
testPubkey := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}
|
||||
blockedPubkey := []byte{0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30}
|
||||
|
||||
// Test with public relay mode (auth not required) and no blacklist
|
||||
s := &Server{
|
||||
C: &config.C{
|
||||
AuthRequired: false,
|
||||
},
|
||||
Lists: new(Lists),
|
||||
}
|
||||
|
||||
// Create event with test pubkey
|
||||
testEvent := &event.E{}
|
||||
testEvent.Pubkey = testPubkey
|
||||
|
||||
// Should accept when no blacklist
|
||||
accept, notice, _ := s.AcceptEvent(ctx, testEvent, req, nil, "127.0.0.1")
|
||||
if !accept {
|
||||
t.Error("AcceptEvent() accept = false, want true")
|
||||
}
|
||||
if notice != "" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
|
||||
}
|
||||
|
||||
// Add blacklist with different pubkey
|
||||
s.blacklistPubkeys = [][]byte{blockedPubkey}
|
||||
|
||||
// Should still accept when author not in blacklist
|
||||
accept, notice, _ = s.AcceptEvent(ctx, testEvent, req, nil, "127.0.0.1")
|
||||
if !accept {
|
||||
t.Error("AcceptEvent() accept = false, want true")
|
||||
}
|
||||
if notice != "" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
|
||||
}
|
||||
|
||||
// Create event with blocked pubkey
|
||||
blockedEvent := &event.E{}
|
||||
blockedEvent.Pubkey = blockedPubkey
|
||||
|
||||
// Should reject when author is in blacklist
|
||||
accept, notice, _ = s.AcceptEvent(ctx, blockedEvent, req, nil, "127.0.0.1")
|
||||
if accept {
|
||||
t.Error("AcceptEvent() accept = true, want false")
|
||||
}
|
||||
if notice != "event author is blacklisted" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want 'event author is blacklisted'", notice)
|
||||
}
|
||||
|
||||
// Test with auth required - blacklist should not apply
|
||||
s.C.AuthRequired = true
|
||||
accept, notice, _ = s.AcceptEvent(ctx, blockedEvent, req, nil, "127.0.0.1")
|
||||
if accept {
|
||||
t.Error("AcceptEvent() accept = true, want false")
|
||||
}
|
||||
if notice != "client isn't authed" {
|
||||
t.Errorf("AcceptEvent() notice = %v, want 'client isn't authed'", notice)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package relay
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
|
||||
"orly.dev/pkg/interfaces/relay"
|
||||
"orly.dev/pkg/protocol/relayinfo"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/version"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// HandleRelayInfo generates and returns a relay information document in JSON
|
||||
@@ -44,7 +45,7 @@ func (s *Server) HandleRelayInfo(w http.ResponseWriter, r *http.Request) {
|
||||
// relayinfo.CommandResults,
|
||||
relayinfo.ParameterizedReplaceableEvents,
|
||||
// relayinfo.ExpirationTimestamp,
|
||||
// relayinfo.ProtectedEvents,
|
||||
relayinfo.ProtectedEvents,
|
||||
// relayinfo.RelayListMetadata,
|
||||
)
|
||||
sort.Sort(supportedNIPs)
|
||||
@@ -52,8 +53,9 @@ func (s *Server) HandleRelayInfo(w http.ResponseWriter, r *http.Request) {
|
||||
info = &relayinfo.T{
|
||||
Name: s.relay.Name(),
|
||||
Description: version.Description,
|
||||
Nips: supportedNIPs, Software: version.URL,
|
||||
Version: version.V,
|
||||
Nips: supportedNIPs,
|
||||
Software: version.URL,
|
||||
Version: version.V,
|
||||
Limitation: relayinfo.Limits{
|
||||
AuthRequired: s.C.AuthRequired,
|
||||
RestrictedWrites: s.C.AuthRequired,
|
||||
|
||||
@@ -8,41 +8,41 @@ import (
|
||||
func TestLists_OwnersPubkeys(t *testing.T) {
|
||||
// Create a new Lists instance
|
||||
l := &Lists{}
|
||||
|
||||
|
||||
// Test with empty list
|
||||
pks := l.OwnersPubkeys()
|
||||
if len(pks) != 0 {
|
||||
t.Errorf("Expected empty list, got %d items", len(pks))
|
||||
}
|
||||
|
||||
|
||||
// Test with some pubkeys
|
||||
testPubkeys := [][]byte{
|
||||
[]byte("pubkey1"),
|
||||
[]byte("pubkey2"),
|
||||
[]byte("pubkey3"),
|
||||
}
|
||||
|
||||
|
||||
l.SetOwnersPubkeys(testPubkeys)
|
||||
|
||||
|
||||
// Verify length
|
||||
if l.LenOwnersPubkeys() != len(testPubkeys) {
|
||||
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersPubkeys())
|
||||
}
|
||||
|
||||
|
||||
// Verify content
|
||||
pks = l.OwnersPubkeys()
|
||||
if len(pks) != len(testPubkeys) {
|
||||
t.Errorf("Expected %d pubkeys, got %d", len(testPubkeys), len(pks))
|
||||
}
|
||||
|
||||
|
||||
// Verify each pubkey
|
||||
for i, pk := range pks {
|
||||
if !bytes.Equal(pk, testPubkeys[i]) {
|
||||
t.Errorf("Pubkey at index %d doesn't match: expected %s, got %s",
|
||||
t.Errorf("Pubkey at index %d doesn't match: expected %s, got %s",
|
||||
i, testPubkeys[i], pk)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Verify that the returned slice is a copy, not a reference
|
||||
pks[0] = []byte("modified")
|
||||
newPks := l.OwnersPubkeys()
|
||||
@@ -54,37 +54,37 @@ func TestLists_OwnersPubkeys(t *testing.T) {
|
||||
func TestLists_OwnersFollowed(t *testing.T) {
|
||||
// Create a new Lists instance
|
||||
l := &Lists{}
|
||||
|
||||
|
||||
// Test with empty list
|
||||
followed := l.OwnersFollowed()
|
||||
if len(followed) != 0 {
|
||||
t.Errorf("Expected empty list, got %d items", len(followed))
|
||||
}
|
||||
|
||||
|
||||
// Test with some pubkeys
|
||||
testPubkeys := [][]byte{
|
||||
[]byte("followed1"),
|
||||
[]byte("followed2"),
|
||||
[]byte("followed3"),
|
||||
}
|
||||
|
||||
|
||||
l.SetOwnersFollowed(testPubkeys)
|
||||
|
||||
|
||||
// Verify length
|
||||
if l.LenOwnersFollowed() != len(testPubkeys) {
|
||||
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersFollowed())
|
||||
}
|
||||
|
||||
|
||||
// Verify content
|
||||
followed = l.OwnersFollowed()
|
||||
if len(followed) != len(testPubkeys) {
|
||||
t.Errorf("Expected %d followed, got %d", len(testPubkeys), len(followed))
|
||||
}
|
||||
|
||||
|
||||
// Verify each pubkey
|
||||
for i, pk := range followed {
|
||||
if !bytes.Equal(pk, testPubkeys[i]) {
|
||||
t.Errorf("Followed at index %d doesn't match: expected %s, got %s",
|
||||
t.Errorf("Followed at index %d doesn't match: expected %s, got %s",
|
||||
i, testPubkeys[i], pk)
|
||||
}
|
||||
}
|
||||
@@ -93,37 +93,37 @@ func TestLists_OwnersFollowed(t *testing.T) {
|
||||
func TestLists_FollowedFollows(t *testing.T) {
|
||||
// Create a new Lists instance
|
||||
l := &Lists{}
|
||||
|
||||
|
||||
// Test with empty list
|
||||
follows := l.FollowedFollows()
|
||||
if len(follows) != 0 {
|
||||
t.Errorf("Expected empty list, got %d items", len(follows))
|
||||
}
|
||||
|
||||
|
||||
// Test with some pubkeys
|
||||
testPubkeys := [][]byte{
|
||||
[]byte("follow1"),
|
||||
[]byte("follow2"),
|
||||
[]byte("follow3"),
|
||||
}
|
||||
|
||||
|
||||
l.SetFollowedFollows(testPubkeys)
|
||||
|
||||
|
||||
// Verify length
|
||||
if l.LenFollowedFollows() != len(testPubkeys) {
|
||||
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenFollowedFollows())
|
||||
}
|
||||
|
||||
|
||||
// Verify content
|
||||
follows = l.FollowedFollows()
|
||||
if len(follows) != len(testPubkeys) {
|
||||
t.Errorf("Expected %d follows, got %d", len(testPubkeys), len(follows))
|
||||
}
|
||||
|
||||
|
||||
// Verify each pubkey
|
||||
for i, pk := range follows {
|
||||
if !bytes.Equal(pk, testPubkeys[i]) {
|
||||
t.Errorf("Follow at index %d doesn't match: expected %s, got %s",
|
||||
t.Errorf("Follow at index %d doesn't match: expected %s, got %s",
|
||||
i, testPubkeys[i], pk)
|
||||
}
|
||||
}
|
||||
@@ -132,37 +132,37 @@ func TestLists_FollowedFollows(t *testing.T) {
|
||||
func TestLists_OwnersMuted(t *testing.T) {
|
||||
// Create a new Lists instance
|
||||
l := &Lists{}
|
||||
|
||||
|
||||
// Test with empty list
|
||||
muted := l.OwnersMuted()
|
||||
if len(muted) != 0 {
|
||||
t.Errorf("Expected empty list, got %d items", len(muted))
|
||||
}
|
||||
|
||||
|
||||
// Test with some pubkeys
|
||||
testPubkeys := [][]byte{
|
||||
[]byte("muted1"),
|
||||
[]byte("muted2"),
|
||||
[]byte("muted3"),
|
||||
}
|
||||
|
||||
|
||||
l.SetOwnersMuted(testPubkeys)
|
||||
|
||||
|
||||
// Verify length
|
||||
if l.LenOwnersMuted() != len(testPubkeys) {
|
||||
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersMuted())
|
||||
}
|
||||
|
||||
|
||||
// Verify content
|
||||
muted = l.OwnersMuted()
|
||||
if len(muted) != len(testPubkeys) {
|
||||
t.Errorf("Expected %d muted, got %d", len(testPubkeys), len(muted))
|
||||
}
|
||||
|
||||
|
||||
// Verify each pubkey
|
||||
for i, pk := range muted {
|
||||
if !bytes.Equal(pk, testPubkeys[i]) {
|
||||
t.Errorf("Muted at index %d doesn't match: expected %s, got %s",
|
||||
t.Errorf("Muted at index %d doesn't match: expected %s, got %s",
|
||||
i, testPubkeys[i], pk)
|
||||
}
|
||||
}
|
||||
@@ -171,10 +171,10 @@ func TestLists_OwnersMuted(t *testing.T) {
|
||||
func TestLists_ConcurrentAccess(t *testing.T) {
|
||||
// Create a new Lists instance
|
||||
l := &Lists{}
|
||||
|
||||
|
||||
// Test concurrent access to the lists
|
||||
done := make(chan bool)
|
||||
|
||||
|
||||
// Concurrent reads and writes
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
@@ -183,7 +183,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
l.SetOwnersFollowed([][]byte{[]byte("followed1"), []byte("followed2")})
|
||||
@@ -191,7 +191,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
l.SetFollowedFollows([][]byte{[]byte("follow1"), []byte("follow2")})
|
||||
@@ -199,7 +199,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
l.SetOwnersMuted([][]byte{[]byte("muted1"), []byte("muted2")})
|
||||
@@ -207,11 +207,11 @@ func TestLists_ConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 4; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
// If we got here without deadlocks or panics, the test passes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"orly.dev/pkg/protocol/openapi"
|
||||
"orly.dev/pkg/protocol/socketapi"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/protocol/openapi"
|
||||
"orly.dev/pkg/protocol/socketapi"
|
||||
|
||||
"orly.dev/pkg/app/config"
|
||||
"orly.dev/pkg/app/relay/helpers"
|
||||
"orly.dev/pkg/app/relay/options"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"orly.dev/pkg/protocol/servemux"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/keys"
|
||||
"orly.dev/pkg/utils/log"
|
||||
|
||||
"github.com/rs/cors"
|
||||
@@ -29,14 +31,15 @@ import (
|
||||
// encapsulates various components such as context, cancel function, options,
|
||||
// relay interface, address, HTTP server, and configuration settings.
|
||||
type Server struct {
|
||||
Ctx context.T
|
||||
Cancel context.F
|
||||
options *options.T
|
||||
relay relay.I
|
||||
Addr string
|
||||
mux *servemux.S
|
||||
httpServer *http.Server
|
||||
listeners *publish.S
|
||||
Ctx context.T
|
||||
Cancel context.F
|
||||
options *options.T
|
||||
relay relay.I
|
||||
Addr string
|
||||
mux *servemux.S
|
||||
httpServer *http.Server
|
||||
listeners *publish.S
|
||||
blacklistPubkeys [][]byte
|
||||
*config.C
|
||||
*Lists
|
||||
*Peers
|
||||
@@ -105,6 +108,17 @@ func NewServer(
|
||||
Lists: new(Lists),
|
||||
Peers: new(Peers),
|
||||
}
|
||||
// Parse blacklist pubkeys
|
||||
for _, v := range s.C.Blacklist {
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
var pk []byte
|
||||
if pk, err = keys.DecodeNpubOrHex(v); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
s.blacklistPubkeys = append(s.blacklistPubkeys, pk)
|
||||
}
|
||||
chk.E(
|
||||
s.Peers.Init(sp.C.PeerRelays, sp.C.RelaySecret),
|
||||
)
|
||||
@@ -207,6 +221,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) Start(
|
||||
host string, port int, started ...chan bool,
|
||||
) (err error) {
|
||||
log.I.F("running spider every %v", s.C.SpiderTime)
|
||||
if len(s.C.Owners) > 0 {
|
||||
// start up spider
|
||||
if err = s.Spider(s.C.Private); chk.E(err) {
|
||||
@@ -216,7 +231,7 @@ func (s *Server) Start(
|
||||
}
|
||||
}
|
||||
// start up a spider run to trigger every 30 minutes
|
||||
ticker := time.NewTicker(time.Hour)
|
||||
ticker := time.NewTicker(s.C.SpiderTime)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/crypto/ec/schnorr"
|
||||
"orly.dev/pkg/database/indexes/types"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
@@ -14,8 +17,7 @@ import (
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
"orly.dev/pkg/utils/values"
|
||||
)
|
||||
|
||||
// IdPkTs is a map of event IDs to their id, pubkey, kind, and timestamp
|
||||
@@ -122,9 +124,9 @@ func (s *Server) SpiderFetch(
|
||||
l := &lim
|
||||
var since *timestamp.T
|
||||
if k == nil {
|
||||
since = timestamp.FromTime(time.Now().Add(-1 * time.Hour))
|
||||
since = timestamp.FromTime(time.Now().Add(-1 * s.C.SpiderTime * 3 / 2))
|
||||
} else {
|
||||
l = nil
|
||||
l = values.ToUintPointer(512)
|
||||
}
|
||||
batchFilter := &filter.F{
|
||||
Kinds: k,
|
||||
@@ -141,14 +143,10 @@ func (s *Server) SpiderFetch(
|
||||
var evss event.S
|
||||
var cli *ws.Client
|
||||
if cli, err = ws.RelayConnect(
|
||||
context.Bg(), seed, ws.WithSignatureChecker(
|
||||
func(e *event.E) bool {
|
||||
return true
|
||||
},
|
||||
),
|
||||
context.Bg(), seed,
|
||||
); chk.E(err) {
|
||||
err = nil
|
||||
return
|
||||
continue
|
||||
}
|
||||
if evss, err = cli.QuerySync(
|
||||
context.Bg(), batchFilter,
|
||||
|
||||
@@ -103,13 +103,32 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
|
||||
if s.C.SpiderType == "directory" {
|
||||
k = kinds.New(
|
||||
kind.ProfileMetadata, kind.RelayListMetadata,
|
||||
kind.DMRelaysList,
|
||||
kind.DMRelaysList, kind.MuteList,
|
||||
)
|
||||
}
|
||||
everyone := append(ownersFollowed, followedFollows...)
|
||||
everyone := ownersFollowed
|
||||
if s.C.SpiderSecondDegree &&
|
||||
(s.C.SpiderType == "follows" ||
|
||||
s.C.SpiderType == "directory") {
|
||||
everyone = append(ownersFollowed, followedFollows...)
|
||||
}
|
||||
_, _ = s.SpiderFetch(
|
||||
k, false, true, everyone...,
|
||||
)
|
||||
// get the directory events also for second degree if spider
|
||||
// type is directory but second degree is disabled, so all
|
||||
// directory data is available for all whitelisted users.
|
||||
if !s.C.SpiderSecondDegree && s.C.SpiderType == "directory" {
|
||||
k = kinds.New(
|
||||
kind.ProfileMetadata, kind.RelayListMetadata,
|
||||
kind.DMRelaysList, kind.MuteList,
|
||||
)
|
||||
everyone = append(ownersFollowed, followedFollows...)
|
||||
_, _ = s.SpiderFetch(
|
||||
k, false, true, everyone...,
|
||||
)
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestBech32(t *testing.T) {
|
||||
{
|
||||
"split1cheo2y9e2w",
|
||||
ErrNonCharsetChar('o'),
|
||||
}, // invalid character (o) in data part
|
||||
}, // invalid character (o) in data part
|
||||
{"split1a2y9w", ErrInvalidSeparatorIndex(5)}, // too short data part
|
||||
{
|
||||
"1checkupstagehandshakeupstreamerranterredcaperred2y9e3w",
|
||||
|
||||
@@ -87,9 +87,9 @@ type nonceAggValidCase struct {
|
||||
}
|
||||
|
||||
type nonceAggInvalidCase struct {
|
||||
Indices []int `json:"pnonce_indices"`
|
||||
Error nonceAggError `json:"error"`
|
||||
Comment string `json:"comment"`
|
||||
Indices []int `json:"pnonce_indices"`
|
||||
Error nonceAggError `json:"error"`
|
||||
Comment string `json:"comment"`
|
||||
ExpectedErr string `json:"btcec_err"`
|
||||
}
|
||||
|
||||
|
||||
@@ -158,8 +158,8 @@ func Decrypt(b64ciphertextWrapped, conversationKey []byte) (
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateConversationKey performs an ECDH key generation hashed with the nip-44-v2 using hkdf.
|
||||
func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
|
||||
// GenerateConversationKeyFromHex performs an ECDH key generation hashed with the nip-44-v2 using hkdf.
|
||||
func GenerateConversationKeyFromHex(pkh, skh string) (ck []byte, err error) {
|
||||
if skh >= "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" ||
|
||||
skh == "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
err = errorf.E(
|
||||
@@ -184,6 +184,17 @@ func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateConversationKeyWithSigner(sign signer.I, pk []byte) (
|
||||
ck []byte, err error,
|
||||
) {
|
||||
var shared []byte
|
||||
if shared, err = sign.ECDH(pk); chk.E(err) {
|
||||
return
|
||||
}
|
||||
ck = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
|
||||
return
|
||||
}
|
||||
|
||||
func encrypt(key, nonce, message []byte) (dst []byte, err error) {
|
||||
var cipher *chacha20.Cipher
|
||||
if cipher, err = chacha20.NewUnauthenticatedCipher(key, nonce); chk.E(err) {
|
||||
|
||||
@@ -47,7 +47,9 @@ func assertCryptPriv(
|
||||
return
|
||||
}
|
||||
expectedBytes = []byte(expected)
|
||||
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
|
||||
if ok = assert.Equalf(
|
||||
t, string(expectedBytes), string(actualBytes), "wrong encryption",
|
||||
); !ok {
|
||||
return
|
||||
}
|
||||
decrypted, err = Decrypt(expectedBytes, k1)
|
||||
@@ -62,8 +64,8 @@ func assertDecryptFail(
|
||||
) {
|
||||
var (
|
||||
k1, ciphertextBytes []byte
|
||||
ok bool
|
||||
err error
|
||||
ok bool
|
||||
err error
|
||||
)
|
||||
k1, err = hex.Dec(conversationKey)
|
||||
if ok = assert.NoErrorf(
|
||||
@@ -79,7 +81,7 @@ func assertDecryptFail(
|
||||
func assertConversationKeyFail(
|
||||
t *testing.T, priv string, pub string, msg string,
|
||||
) {
|
||||
_, err := GenerateConversationKey(pub, priv)
|
||||
_, err := GenerateConversationKeyFromHex(pub, priv)
|
||||
assert.ErrorContains(t, err, msg)
|
||||
}
|
||||
|
||||
@@ -98,7 +100,7 @@ func assertConversationKeyGeneration(
|
||||
); !ok {
|
||||
return false
|
||||
}
|
||||
actualConversationKey, err = GenerateConversationKey(pub, priv)
|
||||
actualConversationKey, err = GenerateConversationKeyFromHex(pub, priv)
|
||||
if ok = assert.NoErrorf(
|
||||
t, err, "conversation key generation failed: %v", err,
|
||||
); !ok {
|
||||
@@ -1312,7 +1314,7 @@ func TestMaxLength(t *testing.T) {
|
||||
pub2, _ := keys.GetPublicKeyHex(string(sk2))
|
||||
salt := make([]byte, 32)
|
||||
rand.Read(salt)
|
||||
conversationKey, _ := GenerateConversationKey(pub2, string(sk1))
|
||||
conversationKey, _ := GenerateConversationKeyFromHex(pub2, string(sk1))
|
||||
plaintext := strings.Repeat("a", MaxPlaintextSize)
|
||||
plaintextBytes := []byte(plaintext)
|
||||
encrypted, err := Encrypt(
|
||||
@@ -1366,7 +1368,9 @@ func assertCryptPub(
|
||||
return
|
||||
}
|
||||
expectedBytes = []byte(expected)
|
||||
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
|
||||
if ok = assert.Equalf(
|
||||
t, string(expectedBytes), string(actualBytes), "wrong encryption",
|
||||
); !ok {
|
||||
return
|
||||
}
|
||||
decrypted, err = Decrypt(expectedBytes, k1)
|
||||
|
||||
@@ -33,7 +33,7 @@ func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
|
||||
}
|
||||
|
||||
func HexToBin(hexStr string) (b []byte, err error) {
|
||||
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
|
||||
if b, err = hex.DecAppend(b, []byte(hexStr)); chk.E(err) {
|
||||
return
|
||||
}
|
||||
return
|
||||
|
||||
@@ -2,16 +2,16 @@ package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/database/indexes"
|
||||
types2 "orly.dev/pkg/database/indexes/types"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
@@ -26,8 +26,7 @@ func TestGetIndexesForEvent(t *testing.T) {
|
||||
// indexes
|
||||
func verifyIndexIncluded(t *testing.T, idxs [][]byte, expectedIdx *indexes.T) {
|
||||
// Marshal the expected index
|
||||
buf := codecbuf.Get()
|
||||
defer codecbuf.Put(buf)
|
||||
buf := new(bytes.Buffer)
|
||||
err := expectedIdx.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("Failed to marshal expected index: %v", err)
|
||||
|
||||
@@ -3,16 +3,16 @@ package database
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/database/indexes"
|
||||
types2 "orly.dev/pkg/database/indexes/types"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/encoders/filter"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/kinds"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
@@ -41,8 +41,7 @@ func verifyIndex(
|
||||
}
|
||||
|
||||
// Marshal the expected start index
|
||||
startBuf := codecbuf.Get()
|
||||
defer codecbuf.Put(startBuf)
|
||||
startBuf := new(bytes.Buffer)
|
||||
err := expectedStartIdx.MarshalWrite(startBuf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("Failed to marshal expected start index: %v", err)
|
||||
@@ -62,8 +61,7 @@ func verifyIndex(
|
||||
}
|
||||
|
||||
// Marshal the expected end index
|
||||
endBuf := codecbuf.Get()
|
||||
defer codecbuf.Put(endBuf)
|
||||
endBuf := new(bytes.Buffer)
|
||||
err = endIdx.MarshalWrite(endBuf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("Failed to marshal expected End index: %v", err)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"orly.dev/pkg/database/indexes/types"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
)
|
||||
@@ -49,7 +48,7 @@ func TestPrefixMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite method
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := prefix.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -209,7 +208,7 @@ func TestTStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -272,7 +271,7 @@ func TestEventFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -318,7 +317,7 @@ func TestIdFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -391,7 +390,7 @@ func TestIdPubkeyFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -452,7 +451,7 @@ func TestCreatedAtFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -516,7 +515,7 @@ func TestPubkeyFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -588,7 +587,7 @@ func TestPubkeyTagFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -660,7 +659,7 @@ func TestTagFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -724,7 +723,7 @@ func TestKindFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -789,7 +788,7 @@ func TestKindTagFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -865,7 +864,7 @@ func TestKindPubkeyFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -941,7 +940,7 @@ func TestKindPubkeyTagFunctions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test marshaling and unmarshaling
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = enc.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -84,7 +84,7 @@ func testUint16Sorting(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint16 values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint16 values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func testUint24Sorting(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint24 values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint24 values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func testUint32Sorting(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint32 values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint32 values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func testUint40Sorting(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint40 values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint40 values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func testUint64Sorting(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint64 values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint64 values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func testUint16EdgeCases(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint16 edge case values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint16 edge case values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func testUint24EdgeCases(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint24 edge case values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint24 edge case values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -294,7 +294,7 @@ func testUint32EdgeCases(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint32 edge case values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint32 edge case values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -326,7 +326,7 @@ func testUint40EdgeCases(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint40 edge case values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint40 edge case values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -355,7 +355,7 @@ func testUint64EdgeCases(t *testing.T) {
|
||||
// Check if they sort correctly with bytes.Compare
|
||||
for i := 0; i < len(marshaledValues)-1; i++ {
|
||||
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
|
||||
t.Errorf("Uint64 edge case values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("Uint64 edge case values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
|
||||
}
|
||||
@@ -390,7 +390,7 @@ func TestEndianness(t *testing.T) {
|
||||
result := bytes.Compare(bigEndianValues[i], bigEndianValues[i+1])
|
||||
t.Logf("Compare %d with %d: result = %d", values[i], values[i+1], result)
|
||||
if result >= 0 {
|
||||
t.Errorf("BigEndian values don't sort correctly: %v should be less than %v",
|
||||
t.Errorf("BigEndian values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", bigEndianValues[i], bigEndianValues[i+1])
|
||||
}
|
||||
@@ -404,7 +404,7 @@ func TestEndianness(t *testing.T) {
|
||||
t.Logf("Compare %d with %d: result = %d", values[i], values[i+1], result)
|
||||
if result >= 0 {
|
||||
correctOrder = false
|
||||
t.Logf("LittleEndian values don't sort correctly: %v should be less than %v",
|
||||
t.Logf("LittleEndian values don't sort correctly: %v should be less than %v",
|
||||
values[i], values[i+1])
|
||||
t.Logf("Bytes representation: %v vs %v", littleEndianValues[i], littleEndianValues[i+1])
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestIdMarshalWriteUnmarshalRead(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = fi1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -2,10 +2,10 @@ package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestIdent_MarshalWriteUnmarshalRead(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = i1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -3,10 +3,10 @@ package types
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/encoders/hex"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
@@ -142,7 +142,7 @@ func TestIdHashMarshalWriteUnmarshalRead(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = i1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -2,9 +2,9 @@ package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
func TestLetter_New(t *testing.T) {
|
||||
@@ -53,7 +53,7 @@ func TestLetter_MarshalWriteUnmarshalRead(t *testing.T) {
|
||||
l1 := new(Letter)
|
||||
l1.Set('A')
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := l1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -2,11 +2,11 @@ package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/crypto/ec/schnorr"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/encoders/hex"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
)
|
||||
@@ -105,7 +105,7 @@ func TestPubHash_MarshalWriteUnmarshalRead(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err = ph1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -2,10 +2,10 @@ package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
func TestTimestamp_FromInt(t *testing.T) {
|
||||
@@ -89,7 +89,7 @@ func TestTimestamp_FromBytes(t *testing.T) {
|
||||
v.Set(12345)
|
||||
|
||||
// Marshal it to bytes
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := v.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -163,7 +163,7 @@ func TestTimestamp_Bytes(t *testing.T) {
|
||||
func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) {
|
||||
// Test with a positive value
|
||||
ts1 := &Timestamp{val: 12345}
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := ts1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -183,7 +183,7 @@ func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) {
|
||||
|
||||
// Test with a negative value
|
||||
ts1 = &Timestamp{val: -12345}
|
||||
buf = codecbuf.Get()
|
||||
buf = new(bytes.Buffer)
|
||||
err = ts1.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
@@ -225,7 +225,7 @@ func TestTimestamp_WithCurrentTime(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite and UnmarshalRead
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
err := ts.MarshalWrite(buf)
|
||||
if chk.E(err) {
|
||||
t.Fatalf("MarshalWrite failed: %v", err)
|
||||
|
||||
@@ -3,11 +3,11 @@ package types
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestUint16(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test encoding to []byte and decoding back
|
||||
bufEnc := codecbuf.Get()
|
||||
bufEnc := new(bytes.Buffer)
|
||||
|
||||
// MarshalWrite
|
||||
err := encodedUint16.MarshalWrite(bufEnc)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
func TestUint24(t *testing.T) {
|
||||
@@ -45,7 +46,7 @@ func TestUint24(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite and UnmarshalRead
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// MarshalWrite directly to the buffer
|
||||
if err := codec.MarshalWrite(buf); chk.E(err) {
|
||||
|
||||
@@ -3,11 +3,11 @@ package types
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestUint32(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test encoding to []byte and decoding back
|
||||
bufEnc := codecbuf.Get()
|
||||
bufEnc := new(bytes.Buffer)
|
||||
|
||||
// MarshalWrite
|
||||
err := codec.MarshalWrite(bufEnc)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
func TestUint40(t *testing.T) {
|
||||
@@ -48,7 +49,7 @@ func TestUint40(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test MarshalWrite and UnmarshalRead
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// Marshal to a buffer
|
||||
if err = codec.MarshalWrite(buf); chk.E(err) {
|
||||
|
||||
@@ -3,11 +3,11 @@ package types
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestUint64(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test encoding to []byte and decoding back
|
||||
bufEnc := codecbuf.Get()
|
||||
bufEnc := new(bytes.Buffer)
|
||||
|
||||
// MarshalWrite
|
||||
err := codec.MarshalWrite(bufEnc)
|
||||
|
||||
@@ -4,6 +4,7 @@ package reqenvelope
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"orly.dev/pkg/encoders/envelopes"
|
||||
"orly.dev/pkg/encoders/filters"
|
||||
"orly.dev/pkg/encoders/subscription"
|
||||
@@ -37,10 +38,21 @@ func New() *T {
|
||||
|
||||
// NewFrom creates a new reqenvelope.T with a provided subscription.Id and
|
||||
// filters.T.
|
||||
func NewFrom(id *subscription.Id, filters *filters.T) *T {
|
||||
func NewFrom(id *subscription.Id, ff *filters.T) *T {
|
||||
return &T{
|
||||
Subscription: id,
|
||||
Filters: filters,
|
||||
Filters: ff,
|
||||
}
|
||||
}
|
||||
|
||||
func NewWithIdString(id string, ff *filters.T) (sub *T) {
|
||||
sid, err := subscription.NewId(id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return &T{
|
||||
Subscription: sid,
|
||||
Filters: ff,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ package event
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/encoders/event/examples"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/encoders/event/examples"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
|
||||
@@ -19,7 +19,7 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
|
||||
var counter int
|
||||
for scanner.Scan() {
|
||||
// Create new event objects and buffer for each iteration
|
||||
buf := codecbuf.Get()
|
||||
buf := new(bytes.Buffer)
|
||||
ea, eb := New(), New()
|
||||
|
||||
chk.E(scanner.Err())
|
||||
@@ -42,7 +42,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
|
||||
// Create a new buffer for unmarshaling
|
||||
buf2 := bytes.NewBuffer(buf.Bytes())
|
||||
if err = eb.UnmarshalBinary(buf2); chk.E(err) {
|
||||
codecbuf.Put(buf)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -57,9 +56,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
// Return buffer to pool
|
||||
codecbuf.Put(buf)
|
||||
|
||||
counter++
|
||||
out = out[:0]
|
||||
}
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/encoders/event/examples"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTMarshal_Unmarshal(t *testing.T) {
|
||||
|
||||
@@ -2,12 +2,13 @@ package event
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
text2 "orly.dev/pkg/encoders/text"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// compareTags compares two tags and reports any differences
|
||||
@@ -96,7 +97,8 @@ func TestUnmarshalEscapedJSONInTags(t *testing.T) {
|
||||
unmarshaledTag := unmarshaledEvent.Tags.GetTagElement(0)
|
||||
if unmarshaledTag.Len() != 2 {
|
||||
t.Fatalf(
|
||||
"Expected tag with 2 elements, got %d", unmarshaledTag.Len(),
|
||||
"Expected tag with 2 elements, got %d",
|
||||
unmarshaledTag.Len(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ package event
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/encoders/hex"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// compareEvents compares two events and reports any differences
|
||||
|
||||
@@ -7,6 +7,8 @@ package filter
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sort"
|
||||
|
||||
"orly.dev/pkg/crypto/ec/schnorr"
|
||||
"orly.dev/pkg/crypto/ec/secp256k1"
|
||||
"orly.dev/pkg/crypto/sha256"
|
||||
@@ -21,8 +23,8 @@ import (
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/pointers"
|
||||
"sort"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
)
|
||||
@@ -181,12 +183,12 @@ func (f *F) Marshal(dst []byte) (b []byte) {
|
||||
dst = append(dst, '[')
|
||||
for i, value := range values {
|
||||
dst = append(dst, '"')
|
||||
if tKey[1] == 'e' || tKey[1] == 'p' {
|
||||
// event and pubkey tags are binary 32 bytes
|
||||
dst = hex.EncAppend(dst, value)
|
||||
} else {
|
||||
dst = append(dst, value...)
|
||||
}
|
||||
// if tKey[1] == 'e' || tKey[1] == 'p' {
|
||||
// // event and pubkey tags are binary 32 bytes
|
||||
// dst = hex.EncAppend(dst, value)
|
||||
// } else {
|
||||
dst = append(dst, value...)
|
||||
// }
|
||||
dst = append(dst, '"')
|
||||
if i < len(values)-1 {
|
||||
dst = append(dst, ',')
|
||||
@@ -300,29 +302,29 @@ func (f *F) Unmarshal(b []byte) (r []byte, err error) {
|
||||
}
|
||||
k := make([]byte, len(key))
|
||||
copy(k, key)
|
||||
switch key[1] {
|
||||
case 'e', 'p':
|
||||
// the tags must all be 64 character hexadecimal
|
||||
var ff [][]byte
|
||||
if ff, r, err = text2.UnmarshalHexArray(
|
||||
r,
|
||||
sha256.Size,
|
||||
); chk.E(err) {
|
||||
return
|
||||
}
|
||||
ff = append([][]byte{k}, ff...)
|
||||
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
|
||||
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
|
||||
default:
|
||||
// other types of tags can be anything
|
||||
var ff [][]byte
|
||||
if ff, r, err = text2.UnmarshalStringArray(r); chk.E(err) {
|
||||
return
|
||||
}
|
||||
ff = append([][]byte{k}, ff...)
|
||||
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
|
||||
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
|
||||
// switch key[1] {
|
||||
// case 'e', 'p':
|
||||
// // the tags must all be 64 character hexadecimal
|
||||
// var ff [][]byte
|
||||
// if ff, r, err = text2.UnmarshalHexArray(
|
||||
// r,
|
||||
// sha256.Size,
|
||||
// ); chk.E(err) {
|
||||
// return
|
||||
// }
|
||||
// ff = append([][]byte{k}, ff...)
|
||||
// f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
|
||||
// // f.Tags.F = append(f.Tags.F, tag.New(ff...))
|
||||
// default:
|
||||
// other types of tags can be anything
|
||||
var ff [][]byte
|
||||
if ff, r, err = text2.UnmarshalStringArray(r); chk.E(err) {
|
||||
return
|
||||
}
|
||||
ff = append([][]byte{k}, ff...)
|
||||
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
|
||||
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
|
||||
// }
|
||||
state = betweenKV
|
||||
case IDs[0]:
|
||||
if len(key) < len(IDs) {
|
||||
@@ -440,43 +442,55 @@ invalid:
|
||||
return
|
||||
}
|
||||
|
||||
// Matches checks a filter against an event and determines if the event matches the filter.
|
||||
func (f *F) Matches(ev *event.E) bool {
|
||||
// MatchesIgnoringTimestampConstraints checks a filter against an event and
|
||||
// determines if the event matches the filter, ignoring timestamp constraints..
|
||||
func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool {
|
||||
if ev == nil {
|
||||
// log.F.ToSliceOfBytes("nil event")
|
||||
log.I.F("nil event")
|
||||
return false
|
||||
}
|
||||
if f.Ids.Len() > 0 && !f.Ids.Contains(ev.ID) {
|
||||
// log.F.ToSliceOfBytes("no ids in filter match event\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
|
||||
log.I.F("no ids in filter match event")
|
||||
return false
|
||||
}
|
||||
if f.Kinds.Len() > 0 && !f.Kinds.Contains(ev.Kind) {
|
||||
// log.F.ToSliceOfBytes("no matching kinds in filter\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
|
||||
log.I.F(
|
||||
"no matching kinds in filter",
|
||||
)
|
||||
return false
|
||||
}
|
||||
if f.Authors.Len() > 0 && !f.Authors.Contains(ev.Pubkey) {
|
||||
// log.F.ToSliceOfBytes("no matching authors in filter\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
|
||||
log.I.F("no matching authors in filter")
|
||||
return false
|
||||
}
|
||||
if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) {
|
||||
return false
|
||||
}
|
||||
// if f.Tags.Len() > 0 {
|
||||
// for _, v := range f.Tags.ToSliceOfTags() {
|
||||
// tvs := v.ToSliceOfBytes()
|
||||
// if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
// return false
|
||||
// if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) {
|
||||
// return false
|
||||
// }
|
||||
if f.Tags.Len() > 0 {
|
||||
for _, v := range f.Tags.ToSliceOfTags() {
|
||||
tvs := v.ToSliceOfBytes()
|
||||
if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) {
|
||||
log.I.F("no matching tags in filter")
|
||||
return false
|
||||
}
|
||||
}
|
||||
// return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Matches checks a filter against an event and determines if the event matches the filter.
|
||||
func (f *F) Matches(ev *event.E) (match bool) {
|
||||
if !f.MatchesIgnoringTimestampConstraints(ev) {
|
||||
return
|
||||
}
|
||||
if f.Since.Int() != 0 && ev.CreatedAt.I64() < f.Since.I64() {
|
||||
// log.F.ToSliceOfBytes("event is older than since\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
|
||||
return false
|
||||
log.I.F("event is older than since")
|
||||
return
|
||||
}
|
||||
if f.Until.Int() != 0 && ev.CreatedAt.I64() > f.Until.I64() {
|
||||
// log.F.ToSliceOfBytes("event is newer than until\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
|
||||
return false
|
||||
log.I.F("event is newer than until")
|
||||
return
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -120,3 +120,12 @@ func GenFilters(n int) (ff *T, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f *T) MatchIgnoringTimestampConstraints(ev *event.E) bool {
|
||||
for _, ff := range f.F {
|
||||
if ff.MatchesIgnoringTimestampConstraints(ev) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package hex
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/templexxx/xhex"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
package kind
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"orly.dev/pkg/encoders/ints"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
@@ -71,6 +72,8 @@ var Privileged = []*T{
|
||||
GiftWrapWithKind4,
|
||||
JWTBinding,
|
||||
ApplicationSpecificData,
|
||||
Seal,
|
||||
PrivateDirectMessage,
|
||||
}
|
||||
|
||||
// IsPrivileged returns true if the type is the kind of message nobody else than the pubkeys in
|
||||
@@ -260,11 +263,11 @@ var (
|
||||
FileStorageServerList = &T{10096}
|
||||
// JWTBinding is an event kind that creates a link between a JWT certificate and a pubkey
|
||||
JWTBinding = &T{13004}
|
||||
// NWCWalletInfo is an event type that...
|
||||
NWCWalletInfo = &T{13194}
|
||||
WalletInfo = NWCWalletInfo
|
||||
// NWCWalletServiceInfo is an event type that...
|
||||
NWCWalletServiceInfo = &T{13194}
|
||||
WalletServiceInfo = &T{13194}
|
||||
// ReplaceableEnd is an event type that...
|
||||
ReplaceableEnd = &T{20000}
|
||||
ReplaceableEnd = &T{19999}
|
||||
// EphemeralStart is an event type that...
|
||||
EphemeralStart = &T{20000}
|
||||
LightningPubRPC = &T{21000}
|
||||
@@ -274,15 +277,16 @@ var (
|
||||
NWCWalletRequest = &T{23194}
|
||||
WalletRequest = &T{23194}
|
||||
// NWCWalletResponse is an event type that...
|
||||
NWCWalletResponse = &T{23195}
|
||||
WalletResponse = NWCWalletResponse
|
||||
NWCNotification = &T{23196}
|
||||
WalletNotification = NWCNotification
|
||||
NWCWalletResponse = &T{23195}
|
||||
WalletResponse = &T{23195}
|
||||
NWCNotification = &T{23196}
|
||||
WalletNotificationNip4 = &T{23196}
|
||||
WalletNotification = &T{23197}
|
||||
// NostrConnect is an event type that...
|
||||
NostrConnect = &T{24133}
|
||||
HTTPAuth = &T{27235}
|
||||
// EphemeralEnd is an event type that...
|
||||
EphemeralEnd = &T{30000}
|
||||
EphemeralEnd = &T{29999}
|
||||
// ParameterizedReplaceableStart is an event type that...
|
||||
ParameterizedReplaceableStart = &T{30000}
|
||||
// CategorizedPeopleList is an event type that...
|
||||
@@ -329,7 +333,7 @@ var (
|
||||
CommunityDefinition = &T{34550}
|
||||
ACLEvent = &T{39998}
|
||||
// ParameterizedReplaceableEnd is an event type that...
|
||||
ParameterizedReplaceableEnd = &T{40000}
|
||||
ParameterizedReplaceableEnd = &T{39999}
|
||||
)
|
||||
|
||||
var MapMx sync.Mutex
|
||||
@@ -380,11 +384,12 @@ var Map = map[uint16]string{
|
||||
UserEmojiList.K: "UserEmojiList",
|
||||
DMRelaysList.K: "DMRelaysList",
|
||||
FileStorageServerList.K: "FileStorageServerList",
|
||||
NWCWalletInfo.K: "NWCWalletInfo",
|
||||
NWCWalletServiceInfo.K: "NWCWalletServiceInfo",
|
||||
LightningPubRPC.K: "LightningPubRPC",
|
||||
ClientAuthentication.K: "ClientAuthentication",
|
||||
WalletRequest.K: "WalletRequest",
|
||||
WalletResponse.K: "WalletResponse",
|
||||
WalletNotificationNip4.K: "WalletNotificationNip4",
|
||||
WalletNotification.K: "WalletNotification",
|
||||
NostrConnect.K: "NostrConnect",
|
||||
HTTPAuth.K: "HTTPAuth",
|
||||
|
||||
@@ -5,6 +5,7 @@ package subscription
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"orly.dev/pkg/crypto/ec/bech32"
|
||||
"orly.dev/pkg/encoders/text"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
@@ -24,7 +25,7 @@ func (si *Id) IsValid() bool { return len(si.T) <= 64 && len(si.T) > 0 }
|
||||
|
||||
// NewId inspects a string and converts to Id if it is
|
||||
// valid. Invalid means length == 0 or length > 64.
|
||||
func NewId[V string | []byte](s V) (*Id, error) {
|
||||
func NewId[V ~string | ~[]byte](s V) (*Id, error) {
|
||||
si := &Id{T: []byte(s)}
|
||||
if si.IsValid() {
|
||||
return si, nil
|
||||
@@ -40,7 +41,7 @@ func NewId[V string | []byte](s V) (*Id, error) {
|
||||
// MustNew is the same as NewId except it doesn't check if you feed it rubbish.
|
||||
//
|
||||
// DO NOT USE WITHOUT CHECKING THE Id IS NOT NIL AND > 0 AND <= 64
|
||||
func MustNew[V string | []byte](s V) *Id {
|
||||
func MustNew[V ~string | ~[]byte](s V) *Id {
|
||||
return &Id{T: []byte(s)}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ package tag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
text2 "orly.dev/pkg/encoders/text"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
@@ -26,7 +27,7 @@ const (
|
||||
)
|
||||
|
||||
// BS is an abstract data type that can process strings and byte slices as byte slices.
|
||||
type BS[Z []byte | string] []byte
|
||||
type BS[Z ~[]byte | ~string] []byte
|
||||
|
||||
// T is a list of strings with a literal ordering.
|
||||
//
|
||||
@@ -36,7 +37,7 @@ type T struct {
|
||||
}
|
||||
|
||||
// New creates a new tag.T from a variadic parameter that can be either string or byte slice.
|
||||
func New[V string | []byte](fields ...V) (t *T) {
|
||||
func New[V ~string | ~[]byte](fields ...V) (t *T) {
|
||||
t = &T{field: make([]BS[[]byte], len(fields))}
|
||||
for i, field := range fields {
|
||||
t.field[i] = []byte(field)
|
||||
|
||||
@@ -7,12 +7,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/lol"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// T is a list of tag.T - which are lists of string elements with ordering and no uniqueness
|
||||
@@ -161,6 +162,15 @@ func (t *T) GetFirst(tagPrefix *tag.T) *tag.T {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *T) GetD() (d string) {
|
||||
for _, v := range t.element {
|
||||
if bytes.Equal(v.Key(), []byte("d")) {
|
||||
return string(v.Value())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetLast gets the last tag in tags that matches the prefix, see [T.StartsWith]
|
||||
func (t *T) GetLast(tagPrefix *tag.T) *tag.T {
|
||||
for i := len(t.element) - 1; i >= 0; i-- {
|
||||
@@ -299,7 +309,7 @@ func (t *T) ContainsAny(tagName []byte, values *tag.T) bool {
|
||||
continue
|
||||
}
|
||||
for _, candidate := range values.ToSliceOfBytes() {
|
||||
if bytes.Equal(v.Value(), candidate) {
|
||||
if bytes.HasPrefix(v.Value(), candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
package varint
|
||||
|
||||
import (
|
||||
"golang.org/x/exp/constraints"
|
||||
"io"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ package varint
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"orly.dev/pkg/encoders/codecbuf"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"testing"
|
||||
|
||||
"orly.dev/pkg/utils/chk"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestEncode_Decode(t *testing.T) {
|
||||
var v uint64
|
||||
for range 10000000 {
|
||||
v = uint64(frand.Intn(math.MaxInt64))
|
||||
buf1 := codecbuf.Get()
|
||||
buf1 := new(bytes.Buffer)
|
||||
Encode(buf1, v)
|
||||
buf2 := bytes.NewBuffer(buf1.Bytes())
|
||||
u, err := Decode(buf2)
|
||||
|
||||
@@ -3,7 +3,8 @@ package nwc
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/crypto/encryption"
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
@@ -19,596 +20,140 @@ import (
|
||||
"orly.dev/pkg/protocol/ws"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"orly.dev/pkg/utils/values"
|
||||
)
|
||||
|
||||
// Options represents options for a NWC client
|
||||
type Options struct {
|
||||
RelayURL string
|
||||
Secret signer.I
|
||||
WalletPubkey []byte
|
||||
Lud16 string
|
||||
}
|
||||
|
||||
// Client represents a NWC client
|
||||
type Client struct {
|
||||
options Options
|
||||
relay *ws.Client
|
||||
mu sync.Mutex
|
||||
client *ws.Client
|
||||
relay string
|
||||
clientSecretKey signer.I
|
||||
walletPublicKey []byte
|
||||
conversationKey []byte // nip44
|
||||
}
|
||||
|
||||
// ParseWalletConnectURL parses a wallet connect URL
|
||||
func ParseWalletConnectURL(walletConnectURL string) (opts *Options, err error) {
|
||||
if !strings.HasPrefix(walletConnectURL, "nostr+walletconnect://") {
|
||||
return nil, fmt.Errorf("unexpected scheme. Should be nostr+walletconnect://")
|
||||
}
|
||||
// Parse URL
|
||||
colonIndex := strings.Index(walletConnectURL, ":")
|
||||
if colonIndex == -1 {
|
||||
err = fmt.Errorf("invalid URL format")
|
||||
return
|
||||
}
|
||||
walletConnectURL = walletConnectURL[colonIndex+1:]
|
||||
if strings.HasPrefix(walletConnectURL, "//") {
|
||||
walletConnectURL = walletConnectURL[2:]
|
||||
}
|
||||
walletConnectURL = "https://" + walletConnectURL
|
||||
var u *url.URL
|
||||
if u, err = url.Parse(walletConnectURL); chk.E(err) {
|
||||
err = fmt.Errorf("failed to parse URL: %w", err)
|
||||
return
|
||||
}
|
||||
// Get wallet pubkey
|
||||
walletPubkey := u.Host
|
||||
if len(walletPubkey) != 64 {
|
||||
err = fmt.Errorf("incorrect wallet pubkey found in auth string")
|
||||
return
|
||||
}
|
||||
var pk []byte
|
||||
if pk, err = hex.Dec(walletPubkey); chk.E(err) {
|
||||
err = fmt.Errorf("failed to decode pubkey: %w", err)
|
||||
return
|
||||
}
|
||||
// Get relay URL
|
||||
relayURL := u.Query().Get("relay")
|
||||
if relayURL == "" {
|
||||
return nil, fmt.Errorf("no relay URL found in auth string")
|
||||
}
|
||||
// Get secret
|
||||
secret := u.Query().Get("secret")
|
||||
if secret == "" {
|
||||
return nil, fmt.Errorf("no secret found in auth string")
|
||||
}
|
||||
var sk []byte
|
||||
if sk, err = hex.Dec(secret); chk.E(err) {
|
||||
return
|
||||
}
|
||||
sign := &p256k.Signer{}
|
||||
if err = sign.InitSec(sk); chk.E(err) {
|
||||
return
|
||||
}
|
||||
opts = &Options{
|
||||
RelayURL: relayURL,
|
||||
Secret: sign,
|
||||
WalletPubkey: pk,
|
||||
}
|
||||
return
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params"`
|
||||
}
|
||||
|
||||
// NewNWCClient creates a new NWC client
|
||||
func NewNWCClient(options *Options) (cl *Client, err error) {
|
||||
if options.RelayURL == "" {
|
||||
err = fmt.Errorf("missing relay URL")
|
||||
return
|
||||
}
|
||||
if options.Secret == nil {
|
||||
err = fmt.Errorf("missing secret")
|
||||
return
|
||||
}
|
||||
if options.WalletPubkey == nil {
|
||||
err = fmt.Errorf("missing wallet pubkey")
|
||||
return
|
||||
}
|
||||
return &Client{
|
||||
options: Options{
|
||||
RelayURL: options.RelayURL,
|
||||
Secret: options.Secret,
|
||||
WalletPubkey: options.WalletPubkey,
|
||||
Lud16: options.Lud16,
|
||||
},
|
||||
}, nil
|
||||
type ResponseError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NostrWalletConnectURL returns the nostr wallet connect URL
|
||||
func (c *Client) NostrWalletConnectURL() string {
|
||||
return c.GetNostrWalletConnectURL(true)
|
||||
func (err *ResponseError) Error() string {
|
||||
return fmt.Sprintf("%s %s", err.Code, err.Message)
|
||||
}
|
||||
|
||||
// GetNostrWalletConnectURL returns the nostr wallet connect URL
|
||||
func (c *Client) GetNostrWalletConnectURL(includeSecret bool) string {
|
||||
params := url.Values{}
|
||||
params.Add("relay", c.options.RelayURL)
|
||||
if includeSecret {
|
||||
params.Add("secret", hex.Enc(c.options.Secret.Sec()))
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"nostr+walletconnect://%s?%s", c.options.WalletPubkey, params.Encode(),
|
||||
)
|
||||
type Response struct {
|
||||
ResultType string `json:"result_type"`
|
||||
Error *ResponseError `json:"error"`
|
||||
Result any `json:"result"`
|
||||
}
|
||||
|
||||
// Connected returns whether the client is connected to the relay
|
||||
func (c *Client) Connected() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.relay != nil && c.relay.IsConnected()
|
||||
}
|
||||
|
||||
// GetPublicKey returns the client's public key
|
||||
func (c *Client) GetPublicKey() (pubkey []byte, err error) {
|
||||
pubkey = c.options.Secret.Pub()
|
||||
return
|
||||
}
|
||||
|
||||
// Close closes the relay connection
|
||||
func (c *Client) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.relay != nil {
|
||||
c.relay.Close()
|
||||
c.relay = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt encrypts content for a pubkey
|
||||
func (c *Client) encrypt(pubkey, content []byte) (
|
||||
cipherText []byte, err error,
|
||||
) {
|
||||
var sharedSecret []byte
|
||||
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
|
||||
func NewClient(c context.T, connectionURI string) (cl *Client, err error) {
|
||||
var parts *ConnectionParams
|
||||
if parts, err = ParseConnectionURI(connectionURI); chk.E(err) {
|
||||
return
|
||||
}
|
||||
cipherText, err = encryption.EncryptNip4(content, sharedSecret)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt decrypts content from a pubkey
|
||||
func (c *Client) decrypt(pubkey, content []byte) (plaintext []byte, err error) {
|
||||
var sharedSecret []byte
|
||||
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
|
||||
clientKey := &p256k.Signer{}
|
||||
if err = clientKey.InitSec(parts.clientSecretKey); chk.E(err) {
|
||||
return
|
||||
}
|
||||
plaintext, err = encryption.DecryptNip4(content, sharedSecret)
|
||||
return
|
||||
}
|
||||
|
||||
// GetInfo gets wallet info
|
||||
func (c *Client) GetInfo() (response *GetInfoResponse, err error) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(GetInfo, nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &GetInfoResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetBudget gets wallet budget
|
||||
func (c *Client) GetBudget() (response *GetBudgetResponse, err error) {
|
||||
var result []byte
|
||||
result, err = c.executeRequest(GetBudget, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response = &GetBudgetResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetBalance gets wallet balance
|
||||
func (c *Client) GetBalance() (response *GetBalanceResponse, err error) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(GetBalance, nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &GetBalanceResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PayInvoice pays an invoice
|
||||
func (c *Client) PayInvoice(request *PayInvoiceRequest) (
|
||||
response *PayResponse, err error,
|
||||
) {
|
||||
var result []byte
|
||||
result, err = c.executeRequest(PayInvoice, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response = &PayResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PayKeysend sends a keysend payment
|
||||
func (c *Client) PayKeysend(request *PayKeysendRequest) (
|
||||
response *PayResponse, err error,
|
||||
) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(PayKeysend, request); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &PayResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MakeInvoice creates an invoice
|
||||
func (c *Client) MakeInvoice(request *MakeInvoiceRequest) (
|
||||
response *Transaction, err error,
|
||||
) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(MakeInvoice, request); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &Transaction{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// LookupInvoice looks up an invoice
|
||||
func (c *Client) LookupInvoice(request *LookupInvoiceRequest) (
|
||||
response *Transaction, err error,
|
||||
) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(LookupInvoice, request); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &Transaction{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ListTransactions lists transactions
|
||||
func (c *Client) ListTransactions(request *ListTransactionsRequest) (
|
||||
response *ListTransactionsResponse, err error,
|
||||
) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(ListTransactions, request); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &ListTransactionsResponse{}
|
||||
if err = json.Unmarshal(result, response); chk.E(err) {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SignMessage signs a message
|
||||
func (c *Client) SignMessage(request *SignMessageRequest) (
|
||||
response *SignMessageResponse, err error,
|
||||
) {
|
||||
var result []byte
|
||||
if result, err = c.executeRequest(SignMessage, request); chk.E(err) {
|
||||
return
|
||||
}
|
||||
response = &SignMessageResponse{}
|
||||
if err = json.Unmarshal(result, response); err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NotificationHandler is a function that handles notifications
|
||||
type NotificationHandler func(*Notification)
|
||||
|
||||
// SubscribeNotifications subscribes to notifications
|
||||
func (c *Client) SubscribeNotifications(
|
||||
handler NotificationHandler,
|
||||
notificationTypes []NotificationType,
|
||||
) (stop func(), err error) {
|
||||
if handler == nil {
|
||||
err = fmt.Errorf("missing notification handler")
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.Cancel(context.Bg())
|
||||
doneCh := make(chan struct{})
|
||||
stop = func() {
|
||||
cancel()
|
||||
<-doneCh
|
||||
}
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Check connection
|
||||
if err := c.checkConnected(); err != nil {
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
// Get client pubkey
|
||||
var clientPubkey []byte
|
||||
if clientPubkey, err = c.GetPublicKey(); chk.E(err) {
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
// Subscribe to events
|
||||
f := &filter.F{
|
||||
Kinds: kinds.New(kind.WalletResponse),
|
||||
Authors: tag.New(c.options.WalletPubkey),
|
||||
Tags: tags.New(tag.New([]byte("#p"), clientPubkey)),
|
||||
}
|
||||
var sub *ws.Subscription
|
||||
if sub, err = c.relay.Subscribe(
|
||||
context.Bg(), &filters.T{
|
||||
F: []*filter.F{f},
|
||||
},
|
||||
); chk.E(err) {
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
// Handle events
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sub.Close()
|
||||
return
|
||||
case ev := <-sub.Events:
|
||||
// Decrypt content
|
||||
var decryptedContent []byte
|
||||
if decryptedContent, err = c.decrypt(
|
||||
c.options.WalletPubkey, ev.Content,
|
||||
); chk.E(err) {
|
||||
log.E.F(
|
||||
"Failed to decrypt event content: %v\n", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Parse notification
|
||||
notification := &Notification{}
|
||||
if err = json.Unmarshal(
|
||||
decryptedContent, notification,
|
||||
); chk.E(err) {
|
||||
log.E.F(
|
||||
"Failed to parse notification: %v\n", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Check if notification type is requested
|
||||
if len(notificationTypes) > 0 {
|
||||
found := false
|
||||
for _, t := range notificationTypes {
|
||||
if notification.NotificationType == t {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Handle notification
|
||||
handler(notification)
|
||||
case <-sub.EndOfStoredEvents:
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// executeRequest executes a NIP-47 request
|
||||
func (c *Client) executeRequest(
|
||||
method Method,
|
||||
params any,
|
||||
) (msg json.RawMessage, err error) {
|
||||
// Default timeout values
|
||||
replyTimeout := 3 * time.Second
|
||||
publishTimeout := 3 * time.Second
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.Timeout(context.Bg(), replyTimeout)
|
||||
defer cancel()
|
||||
// Create result channel
|
||||
resultCh := make(chan json.RawMessage, 1)
|
||||
errCh := make(chan error, 1)
|
||||
// Check connection
|
||||
if err = c.checkConnected(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create request
|
||||
request := struct {
|
||||
Method Method `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}{
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
// Marshal request
|
||||
requestJSON, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
// Encrypt request
|
||||
var encryptedContent []byte
|
||||
if encryptedContent, err = c.encrypt(
|
||||
c.options.WalletPubkey, requestJSON,
|
||||
var ck []byte
|
||||
if ck, err = encryption.GenerateConversationKeyWithSigner(
|
||||
clientKey,
|
||||
parts.walletPublicKey,
|
||||
); chk.E(err) {
|
||||
return nil, fmt.Errorf("failed to encrypt request: %w", err)
|
||||
return
|
||||
}
|
||||
// Create request event
|
||||
requestEvent := &event.E{
|
||||
var relay *ws.Client
|
||||
if relay, err = ws.RelayConnect(c, parts.relay); chk.E(err) {
|
||||
return
|
||||
}
|
||||
cl = &Client{
|
||||
client: relay,
|
||||
relay: parts.relay,
|
||||
clientSecretKey: clientKey,
|
||||
walletPublicKey: parts.walletPublicKey,
|
||||
conversationKey: ck,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type rpcOptions struct {
|
||||
timeout *time.Duration
|
||||
}
|
||||
|
||||
func (cl *Client) RPC(
|
||||
c context.T, method Capability, params, result any, noUnmarshal bool,
|
||||
opts *rpcOptions,
|
||||
) (raw []byte, err error) {
|
||||
var req []byte
|
||||
if req, err = json.Marshal(
|
||||
Request{
|
||||
Method: string(method),
|
||||
Params: params,
|
||||
},
|
||||
); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var content []byte
|
||||
if content, err = encryption.Encrypt(req, cl.conversationKey); chk.E(err) {
|
||||
return
|
||||
}
|
||||
ev := &event.E{
|
||||
Content: content,
|
||||
CreatedAt: timestamp.Now(),
|
||||
Kind: kind.WalletRequest,
|
||||
CreatedAt: timestamp.New(time.Now().Unix()),
|
||||
Tags: tags.New(tag.New("p", hex.Enc(c.options.WalletPubkey))),
|
||||
Content: encryptedContent,
|
||||
Tags: tags.New(
|
||||
tag.New("p", hex.Enc(cl.walletPublicKey)),
|
||||
tag.New(EncryptionTag, Nip44V2),
|
||||
),
|
||||
}
|
||||
// Sign request event
|
||||
err = requestEvent.Sign(c.options.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign request event: %w", err)
|
||||
if err = ev.Sign(cl.clientSecretKey); chk.E(err) {
|
||||
return
|
||||
}
|
||||
// Subscribe to response events
|
||||
f := &filter.F{
|
||||
Kinds: kinds.New(kind.WalletResponse),
|
||||
Authors: tag.New(c.options.WalletPubkey),
|
||||
Tags: tags.New(tag.New([]byte("#p"), requestEvent.ID)),
|
||||
var rc *ws.Client
|
||||
if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) {
|
||||
return
|
||||
}
|
||||
log.I.F("%s", f.Marshal(nil))
|
||||
defer rc.Close()
|
||||
var sub *ws.Subscription
|
||||
if sub, err = c.relay.Subscribe(
|
||||
ctx, &filters.T{
|
||||
F: []*filter.F{f},
|
||||
},
|
||||
if sub, err = rc.Subscribe(
|
||||
c, filters.New(
|
||||
&filter.F{
|
||||
Limit: values.ToUintPointer(1),
|
||||
Kinds: kinds.New(kind.WalletResponse),
|
||||
Authors: tag.New(cl.walletPublicKey),
|
||||
Tags: tags.New(tag.New("#e", hex.Enc(ev.ID))),
|
||||
},
|
||||
),
|
||||
); chk.E(err) {
|
||||
err = fmt.Errorf(
|
||||
"failed to subscribe to response events: %w", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
// Set up reply timeout
|
||||
replyTimer := time.AfterFunc(
|
||||
replyTimeout, func() {
|
||||
errCh <- NewReplyTimeoutError(
|
||||
fmt.Sprintf("Timeout waiting for reply to %s", method),
|
||||
"TIMEOUT",
|
||||
)
|
||||
},
|
||||
)
|
||||
defer replyTimer.Stop()
|
||||
// Handle response events
|
||||
go func() {
|
||||
var resErr error
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ev := <-sub.Events:
|
||||
// Decrypt content
|
||||
var decryptedContent []byte
|
||||
decryptedContent, resErr = c.decrypt(
|
||||
c.options.WalletPubkey, ev.Content,
|
||||
)
|
||||
if chk.E(resErr) {
|
||||
errCh <- fmt.Errorf(
|
||||
"failed to decrypt response: %w",
|
||||
resErr,
|
||||
)
|
||||
return
|
||||
}
|
||||
// Parse response
|
||||
var response struct {
|
||||
ResultType string `json:"result_type"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if resErr = json.Unmarshal(
|
||||
decryptedContent, &response,
|
||||
); chk.E(resErr) {
|
||||
errCh <- fmt.Errorf("failed to parse response: %w", resErr)
|
||||
return
|
||||
}
|
||||
// Check for error
|
||||
if response.Error != nil {
|
||||
errCh <- NewWalletError(
|
||||
response.Error.Message,
|
||||
response.Error.Code,
|
||||
)
|
||||
return
|
||||
}
|
||||
// Send result
|
||||
resultCh <- response.Result
|
||||
return
|
||||
case <-sub.EndOfStoredEvents:
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Publish request event
|
||||
publishCtx, publishCancel := context.Timeout(
|
||||
context.Bg(), publishTimeout,
|
||||
)
|
||||
defer publishCancel()
|
||||
if err = c.relay.Publish(publishCtx, requestEvent); chk.E(err) {
|
||||
err = fmt.Errorf("failed to publish request event: %w", err)
|
||||
defer sub.Unsub()
|
||||
if err = rc.Publish(context.Bg(), ev); chk.E(err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for result or error
|
||||
select {
|
||||
case msg = <-resultCh:
|
||||
return
|
||||
case err = <-errCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
err = NewReplyTimeoutError(
|
||||
fmt.Sprintf("Timeout waiting for reply to %s", method),
|
||||
"TIMEOUT",
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// checkConnected checks if the client is connected to the relay
|
||||
func (c *Client) checkConnected() (err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.options.RelayURL == "" {
|
||||
return fmt.Errorf("missing relay URL")
|
||||
}
|
||||
|
||||
if c.relay == nil {
|
||||
if c.relay, err = ws.RelayConnect(
|
||||
context.Bg(), c.options.RelayURL,
|
||||
case <-c.Done():
|
||||
err = fmt.Errorf("context canceled waiting for response")
|
||||
case e := <-sub.Events:
|
||||
if raw, err = encryption.Decrypt(
|
||||
e.Content, cl.conversationKey,
|
||||
); chk.E(err) {
|
||||
return NewNetworkError(
|
||||
"Failed to connect to "+c.options.RelayURL,
|
||||
"OTHER",
|
||||
)
|
||||
return
|
||||
}
|
||||
} else if !c.relay.IsConnected() {
|
||||
c.relay.Close()
|
||||
if c.relay, err = ws.RelayConnect(
|
||||
context.Bg(), c.options.RelayURL,
|
||||
); chk.E(err) {
|
||||
return NewNetworkError(
|
||||
"Failed to connect to "+c.options.RelayURL,
|
||||
"OTHER",
|
||||
)
|
||||
if noUnmarshal {
|
||||
return
|
||||
}
|
||||
resp := &Response{
|
||||
Result: &result,
|
||||
}
|
||||
if err = json.Unmarshal(raw, resp); chk.E(err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
174
pkg/protocol/nwc/methods.go
Normal file
174
pkg/protocol/nwc/methods.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package nwc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/encoders/filter"
|
||||
"orly.dev/pkg/encoders/filters"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/kinds"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/protocol/ws"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/values"
|
||||
)
|
||||
|
||||
func (cl *Client) GetWalletServiceInfo(c context.T, noUnmarshal bool) (
|
||||
wsi *WalletServiceInfo, raw []byte, err error,
|
||||
) {
|
||||
ctx, cancel := context.Timeout(c, 10*time.Second)
|
||||
defer cancel()
|
||||
var rc *ws.Client
|
||||
if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var sub *ws.Subscription
|
||||
if sub, err = rc.Subscribe(
|
||||
ctx, filters.New(
|
||||
&filter.F{
|
||||
Limit: values.ToUintPointer(1),
|
||||
Kinds: kinds.New(kind.WalletServiceInfo),
|
||||
Authors: tag.New(cl.walletPublicKey),
|
||||
},
|
||||
),
|
||||
); chk.E(err) {
|
||||
return
|
||||
}
|
||||
defer sub.Unsub()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = fmt.Errorf("context canceled")
|
||||
return
|
||||
case e := <-sub.Events:
|
||||
raw = e.Marshal(nil)
|
||||
if noUnmarshal {
|
||||
return
|
||||
}
|
||||
wsi = &WalletServiceInfo{}
|
||||
encTag := e.Tags.GetFirst(tag.New(EncryptionTag))
|
||||
notTag := e.Tags.GetFirst(tag.New(NotificationTag))
|
||||
if encTag != nil {
|
||||
et := bytes.Split(encTag.Value(), []byte(" "))
|
||||
for _, v := range et {
|
||||
wsi.EncryptionTypes = append(wsi.EncryptionTypes, v)
|
||||
}
|
||||
}
|
||||
if notTag != nil {
|
||||
nt := bytes.Split(notTag.Value(), []byte(" "))
|
||||
for _, v := range nt {
|
||||
wsi.NotificationTypes = append(wsi.NotificationTypes, v)
|
||||
}
|
||||
}
|
||||
caps := bytes.Split(e.Content, []byte(" "))
|
||||
for _, v := range caps {
|
||||
wsi.Capabilities = append(wsi.Capabilities, v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) CancelHoldInvoice(
|
||||
c context.T, chi *CancelHoldInvoiceParams, noUnmarshal bool,
|
||||
) (raw []byte, err error) {
|
||||
return cl.RPC(c, CancelHoldInvoice, chi, nil, noUnmarshal, nil)
|
||||
}
|
||||
|
||||
func (cl *Client) CreateConnection(
|
||||
c context.T, cc *CreateConnectionParams, noUnmarshal bool,
|
||||
) (raw []byte, err error) {
|
||||
return cl.RPC(c, CreateConnection, cc, nil, noUnmarshal, nil)
|
||||
}
|
||||
|
||||
func (cl *Client) GetBalance(c context.T, noUnmarshal bool) (
|
||||
gb *GetBalanceResult, raw []byte, err error,
|
||||
) {
|
||||
gb = &GetBalanceResult{}
|
||||
raw, err = cl.RPC(c, GetBalance, nil, gb, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) GetBudget(c context.T, noUnmarshal bool) (
|
||||
gb *GetBudgetResult, raw []byte, err error,
|
||||
) {
|
||||
gb = &GetBudgetResult{}
|
||||
raw, err = cl.RPC(c, GetBudget, nil, gb, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) GetInfo(c context.T, noUnmarshal bool) (
|
||||
gi *GetInfoResult, raw []byte, err error,
|
||||
) {
|
||||
gi = &GetInfoResult{}
|
||||
raw, err = cl.RPC(c, GetInfo, nil, gi, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) ListTransactions(
|
||||
c context.T, params *ListTransactionsParams, noUnmarshal bool,
|
||||
) (lt *ListTransactionsResult, raw []byte, err error) {
|
||||
lt = &ListTransactionsResult{}
|
||||
raw, err = cl.RPC(c, ListTransactions, params, <, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) LookupInvoice(
|
||||
c context.T, params *LookupInvoiceParams, noUnmarshal bool,
|
||||
) (li *LookupInvoiceResult, raw []byte, err error) {
|
||||
li = &LookupInvoiceResult{}
|
||||
raw, err = cl.RPC(c, LookupInvoice, params, &li, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) MakeHoldInvoice(
|
||||
c context.T,
|
||||
mhi *MakeHoldInvoiceParams, noUnmarshal bool,
|
||||
) (mi *MakeInvoiceResult, raw []byte, err error) {
|
||||
mi = &MakeInvoiceResult{}
|
||||
raw, err = cl.RPC(c, MakeHoldInvoice, mhi, mi, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) MakeInvoice(
|
||||
c context.T, params *MakeInvoiceParams, noUnmarshal bool,
|
||||
) (mi *MakeInvoiceResult, raw []byte, err error) {
|
||||
mi = &MakeInvoiceResult{}
|
||||
raw, err = cl.RPC(c, MakeInvoice, params, &mi, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// MultiPayInvoice
|
||||
|
||||
// MultiPayKeysend
|
||||
|
||||
func (cl *Client) PayKeysend(
|
||||
c context.T, params *PayKeysendParams, noUnmarshal bool,
|
||||
) (pk *PayKeysendResult, raw []byte, err error) {
|
||||
pk = &PayKeysendResult{}
|
||||
raw, err = cl.RPC(c, PayKeysend, params, &pk, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) PayInvoice(
|
||||
c context.T, params *PayInvoiceParams, noUnmarshal bool,
|
||||
) (pi *PayInvoiceResult, raw []byte, err error) {
|
||||
pi = &PayInvoiceResult{}
|
||||
raw, err = cl.RPC(c, PayInvoice, params, &pi, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (cl *Client) SettleHoldInvoice(
|
||||
c context.T, shi *SettleHoldInvoiceParams, noUnmarshal bool,
|
||||
) (raw []byte, err error) {
|
||||
return cl.RPC(c, SettleHoldInvoice, shi, nil, noUnmarshal, nil)
|
||||
}
|
||||
|
||||
func (cl *Client) SignMessage(
|
||||
c context.T, sm *SignMessageParams, noUnmarshal bool,
|
||||
) (res *SignMessageResult, raw []byte, err error) {
|
||||
res = &SignMessageResult{}
|
||||
raw, err = cl.RPC(c, SignMessage, sm, &res, noUnmarshal, nil)
|
||||
return
|
||||
}
|
||||
@@ -1,473 +1,190 @@
|
||||
package nwc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
// Capability represents a NIP-47 method
|
||||
type Capability []byte
|
||||
|
||||
var (
|
||||
CancelHoldInvoice = Capability("cancel_hold_invoice")
|
||||
CreateConnection = Capability("create_connection")
|
||||
GetBalance = Capability("get_balance")
|
||||
GetBudget = Capability("get_budget")
|
||||
GetInfo = Capability("get_info")
|
||||
ListTransactions = Capability("list_transactions")
|
||||
LookupInvoice = Capability("lookup_invoice")
|
||||
MakeHoldInvoice = Capability("make_hold_invoice")
|
||||
MakeInvoice = Capability("make_invoice")
|
||||
MultiPayInvoice = Capability("multi_pay_invoice")
|
||||
MultiPayKeysend = Capability("multi_pay_keysend")
|
||||
PayInvoice = Capability("pay_invoice")
|
||||
PayKeysend = Capability("pay_keysend")
|
||||
SettleHoldInvoice = Capability("settle_hold_invoice")
|
||||
SignMessage = Capability("sign_message")
|
||||
)
|
||||
|
||||
// EncryptionType represents the encryption type used for NIP-47 messages
|
||||
type EncryptionType string
|
||||
type EncryptionType []byte
|
||||
|
||||
const (
|
||||
Nip04 EncryptionType = "nip04"
|
||||
Nip44V2 EncryptionType = "nip44_v2"
|
||||
var (
|
||||
EncryptionTag = []byte("encryption")
|
||||
Nip04 = EncryptionType("nip04")
|
||||
Nip44V2 = EncryptionType("nip44_v2")
|
||||
)
|
||||
|
||||
// AuthorizationUrlOptions represents options for creating an NWC authorization URL
|
||||
type AuthorizationUrlOptions struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
RequestMethods []Method `json:"requestMethods,omitempty"`
|
||||
NotificationTypes []NotificationType `json:"notificationTypes,omitempty"`
|
||||
ReturnTo string `json:"returnTo,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||
MaxAmount *int64 `json:"maxAmount,omitempty"`
|
||||
BudgetRenewal BudgetRenewalPeriod `json:"budgetRenewal,omitempty"`
|
||||
Isolated bool `json:"isolated,omitempty"`
|
||||
Metadata interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
type NotificationType []byte
|
||||
|
||||
// Err is the base error type for NIP-47 errors
|
||||
type Err struct {
|
||||
Message string
|
||||
Code string
|
||||
}
|
||||
|
||||
func (e *Err) Error() string {
|
||||
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
|
||||
}
|
||||
|
||||
// NewError creates a new Error
|
||||
func NewError(message, code string) *Err {
|
||||
return &Err{
|
||||
Message: message,
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkError represents a network error in NIP-47 operations
|
||||
type NetworkError struct{ *Err }
|
||||
|
||||
// NewNetworkError creates a new NetworkError
|
||||
func NewNetworkError(message, code string) *NetworkError {
|
||||
return &NetworkError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// WalletError represents a wallet error in NIP-47 operations
|
||||
type WalletError struct {
|
||||
*Err
|
||||
}
|
||||
|
||||
// NewWalletError creates a new WalletError
|
||||
func NewWalletError(message, code string) *WalletError {
|
||||
return &WalletError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// TimeoutError represents a timeout error in NIP-47 operations
|
||||
type TimeoutError struct{ *Err }
|
||||
|
||||
// NewTimeoutError creates a new TimeoutError
|
||||
func NewTimeoutError(message, code string) *TimeoutError {
|
||||
return &TimeoutError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// PublishTimeoutError represents a publish timeout error in NIP-47 operations
|
||||
type PublishTimeoutError struct{ *TimeoutError }
|
||||
|
||||
// NewPublishTimeoutError creates a new PublishTimeoutError
|
||||
func NewPublishTimeoutError(message, code string) *PublishTimeoutError {
|
||||
return &PublishTimeoutError{
|
||||
TimeoutError: NewTimeoutError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// ReplyTimeoutError represents a reply timeout error in NIP-47 operations
|
||||
type ReplyTimeoutError struct{ *TimeoutError }
|
||||
|
||||
// NewReplyTimeoutError creates a new ReplyTimeoutError
|
||||
func NewReplyTimeoutError(message, code string) *ReplyTimeoutError {
|
||||
return &ReplyTimeoutError{
|
||||
TimeoutError: NewTimeoutError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// PublishError represents a publish error in NIP-47 operations
|
||||
type PublishError struct{ *Err }
|
||||
|
||||
// NewPublishError creates a new PublishError
|
||||
func NewPublishError(message, code string) *PublishError {
|
||||
return &PublishError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseDecodingError represents a response decoding error in NIP-47 operations
|
||||
type ResponseDecodingError struct{ *Err }
|
||||
|
||||
// NewResponseDecodingError creates a new ResponseDecodingError
|
||||
func NewResponseDecodingError(message, code string) *ResponseDecodingError {
|
||||
return &ResponseDecodingError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseValidationError represents a response validation error in NIP-47 operations
|
||||
type ResponseValidationError struct{ *Err }
|
||||
|
||||
// NewResponseValidationError creates a new ResponseValidationError
|
||||
func NewResponseValidationError(message, code string) *ResponseValidationError {
|
||||
return &ResponseValidationError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// UnexpectedResponseError represents an unexpected response error in NIP-47 operations
|
||||
type UnexpectedResponseError struct{ *Err }
|
||||
|
||||
// NewUnexpectedResponseError creates a new UnexpectedResponseError
|
||||
func NewUnexpectedResponseError(message, code string) *UnexpectedResponseError {
|
||||
return &UnexpectedResponseError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedEncryptionError represents an unsupported encryption error in NIP-47 operations
|
||||
type UnsupportedEncryptionError struct {
|
||||
*Err
|
||||
}
|
||||
|
||||
// NewUnsupportedEncryptionError creates a new UnsupportedEncryptionError
|
||||
func NewUnsupportedEncryptionError(message, code string) *UnsupportedEncryptionError {
|
||||
return &UnsupportedEncryptionError{
|
||||
Err: NewError(message, code),
|
||||
}
|
||||
}
|
||||
|
||||
// WithDTag represents a type with a dTag field
|
||||
type WithDTag struct {
|
||||
DTag string `json:"dTag"`
|
||||
}
|
||||
|
||||
// WithOptionalId represents a type with an optional id field
|
||||
type WithOptionalId struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// Method represents a NIP-47 method
|
||||
type Method string
|
||||
|
||||
// SingleMethod represents a single NIP-47 method
|
||||
const (
|
||||
GetInfo Method = "get_info"
|
||||
GetBalance Method = "get_balance"
|
||||
GetBudget Method = "get_budget"
|
||||
MakeInvoice Method = "make_invoice"
|
||||
PayInvoice Method = "pay_invoice"
|
||||
PayKeysend Method = "pay_keysend"
|
||||
LookupInvoice Method = "lookup_invoice"
|
||||
ListTransactions Method = "list_transactions"
|
||||
SignMessage Method = "sign_message"
|
||||
CreateConnection Method = "create_connection"
|
||||
MakeHoldInvoice Method = "make_hold_invoice"
|
||||
SettleHoldInvoice Method = "settle_hold_invoice"
|
||||
CancelHoldInvoice Method = "cancel_hold_invoice"
|
||||
var (
|
||||
NotificationTag = []byte("notification")
|
||||
PaymentReceived = NotificationType("payment_received")
|
||||
PaymentSent = NotificationType("payment_sent")
|
||||
HoldInvoiceAccepted = NotificationType("hold_invoice_accepted")
|
||||
)
|
||||
|
||||
// MultiMethod represents a multi NIP-47 method
|
||||
const (
|
||||
MultiPayInvoice Method = "multi_pay_invoice"
|
||||
MultiPayKeysend Method = "multi_pay_keysend"
|
||||
)
|
||||
|
||||
// Capability represents a NIP-47 capability
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
Notifications Capability = "notifications"
|
||||
)
|
||||
|
||||
// BudgetRenewalPeriod represents a budget renewal period
|
||||
type BudgetRenewalPeriod string
|
||||
|
||||
const (
|
||||
Daily BudgetRenewalPeriod = "daily"
|
||||
Weekly BudgetRenewalPeriod = "weekly"
|
||||
Monthly BudgetRenewalPeriod = "monthly"
|
||||
Yearly BudgetRenewalPeriod = "yearly"
|
||||
Never BudgetRenewalPeriod = "never"
|
||||
)
|
||||
|
||||
// GetInfoResponse represents a response to a get_info request
|
||||
type GetInfoResponse struct {
|
||||
Alias string `json:"alias"`
|
||||
Color string `json:"color"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
Network string `json:"network"`
|
||||
BlockHeight int64 `json:"block_height"`
|
||||
BlockHash string `json:"block_hash"`
|
||||
Methods []Method `json:"methods"`
|
||||
Notifications []NotificationType `json:"notifications,omitempty"`
|
||||
Metadata interface{} `json:"metadata,omitempty"`
|
||||
Lud16 string `json:"lud16,omitempty"`
|
||||
type WalletServiceInfo struct {
|
||||
EncryptionTypes []EncryptionType
|
||||
Capabilities []Capability
|
||||
NotificationTypes []NotificationType
|
||||
}
|
||||
|
||||
// GetBudgetResponse represents a response to a get_budget request
|
||||
type GetBudgetResponse struct {
|
||||
UsedBudget int64 `json:"used_budget,omitempty"`
|
||||
TotalBudget int64 `json:"total_budget,omitempty"`
|
||||
RenewsAt *int64 `json:"renews_at,omitempty"`
|
||||
RenewalPeriod BudgetRenewalPeriod `json:"renewal_period,omitempty"`
|
||||
type GetInfoResult struct {
|
||||
Alias string `json:"alias"`
|
||||
Color string `json:"color"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
Network string `json:"network"`
|
||||
BlockHeight uint64 `json:"block_height"`
|
||||
BlockHash string `json:"block_hash"`
|
||||
Methods []string `json:"methods"`
|
||||
Notifications []string `json:"notifications,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
LUD16 string `json:"lud16,omitempty"`
|
||||
}
|
||||
|
||||
// GetBalanceResponse represents a response to a get_balance request
|
||||
type GetBalanceResponse struct {
|
||||
Balance int64 `json:"balance"` // msats
|
||||
type GetBudgetResult struct {
|
||||
UsedBudget int `json:"used_budget,omitempty"`
|
||||
TotalBudget int `json:"total_budget,omitempty"`
|
||||
RenewsAt int `json:"renews_at,omitempty"`
|
||||
RenewalPeriod string `json:"renewal_period,omitempty"`
|
||||
}
|
||||
|
||||
// PayResponse represents a response to a pay request
|
||||
type PayResponse struct {
|
||||
type GetBalanceResult struct {
|
||||
Balance uint64 `json:"balance"`
|
||||
}
|
||||
|
||||
type MakeInvoiceParams struct {
|
||||
Amount uint64 `json:"amount"`
|
||||
Description string `json:"description,omitempty"`
|
||||
DescriptionHash string `json:"description_hash,omitempty"`
|
||||
Expiry *int64 `json:"expiry,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type MakeHoldInvoiceParams struct {
|
||||
Amount uint64 `json:"amount"`
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
Description string `json:"description,omitempty"`
|
||||
DescriptionHash string `json:"description_hash,omitempty"`
|
||||
Expiry *int64 `json:"expiry,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type SettleHoldInvoiceParams struct {
|
||||
Preimage string `json:"preimage"`
|
||||
FeesPaid int64 `json:"fees_paid"`
|
||||
}
|
||||
|
||||
// MultiPayInvoiceRequest represents a request to pay multiple invoices
|
||||
type MultiPayInvoiceRequest struct {
|
||||
Invoices []PayInvoiceRequestWithID `json:"invoices"`
|
||||
type CancelHoldInvoiceParams struct {
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
}
|
||||
|
||||
// PayInvoiceRequestWithID combines PayInvoiceRequest with WithOptionalId
|
||||
type PayInvoiceRequestWithID struct {
|
||||
PayInvoiceRequest
|
||||
WithOptionalId
|
||||
type PayInvoicePayerData struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
}
|
||||
|
||||
// MultiPayKeysendRequest represents a request to pay multiple keysends
|
||||
type MultiPayKeysendRequest struct {
|
||||
Keysends []PayKeysendRequestWithID `json:"keysends"`
|
||||
type PayInvoiceMetadata struct {
|
||||
Comment *string `json:"comment"`
|
||||
PayerData *PayInvoicePayerData `json:"payer_data"`
|
||||
Other any
|
||||
}
|
||||
|
||||
// PayKeysendRequestWithID combines PayKeysendRequest with WithOptionalId
|
||||
type PayKeysendRequestWithID struct {
|
||||
PayKeysendRequest
|
||||
WithOptionalId
|
||||
type PayInvoiceParams struct {
|
||||
Invoice string `json:"invoice"`
|
||||
Amount *uint64 `json:"amount,omitempty"`
|
||||
Metadata *PayInvoiceMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MultiPayInvoiceResponse represents a response to a multi_pay_invoice request
|
||||
type MultiPayInvoiceResponse struct {
|
||||
Invoices []MultiPayInvoiceResponseItem `json:"invoices"`
|
||||
Errors []interface{} `json:"errors"` // TODO: add error handling
|
||||
type PayInvoiceResult struct {
|
||||
Preimage string `json:"preimage"`
|
||||
FeesPaid uint64 `json:"fees_paid"`
|
||||
}
|
||||
|
||||
// MultiPayInvoiceResponseItem represents an item in a multi_pay_invoice response
|
||||
type MultiPayInvoiceResponseItem struct {
|
||||
Invoice PayInvoiceRequest `json:"invoice"`
|
||||
PayResponse
|
||||
WithDTag
|
||||
}
|
||||
|
||||
// MultiPayKeysendResponse represents a response to a multi_pay_keysend request
|
||||
type MultiPayKeysendResponse struct {
|
||||
Keysends []MultiPayKeysendResponseItem `json:"keysends"`
|
||||
Errors []interface{} `json:"errors"` // TODO: add error handling
|
||||
}
|
||||
|
||||
// MultiPayKeysendResponseItem represents an item in a multi_pay_keysend response
|
||||
type MultiPayKeysendResponseItem struct {
|
||||
Keysend PayKeysendRequest `json:"keysend"`
|
||||
PayResponse
|
||||
WithDTag
|
||||
}
|
||||
|
||||
// ListTransactionsRequest represents a request to list transactions
|
||||
type ListTransactionsRequest struct {
|
||||
From *int64 `json:"from,omitempty"`
|
||||
Until *int64 `json:"until,omitempty"`
|
||||
Limit *int64 `json:"limit,omitempty"`
|
||||
Offset *int64 `json:"offset,omitempty"`
|
||||
Unpaid *bool `json:"unpaid,omitempty"`
|
||||
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||
Type *string `json:"type,omitempty"` // "incoming" or "outgoing"
|
||||
}
|
||||
|
||||
// ListTransactionsResponse represents a response to a list_transactions request
|
||||
type ListTransactionsResponse struct {
|
||||
Transactions []Transaction `json:"transactions"`
|
||||
TotalCount int64 `json:"total_count"` // NOTE: non-NIP-47 spec compliant
|
||||
}
|
||||
|
||||
// TransactionType represents the type of a transaction
|
||||
type TransactionType string
|
||||
|
||||
const (
|
||||
Incoming TransactionType = "incoming"
|
||||
Outgoing TransactionType = "outgoing"
|
||||
)
|
||||
|
||||
// TransactionState represents the state of a transaction
|
||||
type TransactionState string
|
||||
|
||||
const (
|
||||
Settled TransactionState = "settled"
|
||||
Pending TransactionState = "pending"
|
||||
Failed TransactionState = "failed"
|
||||
)
|
||||
|
||||
// Transaction represents a transaction
|
||||
type Transaction struct {
|
||||
Type TransactionType `json:"type"`
|
||||
State TransactionState `json:"state"` // NOTE: non-NIP-47 spec compliant
|
||||
Invoice string `json:"invoice"`
|
||||
Description string `json:"description"`
|
||||
DescriptionHash string `json:"description_hash"`
|
||||
Preimage string `json:"preimage"`
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
Amount int64 `json:"amount"`
|
||||
FeesPaid int64 `json:"fees_paid"`
|
||||
SettledAt int64 `json:"settled_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
SettleDeadline *int64 `json:"settle_deadline,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// TransactionMetadata represents metadata for a transaction
|
||||
type TransactionMetadata struct {
|
||||
Comment string `json:"comment,omitempty"` // LUD-12
|
||||
PayerData *PayerData `json:"payer_data,omitempty"` // LUD-18
|
||||
RecipientData *RecipientData `json:"recipient_data,omitempty"` // LUD-18
|
||||
Nostr *NostrData `json:"nostr,omitempty"` // NIP-57
|
||||
ExtraData map[string]interface{} `json:"-"` // For additional fields
|
||||
}
|
||||
|
||||
// PayerData represents payer data for a transaction
|
||||
type PayerData struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Pubkey string `json:"pubkey,omitempty"`
|
||||
}
|
||||
|
||||
// RecipientData represents recipient data for a transaction
|
||||
type RecipientData struct {
|
||||
Identifier string `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
// NostrData represents Nostr data for a transaction
|
||||
type NostrData struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
Tags [][]string `json:"tags"`
|
||||
}
|
||||
|
||||
// NotificationType represents a notification type
|
||||
type NotificationType string
|
||||
|
||||
const (
|
||||
PaymentReceived NotificationType = "payment_received"
|
||||
PaymentSent NotificationType = "payment_sent"
|
||||
HoldInvoiceAccepted NotificationType = "hold_invoice_accepted"
|
||||
)
|
||||
|
||||
// Notification represents a notification
|
||||
type Notification struct {
|
||||
NotificationType NotificationType `json:"notification_type"`
|
||||
Notification Transaction `json:"notification"`
|
||||
}
|
||||
|
||||
// PayInvoiceRequest represents a request to pay an invoice
|
||||
type PayInvoiceRequest struct {
|
||||
Invoice string `json:"invoice"`
|
||||
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||
Amount *int64 `json:"amount,omitempty"` // msats
|
||||
}
|
||||
|
||||
// PayKeysendRequest represents a request to pay a keysend
|
||||
type PayKeysendRequest struct {
|
||||
Amount int64 `json:"amount"` // msats
|
||||
Pubkey string `json:"pubkey"`
|
||||
Preimage string `json:"preimage,omitempty"`
|
||||
TlvRecords []TlvRecord `json:"tlv_records,omitempty"`
|
||||
}
|
||||
|
||||
// TlvRecord represents a TLV record
|
||||
type TlvRecord struct {
|
||||
Type int64 `json:"type"`
|
||||
type PayKeysendTLVRecord struct {
|
||||
Type uint32 `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// MakeInvoiceRequest represents a request to make an invoice
|
||||
type MakeInvoiceRequest struct {
|
||||
Amount int64 `json:"amount"` // msats
|
||||
Description string `json:"description,omitempty"`
|
||||
DescriptionHash string `json:"description_hash,omitempty"`
|
||||
Expiry *int64 `json:"expiry,omitempty"` // in seconds
|
||||
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||
type PayKeysendParams struct {
|
||||
Amount uint64 `json:"amount"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
Preimage *string `json:"preimage,omitempty"`
|
||||
TLVRecords []PayKeysendTLVRecord `json:"tlv_records,omitempty"`
|
||||
}
|
||||
|
||||
// MakeHoldInvoiceRequest represents a request to make a hold invoice
|
||||
type MakeHoldInvoiceRequest struct {
|
||||
MakeInvoiceRequest
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
type PayKeysendResult = PayInvoiceResult
|
||||
|
||||
type LookupInvoiceParams struct {
|
||||
PaymentHash *string `json:"payment_hash,omitempty"`
|
||||
Invoice *string `json:"invoice,omitempty"`
|
||||
}
|
||||
|
||||
// SettleHoldInvoiceRequest represents a request to settle a hold invoice
|
||||
type SettleHoldInvoiceRequest struct {
|
||||
Preimage string `json:"preimage"`
|
||||
type ListTransactionsParams struct {
|
||||
From *int64 `json:"from,omitempty"`
|
||||
Until *int64 `json:"until,omitempty"`
|
||||
Limit *uint16 `json:"limit,omitempty"`
|
||||
Offset *uint32 `json:"offset,omitempty"`
|
||||
Unpaid *bool `json:"unpaid,omitempty"`
|
||||
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"`
|
||||
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"`
|
||||
Type *string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// SettleHoldInvoiceResponse represents a response to a settle_hold_invoice request
|
||||
type SettleHoldInvoiceResponse struct{}
|
||||
|
||||
// CancelHoldInvoiceRequest represents a request to cancel a hold invoice
|
||||
type CancelHoldInvoiceRequest struct {
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
type MakeInvoiceResult = Transaction
|
||||
type LookupInvoiceResult = Transaction
|
||||
type ListTransactionsResult struct {
|
||||
Transactions []Transaction `json:"transactions"`
|
||||
TotalCount uint32 `json:"total_count"`
|
||||
}
|
||||
|
||||
// CancelHoldInvoiceResponse represents a response to a cancel_hold_invoice request
|
||||
type CancelHoldInvoiceResponse struct{}
|
||||
|
||||
// LookupInvoiceRequest represents a request to lookup an invoice
|
||||
type LookupInvoiceRequest struct {
|
||||
PaymentHash string `json:"payment_hash,omitempty"`
|
||||
Invoice string `json:"invoice,omitempty"`
|
||||
type Transaction struct {
|
||||
Type string `json:"type"`
|
||||
State string `json:"state"`
|
||||
Invoice string `json:"invoice"`
|
||||
Description string `json:"description"`
|
||||
DescriptionHash string `json:"description_hash"`
|
||||
Preimage string `json:"preimage"`
|
||||
PaymentHash string `json:"payment_hash"`
|
||||
Amount uint64 `json:"amount"`
|
||||
FeesPaid uint64 `json:"fees_paid"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
SettledDeadline *uint64 `json:"settled_deadline,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// SignMessageRequest represents a request to sign a message
|
||||
type SignMessageRequest struct {
|
||||
type SignMessageParams struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// CreateConnectionRequest represents a request to create a connection
|
||||
type CreateConnectionRequest struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
Name string `json:"name"`
|
||||
RequestMethods []Method `json:"request_methods"`
|
||||
NotificationTypes []NotificationType `json:"notification_types,omitempty"`
|
||||
MaxAmount *int64 `json:"max_amount,omitempty"`
|
||||
BudgetRenewal *BudgetRenewalPeriod `json:"budget_renewal,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
Isolated *bool `json:"isolated,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// CreateConnectionResponse represents a response to a create_connection request
|
||||
type CreateConnectionResponse struct {
|
||||
WalletPubkey string `json:"wallet_pubkey"`
|
||||
}
|
||||
|
||||
// SignMessageResponse represents a response to a sign_message request
|
||||
type SignMessageResponse struct {
|
||||
type SignMessageResult struct {
|
||||
Message string `json:"message"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
|
||||
// TimeoutValues represents timeout values for NIP-47 requests
|
||||
type TimeoutValues struct {
|
||||
ReplyTimeout *int64 `json:"replyTimeout,omitempty"`
|
||||
PublishTimeout *int64 `json:"publishTimeout,omitempty"`
|
||||
type CreateConnectionParams struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
Name string `json:"name"`
|
||||
RequestMethods []string `json:"request_methods"`
|
||||
NotificationTypes []string `json:"notification_types"`
|
||||
MaxAmount *uint64 `json:"max_amount,omitempty"`
|
||||
BudgetRenewal *string `json:"budget_renewal,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
57
pkg/protocol/nwc/uri.go
Normal file
57
pkg/protocol/nwc/uri.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package nwc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
)
|
||||
|
||||
type ConnectionParams struct {
|
||||
clientSecretKey []byte
|
||||
walletPublicKey []byte
|
||||
relay string
|
||||
}
|
||||
|
||||
// GetWalletPublicKey returns the wallet public key from the ConnectionParams.
|
||||
func (c *ConnectionParams) GetWalletPublicKey() []byte {
|
||||
return c.walletPublicKey
|
||||
}
|
||||
|
||||
func ParseConnectionURI(nwcUri string) (parts *ConnectionParams, err error) {
|
||||
var p *url.URL
|
||||
if p, err = url.Parse(nwcUri); chk.E(err) {
|
||||
return
|
||||
}
|
||||
parts = &ConnectionParams{}
|
||||
if p.Scheme != "nostr+walletconnect" {
|
||||
err = errors.New("incorrect scheme")
|
||||
return
|
||||
}
|
||||
if parts.walletPublicKey, err = p256k.HexToBin(p.Host); chk.E(err) {
|
||||
err = errors.New("invalid public key")
|
||||
return
|
||||
}
|
||||
query := p.Query()
|
||||
var ok bool
|
||||
var relay []string
|
||||
if relay, ok = query["relay"]; !ok {
|
||||
err = errors.New("missing relay parameter")
|
||||
return
|
||||
}
|
||||
if len(relay) == 0 {
|
||||
return nil, errors.New("no relays")
|
||||
}
|
||||
parts.relay = relay[0]
|
||||
var secret string
|
||||
if secret = query.Get("secret"); secret == "" {
|
||||
err = errors.New("missing secret parameter")
|
||||
return
|
||||
}
|
||||
if parts.clientSecretKey, err = p256k.HexToBin(secret); chk.E(err) {
|
||||
err = errors.New("invalid secret")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func (x *Operations) RegisterEvent(api huma.API) {
|
||||
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
authed, pubkey, super = x.UserAuth(r, remote)
|
||||
if !authed {
|
||||
// Record the failed authentication attempt
|
||||
|
||||
@@ -580,7 +580,7 @@ type EventsInput struct {
|
||||
}
|
||||
|
||||
type EventsOutput struct {
|
||||
Body []event.J
|
||||
Body []*event.J
|
||||
}
|
||||
|
||||
// RegisterEvents is the implementation of the HTTP API Events method.
|
||||
@@ -667,11 +667,17 @@ Returns events as a JSON array of event objects.`
|
||||
}
|
||||
tmp = append(tmp, ev)
|
||||
}
|
||||
// cap the number of events to 512 to stop excessively large
|
||||
// response.
|
||||
if len(events) > 512 {
|
||||
break
|
||||
}
|
||||
events = tmp
|
||||
}
|
||||
}
|
||||
output = &EventsOutput{}
|
||||
for _, ev := range events {
|
||||
_ = ev
|
||||
output.Body = append(output.Body, ev.ToEventJ())
|
||||
}
|
||||
return
|
||||
},
|
||||
|
||||
@@ -118,7 +118,7 @@ func (p *Publisher) Receive(msg typer.T) {
|
||||
Receiver: m.Receiver,
|
||||
Pubkey: m.Pubkey,
|
||||
}
|
||||
|
||||
|
||||
// Add the filters if provided
|
||||
if m.FilterMap != nil {
|
||||
for id, f := range m.FilterMap {
|
||||
@@ -126,7 +126,7 @@ func (p *Publisher) Receive(msg typer.T) {
|
||||
log.T.F("added subscription %s for new listener %s", id, m.Id)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Add the listener to the map
|
||||
p.ListenMap[m.Id] = listener
|
||||
log.T.F("added new listener %s", m.Id)
|
||||
@@ -254,7 +254,7 @@ func CheckListenerExists(clientId string, publishers ...publisher.I) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Check if the publisher has a Publishers field of type publisher.Publishers
|
||||
// This handles the case where the publisher is a *publish.S
|
||||
val := reflect.ValueOf(p)
|
||||
@@ -290,7 +290,7 @@ func CheckSubscriptionExists(clientId string, subscriptionId string, publishers
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Check if the publisher has a Publishers field of type publisher.Publishers
|
||||
// This handles the case where the publisher is a *publish.S
|
||||
val := reflect.ValueOf(p)
|
||||
|
||||
@@ -72,7 +72,7 @@ func (a *A) HandleAuth(b []byte, srv server.I) (msg []byte) {
|
||||
env.Event.Pubkey,
|
||||
)
|
||||
a.Listener.SetAuthedPubkey(env.Event.Pubkey)
|
||||
|
||||
|
||||
// If authentication is successful, remove any blocks for this IP
|
||||
iptracker.Global.Authenticate(a.Listener.RealRemote())
|
||||
}
|
||||
|
||||
@@ -75,43 +75,43 @@ func (a *A) HandleEvent(
|
||||
if a.I.AuthRequired() && !a.Listener.IsAuthed() {
|
||||
remoteIP := a.Listener.RealRemote()
|
||||
log.I.F("requesting auth from client from %s", remoteIP)
|
||||
|
||||
|
||||
// Check if the IP is blocked due to too many failed auth attempts
|
||||
if iptracker.Global.IsBlocked(remoteIP) {
|
||||
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
|
||||
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||
blockedUntil.Format(time.RFC3339))
|
||||
|
||||
|
||||
// Send a notice to the client explaining why they're blocked
|
||||
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
|
||||
// Close the connection
|
||||
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
|
||||
a.Listener.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Record the failed authentication attempt
|
||||
blocked := iptracker.Global.RecordFailedAttempt(remoteIP)
|
||||
if blocked {
|
||||
// If this attempt caused the IP to be blocked, close the connection
|
||||
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
|
||||
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||
blockedUntil.Format(time.RFC3339))
|
||||
|
||||
|
||||
// Send a notice to the client explaining why they're blocked
|
||||
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
|
||||
// Close the connection
|
||||
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
|
||||
a.Listener.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Continue with normal auth flow for non-blocked IPs
|
||||
a.Listener.RequestAuth()
|
||||
if err = Ok.AuthRequired(a, env.E, "auth required"); chk.E(err) {
|
||||
|
||||
@@ -57,18 +57,18 @@ type A struct {
|
||||
func (a *A) Serve(w http.ResponseWriter, r *http.Request, s server.I) {
|
||||
c := a.Config()
|
||||
remote := helpers.GetRemoteFromReq(r)
|
||||
|
||||
|
||||
// Check if the IP is blocked due to too many failed auth attempts
|
||||
if iptracker.Global.IsBlocked(remote) {
|
||||
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
|
||||
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
|
||||
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
|
||||
remote, blockedUntil.Format(time.RFC3339))
|
||||
|
||||
// We can't send a notice to the client here because the websocket connection
|
||||
|
||||
// We can't send a notice to the client here because the websocket connection
|
||||
// hasn't been established yet, so we just reject the connection
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
var whitelisted bool
|
||||
if len(c.Whitelist) > 0 {
|
||||
for _, addr := range c.Whitelist {
|
||||
|
||||
@@ -3,11 +3,19 @@ package ws
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"orly.dev/pkg/encoders/envelopes"
|
||||
"orly.dev/pkg/encoders/envelopes/authenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/closedenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/countenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/eoseenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/eventenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/noticeenvelope"
|
||||
@@ -15,54 +23,43 @@ import (
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/filter"
|
||||
"orly.dev/pkg/encoders/filters"
|
||||
"orly.dev/pkg/encoders/hex"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/interfaces/signer"
|
||||
"orly.dev/pkg/protocol/auth"
|
||||
"orly.dev/pkg/utils/atomic"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"orly.dev/pkg/utils/normalize"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
)
|
||||
|
||||
var subscriptionIDCounter atomic.Int32
|
||||
var subscriptionIDCounter atomic.Int64
|
||||
|
||||
// Relay represents a connection to a Nostr relay.
|
||||
type Client struct {
|
||||
closeMutex sync.Mutex
|
||||
|
||||
URL string
|
||||
URL string
|
||||
requestHeader http.Header // e.g. for origin header
|
||||
|
||||
RequestHeader http.Header // e.g. for origin header
|
||||
Connection *Connection
|
||||
Subscriptions *xsync.MapOf[int64, *Subscription]
|
||||
|
||||
Connection *Connection
|
||||
|
||||
Subscriptions *xsync.MapOf[string, *Subscription]
|
||||
|
||||
ConnectionError error
|
||||
|
||||
connectionContext context.T // will be canceled when the connection closes
|
||||
|
||||
connectionContextCancel context.F
|
||||
|
||||
challenge []byte // NIP-42 challenge, we only keep the last
|
||||
|
||||
notices chan []byte // NIP-01 NOTICEs
|
||||
|
||||
okCallbacks *xsync.MapOf[string, func(bool, string)]
|
||||
|
||||
writeQueue chan writeRequest
|
||||
ConnectionError error
|
||||
connectionContext context.T // will be canceled when the connection closes
|
||||
connectionContextCancel context.C
|
||||
|
||||
challenge []byte // NIP-42 challenge, we only keep the last
|
||||
noticeHandler func(string) // NIP-01 NOTICEs
|
||||
customHandler func(string) // nonstandard unparseable messages
|
||||
okCallbacks *xsync.MapOf[string, func(bool, string)]
|
||||
writeQueue chan writeRequest
|
||||
subscriptionChannelCloseQueue chan *Subscription
|
||||
|
||||
signatureChecker func(*event.E) bool
|
||||
|
||||
// custom things that aren't often used
|
||||
//
|
||||
AssumeValid bool // this will skip verifying signatures for events received from this relay
|
||||
}
|
||||
|
||||
@@ -71,21 +68,20 @@ type writeRequest struct {
|
||||
answer chan error
|
||||
}
|
||||
|
||||
// NewRelay returns a new relay. The relay connection will be closed when the
|
||||
// context is cancelled.
|
||||
func NewRelay(c context.T, url string, opts ...RelayOption) *Client {
|
||||
ctx, cancel := context.Cancel(c)
|
||||
// NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection.
|
||||
func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client {
|
||||
ctx, cancel := context.Cause(ctx)
|
||||
r := &Client{
|
||||
URL: string(normalize.URL([]byte(url))),
|
||||
URL: string(normalize.URL(url)),
|
||||
connectionContext: ctx,
|
||||
connectionContextCancel: cancel,
|
||||
Subscriptions: xsync.NewMapOf[string, *Subscription](),
|
||||
Subscriptions: xsync.NewMapOf[int64, *Subscription](),
|
||||
okCallbacks: xsync.NewMapOf[string, func(
|
||||
bool, string,
|
||||
)](),
|
||||
writeQueue: make(chan writeRequest),
|
||||
subscriptionChannelCloseQueue: make(chan *Subscription),
|
||||
signatureChecker: func(e *event.E) bool { ok, _ := e.Verify(); return ok },
|
||||
requestHeader: nil,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -95,9 +91,12 @@ func NewRelay(c context.T, url string, opts ...RelayOption) *Client {
|
||||
return r
|
||||
}
|
||||
|
||||
// RelayConnect returns a relay object connected to url. Once successfully
|
||||
// connected, cancelling ctx has no effect. To close the connection, call
|
||||
// r.Close().
|
||||
// RelayConnect returns a relay object connected to url.
|
||||
//
|
||||
// The given subscription is only used during the connection phase. Once successfully connected, cancelling ctx has no effect.
|
||||
//
|
||||
// The ongoing relay connection uses a background context. To close the connection, call r.Close().
|
||||
// If you need fine grained long-term connection contexts, use NewRelay() instead.
|
||||
func RelayConnect(ctx context.T, url string, opts ...RelayOption) (
|
||||
*Client, error,
|
||||
) {
|
||||
@@ -106,35 +105,38 @@ func RelayConnect(ctx context.T, url string, opts ...RelayOption) (
|
||||
return r, err
|
||||
}
|
||||
|
||||
// RelayOption is the type of the argument passed for that.
|
||||
// RelayOption is the type of the argument passed when instantiating relay connections.
|
||||
type RelayOption interface {
|
||||
ApplyRelayOption(*Client)
|
||||
}
|
||||
|
||||
var (
|
||||
_ RelayOption = (WithNoticeHandler)(nil)
|
||||
_ RelayOption = (WithSignatureChecker)(nil)
|
||||
_ RelayOption = (WithCustomHandler)(nil)
|
||||
_ RelayOption = (WithRequestHeader)(nil)
|
||||
)
|
||||
|
||||
// WithNoticeHandler just takes notices and is expected to do something with
|
||||
// them. when not given, defaults to logging the notices.
|
||||
type WithNoticeHandler func(notice []byte)
|
||||
// WithNoticeHandler just takes notices and is expected to do something with them.
|
||||
// when not given, defaults to logging the notices.
|
||||
type WithNoticeHandler func(notice string)
|
||||
|
||||
func (nh WithNoticeHandler) ApplyRelayOption(r *Client) {
|
||||
r.notices = make(chan []byte)
|
||||
go func() {
|
||||
for notice := range r.notices {
|
||||
nh(notice)
|
||||
}
|
||||
}()
|
||||
r.noticeHandler = nh
|
||||
}
|
||||
|
||||
// WithSignatureChecker must be a function that checks the signature of an event
|
||||
// and returns true or false.
|
||||
type WithSignatureChecker func(*event.E) bool
|
||||
// WithCustomHandler must be a function that handles any relay message that couldn't be
|
||||
// parsed as a standard envelope.
|
||||
type WithCustomHandler func(data string)
|
||||
|
||||
func (sc WithSignatureChecker) ApplyRelayOption(r *Client) {
|
||||
r.signatureChecker = sc
|
||||
func (ch WithCustomHandler) ApplyRelayOption(r *Client) {
|
||||
r.customHandler = ch
|
||||
}
|
||||
|
||||
// WithRequestHeader sets the HTTP request header of the websocket preflight request.
|
||||
type WithRequestHeader http.Header
|
||||
|
||||
func (ch WithRequestHeader) ApplyRelayOption(r *Client) {
|
||||
r.requestHeader = http.Header(ch)
|
||||
}
|
||||
|
||||
// String just returns the relay URL.
|
||||
@@ -143,253 +145,269 @@ func (r *Client) String() string {
|
||||
}
|
||||
|
||||
// Context retrieves the context that is associated with this relay connection.
|
||||
// It will be closed when the relay is disconnected.
|
||||
func (r *Client) Context() context.T { return r.connectionContext }
|
||||
|
||||
// IsConnected returns true if the connection to this relay seems to be active.
|
||||
func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil }
|
||||
|
||||
// Connect tries to establish a websocket connection to r.URL. If the context
|
||||
// expires before the connection is complete, an error is returned. Once
|
||||
// successfully connected, context expiration has no effect: call r.Close to
|
||||
// close the connection.
|
||||
// Connect tries to establish a websocket connection to r.URL.
|
||||
// If the context expires before the connection is complete, an error is returned.
|
||||
// Once successfully connected, context expiration has no effect: call r.Close
|
||||
// to close the connection.
|
||||
//
|
||||
// The underlying relay connection will use a background context. If you want to
|
||||
// pass a custom context to the underlying relay connection, use NewRelay() and
|
||||
// then Client.Connect().
|
||||
func (r *Client) Connect(c context.T) error { return r.ConnectWithTLS(c, nil) }
|
||||
// The given context here is only used during the connection phase. The long-living
|
||||
// relay connection will be based on the context given to NewRelay().
|
||||
func (r *Client) Connect(ctx context.T) error {
|
||||
return r.ConnectWithTLS(ctx, nil)
|
||||
}
|
||||
|
||||
// ConnectWithTLS tries to establish a secured websocket connection to r.URL
|
||||
// using customized tls.Config (CA's, etc.).
|
||||
func (r *Client) ConnectWithTLS(ctx context.T, tlsConfig *tls.Config) error {
|
||||
func subIdToSerial(subId string) int64 {
|
||||
n := strings.Index(subId, ":")
|
||||
if n < 0 || n > len(subId) {
|
||||
return -1
|
||||
}
|
||||
serialId, _ := strconv.ParseInt(subId[0:n], 10, 64)
|
||||
return serialId
|
||||
}
|
||||
|
||||
// ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that.
|
||||
func (r *Client) ConnectWithTLS(
|
||||
ctx context.T, tlsConfig *tls.Config,
|
||||
) (err error) {
|
||||
if r.connectionContext == nil || r.Subscriptions == nil {
|
||||
return errorf.E("relay must be initialized with a call to NewRelay()")
|
||||
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
|
||||
}
|
||||
|
||||
if r.URL == "" {
|
||||
return errorf.E("invalid relay URL '%s'", r.URL)
|
||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||
}
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.F
|
||||
ctx, cancel = context.Timeout(ctx, 7*time.Second)
|
||||
ctx, cancel = context.TimeoutCause(
|
||||
ctx, 7*time.Second, errors.New("connection took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
}
|
||||
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig)
|
||||
if err != nil {
|
||||
return errorf.E(
|
||||
"error opening websocket to '%s': %s", r.URL, err.Error(),
|
||||
)
|
||||
var conn *Connection
|
||||
if conn, err = NewConnection(
|
||||
ctx, r.URL, r.requestHeader, tlsConfig,
|
||||
); chk.E(err) {
|
||||
err = fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
||||
return
|
||||
}
|
||||
r.Connection = conn
|
||||
// ping every 29 seconds (??)
|
||||
// ping every 29 seconds
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
// to be used when the connection is closed
|
||||
go func() {
|
||||
<-r.connectionContext.Done()
|
||||
// close these things when the connection is closed
|
||||
if r.notices != nil {
|
||||
close(r.notices)
|
||||
}
|
||||
// stop the ticker
|
||||
ticker.Stop()
|
||||
// close all subscriptions
|
||||
r.Subscriptions.Range(
|
||||
func(_ string, sub *Subscription) bool {
|
||||
go sub.Unsub()
|
||||
return true
|
||||
},
|
||||
)
|
||||
}()
|
||||
// queue all write operations here so we don't do mutex spaghetti
|
||||
go func() {
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
ticker.Stop()
|
||||
r.Connection = nil
|
||||
for _, sub := range r.Subscriptions.Range {
|
||||
sub.unsub(
|
||||
fmt.Errorf(
|
||||
"relay connection closed: %w / %w",
|
||||
context.GetCause(r.connectionContext),
|
||||
r.ConnectionError,
|
||||
),
|
||||
)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
err = wsutil.WriteClientMessage(
|
||||
r.Connection.conn, ws.OpPing, nil,
|
||||
)
|
||||
if err != nil {
|
||||
log.D.F(
|
||||
err := r.Connection.Ping(r.connectionContext)
|
||||
if err != nil && !strings.Contains(
|
||||
err.Error(), "failed to wait for pong",
|
||||
) {
|
||||
log.I.F(
|
||||
"{%s} error writing ping: %v; closing websocket", r.URL,
|
||||
err,
|
||||
)
|
||||
r.Close() // this should trigger a context cancelation
|
||||
return
|
||||
}
|
||||
case writeReq := <-r.writeQueue:
|
||||
case writeRequest := <-r.writeQueue:
|
||||
// all write requests will go through this to prevent races
|
||||
if err = r.Connection.WriteMessage(
|
||||
r.connectionContext,
|
||||
writeReq.msg,
|
||||
); chk.T(err) {
|
||||
writeReq.answer <- err
|
||||
log.D.F("{%s} sending %v\n", r.URL, string(writeRequest.msg))
|
||||
if err := r.Connection.WriteMessage(
|
||||
r.connectionContext, writeRequest.msg,
|
||||
); err != nil {
|
||||
writeRequest.answer <- err
|
||||
}
|
||||
close(writeReq.answer)
|
||||
case <-r.connectionContext.Done():
|
||||
// stop here
|
||||
return
|
||||
close(writeRequest.answer)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// general message reader loop
|
||||
go func() {
|
||||
var err error
|
||||
|
||||
for {
|
||||
buf := new(bytes.Buffer)
|
||||
if err = conn.ReadMessage(r.connectionContext, buf); err != nil {
|
||||
if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
|
||||
r.ConnectionError = err
|
||||
r.Close()
|
||||
r.close(err)
|
||||
break
|
||||
}
|
||||
message := buf.Bytes()
|
||||
// log.D.F("{%s} %s\n", r.URL, message)
|
||||
|
||||
var err error
|
||||
var t string
|
||||
if t, message, err = envelopes.Identify(message); chk.E(err) {
|
||||
var rem []byte
|
||||
if t, rem, err = envelopes.Identify(buf.Bytes()); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case noticeenvelope.L:
|
||||
env := noticeenvelope.New()
|
||||
if env, message, err = noticeenvelope.Parse(message); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
env := noticeenvelope.NewFrom(rem)
|
||||
// see WithNoticeHandler
|
||||
if r.notices != nil {
|
||||
r.notices <- env.Message
|
||||
if r.noticeHandler != nil {
|
||||
r.noticeHandler(string(env.Message))
|
||||
} else {
|
||||
log.E.F("NOTICE from %s: '%s'\n", r.URL, env.Message)
|
||||
log.D.F(
|
||||
"NOTICE from %s: '%s'\n", r.URL, string(env.Message),
|
||||
)
|
||||
}
|
||||
case authenvelope.L:
|
||||
env := authenvelope.NewChallenge()
|
||||
if env, message, err = authenvelope.ParseChallenge(message); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
if len(env.Challenge) == 0 {
|
||||
env := authenvelope.NewChallengeWith(rem)
|
||||
if env.Challenge == nil {
|
||||
continue
|
||||
}
|
||||
r.challenge = env.Challenge
|
||||
case eventenvelope.L:
|
||||
// log.I.F("message: %s", message)
|
||||
env := eventenvelope.NewResult()
|
||||
if env, message, err = eventenvelope.ParseResult(message); err != nil {
|
||||
var env *eventenvelope.Result
|
||||
env = eventenvelope.NewResult()
|
||||
if _, err = env.Unmarshal(rem); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
// log.I.F("%s", env.Event.Marshal(nil))
|
||||
if len(env.Subscription.T) == 0 {
|
||||
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
|
||||
if !ok {
|
||||
log.W.F(
|
||||
"unknown subscription with id '%s'\n",
|
||||
env.Subscription.String(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if sub, ok := r.Subscriptions.Load(env.Subscription.String()); !ok {
|
||||
// log.D.F(
|
||||
// "{%s} no subscription with id '%s'\n", r.URL,
|
||||
// env.Subscription,
|
||||
// )
|
||||
if !sub.Filters.Match(env.Event) {
|
||||
log.I.F(
|
||||
"{%s} filter does not match: %v ~ %s\n", r.URL,
|
||||
sub.Filters, env.Event.Marshal(nil),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
// check if the event matches the desired filter, ignore
|
||||
// otherwise
|
||||
if !sub.Filters.Match(env.Event) {
|
||||
log.D.F(
|
||||
"{%s} filter does not match: %v ~ %v\n", r.URL,
|
||||
sub.Filters, env.Event,
|
||||
}
|
||||
if !r.AssumeValid {
|
||||
if ok, err = env.Event.Verify(); !ok || chk.E(err) {
|
||||
log.I.F(
|
||||
"{%s} bad signature on %s\n", r.URL,
|
||||
env.Event.ID,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// check signature, ignore invalid, except from trusted
|
||||
// (AssumeValid) relays
|
||||
if !r.AssumeValid {
|
||||
if ok = r.signatureChecker(env.Event); !ok {
|
||||
log.E.F(
|
||||
"{%s} bad signature on %s\n", r.URL,
|
||||
env.Event.ID,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// dispatch this to the internal .events channel of the
|
||||
// subscription
|
||||
sub.dispatchEvent(env.Event)
|
||||
}
|
||||
sub.dispatchEvent(env.Event)
|
||||
case eoseenvelope.L:
|
||||
env := eoseenvelope.New()
|
||||
if env, message, err = eoseenvelope.Parse(message); chk.E(err) {
|
||||
var env *eoseenvelope.T
|
||||
if env, rem, err = eoseenvelope.Parse(rem); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
|
||||
subscription.dispatchEose()
|
||||
if len(rem) != 0 {
|
||||
log.W.F(
|
||||
"{%s} unexpected data after EOSE: %s\n", r.URL,
|
||||
string(rem),
|
||||
)
|
||||
}
|
||||
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
|
||||
if !ok {
|
||||
log.W.F(
|
||||
"unknown subscription with id '%s'\n",
|
||||
env.Subscription.String(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
sub.dispatchEose()
|
||||
case closedenvelope.L:
|
||||
env := closedenvelope.New()
|
||||
if env, message, err = closedenvelope.Parse(message); chk.E(err) {
|
||||
var env *closedenvelope.T
|
||||
if env, rem, err = closedenvelope.Parse(rem); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
|
||||
subscription.dispatchClosed(env.ReasonString())
|
||||
}
|
||||
case countenvelope.L:
|
||||
env := countenvelope.NewResponse()
|
||||
if env, message, err = countenvelope.Parse(message); chk.E(err) {
|
||||
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
|
||||
if !ok {
|
||||
log.W.F(
|
||||
"unknown subscription with id '%s'\n",
|
||||
env.Subscription.String(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if subscription, ok := r.Subscriptions.Load(env.ID.String()); ok && subscription.countResult != nil {
|
||||
subscription.countResult <- env.Count
|
||||
}
|
||||
sub.handleClosed(env.ReasonString())
|
||||
case okenvelope.L:
|
||||
env := okenvelope.New()
|
||||
if env, message, err = okenvelope.Parse(message); chk.E(err) {
|
||||
var env *okenvelope.T
|
||||
if env, rem, err = okenvelope.Parse(rem); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID.String()); exist {
|
||||
okCallback(env.OK, env.ReasonString())
|
||||
okCallback(env.OK, string(env.Reason))
|
||||
} else {
|
||||
log.I.F(
|
||||
"{%s} got an unexpected OK message for event %s", r.URL,
|
||||
env.EventID,
|
||||
)
|
||||
}
|
||||
default:
|
||||
log.W.F("unknown envelope type %s\n%s", t, rem)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Write queues a message to be sent to the relay.
|
||||
// Write queues an arbitrary message to be sent to the relay.
|
||||
func (r *Client) Write(msg []byte) <-chan error {
|
||||
ch := make(chan error)
|
||||
select {
|
||||
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||
case <-r.connectionContext.Done():
|
||||
go func() { ch <- errorf.E("connection closed") }()
|
||||
go func() { ch <- fmt.Errorf("connection closed") }()
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an
|
||||
// OK response.
|
||||
func (r *Client) Publish(c context.T, ev *event.E) error {
|
||||
return r.publish(
|
||||
c, ev,
|
||||
)
|
||||
func (r *Client) Publish(ctx context.T, ev *event.E) error {
|
||||
return r.publish(ctx, ev.ID, ev)
|
||||
}
|
||||
|
||||
// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK
|
||||
// response.
|
||||
func (r *Client) Auth(c context.T, sign signer.I) error {
|
||||
authEvent := auth.CreateUnsigned(sign.Pub(), r.challenge, r.URL)
|
||||
if err := authEvent.Sign(sign); chk.T(err) {
|
||||
return errorf.E("error signing auth event: %w", err)
|
||||
func (r *Client) Auth(
|
||||
ctx context.T, sign signer.I,
|
||||
) (err error) {
|
||||
authEvent := &event.E{
|
||||
CreatedAt: timestamp.Now(),
|
||||
Kind: kind.ClientAuthentication,
|
||||
Tags: tags.New(
|
||||
tag.New("relay", r.URL),
|
||||
tag.New([]byte("challenge"), r.challenge),
|
||||
),
|
||||
}
|
||||
return r.publish(c, authEvent)
|
||||
if err = authEvent.Sign(sign); chk.E(err) {
|
||||
err = fmt.Errorf("error signing auth event: %w", err)
|
||||
return
|
||||
}
|
||||
return r.publish(ctx, authEvent.ID, authEvent)
|
||||
}
|
||||
|
||||
// publish can be used both for EVENT and for AUTH
|
||||
func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
|
||||
func (r *Client) publish(
|
||||
ctx context.T, id []byte, ev *event.E,
|
||||
) error {
|
||||
var err error
|
||||
var cancel context.F
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
ctx, cancel = context.TimeoutCause(
|
||||
ctx, 7*time.Second,
|
||||
errorf.E("given up waiting for an OK"),
|
||||
ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"),
|
||||
)
|
||||
defer cancel()
|
||||
} else {
|
||||
@@ -398,32 +416,24 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
|
||||
ctx, cancel = context.Cancel(ctx)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// listen for an OK callback
|
||||
gotOk := false
|
||||
id := ev.IdString()
|
||||
ids := hex.Enc(id)
|
||||
r.okCallbacks.Store(
|
||||
id, func(ok bool, reason string) {
|
||||
ids, func(ok bool, reason string) {
|
||||
gotOk = true
|
||||
if !ok {
|
||||
err = errorf.E("msg: %s", reason)
|
||||
err = fmt.Errorf("msg: %s", reason)
|
||||
}
|
||||
cancel()
|
||||
},
|
||||
)
|
||||
defer r.okCallbacks.Delete(id)
|
||||
defer r.okCallbacks.Delete(ids)
|
||||
// publish event
|
||||
var b []byte
|
||||
if ev.Kind.Equal(kind.ClientAuthentication) {
|
||||
if b = authenvelope.NewResponseWith(ev).Marshal(b); chk.E(err) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if b = eventenvelope.NewSubmissionWith(ev).Marshal(b); chk.E(err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.T.F("{%s} sending %s\n", r.URL, b)
|
||||
if err = <-r.Write(b); chk.T(err) {
|
||||
envb := eventenvelope.NewSubmissionWith(ev).Marshal(nil)
|
||||
// envb := ev.Marshal(nil)
|
||||
if err = <-r.Write(envb); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
@@ -447,111 +457,168 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
|
||||
// context ctx is cancelled ("CLOSE" in NIP-01).
|
||||
//
|
||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or
|
||||
// ensuring their `context.Context` will be canceled at some point. Failure to
|
||||
// ensuring their `context.T` will be canceled at some point. Failure to
|
||||
// do that will result in a huge number of halted goroutines being created.
|
||||
func (r *Client) Subscribe(
|
||||
c context.T, ff *filters.T,
|
||||
opts ...SubscriptionOption,
|
||||
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
|
||||
) (sub *Subscription, err error) {
|
||||
sub = r.PrepareSubscription(c, ff, opts...)
|
||||
sub = r.PrepareSubscription(ctx, ff, opts...)
|
||||
if r.Connection == nil {
|
||||
return nil, errorf.E("not connected to %s", r.URL)
|
||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
||||
}
|
||||
if err = sub.Fire(); chk.T(err) {
|
||||
return nil, errorf.E(
|
||||
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
|
||||
if err = sub.Fire(); err != nil {
|
||||
err = fmt.Errorf(
|
||||
"couldn't subscribe to %v at %s: %w", ff.Marshal(nil), r.URL, err,
|
||||
)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PrepareSubscription creates a subscription, but doesn't fire it.
|
||||
//
|
||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or
|
||||
// ensuring their `context.Context` will be canceled at some point. Failure to
|
||||
// do that will result in a huge number of halted goroutines being created.
|
||||
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.T` will be canceled at some point.
|
||||
// Failure to do that will result in a huge number of halted goroutines being created.
|
||||
func (r *Client) PrepareSubscription(
|
||||
c context.T, ff *filters.T,
|
||||
opts ...SubscriptionOption,
|
||||
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
|
||||
) *Subscription {
|
||||
current := subscriptionIDCounter.Add(1)
|
||||
c, cancel := context.Cancel(c)
|
||||
ctx, cancel := context.Cause(ctx)
|
||||
sub := &Subscription{
|
||||
Relay: r,
|
||||
Context: c,
|
||||
Context: ctx,
|
||||
cancel: cancel,
|
||||
counter: int(current),
|
||||
counter: current,
|
||||
Events: make(event.C),
|
||||
EndOfStoredEvents: make(chan struct{}, 1),
|
||||
ClosedReason: make(chan string, 1),
|
||||
Filters: ff,
|
||||
match: ff.Match,
|
||||
}
|
||||
label := ""
|
||||
for _, opt := range opts {
|
||||
switch o := opt.(type) {
|
||||
case WithLabel:
|
||||
sub.label = string(o)
|
||||
label = string(o)
|
||||
// case WithCheckDuplicate:
|
||||
// sub.checkDuplicate = o
|
||||
// case WithCheckDuplicateReplaceable:
|
||||
// sub.checkDuplicateReplaceable = o
|
||||
}
|
||||
}
|
||||
id := sub.GetID()
|
||||
r.Subscriptions.Store(id.String(), sub)
|
||||
// subscription id computation
|
||||
buf := subIdPool.Get().([]byte)[:0]
|
||||
buf = strconv.AppendInt(buf, sub.counter, 10)
|
||||
buf = append(buf, ':')
|
||||
buf = append(buf, label...)
|
||||
defer subIdPool.Put(buf)
|
||||
sub.id = string(buf)
|
||||
// we track subscriptions only by their counter, no need for the full id
|
||||
r.Subscriptions.Store(int64(sub.counter), sub)
|
||||
// start handling events, eose, unsub etc:
|
||||
go sub.start()
|
||||
return sub
|
||||
}
|
||||
|
||||
// QuerySync is only used in tests. The relay query method is synchronous now
|
||||
// anyway (it ensures sort order is respected).
|
||||
func (r *Client) QuerySync(
|
||||
ctx context.T, f *filter.F,
|
||||
opts ...SubscriptionOption,
|
||||
) ([]*event.E, error) {
|
||||
// log.T.F("QuerySync:\n%s", f.Marshal(nil))
|
||||
sub, err := r.Subscribe(ctx, filters.New(f), opts...)
|
||||
// QueryEvents subscribes to events matching the given filter and returns a channel of events.
|
||||
//
|
||||
// In most cases it's better to use Pool instead of this method.
|
||||
func (r *Client) QueryEvents(ctx context.T, f *filter.F) (
|
||||
evc event.C, err error,
|
||||
) {
|
||||
var sub *Subscription
|
||||
if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-sub.ClosedReason:
|
||||
case <-sub.EndOfStoredEvents:
|
||||
case <-ctx.Done():
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
sub.unsub(errors.New("QueryEvents() ended"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
return sub.Events, nil
|
||||
}
|
||||
|
||||
// QuerySync subscribes to events matching the given filter and returns a slice
|
||||
// of events. This method blocks until all events are received or the context is
|
||||
// canceled.
|
||||
//
|
||||
// If the filter causes a subscription to open, it will stay open until the
|
||||
// limit is exceeded. So this method will return an error if the limit is nil.
|
||||
// If the query blocks, the caller needs to cancel the context to prevent the
|
||||
// thread stalling.
|
||||
func (r *Client) QuerySync(ctx context.T, f *filter.F) (
|
||||
evs event.S, err error,
|
||||
) {
|
||||
if f.Limit == nil {
|
||||
err = errors.New("limit must be set for a sync query to prevent blocking")
|
||||
return
|
||||
}
|
||||
var sub *Subscription
|
||||
if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) {
|
||||
return
|
||||
}
|
||||
defer sub.unsub(errors.New("QuerySync() ended"))
|
||||
evs = make(event.S, 0, *f.Limit)
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.F
|
||||
ctx, cancel = context.TimeoutCause(
|
||||
ctx, 7*time.Second, errors.New("QuerySync() took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
}
|
||||
lim := 250
|
||||
if f.Limit != nil {
|
||||
lim = int(*f.Limit)
|
||||
}
|
||||
events := make(event.S, 0, max(lim, 250))
|
||||
ch, err := r.QueryEvents(ctx, f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer sub.Unsub()
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.F
|
||||
ctx, cancel = context.Timeout(ctx, 7*time.Second)
|
||||
defer cancel()
|
||||
for evt := range ch {
|
||||
events = append(events, evt)
|
||||
}
|
||||
|
||||
var events []*event.E
|
||||
for {
|
||||
select {
|
||||
case evt := <-sub.Events:
|
||||
if evt == nil {
|
||||
// channel is closed
|
||||
return events, nil
|
||||
}
|
||||
events = append(events, evt)
|
||||
case <-sub.EndOfStoredEvents:
|
||||
return events, nil
|
||||
case <-ctx.Done():
|
||||
return events, nil
|
||||
}
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// TODO: count is a dumb idea anyway, and nothing is using this
|
||||
// func (r *Client) Count(c context.F, ff *filters.F, opts ...SubscriptionOption) (int, error) {
|
||||
// sub := r.PrepareSubscription(c, ff, opts...)
|
||||
// sub.countResult = make(chan int)
|
||||
//
|
||||
// if err := sub.Fire(); chk.F(err) {
|
||||
// return 0, err
|
||||
// // Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
|
||||
// func (r *Relay) Count(
|
||||
// ctx context.T,
|
||||
// filters Filters,
|
||||
// opts ...SubscriptionOption,
|
||||
// ) (int64, []byte, error) {
|
||||
// v, err := r.countInternal(ctx, filters, opts...)
|
||||
// if err != nil {
|
||||
// return 0, nil, err
|
||||
// }
|
||||
//
|
||||
// defer sub.Unsub()
|
||||
// return *v.Count, v.HyperLogLog, nil
|
||||
// }
|
||||
//
|
||||
// if _, ok := c.Deadline(); !ok {
|
||||
// func (r *Relay) countInternal(ctx context.T, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
|
||||
// sub := r.PrepareSubscription(ctx, filters, opts...)
|
||||
// sub.countResult = make(chan CountEnvelope)
|
||||
//
|
||||
// if err := sub.Fire(); err != nil {
|
||||
// return CountEnvelope{}, err
|
||||
// }
|
||||
//
|
||||
// defer sub.unsub(errors.New("countInternal() ended"))
|
||||
//
|
||||
// if _, ok := ctx.Deadline(); !ok {
|
||||
// // if no timeout is set, force it to 7 seconds
|
||||
// var cancel context.F
|
||||
// c, cancel = context.Timeout(c, 7*time.Second)
|
||||
// var cancel context.CancelFunc
|
||||
// ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long"))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
@@ -559,28 +626,39 @@ func (r *Client) QuerySync(
|
||||
// select {
|
||||
// case count := <-sub.countResult:
|
||||
// return count, nil
|
||||
// case <-c.Done():
|
||||
// return 0, c.Err()
|
||||
// case <-ctx.Done():
|
||||
// return CountEnvelope{}, ctx.Err()
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Close shuts down a websocket client connection.
|
||||
// Close closes the relay connection.
|
||||
func (r *Client) Close() error {
|
||||
return r.close(errors.New("relay connection closed"))
|
||||
}
|
||||
|
||||
func (r *Client) close(reason error) error {
|
||||
r.closeMutex.Lock()
|
||||
defer r.closeMutex.Unlock()
|
||||
|
||||
if r.connectionContextCancel == nil {
|
||||
return errorf.E("relay already closed")
|
||||
return fmt.Errorf("relay already closed")
|
||||
}
|
||||
r.connectionContextCancel()
|
||||
r.connectionContextCancel(reason)
|
||||
r.connectionContextCancel = nil
|
||||
|
||||
if r.Connection == nil {
|
||||
return errorf.E("relay not connected")
|
||||
return fmt.Errorf("relay not connected")
|
||||
}
|
||||
|
||||
err := r.Connection.Close()
|
||||
r.Connection = nil
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var subIdPool = sync.Pool{
|
||||
New: func() any { return make([]byte, 0, 15) },
|
||||
}
|
||||
|
||||
@@ -1,146 +1,159 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/encoders/envelopes/eventenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/okenvelope"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/tag"
|
||||
"orly.dev/pkg/encoders/tags"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/normalize"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/crypto/p256k"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/normalize"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
// test note to be sent over websocket
|
||||
var err error
|
||||
signer := &p256k.Signer{}
|
||||
if err = signer.Generate(); chk.E(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
textNote := &event.E{
|
||||
Kind: kind.TextNote,
|
||||
Content: []byte("hello"),
|
||||
CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
|
||||
Tags: tags.New(tag.New("foo", "bar")),
|
||||
Pubkey: signer.Pub(),
|
||||
}
|
||||
if err = textNote.Sign(signer); chk.E(err) {
|
||||
t.Fatalf("textNote.Sign: %v", err)
|
||||
}
|
||||
// fake relay server
|
||||
var mu sync.Mutex // guards published to satisfy `go test -race`
|
||||
var published bool
|
||||
ws := newWebsocketServer(
|
||||
func(conn *websocket.Conn) {
|
||||
mu.Lock()
|
||||
published = true
|
||||
mu.Unlock()
|
||||
// verify the client sent exactly the textNote
|
||||
var raw []json.RawMessage
|
||||
if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
|
||||
t.Errorf("websocket.JSON.Receive: %v", err)
|
||||
}
|
||||
|
||||
if string(raw[0]) != fmt.Sprintf(`"%s"`, eventenvelope.L) {
|
||||
t.Errorf("got type %s, want %s", raw[0], eventenvelope.L)
|
||||
}
|
||||
env := eventenvelope.NewSubmission()
|
||||
if raw[1], err = env.Unmarshal(raw[1]); chk.E(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(env.E.Serialize(), textNote.Serialize()) {
|
||||
t.Errorf(
|
||||
"received event:\n%s\nwant:\n%s", env.E.Serialize(),
|
||||
textNote.Serialize(),
|
||||
)
|
||||
}
|
||||
// send back an ok nip-20 command result
|
||||
var res []byte
|
||||
if res = okenvelope.NewFrom(
|
||||
textNote.ID, true, nil,
|
||||
).Marshal(res); chk.E(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := websocket.Message.Send(conn, res); chk.T(err) {
|
||||
t.Errorf("websocket.JSON.Send: %v", err)
|
||||
}
|
||||
},
|
||||
)
|
||||
defer ws.Close()
|
||||
// connect a client and send the text note
|
||||
rl := mustRelayConnect(ws.URL)
|
||||
err = rl.Publish(context.Bg(), textNote)
|
||||
if err != nil {
|
||||
t.Errorf("publish should have succeeded")
|
||||
}
|
||||
if !published {
|
||||
t.Errorf("fake relay server saw no event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishBlocked(t *testing.T) {
|
||||
// test note to be sent over websocket
|
||||
var err error
|
||||
signer := &p256k.Signer{}
|
||||
if err = signer.Generate(); chk.E(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
textNote := &event.E{
|
||||
Kind: kind.TextNote,
|
||||
Content: []byte("hello"),
|
||||
CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
|
||||
Pubkey: signer.Pub(),
|
||||
}
|
||||
if err = textNote.Sign(signer); chk.E(err) {
|
||||
t.Fatalf("textNote.Sign: %v", err)
|
||||
}
|
||||
// fake relay server
|
||||
ws := newWebsocketServer(
|
||||
func(conn *websocket.Conn) {
|
||||
// discard received message; not interested
|
||||
var raw []json.RawMessage
|
||||
if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
|
||||
t.Errorf("websocket.JSON.Receive: %v", err)
|
||||
}
|
||||
// send back a not ok nip-20 command result
|
||||
var res []byte
|
||||
if res = okenvelope.NewFrom(
|
||||
textNote.ID, false,
|
||||
normalize.Msg(normalize.Blocked, "no reason"),
|
||||
).Marshal(res); chk.E(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := websocket.Message.Send(conn, res); chk.T(err) {
|
||||
t.Errorf("websocket.JSON.Send: %v", err)
|
||||
}
|
||||
// res := []any{"OK", textNote.ID, false, "blocked"}
|
||||
chk.E(websocket.JSON.Send(conn, res))
|
||||
},
|
||||
)
|
||||
defer ws.Close()
|
||||
|
||||
// connect a client and send a text note
|
||||
rl := mustRelayConnect(ws.URL)
|
||||
if err = rl.Publish(context.Bg(), textNote); !chk.E(err) {
|
||||
t.Errorf("should have failed to publish")
|
||||
}
|
||||
}
|
||||
// func TestPublish(t *testing.T) {
|
||||
// // test note to be sent over websocket
|
||||
// var err error
|
||||
// signer := &p256k.Signer{}
|
||||
// if err = signer.Generate(); chk.E(err) {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// textNote := &event.E{
|
||||
// Kind: kind.TextNote,
|
||||
// Content: []byte("hello"),
|
||||
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
|
||||
// Pubkey: signer.Pub(),
|
||||
// }
|
||||
// if err = textNote.Sign(signer); chk.E(err) {
|
||||
// t.Fatalf("textNote.Sign: %v", err)
|
||||
// }
|
||||
// // fake relay server
|
||||
// var published bool
|
||||
// ws := newWebsocketServer(
|
||||
// func(conn *websocket.Conn) {
|
||||
// // receive message
|
||||
// var raw []json.RawMessage
|
||||
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
|
||||
// t.Errorf("websocket.JSON.Receive: %v", err)
|
||||
// }
|
||||
// // check that it's an EVENT message
|
||||
// if len(raw) < 2 {
|
||||
// t.Errorf("message too short: %v", raw)
|
||||
// }
|
||||
// var msgType string
|
||||
// if err := json.Unmarshal(raw[0], &msgType); chk.T(err) {
|
||||
// t.Errorf("json.Unmarshal: %v", err)
|
||||
// }
|
||||
// if msgType != "EVENT" {
|
||||
// t.Errorf("expected EVENT message, got %q", msgType)
|
||||
// }
|
||||
// // check that the event is the one we sent
|
||||
// var ev event.E
|
||||
// if err := json.Unmarshal(raw[1], &ev); chk.T(err) {
|
||||
// t.Errorf("json.Unmarshal: %v", err)
|
||||
// }
|
||||
// published = true
|
||||
// if !bytes.Equal(ev.ID, textNote.ID) {
|
||||
// t.Errorf(
|
||||
// "event ID mismatch: got %x, want %x",
|
||||
// ev.ID, textNote.ID,
|
||||
// )
|
||||
// }
|
||||
// if !bytes.Equal(ev.Pubkey, textNote.Pubkey) {
|
||||
// t.Errorf(
|
||||
// "event pubkey mismatch: got %x, want %x",
|
||||
// ev.Pubkey, textNote.Pubkey,
|
||||
// )
|
||||
// }
|
||||
// if !bytes.Equal(ev.Content, textNote.Content) {
|
||||
// t.Errorf(
|
||||
// "event content mismatch: got %q, want %q",
|
||||
// ev.Content, textNote.Content,
|
||||
// )
|
||||
// }
|
||||
// fmt.Printf(
|
||||
// "received event: %s\n",
|
||||
// textNote.Serialize(),
|
||||
// )
|
||||
// // send back an ok nip-20 command result
|
||||
// var res []byte
|
||||
// if res = okenvelope.NewFrom(
|
||||
// textNote.ID, true, nil,
|
||||
// ).Marshal(res); chk.E(err) {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if err := websocket.Message.Send(conn, res); chk.T(err) {
|
||||
// t.Errorf("websocket.Message.Send: %v", err)
|
||||
// }
|
||||
// },
|
||||
// )
|
||||
// defer ws.Close()
|
||||
// // connect a client and send the text note
|
||||
// rl := mustRelayConnect(ws.URL)
|
||||
// err = rl.Publish(context.Background(), textNote)
|
||||
// if err != nil {
|
||||
// t.Errorf("publish should have succeeded")
|
||||
// }
|
||||
// if !published {
|
||||
// t.Errorf("fake relay server saw no event")
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func TestPublishBlocked(t *testing.T) {
|
||||
// // test note to be sent over websocket
|
||||
// var err error
|
||||
// signer := &p256k.Signer{}
|
||||
// if err = signer.Generate(); chk.E(err) {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// textNote := &event.E{
|
||||
// Kind: kind.TextNote,
|
||||
// Content: []byte("hello"),
|
||||
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
|
||||
// Pubkey: signer.Pub(),
|
||||
// }
|
||||
// if err = textNote.Sign(signer); chk.E(err) {
|
||||
// t.Fatalf("textNote.Sign: %v", err)
|
||||
// }
|
||||
// // fake relay server
|
||||
// ws := newWebsocketServer(
|
||||
// func(conn *websocket.Conn) {
|
||||
// // discard received message; not interested
|
||||
// var raw []json.RawMessage
|
||||
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
|
||||
// t.Errorf("websocket.JSON.Receive: %v", err)
|
||||
// }
|
||||
// // send back a not ok nip-20 command result
|
||||
// var res []byte
|
||||
// if res = okenvelope.NewFrom(
|
||||
// textNote.ID, false,
|
||||
// normalize.Msg(normalize.Blocked, "no reason"),
|
||||
// ).Marshal(res); chk.E(err) {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if err := websocket.Message.Send(conn, res); chk.T(err) {
|
||||
// t.Errorf("websocket.Message.Send: %v", err)
|
||||
// }
|
||||
// // res := []any{"OK", textNote.ID, false, "blocked"}
|
||||
// },
|
||||
// )
|
||||
// defer ws.Close()
|
||||
//
|
||||
// // connect a client and send a text note
|
||||
// rl := mustRelayConnect(ws.URL)
|
||||
// if err = rl.Publish(context.Background(), textNote); !chk.E(err) {
|
||||
// t.Errorf("should have failed to publish")
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestPublishWriteFailed(t *testing.T) {
|
||||
// test note to be sent over websocket
|
||||
@@ -171,7 +184,7 @@ func TestPublishWriteFailed(t *testing.T) {
|
||||
rl := mustRelayConnect(ws.URL)
|
||||
// Force brief period of time so that publish always fails on closed socket.
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
err = rl.Publish(context.Bg(), textNote)
|
||||
err = rl.Publish(context.Background(), textNote)
|
||||
if err == nil {
|
||||
t.Errorf("should have failed to publish")
|
||||
}
|
||||
@@ -192,7 +205,7 @@ func TestConnectContext(t *testing.T) {
|
||||
defer ws.Close()
|
||||
|
||||
// relay client
|
||||
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
r, err := RelayConnect(ctx, ws.URL)
|
||||
if err != nil {
|
||||
@@ -213,7 +226,7 @@ func TestConnectContextCanceled(t *testing.T) {
|
||||
defer ws.Close()
|
||||
|
||||
// relay client
|
||||
ctx, cancel := context.Cancel(context.Bg())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // make ctx expired
|
||||
_, err := RelayConnect(ctx, ws.URL)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
@@ -230,9 +243,9 @@ func TestConnectWithOrigin(t *testing.T) {
|
||||
defer ws.Close()
|
||||
|
||||
// relay client
|
||||
r := NewRelay(context.Bg(), string(normalize.URL(ws.URL)))
|
||||
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
|
||||
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
|
||||
r := NewRelay(context.Background(), string(normalize.URL(ws.URL)))
|
||||
r.requestHeader = http.Header{"origin": {"https://example.com"}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
err := r.Connect(ctx)
|
||||
if err != nil {
|
||||
@@ -263,7 +276,7 @@ var anyOriginHandshake = func(
|
||||
}
|
||||
|
||||
func mustRelayConnect(url string) (client *Client) {
|
||||
rl, err := RelayConnect(context.Bg(), url)
|
||||
rl, err := RelayConnect(context.Background(), url)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
@@ -1,224 +1,98 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gobwas/httphead"
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsflate"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"net/textproto"
|
||||
"time"
|
||||
|
||||
ws "github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// Connection is an outbound client -> relay connection.
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
enableCompression bool
|
||||
controlHandler wsutil.FrameHandlerFunc
|
||||
flateReader *wsflate.Reader
|
||||
reader *wsutil.Reader
|
||||
flateWriter *wsflate.Writer
|
||||
writer *wsutil.Writer
|
||||
msgStateR *wsflate.MessageState
|
||||
msgStateW *wsflate.MessageState
|
||||
var defaultConnectionOptions = &ws.DialOptions{
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPHeader: http.Header{
|
||||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
|
||||
},
|
||||
}
|
||||
|
||||
// NewConnection creates a new Connection.
|
||||
func NewConnection(
|
||||
c context.T, url string, requestHeader http.Header,
|
||||
tlsConfig *tls.Config,
|
||||
) (connection *Connection, errResult error) {
|
||||
dialer := ws.Dialer{
|
||||
Header: ws.HandshakeHeaderHTTP(requestHeader),
|
||||
Extensions: []httphead.Option{
|
||||
wsflate.DefaultParameters.Option(),
|
||||
},
|
||||
TLSConfig: tlsConfig,
|
||||
func getConnectionOptions(
|
||||
requestHeader http.Header, tlsConfig *tls.Config,
|
||||
) *ws.DialOptions {
|
||||
if requestHeader == nil && tlsConfig == nil {
|
||||
return defaultConnectionOptions
|
||||
}
|
||||
conn, _, hs, err := dialer.Dial(c, url)
|
||||
|
||||
return &ws.DialOptions{
|
||||
HTTPHeader: requestHeader,
|
||||
CompressionMode: ws.CompressionContextTakeover,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connection represents a websocket connection to a Nostr relay.
|
||||
type Connection struct {
|
||||
conn *ws.Conn
|
||||
}
|
||||
|
||||
// NewConnection creates a new websocket connection to a Nostr relay.
|
||||
func NewConnection(
|
||||
ctx context.Context, url string, requestHeader http.Header,
|
||||
tlsConfig *tls.Config,
|
||||
) (*Connection, error) {
|
||||
c, _, err := ws.Dial(
|
||||
ctx, url, getConnectionOptions(requestHeader, tlsConfig),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enableCompression := false
|
||||
state := ws.StateClientSide
|
||||
for _, extension := range hs.Extensions {
|
||||
if string(extension.Name) == wsflate.ExtensionName {
|
||||
enableCompression = true
|
||||
state |= ws.StateExtended
|
||||
break
|
||||
}
|
||||
}
|
||||
// reader
|
||||
var flateReader *wsflate.Reader
|
||||
var msgStateR wsflate.MessageState
|
||||
if enableCompression {
|
||||
msgStateR.SetCompressed(true)
|
||||
|
||||
flateReader = wsflate.NewReader(
|
||||
nil, func(r io.Reader) wsflate.Decompressor {
|
||||
return flate.NewReader(r)
|
||||
},
|
||||
)
|
||||
}
|
||||
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
|
||||
reader := &wsutil.Reader{
|
||||
Source: conn,
|
||||
State: state,
|
||||
OnIntermediate: controlHandler,
|
||||
CheckUTF8: false,
|
||||
Extensions: []wsutil.RecvExtension{
|
||||
&msgStateR,
|
||||
},
|
||||
}
|
||||
// writer
|
||||
var flateWriter *wsflate.Writer
|
||||
var msgStateW wsflate.MessageState
|
||||
if enableCompression {
|
||||
msgStateW.SetCompressed(true)
|
||||
c.SetReadLimit(2 << 24) // 33MB
|
||||
|
||||
flateWriter = wsflate.NewWriter(
|
||||
nil, func(w io.Writer) wsflate.Compressor {
|
||||
fw, err := flate.NewWriter(w, 4)
|
||||
if err != nil {
|
||||
log.E.F("Failed to create flate writer: %v", err)
|
||||
}
|
||||
return fw
|
||||
},
|
||||
)
|
||||
}
|
||||
writer := wsutil.NewWriter(conn, state, ws.OpText)
|
||||
writer.SetExtensions(&msgStateW)
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
enableCompression: enableCompression,
|
||||
controlHandler: controlHandler,
|
||||
flateReader: flateReader,
|
||||
reader: reader,
|
||||
msgStateR: &msgStateR,
|
||||
flateWriter: flateWriter,
|
||||
writer: writer,
|
||||
msgStateW: &msgStateW,
|
||||
conn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WriteMessage dispatches a message through the Connection.
|
||||
func (cn *Connection) WriteMessage(c context.T, data []byte) (err error) {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return errorf.E(
|
||||
"%s context canceled",
|
||||
cn.conn.RemoteAddr(),
|
||||
)
|
||||
default:
|
||||
// WriteMessage writes arbitrary bytes to the websocket connection.
|
||||
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
|
||||
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
if cn.msgStateW.IsCompressed() && cn.enableCompression {
|
||||
cn.flateWriter.Reset(cn.writer)
|
||||
if _, err := io.Copy(
|
||||
cn.flateWriter, bytes.NewReader(data),
|
||||
); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to write message: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if err := cn.flateWriter.Close(); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to close flate writer: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(cn.writer, bytes.NewReader(data)); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to write message: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer.
|
||||
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
|
||||
_, reader, err := c.conn.Reader(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get reader: %w", err)
|
||||
}
|
||||
if err := cn.writer.Flush(); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to flush writer: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
if _, err := io.Copy(buf, reader); err != nil {
|
||||
return fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMessage picks up the next incoming message on a Connection.
|
||||
func (cn *Connection) ReadMessage(c context.T, buf io.Writer) (err error) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return errorf.D(
|
||||
"%s context canceled",
|
||||
cn.conn.RemoteAddr(),
|
||||
)
|
||||
default:
|
||||
}
|
||||
h, err := cn.reader.NextFrame()
|
||||
if err != nil {
|
||||
cn.conn.Close()
|
||||
return fmt.Errorf(
|
||||
"%s failed to advance frame: %s",
|
||||
cn.conn.RemoteAddr(),
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if h.OpCode.IsControl() {
|
||||
if err := cn.controlHandler(h, cn.reader); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to handle control frame: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
} else if h.OpCode == ws.OpBinary ||
|
||||
h.OpCode == ws.OpText {
|
||||
break
|
||||
}
|
||||
if err := cn.reader.Discard(); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to discard: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
if cn.msgStateR.IsCompressed() && cn.enableCompression {
|
||||
cn.flateReader.Reset(cn.reader)
|
||||
if _, err := io.Copy(buf, cn.flateReader); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to read message: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(buf, cn.reader); chk.T(err) {
|
||||
return errorf.E(
|
||||
"%s failed to read message: %w",
|
||||
cn.conn.RemoteAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
// Close closes the websocket connection.
|
||||
func (c *Connection) Close() error {
|
||||
return c.conn.Close(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// Close the Connection.
|
||||
func (cn *Connection) Close() (err error) {
|
||||
return cn.conn.Close()
|
||||
// Ping sends a ping message to the websocket connection.
|
||||
func (c *Connection) Ping(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeoutCause(
|
||||
ctx, time.Millisecond*800, errors.New("ping took too long"),
|
||||
)
|
||||
defer cancel()
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package ws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"orly.dev/pkg/app/relay/helpers"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/protocol/auth"
|
||||
atomic2 "orly.dev/pkg/utils/atomic"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fasthttp/websocket"
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
216
pkg/protocol/ws/pool_test.go
Normal file
216
pkg/protocol/ws/pool_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/filter"
|
||||
"orly.dev/pkg/encoders/kind"
|
||||
"orly.dev/pkg/encoders/timestamp"
|
||||
"orly.dev/pkg/interfaces/signer"
|
||||
)
|
||||
|
||||
// mockSigner implements signer.I for testing
|
||||
type mockSigner struct {
|
||||
pubkey []byte
|
||||
}
|
||||
|
||||
func (m *mockSigner) Pub() []byte { return m.pubkey }
|
||||
func (m *mockSigner) Sign([]byte) (
|
||||
[]byte, error,
|
||||
) {
|
||||
return []byte("mock-signature"), nil
|
||||
}
|
||||
func (m *mockSigner) Generate() error { return nil }
|
||||
func (m *mockSigner) InitSec([]byte) error { return nil }
|
||||
func (m *mockSigner) InitPub([]byte) error { return nil }
|
||||
func (m *mockSigner) Sec() []byte { return []byte("mock-secret") }
|
||||
func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil }
|
||||
func (m *mockSigner) Zero() {}
|
||||
func (m *mockSigner) ECDH([]byte) (
|
||||
[]byte, error,
|
||||
) {
|
||||
return []byte("mock-shared-secret"), nil
|
||||
}
|
||||
|
||||
func TestNewPool(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool := NewPool(ctx)
|
||||
|
||||
if pool == nil {
|
||||
t.Fatal("NewPool returned nil")
|
||||
}
|
||||
|
||||
if pool.Relays == nil {
|
||||
t.Error("Pool should have initialized Relays map")
|
||||
}
|
||||
|
||||
if pool.Context == nil {
|
||||
t.Error("Pool should have a context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithAuthHandler(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
authHandler := WithAuthHandler(
|
||||
func() signer.I {
|
||||
return &mockSigner{pubkey: []byte("test-pubkey")}
|
||||
},
|
||||
)
|
||||
|
||||
pool := NewPool(ctx, authHandler)
|
||||
|
||||
if pool.authHandler == nil {
|
||||
t.Error("Pool should have auth handler set")
|
||||
}
|
||||
|
||||
// Test that auth handler returns the expected signer
|
||||
signer := pool.authHandler()
|
||||
if string(signer.Pub()) != "test-pubkey" {
|
||||
t.Errorf(
|
||||
"Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithEventMiddleware(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var middlewareCalled bool
|
||||
middleware := WithEventMiddleware(
|
||||
func(ie RelayEvent) {
|
||||
middlewareCalled = true
|
||||
},
|
||||
)
|
||||
|
||||
pool := NewPool(ctx, middleware)
|
||||
|
||||
// Test that middleware is called
|
||||
testEvent := &event.E{
|
||||
Kind: kind.TextNote,
|
||||
Content: []byte("test"),
|
||||
CreatedAt: timestamp.Now(),
|
||||
}
|
||||
|
||||
ie := RelayEvent{E: testEvent, Relay: nil}
|
||||
pool.eventMiddleware(ie)
|
||||
|
||||
if !middlewareCalled {
|
||||
t.Error("Expected middleware to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayEventString(t *testing.T) {
|
||||
testEvent := &event.E{
|
||||
Kind: kind.TextNote,
|
||||
Content: []byte("test content"),
|
||||
CreatedAt: timestamp.Now(),
|
||||
}
|
||||
|
||||
client := &Client{URL: "wss://test.relay"}
|
||||
ie := RelayEvent{E: testEvent, Relay: client}
|
||||
|
||||
str := ie.String()
|
||||
if !contains(str, "wss://test.relay") {
|
||||
t.Errorf("Expected string to contain relay URL, got: %s", str)
|
||||
}
|
||||
|
||||
if !contains(str, "test content") {
|
||||
t.Errorf("Expected string to contain event content, got: %s", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedLock(t *testing.T) {
|
||||
// Test that named locks work correctly
|
||||
var wg sync.WaitGroup
|
||||
var counter int
|
||||
var mu sync.Mutex
|
||||
|
||||
lockName := "test-lock"
|
||||
|
||||
// Start multiple goroutines that try to increment counter
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
unlock := namedLock(lockName)
|
||||
defer unlock()
|
||||
|
||||
// Critical section
|
||||
mu.Lock()
|
||||
temp := counter
|
||||
time.Sleep(1 * time.Millisecond) // Simulate work
|
||||
counter = temp + 1
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if counter != 10 {
|
||||
t.Errorf("Expected counter to be 10, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolEnsureRelayInvalidURL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool := NewPool(ctx)
|
||||
|
||||
// Test with invalid URL
|
||||
_, err := pool.EnsureRelay("invalid-url")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolQuerySingle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool := NewPool(ctx)
|
||||
|
||||
// Test with empty URLs slice
|
||||
result := pool.QuerySingle(ctx, []string{}, &filter.F{})
|
||||
if result != nil {
|
||||
t.Error("Expected nil result for empty URLs")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
|
||||
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
|
||||
containsSubstring(s, substr)))
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func uintPtr(u uint) *uint {
|
||||
return &u
|
||||
}
|
||||
|
||||
// Test pool context cancellation
|
||||
func TestPoolContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := NewPool(ctx)
|
||||
|
||||
// Cancel the context
|
||||
cancel()
|
||||
|
||||
// Check that pool context is cancelled
|
||||
select {
|
||||
case <-pool.Context.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Expected pool context to be cancelled")
|
||||
}
|
||||
}
|
||||
@@ -1,100 +1,105 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"orly.dev/pkg/encoders/envelopes/closeenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/countenvelope"
|
||||
"orly.dev/pkg/encoders/envelopes/reqenvelope"
|
||||
"orly.dev/pkg/encoders/event"
|
||||
"orly.dev/pkg/encoders/filters"
|
||||
"orly.dev/pkg/encoders/subscription"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/errorf"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Subscription is a client interface for a subscription (what REQ turns into
|
||||
// after EOSE).
|
||||
// Subscription represents a subscription to a relay.
|
||||
type Subscription struct {
|
||||
label string
|
||||
counter int
|
||||
counter int64
|
||||
id string
|
||||
|
||||
Relay *Client
|
||||
Filters *filters.T
|
||||
|
||||
// for this to be treated as a COUNT and not a REQ this must be set
|
||||
countResult chan int
|
||||
// // for this to be treated as a COUNT and not a REQ this must be set
|
||||
// countResult chan CountEnvelope
|
||||
|
||||
// The Events channel emits all EVENTs that come in a Subscription will be
|
||||
// closed when the subscription ends
|
||||
// the Events channel emits all EVENTs that come in a Subscription
|
||||
// will be closed when the subscription ends
|
||||
Events event.C
|
||||
mu sync.Mutex
|
||||
|
||||
// The EndOfStoredEvents channel is closed when an EOSE comes for that
|
||||
// subscription
|
||||
// the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription
|
||||
EndOfStoredEvents chan struct{}
|
||||
|
||||
// The ClosedReason channel emits the reason when a CLOSED message is
|
||||
// received
|
||||
// the ClosedReason channel emits the reason when a CLOSED message is received
|
||||
ClosedReason chan string
|
||||
|
||||
// Context will be .Done() when the subscription ends
|
||||
Context context.T
|
||||
Context context.Context
|
||||
|
||||
// // if it is not nil, checkDuplicate will be called for every event received
|
||||
// // if it returns true that event will not be processed further.
|
||||
// checkDuplicate func(id string, relay string) bool
|
||||
//
|
||||
// // if it is not nil, checkDuplicateReplaceable will be called for every event received
|
||||
// // if it returns true that event will not be processed further.
|
||||
// checkDuplicateReplaceable func(rk ReplaceableKey, ts Timestamp) bool
|
||||
|
||||
match func(*event.E) bool // this will be either Filters.Match or Filters.MatchIgnoringTimestampConstraints
|
||||
live atomic.Bool
|
||||
eosed atomic.Bool
|
||||
closed atomic.Bool
|
||||
cancel context.F
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
// This keeps track of the events we've received before the EOSE that we
|
||||
// must dispatch before closing the EndOfStoredEvents channel
|
||||
// this keeps track of the events we've received before the EOSE that we must dispatch before
|
||||
// closing the EndOfStoredEvents channel
|
||||
storedwg sync.WaitGroup
|
||||
}
|
||||
|
||||
// EventMessage is an event, with the associated relay URL attached.
|
||||
type EventMessage struct {
|
||||
Event event.E
|
||||
Relay string
|
||||
}
|
||||
|
||||
// SubscriptionOption is the type of the argument passed for that. Some examples
|
||||
// are WithLabel.
|
||||
// SubscriptionOption is the type of the argument passed when instantiating relay connections.
|
||||
// Some examples are WithLabel.
|
||||
type SubscriptionOption interface {
|
||||
IsSubscriptionOption()
|
||||
}
|
||||
|
||||
// WithLabel puts a label on the subscription (it is prepended to the automatic
|
||||
// id) that is sent to relays.
|
||||
// WithLabel puts a label on the subscription (it is prepended to the automatic id) that is sent to relays.
|
||||
type WithLabel string
|
||||
|
||||
func (_ WithLabel) IsSubscriptionOption() {}
|
||||
|
||||
var _ SubscriptionOption = (WithLabel)("")
|
||||
// // WithCheckDuplicate sets checkDuplicate on the subscription
|
||||
// type WithCheckDuplicate func(id, relay string) bool
|
||||
//
|
||||
// func (_ WithCheckDuplicate) IsSubscriptionOption() {}
|
||||
//
|
||||
// // WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription
|
||||
// type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool
|
||||
//
|
||||
// func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {}
|
||||
|
||||
// GetID return the Nostr subscription ID as given to the Client it is a
|
||||
// concatenation of the label and a serial number.
|
||||
func (sub *Subscription) GetID() (id *subscription.Id) {
|
||||
var err error
|
||||
if id, err = subscription.NewId(sub.label + ":" + strconv.Itoa(sub.counter)); chk.E(err) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
var (
|
||||
_ SubscriptionOption = (WithLabel)("")
|
||||
// _ SubscriptionOption = (WithCheckDuplicate)(nil)
|
||||
// _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil)
|
||||
)
|
||||
|
||||
func (sub *Subscription) start() {
|
||||
<-sub.Context.Done()
|
||||
// the subscription ends once the context is canceled (if not already)
|
||||
sub.Unsub() // this will set sub.live to false
|
||||
|
||||
// do this so we don't have the possibility of closing the Events channel
|
||||
// and then trying to send to it
|
||||
// the subscription ends once the context is canceled (if not already)
|
||||
sub.unsub(errors.New("context done on start()")) // this will set sub.live to false
|
||||
|
||||
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
|
||||
sub.mu.Lock()
|
||||
close(sub.Events)
|
||||
sub.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetID returns the subscription ID.
|
||||
func (sub *Subscription) GetID() string { return sub.id }
|
||||
|
||||
func (sub *Subscription) dispatchEvent(evt *event.E) {
|
||||
added := false
|
||||
if !sub.eosed.Load() {
|
||||
@@ -112,7 +117,6 @@ func (sub *Subscription) dispatchEvent(evt *event.E) {
|
||||
case <-sub.Context.Done():
|
||||
}
|
||||
}
|
||||
|
||||
if added {
|
||||
sub.storedwg.Done()
|
||||
}
|
||||
@@ -121,6 +125,7 @@ func (sub *Subscription) dispatchEvent(evt *event.E) {
|
||||
|
||||
func (sub *Subscription) dispatchEose() {
|
||||
if sub.eosed.CompareAndSwap(false, true) {
|
||||
sub.match = sub.Filters.MatchIgnoringTimestampConstraints
|
||||
go func() {
|
||||
sub.storedwg.Wait()
|
||||
sub.EndOfStoredEvents <- struct{}{}
|
||||
@@ -128,62 +133,72 @@ func (sub *Subscription) dispatchEose() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sub *Subscription) dispatchClosed(reason string) {
|
||||
if sub.closed.CompareAndSwap(false, true) {
|
||||
go func() {
|
||||
sub.ClosedReason <- reason
|
||||
}()
|
||||
}
|
||||
// handleClosed handles the CLOSED message from a relay.
|
||||
func (sub *Subscription) handleClosed(reason string) {
|
||||
go func() {
|
||||
sub.ClosedReason <- reason
|
||||
sub.live.Store(false) // set this so we don't send an unnecessary CLOSE to the relay
|
||||
sub.unsub(fmt.Errorf("CLOSED received: %s", reason))
|
||||
}()
|
||||
}
|
||||
|
||||
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. Unsub()
|
||||
// also closes the channel sub.Events and makes a new one.
|
||||
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01.
|
||||
// Unsub() also closes the channel sub.Events and makes a new one.
|
||||
func (sub *Subscription) Unsub() {
|
||||
sub.unsub(errors.New("Unsub() called"))
|
||||
}
|
||||
|
||||
// unsub is the internal implementation of Unsub.
|
||||
func (sub *Subscription) unsub(err error) {
|
||||
// cancel the context (if it's not canceled already)
|
||||
sub.cancel()
|
||||
// mark the subscription as closed and send a CLOSE to the relay (naïve
|
||||
// sync.Once implementation)
|
||||
sub.cancel(err)
|
||||
|
||||
// mark subscription as closed and send a CLOSE to the relay (naïve sync.Once implementation)
|
||||
if sub.live.CompareAndSwap(true, false) {
|
||||
sub.Close()
|
||||
}
|
||||
|
||||
// remove subscription from our map
|
||||
sub.Relay.Subscriptions.Delete(sub.GetID().String())
|
||||
sub.Relay.Subscriptions.Delete(sub.counter)
|
||||
}
|
||||
|
||||
// Close just sends a CLOSE message. You probably want Unsub() instead.
|
||||
func (sub *Subscription) Close() {
|
||||
if sub.Relay.IsConnected() {
|
||||
id := sub.GetID()
|
||||
id, err := subscription.NewId(sub.id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
closeMsg := closeenvelope.NewFrom(id)
|
||||
var b []byte
|
||||
b = closeMsg.Marshal(nil)
|
||||
<-sub.Relay.Write(b)
|
||||
closeb := closeMsg.Marshal(nil)
|
||||
<-sub.Relay.Write(closeb)
|
||||
}
|
||||
}
|
||||
|
||||
// Sub sets sub.Filters and then calls sub.Fire(ctx). The subscription will be
|
||||
// closed if the context expires.
|
||||
func (sub *Subscription) Sub(_ context.T, ff *filters.T) {
|
||||
// Sub sets sub.Filters and then calls sub.Fire(ctx).
|
||||
// The subscription will be closed if the context expires.
|
||||
func (sub *Subscription) Sub(_ context.Context, ff *filters.T) {
|
||||
sub.Filters = ff
|
||||
sub.Fire()
|
||||
}
|
||||
|
||||
// Fire sends the "REQ" command to the relay.
|
||||
func (sub *Subscription) Fire() (err error) {
|
||||
id := sub.GetID()
|
||||
|
||||
var b []byte
|
||||
if sub.countResult == nil {
|
||||
b = reqenvelope.NewFrom(id, sub.Filters).Marshal(b)
|
||||
} else {
|
||||
b = countenvelope.NewRequest(id, sub.Filters).Marshal(b)
|
||||
// if sub.countResult == nil {
|
||||
req := reqenvelope.NewWithIdString(sub.id, sub.Filters)
|
||||
if req == nil {
|
||||
return fmt.Errorf("invalid ID or filters")
|
||||
}
|
||||
// log.T.F("{%s} sending %s", sub.Relay.URL, b)
|
||||
reqb := req.Marshal(nil)
|
||||
// } else
|
||||
// if len(sub.Filters) == 1 {
|
||||
// reqb, _ = CountEnvelope{sub.id, sub.Filters[0], nil, nil}.MarshalJSON()
|
||||
// } else {
|
||||
// return fmt.Errorf("unexpected sub configuration")
|
||||
sub.live.Store(true)
|
||||
if err = <-sub.Relay.Write(b); chk.T(err) {
|
||||
sub.cancel()
|
||||
return errorf.E("failed to write: %w", err)
|
||||
if err = <-sub.Relay.Write(reqb); chk.E(err) {
|
||||
err = fmt.Errorf("failed to write: %w", err)
|
||||
sub.cancel(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestBytesConcurrentAccess(t *testing.T) {
|
||||
loaded := atom.Load()
|
||||
|
||||
// Verify the loaded data is valid (either our data or another goroutine's data)
|
||||
require.LessOrEqual(t, len(loaded), parallelism,
|
||||
require.LessOrEqual(t, len(loaded), parallelism,
|
||||
"Loaded data length should not exceed parallelism")
|
||||
|
||||
// If it's our data, verify it's correct
|
||||
|
||||
@@ -28,8 +28,10 @@ var (
|
||||
TODO = context.TODO
|
||||
// Value - context.WithValue
|
||||
Value = context.WithValue
|
||||
// CancelCause - context.WithCancelCause
|
||||
CancelCause = context.WithCancelCause
|
||||
// Cause - context.WithCancelCause
|
||||
Cause = context.WithCancelCause
|
||||
|
||||
GetCause = context.Cause
|
||||
// Canceled - context.Canceled
|
||||
Canceled = context.Canceled
|
||||
)
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PointerToValue is a generic interface to refer to any pointer to almost any kind of common
|
||||
// type of value.
|
||||
// PointerToValue is a generic interface (type constraint) to refer to any
|
||||
// pointer to almost any kind of common type of value.
|
||||
//
|
||||
// see the utils/values package for a set of methods to accept these values and
|
||||
// return the correct type pointer to them.
|
||||
type PointerToValue interface {
|
||||
~*uint | ~*int | ~*uint8 | ~*uint16 | ~*uint32 | ~*uint64 | ~*int8 | ~*int16 | ~*int32 |
|
||||
~*int64 | ~*float32 | ~*float64 | ~*string | ~*[]string | ~*time.Time | ~*time.Duration |
|
||||
|
||||
101
pkg/utils/values/values.go
Normal file
101
pkg/utils/values/values.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package values
|
||||
|
||||
import (
|
||||
"orly.dev/pkg/encoders/unix"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ToUintPointer returns a pointer to the uint value passed in.
|
||||
func ToUintPointer(v uint) *uint {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToIntPointer returns a pointer to the int value passed in.
|
||||
func ToIntPointer(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToUint8Pointer returns a pointer to the uint8 value passed in.
|
||||
func ToUint8Pointer(v uint8) *uint8 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToUint16Pointer returns a pointer to the uint16 value passed in.
|
||||
func ToUint16Pointer(v uint16) *uint16 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToUint32Pointer returns a pointer to the uint32 value passed in.
|
||||
func ToUint32Pointer(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToUint64Pointer returns a pointer to the uint64 value passed in.
|
||||
func ToUint64Pointer(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToInt8Pointer returns a pointer to the int8 value passed in.
|
||||
func ToInt8Pointer(v int8) *int8 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToInt16Pointer returns a pointer to the int16 value passed in.
|
||||
func ToInt16Pointer(v int16) *int16 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToInt32Pointer returns a pointer to the int32 value passed in.
|
||||
func ToInt32Pointer(v int32) *int32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToInt64Pointer returns a pointer to the int64 value passed in.
|
||||
func ToInt64Pointer(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToFloat32Pointer returns a pointer to the float32 value passed in.
|
||||
func ToFloat32Pointer(v float32) *float32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToFloat64Pointer returns a pointer to the float64 value passed in.
|
||||
func ToFloat64Pointer(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToStringPointer returns a pointer to the string value passed in.
|
||||
func ToStringPointer(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToStringSlicePointer returns a pointer to the []string value passed in.
|
||||
func ToStringSlicePointer(v []string) *[]string {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToTimePointer returns a pointer to the time.Time value passed in.
|
||||
func ToTimePointer(v time.Time) *time.Time {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToDurationPointer returns a pointer to the time.Duration value passed in.
|
||||
func ToDurationPointer(v time.Duration) *time.Duration {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToBytesPointer returns a pointer to the []byte value passed in.
|
||||
func ToBytesPointer(v []byte) *[]byte {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToByteSlicesPointer returns a pointer to the [][]byte value passed in.
|
||||
func ToByteSlicesPointer(v [][]byte) *[][]byte {
|
||||
return &v
|
||||
}
|
||||
|
||||
// ToUnixTimePointer returns a pointer to the unix.Time value passed in.
|
||||
func ToUnixTimePointer(v unix.Time) *unix.Time {
|
||||
return &v
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
v0.4.7
|
||||
v0.5.4
|
||||
@@ -7,6 +7,6 @@ import (
|
||||
//go:embed version
|
||||
var V string
|
||||
|
||||
var Description = "relay powered by the orly framework"
|
||||
var Description = "relay powered by the orly framework https://orly.dev"
|
||||
|
||||
var URL = "https://orly.dev"
|
||||
|
||||
50
readme.adoc
50
readme.adoc
@@ -12,12 +12,12 @@ and https://github.com/fiatjaf/relayer[fiatjaf/relayer] aimed at maximum perform
|
||||
|
||||
== Features
|
||||
|
||||
* a lot of bits and pieces accumulated from nearly 8 years of working with Go, logging and run control, XDG user data directories (windows, mac, linux, android) (todo: this is mostly built and designed but not currently available)
|
||||
* a cleaned up and unified fork of the btcd/dcred BIP-340 signatures, including the use of bitcoin core's BIP-340 implementation (more than 4x faster than btcd) (todo: ECDH from the C library tbd). (todo: HTTP API not in this repo yet but coming soon TM)
|
||||
* a lot of bits and pieces accumulated from nearly 8 years of working with Go, logging and run control, XDG user data directories (windows, mac, linux, android)
|
||||
* a cleaned up and unified fork of the btcd/dcred BIP-340 signatures, including the use of bitcoin core's BIP-340 implementation (more than 4x faster than btcd) (todo: ECDH from the C library tbd).
|
||||
* AVX/AVX2 optimized SHA256 and SIMD hex encoder
|
||||
* https://github.com/bitcoin/secp256k1[libsecp256k1]-enabled signature and signature verification (see link:p256k/README.md[here]).
|
||||
* efficient, mutable byte slice-based hash/pubkey/signature encoding in memory (zero allocation decode from wire, can tolerate whitespace, at a speed penalty)
|
||||
* custom badger-based event store with an optional garbage collector that uses fast binary encoder for storage of events.
|
||||
* efficient, mutable byte slice-based hash/pubkey/signature encoding in memory (zero allocation decode from wire of all but id/pubkey/signature, can tolerate whitespace, at a speed penalty)
|
||||
* custom badger-based event store that uses fast binary encoder for storage of events, and has a complete set of indexes so it doesn't need to decode events for any query until delivering them.
|
||||
* link:cmd/vainstr[vainstr] vanity npub generator that can mine a 5-letter suffix in around 15 minutes on a 6 core Ryzen 5 processor using the CGO bitcoin core signature library.
|
||||
* reverse proxy tool link:cmd/lerproxy[lerproxy] with support for Go vanity imports and https://github.com/nostr-protocol/nips/blob/master/05.md[nip-05] npub DNS verification and own TLS certificates
|
||||
* link:https://github.com/nostr-protocol/nips/blob/master/98.md[nip-98] implementation with new expiring variant for vanilla HTTP tools and browsers.
|
||||
@@ -85,6 +85,46 @@ To see the current active configuration:
|
||||
orly env
|
||||
----
|
||||
|
||||
To see the help information:
|
||||
|
||||
----
|
||||
orly help
|
||||
----
|
||||
|
||||
Environment variables that configure orly:
|
||||
|
||||
[cols="4"]
|
||||
|===
|
||||
| Environment variable | type | default | description
|
||||
| ORLY_APP_NAME | string | orly |
|
||||
| ORLY_CONFIG_DIR | string | ~/.config/orly | location for configuration file, which has the name '.env' to make it harder to delete, and is a standard environment KEY=value<newline>... style
|
||||
| ORLY_STATE_DATA_DIR | string | ~/.local/state/orly | storage location for state data affected by dynamic interactive interfaces
|
||||
| ORLY_DATA_DIR | string | ~/.local/cache/orly | storage location for the event store
|
||||
| ORLY_LISTEN | string | 0.0.0.0 | network listen address
|
||||
| ORLY_PORT | int | 3334 | port to listen on
|
||||
| ORLY_LOG_LEVEL | string | info | debug level: fatal error warn info debug trace
|
||||
| ORLY_DB_LOG_LEVEL | string | info | debug level: fatal error warn info debug trace
|
||||
| ORLY_PPROF | string | <empty> | enable pprof on 127.0.0.1:6060
|
||||
| ORLY_AUTH_REQUIRED | bool | false | require authentication for all requests
|
||||
| ORLY_PUBLIC_READABLE | bool | true | allow public read access to regardless of whether the client is authed
|
||||
| ORLY_SPIDER_SEEDS | []string | wss://profiles.nostr1.com/,
|
||||
wss://relay.nostr.band/,
|
||||
wss://relay.damus.io/,
|
||||
wss://nostr.wine/,
|
||||
wss://nostr.land/,
|
||||
wss://theforest.nostr1.com/,
|
||||
wss://profiles.nostr1.com
|
||||
| seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)
|
||||
| ORLY_SPIDER_TYPE | string | directory | whether to spider, and what degree of spidering: none, directory, follows (follows means to the second degree of the follow graph)
|
||||
| ORLY_SPIDER_FREQUENCY | time.Duration | 1h | how often to run the spider, uses notation 0h0m0s
|
||||
| ORLY_SPIDER_SECOND_DEGREE | bool | true | whether to enable spidering the second degree of follows for non-directory events if ORLY_SPIDER_TYPE is set to 'follows'
|
||||
| ORLY_OWNERS | []string | [] | list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)
|
||||
| ORLY_PRIVATE | bool | false | do not spider for user metadata because the relay is private and this would leak relay memberships
|
||||
| ORLY_WHITELIST | []string | [] | only allow connections from this list of IP addresses
|
||||
| ORLY_SECRET_KEY | string | <empty> | secret key for relay cluster replication authentication
|
||||
| ORLY_PEER_RELAYS | []string | [] | list of peer relays URLs that new events are pushed to in format <pubkey>\|<url>
|
||||
|===
|
||||
|
||||
=== Create Persistent Configuration
|
||||
|
||||
This output can be directed to the profile location to make the settings editable without manually setting them on the
|
||||
@@ -122,8 +162,6 @@ messages and it uses and parses relay lists, and all that other stuff.
|
||||
[#_simplified_nostr]
|
||||
=== Simplified Nostr
|
||||
|
||||
NOTE: this is not currently implemented. coming soon TM
|
||||
|
||||
Rather than write a text that will likely fall out of date very quickly, simply run `orly` and visit its listener
|
||||
address (eg link:http://localhost:3334/api[http://localhost:3334/api]) to see the full documentation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user