Compare commits

...

13 Commits

Author SHA1 Message Date
b351d0fb78 fix bugs in tag comparison code
nwc walletcli now works!

bumped to v0.5.0 because NWC client now in and available
2025-08-07 09:32:53 +01:00
9c8ff2976d backporting relay client and pool from latest go-nostr 2025-08-06 22:18:26 +01:00
a7dd958585 Renamed NWC client methods and added RPCRaw wrappers
*   Renamed `NWCClient` to `nwc.NewNWCClient(opts)` in `cmd/nwcclient/main.go`
*   Added `RPCRaw` wrappers for NWC client methods in `pkg/protocol/nwc/methods.go`

**Updated walletcli main function**

*   Updated the main function in `cmd/walletcli/main.go` to use new NWC client and RPCRaw wrappers

**Added new methods for walletcli**

*   Added new methods for handling NWC client RPC calls, such as:
    *   `handleGetWalletServiceInfo`
    *   `handleMakeHoldInvoice`
    *   `handleSettleHoldInvoice`
    *   `handleCancelHoldInvoice`

**Code formatting and style changes**

*   Formatted code according to Go standard
*   Used consistent naming conventions and coding styles

**Other updates**

*   Updated dependencies and imported packages accordingly
2025-08-06 10:03:16 +01:00
8eb5b839b0 add all methods except multi
- added extra types and corrected struct tags to conform with js-sdk
- implement all unimplemented RPC call method wrappers except the multi methods
2025-08-06 00:23:03 +01:00
e57169eeae add blacklist and add to accept-event.go; Bump version to v0.4.15 2025-08-05 23:15:13 +01:00
109326dfa3 Merge remote-tracking branch 'origin/main' 2025-08-05 23:06:24 +01:00
52911354a7 Merge pull request #5 from kwsantiago/kwsantiago/1-public-relay-with-blacklist
feat: Add blacklist support for public relays

adds a simple explicit blacklist configuration and exclude in the event handling
2025-08-05 23:05:22 +01:00
b74f4757e7 refactor: Simplify NWC protocol structures and update method handling
- cmd/lerproxy/app/bufpool.go
  - Removed bufferPool-related code and `Pool` struct

- cmd/nwcclient/main.go
  - Renamed `Method` to `Capability` for clarity in method handling

- pkg/utils/values/values.go
  - Added utility functions to return pointers for various types

- pkg/utils/pointers/pointers.go
  - Revised documentation to reference `utils/values` package for pointer utilities

- pkg/protocol/nwc/types.go
  - Replaced redundant types and structures with simplified versions
  - Introduced dedicated structs for `MakeInvoice`, `PayInvoice`, and related results
  - Refactored `Transaction` and its fields for consistent type usage

- pkg/protocol/nwc/uri.go
  - Added `ParseConnectionURI` function for URI parsing and validation

- pkg/protocol/nwc/client.go
  - Refactored `Client` struct to improve key management and relay handling
  - Introduced `Request` struct for generic method invocation payloads
2025-08-05 20:18:32 +01:00
2d0ebfe032 Merge remote-tracking branch 'upstream/main' into kwsantiago/1-public-relay-with-blacklist 2025-08-05 14:15:02 -04:00
fff61ceca1 fmt 2025-08-05 14:09:15 -04:00
b7b7dc7353 feat: Add blacklist support for public relays 2025-08-05 14:09:01 -04:00
996fb3aeb7 Merge pull request #4 from kwsantiago/kwsantiago/benchmark
feat: Add Relay Performance Benchmark Tool
2025-08-05 18:19:28 +01:00
b9a713d81d simple performance benchmark tool 2025-08-05 10:54:53 -04:00
74 changed files with 4103 additions and 2538 deletions

View File

@@ -0,0 +1,173 @@
# Orly Relay Benchmark Results
## Test Environment
- **Date**: August 5, 2025
- **Relay**: Orly v0.4.14
- **Port**: 3334 (WebSocket)
- **System**: Linux 5.15.0-151-generic
- **Storage**: BadgerDB v4
## Benchmark Test Results
### Test 1: Basic Performance (1,000 events, 1KB each)
**Parameters:**
- Events: 1,000
- Event size: 1,024 bytes
- Concurrent publishers: 5
- Queries: 50
**Results:**
```
Publish Performance:
Events Published: 1,000
Total Data: 4.01 MB
Duration: 1.769s
Rate: 565.42 events/second
Bandwidth: 2.26 MB/second
Query Performance:
Queries Executed: 50
Events Returned: 2,000
Duration: 3.058s
Rate: 16.35 queries/second
Avg Events/Query: 40.00
```
### Test 2: Medium Load (10,000 events, 2KB each)
**Parameters:**
- Events: 10,000
- Event size: 2,048 bytes
- Concurrent publishers: 10
- Queries: 100
**Results:**
```
Publish Performance:
Events Published: 10,000
Total Data: 76.81 MB
Duration: 598.301ms
Rate: 16,714.00 events/second
Bandwidth: 128.38 MB/second
Query Performance:
Queries Executed: 100
Events Returned: 4,000
Duration: 8.923s
Rate: 11.21 queries/second
Avg Events/Query: 40.00
```
### Test 3: High Concurrency (50,000 events, 512 bytes each)
**Parameters:**
- Events: 50,000
- Event size: 512 bytes
- Concurrent publishers: 50
- Queries: 200
**Results:**
```
Publish Performance:
Events Published: 50,000
Total Data: 108.63 MB
Duration: 2.368s
Rate: 21,118.66 events/second
Bandwidth: 45.88 MB/second
Query Performance:
Queries Executed: 200
Events Returned: 8,000
Duration: 36.146s
Rate: 5.53 queries/second
Avg Events/Query: 40.00
```
### Test 4: Large Events (5,000 events, 10KB each)
**Parameters:**
- Events: 5,000
- Event size: 10,240 bytes
- Concurrent publishers: 10
- Queries: 50
**Results:**
```
Publish Performance:
Events Published: 5,000
Total Data: 185.26 MB
Duration: 934.328ms
Rate: 5,351.44 events/second
Bandwidth: 198.28 MB/second
Query Performance:
Queries Executed: 50
Events Returned: 2,000
Duration: 9.982s
Rate: 5.01 queries/second
Avg Events/Query: 40.00
```
### Test 5: Query-Only Performance (500 queries)
**Parameters:**
- Skip publishing phase
- Queries: 500
- Query limit: 100
**Results:**
```
Query Performance:
Queries Executed: 500
Events Returned: 20,000
Duration: 1m14.384s
Rate: 6.72 queries/second
Avg Events/Query: 40.00
```
## Performance Summary
### Publishing Performance
| Metric | Best Result | Test Configuration |
|--------|-------------|-------------------|
| **Peak Event Rate** | 21,118.66 events/sec | 50 concurrent publishers, 512-byte events |
| **Peak Bandwidth** | 198.28 MB/sec | 10 concurrent publishers, 10KB events |
| **Optimal Balance** | 16,714.00 events/sec @ 128.38 MB/sec | 10 concurrent publishers, 2KB events |
### Query Performance
| Query Type | Avg Rate | Notes |
|------------|----------|--------|
| **Light Load** | 16.35 queries/sec | 50 queries after 1K events |
| **Medium Load** | 11.21 queries/sec | 100 queries after 10K events |
| **Heavy Load** | 5.53 queries/sec | 200 queries after 50K events |
| **Sustained** | 6.72 queries/sec | 500 continuous queries |
## Key Findings
1. **Optimal Concurrency**: The relay performs best with 10-50 concurrent publishers, achieving rates of 16,000-21,000 events/second.
2. **Event Size Impact**:
- Smaller events (512B-2KB) achieve higher event rates
- Larger events (10KB) achieve higher bandwidth utilization but lower event rates
3. **Query Performance**: Query performance varies with database size:
- Fresh database: ~16 queries/second
- After 50K events: ~6 queries/second
4. **Scalability**: The relay maintains consistent performance up to 50 concurrent connections and can sustain 21,000+ events/second under optimal conditions.
## Query Filter Distribution
The benchmark tested 5 different query patterns in rotation:
1. Query by kind (20%)
2. Query by time range (20%)
3. Query by tag (20%)
4. Query by author (20%)
5. Complex queries with multiple conditions (20%)
All query types showed similar performance characteristics, indicating well-balanced indexing.

112
cmd/benchmark/README.md Normal file
View File

@@ -0,0 +1,112 @@
# Orly Relay Benchmark Tool
A performance benchmarking tool for Nostr relays that tests both event ingestion speed and query performance.
## Quick Start (Simple Version)
The repository includes a simple standalone benchmark tool that doesn't require the full Orly dependencies:
```bash
# Build the simple benchmark
go build -o benchmark-simple ./benchmark_simple.go
# Run with default settings
./benchmark-simple
# Or use the convenience script
chmod +x run_benchmark.sh
./run_benchmark.sh --relay ws://localhost:7447 --events 10000
```
## Features
- **Event Publishing Benchmark**: Tests how fast a relay can accept and store events
- **Query Performance Benchmark**: Tests various filter types and query speeds
- **Concurrent Publishing**: Supports multiple concurrent publishers to stress test the relay
- **Detailed Metrics**: Reports events/second, bandwidth usage, and query performance
## Usage
```bash
# Build the tool
go build -o benchmark ./cmd/benchmark
# Run a full benchmark (publish and query)
./benchmark -relay ws://localhost:7447 -events 10000 -queries 100
# Benchmark only publishing
./benchmark -relay ws://localhost:7447 -events 50000 -concurrency 20 -skip-query
# Benchmark only querying
./benchmark -relay ws://localhost:7447 -queries 500 -skip-publish
# Use custom event sizes
./benchmark -relay ws://localhost:7447 -events 10000 -size 2048
```
## Options
- `-relay`: Relay URL to benchmark (default: ws://localhost:7447)
- `-events`: Number of events to publish (default: 10000)
- `-size`: Average size of event content in bytes (default: 1024)
- `-concurrency`: Number of concurrent publishers (default: 10)
- `-queries`: Number of queries to execute (default: 100)
- `-query-limit`: Limit for each query (default: 100)
- `-skip-publish`: Skip the publishing phase
- `-skip-query`: Skip the query phase
- `-v`: Enable verbose output
## Query Types Tested
The benchmark tests various query patterns:
1. Query by kind
2. Query by time range (last hour)
3. Query by tag (p tags)
4. Query by author
5. Complex queries with multiple conditions
## Output
The tool provides detailed metrics including:
**Publish Performance:**
- Total events published
- Total data transferred
- Publishing rate (events/second)
- Bandwidth usage (MB/second)
**Query Performance:**
- Total queries executed
- Total events returned
- Query rate (queries/second)
- Average events per query
## Example Output
```
Publishing 10000 events to ws://localhost:7447...
Published 1000 events...
Published 2000 events...
...
Querying events from ws://localhost:7447...
Executed 20 queries...
Executed 40 queries...
...
=== Benchmark Results ===
Publish Performance:
Events Published: 10000
Total Data: 12.34 MB
Duration: 5.2s
Rate: 1923.08 events/second
Bandwidth: 2.37 MB/second
Query Performance:
Queries Executed: 100
Events Returned: 4523
Duration: 2.1s
Rate: 47.62 queries/second
Avg Events/Query: 45.23
```

View File

@@ -0,0 +1,304 @@
// +build ignore
package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"log"
"math/rand"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
// Simple event structure for benchmarking
type Event struct {
ID string `json:"id"`
Pubkey string `json:"pubkey"`
CreatedAt int64 `json:"created_at"`
Kind int `json:"kind"`
Tags [][]string `json:"tags"`
Content string `json:"content"`
Sig string `json:"sig"`
}
// Generate a test event
func generateTestEvent(size int) *Event {
content := make([]byte, size)
rand.Read(content)
// Generate random pubkey and sig
pubkey := make([]byte, 32)
sig := make([]byte, 64)
rand.Read(pubkey)
rand.Read(sig)
ev := &Event{
Pubkey: hex.EncodeToString(pubkey),
CreatedAt: time.Now().Unix(),
Kind: 1,
Tags: [][]string{},
Content: string(content),
Sig: hex.EncodeToString(sig),
}
// Generate ID (simplified)
serialized, _ := json.Marshal([]interface{}{
0,
ev.Pubkey,
ev.CreatedAt,
ev.Kind,
ev.Tags,
ev.Content,
})
hash := sha256.Sum256(serialized)
ev.ID = hex.EncodeToString(hash[:])
return ev
}
func publishEvents(relayURL string, count int, size int, concurrency int) (int64, int64, time.Duration, error) {
u, err := url.Parse(relayURL)
if err != nil {
return 0, 0, 0, err
}
var publishedEvents atomic.Int64
var publishedBytes atomic.Int64
var wg sync.WaitGroup
eventsPerWorker := count / concurrency
extraEvents := count % concurrency
start := time.Now()
for i := 0; i < concurrency; i++ {
wg.Add(1)
eventsToPublish := eventsPerWorker
if i < extraEvents {
eventsToPublish++
}
go func(workerID int, eventCount int) {
defer wg.Done()
// Connect to relay
ctx := context.Background()
conn, _, _, err := ws.Dial(ctx, u.String())
if err != nil {
log.Printf("Worker %d: connection error: %v", workerID, err)
return
}
defer conn.Close()
// Publish events
for j := 0; j < eventCount; j++ {
ev := generateTestEvent(size)
// Create EVENT message
msg, _ := json.Marshal([]interface{}{"EVENT", ev})
err := wsutil.WriteClientMessage(conn, ws.OpText, msg)
if err != nil {
log.Printf("Worker %d: write error: %v", workerID, err)
continue
}
publishedEvents.Add(1)
publishedBytes.Add(int64(len(msg)))
// Read response (OK or error)
_, _, err = wsutil.ReadServerData(conn)
if err != nil {
log.Printf("Worker %d: read error: %v", workerID, err)
}
}
}(i, eventsToPublish)
}
wg.Wait()
duration := time.Since(start)
return publishedEvents.Load(), publishedBytes.Load(), duration, nil
}
func queryEvents(relayURL string, queries int, limit int) (int64, int64, time.Duration, error) {
u, err := url.Parse(relayURL)
if err != nil {
return 0, 0, 0, err
}
ctx := context.Background()
conn, _, _, err := ws.Dial(ctx, u.String())
if err != nil {
return 0, 0, 0, err
}
defer conn.Close()
var totalQueries int64
var totalEvents int64
start := time.Now()
for i := 0; i < queries; i++ {
// Generate various filter types
var filter map[string]interface{}
switch i % 5 {
case 0:
// Query by kind
filter = map[string]interface{}{
"kinds": []int{1},
"limit": limit,
}
case 1:
// Query by time range
now := time.Now().Unix()
filter = map[string]interface{}{
"since": now - 3600,
"until": now,
"limit": limit,
}
case 2:
// Query by tag
filter = map[string]interface{}{
"#p": []string{hex.EncodeToString(randBytes(32))},
"limit": limit,
}
case 3:
// Query by author
filter = map[string]interface{}{
"authors": []string{hex.EncodeToString(randBytes(32))},
"limit": limit,
}
case 4:
// Complex query
now := time.Now().Unix()
filter = map[string]interface{}{
"kinds": []int{1, 6},
"authors": []string{hex.EncodeToString(randBytes(32))},
"since": now - 7200,
"limit": limit,
}
}
// Send REQ
subID := fmt.Sprintf("bench-%d", i)
msg, _ := json.Marshal([]interface{}{"REQ", subID, filter})
err := wsutil.WriteClientMessage(conn, ws.OpText, msg)
if err != nil {
log.Printf("Query %d: write error: %v", i, err)
continue
}
// Read events until EOSE
eventCount := 0
for {
data, err := wsutil.ReadServerText(conn)
if err != nil {
log.Printf("Query %d: read error: %v", i, err)
break
}
var msg []interface{}
if err := json.Unmarshal(data, &msg); err != nil {
continue
}
if len(msg) < 2 {
continue
}
msgType, ok := msg[0].(string)
if !ok {
continue
}
switch msgType {
case "EVENT":
eventCount++
case "EOSE":
goto done
}
}
done:
// Send CLOSE
closeMsg, _ := json.Marshal([]interface{}{"CLOSE", subID})
wsutil.WriteClientMessage(conn, ws.OpText, closeMsg)
totalQueries++
totalEvents += int64(eventCount)
if totalQueries%20 == 0 {
fmt.Printf(" Executed %d queries...\n", totalQueries)
}
}
duration := time.Since(start)
return totalQueries, totalEvents, duration, nil
}
func randBytes(n int) []byte {
b := make([]byte, n)
rand.Read(b)
return b
}
func main() {
var (
relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark")
eventCount = flag.Int("events", 10000, "Number of events to publish")
eventSize = flag.Int("size", 1024, "Average size of event content in bytes")
concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers")
queryCount = flag.Int("queries", 100, "Number of queries to execute")
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase")
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
)
flag.Parse()
fmt.Printf("=== Nostr Relay Benchmark ===\n\n")
// Phase 1: Publish events
if !*skipPublish {
fmt.Printf("Publishing %d events to %s...\n", *eventCount, *relayURL)
published, bytes, duration, err := publishEvents(*relayURL, *eventCount, *eventSize, *concurrency)
if err != nil {
log.Fatalf("Publishing failed: %v", err)
}
fmt.Printf("\nPublish Performance:\n")
fmt.Printf(" Events Published: %d\n", published)
fmt.Printf(" Total Data: %.2f MB\n", float64(bytes)/1024/1024)
fmt.Printf(" Duration: %s\n", duration)
fmt.Printf(" Rate: %.2f events/second\n", float64(published)/duration.Seconds())
fmt.Printf(" Bandwidth: %.2f MB/second\n", float64(bytes)/duration.Seconds()/1024/1024)
}
// Phase 2: Query events
if !*skipQuery {
fmt.Printf("\nQuerying events from %s...\n", *relayURL)
queries, events, duration, err := queryEvents(*relayURL, *queryCount, *queryLimit)
if err != nil {
log.Fatalf("Querying failed: %v", err)
}
fmt.Printf("\nQuery Performance:\n")
fmt.Printf(" Queries Executed: %d\n", queries)
fmt.Printf(" Events Returned: %d\n", events)
fmt.Printf(" Duration: %s\n", duration)
fmt.Printf(" Rate: %.2f queries/second\n", float64(queries)/duration.Seconds())
fmt.Printf(" Avg Events/Query: %.2f\n", float64(events)/float64(queries))
}
}

320
cmd/benchmark/main.go Normal file
View File

@@ -0,0 +1,320 @@
package main
import (
"flag"
"fmt"
"lukechampine.com/frand"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/text"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/protocol/ws"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/lol"
"os"
"sync"
"sync/atomic"
"time"
)
type BenchmarkResults struct {
EventsPublished int64
EventsPublishedBytes int64
PublishDuration time.Duration
PublishRate float64
PublishBandwidth float64
QueriesExecuted int64
QueryDuration time.Duration
QueryRate float64
EventsReturned int64
}
func main() {
var (
relayURL = flag.String("relay", "ws://localhost:7447", "Relay URL to benchmark")
eventCount = flag.Int("events", 10000, "Number of events to publish")
eventSize = flag.Int("size", 1024, "Average size of event content in bytes")
concurrency = flag.Int("concurrency", 10, "Number of concurrent publishers")
queryCount = flag.Int("queries", 100, "Number of queries to execute")
queryLimit = flag.Int("query-limit", 100, "Limit for each query")
skipPublish = flag.Bool("skip-publish", false, "Skip publishing phase")
skipQuery = flag.Bool("skip-query", false, "Skip query phase")
verbose = flag.Bool("v", false, "Verbose output")
)
flag.Parse()
if *verbose {
lol.SetLogLevel("trace")
}
c := context.Bg()
results := &BenchmarkResults{}
// Phase 1: Publish events
if !*skipPublish {
fmt.Printf("Publishing %d events to %s...\n", *eventCount, *relayURL)
if err := benchmarkPublish(c, *relayURL, *eventCount, *eventSize, *concurrency, results); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during publish benchmark: %v\n", err)
os.Exit(1)
}
}
// Phase 2: Query events
if !*skipQuery {
fmt.Printf("\nQuerying events from %s...\n", *relayURL)
if err := benchmarkQuery(c, *relayURL, *queryCount, *queryLimit, results); chk.E(err) {
fmt.Fprintf(os.Stderr, "Error during query benchmark: %v\n", err)
os.Exit(1)
}
}
// Print results
printResults(results)
}
func benchmarkPublish(c context.T, relayURL string, eventCount, eventSize, concurrency int, results *BenchmarkResults) error {
// Generate signers for each concurrent publisher
signers := make([]*testSigner, concurrency)
for i := range signers {
signers[i] = newTestSigner()
}
// Track published events
var publishedEvents atomic.Int64
var publishedBytes atomic.Int64
var errors atomic.Int64
// Create wait group for concurrent publishers
var wg sync.WaitGroup
eventsPerPublisher := eventCount / concurrency
extraEvents := eventCount % concurrency
startTime := time.Now()
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(publisherID int) {
defer wg.Done()
// Connect to relay
relay, err := ws.RelayConnect(c, relayURL)
if err != nil {
log.E.F("Publisher %d failed to connect: %v", publisherID, err)
errors.Add(1)
return
}
defer relay.Close()
// Calculate events for this publisher
eventsToPublish := eventsPerPublisher
if publisherID < extraEvents {
eventsToPublish++
}
signer := signers[publisherID]
// Publish events
for j := 0; j < eventsToPublish; j++ {
ev := generateEvent(signer, eventSize)
if err := relay.Publish(c, ev); err != nil {
log.E.F("Publisher %d failed to publish event: %v", publisherID, err)
errors.Add(1)
continue
}
evBytes := ev.Marshal(nil)
publishedEvents.Add(1)
publishedBytes.Add(int64(len(evBytes)))
if publishedEvents.Load()%1000 == 0 {
fmt.Printf(" Published %d events...\n", publishedEvents.Load())
}
}
}(i)
}
wg.Wait()
duration := time.Since(startTime)
results.EventsPublished = publishedEvents.Load()
results.EventsPublishedBytes = publishedBytes.Load()
results.PublishDuration = duration
results.PublishRate = float64(results.EventsPublished) / duration.Seconds()
results.PublishBandwidth = float64(results.EventsPublishedBytes) / duration.Seconds() / 1024 / 1024 // MB/s
if errors.Load() > 0 {
fmt.Printf(" Warning: %d errors occurred during publishing\n", errors.Load())
}
return nil
}
func benchmarkQuery(c context.T, relayURL string, queryCount, queryLimit int, results *BenchmarkResults) error {
relay, err := ws.RelayConnect(c, relayURL)
if err != nil {
return fmt.Errorf("failed to connect to relay: %w", err)
}
defer relay.Close()
var totalEvents atomic.Int64
var totalQueries atomic.Int64
startTime := time.Now()
for i := 0; i < queryCount; i++ {
// Generate various filter types
var f *filter.F
switch i % 5 {
case 0:
// Query by kind
limit := uint(queryLimit)
f = &filter.F{
Kinds: kinds.New(kind.TextNote),
Limit: &limit,
}
case 1:
// Query by time range
now := timestamp.Now()
since := timestamp.New(now.I64() - 3600) // last hour
limit := uint(queryLimit)
f = &filter.F{
Since: since,
Until: now,
Limit: &limit,
}
case 2:
// Query by tag
limit := uint(queryLimit)
f = &filter.F{
Tags: tags.New(tag.New([]byte("p"), generateRandomPubkey())),
Limit: &limit,
}
case 3:
// Query by author
limit := uint(queryLimit)
f = &filter.F{
Authors: tag.New(generateRandomPubkey()),
Limit: &limit,
}
case 4:
// Complex query with multiple conditions
now := timestamp.Now()
since := timestamp.New(now.I64() - 7200)
limit := uint(queryLimit)
f = &filter.F{
Kinds: kinds.New(kind.TextNote, kind.Repost),
Authors: tag.New(generateRandomPubkey()),
Since: since,
Limit: &limit,
}
}
// Execute query
events, err := relay.QuerySync(c, f, ws.WithLabel("benchmark"))
if err != nil {
log.E.F("Query %d failed: %v", i, err)
continue
}
totalEvents.Add(int64(len(events)))
totalQueries.Add(1)
if totalQueries.Load()%20 == 0 {
fmt.Printf(" Executed %d queries...\n", totalQueries.Load())
}
}
duration := time.Since(startTime)
results.QueriesExecuted = totalQueries.Load()
results.QueryDuration = duration
results.QueryRate = float64(results.QueriesExecuted) / duration.Seconds()
results.EventsReturned = totalEvents.Load()
return nil
}
func generateEvent(signer *testSigner, contentSize int) *event.E {
// Generate content with some variation
size := contentSize + frand.Intn(contentSize/2) - contentSize/4
if size < 10 {
size = 10
}
content := text.NostrEscape(nil, frand.Bytes(size))
ev := &event.E{
Pubkey: signer.Pub(),
Kind: kind.TextNote,
CreatedAt: timestamp.Now(),
Content: content,
Tags: generateRandomTags(),
}
if err := ev.Sign(signer); chk.E(err) {
panic(fmt.Sprintf("failed to sign event: %v", err))
}
return ev
}
func generateRandomTags() *tags.T {
t := tags.New()
// Add some random tags
numTags := frand.Intn(5)
for i := 0; i < numTags; i++ {
switch frand.Intn(3) {
case 0:
// p tag
t.AppendUnique(tag.New([]byte("p"), generateRandomPubkey()))
case 1:
// e tag
t.AppendUnique(tag.New([]byte("e"), generateRandomEventID()))
case 2:
// t tag
t.AppendUnique(tag.New([]byte("t"), []byte(fmt.Sprintf("topic%d", frand.Intn(100)))))
}
}
return t
}
func generateRandomPubkey() []byte {
return frand.Bytes(32)
}
func generateRandomEventID() []byte {
return frand.Bytes(32)
}
func printResults(results *BenchmarkResults) {
fmt.Println("\n=== Benchmark Results ===")
if results.EventsPublished > 0 {
fmt.Println("\nPublish Performance:")
fmt.Printf(" Events Published: %d\n", results.EventsPublished)
fmt.Printf(" Total Data: %.2f MB\n", float64(results.EventsPublishedBytes)/1024/1024)
fmt.Printf(" Duration: %s\n", results.PublishDuration)
fmt.Printf(" Rate: %.2f events/second\n", results.PublishRate)
fmt.Printf(" Bandwidth: %.2f MB/second\n", results.PublishBandwidth)
}
if results.QueriesExecuted > 0 {
fmt.Println("\nQuery Performance:")
fmt.Printf(" Queries Executed: %d\n", results.QueriesExecuted)
fmt.Printf(" Events Returned: %d\n", results.EventsReturned)
fmt.Printf(" Duration: %s\n", results.QueryDuration)
fmt.Printf(" Rate: %.2f queries/second\n", results.QueryRate)
avgEventsPerQuery := float64(results.EventsReturned) / float64(results.QueriesExecuted)
fmt.Printf(" Avg Events/Query: %.2f\n", avgEventsPerQuery)
}
}

82
cmd/benchmark/run_benchmark.sh Executable file
View File

@@ -0,0 +1,82 @@
#!/bin/bash
# Simple Nostr Relay Benchmark Script
# Default values
RELAY_URL="ws://localhost:7447"
EVENTS=10000
SIZE=1024
CONCURRENCY=10
QUERIES=100
QUERY_LIMIT=100
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--relay)
RELAY_URL="$2"
shift 2
;;
--events)
EVENTS="$2"
shift 2
;;
--size)
SIZE="$2"
shift 2
;;
--concurrency)
CONCURRENCY="$2"
shift 2
;;
--queries)
QUERIES="$2"
shift 2
;;
--query-limit)
QUERY_LIMIT="$2"
shift 2
;;
--skip-publish)
SKIP_PUBLISH="-skip-publish"
shift
;;
--skip-query)
SKIP_QUERY="-skip-query"
shift
;;
*)
echo "Unknown option: $1"
echo "Usage: $0 [--relay URL] [--events N] [--size N] [--concurrency N] [--queries N] [--query-limit N] [--skip-publish] [--skip-query]"
exit 1
;;
esac
done
# Build the benchmark tool if it doesn't exist
if [ ! -f benchmark-simple ]; then
echo "Building benchmark tool..."
go build -o benchmark-simple ./benchmark_simple.go
if [ $? -ne 0 ]; then
echo "Failed to build benchmark tool"
exit 1
fi
fi
# Run the benchmark
echo "Running Nostr relay benchmark..."
echo "Relay: $RELAY_URL"
echo "Events: $EVENTS (size: $SIZE bytes)"
echo "Concurrency: $CONCURRENCY"
echo "Queries: $QUERIES (limit: $QUERY_LIMIT)"
echo ""
./benchmark-simple \
-relay "$RELAY_URL" \
-events $EVENTS \
-size $SIZE \
-concurrency $CONCURRENCY \
-queries $QUERIES \
-query-limit $QUERY_LIMIT \
$SKIP_PUBLISH \
$SKIP_QUERY

View File

@@ -0,0 +1,63 @@
package main
import (
"lukechampine.com/frand"
"orly.dev/pkg/interfaces/signer"
)
// testSigner is a simple signer implementation for benchmarking
type testSigner struct {
pub []byte
sec []byte
}
func newTestSigner() *testSigner {
return &testSigner{
pub: frand.Bytes(32),
sec: frand.Bytes(32),
}
}
func (s *testSigner) Pub() []byte {
return s.pub
}
func (s *testSigner) Sec() []byte {
return s.sec
}
func (s *testSigner) Sign(msg []byte) ([]byte, error) {
return frand.Bytes(64), nil
}
func (s *testSigner) Verify(msg, sig []byte) (bool, error) {
return true, nil
}
func (s *testSigner) InitSec(sec []byte) error {
s.sec = sec
s.pub = frand.Bytes(32)
return nil
}
func (s *testSigner) InitPub(pub []byte) error {
s.pub = pub
return nil
}
func (s *testSigner) Zero() {
for i := range s.sec {
s.sec[i] = 0
}
}
func (s *testSigner) ECDH(pubkey []byte) ([]byte, error) {
return frand.Bytes(32), nil
}
func (s *testSigner) Generate() error {
return nil
}
var _ signer.I = (*testSigner)(nil)

View File

@@ -1,15 +0,0 @@
package app
import "sync"
var bufferPool = &sync.Pool{
New: func() interface{} {
buf := make([]byte, 32*1024)
return &buf
},
}
type Pool struct{}
func (bp Pool) Get() []byte { return *(bufferPool.Get().(*[]byte)) }
func (bp Pool) Put(b []byte) { bufferPool.Put(&b) }

View File

@@ -8,6 +8,8 @@ import (
"io"
"net/http"
"net/url"
"os"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/crypto/sha256"
"orly.dev/pkg/encoders/bech32encoding"
@@ -18,7 +20,6 @@ import (
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
realy_lol "orly.dev/pkg/version"
"os"
)
const secEnv = "NOSTR_SECRET_KEY"

View File

@@ -1,285 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"orly.dev/pkg/protocol/nwc"
)
func printUsage() {
fmt.Println("Usage: nwcclient \"<connection URL>\" <method> [parameters...]")
fmt.Println("\nSupported methods:")
fmt.Println(" get_info - Get wallet information")
fmt.Println(" get_balance - Get wallet balance")
fmt.Println(" get_budget - Get wallet budget")
fmt.Println(" make_invoice - Create an invoice (amount, description, [description_hash], [expiry])")
fmt.Println(" pay_invoice - Pay an invoice (invoice, [amount])")
fmt.Println(" pay_keysend - Send a keysend payment (amount, pubkey, [preimage])")
fmt.Println(" lookup_invoice - Look up an invoice (payment_hash or invoice)")
fmt.Println(" list_transactions - List transactions ([from], [until], [limit], [offset], [unpaid], [type])")
fmt.Println(" sign_message - Sign a message (message)")
fmt.Println("\nUnsupported methods (due to limitations in the nwc package):")
fmt.Println(" create_connection - Create a connection")
fmt.Println(" make_hold_invoice - Create a hold invoice")
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
fmt.Println(" multi_pay_invoice - Pay multiple invoices")
fmt.Println(" multi_pay_keysend - Send multiple keysend payments")
fmt.Println("\nParameters format:")
fmt.Println(" - Positional parameters are used for required fields")
fmt.Println(" - For list_transactions, named parameters are used: 'from', 'until', 'limit', 'offset', 'unpaid', 'type'")
fmt.Println(" Example: nwcclient <url> list_transactions limit 10 type incoming")
os.Exit(1)
}
func main() {
// Check if we have enough arguments
if len(os.Args) < 3 {
printUsage()
}
// Parse connection URL and method
connectionURL := os.Args[1]
methodStr := os.Args[2]
method := nwc.Method(methodStr)
// Parse the wallet connect URL
opts, err := nwc.ParseWalletConnectURL(connectionURL)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing connection URL: %v\n", err)
os.Exit(1)
}
// Create a new NWC client
client, err := nwc.NewNWCClient(opts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating NWC client: %v\n", err)
os.Exit(1)
}
defer client.Close()
// Execute the requested method
var result interface{}
switch method {
case nwc.GetInfo:
result, err = client.GetInfo()
case nwc.GetBalance:
result, err = client.GetBalance()
case nwc.GetBudget:
result, err = client.GetBudget()
case nwc.MakeInvoice:
if len(os.Args) < 5 {
fmt.Fprintf(
os.Stderr,
"Error: make_invoice requires at least amount and description\n",
)
printUsage()
}
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
description := os.Args[4]
req := &nwc.MakeInvoiceRequest{
Amount: amount,
Description: description,
}
// Optional parameters
if len(os.Args) > 5 {
req.DescriptionHash = os.Args[5]
}
if len(os.Args) > 6 {
expiry, err := strconv.ParseInt(os.Args[6], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing expiry: %v\n", err)
os.Exit(1)
}
req.Expiry = &expiry
}
result, err = client.MakeInvoice(req)
case nwc.PayInvoice:
if len(os.Args) < 4 {
fmt.Fprintf(os.Stderr, "Error: pay_invoice requires an invoice\n")
printUsage()
}
req := &nwc.PayInvoiceRequest{
Invoice: os.Args[3],
}
// Optional amount parameter
if len(os.Args) > 4 {
amount, err := strconv.ParseInt(os.Args[4], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
req.Amount = &amount
}
result, err = client.PayInvoice(req)
case nwc.PayKeysend:
if len(os.Args) < 5 {
fmt.Fprintf(
os.Stderr, "Error: pay_keysend requires amount and pubkey\n",
)
printUsage()
}
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
req := &nwc.PayKeysendRequest{
Amount: amount,
Pubkey: os.Args[4],
}
// Optional preimage
if len(os.Args) > 5 {
req.Preimage = os.Args[5]
}
result, err = client.PayKeysend(req)
case nwc.LookupInvoice:
if len(os.Args) < 4 {
fmt.Fprintf(
os.Stderr,
"Error: lookup_invoice requires a payment_hash or invoice\n",
)
printUsage()
}
param := os.Args[3]
req := &nwc.LookupInvoiceRequest{}
// Determine if the parameter is a payment hash or an invoice
if strings.HasPrefix(param, "ln") {
req.Invoice = param
} else {
req.PaymentHash = param
}
result, err = client.LookupInvoice(req)
case nwc.ListTransactions:
req := &nwc.ListTransactionsRequest{}
// Parse optional parameters
paramIndex := 3
for paramIndex < len(os.Args) {
if paramIndex+1 >= len(os.Args) {
break
}
paramName := os.Args[paramIndex]
paramValue := os.Args[paramIndex+1]
switch paramName {
case "from":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing from: %v\n", err)
os.Exit(1)
}
req.From = &val
case "until":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing until: %v\n", err)
os.Exit(1)
}
req.Until = &val
case "limit":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing limit: %v\n", err)
os.Exit(1)
}
req.Limit = &val
case "offset":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing offset: %v\n", err)
os.Exit(1)
}
req.Offset = &val
case "unpaid":
val := paramValue == "true"
req.Unpaid = &val
case "type":
req.Type = &paramValue
default:
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
os.Exit(1)
}
paramIndex += 2
}
result, err = client.ListTransactions(req)
case nwc.SignMessage:
if len(os.Args) < 4 {
fmt.Fprintf(os.Stderr, "Error: sign_message requires a message\n")
printUsage()
}
req := &nwc.SignMessageRequest{
Message: os.Args[3],
}
result, err = client.SignMessage(req)
case nwc.CreateConnection, nwc.MakeHoldInvoice, nwc.SettleHoldInvoice, nwc.CancelHoldInvoice, nwc.MultiPayInvoice, nwc.MultiPayKeysend:
fmt.Fprintf(
os.Stderr,
"Error: Method %s is not directly supported by the CLI tool.\n",
methodStr,
)
fmt.Fprintf(
os.Stderr,
"This is because these methods don't have exported client methods in the nwc package.\n",
)
fmt.Fprintf(
os.Stderr,
"Only the following methods are currently supported: get_info, get_balance, get_budget, make_invoice, pay_invoice, pay_keysend, lookup_invoice, list_transactions, sign_message\n",
)
os.Exit(1)
default:
fmt.Fprintf(os.Stderr, "Error: Unsupported method: %s\n", methodStr)
printUsage()
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error executing method: %v\n", err)
os.Exit(1)
}
// Print the result as JSON
jsonData, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marshaling result to JSON: %v\n", err)
os.Exit(1)
}
fmt.Println(string(jsonData))
}

417
cmd/walletcli/main.go Normal file
View File

@@ -0,0 +1,417 @@
package main
import (
"fmt"
"os"
"strconv"
"strings"
"orly.dev/pkg/protocol/nwc"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
)
func printUsage() {
fmt.Println("Usage: walletcli '<NWC connection URL>' <method> [<args...>]")
fmt.Println("\nAvailable methods:")
fmt.Println(" get_wallet_service_info - Get wallet service information")
fmt.Println(" get_info - Get wallet information")
fmt.Println(" get_balance - Get wallet balance")
fmt.Println(" get_budget - Get wallet budget")
fmt.Println(" make_invoice - Create an invoice")
fmt.Println(" Args: <amount> [<description>] [<description_hash>] [<expiry>]")
fmt.Println(" pay_invoice - Pay an invoice")
fmt.Println(" Args: <invoice> [<amount>] [<comment>]")
fmt.Println(" pay_keysend - Pay to a node using keysend")
fmt.Println(" Args: <pubkey> <amount> [<preimage>] [<tlv_type> <tlv_value>...]")
fmt.Println(" lookup_invoice - Look up an invoice")
fmt.Println(" Args: <payment_hash or invoice>")
fmt.Println(" list_transactions - List transactions")
fmt.Println(" Args: [<limit>] [<offset>] [<from>] [<until>]")
fmt.Println(" make_hold_invoice - Create a hold invoice")
fmt.Println(" Args: <amount> <payment_hash> [<description>] [<description_hash>] [<expiry>]")
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
fmt.Println(" Args: <preimage>")
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
fmt.Println(" Args: <payment_hash>")
fmt.Println(" sign_message - Sign a message")
fmt.Println(" Args: <message>")
fmt.Println(" create_connection - Create a connection")
fmt.Println(" Args: <pubkey> <name> <methods> [<notification_types>] [<max_amount>] [<budget_renewal>] [<expires_at>]")
}
func main() {
if len(os.Args) < 3 {
printUsage()
os.Exit(1)
}
connectionURL := os.Args[1]
method := os.Args[2]
args := os.Args[3:]
// Create context
// ctx, cancel := context.Cancel(context.Bg())
ctx := context.Bg()
// defer cancel()
// Create NWC client
client, err := nwc.NewClient(ctx, connectionURL)
if err != nil {
fmt.Printf("Error creating client: %v\n", err)
os.Exit(1)
}
// Execute the requested method
switch method {
case "get_wallet_service_info":
handleGetWalletServiceInfo(ctx, client)
case "get_info":
handleGetInfo(ctx, client)
case "get_balance":
handleGetBalance(ctx, client)
case "get_budget":
handleGetBudget(ctx, client)
case "make_invoice":
handleMakeInvoice(ctx, client, args)
case "pay_invoice":
handlePayInvoice(ctx, client, args)
case "pay_keysend":
handlePayKeysend(ctx, client, args)
case "lookup_invoice":
handleLookupInvoice(ctx, client, args)
case "list_transactions":
handleListTransactions(ctx, client, args)
case "make_hold_invoice":
handleMakeHoldInvoice(ctx, client, args)
case "settle_hold_invoice":
handleSettleHoldInvoice(ctx, client, args)
case "cancel_hold_invoice":
handleCancelHoldInvoice(ctx, client, args)
case "sign_message":
handleSignMessage(ctx, client, args)
case "create_connection":
handleCreateConnection(ctx, client, args)
default:
fmt.Printf("Unknown method: %s\n", method)
printUsage()
os.Exit(1)
}
}
func handleGetWalletServiceInfo(ctx context.T, client *nwc.Client) {
if _, raw, err := client.GetWalletServiceInfo(ctx, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleGetInfo(ctx context.T, client *nwc.Client) {
if _, raw, err := client.GetInfo(ctx, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleGetBalance(ctx context.T, client *nwc.Client) {
if _, raw, err := client.GetBalance(ctx, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleGetBudget(ctx context.T, client *nwc.Client) {
if _, raw, err := client.GetBudget(ctx, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleMakeInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> make_invoice <amount> [<description>] [<description_hash>] [<expiry>]")
return
}
amount, err := strconv.ParseUint(args[0], 10, 64)
if err != nil {
fmt.Printf("Error parsing amount: %v\n", err)
return
}
params := &nwc.MakeInvoiceParams{
Amount: amount,
}
if len(args) > 1 {
params.Description = args[1]
}
if len(args) > 2 {
params.DescriptionHash = args[2]
}
if len(args) > 3 {
expiry, err := strconv.ParseInt(args[3], 10, 64)
if err != nil {
fmt.Printf("Error parsing expiry: %v\n", err)
return
}
params.Expiry = &expiry
}
var raw []byte
if _, raw, err = client.MakeInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handlePayInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> pay_invoice <invoice> [<amount>] [<comment>]")
return
}
params := &nwc.PayInvoiceParams{
Invoice: args[0],
}
if len(args) > 1 {
amount, err := strconv.ParseUint(args[1], 10, 64)
if err != nil {
fmt.Printf("Error parsing amount: %v\n", err)
return
}
params.Amount = &amount
}
if len(args) > 2 {
comment := args[2]
params.Metadata = &nwc.PayInvoiceMetadata{
Comment: &comment,
}
}
if _, raw, err := client.PayInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleLookupInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> lookup_invoice <payment_hash or invoice>")
return
}
params := &nwc.LookupInvoiceParams{}
// Determine if the argument is a payment hash or an invoice
if strings.HasPrefix(args[0], "ln") {
invoice := args[0]
params.Invoice = &invoice
} else {
paymentHash := args[0]
params.PaymentHash = &paymentHash
}
var err error
var raw []byte
if _, raw, err = client.LookupInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleListTransactions(ctx context.T, client *nwc.Client, args []string) {
params := &nwc.ListTransactionsParams{}
if len(args) > 0 {
limit, err := strconv.ParseUint(args[0], 10, 16)
if err != nil {
fmt.Printf("Error parsing limit: %v\n", err)
return
}
limitUint16 := uint16(limit)
params.Limit = &limitUint16
}
if len(args) > 1 {
offset, err := strconv.ParseUint(args[1], 10, 32)
if err != nil {
fmt.Printf("Error parsing offset: %v\n", err)
return
}
offsetUint32 := uint32(offset)
params.Offset = &offsetUint32
}
if len(args) > 2 {
from, err := strconv.ParseInt(args[2], 10, 64)
if err != nil {
fmt.Printf("Error parsing from: %v\n", err)
return
}
params.From = &from
}
if len(args) > 3 {
until, err := strconv.ParseInt(args[3], 10, 64)
if err != nil {
fmt.Printf("Error parsing until: %v\n", err)
return
}
params.Until = &until
}
var raw []byte
var err error
if _, raw, err = client.ListTransactions(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleMakeHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 2 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> make_hold_invoice <amount> <payment_hash> [<description>] [<description_hash>] [<expiry>]")
return
}
amount, err := strconv.ParseUint(args[0], 10, 64)
if err != nil {
fmt.Printf("Error parsing amount: %v\n", err)
return
}
params := &nwc.MakeHoldInvoiceParams{
Amount: amount,
PaymentHash: args[1],
}
if len(args) > 2 {
params.Description = args[2]
}
if len(args) > 3 {
params.DescriptionHash = args[3]
}
if len(args) > 4 {
expiry, err := strconv.ParseInt(args[4], 10, 64)
if err != nil {
fmt.Printf("Error parsing expiry: %v\n", err)
return
}
params.Expiry = &expiry
}
var raw []byte
if _, raw, err = client.MakeHoldInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleSettleHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> settle_hold_invoice <preimage>")
return
}
params := &nwc.SettleHoldInvoiceParams{
Preimage: args[0],
}
var raw []byte
var err error
if raw, err = client.SettleHoldInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleCancelHoldInvoice(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> cancel_hold_invoice <payment_hash>")
return
}
params := &nwc.CancelHoldInvoiceParams{
PaymentHash: args[0],
}
var err error
var raw []byte
if raw, err = client.CancelHoldInvoice(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleSignMessage(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 1 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> sign_message <message>")
return
}
params := &nwc.SignMessageParams{
Message: args[0],
}
var raw []byte
var err error
if _, raw, err = client.SignMessage(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 2 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> pay_keysend <pubkey> <amount> [<preimage>] [<tlv_type> <tlv_value>...]")
return
}
pubkey := args[0]
amount, err := strconv.ParseUint(args[1], 10, 64)
if err != nil {
fmt.Printf("Error parsing amount: %v\n", err)
return
}
params := &nwc.PayKeysendParams{
Pubkey: pubkey,
Amount: amount,
}
// Optional preimage
if len(args) > 2 {
preimage := args[2]
params.Preimage = &preimage
}
// Optional TLV records (must come in pairs)
if len(args) > 3 {
// Start from index 3 and process pairs of arguments
for i := 3; i < len(args)-1; i += 2 {
tlvType, err := strconv.ParseUint(args[i], 10, 32)
if err != nil {
fmt.Printf("Error parsing TLV type: %v\n", err)
return
}
tlvValue := args[i+1]
params.TLVRecords = append(
params.TLVRecords, nwc.PayKeysendTLVRecord{
Type: uint32(tlvType),
Value: tlvValue,
},
)
}
}
var raw []byte
if _, raw, err = client.PayKeysend(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}
func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) {
if len(args) < 3 {
fmt.Println("Error: Missing required arguments")
fmt.Println("Usage: walletcli <NWC connection URL> create_connection <pubkey> <name> <methods> [<notification_types>] [<max_amount>] [<budget_renewal>] [<expires_at>]")
return
}
params := &nwc.CreateConnectionParams{
Pubkey: args[0],
Name: args[1],
RequestMethods: strings.Split(args[2], ","),
}
if len(args) > 3 {
params.NotificationTypes = strings.Split(args[3], ",")
}
if len(args) > 4 {
maxAmount, err := strconv.ParseUint(args[4], 10, 64)
if err != nil {
fmt.Printf("Error parsing max_amount: %v\n", err)
return
}
params.MaxAmount = &maxAmount
}
if len(args) > 5 {
params.BudgetRenewal = &args[5]
}
if len(args) > 6 {
expiresAt, err := strconv.ParseInt(args[6], 10, 64)
if err != nil {
fmt.Printf("Error parsing expires_at: %v\n", err)
return
}
params.ExpiresAt = &expiresAt
}
var raw []byte
var err error
if raw, err = client.CreateConnection(ctx, params, true); !chk.E(err) {
fmt.Println(string(raw))
}
}

4
go.mod
View File

@@ -5,13 +5,12 @@ go 1.24.2
require (
github.com/adrg/xdg v0.5.3
github.com/alexflint/go-arg v1.6.0
github.com/coder/websocket v1.8.13
github.com/danielgtaylor/huma/v2 v2.34.1
github.com/davecgh/go-spew v1.1.1
github.com/dgraph-io/badger/v4 v4.7.0
github.com/fasthttp/websocket v1.5.12
github.com/fatih/color v1.18.0
github.com/gobwas/httphead v0.1.0
github.com/gobwas/ws v1.4.0
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
github.com/klauspost/cpuid/v2 v2.2.11
github.com/minio/sha256-simd v1.0.1
@@ -41,7 +40,6 @@ require (
github.com/felixge/fgprof v0.9.5 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/google/flatbuffers v25.2.10+incompatible // indirect
github.com/google/pprof v0.0.0-20250630185457-6e76a2b096b5 // indirect
github.com/klauspost/compress v1.18.0 // indirect

6
go.sum
View File

@@ -19,6 +19,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/danielgtaylor/huma/v2 v2.34.1 h1:EmOJAbzEGfy0wAq/QMQ1YKfEMBEfE94xdBRLPBP0gwQ=
github.com/danielgtaylor/huma/v2 v2.34.1/go.mod h1:ynwJgLk8iGVgoaipi5tgwIQ5yoFNmiu+QdhU7CEEmhk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -44,13 +46,9 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=

View File

@@ -44,6 +44,7 @@ type C struct {
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`
Blacklist []string `env:"ORLY_BLACKLIST" usage:"list of pubkeys to block when auth is not required (comma separated)"`
RelaySecret string `env:"ORLY_SECRET_KEY" usage:"secret key for relay cluster replication authentication"`
PeerRelays []string `env:"ORLY_PEER_RELAYS" usage:"list of peer relays URLs that new events are pushed to in format <pubkey>|<url>"`
}

View File

@@ -42,30 +42,35 @@ func (s *Server) AcceptEvent(
remote string,
) (accept bool, notice string, afterSave func()) {
if !s.AuthRequired() {
accept = true
// Check blacklist for public relay mode
if len(s.blacklistPubkeys) > 0 {
for _, blockedPubkey := range s.blacklistPubkeys {
if bytes.Equal(blockedPubkey, ev.Pubkey) {
notice = "event author is blacklisted"
return
}
}
}
return
}
// if auth is required and the user is not authed, reject
if s.AuthRequired() && len(authedPubkey) == 0 {
if len(authedPubkey) == 0 {
notice = "client isn't authed"
return
}
// check if the authed user is on the lists
list := append(s.OwnersFollowed(), s.FollowedFollows()...)
for _, u := range list {
if bytes.Equal(u, authedPubkey) {
accept = true
break
}
}
if !accept {
return
}
for _, u := range s.OwnersMuted() {
if bytes.Equal(u, authedPubkey) {
notice = "event author is banned from this relay"
return
}
}
// check if the authed user is on the lists
list := append(s.OwnersFollowed(), s.FollowedFollows()...)
for _, u := range list {
if bytes.Equal(u, authedPubkey) {
accept = true
return
}
}
return
}

View File

@@ -12,8 +12,8 @@ import (
// mockServerForEvent is a simple mock implementation of the Server struct for testing AcceptEvent
type mockServerForEvent struct {
authRequired bool
ownersFollowed [][]byte
authRequired bool
ownersFollowed [][]byte
followedFollows [][]byte
}
@@ -203,8 +203,8 @@ func TestAcceptEventWithRealServer(t *testing.T) {
if accept {
t.Error("AcceptEvent() accept = true, want false")
}
if notice != "" {
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
if notice != "client isn't authed" {
t.Errorf("AcceptEvent() notice = %v, want 'client isn't authed'", notice)
}
if afterSave != nil {
t.Error("AcceptEvent() afterSave is not nil, but should be nil")
@@ -234,4 +234,81 @@ func TestAcceptEventWithRealServer(t *testing.T) {
if !accept {
t.Error("AcceptEvent() accept = false, want true")
}
// Test with muted user
s.SetOwnersMuted([][]byte{[]byte("test-pubkey")})
accept, notice, afterSave = s.AcceptEvent(ctx, testEvent, req, []byte("test-pubkey"), "127.0.0.1")
if accept {
t.Error("AcceptEvent() accept = true, want false")
}
if notice != "event author is banned from this relay" {
t.Errorf("AcceptEvent() notice = %v, want 'event author is banned from this relay'", notice)
}
}
// TestAcceptEventWithBlacklist tests the blacklist functionality when auth is not required
func TestAcceptEventWithBlacklist(t *testing.T) {
// Create a context and HTTP request for testing
ctx := context.Bg()
req, _ := http.NewRequest("GET", "http://example.com", nil)
// Test pubkey bytes
testPubkey := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}
blockedPubkey := []byte{0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30}
// Test with public relay mode (auth not required) and no blacklist
s := &Server{
C: &config.C{
AuthRequired: false,
},
Lists: new(Lists),
}
// Create event with test pubkey
testEvent := &event.E{}
testEvent.Pubkey = testPubkey
// Should accept when no blacklist
accept, notice, _ := s.AcceptEvent(ctx, testEvent, req, nil, "127.0.0.1")
if !accept {
t.Error("AcceptEvent() accept = false, want true")
}
if notice != "" {
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
}
// Add blacklist with different pubkey
s.blacklistPubkeys = [][]byte{blockedPubkey}
// Should still accept when author not in blacklist
accept, notice, _ = s.AcceptEvent(ctx, testEvent, req, nil, "127.0.0.1")
if !accept {
t.Error("AcceptEvent() accept = false, want true")
}
if notice != "" {
t.Errorf("AcceptEvent() notice = %v, want empty string", notice)
}
// Create event with blocked pubkey
blockedEvent := &event.E{}
blockedEvent.Pubkey = blockedPubkey
// Should reject when author is in blacklist
accept, notice, _ = s.AcceptEvent(ctx, blockedEvent, req, nil, "127.0.0.1")
if accept {
t.Error("AcceptEvent() accept = true, want false")
}
if notice != "event author is blacklisted" {
t.Errorf("AcceptEvent() notice = %v, want 'event author is blacklisted'", notice)
}
// Test with auth required - blacklist should not apply
s.C.AuthRequired = true
accept, notice, _ = s.AcceptEvent(ctx, blockedEvent, req, nil, "127.0.0.1")
if accept {
t.Error("AcceptEvent() accept = true, want false")
}
if notice != "client isn't authed" {
t.Errorf("AcceptEvent() notice = %v, want 'client isn't authed'", notice)
}
}

View File

@@ -8,41 +8,41 @@ import (
func TestLists_OwnersPubkeys(t *testing.T) {
// Create a new Lists instance
l := &Lists{}
// Test with empty list
pks := l.OwnersPubkeys()
if len(pks) != 0 {
t.Errorf("Expected empty list, got %d items", len(pks))
}
// Test with some pubkeys
testPubkeys := [][]byte{
[]byte("pubkey1"),
[]byte("pubkey2"),
[]byte("pubkey3"),
}
l.SetOwnersPubkeys(testPubkeys)
// Verify length
if l.LenOwnersPubkeys() != len(testPubkeys) {
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersPubkeys())
}
// Verify content
pks = l.OwnersPubkeys()
if len(pks) != len(testPubkeys) {
t.Errorf("Expected %d pubkeys, got %d", len(testPubkeys), len(pks))
}
// Verify each pubkey
for i, pk := range pks {
if !bytes.Equal(pk, testPubkeys[i]) {
t.Errorf("Pubkey at index %d doesn't match: expected %s, got %s",
t.Errorf("Pubkey at index %d doesn't match: expected %s, got %s",
i, testPubkeys[i], pk)
}
}
// Verify that the returned slice is a copy, not a reference
pks[0] = []byte("modified")
newPks := l.OwnersPubkeys()
@@ -54,37 +54,37 @@ func TestLists_OwnersPubkeys(t *testing.T) {
func TestLists_OwnersFollowed(t *testing.T) {
// Create a new Lists instance
l := &Lists{}
// Test with empty list
followed := l.OwnersFollowed()
if len(followed) != 0 {
t.Errorf("Expected empty list, got %d items", len(followed))
}
// Test with some pubkeys
testPubkeys := [][]byte{
[]byte("followed1"),
[]byte("followed2"),
[]byte("followed3"),
}
l.SetOwnersFollowed(testPubkeys)
// Verify length
if l.LenOwnersFollowed() != len(testPubkeys) {
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersFollowed())
}
// Verify content
followed = l.OwnersFollowed()
if len(followed) != len(testPubkeys) {
t.Errorf("Expected %d followed, got %d", len(testPubkeys), len(followed))
}
// Verify each pubkey
for i, pk := range followed {
if !bytes.Equal(pk, testPubkeys[i]) {
t.Errorf("Followed at index %d doesn't match: expected %s, got %s",
t.Errorf("Followed at index %d doesn't match: expected %s, got %s",
i, testPubkeys[i], pk)
}
}
@@ -93,37 +93,37 @@ func TestLists_OwnersFollowed(t *testing.T) {
func TestLists_FollowedFollows(t *testing.T) {
// Create a new Lists instance
l := &Lists{}
// Test with empty list
follows := l.FollowedFollows()
if len(follows) != 0 {
t.Errorf("Expected empty list, got %d items", len(follows))
}
// Test with some pubkeys
testPubkeys := [][]byte{
[]byte("follow1"),
[]byte("follow2"),
[]byte("follow3"),
}
l.SetFollowedFollows(testPubkeys)
// Verify length
if l.LenFollowedFollows() != len(testPubkeys) {
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenFollowedFollows())
}
// Verify content
follows = l.FollowedFollows()
if len(follows) != len(testPubkeys) {
t.Errorf("Expected %d follows, got %d", len(testPubkeys), len(follows))
}
// Verify each pubkey
for i, pk := range follows {
if !bytes.Equal(pk, testPubkeys[i]) {
t.Errorf("Follow at index %d doesn't match: expected %s, got %s",
t.Errorf("Follow at index %d doesn't match: expected %s, got %s",
i, testPubkeys[i], pk)
}
}
@@ -132,37 +132,37 @@ func TestLists_FollowedFollows(t *testing.T) {
func TestLists_OwnersMuted(t *testing.T) {
// Create a new Lists instance
l := &Lists{}
// Test with empty list
muted := l.OwnersMuted()
if len(muted) != 0 {
t.Errorf("Expected empty list, got %d items", len(muted))
}
// Test with some pubkeys
testPubkeys := [][]byte{
[]byte("muted1"),
[]byte("muted2"),
[]byte("muted3"),
}
l.SetOwnersMuted(testPubkeys)
// Verify length
if l.LenOwnersMuted() != len(testPubkeys) {
t.Errorf("Expected length %d, got %d", len(testPubkeys), l.LenOwnersMuted())
}
// Verify content
muted = l.OwnersMuted()
if len(muted) != len(testPubkeys) {
t.Errorf("Expected %d muted, got %d", len(testPubkeys), len(muted))
}
// Verify each pubkey
for i, pk := range muted {
if !bytes.Equal(pk, testPubkeys[i]) {
t.Errorf("Muted at index %d doesn't match: expected %s, got %s",
t.Errorf("Muted at index %d doesn't match: expected %s, got %s",
i, testPubkeys[i], pk)
}
}
@@ -171,10 +171,10 @@ func TestLists_OwnersMuted(t *testing.T) {
func TestLists_ConcurrentAccess(t *testing.T) {
// Create a new Lists instance
l := &Lists{}
// Test concurrent access to the lists
done := make(chan bool)
// Concurrent reads and writes
go func() {
for i := 0; i < 100; i++ {
@@ -183,7 +183,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
l.SetOwnersFollowed([][]byte{[]byte("followed1"), []byte("followed2")})
@@ -191,7 +191,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
l.SetFollowedFollows([][]byte{[]byte("follow1"), []byte("follow2")})
@@ -199,7 +199,7 @@ func TestLists_ConcurrentAccess(t *testing.T) {
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
l.SetOwnersMuted([][]byte{[]byte("muted1"), []byte("muted2")})
@@ -207,11 +207,11 @@ func TestLists_ConcurrentAccess(t *testing.T) {
}
done <- true
}()
// Wait for all goroutines to complete
for i := 0; i < 4; i++ {
<-done
}
// If we got here without deadlocks or panics, the test passes
}
}

View File

@@ -6,12 +6,13 @@ import (
"fmt"
"net"
"net/http"
"orly.dev/pkg/protocol/openapi"
"orly.dev/pkg/protocol/socketapi"
"strconv"
"strings"
"time"
"orly.dev/pkg/protocol/openapi"
"orly.dev/pkg/protocol/socketapi"
"orly.dev/pkg/app/config"
"orly.dev/pkg/app/relay/helpers"
"orly.dev/pkg/app/relay/options"
@@ -20,6 +21,7 @@ import (
"orly.dev/pkg/protocol/servemux"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/keys"
"orly.dev/pkg/utils/log"
"github.com/rs/cors"
@@ -29,14 +31,15 @@ import (
// encapsulates various components such as context, cancel function, options,
// relay interface, address, HTTP server, and configuration settings.
type Server struct {
Ctx context.T
Cancel context.F
options *options.T
relay relay.I
Addr string
mux *servemux.S
httpServer *http.Server
listeners *publish.S
Ctx context.T
Cancel context.F
options *options.T
relay relay.I
Addr string
mux *servemux.S
httpServer *http.Server
listeners *publish.S
blacklistPubkeys [][]byte
*config.C
*Lists
*Peers
@@ -105,6 +108,17 @@ func NewServer(
Lists: new(Lists),
Peers: new(Peers),
}
// Parse blacklist pubkeys
for _, v := range s.C.Blacklist {
if len(v) == 0 {
continue
}
var pk []byte
if pk, err = keys.DecodeNpubOrHex(v); chk.E(err) {
continue
}
s.blacklistPubkeys = append(s.blacklistPubkeys, pk)
}
chk.E(
s.Peers.Init(sp.C.PeerRelays, sp.C.RelaySecret),
)

View File

@@ -1,6 +1,9 @@
package relay
import (
"runtime/debug"
"time"
"orly.dev/pkg/crypto/ec/schnorr"
"orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/event"
@@ -14,8 +17,6 @@ import (
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"runtime/debug"
"time"
)
// IdPkTs is a map of event IDs to their id, pubkey, kind, and timestamp
@@ -141,14 +142,10 @@ func (s *Server) SpiderFetch(
var evss event.S
var cli *ws.Client
if cli, err = ws.RelayConnect(
context.Bg(), seed, ws.WithSignatureChecker(
func(e *event.E) bool {
return true
},
),
context.Bg(), seed,
); chk.E(err) {
err = nil
return
continue
}
if evss, err = cli.QuerySync(
context.Bg(), batchFilter,

View File

@@ -52,7 +52,7 @@ func TestBech32(t *testing.T) {
{
"split1cheo2y9e2w",
ErrNonCharsetChar('o'),
}, // invalid character (o) in data part
}, // invalid character (o) in data part
{"split1a2y9w", ErrInvalidSeparatorIndex(5)}, // too short data part
{
"1checkupstagehandshakeupstreamerranterredcaperred2y9e3w",

View File

@@ -87,9 +87,9 @@ type nonceAggValidCase struct {
}
type nonceAggInvalidCase struct {
Indices []int `json:"pnonce_indices"`
Error nonceAggError `json:"error"`
Comment string `json:"comment"`
Indices []int `json:"pnonce_indices"`
Error nonceAggError `json:"error"`
Comment string `json:"comment"`
ExpectedErr string `json:"btcec_err"`
}

View File

@@ -158,8 +158,8 @@ func Decrypt(b64ciphertextWrapped, conversationKey []byte) (
return
}
// GenerateConversationKey performs an ECDH key generation hashed with the nip-44-v2 using hkdf.
func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
// GenerateConversationKeyFromHex performs an ECDH key generation hashed with the nip-44-v2 using hkdf.
func GenerateConversationKeyFromHex(pkh, skh string) (ck []byte, err error) {
if skh >= "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" ||
skh == "0000000000000000000000000000000000000000000000000000000000000000" {
err = errorf.E(
@@ -184,6 +184,17 @@ func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
return
}
func GenerateConversationKeyWithSigner(sign signer.I, pk []byte) (
ck []byte, err error,
) {
var shared []byte
if shared, err = sign.ECDH(pk); chk.E(err) {
return
}
ck = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
return
}
func encrypt(key, nonce, message []byte) (dst []byte, err error) {
var cipher *chacha20.Cipher
if cipher, err = chacha20.NewUnauthenticatedCipher(key, nonce); chk.E(err) {

View File

@@ -47,7 +47,9 @@ func assertCryptPriv(
return
}
expectedBytes = []byte(expected)
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
if ok = assert.Equalf(
t, string(expectedBytes), string(actualBytes), "wrong encryption",
); !ok {
return
}
decrypted, err = Decrypt(expectedBytes, k1)
@@ -62,8 +64,8 @@ func assertDecryptFail(
) {
var (
k1, ciphertextBytes []byte
ok bool
err error
ok bool
err error
)
k1, err = hex.Dec(conversationKey)
if ok = assert.NoErrorf(
@@ -79,7 +81,7 @@ func assertDecryptFail(
func assertConversationKeyFail(
t *testing.T, priv string, pub string, msg string,
) {
_, err := GenerateConversationKey(pub, priv)
_, err := GenerateConversationKeyFromHex(pub, priv)
assert.ErrorContains(t, err, msg)
}
@@ -98,7 +100,7 @@ func assertConversationKeyGeneration(
); !ok {
return false
}
actualConversationKey, err = GenerateConversationKey(pub, priv)
actualConversationKey, err = GenerateConversationKeyFromHex(pub, priv)
if ok = assert.NoErrorf(
t, err, "conversation key generation failed: %v", err,
); !ok {
@@ -1312,7 +1314,7 @@ func TestMaxLength(t *testing.T) {
pub2, _ := keys.GetPublicKeyHex(string(sk2))
salt := make([]byte, 32)
rand.Read(salt)
conversationKey, _ := GenerateConversationKey(pub2, string(sk1))
conversationKey, _ := GenerateConversationKeyFromHex(pub2, string(sk1))
plaintext := strings.Repeat("a", MaxPlaintextSize)
plaintextBytes := []byte(plaintext)
encrypted, err := Encrypt(
@@ -1366,7 +1368,9 @@ func assertCryptPub(
return
}
expectedBytes = []byte(expected)
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
if ok = assert.Equalf(
t, string(expectedBytes), string(actualBytes), "wrong encryption",
); !ok {
return
}
decrypted, err = Decrypt(expectedBytes, k1)

View File

@@ -33,7 +33,7 @@ func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
}
func HexToBin(hexStr string) (b []byte, err error) {
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
if b, err = hex.DecAppend(b, []byte(hexStr)); chk.E(err) {
return
}
return

View File

@@ -2,16 +2,16 @@ package database
import (
"bytes"
"testing"
"orly.dev/pkg/database/indexes"
types2 "orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"testing"
"github.com/minio/sha256-simd"
)
@@ -26,8 +26,7 @@ func TestGetIndexesForEvent(t *testing.T) {
// indexes
func verifyIndexIncluded(t *testing.T, idxs [][]byte, expectedIdx *indexes.T) {
// Marshal the expected index
buf := codecbuf.Get()
defer codecbuf.Put(buf)
buf := new(bytes.Buffer)
err := expectedIdx.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("Failed to marshal expected index: %v", err)

View File

@@ -3,16 +3,16 @@ package database
import (
"bytes"
"math"
"testing"
"orly.dev/pkg/database/indexes"
types2 "orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"testing"
"github.com/minio/sha256-simd"
)
@@ -41,8 +41,7 @@ func verifyIndex(
}
// Marshal the expected start index
startBuf := codecbuf.Get()
defer codecbuf.Put(startBuf)
startBuf := new(bytes.Buffer)
err := expectedStartIdx.MarshalWrite(startBuf)
if chk.E(err) {
t.Fatalf("Failed to marshal expected start index: %v", err)
@@ -62,8 +61,7 @@ func verifyIndex(
}
// Marshal the expected end index
endBuf := codecbuf.Get()
defer codecbuf.Put(endBuf)
endBuf := new(bytes.Buffer)
err = endIdx.MarshalWrite(endBuf)
if chk.E(err) {
t.Fatalf("Failed to marshal expected End index: %v", err)

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"io"
"orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
)
@@ -49,7 +48,7 @@ func TestPrefixMethods(t *testing.T) {
}
// Test MarshalWrite method
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := prefix.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -209,7 +208,7 @@ func TestTStruct(t *testing.T) {
}
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -272,7 +271,7 @@ func TestEventFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -318,7 +317,7 @@ func TestIdFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -391,7 +390,7 @@ func TestIdPubkeyFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -452,7 +451,7 @@ func TestCreatedAtFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -516,7 +515,7 @@ func TestPubkeyFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -588,7 +587,7 @@ func TestPubkeyTagFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -660,7 +659,7 @@ func TestTagFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -724,7 +723,7 @@ func TestKindFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -789,7 +788,7 @@ func TestKindTagFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -865,7 +864,7 @@ func TestKindPubkeyFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -941,7 +940,7 @@ func TestKindPubkeyTagFunctions(t *testing.T) {
}
// Test marshaling and unmarshaling
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = enc.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -84,7 +84,7 @@ func testUint16Sorting(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint16 values don't sort correctly: %v should be less than %v",
t.Errorf("Uint16 values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -115,7 +115,7 @@ func testUint24Sorting(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint24 values don't sort correctly: %v should be less than %v",
t.Errorf("Uint24 values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -143,7 +143,7 @@ func testUint32Sorting(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint32 values don't sort correctly: %v should be less than %v",
t.Errorf("Uint32 values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -174,7 +174,7 @@ func testUint40Sorting(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint40 values don't sort correctly: %v should be less than %v",
t.Errorf("Uint40 values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -202,7 +202,7 @@ func testUint64Sorting(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint64 values don't sort correctly: %v should be less than %v",
t.Errorf("Uint64 values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -233,7 +233,7 @@ func testUint16EdgeCases(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint16 edge case values don't sort correctly: %v should be less than %v",
t.Errorf("Uint16 edge case values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -265,7 +265,7 @@ func testUint24EdgeCases(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint24 edge case values don't sort correctly: %v should be less than %v",
t.Errorf("Uint24 edge case values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -294,7 +294,7 @@ func testUint32EdgeCases(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint32 edge case values don't sort correctly: %v should be less than %v",
t.Errorf("Uint32 edge case values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -326,7 +326,7 @@ func testUint40EdgeCases(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint40 edge case values don't sort correctly: %v should be less than %v",
t.Errorf("Uint40 edge case values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -355,7 +355,7 @@ func testUint64EdgeCases(t *testing.T) {
// Check if they sort correctly with bytes.Compare
for i := 0; i < len(marshaledValues)-1; i++ {
if bytes.Compare(marshaledValues[i], marshaledValues[i+1]) >= 0 {
t.Errorf("Uint64 edge case values don't sort correctly: %v should be less than %v",
t.Errorf("Uint64 edge case values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", marshaledValues[i], marshaledValues[i+1])
}
@@ -390,7 +390,7 @@ func TestEndianness(t *testing.T) {
result := bytes.Compare(bigEndianValues[i], bigEndianValues[i+1])
t.Logf("Compare %d with %d: result = %d", values[i], values[i+1], result)
if result >= 0 {
t.Errorf("BigEndian values don't sort correctly: %v should be less than %v",
t.Errorf("BigEndian values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", bigEndianValues[i], bigEndianValues[i+1])
}
@@ -404,7 +404,7 @@ func TestEndianness(t *testing.T) {
t.Logf("Compare %d with %d: result = %d", values[i], values[i+1], result)
if result >= 0 {
correctOrder = false
t.Logf("LittleEndian values don't sort correctly: %v should be less than %v",
t.Logf("LittleEndian values don't sort correctly: %v should be less than %v",
values[i], values[i+1])
t.Logf("Bytes representation: %v vs %v", littleEndianValues[i], littleEndianValues[i+1])
}

View File

@@ -2,10 +2,10 @@ package types
import (
"bytes"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
"orly.dev/pkg/utils/chk"
"github.com/minio/sha256-simd"
)
@@ -55,7 +55,7 @@ func TestIdMarshalWriteUnmarshalRead(t *testing.T) {
}
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = fi1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -2,10 +2,10 @@ package types
import (
"bytes"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
"orly.dev/pkg/utils/chk"
"github.com/minio/sha256-simd"
)
@@ -45,7 +45,7 @@ func TestIdent_MarshalWriteUnmarshalRead(t *testing.T) {
}
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = i1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -3,10 +3,10 @@ package types
import (
"bytes"
"encoding/base64"
"orly.dev/pkg/encoders/codecbuf"
"testing"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/utils/chk"
"testing"
"github.com/minio/sha256-simd"
)
@@ -142,7 +142,7 @@ func TestIdHashMarshalWriteUnmarshalRead(t *testing.T) {
}
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = i1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -2,9 +2,9 @@ package types
import (
"bytes"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
"orly.dev/pkg/utils/chk"
)
func TestLetter_New(t *testing.T) {
@@ -53,7 +53,7 @@ func TestLetter_MarshalWriteUnmarshalRead(t *testing.T) {
l1 := new(Letter)
l1.Set('A')
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := l1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -2,11 +2,11 @@ package types
import (
"bytes"
"testing"
"orly.dev/pkg/crypto/ec/schnorr"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/utils/chk"
"testing"
"github.com/minio/sha256-simd"
)
@@ -105,7 +105,7 @@ func TestPubHash_MarshalWriteUnmarshalRead(t *testing.T) {
}
// Test MarshalWrite
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err = ph1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -2,10 +2,10 @@ package types
import (
"bytes"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
"time"
"orly.dev/pkg/utils/chk"
)
func TestTimestamp_FromInt(t *testing.T) {
@@ -89,7 +89,7 @@ func TestTimestamp_FromBytes(t *testing.T) {
v.Set(12345)
// Marshal it to bytes
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := v.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -163,7 +163,7 @@ func TestTimestamp_Bytes(t *testing.T) {
func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) {
// Test with a positive value
ts1 := &Timestamp{val: 12345}
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := ts1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -183,7 +183,7 @@ func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) {
// Test with a negative value
ts1 = &Timestamp{val: -12345}
buf = codecbuf.Get()
buf = new(bytes.Buffer)
err = ts1.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)
@@ -225,7 +225,7 @@ func TestTimestamp_WithCurrentTime(t *testing.T) {
}
// Test MarshalWrite and UnmarshalRead
buf := codecbuf.Get()
buf := new(bytes.Buffer)
err := ts.MarshalWrite(buf)
if chk.E(err) {
t.Fatalf("MarshalWrite failed: %v", err)

View File

@@ -3,11 +3,11 @@ package types
import (
"bytes"
"math"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"reflect"
"testing"
"orly.dev/pkg/utils/chk"
"lukechampine.com/frand"
)
@@ -44,7 +44,7 @@ func TestUint16(t *testing.T) {
}
// Test encoding to []byte and decoding back
bufEnc := codecbuf.Get()
bufEnc := new(bytes.Buffer)
// MarshalWrite
err := encodedUint16.MarshalWrite(bufEnc)

View File

@@ -1,10 +1,11 @@
package types
import (
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"bytes"
"reflect"
"testing"
"orly.dev/pkg/utils/chk"
)
func TestUint24(t *testing.T) {
@@ -45,7 +46,7 @@ func TestUint24(t *testing.T) {
}
// Test MarshalWrite and UnmarshalRead
buf := codecbuf.Get()
buf := new(bytes.Buffer)
// MarshalWrite directly to the buffer
if err := codec.MarshalWrite(buf); chk.E(err) {

View File

@@ -3,11 +3,11 @@ package types
import (
"bytes"
"math"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"reflect"
"testing"
"orly.dev/pkg/utils/chk"
"lukechampine.com/frand"
)
@@ -43,7 +43,7 @@ func TestUint32(t *testing.T) {
}
// Test encoding to []byte and decoding back
bufEnc := codecbuf.Get()
bufEnc := new(bytes.Buffer)
// MarshalWrite
err := codec.MarshalWrite(bufEnc)

View File

@@ -1,10 +1,11 @@
package types
import (
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"bytes"
"reflect"
"testing"
"orly.dev/pkg/utils/chk"
)
func TestUint40(t *testing.T) {
@@ -48,7 +49,7 @@ func TestUint40(t *testing.T) {
}
// Test MarshalWrite and UnmarshalRead
buf := codecbuf.Get()
buf := new(bytes.Buffer)
// Marshal to a buffer
if err = codec.MarshalWrite(buf); chk.E(err) {

View File

@@ -3,11 +3,11 @@ package types
import (
"bytes"
"math"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"reflect"
"testing"
"orly.dev/pkg/utils/chk"
"lukechampine.com/frand"
)
@@ -43,7 +43,7 @@ func TestUint64(t *testing.T) {
}
// Test encoding to []byte and decoding back
bufEnc := codecbuf.Get()
bufEnc := new(bytes.Buffer)
// MarshalWrite
err := codec.MarshalWrite(bufEnc)

View File

@@ -4,6 +4,7 @@ package reqenvelope
import (
"io"
"orly.dev/pkg/encoders/envelopes"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/subscription"
@@ -37,10 +38,21 @@ func New() *T {
// NewFrom creates a new reqenvelope.T with a provided subscription.Id and
// filters.T.
func NewFrom(id *subscription.Id, filters *filters.T) *T {
func NewFrom(id *subscription.Id, ff *filters.T) *T {
return &T{
Subscription: id,
Filters: filters,
Filters: ff,
}
}
func NewWithIdString(id string, ff *filters.T) (sub *T) {
sid, err := subscription.NewId(id)
if err != nil {
return
}
return &T{
Subscription: sid,
Filters: ff,
}
}

View File

@@ -3,11 +3,11 @@ package event
import (
"bufio"
"bytes"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/encoders/event/examples"
"orly.dev/pkg/utils/chk"
"testing"
"time"
"orly.dev/pkg/encoders/event/examples"
"orly.dev/pkg/utils/chk"
)
func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
@@ -19,7 +19,7 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
var counter int
for scanner.Scan() {
// Create new event objects and buffer for each iteration
buf := codecbuf.Get()
buf := new(bytes.Buffer)
ea, eb := New(), New()
chk.E(scanner.Err())
@@ -42,7 +42,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
// Create a new buffer for unmarshaling
buf2 := bytes.NewBuffer(buf.Bytes())
if err = eb.UnmarshalBinary(buf2); chk.E(err) {
codecbuf.Put(buf)
t.Fatal(err)
}
@@ -57,9 +56,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) {
)
}
// Return buffer to pool
codecbuf.Put(buf)
counter++
out = out[:0]
}

View File

@@ -4,10 +4,11 @@ import (
"bufio"
"bytes"
_ "embed"
"testing"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event/examples"
"orly.dev/pkg/utils/chk"
"testing"
)
func TestTMarshal_Unmarshal(t *testing.T) {

View File

@@ -2,12 +2,13 @@ package event
import (
"bytes"
"testing"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
text2 "orly.dev/pkg/encoders/text"
"orly.dev/pkg/encoders/timestamp"
"testing"
)
// compareTags compares two tags and reports any differences
@@ -96,7 +97,8 @@ func TestUnmarshalEscapedJSONInTags(t *testing.T) {
unmarshaledTag := unmarshaledEvent.Tags.GetTagElement(0)
if unmarshaledTag.Len() != 2 {
t.Fatalf(
"Expected tag with 2 elements, got %d", unmarshaledTag.Len(),
"Expected tag with 2 elements, got %d",
unmarshaledTag.Len(),
)
}

View File

@@ -2,11 +2,12 @@ package event
import (
"bytes"
"testing"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"testing"
)
// compareEvents compares two events and reports any differences

View File

@@ -7,6 +7,8 @@ package filter
import (
"bytes"
"encoding/binary"
"sort"
"orly.dev/pkg/crypto/ec/schnorr"
"orly.dev/pkg/crypto/ec/secp256k1"
"orly.dev/pkg/crypto/sha256"
@@ -21,8 +23,8 @@ import (
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/pointers"
"sort"
"lukechampine.com/frand"
)
@@ -181,12 +183,12 @@ func (f *F) Marshal(dst []byte) (b []byte) {
dst = append(dst, '[')
for i, value := range values {
dst = append(dst, '"')
if tKey[1] == 'e' || tKey[1] == 'p' {
// event and pubkey tags are binary 32 bytes
dst = hex.EncAppend(dst, value)
} else {
dst = append(dst, value...)
}
// if tKey[1] == 'e' || tKey[1] == 'p' {
// // event and pubkey tags are binary 32 bytes
// dst = hex.EncAppend(dst, value)
// } else {
dst = append(dst, value...)
// }
dst = append(dst, '"')
if i < len(values)-1 {
dst = append(dst, ',')
@@ -300,29 +302,29 @@ func (f *F) Unmarshal(b []byte) (r []byte, err error) {
}
k := make([]byte, len(key))
copy(k, key)
switch key[1] {
case 'e', 'p':
// the tags must all be 64 character hexadecimal
var ff [][]byte
if ff, r, err = text2.UnmarshalHexArray(
r,
sha256.Size,
); chk.E(err) {
return
}
ff = append([][]byte{k}, ff...)
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
default:
// other types of tags can be anything
var ff [][]byte
if ff, r, err = text2.UnmarshalStringArray(r); chk.E(err) {
return
}
ff = append([][]byte{k}, ff...)
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
// switch key[1] {
// case 'e', 'p':
// // the tags must all be 64 character hexadecimal
// var ff [][]byte
// if ff, r, err = text2.UnmarshalHexArray(
// r,
// sha256.Size,
// ); chk.E(err) {
// return
// }
// ff = append([][]byte{k}, ff...)
// f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
// // f.Tags.F = append(f.Tags.F, tag.New(ff...))
// default:
// other types of tags can be anything
var ff [][]byte
if ff, r, err = text2.UnmarshalStringArray(r); chk.E(err) {
return
}
ff = append([][]byte{k}, ff...)
f.Tags = f.Tags.AppendTags(tag.FromBytesSlice(ff...))
// f.Tags.F = append(f.Tags.F, tag.New(ff...))
// }
state = betweenKV
case IDs[0]:
if len(key) < len(IDs) {
@@ -440,43 +442,55 @@ invalid:
return
}
// Matches checks a filter against an event and determines if the event matches the filter.
func (f *F) Matches(ev *event.E) bool {
// MatchesIgnoringTimestampConstraints checks a filter against an event and
// determines if the event matches the filter, ignoring timestamp constraints..
func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool {
if ev == nil {
// log.F.ToSliceOfBytes("nil event")
log.I.F("nil event")
return false
}
if f.Ids.Len() > 0 && !f.Ids.Contains(ev.ID) {
// log.F.ToSliceOfBytes("no ids in filter match event\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
log.I.F("no ids in filter match event")
return false
}
if f.Kinds.Len() > 0 && !f.Kinds.Contains(ev.Kind) {
// log.F.ToSliceOfBytes("no matching kinds in filter\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
log.I.F(
"no matching kinds in filter",
)
return false
}
if f.Authors.Len() > 0 && !f.Authors.Contains(ev.Pubkey) {
// log.F.ToSliceOfBytes("no matching authors in filter\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
log.I.F("no matching authors in filter")
return false
}
if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) {
return false
}
// if f.Tags.Len() > 0 {
// for _, v := range f.Tags.ToSliceOfTags() {
// tvs := v.ToSliceOfBytes()
// if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) {
// return false
// }
// }
// return false
// if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) {
// return false
// }
if f.Tags.Len() > 0 {
for _, v := range f.Tags.ToSliceOfTags() {
tvs := v.ToSliceOfBytes()
if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) {
log.I.F("no matching tags in filter")
return false
}
}
// return false
}
return true
}
// Matches checks a filter against an event and determines if the event matches the filter.
func (f *F) Matches(ev *event.E) (match bool) {
if !f.MatchesIgnoringTimestampConstraints(ev) {
return
}
if f.Since.Int() != 0 && ev.CreatedAt.I64() < f.Since.I64() {
// log.F.ToSliceOfBytes("event is older than since\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
return false
log.I.F("event is older than since")
return
}
if f.Until.Int() != 0 && ev.CreatedAt.I64() > f.Until.I64() {
// log.F.ToSliceOfBytes("event is newer than until\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String())
return false
log.I.F("event is newer than until")
return
}
return true
}

View File

@@ -120,3 +120,12 @@ func GenFilters(n int) (ff *T, err error) {
}
return
}
func (f *T) MatchIgnoringTimestampConstraints(ev *event.E) bool {
for _, ff := range f.F {
if ff.MatchesIgnoringTimestampConstraints(ev) {
return true
}
}
return false
}

View File

@@ -4,6 +4,7 @@ package hex
import (
"encoding/hex"
"github.com/templexxx/xhex"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/errorf"

View File

@@ -275,9 +275,9 @@ var (
WalletRequest = &T{23194}
// NWCWalletResponse is an event type that...
NWCWalletResponse = &T{23195}
WalletResponse = NWCWalletResponse
WalletResponse = &T{23195}
NWCNotification = &T{23196}
WalletNotification = NWCNotification
WalletNotification = &T{23196}
// NostrConnect is an event type that...
NostrConnect = &T{24133}
HTTPAuth = &T{27235}

View File

@@ -5,6 +5,7 @@ package subscription
import (
"crypto/rand"
"orly.dev/pkg/crypto/ec/bech32"
"orly.dev/pkg/encoders/text"
"orly.dev/pkg/utils/chk"
@@ -24,7 +25,7 @@ func (si *Id) IsValid() bool { return len(si.T) <= 64 && len(si.T) > 0 }
// NewId inspects a string and converts to Id if it is
// valid. Invalid means length == 0 or length > 64.
func NewId[V string | []byte](s V) (*Id, error) {
func NewId[V ~string | ~[]byte](s V) (*Id, error) {
si := &Id{T: []byte(s)}
if si.IsValid() {
return si, nil
@@ -40,7 +41,7 @@ func NewId[V string | []byte](s V) (*Id, error) {
// MustNew is the same as NewId except it doesn't check if you feed it rubbish.
//
// DO NOT USE WITHOUT CHECKING THE Id IS NOT NIL AND > 0 AND <= 64
func MustNew[V string | []byte](s V) *Id {
func MustNew[V ~string | ~[]byte](s V) *Id {
return &Id{T: []byte(s)}
}

View File

@@ -5,6 +5,7 @@ package tag
import (
"bytes"
text2 "orly.dev/pkg/encoders/text"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
@@ -26,7 +27,7 @@ const (
)
// BS is an abstract data type that can process strings and byte slices as byte slices.
type BS[Z []byte | string] []byte
type BS[Z ~[]byte | ~string] []byte
// T is a list of strings with a literal ordering.
//
@@ -36,7 +37,7 @@ type T struct {
}
// New creates a new tag.T from a variadic parameter that can be either string or byte slice.
func New[V string | []byte](fields ...V) (t *T) {
func New[V ~string | ~[]byte](fields ...V) (t *T) {
t = &T{field: make([]BS[[]byte], len(fields))}
for i, field := range fields {
t.field[i] = []byte(field)

View File

@@ -7,12 +7,13 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"sort"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/lol"
"os"
"sort"
)
// T is a list of tag.T - which are lists of string elements with ordering and no uniqueness
@@ -161,6 +162,15 @@ func (t *T) GetFirst(tagPrefix *tag.T) *tag.T {
return nil
}
func (t *T) GetD() (d string) {
for _, v := range t.element {
if bytes.Equal(v.Key(), []byte("d")) {
return string(v.Value())
}
}
return
}
// GetLast gets the last tag in tags that matches the prefix, see [T.StartsWith]
func (t *T) GetLast(tagPrefix *tag.T) *tag.T {
for i := len(t.element) - 1; i >= 0; i-- {
@@ -299,7 +309,7 @@ func (t *T) ContainsAny(tagName []byte, values *tag.T) bool {
continue
}
for _, candidate := range values.ToSliceOfBytes() {
if bytes.Equal(v.Value(), candidate) {
if bytes.HasPrefix(v.Value(), candidate) {
return true
}
}

View File

@@ -5,8 +5,9 @@
package varint
import (
"golang.org/x/exp/constraints"
"io"
"golang.org/x/exp/constraints"
"orly.dev/pkg/utils/chk"
)

View File

@@ -3,10 +3,10 @@ package varint
import (
"bytes"
"math"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/utils/chk"
"testing"
"orly.dev/pkg/utils/chk"
"lukechampine.com/frand"
)
@@ -14,7 +14,7 @@ func TestEncode_Decode(t *testing.T) {
var v uint64
for range 10000000 {
v = uint64(frand.Intn(math.MaxInt64))
buf1 := codecbuf.Get()
buf1 := new(bytes.Buffer)
Encode(buf1, v)
buf2 := bytes.NewBuffer(buf1.Bytes())
u, err := Decode(buf2)

View File

@@ -3,7 +3,8 @@ package nwc
import (
"encoding/json"
"fmt"
"net/url"
"time"
"orly.dev/pkg/crypto/encryption"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event"
@@ -19,596 +20,140 @@ import (
"orly.dev/pkg/protocol/ws"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/log"
"strings"
"sync"
"time"
"orly.dev/pkg/utils/values"
)
// Options represents options for a NWC client
type Options struct {
RelayURL string
Secret signer.I
WalletPubkey []byte
Lud16 string
}
// Client represents a NWC client
type Client struct {
options Options
relay *ws.Client
mu sync.Mutex
client *ws.Client
relay string
clientSecretKey signer.I
walletPublicKey []byte
conversationKey []byte // nip44
}
// ParseWalletConnectURL parses a wallet connect URL
func ParseWalletConnectURL(walletConnectURL string) (opts *Options, err error) {
if !strings.HasPrefix(walletConnectURL, "nostr+walletconnect://") {
return nil, fmt.Errorf("unexpected scheme. Should be nostr+walletconnect://")
}
// Parse URL
colonIndex := strings.Index(walletConnectURL, ":")
if colonIndex == -1 {
err = fmt.Errorf("invalid URL format")
return
}
walletConnectURL = walletConnectURL[colonIndex+1:]
if strings.HasPrefix(walletConnectURL, "//") {
walletConnectURL = walletConnectURL[2:]
}
walletConnectURL = "https://" + walletConnectURL
var u *url.URL
if u, err = url.Parse(walletConnectURL); chk.E(err) {
err = fmt.Errorf("failed to parse URL: %w", err)
return
}
// Get wallet pubkey
walletPubkey := u.Host
if len(walletPubkey) != 64 {
err = fmt.Errorf("incorrect wallet pubkey found in auth string")
return
}
var pk []byte
if pk, err = hex.Dec(walletPubkey); chk.E(err) {
err = fmt.Errorf("failed to decode pubkey: %w", err)
return
}
// Get relay URL
relayURL := u.Query().Get("relay")
if relayURL == "" {
return nil, fmt.Errorf("no relay URL found in auth string")
}
// Get secret
secret := u.Query().Get("secret")
if secret == "" {
return nil, fmt.Errorf("no secret found in auth string")
}
var sk []byte
if sk, err = hex.Dec(secret); chk.E(err) {
return
}
sign := &p256k.Signer{}
if err = sign.InitSec(sk); chk.E(err) {
return
}
opts = &Options{
RelayURL: relayURL,
Secret: sign,
WalletPubkey: pk,
}
return
type Request struct {
Method string `json:"method"`
Params any `json:"params"`
}
// NewNWCClient creates a new NWC client
func NewNWCClient(options *Options) (cl *Client, err error) {
if options.RelayURL == "" {
err = fmt.Errorf("missing relay URL")
return
}
if options.Secret == nil {
err = fmt.Errorf("missing secret")
return
}
if options.WalletPubkey == nil {
err = fmt.Errorf("missing wallet pubkey")
return
}
return &Client{
options: Options{
RelayURL: options.RelayURL,
Secret: options.Secret,
WalletPubkey: options.WalletPubkey,
Lud16: options.Lud16,
},
}, nil
type ResponseError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NostrWalletConnectURL returns the nostr wallet connect URL
func (c *Client) NostrWalletConnectURL() string {
return c.GetNostrWalletConnectURL(true)
func (err *ResponseError) Error() string {
return fmt.Sprintf("%s %s", err.Code, err.Message)
}
// GetNostrWalletConnectURL returns the nostr wallet connect URL
func (c *Client) GetNostrWalletConnectURL(includeSecret bool) string {
params := url.Values{}
params.Add("relay", c.options.RelayURL)
if includeSecret {
params.Add("secret", hex.Enc(c.options.Secret.Sec()))
}
return fmt.Sprintf(
"nostr+walletconnect://%s?%s", c.options.WalletPubkey, params.Encode(),
)
type Response struct {
ResultType string `json:"result_type"`
Error *ResponseError `json:"error"`
Result any `json:"result"`
}
// Connected returns whether the client is connected to the relay
func (c *Client) Connected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.relay != nil && c.relay.IsConnected()
}
// GetPublicKey returns the client's public key
func (c *Client) GetPublicKey() (pubkey []byte, err error) {
pubkey = c.options.Secret.Pub()
return
}
// Close closes the relay connection
func (c *Client) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.relay != nil {
c.relay.Close()
c.relay = nil
}
}
// Encrypt encrypts content for a pubkey
func (c *Client) encrypt(pubkey, content []byte) (
cipherText []byte, err error,
) {
var sharedSecret []byte
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
func NewClient(c context.T, connectionURI string) (cl *Client, err error) {
var parts *ConnectionParams
if parts, err = ParseConnectionURI(connectionURI); chk.E(err) {
return
}
cipherText, err = encryption.EncryptNip4(content, sharedSecret)
return
}
// Decrypt decrypts content from a pubkey
func (c *Client) decrypt(pubkey, content []byte) (plaintext []byte, err error) {
var sharedSecret []byte
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
clientKey := &p256k.Signer{}
if err = clientKey.InitSec(parts.clientSecretKey); chk.E(err) {
return
}
plaintext, err = encryption.DecryptNip4(content, sharedSecret)
return
}
// GetInfo gets wallet info
func (c *Client) GetInfo() (response *GetInfoResponse, err error) {
var result []byte
if result, err = c.executeRequest(GetInfo, nil); chk.E(err) {
return
}
response = &GetInfoResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// GetBudget gets wallet budget
func (c *Client) GetBudget() (response *GetBudgetResponse, err error) {
var result []byte
result, err = c.executeRequest(GetBudget, nil)
if err != nil {
return nil, err
}
response = &GetBudgetResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// GetBalance gets wallet balance
func (c *Client) GetBalance() (response *GetBalanceResponse, err error) {
var result []byte
if result, err = c.executeRequest(GetBalance, nil); chk.E(err) {
return
}
response = &GetBalanceResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// PayInvoice pays an invoice
func (c *Client) PayInvoice(request *PayInvoiceRequest) (
response *PayResponse, err error,
) {
var result []byte
result, err = c.executeRequest(PayInvoice, request)
if err != nil {
return nil, err
}
response = &PayResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// PayKeysend sends a keysend payment
func (c *Client) PayKeysend(request *PayKeysendRequest) (
response *PayResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(PayKeysend, request); chk.E(err) {
return
}
response = &PayResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// MakeInvoice creates an invoice
func (c *Client) MakeInvoice(request *MakeInvoiceRequest) (
response *Transaction, err error,
) {
var result []byte
if result, err = c.executeRequest(MakeInvoice, request); chk.E(err) {
return
}
response = &Transaction{}
if err = json.Unmarshal(result, response); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return
}
// LookupInvoice looks up an invoice
func (c *Client) LookupInvoice(request *LookupInvoiceRequest) (
response *Transaction, err error,
) {
var result []byte
if result, err = c.executeRequest(LookupInvoice, request); chk.E(err) {
return
}
response = &Transaction{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// ListTransactions lists transactions
func (c *Client) ListTransactions(request *ListTransactionsRequest) (
response *ListTransactionsResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(ListTransactions, request); chk.E(err) {
return
}
response = &ListTransactionsResponse{}
if err = json.Unmarshal(result, response); chk.E(err) {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// SignMessage signs a message
func (c *Client) SignMessage(request *SignMessageRequest) (
response *SignMessageResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(SignMessage, request); chk.E(err) {
return
}
response = &SignMessageResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// NotificationHandler is a function that handles notifications
type NotificationHandler func(*Notification)
// SubscribeNotifications subscribes to notifications
func (c *Client) SubscribeNotifications(
handler NotificationHandler,
notificationTypes []NotificationType,
) (stop func(), err error) {
if handler == nil {
err = fmt.Errorf("missing notification handler")
return
}
ctx, cancel := context.Cancel(context.Bg())
doneCh := make(chan struct{})
stop = func() {
cancel()
<-doneCh
}
go func() {
defer close(doneCh)
for {
select {
case <-ctx.Done():
return
default:
// Check connection
if err := c.checkConnected(); err != nil {
time.Sleep(1 * time.Second)
continue
}
// Get client pubkey
var clientPubkey []byte
if clientPubkey, err = c.GetPublicKey(); chk.E(err) {
time.Sleep(1 * time.Second)
continue
}
// Subscribe to events
f := &filter.F{
Kinds: kinds.New(kind.WalletResponse),
Authors: tag.New(c.options.WalletPubkey),
Tags: tags.New(tag.New([]byte("#p"), clientPubkey)),
}
var sub *ws.Subscription
if sub, err = c.relay.Subscribe(
context.Bg(), &filters.T{
F: []*filter.F{f},
},
); chk.E(err) {
time.Sleep(1 * time.Second)
continue
}
// Handle events
for {
select {
case <-ctx.Done():
sub.Close()
return
case ev := <-sub.Events:
// Decrypt content
var decryptedContent []byte
if decryptedContent, err = c.decrypt(
c.options.WalletPubkey, ev.Content,
); chk.E(err) {
log.E.F(
"Failed to decrypt event content: %v\n", err,
)
continue
}
// Parse notification
notification := &Notification{}
if err = json.Unmarshal(
decryptedContent, notification,
); chk.E(err) {
log.E.F(
"Failed to parse notification: %v\n", err,
)
continue
}
// Check if notification type is requested
if len(notificationTypes) > 0 {
found := false
for _, t := range notificationTypes {
if notification.NotificationType == t {
found = true
break
}
}
if !found {
continue
}
}
// Handle notification
handler(notification)
case <-sub.EndOfStoredEvents:
// Ignore
}
}
}
}
}()
return
}
// executeRequest executes a NIP-47 request
func (c *Client) executeRequest(
method Method,
params any,
) (msg json.RawMessage, err error) {
// Default timeout values
replyTimeout := 3 * time.Second
publishTimeout := 3 * time.Second
// Create context with timeout
ctx, cancel := context.Timeout(context.Bg(), replyTimeout)
defer cancel()
// Create result channel
resultCh := make(chan json.RawMessage, 1)
errCh := make(chan error, 1)
// Check connection
if err = c.checkConnected(); err != nil {
return nil, err
}
// Create request
request := struct {
Method Method `json:"method"`
Params any `json:"params,omitempty"`
}{
Method: method,
Params: params,
}
// Marshal request
requestJSON, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Encrypt request
var encryptedContent []byte
if encryptedContent, err = c.encrypt(
c.options.WalletPubkey, requestJSON,
var ck []byte
if ck, err = encryption.GenerateConversationKeyWithSigner(
clientKey,
parts.walletPublicKey,
); chk.E(err) {
return nil, fmt.Errorf("failed to encrypt request: %w", err)
return
}
// Create request event
requestEvent := &event.E{
var relay *ws.Client
if relay, err = ws.RelayConnect(c, parts.relay); chk.E(err) {
return
}
cl = &Client{
client: relay,
relay: parts.relay,
clientSecretKey: clientKey,
walletPublicKey: parts.walletPublicKey,
conversationKey: ck,
}
return
}
type rpcOptions struct {
timeout *time.Duration
}
func (cl *Client) RPC(
c context.T, method Capability, params, result any, noUnmarshal bool,
opts *rpcOptions,
) (raw []byte, err error) {
var req []byte
if req, err = json.Marshal(
Request{
Method: string(method),
Params: params,
},
); chk.E(err) {
return
}
var content []byte
if content, err = encryption.Encrypt(req, cl.conversationKey); chk.E(err) {
return
}
ev := &event.E{
Content: content,
CreatedAt: timestamp.Now(),
Kind: kind.WalletRequest,
CreatedAt: timestamp.New(time.Now().Unix()),
Tags: tags.New(tag.New("p", hex.Enc(c.options.WalletPubkey))),
Content: encryptedContent,
Tags: tags.New(
tag.New("p", hex.Enc(cl.walletPublicKey)),
tag.New(EncryptionTag, Nip44V2),
),
}
// Sign request event
err = requestEvent.Sign(c.options.Secret)
if err != nil {
return nil, fmt.Errorf("failed to sign request event: %w", err)
if err = ev.Sign(cl.clientSecretKey); chk.E(err) {
return
}
// Subscribe to response events
f := &filter.F{
Kinds: kinds.New(kind.WalletResponse),
Authors: tag.New(c.options.WalletPubkey),
Tags: tags.New(tag.New([]byte("#p"), requestEvent.ID)),
var rc *ws.Client
if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) {
return
}
log.I.F("%s", f.Marshal(nil))
defer rc.Close()
var sub *ws.Subscription
if sub, err = c.relay.Subscribe(
ctx, &filters.T{
F: []*filter.F{f},
},
if sub, err = rc.Subscribe(
c, filters.New(
&filter.F{
Limit: values.ToUintPointer(1),
Kinds: kinds.New(kind.WalletResponse),
Authors: tag.New(cl.walletPublicKey),
Tags: tags.New(tag.New("#e", hex.Enc(ev.ID))),
},
),
); chk.E(err) {
err = fmt.Errorf(
"failed to subscribe to response events: %w", err,
)
return
}
defer sub.Close()
// Set up reply timeout
replyTimer := time.AfterFunc(
replyTimeout, func() {
errCh <- NewReplyTimeoutError(
fmt.Sprintf("Timeout waiting for reply to %s", method),
"TIMEOUT",
)
},
)
defer replyTimer.Stop()
// Handle response events
go func() {
var resErr error
for {
select {
case <-ctx.Done():
return
case ev := <-sub.Events:
// Decrypt content
var decryptedContent []byte
decryptedContent, resErr = c.decrypt(
c.options.WalletPubkey, ev.Content,
)
if chk.E(resErr) {
errCh <- fmt.Errorf(
"failed to decrypt response: %w",
resErr,
)
return
}
// Parse response
var response struct {
ResultType string `json:"result_type"`
Result json.RawMessage `json:"result"`
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if resErr = json.Unmarshal(
decryptedContent, &response,
); chk.E(resErr) {
errCh <- fmt.Errorf("failed to parse response: %w", resErr)
return
}
// Check for error
if response.Error != nil {
errCh <- NewWalletError(
response.Error.Message,
response.Error.Code,
)
return
}
// Send result
resultCh <- response.Result
return
case <-sub.EndOfStoredEvents:
// Ignore
}
}
}()
// Publish request event
publishCtx, publishCancel := context.Timeout(
context.Bg(), publishTimeout,
)
defer publishCancel()
if err = c.relay.Publish(publishCtx, requestEvent); chk.E(err) {
err = fmt.Errorf("failed to publish request event: %w", err)
defer sub.Unsub()
if err = rc.Publish(context.Bg(), ev); chk.E(err) {
return
}
// Wait for result or error
select {
case msg = <-resultCh:
return
case err = <-errCh:
return
case <-ctx.Done():
err = NewReplyTimeoutError(
fmt.Sprintf("Timeout waiting for reply to %s", method),
"TIMEOUT",
)
return
}
}
// checkConnected checks if the client is connected to the relay
func (c *Client) checkConnected() (err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.options.RelayURL == "" {
return fmt.Errorf("missing relay URL")
}
if c.relay == nil {
if c.relay, err = ws.RelayConnect(
context.Bg(), c.options.RelayURL,
case <-c.Done():
err = fmt.Errorf("context canceled waiting for response")
case e := <-sub.Events:
if raw, err = encryption.Decrypt(
e.Content, cl.conversationKey,
); chk.E(err) {
return NewNetworkError(
"Failed to connect to "+c.options.RelayURL,
"OTHER",
)
return
}
} else if !c.relay.IsConnected() {
c.relay.Close()
if c.relay, err = ws.RelayConnect(
context.Bg(), c.options.RelayURL,
); chk.E(err) {
return NewNetworkError(
"Failed to connect to "+c.options.RelayURL,
"OTHER",
)
if noUnmarshal {
return
}
resp := &Response{
Result: &result,
}
if err = json.Unmarshal(raw, resp); chk.E(err) {
return
}
}
return nil
return
}

183
pkg/protocol/nwc/methods.go Normal file
View File

@@ -0,0 +1,183 @@
package nwc
import (
"bytes"
"fmt"
"time"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/protocol/ws"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/values"
)
func (cl *Client) GetWalletServiceInfo(c context.T, noUnmarshal bool) (
wsi *WalletServiceInfo, raw []byte, err error,
) {
timeout := 10 * time.Second
ctx, cancel := context.Timeout(c, timeout)
defer cancel()
var rc *ws.Client
if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) {
return
}
if err = rc.Connect(c); chk.E(err) {
return
}
var sub *ws.Subscription
if sub, err = rc.Subscribe(
ctx, filters.New(
&filter.F{
Limit: values.ToUintPointer(1),
Kinds: kinds.New(kind.WalletRequest),
Authors: tag.New(cl.walletPublicKey),
},
),
); chk.E(err) {
return
}
defer sub.Unsub()
select {
case <-c.Done():
err = fmt.Errorf("GetWalletServiceInfo canceled")
return
case ev := <-sub.Events:
var encryptionTypes []EncryptionType
var notificationTypes []NotificationType
encryptionTag := ev.Tags.GetFirst(tag.New("encryption"))
notificationsTag := ev.Tags.GetFirst(tag.New("notifications"))
if encryptionTag != nil {
et := encryptionTag.ToSliceOfBytes()
encType := bytes.Split(et[0], []byte(" "))
for _, e := range encType {
encryptionTypes = append(encryptionTypes, e)
}
}
if notificationsTag != nil {
nt := notificationsTag.ToSliceOfBytes()
notifs := bytes.Split(nt[0], []byte(" "))
for _, e := range notifs {
notificationTypes = append(notificationTypes, e)
}
}
cp := bytes.Split(ev.Content, []byte(" "))
var capabilities []Capability
for _, capability := range cp {
capabilities = append(capabilities, capability)
}
wsi = &WalletServiceInfo{
EncryptionTypes: encryptionTypes,
NotificationTypes: notificationTypes,
Capabilities: capabilities,
}
}
return
}
func (cl *Client) CancelHoldInvoice(
c context.T, chi *CancelHoldInvoiceParams, noUnmarshal bool,
) (raw []byte, err error) {
return cl.RPC(c, CancelHoldInvoice, chi, nil, noUnmarshal, nil)
}
func (cl *Client) CreateConnection(
c context.T, cc *CreateConnectionParams, noUnmarshal bool,
) (raw []byte, err error) {
return cl.RPC(c, CreateConnection, cc, nil, noUnmarshal, nil)
}
func (cl *Client) GetBalance(c context.T, noUnmarshal bool) (
gb *GetBalanceResult, raw []byte, err error,
) {
gb = &GetBalanceResult{}
raw, err = cl.RPC(c, GetBalance, nil, gb, noUnmarshal, nil)
return
}
func (cl *Client) GetBudget(c context.T, noUnmarshal bool) (
gb *GetBudgetResult, raw []byte, err error,
) {
gb = &GetBudgetResult{}
raw, err = cl.RPC(c, GetBudget, nil, gb, noUnmarshal, nil)
return
}
func (cl *Client) GetInfo(c context.T, noUnmarshal bool) (
gi *GetInfoResult, raw []byte, err error,
) {
gi = &GetInfoResult{}
raw, err = cl.RPC(c, GetInfo, nil, gi, noUnmarshal, nil)
return
}
func (cl *Client) ListTransactions(
c context.T, params *ListTransactionsParams, noUnmarshal bool,
) (lt *ListTransactionsResult, raw []byte, err error) {
lt = &ListTransactionsResult{}
raw, err = cl.RPC(c, ListTransactions, params, &lt, noUnmarshal, nil)
return
}
func (cl *Client) LookupInvoice(
c context.T, params *LookupInvoiceParams, noUnmarshal bool,
) (li *LookupInvoiceResult, raw []byte, err error) {
li = &LookupInvoiceResult{}
raw, err = cl.RPC(c, LookupInvoice, params, &li, noUnmarshal, nil)
return
}
func (cl *Client) MakeHoldInvoice(
c context.T,
mhi *MakeHoldInvoiceParams, noUnmarshal bool,
) (mi *MakeInvoiceResult, raw []byte, err error) {
mi = &MakeInvoiceResult{}
raw, err = cl.RPC(c, MakeHoldInvoice, mhi, mi, noUnmarshal, nil)
return
}
func (cl *Client) MakeInvoice(
c context.T, params *MakeInvoiceParams, noUnmarshal bool,
) (mi *MakeInvoiceResult, raw []byte, err error) {
mi = &MakeInvoiceResult{}
raw, err = cl.RPC(c, MakeInvoice, params, &mi, noUnmarshal, nil)
return
}
// MultiPayInvoice
// MultiPayKeysend
func (cl *Client) PayKeysend(
c context.T, params *PayKeysendParams, noUnmarshal bool,
) (pk *PayKeysendResult, raw []byte, err error) {
pk = &PayKeysendResult{}
raw, err = cl.RPC(c, PayKeysend, params, &pk, noUnmarshal, nil)
return
}
func (cl *Client) PayInvoice(
c context.T, params *PayInvoiceParams, noUnmarshal bool,
) (pi *PayInvoiceResult, raw []byte, err error) {
pi = &PayInvoiceResult{}
raw, err = cl.RPC(c, PayInvoice, params, &pi, noUnmarshal, nil)
return
}
func (cl *Client) SettleHoldInvoice(
c context.T, shi *SettleHoldInvoiceParams, noUnmarshal bool,
) (raw []byte, err error) {
return cl.RPC(c, SettleHoldInvoice, shi, nil, noUnmarshal, nil)
}
func (cl *Client) SignMessage(
c context.T, sm *SignMessageParams, noUnmarshal bool,
) (res *SignMessageResult, raw []byte, err error) {
res = &SignMessageResult{}
raw, err = cl.RPC(c, SignMessage, sm, &res, noUnmarshal, nil)
return
}

View File

@@ -1,473 +1,188 @@
package nwc
import (
"fmt"
"time"
// Capability represents a NIP-47 method
type Capability []byte
var (
CancelHoldInvoice = Capability("cancel_hold_invoice")
CreateConnection = Capability("create_connection")
GetBalance = Capability("get_balance")
GetBudget = Capability("get_budget")
GetInfo = Capability("get_info")
ListTransactions = Capability("list_transactions")
LookupInvoice = Capability("lookup_invoice")
MakeHoldInvoice = Capability("make_hold_invoice")
MakeInvoice = Capability("make_invoice")
MultiPayInvoice = Capability("multi_pay_invoice")
MultiPayKeysend = Capability("multi_pay_keysend")
PayInvoice = Capability("pay_invoice")
PayKeysend = Capability("pay_keysend")
SettleHoldInvoice = Capability("settle_hold_invoice")
SignMessage = Capability("sign_message")
)
// EncryptionType represents the encryption type used for NIP-47 messages
type EncryptionType string
type EncryptionType []byte
const (
Nip04 EncryptionType = "nip04"
Nip44V2 EncryptionType = "nip44_v2"
var (
EncryptionTag = []byte("encryption")
Nip04 = EncryptionType("nip04")
Nip44V2 = EncryptionType("nip44_v2")
)
// AuthorizationUrlOptions represents options for creating an NWC authorization URL
type AuthorizationUrlOptions struct {
Name string `json:"name,omitempty"`
Icon string `json:"icon,omitempty"`
RequestMethods []Method `json:"requestMethods,omitempty"`
NotificationTypes []NotificationType `json:"notificationTypes,omitempty"`
ReturnTo string `json:"returnTo,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
MaxAmount *int64 `json:"maxAmount,omitempty"`
BudgetRenewal BudgetRenewalPeriod `json:"budgetRenewal,omitempty"`
Isolated bool `json:"isolated,omitempty"`
Metadata interface{} `json:"metadata,omitempty"`
}
type NotificationType []byte
// Err is the base error type for NIP-47 errors
type Err struct {
Message string
Code string
}
func (e *Err) Error() string {
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
}
// NewError creates a new Error
func NewError(message, code string) *Err {
return &Err{
Message: message,
Code: code,
}
}
// NetworkError represents a network error in NIP-47 operations
type NetworkError struct{ *Err }
// NewNetworkError creates a new NetworkError
func NewNetworkError(message, code string) *NetworkError {
return &NetworkError{
Err: NewError(message, code),
}
}
// WalletError represents a wallet error in NIP-47 operations
type WalletError struct {
*Err
}
// NewWalletError creates a new WalletError
func NewWalletError(message, code string) *WalletError {
return &WalletError{
Err: NewError(message, code),
}
}
// TimeoutError represents a timeout error in NIP-47 operations
type TimeoutError struct{ *Err }
// NewTimeoutError creates a new TimeoutError
func NewTimeoutError(message, code string) *TimeoutError {
return &TimeoutError{
Err: NewError(message, code),
}
}
// PublishTimeoutError represents a publish timeout error in NIP-47 operations
type PublishTimeoutError struct{ *TimeoutError }
// NewPublishTimeoutError creates a new PublishTimeoutError
func NewPublishTimeoutError(message, code string) *PublishTimeoutError {
return &PublishTimeoutError{
TimeoutError: NewTimeoutError(message, code),
}
}
// ReplyTimeoutError represents a reply timeout error in NIP-47 operations
type ReplyTimeoutError struct{ *TimeoutError }
// NewReplyTimeoutError creates a new ReplyTimeoutError
func NewReplyTimeoutError(message, code string) *ReplyTimeoutError {
return &ReplyTimeoutError{
TimeoutError: NewTimeoutError(message, code),
}
}
// PublishError represents a publish error in NIP-47 operations
type PublishError struct{ *Err }
// NewPublishError creates a new PublishError
func NewPublishError(message, code string) *PublishError {
return &PublishError{
Err: NewError(message, code),
}
}
// ResponseDecodingError represents a response decoding error in NIP-47 operations
type ResponseDecodingError struct{ *Err }
// NewResponseDecodingError creates a new ResponseDecodingError
func NewResponseDecodingError(message, code string) *ResponseDecodingError {
return &ResponseDecodingError{
Err: NewError(message, code),
}
}
// ResponseValidationError represents a response validation error in NIP-47 operations
type ResponseValidationError struct{ *Err }
// NewResponseValidationError creates a new ResponseValidationError
func NewResponseValidationError(message, code string) *ResponseValidationError {
return &ResponseValidationError{
Err: NewError(message, code),
}
}
// UnexpectedResponseError represents an unexpected response error in NIP-47 operations
type UnexpectedResponseError struct{ *Err }
// NewUnexpectedResponseError creates a new UnexpectedResponseError
func NewUnexpectedResponseError(message, code string) *UnexpectedResponseError {
return &UnexpectedResponseError{
Err: NewError(message, code),
}
}
// UnsupportedEncryptionError represents an unsupported encryption error in NIP-47 operations
type UnsupportedEncryptionError struct {
*Err
}
// NewUnsupportedEncryptionError creates a new UnsupportedEncryptionError
func NewUnsupportedEncryptionError(message, code string) *UnsupportedEncryptionError {
return &UnsupportedEncryptionError{
Err: NewError(message, code),
}
}
// WithDTag represents a type with a dTag field
type WithDTag struct {
DTag string `json:"dTag"`
}
// WithOptionalId represents a type with an optional id field
type WithOptionalId struct {
ID string `json:"id,omitempty"`
}
// Method represents a NIP-47 method
type Method string
// SingleMethod represents a single NIP-47 method
const (
GetInfo Method = "get_info"
GetBalance Method = "get_balance"
GetBudget Method = "get_budget"
MakeInvoice Method = "make_invoice"
PayInvoice Method = "pay_invoice"
PayKeysend Method = "pay_keysend"
LookupInvoice Method = "lookup_invoice"
ListTransactions Method = "list_transactions"
SignMessage Method = "sign_message"
CreateConnection Method = "create_connection"
MakeHoldInvoice Method = "make_hold_invoice"
SettleHoldInvoice Method = "settle_hold_invoice"
CancelHoldInvoice Method = "cancel_hold_invoice"
var (
PaymentReceived = NotificationType("payment_received")
PaymentSent = NotificationType("payment_sent")
)
// MultiMethod represents a multi NIP-47 method
const (
MultiPayInvoice Method = "multi_pay_invoice"
MultiPayKeysend Method = "multi_pay_keysend"
)
// Capability represents a NIP-47 capability
type Capability string
const (
Notifications Capability = "notifications"
)
// BudgetRenewalPeriod represents a budget renewal period
type BudgetRenewalPeriod string
const (
Daily BudgetRenewalPeriod = "daily"
Weekly BudgetRenewalPeriod = "weekly"
Monthly BudgetRenewalPeriod = "monthly"
Yearly BudgetRenewalPeriod = "yearly"
Never BudgetRenewalPeriod = "never"
)
// GetInfoResponse represents a response to a get_info request
type GetInfoResponse struct {
Alias string `json:"alias"`
Color string `json:"color"`
Pubkey string `json:"pubkey"`
Network string `json:"network"`
BlockHeight int64 `json:"block_height"`
BlockHash string `json:"block_hash"`
Methods []Method `json:"methods"`
Notifications []NotificationType `json:"notifications,omitempty"`
Metadata interface{} `json:"metadata,omitempty"`
Lud16 string `json:"lud16,omitempty"`
type WalletServiceInfo struct {
EncryptionTypes []EncryptionType
Capabilities []Capability
NotificationTypes []NotificationType
}
// GetBudgetResponse represents a response to a get_budget request
type GetBudgetResponse struct {
UsedBudget int64 `json:"used_budget,omitempty"`
TotalBudget int64 `json:"total_budget,omitempty"`
RenewsAt *int64 `json:"renews_at,omitempty"`
RenewalPeriod BudgetRenewalPeriod `json:"renewal_period,omitempty"`
type GetInfoResult struct {
Alias string `json:"alias"`
Color string `json:"color"`
Pubkey string `json:"pubkey"`
Network string `json:"network"`
BlockHeight uint64 `json:"block_height"`
BlockHash string `json:"block_hash"`
Methods []string `json:"methods"`
Notifications []string `json:"notifications,omitempty"`
Metadata any `json:"metadata,omitempty"`
LUD16 string `json:"lud16,omitempty"`
}
// GetBalanceResponse represents a response to a get_balance request
type GetBalanceResponse struct {
Balance int64 `json:"balance"` // msats
type GetBudgetResult struct {
UsedBudget int `json:"used_budget,omitempty"`
TotalBudget int `json:"total_budget,omitempty"`
RenewsAt int `json:"renews_at,omitempty"`
RenewalPeriod string `json:"renewal_period,omitempty"`
}
// PayResponse represents a response to a pay request
type PayResponse struct {
type GetBalanceResult struct {
Balance uint64 `json:"balance"`
}
type MakeInvoiceParams struct {
Amount uint64 `json:"amount"`
Description string `json:"description,omitempty"`
DescriptionHash string `json:"description_hash,omitempty"`
Expiry *int64 `json:"expiry,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
type MakeHoldInvoiceParams struct {
Amount uint64 `json:"amount"`
PaymentHash string `json:"payment_hash"`
Description string `json:"description,omitempty"`
DescriptionHash string `json:"description_hash,omitempty"`
Expiry *int64 `json:"expiry,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
type SettleHoldInvoiceParams struct {
Preimage string `json:"preimage"`
FeesPaid int64 `json:"fees_paid"`
}
// MultiPayInvoiceRequest represents a request to pay multiple invoices
type MultiPayInvoiceRequest struct {
Invoices []PayInvoiceRequestWithID `json:"invoices"`
type CancelHoldInvoiceParams struct {
PaymentHash string `json:"payment_hash"`
}
// PayInvoiceRequestWithID combines PayInvoiceRequest with WithOptionalId
type PayInvoiceRequestWithID struct {
PayInvoiceRequest
WithOptionalId
type PayInvoicePayerData struct {
Email string `json:"email"`
Name string `json:"name"`
Pubkey string `json:"pubkey"`
}
// MultiPayKeysendRequest represents a request to pay multiple keysends
type MultiPayKeysendRequest struct {
Keysends []PayKeysendRequestWithID `json:"keysends"`
type PayInvoiceMetadata struct {
Comment *string `json:"comment"`
PayerData *PayInvoicePayerData `json:"payer_data"`
Other any
}
// PayKeysendRequestWithID combines PayKeysendRequest with WithOptionalId
type PayKeysendRequestWithID struct {
PayKeysendRequest
WithOptionalId
type PayInvoiceParams struct {
Invoice string `json:"invoice"`
Amount *uint64 `json:"amount,omitempty"`
Metadata *PayInvoiceMetadata `json:"metadata,omitempty"`
}
// MultiPayInvoiceResponse represents a response to a multi_pay_invoice request
type MultiPayInvoiceResponse struct {
Invoices []MultiPayInvoiceResponseItem `json:"invoices"`
Errors []interface{} `json:"errors"` // TODO: add error handling
type PayInvoiceResult struct {
Preimage string `json:"preimage"`
FeesPaid uint64 `json:"fees_paid"`
}
// MultiPayInvoiceResponseItem represents an item in a multi_pay_invoice response
type MultiPayInvoiceResponseItem struct {
Invoice PayInvoiceRequest `json:"invoice"`
PayResponse
WithDTag
}
// MultiPayKeysendResponse represents a response to a multi_pay_keysend request
type MultiPayKeysendResponse struct {
Keysends []MultiPayKeysendResponseItem `json:"keysends"`
Errors []interface{} `json:"errors"` // TODO: add error handling
}
// MultiPayKeysendResponseItem represents an item in a multi_pay_keysend response
type MultiPayKeysendResponseItem struct {
Keysend PayKeysendRequest `json:"keysend"`
PayResponse
WithDTag
}
// ListTransactionsRequest represents a request to list transactions
type ListTransactionsRequest struct {
From *int64 `json:"from,omitempty"`
Until *int64 `json:"until,omitempty"`
Limit *int64 `json:"limit,omitempty"`
Offset *int64 `json:"offset,omitempty"`
Unpaid *bool `json:"unpaid,omitempty"`
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"` // NOTE: non-NIP-47 spec compliant
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"` // NOTE: non-NIP-47 spec compliant
Type *string `json:"type,omitempty"` // "incoming" or "outgoing"
}
// ListTransactionsResponse represents a response to a list_transactions request
type ListTransactionsResponse struct {
Transactions []Transaction `json:"transactions"`
TotalCount int64 `json:"total_count"` // NOTE: non-NIP-47 spec compliant
}
// TransactionType represents the type of a transaction
type TransactionType string
const (
Incoming TransactionType = "incoming"
Outgoing TransactionType = "outgoing"
)
// TransactionState represents the state of a transaction
type TransactionState string
const (
Settled TransactionState = "settled"
Pending TransactionState = "pending"
Failed TransactionState = "failed"
)
// Transaction represents a transaction
type Transaction struct {
Type TransactionType `json:"type"`
State TransactionState `json:"state"` // NOTE: non-NIP-47 spec compliant
Invoice string `json:"invoice"`
Description string `json:"description"`
DescriptionHash string `json:"description_hash"`
Preimage string `json:"preimage"`
PaymentHash string `json:"payment_hash"`
Amount int64 `json:"amount"`
FeesPaid int64 `json:"fees_paid"`
SettledAt int64 `json:"settled_at"`
CreatedAt int64 `json:"created_at"`
ExpiresAt int64 `json:"expires_at"`
SettleDeadline *int64 `json:"settle_deadline,omitempty"` // NOTE: non-NIP-47 spec compliant
Metadata *TransactionMetadata `json:"metadata,omitempty"`
}
// TransactionMetadata represents metadata for a transaction
type TransactionMetadata struct {
Comment string `json:"comment,omitempty"` // LUD-12
PayerData *PayerData `json:"payer_data,omitempty"` // LUD-18
RecipientData *RecipientData `json:"recipient_data,omitempty"` // LUD-18
Nostr *NostrData `json:"nostr,omitempty"` // NIP-57
ExtraData map[string]interface{} `json:"-"` // For additional fields
}
// PayerData represents payer data for a transaction
type PayerData struct {
Email string `json:"email,omitempty"`
Name string `json:"name,omitempty"`
Pubkey string `json:"pubkey,omitempty"`
}
// RecipientData represents recipient data for a transaction
type RecipientData struct {
Identifier string `json:"identifier,omitempty"`
}
// NostrData represents Nostr data for a transaction
type NostrData struct {
Pubkey string `json:"pubkey"`
Tags [][]string `json:"tags"`
}
// NotificationType represents a notification type
type NotificationType string
const (
PaymentReceived NotificationType = "payment_received"
PaymentSent NotificationType = "payment_sent"
HoldInvoiceAccepted NotificationType = "hold_invoice_accepted"
)
// Notification represents a notification
type Notification struct {
NotificationType NotificationType `json:"notification_type"`
Notification Transaction `json:"notification"`
}
// PayInvoiceRequest represents a request to pay an invoice
type PayInvoiceRequest struct {
Invoice string `json:"invoice"`
Metadata *TransactionMetadata `json:"metadata,omitempty"`
Amount *int64 `json:"amount,omitempty"` // msats
}
// PayKeysendRequest represents a request to pay a keysend
type PayKeysendRequest struct {
Amount int64 `json:"amount"` // msats
Pubkey string `json:"pubkey"`
Preimage string `json:"preimage,omitempty"`
TlvRecords []TlvRecord `json:"tlv_records,omitempty"`
}
// TlvRecord represents a TLV record
type TlvRecord struct {
Type int64 `json:"type"`
type PayKeysendTLVRecord struct {
Type uint32 `json:"type"`
Value string `json:"value"`
}
// MakeInvoiceRequest represents a request to make an invoice
type MakeInvoiceRequest struct {
Amount int64 `json:"amount"` // msats
Description string `json:"description,omitempty"`
DescriptionHash string `json:"description_hash,omitempty"`
Expiry *int64 `json:"expiry,omitempty"` // in seconds
Metadata *TransactionMetadata `json:"metadata,omitempty"`
type PayKeysendParams struct {
Amount uint64 `json:"amount"`
Pubkey string `json:"pubkey"`
Preimage *string `json:"preimage,omitempty"`
TLVRecords []PayKeysendTLVRecord `json:"tlv_records,omitempty"`
}
// MakeHoldInvoiceRequest represents a request to make a hold invoice
type MakeHoldInvoiceRequest struct {
MakeInvoiceRequest
PaymentHash string `json:"payment_hash"`
type PayKeysendResult = PayInvoiceResult
type LookupInvoiceParams struct {
PaymentHash *string `json:"payment_hash,omitempty"`
Invoice *string `json:"invoice,omitempty"`
}
// SettleHoldInvoiceRequest represents a request to settle a hold invoice
type SettleHoldInvoiceRequest struct {
Preimage string `json:"preimage"`
type ListTransactionsParams struct {
From *int64 `json:"from,omitempty"`
Until *int64 `json:"until,omitempty"`
Limit *uint16 `json:"limit,omitempty"`
Offset *uint32 `json:"offset,omitempty"`
Unpaid *bool `json:"unpaid,omitempty"`
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"`
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"`
Type *string `json:"type,omitempty"`
}
// SettleHoldInvoiceResponse represents a response to a settle_hold_invoice request
type SettleHoldInvoiceResponse struct{}
// CancelHoldInvoiceRequest represents a request to cancel a hold invoice
type CancelHoldInvoiceRequest struct {
PaymentHash string `json:"payment_hash"`
type MakeInvoiceResult = Transaction
type LookupInvoiceResult = Transaction
type ListTransactionsResult struct {
Transactions []Transaction `json:"transactions"`
TotalCount uint32 `json:"total_count"`
}
// CancelHoldInvoiceResponse represents a response to a cancel_hold_invoice request
type CancelHoldInvoiceResponse struct{}
// LookupInvoiceRequest represents a request to lookup an invoice
type LookupInvoiceRequest struct {
PaymentHash string `json:"payment_hash,omitempty"`
Invoice string `json:"invoice,omitempty"`
type Transaction struct {
Type string `json:"type"`
State string `json:"state"`
Invoice string `json:"invoice"`
Description string `json:"description"`
DescriptionHash string `json:"description_hash"`
Preimage string `json:"preimage"`
PaymentHash string `json:"payment_hash"`
Amount uint64 `json:"amount"`
FeesPaid uint64 `json:"fees_paid"`
CreatedAt int64 `json:"created_at"`
ExpiresAt int64 `json:"expires_at"`
SettledDeadline *uint64 `json:"settled_deadline,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
// SignMessageRequest represents a request to sign a message
type SignMessageRequest struct {
type SignMessageParams struct {
Message string `json:"message"`
}
// CreateConnectionRequest represents a request to create a connection
type CreateConnectionRequest struct {
Pubkey string `json:"pubkey"`
Name string `json:"name"`
RequestMethods []Method `json:"request_methods"`
NotificationTypes []NotificationType `json:"notification_types,omitempty"`
MaxAmount *int64 `json:"max_amount,omitempty"`
BudgetRenewal *BudgetRenewalPeriod `json:"budget_renewal,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
Isolated *bool `json:"isolated,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
// CreateConnectionResponse represents a response to a create_connection request
type CreateConnectionResponse struct {
WalletPubkey string `json:"wallet_pubkey"`
}
// SignMessageResponse represents a response to a sign_message request
type SignMessageResponse struct {
type SignMessageResult struct {
Message string `json:"message"`
Signature string `json:"signature"`
}
// TimeoutValues represents timeout values for NIP-47 requests
type TimeoutValues struct {
ReplyTimeout *int64 `json:"replyTimeout,omitempty"`
PublishTimeout *int64 `json:"publishTimeout,omitempty"`
type CreateConnectionParams struct {
Pubkey string `json:"pubkey"`
Name string `json:"name"`
RequestMethods []string `json:"request_methods"`
NotificationTypes []string `json:"notification_types"`
MaxAmount *uint64 `json:"max_amount,omitempty"`
BudgetRenewal *string `json:"budget_renewal,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
}

57
pkg/protocol/nwc/uri.go Normal file
View File

@@ -0,0 +1,57 @@
package nwc
import (
"errors"
"net/url"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/utils/chk"
)
type ConnectionParams struct {
clientSecretKey []byte
walletPublicKey []byte
relay string
}
// GetWalletPublicKey returns the wallet public key from the ConnectionParams.
func (c *ConnectionParams) GetWalletPublicKey() []byte {
return c.walletPublicKey
}
func ParseConnectionURI(nwcUri string) (parts *ConnectionParams, err error) {
var p *url.URL
if p, err = url.Parse(nwcUri); chk.E(err) {
return
}
parts = &ConnectionParams{}
if p.Scheme != "nostr+walletconnect" {
err = errors.New("incorrect scheme")
return
}
if parts.walletPublicKey, err = p256k.HexToBin(p.Host); chk.E(err) {
err = errors.New("invalid public key")
return
}
query := p.Query()
var ok bool
var relay []string
if relay, ok = query["relay"]; !ok {
err = errors.New("missing relay parameter")
return
}
if len(relay) == 0 {
return nil, errors.New("no relays")
}
parts.relay = relay[0]
var secret string
if secret = query.Get("secret"); secret == "" {
err = errors.New("missing secret parameter")
return
}
if parts.clientSecretKey, err = p256k.HexToBin(secret); chk.E(err) {
err = errors.New("invalid secret")
return
}
return
}

View File

@@ -112,7 +112,7 @@ func (x *Operations) RegisterEvent(api huma.API) {
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
return
}
authed, pubkey, super = x.UserAuth(r, remote)
if !authed {
// Record the failed authentication attempt

View File

@@ -118,7 +118,7 @@ func (p *Publisher) Receive(msg typer.T) {
Receiver: m.Receiver,
Pubkey: m.Pubkey,
}
// Add the filters if provided
if m.FilterMap != nil {
for id, f := range m.FilterMap {
@@ -126,7 +126,7 @@ func (p *Publisher) Receive(msg typer.T) {
log.T.F("added subscription %s for new listener %s", id, m.Id)
}
}
// Add the listener to the map
p.ListenMap[m.Id] = listener
log.T.F("added new listener %s", m.Id)
@@ -254,7 +254,7 @@ func CheckListenerExists(clientId string, publishers ...publisher.I) bool {
return true
}
}
// Check if the publisher has a Publishers field of type publisher.Publishers
// This handles the case where the publisher is a *publish.S
val := reflect.ValueOf(p)
@@ -290,7 +290,7 @@ func CheckSubscriptionExists(clientId string, subscriptionId string, publishers
return true
}
}
// Check if the publisher has a Publishers field of type publisher.Publishers
// This handles the case where the publisher is a *publish.S
val := reflect.ValueOf(p)

View File

@@ -72,7 +72,7 @@ func (a *A) HandleAuth(b []byte, srv server.I) (msg []byte) {
env.Event.Pubkey,
)
a.Listener.SetAuthedPubkey(env.Event.Pubkey)
// If authentication is successful, remove any blocks for this IP
iptracker.Global.Authenticate(a.Listener.RealRemote())
}

View File

@@ -75,43 +75,43 @@ func (a *A) HandleEvent(
if a.I.AuthRequired() && !a.Listener.IsAuthed() {
remoteIP := a.Listener.RealRemote()
log.I.F("requesting auth from client from %s", remoteIP)
// Check if the IP is blocked due to too many failed auth attempts
if iptracker.Global.IsBlocked(remoteIP) {
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockedUntil.Format(time.RFC3339))
// Send a notice to the client explaining why they're blocked
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
err = nil
}
// Close the connection
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
a.Listener.Close()
return
}
// Record the failed authentication attempt
blocked := iptracker.Global.RecordFailedAttempt(remoteIP)
if blocked {
// If this attempt caused the IP to be blocked, close the connection
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockedUntil.Format(time.RFC3339))
// Send a notice to the client explaining why they're blocked
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
err = nil
}
// Close the connection
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
a.Listener.Close()
return
}
// Continue with normal auth flow for non-blocked IPs
a.Listener.RequestAuth()
if err = Ok.AuthRequired(a, env.E, "auth required"); chk.E(err) {

View File

@@ -57,18 +57,18 @@ type A struct {
func (a *A) Serve(w http.ResponseWriter, r *http.Request, s server.I) {
c := a.Config()
remote := helpers.GetRemoteFromReq(r)
// Check if the IP is blocked due to too many failed auth attempts
if iptracker.Global.IsBlocked(remote) {
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
remote, blockedUntil.Format(time.RFC3339))
// We can't send a notice to the client here because the websocket connection
// We can't send a notice to the client here because the websocket connection
// hasn't been established yet, so we just reject the connection
return
}
var whitelisted bool
if len(c.Whitelist) > 0 {
for _, addr := range c.Whitelist {

View File

@@ -3,11 +3,19 @@ package ws
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/puzpuzpuz/xsync/v3"
"orly.dev/pkg/encoders/envelopes"
"orly.dev/pkg/encoders/envelopes/authenvelope"
"orly.dev/pkg/encoders/envelopes/closedenvelope"
"orly.dev/pkg/encoders/envelopes/countenvelope"
"orly.dev/pkg/encoders/envelopes/eoseenvelope"
"orly.dev/pkg/encoders/envelopes/eventenvelope"
"orly.dev/pkg/encoders/envelopes/noticeenvelope"
@@ -15,54 +23,43 @@ import (
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/protocol/auth"
"orly.dev/pkg/utils/atomic"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/normalize"
"sync"
"time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/puzpuzpuz/xsync/v3"
)
var subscriptionIDCounter atomic.Int32
var subscriptionIDCounter atomic.Int64
// Relay represents a connection to a Nostr relay.
type Client struct {
closeMutex sync.Mutex
URL string
URL string
requestHeader http.Header // e.g. for origin header
RequestHeader http.Header // e.g. for origin header
Connection *Connection
Subscriptions *xsync.MapOf[int64, *Subscription]
Connection *Connection
Subscriptions *xsync.MapOf[string, *Subscription]
ConnectionError error
connectionContext context.T // will be canceled when the connection closes
connectionContextCancel context.F
challenge []byte // NIP-42 challenge, we only keep the last
notices chan []byte // NIP-01 NOTICEs
okCallbacks *xsync.MapOf[string, func(bool, string)]
writeQueue chan writeRequest
ConnectionError error
connectionContext context.T // will be canceled when the connection closes
connectionContextCancel context.C
challenge []byte // NIP-42 challenge, we only keep the last
noticeHandler func(string) // NIP-01 NOTICEs
customHandler func(string) // nonstandard unparseable messages
okCallbacks *xsync.MapOf[string, func(bool, string)]
writeQueue chan writeRequest
subscriptionChannelCloseQueue chan *Subscription
signatureChecker func(*event.E) bool
// custom things that aren't often used
//
AssumeValid bool // this will skip verifying signatures for events received from this relay
}
@@ -71,21 +68,20 @@ type writeRequest struct {
answer chan error
}
// NewRelay returns a new relay. The relay connection will be closed when the
// context is cancelled.
func NewRelay(c context.T, url string, opts ...RelayOption) *Client {
ctx, cancel := context.Cancel(c)
// NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection.
func NewRelay(ctx context.T, url string, opts ...RelayOption) *Client {
ctx, cancel := context.Cause(ctx)
r := &Client{
URL: string(normalize.URL([]byte(url))),
URL: string(normalize.URL(url)),
connectionContext: ctx,
connectionContextCancel: cancel,
Subscriptions: xsync.NewMapOf[string, *Subscription](),
Subscriptions: xsync.NewMapOf[int64, *Subscription](),
okCallbacks: xsync.NewMapOf[string, func(
bool, string,
)](),
writeQueue: make(chan writeRequest),
subscriptionChannelCloseQueue: make(chan *Subscription),
signatureChecker: func(e *event.E) bool { ok, _ := e.Verify(); return ok },
requestHeader: nil,
}
for _, opt := range opts {
@@ -95,9 +91,12 @@ func NewRelay(c context.T, url string, opts ...RelayOption) *Client {
return r
}
// RelayConnect returns a relay object connected to url. Once successfully
// connected, cancelling ctx has no effect. To close the connection, call
// r.Close().
// RelayConnect returns a relay object connected to url.
//
// The given subscription is only used during the connection phase. Once successfully connected, cancelling ctx has no effect.
//
// The ongoing relay connection uses a background context. To close the connection, call r.Close().
// If you need fine grained long-term connection contexts, use NewRelay() instead.
func RelayConnect(ctx context.T, url string, opts ...RelayOption) (
*Client, error,
) {
@@ -106,35 +105,38 @@ func RelayConnect(ctx context.T, url string, opts ...RelayOption) (
return r, err
}
// RelayOption is the type of the argument passed for that.
// RelayOption is the type of the argument passed when instantiating relay connections.
type RelayOption interface {
ApplyRelayOption(*Client)
}
var (
_ RelayOption = (WithNoticeHandler)(nil)
_ RelayOption = (WithSignatureChecker)(nil)
_ RelayOption = (WithCustomHandler)(nil)
_ RelayOption = (WithRequestHeader)(nil)
)
// WithNoticeHandler just takes notices and is expected to do something with
// them. when not given, defaults to logging the notices.
type WithNoticeHandler func(notice []byte)
// WithNoticeHandler just takes notices and is expected to do something with them.
// when not given, defaults to logging the notices.
type WithNoticeHandler func(notice string)
func (nh WithNoticeHandler) ApplyRelayOption(r *Client) {
r.notices = make(chan []byte)
go func() {
for notice := range r.notices {
nh(notice)
}
}()
r.noticeHandler = nh
}
// WithSignatureChecker must be a function that checks the signature of an event
// and returns true or false.
type WithSignatureChecker func(*event.E) bool
// WithCustomHandler must be a function that handles any relay message that couldn't be
// parsed as a standard envelope.
type WithCustomHandler func(data string)
func (sc WithSignatureChecker) ApplyRelayOption(r *Client) {
r.signatureChecker = sc
func (ch WithCustomHandler) ApplyRelayOption(r *Client) {
r.customHandler = ch
}
// WithRequestHeader sets the HTTP request header of the websocket preflight request.
type WithRequestHeader http.Header
func (ch WithRequestHeader) ApplyRelayOption(r *Client) {
r.requestHeader = http.Header(ch)
}
// String just returns the relay URL.
@@ -143,253 +145,269 @@ func (r *Client) String() string {
}
// Context retrieves the context that is associated with this relay connection.
// It will be closed when the relay is disconnected.
func (r *Client) Context() context.T { return r.connectionContext }
// IsConnected returns true if the connection to this relay seems to be active.
func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil }
// Connect tries to establish a websocket connection to r.URL. If the context
// expires before the connection is complete, an error is returned. Once
// successfully connected, context expiration has no effect: call r.Close to
// close the connection.
// Connect tries to establish a websocket connection to r.URL.
// If the context expires before the connection is complete, an error is returned.
// Once successfully connected, context expiration has no effect: call r.Close
// to close the connection.
//
// The underlying relay connection will use a background context. If you want to
// pass a custom context to the underlying relay connection, use NewRelay() and
// then Client.Connect().
func (r *Client) Connect(c context.T) error { return r.ConnectWithTLS(c, nil) }
// The given context here is only used during the connection phase. The long-living
// relay connection will be based on the context given to NewRelay().
func (r *Client) Connect(ctx context.T) error {
return r.ConnectWithTLS(ctx, nil)
}
// ConnectWithTLS tries to establish a secured websocket connection to r.URL
// using customized tls.Config (CA's, etc.).
func (r *Client) ConnectWithTLS(ctx context.T, tlsConfig *tls.Config) error {
func subIdToSerial(subId string) int64 {
n := strings.Index(subId, ":")
if n < 0 || n > len(subId) {
return -1
}
serialId, _ := strconv.ParseInt(subId[0:n], 10, 64)
return serialId
}
// ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that.
func (r *Client) ConnectWithTLS(
ctx context.T, tlsConfig *tls.Config,
) (err error) {
if r.connectionContext == nil || r.Subscriptions == nil {
return errorf.E("relay must be initialized with a call to NewRelay()")
return fmt.Errorf("relay must be initialized with a call to NewRelay()")
}
if r.URL == "" {
return errorf.E("invalid relay URL '%s'", r.URL)
return fmt.Errorf("invalid relay URL '%s'", r.URL)
}
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.F
ctx, cancel = context.Timeout(ctx, 7*time.Second)
ctx, cancel = context.TimeoutCause(
ctx, 7*time.Second, errors.New("connection took too long"),
)
defer cancel()
}
conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig)
if err != nil {
return errorf.E(
"error opening websocket to '%s': %s", r.URL, err.Error(),
)
var conn *Connection
if conn, err = NewConnection(
ctx, r.URL, r.requestHeader, tlsConfig,
); chk.E(err) {
err = fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
return
}
r.Connection = conn
// ping every 29 seconds (??)
// ping every 29 seconds
ticker := time.NewTicker(29 * time.Second)
// to be used when the connection is closed
go func() {
<-r.connectionContext.Done()
// close these things when the connection is closed
if r.notices != nil {
close(r.notices)
}
// stop the ticker
ticker.Stop()
// close all subscriptions
r.Subscriptions.Range(
func(_ string, sub *Subscription) bool {
go sub.Unsub()
return true
},
)
}()
// queue all write operations here so we don't do mutex spaghetti
go func() {
var err error
for {
select {
case <-r.connectionContext.Done():
ticker.Stop()
r.Connection = nil
for _, sub := range r.Subscriptions.Range {
sub.unsub(
fmt.Errorf(
"relay connection closed: %w / %w",
context.GetCause(r.connectionContext),
r.ConnectionError,
),
)
}
return
case <-ticker.C:
err = wsutil.WriteClientMessage(
r.Connection.conn, ws.OpPing, nil,
)
if err != nil {
log.D.F(
err := r.Connection.Ping(r.connectionContext)
if err != nil && !strings.Contains(
err.Error(), "failed to wait for pong",
) {
log.I.F(
"{%s} error writing ping: %v; closing websocket", r.URL,
err,
)
r.Close() // this should trigger a context cancelation
return
}
case writeReq := <-r.writeQueue:
case writeRequest := <-r.writeQueue:
// all write requests will go through this to prevent races
if err = r.Connection.WriteMessage(
r.connectionContext,
writeReq.msg,
); chk.T(err) {
writeReq.answer <- err
log.D.F("{%s} sending %v\n", r.URL, string(writeRequest.msg))
if err := r.Connection.WriteMessage(
r.connectionContext, writeRequest.msg,
); err != nil {
writeRequest.answer <- err
}
close(writeReq.answer)
case <-r.connectionContext.Done():
// stop here
return
close(writeRequest.answer)
}
}
}()
// general message reader loop
go func() {
var err error
for {
buf := new(bytes.Buffer)
if err = conn.ReadMessage(r.connectionContext, buf); err != nil {
if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
r.ConnectionError = err
r.Close()
r.close(err)
break
}
message := buf.Bytes()
// log.D.F("{%s} %s\n", r.URL, message)
var err error
var t string
if t, message, err = envelopes.Identify(message); chk.E(err) {
var rem []byte
if t, rem, err = envelopes.Identify(buf.Bytes()); chk.E(err) {
continue
}
switch t {
case noticeenvelope.L:
env := noticeenvelope.New()
if env, message, err = noticeenvelope.Parse(message); chk.E(err) {
continue
}
env := noticeenvelope.NewFrom(rem)
// see WithNoticeHandler
if r.notices != nil {
r.notices <- env.Message
if r.noticeHandler != nil {
r.noticeHandler(string(env.Message))
} else {
log.E.F("NOTICE from %s: '%s'\n", r.URL, env.Message)
log.D.F(
"NOTICE from %s: '%s'\n", r.URL, string(env.Message),
)
}
case authenvelope.L:
env := authenvelope.NewChallenge()
if env, message, err = authenvelope.ParseChallenge(message); chk.E(err) {
continue
}
if len(env.Challenge) == 0 {
env := authenvelope.NewChallengeWith(rem)
if env.Challenge == nil {
continue
}
r.challenge = env.Challenge
case eventenvelope.L:
// log.I.F("message: %s", message)
env := eventenvelope.NewResult()
if env, message, err = eventenvelope.ParseResult(message); err != nil {
var env *eventenvelope.Result
env = eventenvelope.NewResult()
if _, err = env.Unmarshal(rem); chk.E(err) {
continue
}
// log.I.F("%s", env.Event.Marshal(nil))
if len(env.Subscription.T) == 0 {
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
env.Subscription.String(),
)
continue
}
if sub, ok := r.Subscriptions.Load(env.Subscription.String()); !ok {
// log.D.F(
// "{%s} no subscription with id '%s'\n", r.URL,
// env.Subscription,
// )
if !sub.Filters.Match(env.Event) {
log.I.F(
"{%s} filter does not match: %v ~ %s\n", r.URL,
sub.Filters, env.Event.Marshal(nil),
)
continue
} else {
// check if the event matches the desired filter, ignore
// otherwise
if !sub.Filters.Match(env.Event) {
log.D.F(
"{%s} filter does not match: %v ~ %v\n", r.URL,
sub.Filters, env.Event,
}
if !r.AssumeValid {
if ok, err = env.Event.Verify(); !ok || chk.E(err) {
log.I.F(
"{%s} bad signature on %s\n", r.URL,
env.Event.ID,
)
continue
}
// check signature, ignore invalid, except from trusted
// (AssumeValid) relays
if !r.AssumeValid {
if ok = r.signatureChecker(env.Event); !ok {
log.E.F(
"{%s} bad signature on %s\n", r.URL,
env.Event.ID,
)
continue
}
}
// dispatch this to the internal .events channel of the
// subscription
sub.dispatchEvent(env.Event)
}
sub.dispatchEvent(env.Event)
case eoseenvelope.L:
env := eoseenvelope.New()
if env, message, err = eoseenvelope.Parse(message); chk.E(err) {
var env *eoseenvelope.T
if env, rem, err = eoseenvelope.Parse(rem); chk.E(err) {
continue
}
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
subscription.dispatchEose()
if len(rem) != 0 {
log.W.F(
"{%s} unexpected data after EOSE: %s\n", r.URL,
string(rem),
)
}
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
env.Subscription.String(),
)
continue
}
sub.dispatchEose()
case closedenvelope.L:
env := closedenvelope.New()
if env, message, err = closedenvelope.Parse(message); chk.E(err) {
var env *closedenvelope.T
if env, rem, err = closedenvelope.Parse(rem); chk.E(err) {
continue
}
if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok {
subscription.dispatchClosed(env.ReasonString())
}
case countenvelope.L:
env := countenvelope.NewResponse()
if env, message, err = countenvelope.Parse(message); chk.E(err) {
sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String()))
if !ok {
log.W.F(
"unknown subscription with id '%s'\n",
env.Subscription.String(),
)
continue
}
if subscription, ok := r.Subscriptions.Load(env.ID.String()); ok && subscription.countResult != nil {
subscription.countResult <- env.Count
}
sub.handleClosed(env.ReasonString())
case okenvelope.L:
env := okenvelope.New()
if env, message, err = okenvelope.Parse(message); chk.E(err) {
var env *okenvelope.T
if env, rem, err = okenvelope.Parse(rem); chk.E(err) {
continue
}
if okCallback, exist := r.okCallbacks.Load(env.EventID.String()); exist {
okCallback(env.OK, env.ReasonString())
okCallback(env.OK, string(env.Reason))
} else {
log.I.F(
"{%s} got an unexpected OK message for event %s", r.URL,
env.EventID,
)
}
default:
log.W.F("unknown envelope type %s\n%s", t, rem)
continue
}
}
}()
return nil
return
}
// Write queues a message to be sent to the relay.
// Write queues an arbitrary message to be sent to the relay.
func (r *Client) Write(msg []byte) <-chan error {
ch := make(chan error)
select {
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
case <-r.connectionContext.Done():
go func() { ch <- errorf.E("connection closed") }()
go func() { ch <- fmt.Errorf("connection closed") }()
}
return ch
}
// Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an
// OK response.
func (r *Client) Publish(c context.T, ev *event.E) error {
return r.publish(
c, ev,
)
func (r *Client) Publish(ctx context.T, ev *event.E) error {
return r.publish(ctx, ev.ID, ev)
}
// Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK
// response.
func (r *Client) Auth(c context.T, sign signer.I) error {
authEvent := auth.CreateUnsigned(sign.Pub(), r.challenge, r.URL)
if err := authEvent.Sign(sign); chk.T(err) {
return errorf.E("error signing auth event: %w", err)
func (r *Client) Auth(
ctx context.T, sign signer.I,
) (err error) {
authEvent := &event.E{
CreatedAt: timestamp.Now(),
Kind: kind.ClientAuthentication,
Tags: tags.New(
tag.New("relay", r.URL),
tag.New([]byte("challenge"), r.challenge),
),
}
return r.publish(c, authEvent)
if err = authEvent.Sign(sign); chk.E(err) {
err = fmt.Errorf("error signing auth event: %w", err)
return
}
return r.publish(ctx, authEvent.ID, authEvent)
}
// publish can be used both for EVENT and for AUTH
func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
func (r *Client) publish(
ctx context.T, id []byte, ev *event.E,
) error {
var err error
var cancel context.F
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
ctx, cancel = context.TimeoutCause(
ctx, 7*time.Second,
errorf.E("given up waiting for an OK"),
ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"),
)
defer cancel()
} else {
@@ -398,32 +416,24 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
ctx, cancel = context.Cancel(ctx)
defer cancel()
}
// listen for an OK callback
gotOk := false
id := ev.IdString()
ids := hex.Enc(id)
r.okCallbacks.Store(
id, func(ok bool, reason string) {
ids, func(ok bool, reason string) {
gotOk = true
if !ok {
err = errorf.E("msg: %s", reason)
err = fmt.Errorf("msg: %s", reason)
}
cancel()
},
)
defer r.okCallbacks.Delete(id)
defer r.okCallbacks.Delete(ids)
// publish event
var b []byte
if ev.Kind.Equal(kind.ClientAuthentication) {
if b = authenvelope.NewResponseWith(ev).Marshal(b); chk.E(err) {
return
}
} else {
if b = eventenvelope.NewSubmissionWith(ev).Marshal(b); chk.E(err) {
return
}
}
log.T.F("{%s} sending %s\n", r.URL, b)
if err = <-r.Write(b); chk.T(err) {
envb := eventenvelope.NewSubmissionWith(ev).Marshal(nil)
// envb := ev.Marshal(nil)
if err = <-r.Write(envb); err != nil {
return err
}
for {
@@ -447,111 +457,168 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
// context ctx is cancelled ("CLOSE" in NIP-01).
//
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or
// ensuring their `context.Context` will be canceled at some point. Failure to
// ensuring their `context.T` will be canceled at some point. Failure to
// do that will result in a huge number of halted goroutines being created.
func (r *Client) Subscribe(
c context.T, ff *filters.T,
opts ...SubscriptionOption,
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
) (sub *Subscription, err error) {
sub = r.PrepareSubscription(c, ff, opts...)
sub = r.PrepareSubscription(ctx, ff, opts...)
if r.Connection == nil {
return nil, errorf.E("not connected to %s", r.URL)
return nil, fmt.Errorf("not connected to %s", r.URL)
}
if err = sub.Fire(); chk.T(err) {
return nil, errorf.E(
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
if err = sub.Fire(); err != nil {
err = fmt.Errorf(
"couldn't subscribe to %v at %s: %w", ff.Marshal(nil), r.URL, err,
)
return
}
return
}
// PrepareSubscription creates a subscription, but doesn't fire it.
//
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or
// ensuring their `context.Context` will be canceled at some point. Failure to
// do that will result in a huge number of halted goroutines being created.
// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.T` will be canceled at some point.
// Failure to do that will result in a huge number of halted goroutines being created.
func (r *Client) PrepareSubscription(
c context.T, ff *filters.T,
opts ...SubscriptionOption,
ctx context.T, ff *filters.T, opts ...SubscriptionOption,
) *Subscription {
current := subscriptionIDCounter.Add(1)
c, cancel := context.Cancel(c)
ctx, cancel := context.Cause(ctx)
sub := &Subscription{
Relay: r,
Context: c,
Context: ctx,
cancel: cancel,
counter: int(current),
counter: current,
Events: make(event.C),
EndOfStoredEvents: make(chan struct{}, 1),
ClosedReason: make(chan string, 1),
Filters: ff,
match: ff.Match,
}
label := ""
for _, opt := range opts {
switch o := opt.(type) {
case WithLabel:
sub.label = string(o)
label = string(o)
// case WithCheckDuplicate:
// sub.checkDuplicate = o
// case WithCheckDuplicateReplaceable:
// sub.checkDuplicateReplaceable = o
}
}
id := sub.GetID()
r.Subscriptions.Store(id.String(), sub)
// subscription id computation
buf := subIdPool.Get().([]byte)[:0]
buf = strconv.AppendInt(buf, sub.counter, 10)
buf = append(buf, ':')
buf = append(buf, label...)
defer subIdPool.Put(buf)
sub.id = string(buf)
// we track subscriptions only by their counter, no need for the full id
r.Subscriptions.Store(int64(sub.counter), sub)
// start handling events, eose, unsub etc:
go sub.start()
return sub
}
// QuerySync is only used in tests. The relay query method is synchronous now
// anyway (it ensures sort order is respected).
func (r *Client) QuerySync(
ctx context.T, f *filter.F,
opts ...SubscriptionOption,
) ([]*event.E, error) {
// log.T.F("QuerySync:\n%s", f.Marshal(nil))
sub, err := r.Subscribe(ctx, filters.New(f), opts...)
// QueryEvents subscribes to events matching the given filter and returns a channel of events.
//
// In most cases it's better to use Pool instead of this method.
func (r *Client) QueryEvents(ctx context.T, f *filter.F) (
evc event.C, err error,
) {
var sub *Subscription
if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) {
return
}
go func() {
for {
select {
case <-sub.ClosedReason:
case <-sub.EndOfStoredEvents:
case <-ctx.Done():
case <-r.Context().Done():
}
sub.unsub(errors.New("QueryEvents() ended"))
return
}
}()
return sub.Events, nil
}
// QuerySync subscribes to events matching the given filter and returns a slice
// of events. This method blocks until all events are received or the context is
// canceled.
//
// If the filter causes a subscription to open, it will stay open until the
// limit is exceeded. So this method will return an error if the limit is nil.
// If the query blocks, the caller needs to cancel the context to prevent the
// thread stalling.
func (r *Client) QuerySync(ctx context.T, f *filter.F) (
evs event.S, err error,
) {
if f.Limit == nil {
err = errors.New("limit must be set for a sync query to prevent blocking")
return
}
var sub *Subscription
if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) {
return
}
defer sub.unsub(errors.New("QuerySync() ended"))
evs = make(event.S, 0, *f.Limit)
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.F
ctx, cancel = context.TimeoutCause(
ctx, 7*time.Second, errors.New("QuerySync() took too long"),
)
defer cancel()
}
lim := 250
if f.Limit != nil {
lim = int(*f.Limit)
}
events := make(event.S, 0, max(lim, 250))
ch, err := r.QueryEvents(ctx, f)
if err != nil {
return nil, err
}
defer sub.Unsub()
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.F
ctx, cancel = context.Timeout(ctx, 7*time.Second)
defer cancel()
for evt := range ch {
events = append(events, evt)
}
var events []*event.E
for {
select {
case evt := <-sub.Events:
if evt == nil {
// channel is closed
return events, nil
}
events = append(events, evt)
case <-sub.EndOfStoredEvents:
return events, nil
case <-ctx.Done():
return events, nil
}
}
return events, nil
}
// TODO: count is a dumb idea anyway, and nothing is using this
// func (r *Client) Count(c context.F, ff *filters.F, opts ...SubscriptionOption) (int, error) {
// sub := r.PrepareSubscription(c, ff, opts...)
// sub.countResult = make(chan int)
//
// if err := sub.Fire(); chk.F(err) {
// return 0, err
// // Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
// func (r *Relay) Count(
// ctx context.T,
// filters Filters,
// opts ...SubscriptionOption,
// ) (int64, []byte, error) {
// v, err := r.countInternal(ctx, filters, opts...)
// if err != nil {
// return 0, nil, err
// }
//
// defer sub.Unsub()
// return *v.Count, v.HyperLogLog, nil
// }
//
// if _, ok := c.Deadline(); !ok {
// func (r *Relay) countInternal(ctx context.T, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
// sub := r.PrepareSubscription(ctx, filters, opts...)
// sub.countResult = make(chan CountEnvelope)
//
// if err := sub.Fire(); err != nil {
// return CountEnvelope{}, err
// }
//
// defer sub.unsub(errors.New("countInternal() ended"))
//
// if _, ok := ctx.Deadline(); !ok {
// // if no timeout is set, force it to 7 seconds
// var cancel context.F
// c, cancel = context.Timeout(c, 7*time.Second)
// var cancel context.CancelFunc
// ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long"))
// defer cancel()
// }
//
@@ -559,28 +626,39 @@ func (r *Client) QuerySync(
// select {
// case count := <-sub.countResult:
// return count, nil
// case <-c.Done():
// return 0, c.Err()
// case <-ctx.Done():
// return CountEnvelope{}, ctx.Err()
// }
// }
// }
// Close shuts down a websocket client connection.
// Close closes the relay connection.
func (r *Client) Close() error {
return r.close(errors.New("relay connection closed"))
}
func (r *Client) close(reason error) error {
r.closeMutex.Lock()
defer r.closeMutex.Unlock()
if r.connectionContextCancel == nil {
return errorf.E("relay already closed")
return fmt.Errorf("relay already closed")
}
r.connectionContextCancel()
r.connectionContextCancel(reason)
r.connectionContextCancel = nil
if r.Connection == nil {
return errorf.E("relay not connected")
return fmt.Errorf("relay not connected")
}
err := r.Connection.Close()
r.Connection = nil
if err != nil {
return err
}
return nil
}
var subIdPool = sync.Pool{
New: func() any { return make([]byte, 0, 15) },
}

View File

@@ -1,146 +1,159 @@
package ws
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/envelopes/eventenvelope"
"orly.dev/pkg/encoders/envelopes/okenvelope"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/normalize"
"sync"
"testing"
"time"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/normalize"
"golang.org/x/net/websocket"
)
func TestPublish(t *testing.T) {
// test note to be sent over websocket
var err error
signer := &p256k.Signer{}
if err = signer.Generate(); chk.E(err) {
t.Fatal(err)
}
textNote := &event.E{
Kind: kind.TextNote,
Content: []byte("hello"),
CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
Tags: tags.New(tag.New("foo", "bar")),
Pubkey: signer.Pub(),
}
if err = textNote.Sign(signer); chk.E(err) {
t.Fatalf("textNote.Sign: %v", err)
}
// fake relay server
var mu sync.Mutex // guards published to satisfy `go test -race`
var published bool
ws := newWebsocketServer(
func(conn *websocket.Conn) {
mu.Lock()
published = true
mu.Unlock()
// verify the client sent exactly the textNote
var raw []json.RawMessage
if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
t.Errorf("websocket.JSON.Receive: %v", err)
}
if string(raw[0]) != fmt.Sprintf(`"%s"`, eventenvelope.L) {
t.Errorf("got type %s, want %s", raw[0], eventenvelope.L)
}
env := eventenvelope.NewSubmission()
if raw[1], err = env.Unmarshal(raw[1]); chk.E(err) {
t.Fatal(err)
}
if !bytes.Equal(env.E.Serialize(), textNote.Serialize()) {
t.Errorf(
"received event:\n%s\nwant:\n%s", env.E.Serialize(),
textNote.Serialize(),
)
}
// send back an ok nip-20 command result
var res []byte
if res = okenvelope.NewFrom(
textNote.ID, true, nil,
).Marshal(res); chk.E(err) {
t.Fatal(err)
}
if err := websocket.Message.Send(conn, res); chk.T(err) {
t.Errorf("websocket.JSON.Send: %v", err)
}
},
)
defer ws.Close()
// connect a client and send the text note
rl := mustRelayConnect(ws.URL)
err = rl.Publish(context.Bg(), textNote)
if err != nil {
t.Errorf("publish should have succeeded")
}
if !published {
t.Errorf("fake relay server saw no event")
}
}
func TestPublishBlocked(t *testing.T) {
// test note to be sent over websocket
var err error
signer := &p256k.Signer{}
if err = signer.Generate(); chk.E(err) {
t.Fatal(err)
}
textNote := &event.E{
Kind: kind.TextNote,
Content: []byte("hello"),
CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
Pubkey: signer.Pub(),
}
if err = textNote.Sign(signer); chk.E(err) {
t.Fatalf("textNote.Sign: %v", err)
}
// fake relay server
ws := newWebsocketServer(
func(conn *websocket.Conn) {
// discard received message; not interested
var raw []json.RawMessage
if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
t.Errorf("websocket.JSON.Receive: %v", err)
}
// send back a not ok nip-20 command result
var res []byte
if res = okenvelope.NewFrom(
textNote.ID, false,
normalize.Msg(normalize.Blocked, "no reason"),
).Marshal(res); chk.E(err) {
t.Fatal(err)
}
if err := websocket.Message.Send(conn, res); chk.T(err) {
t.Errorf("websocket.JSON.Send: %v", err)
}
// res := []any{"OK", textNote.ID, false, "blocked"}
chk.E(websocket.JSON.Send(conn, res))
},
)
defer ws.Close()
// connect a client and send a text note
rl := mustRelayConnect(ws.URL)
if err = rl.Publish(context.Bg(), textNote); !chk.E(err) {
t.Errorf("should have failed to publish")
}
}
// func TestPublish(t *testing.T) {
// // test note to be sent over websocket
// var err error
// signer := &p256k.Signer{}
// if err = signer.Generate(); chk.E(err) {
// t.Fatal(err)
// }
// textNote := &event.E{
// Kind: kind.TextNote,
// Content: []byte("hello"),
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
// Pubkey: signer.Pub(),
// }
// if err = textNote.Sign(signer); chk.E(err) {
// t.Fatalf("textNote.Sign: %v", err)
// }
// // fake relay server
// var published bool
// ws := newWebsocketServer(
// func(conn *websocket.Conn) {
// // receive message
// var raw []json.RawMessage
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
// t.Errorf("websocket.JSON.Receive: %v", err)
// }
// // check that it's an EVENT message
// if len(raw) < 2 {
// t.Errorf("message too short: %v", raw)
// }
// var msgType string
// if err := json.Unmarshal(raw[0], &msgType); chk.T(err) {
// t.Errorf("json.Unmarshal: %v", err)
// }
// if msgType != "EVENT" {
// t.Errorf("expected EVENT message, got %q", msgType)
// }
// // check that the event is the one we sent
// var ev event.E
// if err := json.Unmarshal(raw[1], &ev); chk.T(err) {
// t.Errorf("json.Unmarshal: %v", err)
// }
// published = true
// if !bytes.Equal(ev.ID, textNote.ID) {
// t.Errorf(
// "event ID mismatch: got %x, want %x",
// ev.ID, textNote.ID,
// )
// }
// if !bytes.Equal(ev.Pubkey, textNote.Pubkey) {
// t.Errorf(
// "event pubkey mismatch: got %x, want %x",
// ev.Pubkey, textNote.Pubkey,
// )
// }
// if !bytes.Equal(ev.Content, textNote.Content) {
// t.Errorf(
// "event content mismatch: got %q, want %q",
// ev.Content, textNote.Content,
// )
// }
// fmt.Printf(
// "received event: %s\n",
// textNote.Serialize(),
// )
// // send back an ok nip-20 command result
// var res []byte
// if res = okenvelope.NewFrom(
// textNote.ID, true, nil,
// ).Marshal(res); chk.E(err) {
// t.Fatal(err)
// }
// if err := websocket.Message.Send(conn, res); chk.T(err) {
// t.Errorf("websocket.Message.Send: %v", err)
// }
// },
// )
// defer ws.Close()
// // connect a client and send the text note
// rl := mustRelayConnect(ws.URL)
// err = rl.Publish(context.Background(), textNote)
// if err != nil {
// t.Errorf("publish should have succeeded")
// }
// if !published {
// t.Errorf("fake relay server saw no event")
// }
// }
//
// func TestPublishBlocked(t *testing.T) {
// // test note to be sent over websocket
// var err error
// signer := &p256k.Signer{}
// if err = signer.Generate(); chk.E(err) {
// t.Fatal(err)
// }
// textNote := &event.E{
// Kind: kind.TextNote,
// Content: []byte("hello"),
// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp
// Pubkey: signer.Pub(),
// }
// if err = textNote.Sign(signer); chk.E(err) {
// t.Fatalf("textNote.Sign: %v", err)
// }
// // fake relay server
// ws := newWebsocketServer(
// func(conn *websocket.Conn) {
// // discard received message; not interested
// var raw []json.RawMessage
// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) {
// t.Errorf("websocket.JSON.Receive: %v", err)
// }
// // send back a not ok nip-20 command result
// var res []byte
// if res = okenvelope.NewFrom(
// textNote.ID, false,
// normalize.Msg(normalize.Blocked, "no reason"),
// ).Marshal(res); chk.E(err) {
// t.Fatal(err)
// }
// if err := websocket.Message.Send(conn, res); chk.T(err) {
// t.Errorf("websocket.Message.Send: %v", err)
// }
// // res := []any{"OK", textNote.ID, false, "blocked"}
// },
// )
// defer ws.Close()
//
// // connect a client and send a text note
// rl := mustRelayConnect(ws.URL)
// if err = rl.Publish(context.Background(), textNote); !chk.E(err) {
// t.Errorf("should have failed to publish")
// }
// }
func TestPublishWriteFailed(t *testing.T) {
// test note to be sent over websocket
@@ -171,7 +184,7 @@ func TestPublishWriteFailed(t *testing.T) {
rl := mustRelayConnect(ws.URL)
// Force brief period of time so that publish always fails on closed socket.
time.Sleep(1 * time.Millisecond)
err = rl.Publish(context.Bg(), textNote)
err = rl.Publish(context.Background(), textNote)
if err == nil {
t.Errorf("should have failed to publish")
}
@@ -192,7 +205,7 @@ func TestConnectContext(t *testing.T) {
defer ws.Close()
// relay client
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
r, err := RelayConnect(ctx, ws.URL)
if err != nil {
@@ -213,7 +226,7 @@ func TestConnectContextCanceled(t *testing.T) {
defer ws.Close()
// relay client
ctx, cancel := context.Cancel(context.Bg())
ctx, cancel := context.WithCancel(context.Background())
cancel() // make ctx expired
_, err := RelayConnect(ctx, ws.URL)
if !errors.Is(err, context.Canceled) {
@@ -230,9 +243,9 @@ func TestConnectWithOrigin(t *testing.T) {
defer ws.Close()
// relay client
r := NewRelay(context.Bg(), string(normalize.URL(ws.URL)))
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
r := NewRelay(context.Background(), string(normalize.URL(ws.URL)))
r.requestHeader = http.Header{"origin": {"https://example.com"}}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := r.Connect(ctx)
if err != nil {
@@ -263,7 +276,7 @@ var anyOriginHandshake = func(
}
func mustRelayConnect(url string) (client *Client) {
rl, err := RelayConnect(context.Bg(), url)
rl, err := RelayConnect(context.Background(), url)
if err != nil {
panic(err.Error())
}

View File

@@ -1,224 +1,98 @@
package ws
import (
"bytes"
"compress/flate"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
"io"
"net"
"net/http"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"net/textproto"
"time"
ws "github.com/coder/websocket"
)
// Connection is an outbound client -> relay connection.
type Connection struct {
conn net.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
// NewConnection creates a new Connection.
func NewConnection(
c context.T, url string, requestHeader http.Header,
tlsConfig *tls.Config,
) (connection *Connection, errResult error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
func getConnectionOptions(
requestHeader http.Header, tlsConfig *tls.Config,
) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
conn, _, hs, err := dialer.Dial(c, url)
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}
// Connection represents a websocket connection to a Nostr relay.
type Connection struct {
conn *ws.Conn
}
// NewConnection creates a new websocket connection to a Nostr relay.
func NewConnection(
ctx context.Context, url string, requestHeader http.Header,
tlsConfig *tls.Config,
) (*Connection, error) {
c, _, err := ws.Dial(
ctx, url, getConnectionOptions(requestHeader, tlsConfig),
)
if err != nil {
return nil, err
}
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(
nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
},
)
}
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
c.SetReadLimit(2 << 24) // 33MB
flateWriter = wsflate.NewWriter(
nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
log.E.F("Failed to create flate writer: %v", err)
}
return fw
},
)
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)
return &Connection{
conn: conn,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
conn: c,
}, nil
}
// WriteMessage dispatches a message through the Connection.
func (cn *Connection) WriteMessage(c context.T, data []byte) (err error) {
select {
case <-c.Done():
return errorf.E(
"%s context canceled",
cn.conn.RemoteAddr(),
)
default:
// WriteMessage writes arbitrary bytes to the websocket connection.
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
if cn.msgStateW.IsCompressed() && cn.enableCompression {
cn.flateWriter.Reset(cn.writer)
if _, err := io.Copy(
cn.flateWriter, bytes.NewReader(data),
); chk.T(err) {
return errorf.E(
"%s failed to write message: %w",
cn.conn.RemoteAddr(),
err,
)
}
if err := cn.flateWriter.Close(); chk.T(err) {
return errorf.E(
"%s failed to close flate writer: %w",
cn.conn.RemoteAddr(),
err,
)
}
} else {
if _, err := io.Copy(cn.writer, bytes.NewReader(data)); chk.T(err) {
return errorf.E(
"%s failed to write message: %w",
cn.conn.RemoteAddr(),
err,
)
}
return nil
}
// ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer.
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
}
if err := cn.writer.Flush(); chk.T(err) {
return errorf.E(
"%s failed to flush writer: %w",
cn.conn.RemoteAddr(),
err,
)
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
return nil
}
// ReadMessage picks up the next incoming message on a Connection.
func (cn *Connection) ReadMessage(c context.T, buf io.Writer) (err error) {
for {
select {
case <-c.Done():
return errorf.D(
"%s context canceled",
cn.conn.RemoteAddr(),
)
default:
}
h, err := cn.reader.NextFrame()
if err != nil {
cn.conn.Close()
return fmt.Errorf(
"%s failed to advance frame: %s",
cn.conn.RemoteAddr(),
err.Error(),
)
}
if h.OpCode.IsControl() {
if err := cn.controlHandler(h, cn.reader); chk.T(err) {
return errorf.E(
"%s failed to handle control frame: %w",
cn.conn.RemoteAddr(),
err,
)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}
if err := cn.reader.Discard(); chk.T(err) {
return errorf.E(
"%s failed to discard: %w",
cn.conn.RemoteAddr(),
err,
)
}
}
if cn.msgStateR.IsCompressed() && cn.enableCompression {
cn.flateReader.Reset(cn.reader)
if _, err := io.Copy(buf, cn.flateReader); chk.T(err) {
return errorf.E(
"%s failed to read message: %w",
cn.conn.RemoteAddr(),
err,
)
}
} else {
if _, err := io.Copy(buf, cn.reader); chk.T(err) {
return errorf.E(
"%s failed to read message: %w",
cn.conn.RemoteAddr(),
err,
)
}
}
return nil
// Close closes the websocket connection.
func (c *Connection) Close() error {
return c.conn.Close(ws.StatusNormalClosure, "")
}
// Close the Connection.
func (cn *Connection) Close() (err error) {
return cn.conn.Close()
// Ping sends a ping message to the websocket connection.
func (c *Connection) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeoutCause(
ctx, time.Millisecond*800, errors.New("ping took too long"),
)
defer cancel()
return c.conn.Ping(ctx)
}

View File

@@ -3,12 +3,13 @@ package ws
import (
"net/http"
"strings"
"sync"
"orly.dev/pkg/app/relay/helpers"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/protocol/auth"
atomic2 "orly.dev/pkg/utils/atomic"
"strings"
"sync"
"github.com/fasthttp/websocket"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,216 @@
package ws
import (
"context"
"sync"
"testing"
"time"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/signer"
)
// mockSigner implements signer.I for testing
type mockSigner struct {
pubkey []byte
}
func (m *mockSigner) Pub() []byte { return m.pubkey }
func (m *mockSigner) Sign([]byte) (
[]byte, error,
) {
return []byte("mock-signature"), nil
}
func (m *mockSigner) Generate() error { return nil }
func (m *mockSigner) InitSec([]byte) error { return nil }
func (m *mockSigner) InitPub([]byte) error { return nil }
func (m *mockSigner) Sec() []byte { return []byte("mock-secret") }
func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil }
func (m *mockSigner) Zero() {}
func (m *mockSigner) ECDH([]byte) (
[]byte, error,
) {
return []byte("mock-shared-secret"), nil
}
func TestNewPool(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
if pool == nil {
t.Fatal("NewPool returned nil")
}
if pool.Relays == nil {
t.Error("Pool should have initialized Relays map")
}
if pool.Context == nil {
t.Error("Pool should have a context")
}
}
func TestPoolWithAuthHandler(t *testing.T) {
ctx := context.Background()
authHandler := WithAuthHandler(
func() signer.I {
return &mockSigner{pubkey: []byte("test-pubkey")}
},
)
pool := NewPool(ctx, authHandler)
if pool.authHandler == nil {
t.Error("Pool should have auth handler set")
}
// Test that auth handler returns the expected signer
signer := pool.authHandler()
if string(signer.Pub()) != "test-pubkey" {
t.Errorf(
"Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()),
)
}
}
func TestPoolWithEventMiddleware(t *testing.T) {
ctx := context.Background()
var middlewareCalled bool
middleware := WithEventMiddleware(
func(ie RelayEvent) {
middlewareCalled = true
},
)
pool := NewPool(ctx, middleware)
// Test that middleware is called
testEvent := &event.E{
Kind: kind.TextNote,
Content: []byte("test"),
CreatedAt: timestamp.Now(),
}
ie := RelayEvent{E: testEvent, Relay: nil}
pool.eventMiddleware(ie)
if !middlewareCalled {
t.Error("Expected middleware to be called")
}
}
func TestRelayEventString(t *testing.T) {
testEvent := &event.E{
Kind: kind.TextNote,
Content: []byte("test content"),
CreatedAt: timestamp.Now(),
}
client := &Client{URL: "wss://test.relay"}
ie := RelayEvent{E: testEvent, Relay: client}
str := ie.String()
if !contains(str, "wss://test.relay") {
t.Errorf("Expected string to contain relay URL, got: %s", str)
}
if !contains(str, "test content") {
t.Errorf("Expected string to contain event content, got: %s", str)
}
}
func TestNamedLock(t *testing.T) {
// Test that named locks work correctly
var wg sync.WaitGroup
var counter int
var mu sync.Mutex
lockName := "test-lock"
// Start multiple goroutines that try to increment counter
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
unlock := namedLock(lockName)
defer unlock()
// Critical section
mu.Lock()
temp := counter
time.Sleep(1 * time.Millisecond) // Simulate work
counter = temp + 1
mu.Unlock()
}()
}
wg.Wait()
if counter != 10 {
t.Errorf("Expected counter to be 10, got %d", counter)
}
}
func TestPoolEnsureRelayInvalidURL(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with invalid URL
_, err := pool.EnsureRelay("invalid-url")
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestPoolQuerySingle(t *testing.T) {
ctx := context.Background()
pool := NewPool(ctx)
// Test with empty URLs slice
result := pool.QuerySingle(ctx, []string{}, &filter.F{})
if result != nil {
t.Error("Expected nil result for empty URLs")
}
}
// Helper functions
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr)))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func uintPtr(u uint) *uint {
return &u
}
// Test pool context cancellation
func TestPoolContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := NewPool(ctx)
// Cancel the context
cancel()
// Check that pool context is cancelled
select {
case <-pool.Context.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Expected pool context to be cancelled")
}
}

View File

@@ -1,100 +1,105 @@
package ws
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"orly.dev/pkg/encoders/envelopes/closeenvelope"
"orly.dev/pkg/encoders/envelopes/countenvelope"
"orly.dev/pkg/encoders/envelopes/reqenvelope"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/subscription"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/errorf"
"strconv"
"sync"
"sync/atomic"
)
// Subscription is a client interface for a subscription (what REQ turns into
// after EOSE).
// Subscription represents a subscription to a relay.
type Subscription struct {
label string
counter int
counter int64
id string
Relay *Client
Filters *filters.T
// for this to be treated as a COUNT and not a REQ this must be set
countResult chan int
// // for this to be treated as a COUNT and not a REQ this must be set
// countResult chan CountEnvelope
// The Events channel emits all EVENTs that come in a Subscription will be
// closed when the subscription ends
// the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
Events event.C
mu sync.Mutex
// The EndOfStoredEvents channel is closed when an EOSE comes for that
// subscription
// the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription
EndOfStoredEvents chan struct{}
// The ClosedReason channel emits the reason when a CLOSED message is
// received
// the ClosedReason channel emits the reason when a CLOSED message is received
ClosedReason chan string
// Context will be .Done() when the subscription ends
Context context.T
Context context.Context
// // if it is not nil, checkDuplicate will be called for every event received
// // if it returns true that event will not be processed further.
// checkDuplicate func(id string, relay string) bool
//
// // if it is not nil, checkDuplicateReplaceable will be called for every event received
// // if it returns true that event will not be processed further.
// checkDuplicateReplaceable func(rk ReplaceableKey, ts Timestamp) bool
match func(*event.E) bool // this will be either Filters.Match or Filters.MatchIgnoringTimestampConstraints
live atomic.Bool
eosed atomic.Bool
closed atomic.Bool
cancel context.F
cancel context.CancelCauseFunc
// This keeps track of the events we've received before the EOSE that we
// must dispatch before closing the EndOfStoredEvents channel
// this keeps track of the events we've received before the EOSE that we must dispatch before
// closing the EndOfStoredEvents channel
storedwg sync.WaitGroup
}
// EventMessage is an event, with the associated relay URL attached.
type EventMessage struct {
Event event.E
Relay string
}
// SubscriptionOption is the type of the argument passed for that. Some examples
// are WithLabel.
// SubscriptionOption is the type of the argument passed when instantiating relay connections.
// Some examples are WithLabel.
type SubscriptionOption interface {
IsSubscriptionOption()
}
// WithLabel puts a label on the subscription (it is prepended to the automatic
// id) that is sent to relays.
// WithLabel puts a label on the subscription (it is prepended to the automatic id) that is sent to relays.
type WithLabel string
func (_ WithLabel) IsSubscriptionOption() {}
var _ SubscriptionOption = (WithLabel)("")
// // WithCheckDuplicate sets checkDuplicate on the subscription
// type WithCheckDuplicate func(id, relay string) bool
//
// func (_ WithCheckDuplicate) IsSubscriptionOption() {}
//
// // WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription
// type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool
//
// func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {}
// GetID return the Nostr subscription ID as given to the Client it is a
// concatenation of the label and a serial number.
func (sub *Subscription) GetID() (id *subscription.Id) {
var err error
if id, err = subscription.NewId(sub.label + ":" + strconv.Itoa(sub.counter)); chk.E(err) {
return
}
return
}
var (
_ SubscriptionOption = (WithLabel)("")
// _ SubscriptionOption = (WithCheckDuplicate)(nil)
// _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil)
)
func (sub *Subscription) start() {
<-sub.Context.Done()
// the subscription ends once the context is canceled (if not already)
sub.Unsub() // this will set sub.live to false
// do this so we don't have the possibility of closing the Events channel
// and then trying to send to it
// the subscription ends once the context is canceled (if not already)
sub.unsub(errors.New("context done on start()")) // this will set sub.live to false
// do this so we don't have the possibility of closing the Events channel and then trying to send to it
sub.mu.Lock()
close(sub.Events)
sub.mu.Unlock()
}
// GetID returns the subscription ID.
func (sub *Subscription) GetID() string { return sub.id }
func (sub *Subscription) dispatchEvent(evt *event.E) {
added := false
if !sub.eosed.Load() {
@@ -112,7 +117,6 @@ func (sub *Subscription) dispatchEvent(evt *event.E) {
case <-sub.Context.Done():
}
}
if added {
sub.storedwg.Done()
}
@@ -121,6 +125,7 @@ func (sub *Subscription) dispatchEvent(evt *event.E) {
func (sub *Subscription) dispatchEose() {
if sub.eosed.CompareAndSwap(false, true) {
sub.match = sub.Filters.MatchIgnoringTimestampConstraints
go func() {
sub.storedwg.Wait()
sub.EndOfStoredEvents <- struct{}{}
@@ -128,62 +133,72 @@ func (sub *Subscription) dispatchEose() {
}
}
func (sub *Subscription) dispatchClosed(reason string) {
if sub.closed.CompareAndSwap(false, true) {
go func() {
sub.ClosedReason <- reason
}()
}
// handleClosed handles the CLOSED message from a relay.
func (sub *Subscription) handleClosed(reason string) {
go func() {
sub.ClosedReason <- reason
sub.live.Store(false) // set this so we don't send an unnecessary CLOSE to the relay
sub.unsub(fmt.Errorf("CLOSED received: %s", reason))
}()
}
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. Unsub()
// also closes the channel sub.Events and makes a new one.
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01.
// Unsub() also closes the channel sub.Events and makes a new one.
func (sub *Subscription) Unsub() {
sub.unsub(errors.New("Unsub() called"))
}
// unsub is the internal implementation of Unsub.
func (sub *Subscription) unsub(err error) {
// cancel the context (if it's not canceled already)
sub.cancel()
// mark the subscription as closed and send a CLOSE to the relay (naïve
// sync.Once implementation)
sub.cancel(err)
// mark subscription as closed and send a CLOSE to the relay (naïve sync.Once implementation)
if sub.live.CompareAndSwap(true, false) {
sub.Close()
}
// remove subscription from our map
sub.Relay.Subscriptions.Delete(sub.GetID().String())
sub.Relay.Subscriptions.Delete(sub.counter)
}
// Close just sends a CLOSE message. You probably want Unsub() instead.
func (sub *Subscription) Close() {
if sub.Relay.IsConnected() {
id := sub.GetID()
id, err := subscription.NewId(sub.id)
if err != nil {
return
}
closeMsg := closeenvelope.NewFrom(id)
var b []byte
b = closeMsg.Marshal(nil)
<-sub.Relay.Write(b)
closeb := closeMsg.Marshal(nil)
<-sub.Relay.Write(closeb)
}
}
// Sub sets sub.Filters and then calls sub.Fire(ctx). The subscription will be
// closed if the context expires.
func (sub *Subscription) Sub(_ context.T, ff *filters.T) {
// Sub sets sub.Filters and then calls sub.Fire(ctx).
// The subscription will be closed if the context expires.
func (sub *Subscription) Sub(_ context.Context, ff *filters.T) {
sub.Filters = ff
sub.Fire()
}
// Fire sends the "REQ" command to the relay.
func (sub *Subscription) Fire() (err error) {
id := sub.GetID()
var b []byte
if sub.countResult == nil {
b = reqenvelope.NewFrom(id, sub.Filters).Marshal(b)
} else {
b = countenvelope.NewRequest(id, sub.Filters).Marshal(b)
// if sub.countResult == nil {
req := reqenvelope.NewWithIdString(sub.id, sub.Filters)
if req == nil {
return fmt.Errorf("invalid ID or filters")
}
// log.T.F("{%s} sending %s", sub.Relay.URL, b)
reqb := req.Marshal(nil)
// } else
// if len(sub.Filters) == 1 {
// reqb, _ = CountEnvelope{sub.id, sub.Filters[0], nil, nil}.MarshalJSON()
// } else {
// return fmt.Errorf("unexpected sub configuration")
sub.live.Store(true)
if err = <-sub.Relay.Write(b); chk.T(err) {
sub.cancel()
return errorf.E("failed to write: %w", err)
if err = <-sub.Relay.Write(reqb); chk.E(err) {
err = fmt.Errorf("failed to write: %w", err)
sub.cancel(err)
}
return nil
return
}

View File

@@ -108,7 +108,7 @@ func TestBytesConcurrentAccess(t *testing.T) {
loaded := atom.Load()
// Verify the loaded data is valid (either our data or another goroutine's data)
require.LessOrEqual(t, len(loaded), parallelism,
require.LessOrEqual(t, len(loaded), parallelism,
"Loaded data length should not exceed parallelism")
// If it's our data, verify it's correct

View File

@@ -28,8 +28,10 @@ var (
TODO = context.TODO
// Value - context.WithValue
Value = context.WithValue
// CancelCause - context.WithCancelCause
CancelCause = context.WithCancelCause
// Cause - context.WithCancelCause
Cause = context.WithCancelCause
GetCause = context.Cause
// Canceled - context.Canceled
Canceled = context.Canceled
)

View File

@@ -5,8 +5,11 @@ import (
"time"
)
// PointerToValue is a generic interface to refer to any pointer to almost any kind of common
// type of value.
// PointerToValue is a generic interface (type constraint) to refer to any
// pointer to almost any kind of common type of value.
//
// see the utils/values package for a set of methods to accept these values and
// return the correct type pointer to them.
type PointerToValue interface {
~*uint | ~*int | ~*uint8 | ~*uint16 | ~*uint32 | ~*uint64 | ~*int8 | ~*int16 | ~*int32 |
~*int64 | ~*float32 | ~*float64 | ~*string | ~*[]string | ~*time.Time | ~*time.Duration |

101
pkg/utils/values/values.go Normal file
View File

@@ -0,0 +1,101 @@
package values
import (
"orly.dev/pkg/encoders/unix"
"time"
)
// ToUintPointer returns a pointer to the uint value passed in.
func ToUintPointer(v uint) *uint {
return &v
}
// ToIntPointer returns a pointer to the int value passed in.
func ToIntPointer(v int) *int {
return &v
}
// ToUint8Pointer returns a pointer to the uint8 value passed in.
func ToUint8Pointer(v uint8) *uint8 {
return &v
}
// ToUint16Pointer returns a pointer to the uint16 value passed in.
func ToUint16Pointer(v uint16) *uint16 {
return &v
}
// ToUint32Pointer returns a pointer to the uint32 value passed in.
func ToUint32Pointer(v uint32) *uint32 {
return &v
}
// ToUint64Pointer returns a pointer to the uint64 value passed in.
func ToUint64Pointer(v uint64) *uint64 {
return &v
}
// ToInt8Pointer returns a pointer to the int8 value passed in.
func ToInt8Pointer(v int8) *int8 {
return &v
}
// ToInt16Pointer returns a pointer to the int16 value passed in.
func ToInt16Pointer(v int16) *int16 {
return &v
}
// ToInt32Pointer returns a pointer to the int32 value passed in.
func ToInt32Pointer(v int32) *int32 {
return &v
}
// ToInt64Pointer returns a pointer to the int64 value passed in.
func ToInt64Pointer(v int64) *int64 {
return &v
}
// ToFloat32Pointer returns a pointer to the float32 value passed in.
func ToFloat32Pointer(v float32) *float32 {
return &v
}
// ToFloat64Pointer returns a pointer to the float64 value passed in.
func ToFloat64Pointer(v float64) *float64 {
return &v
}
// ToStringPointer returns a pointer to the string value passed in.
func ToStringPointer(v string) *string {
return &v
}
// ToStringSlicePointer returns a pointer to the []string value passed in.
func ToStringSlicePointer(v []string) *[]string {
return &v
}
// ToTimePointer returns a pointer to the time.Time value passed in.
func ToTimePointer(v time.Time) *time.Time {
return &v
}
// ToDurationPointer returns a pointer to the time.Duration value passed in.
func ToDurationPointer(v time.Duration) *time.Duration {
return &v
}
// ToBytesPointer returns a pointer to the []byte value passed in.
func ToBytesPointer(v []byte) *[]byte {
return &v
}
// ToByteSlicesPointer returns a pointer to the [][]byte value passed in.
func ToByteSlicesPointer(v [][]byte) *[][]byte {
return &v
}
// ToUnixTimePointer returns a pointer to the unix.Time value passed in.
func ToUnixTimePointer(v unix.Time) *unix.Time {
return &v
}

View File

@@ -1 +1 @@
v0.4.14
v0.5.0