From 9c8ff2976db5de9a01d612d9548864adb7603d5a Mon Sep 17 00:00:00 2001 From: mleku Date: Wed, 6 Aug 2025 22:18:26 +0100 Subject: [PATCH] backporting relay client and pool from latest go-nostr --- cmd/nurl/main.go | 3 +- cmd/walletcli/main.go | 186 +-- go.mod | 4 +- go.sum | 6 +- pkg/app/relay/spider-fetch.go | 15 +- pkg/database/get-indexes-for-event_test.go | 7 +- pkg/database/get-indexes-from-filter_test.go | 10 +- pkg/database/indexes/keys_test.go | 27 +- pkg/database/indexes/types/fullid_test.go | 6 +- pkg/database/indexes/types/identhash_test.go | 6 +- pkg/database/indexes/types/idhash_test.go | 6 +- pkg/database/indexes/types/letter_test.go | 6 +- pkg/database/indexes/types/pubhash_test.go | 6 +- pkg/database/indexes/types/timestamp_test.go | 12 +- pkg/database/indexes/types/uint16_test.go | 6 +- pkg/database/indexes/types/uint24_test.go | 7 +- pkg/database/indexes/types/uint32_test.go | 6 +- pkg/database/indexes/types/uint40_test.go | 7 +- pkg/database/indexes/types/uint64_test.go | 6 +- .../envelopes/reqenvelope/reqenvelope.go | 16 +- pkg/encoders/event/binary_test.go | 12 +- pkg/encoders/event/event_test.go | 3 +- pkg/encoders/event/json_tags_test.go | 6 +- pkg/encoders/event/json_whitespace_test.go | 3 +- pkg/encoders/filter/filter.go | 38 +- pkg/encoders/filters/filters.go | 9 + pkg/encoders/subscription/subscriptionid.go | 5 +- pkg/encoders/tags/tags.go | 14 +- pkg/encoders/varint/varint.go | 3 +- pkg/encoders/varint/varint_test.go | 6 +- pkg/protocol/nwc/client.go | 188 +-- pkg/protocol/nwc/methods.go | 290 ++--- pkg/protocol/nwc/uri.go | 8 +- pkg/protocol/ws/client.go | 656 ++++++----- pkg/protocol/ws/client_test.go | 285 ++--- pkg/protocol/ws/connection.go | 260 ++--- pkg/protocol/ws/listener.go | 5 +- pkg/protocol/ws/pool.go | 1038 ++++++++++++----- pkg/protocol/ws/pool_test.go | 216 ++++ pkg/protocol/ws/subscription.go | 178 +-- pkg/utils/context/context.go | 6 +- 41 files changed, 1954 insertions(+), 1623 deletions(-) create mode 100644 pkg/protocol/ws/pool_test.go diff --git a/cmd/nurl/main.go b/cmd/nurl/main.go index cfe066a..a865e8e 100644 --- a/cmd/nurl/main.go +++ b/cmd/nurl/main.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "net/url" + "os" + "orly.dev/pkg/crypto/p256k" "orly.dev/pkg/crypto/sha256" "orly.dev/pkg/encoders/bech32encoding" @@ -18,7 +20,6 @@ import ( "orly.dev/pkg/utils/errorf" "orly.dev/pkg/utils/log" realy_lol "orly.dev/pkg/version" - "os" ) const secEnv = "NOSTR_SECRET_KEY" diff --git a/cmd/walletcli/main.go b/cmd/walletcli/main.go index 0660ac6..59fc6ac 100644 --- a/cmd/walletcli/main.go +++ b/cmd/walletcli/main.go @@ -1,13 +1,13 @@ package main import ( - "encoding/json" "fmt" "os" "strconv" "strings" "orly.dev/pkg/protocol/nwc" + "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/context" ) @@ -45,21 +45,19 @@ func main() { printUsage() os.Exit(1) } - connectionURL := os.Args[1] method := os.Args[2] args := os.Args[3:] - // Create context + // ctx, cancel := context.Cancel(context.Bg()) ctx := context.Bg() - + // defer cancel() // Create NWC client client, err := nwc.NewClient(ctx, connectionURL) if err != nil { fmt.Printf("Error creating client: %v\n", err) os.Exit(1) } - // Execute the requested method switch method { case "get_wallet_service_info": @@ -98,43 +96,27 @@ func main() { } func handleGetWalletServiceInfo(ctx context.T, client *nwc.Client) { - raw, err := client.GetWalletServiceInfoRaw(ctx) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + if _, raw, err := client.GetWalletServiceInfo(ctx, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleGetInfo(ctx context.T, client *nwc.Client) { - raw, err := client.GetInfoRaw(ctx) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + if _, raw, err := client.GetInfo(ctx, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleGetBalance(ctx context.T, client *nwc.Client) { - raw, err := client.GetBalanceRaw(ctx) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + if _, raw, err := client.GetBalance(ctx, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleGetBudget(ctx context.T, client *nwc.Client) { - raw, err := client.GetBudgetRaw(ctx) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + if _, raw, err := client.GetBudget(ctx, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleMakeInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -143,25 +125,20 @@ func handleMakeInvoice(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli make_invoice [] [] []") return } - amount, err := strconv.ParseUint(args[0], 10, 64) if err != nil { fmt.Printf("Error parsing amount: %v\n", err) return } - params := &nwc.MakeInvoiceParams{ Amount: amount, } - if len(args) > 1 { params.Description = args[1] } - if len(args) > 2 { params.DescriptionHash = args[2] } - if len(args) > 3 { expiry, err := strconv.ParseInt(args[3], 10, 64) if err != nil { @@ -170,14 +147,10 @@ func handleMakeInvoice(ctx context.T, client *nwc.Client, args []string) { } params.Expiry = &expiry } - - raw, err := client.MakeInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + if _, raw, err = client.MakeInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handlePayInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -186,11 +159,9 @@ func handlePayInvoice(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli pay_invoice [] []") return } - params := &nwc.PayInvoiceParams{ Invoice: args[0], } - if len(args) > 1 { amount, err := strconv.ParseUint(args[1], 10, 64) if err != nil { @@ -199,21 +170,15 @@ func handlePayInvoice(ctx context.T, client *nwc.Client, args []string) { } params.Amount = &amount } - if len(args) > 2 { comment := args[2] params.Metadata = &nwc.PayInvoiceMetadata{ Comment: &comment, } } - - raw, err := client.PayInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + if _, raw, err := client.PayInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleLookupInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -222,9 +187,7 @@ func handleLookupInvoice(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli lookup_invoice ") return } - params := &nwc.LookupInvoiceParams{} - // Determine if the argument is a payment hash or an invoice if strings.HasPrefix(args[0], "ln") { invoice := args[0] @@ -233,19 +196,15 @@ func handleLookupInvoice(ctx context.T, client *nwc.Client, args []string) { paymentHash := args[0] params.PaymentHash = &paymentHash } - - raw, err := client.LookupInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var err error + var raw []byte + if _, raw, err = client.LookupInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleListTransactions(ctx context.T, client *nwc.Client, args []string) { params := &nwc.ListTransactionsParams{} - if len(args) > 0 { limit, err := strconv.ParseUint(args[0], 10, 16) if err != nil { @@ -255,7 +214,6 @@ func handleListTransactions(ctx context.T, client *nwc.Client, args []string) { limitUint16 := uint16(limit) params.Limit = &limitUint16 } - if len(args) > 1 { offset, err := strconv.ParseUint(args[1], 10, 32) if err != nil { @@ -265,7 +223,6 @@ func handleListTransactions(ctx context.T, client *nwc.Client, args []string) { offsetUint32 := uint32(offset) params.Offset = &offsetUint32 } - if len(args) > 2 { from, err := strconv.ParseInt(args[2], 10, 64) if err != nil { @@ -274,7 +231,6 @@ func handleListTransactions(ctx context.T, client *nwc.Client, args []string) { } params.From = &from } - if len(args) > 3 { until, err := strconv.ParseInt(args[3], 10, 64) if err != nil { @@ -283,14 +239,11 @@ func handleListTransactions(ctx context.T, client *nwc.Client, args []string) { } params.Until = &until } - - raw, err := client.ListTransactionsRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + var err error + if _, raw, err = client.ListTransactions(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleMakeHoldInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -299,26 +252,21 @@ func handleMakeHoldInvoice(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli make_hold_invoice [] [] []") return } - amount, err := strconv.ParseUint(args[0], 10, 64) if err != nil { fmt.Printf("Error parsing amount: %v\n", err) return } - params := &nwc.MakeHoldInvoiceParams{ Amount: amount, PaymentHash: args[1], } - if len(args) > 2 { params.Description = args[2] } - if len(args) > 3 { params.DescriptionHash = args[3] } - if len(args) > 4 { expiry, err := strconv.ParseInt(args[4], 10, 64) if err != nil { @@ -327,14 +275,10 @@ func handleMakeHoldInvoice(ctx context.T, client *nwc.Client, args []string) { } params.Expiry = &expiry } - - raw, err := client.MakeHoldInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + if _, raw, err = client.MakeHoldInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleSettleHoldInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -343,18 +287,14 @@ func handleSettleHoldInvoice(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli settle_hold_invoice ") return } - params := &nwc.SettleHoldInvoiceParams{ Preimage: args[0], } - - raw, err := client.SettleHoldInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + var err error + if raw, err = client.SettleHoldInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleCancelHoldInvoice(ctx context.T, client *nwc.Client, args []string) { @@ -367,14 +307,11 @@ func handleCancelHoldInvoice(ctx context.T, client *nwc.Client, args []string) { params := &nwc.CancelHoldInvoiceParams{ PaymentHash: args[0], } - - raw, err := client.CancelHoldInvoiceRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var err error + var raw []byte + if raw, err = client.CancelHoldInvoice(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleSignMessage(ctx context.T, client *nwc.Client, args []string) { @@ -387,14 +324,11 @@ func handleSignMessage(ctx context.T, client *nwc.Client, args []string) { params := &nwc.SignMessageParams{ Message: args[0], } - - raw, err := client.SignMessageRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + var err error + if _, raw, err = client.SignMessage(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) { @@ -403,26 +337,21 @@ func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli pay_keysend [] [ ...]") return } - pubkey := args[0] - amount, err := strconv.ParseUint(args[1], 10, 64) if err != nil { fmt.Printf("Error parsing amount: %v\n", err) return } - params := &nwc.PayKeysendParams{ Pubkey: pubkey, Amount: amount, } - // Optional preimage if len(args) > 2 { preimage := args[2] params.Preimage = &preimage } - // Optional TLV records (must come in pairs) if len(args) > 3 { // Start from index 3 and process pairs of arguments @@ -432,9 +361,7 @@ func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) { fmt.Printf("Error parsing TLV type: %v\n", err) return } - tlvValue := args[i+1] - params.TLVRecords = append( params.TLVRecords, nwc.PayKeysendTLVRecord{ Type: uint32(tlvType), @@ -443,14 +370,10 @@ func handlePayKeysend(ctx context.T, client *nwc.Client, args []string) { ) } } - - raw, err := client.PayKeysendRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + if _, raw, err = client.PayKeysend(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) } func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) { @@ -459,17 +382,14 @@ func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) { fmt.Println("Usage: walletcli create_connection [] [] [] []") return } - params := &nwc.CreateConnectionParams{ Pubkey: args[0], Name: args[1], RequestMethods: strings.Split(args[2], ","), } - if len(args) > 3 { params.NotificationTypes = strings.Split(args[3], ",") } - if len(args) > 4 { maxAmount, err := strconv.ParseUint(args[4], 10, 64) if err != nil { @@ -478,11 +398,9 @@ func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) { } params.MaxAmount = &maxAmount } - if len(args) > 5 { params.BudgetRenewal = &args[5] } - if len(args) > 6 { expiresAt, err := strconv.ParseInt(args[6], 10, 64) if err != nil { @@ -491,21 +409,9 @@ func handleCreateConnection(ctx context.T, client *nwc.Client, args []string) { } params.ExpiresAt = &expiresAt } - - raw, err := client.CreateConnectionRaw(ctx, params) - if err != nil { - fmt.Printf("Error: %v\n", err) - return + var raw []byte + var err error + if raw, err = client.CreateConnection(ctx, params, true); !chk.E(err) { + fmt.Println(string(raw)) } - - fmt.Println(string(raw)) -} - -func printJSON(v interface{}) { - data, err := json.MarshalIndent(v, "", " ") - if err != nil { - fmt.Printf("Error marshaling JSON: %v\n", err) - return - } - fmt.Println(string(data)) } diff --git a/go.mod b/go.mod index 5c913f2..06b223c 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,12 @@ go 1.24.2 require ( github.com/adrg/xdg v0.5.3 github.com/alexflint/go-arg v1.6.0 + github.com/coder/websocket v1.8.13 github.com/danielgtaylor/huma/v2 v2.34.1 github.com/davecgh/go-spew v1.1.1 github.com/dgraph-io/badger/v4 v4.7.0 github.com/fasthttp/websocket v1.5.12 github.com/fatih/color v1.18.0 - github.com/gobwas/httphead v0.1.0 - github.com/gobwas/ws v1.4.0 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 github.com/klauspost/cpuid/v2 v2.2.11 github.com/minio/sha256-simd v1.0.1 @@ -41,7 +40,6 @@ require ( github.com/felixge/fgprof v0.9.5 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/gobwas/pool v0.2.1 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/pprof v0.0.0-20250630185457-6e76a2b096b5 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index 743cffa..a7d3a1d 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/danielgtaylor/huma/v2 v2.34.1 h1:EmOJAbzEGfy0wAq/QMQ1YKfEMBEfE94xdBRLPBP0gwQ= github.com/danielgtaylor/huma/v2 v2.34.1/go.mod h1:ynwJgLk8iGVgoaipi5tgwIQ5yoFNmiu+QdhU7CEEmhk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -44,13 +46,9 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= -github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= -github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= -github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= diff --git a/pkg/app/relay/spider-fetch.go b/pkg/app/relay/spider-fetch.go index e18eb3d..da2d4a0 100644 --- a/pkg/app/relay/spider-fetch.go +++ b/pkg/app/relay/spider-fetch.go @@ -1,6 +1,9 @@ package relay import ( + "runtime/debug" + "time" + "orly.dev/pkg/crypto/ec/schnorr" "orly.dev/pkg/database/indexes/types" "orly.dev/pkg/encoders/event" @@ -14,8 +17,6 @@ import ( "orly.dev/pkg/utils/context" "orly.dev/pkg/utils/errorf" "orly.dev/pkg/utils/log" - "runtime/debug" - "time" ) // IdPkTs is a map of event IDs to their id, pubkey, kind, and timestamp @@ -139,16 +140,12 @@ func (s *Server) SpiderFetch( default: } var evss event.S - var cli *ws.Client + var cli *ws.Relay if cli, err = ws.RelayConnect( - context.Bg(), seed, ws.WithSignatureChecker( - func(e *event.E) bool { - return true - }, - ), + context.Bg(), seed, ); chk.E(err) { err = nil - return + continue } if evss, err = cli.QuerySync( context.Bg(), batchFilter, diff --git a/pkg/database/get-indexes-for-event_test.go b/pkg/database/get-indexes-for-event_test.go index 530c73d..7d647b7 100644 --- a/pkg/database/get-indexes-for-event_test.go +++ b/pkg/database/get-indexes-for-event_test.go @@ -2,16 +2,16 @@ package database import ( "bytes" + "testing" + "orly.dev/pkg/database/indexes" types2 "orly.dev/pkg/database/indexes/types" - "orly.dev/pkg/encoders/codecbuf" "orly.dev/pkg/encoders/event" "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/tag" "orly.dev/pkg/encoders/tags" "orly.dev/pkg/encoders/timestamp" "orly.dev/pkg/utils/chk" - "testing" "github.com/minio/sha256-simd" ) @@ -26,8 +26,7 @@ func TestGetIndexesForEvent(t *testing.T) { // indexes func verifyIndexIncluded(t *testing.T, idxs [][]byte, expectedIdx *indexes.T) { // Marshal the expected index - buf := codecbuf.Get() - defer codecbuf.Put(buf) + buf := new(bytes.Buffer) err := expectedIdx.MarshalWrite(buf) if chk.E(err) { t.Fatalf("Failed to marshal expected index: %v", err) diff --git a/pkg/database/get-indexes-from-filter_test.go b/pkg/database/get-indexes-from-filter_test.go index 02452c2..66d3ced 100644 --- a/pkg/database/get-indexes-from-filter_test.go +++ b/pkg/database/get-indexes-from-filter_test.go @@ -3,16 +3,16 @@ package database import ( "bytes" "math" + "testing" + "orly.dev/pkg/database/indexes" types2 "orly.dev/pkg/database/indexes/types" - "orly.dev/pkg/encoders/codecbuf" "orly.dev/pkg/encoders/filter" "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/kinds" "orly.dev/pkg/encoders/tag" "orly.dev/pkg/encoders/timestamp" "orly.dev/pkg/utils/chk" - "testing" "github.com/minio/sha256-simd" ) @@ -41,8 +41,7 @@ func verifyIndex( } // Marshal the expected start index - startBuf := codecbuf.Get() - defer codecbuf.Put(startBuf) + startBuf := new(bytes.Buffer) err := expectedStartIdx.MarshalWrite(startBuf) if chk.E(err) { t.Fatalf("Failed to marshal expected start index: %v", err) @@ -62,8 +61,7 @@ func verifyIndex( } // Marshal the expected end index - endBuf := codecbuf.Get() - defer codecbuf.Put(endBuf) + endBuf := new(bytes.Buffer) err = endIdx.MarshalWrite(endBuf) if chk.E(err) { t.Fatalf("Failed to marshal expected End index: %v", err) diff --git a/pkg/database/indexes/keys_test.go b/pkg/database/indexes/keys_test.go index 8367052..7679270 100644 --- a/pkg/database/indexes/keys_test.go +++ b/pkg/database/indexes/keys_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "orly.dev/pkg/database/indexes/types" - "orly.dev/pkg/encoders/codecbuf" "orly.dev/pkg/utils/chk" "testing" ) @@ -49,7 +48,7 @@ func TestPrefixMethods(t *testing.T) { } // Test MarshalWrite method - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := prefix.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -209,7 +208,7 @@ func TestTStruct(t *testing.T) { } // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -272,7 +271,7 @@ func TestEventFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -318,7 +317,7 @@ func TestIdFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -391,7 +390,7 @@ func TestIdPubkeyFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -452,7 +451,7 @@ func TestCreatedAtFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -516,7 +515,7 @@ func TestPubkeyFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -588,7 +587,7 @@ func TestPubkeyTagFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -660,7 +659,7 @@ func TestTagFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -724,7 +723,7 @@ func TestKindFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -789,7 +788,7 @@ func TestKindTagFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -865,7 +864,7 @@ func TestKindPubkeyFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -941,7 +940,7 @@ func TestKindPubkeyTagFunctions(t *testing.T) { } // Test marshaling and unmarshaling - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = enc.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/fullid_test.go b/pkg/database/indexes/types/fullid_test.go index 244d380..4ee2dab 100644 --- a/pkg/database/indexes/types/fullid_test.go +++ b/pkg/database/indexes/types/fullid_test.go @@ -2,10 +2,10 @@ package types import ( "bytes" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "testing" + "orly.dev/pkg/utils/chk" + "github.com/minio/sha256-simd" ) @@ -55,7 +55,7 @@ func TestIdMarshalWriteUnmarshalRead(t *testing.T) { } // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = fi1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/identhash_test.go b/pkg/database/indexes/types/identhash_test.go index c95b241..58a9b7d 100644 --- a/pkg/database/indexes/types/identhash_test.go +++ b/pkg/database/indexes/types/identhash_test.go @@ -2,10 +2,10 @@ package types import ( "bytes" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "testing" + "orly.dev/pkg/utils/chk" + "github.com/minio/sha256-simd" ) @@ -45,7 +45,7 @@ func TestIdent_MarshalWriteUnmarshalRead(t *testing.T) { } // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = i1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/idhash_test.go b/pkg/database/indexes/types/idhash_test.go index e5436ec..7066155 100644 --- a/pkg/database/indexes/types/idhash_test.go +++ b/pkg/database/indexes/types/idhash_test.go @@ -3,10 +3,10 @@ package types import ( "bytes" "encoding/base64" - "orly.dev/pkg/encoders/codecbuf" + "testing" + "orly.dev/pkg/encoders/hex" "orly.dev/pkg/utils/chk" - "testing" "github.com/minio/sha256-simd" ) @@ -142,7 +142,7 @@ func TestIdHashMarshalWriteUnmarshalRead(t *testing.T) { } // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = i1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/letter_test.go b/pkg/database/indexes/types/letter_test.go index 99c8389..acaac39 100644 --- a/pkg/database/indexes/types/letter_test.go +++ b/pkg/database/indexes/types/letter_test.go @@ -2,9 +2,9 @@ package types import ( "bytes" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "testing" + + "orly.dev/pkg/utils/chk" ) func TestLetter_New(t *testing.T) { @@ -53,7 +53,7 @@ func TestLetter_MarshalWriteUnmarshalRead(t *testing.T) { l1 := new(Letter) l1.Set('A') // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := l1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/pubhash_test.go b/pkg/database/indexes/types/pubhash_test.go index e789227..0b5c95e 100644 --- a/pkg/database/indexes/types/pubhash_test.go +++ b/pkg/database/indexes/types/pubhash_test.go @@ -2,11 +2,11 @@ package types import ( "bytes" + "testing" + "orly.dev/pkg/crypto/ec/schnorr" - "orly.dev/pkg/encoders/codecbuf" "orly.dev/pkg/encoders/hex" "orly.dev/pkg/utils/chk" - "testing" "github.com/minio/sha256-simd" ) @@ -105,7 +105,7 @@ func TestPubHash_MarshalWriteUnmarshalRead(t *testing.T) { } // Test MarshalWrite - buf := codecbuf.Get() + buf := new(bytes.Buffer) err = ph1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/timestamp_test.go b/pkg/database/indexes/types/timestamp_test.go index 787f3c0..b84954e 100644 --- a/pkg/database/indexes/types/timestamp_test.go +++ b/pkg/database/indexes/types/timestamp_test.go @@ -2,10 +2,10 @@ package types import ( "bytes" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "testing" "time" + + "orly.dev/pkg/utils/chk" ) func TestTimestamp_FromInt(t *testing.T) { @@ -89,7 +89,7 @@ func TestTimestamp_FromBytes(t *testing.T) { v.Set(12345) // Marshal it to bytes - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := v.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -163,7 +163,7 @@ func TestTimestamp_Bytes(t *testing.T) { func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) { // Test with a positive value ts1 := &Timestamp{val: 12345} - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := ts1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -183,7 +183,7 @@ func TestTimestamp_MarshalWriteUnmarshalRead(t *testing.T) { // Test with a negative value ts1 = &Timestamp{val: -12345} - buf = codecbuf.Get() + buf = new(bytes.Buffer) err = ts1.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) @@ -225,7 +225,7 @@ func TestTimestamp_WithCurrentTime(t *testing.T) { } // Test MarshalWrite and UnmarshalRead - buf := codecbuf.Get() + buf := new(bytes.Buffer) err := ts.MarshalWrite(buf) if chk.E(err) { t.Fatalf("MarshalWrite failed: %v", err) diff --git a/pkg/database/indexes/types/uint16_test.go b/pkg/database/indexes/types/uint16_test.go index 85b746a..0e071c6 100644 --- a/pkg/database/indexes/types/uint16_test.go +++ b/pkg/database/indexes/types/uint16_test.go @@ -3,11 +3,11 @@ package types import ( "bytes" "math" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "reflect" "testing" + "orly.dev/pkg/utils/chk" + "lukechampine.com/frand" ) @@ -44,7 +44,7 @@ func TestUint16(t *testing.T) { } // Test encoding to []byte and decoding back - bufEnc := codecbuf.Get() + bufEnc := new(bytes.Buffer) // MarshalWrite err := encodedUint16.MarshalWrite(bufEnc) diff --git a/pkg/database/indexes/types/uint24_test.go b/pkg/database/indexes/types/uint24_test.go index 5db14fa..173a1a7 100644 --- a/pkg/database/indexes/types/uint24_test.go +++ b/pkg/database/indexes/types/uint24_test.go @@ -1,10 +1,11 @@ package types import ( - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" + "bytes" "reflect" "testing" + + "orly.dev/pkg/utils/chk" ) func TestUint24(t *testing.T) { @@ -45,7 +46,7 @@ func TestUint24(t *testing.T) { } // Test MarshalWrite and UnmarshalRead - buf := codecbuf.Get() + buf := new(bytes.Buffer) // MarshalWrite directly to the buffer if err := codec.MarshalWrite(buf); chk.E(err) { diff --git a/pkg/database/indexes/types/uint32_test.go b/pkg/database/indexes/types/uint32_test.go index 68e72b8..fc5a62d 100644 --- a/pkg/database/indexes/types/uint32_test.go +++ b/pkg/database/indexes/types/uint32_test.go @@ -3,11 +3,11 @@ package types import ( "bytes" "math" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "reflect" "testing" + "orly.dev/pkg/utils/chk" + "lukechampine.com/frand" ) @@ -43,7 +43,7 @@ func TestUint32(t *testing.T) { } // Test encoding to []byte and decoding back - bufEnc := codecbuf.Get() + bufEnc := new(bytes.Buffer) // MarshalWrite err := codec.MarshalWrite(bufEnc) diff --git a/pkg/database/indexes/types/uint40_test.go b/pkg/database/indexes/types/uint40_test.go index 46c7843..52cd575 100644 --- a/pkg/database/indexes/types/uint40_test.go +++ b/pkg/database/indexes/types/uint40_test.go @@ -1,10 +1,11 @@ package types import ( - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" + "bytes" "reflect" "testing" + + "orly.dev/pkg/utils/chk" ) func TestUint40(t *testing.T) { @@ -48,7 +49,7 @@ func TestUint40(t *testing.T) { } // Test MarshalWrite and UnmarshalRead - buf := codecbuf.Get() + buf := new(bytes.Buffer) // Marshal to a buffer if err = codec.MarshalWrite(buf); chk.E(err) { diff --git a/pkg/database/indexes/types/uint64_test.go b/pkg/database/indexes/types/uint64_test.go index 214ca86..f7e5c27 100644 --- a/pkg/database/indexes/types/uint64_test.go +++ b/pkg/database/indexes/types/uint64_test.go @@ -3,11 +3,11 @@ package types import ( "bytes" "math" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "reflect" "testing" + "orly.dev/pkg/utils/chk" + "lukechampine.com/frand" ) @@ -43,7 +43,7 @@ func TestUint64(t *testing.T) { } // Test encoding to []byte and decoding back - bufEnc := codecbuf.Get() + bufEnc := new(bytes.Buffer) // MarshalWrite err := codec.MarshalWrite(bufEnc) diff --git a/pkg/encoders/envelopes/reqenvelope/reqenvelope.go b/pkg/encoders/envelopes/reqenvelope/reqenvelope.go index d4baeb1..c026434 100644 --- a/pkg/encoders/envelopes/reqenvelope/reqenvelope.go +++ b/pkg/encoders/envelopes/reqenvelope/reqenvelope.go @@ -4,6 +4,7 @@ package reqenvelope import ( "io" + "orly.dev/pkg/encoders/envelopes" "orly.dev/pkg/encoders/filters" "orly.dev/pkg/encoders/subscription" @@ -37,10 +38,21 @@ func New() *T { // NewFrom creates a new reqenvelope.T with a provided subscription.Id and // filters.T. -func NewFrom(id *subscription.Id, filters *filters.T) *T { +func NewFrom(id *subscription.Id, ff *filters.T) *T { return &T{ Subscription: id, - Filters: filters, + Filters: ff, + } +} + +func NewWithIdString(id string, ff *filters.T) (sub *T) { + sid, err := subscription.NewId(id) + if err != nil { + return + } + return &T{ + Subscription: sid, + Filters: ff, } } diff --git a/pkg/encoders/event/binary_test.go b/pkg/encoders/event/binary_test.go index 7d48ed9..69e3bd6 100644 --- a/pkg/encoders/event/binary_test.go +++ b/pkg/encoders/event/binary_test.go @@ -3,11 +3,11 @@ package event import ( "bufio" "bytes" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/encoders/event/examples" - "orly.dev/pkg/utils/chk" "testing" "time" + + "orly.dev/pkg/encoders/event/examples" + "orly.dev/pkg/utils/chk" ) func TestTMarshalBinary_UnmarshalBinary(t *testing.T) { @@ -19,7 +19,7 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) { var counter int for scanner.Scan() { // Create new event objects and buffer for each iteration - buf := codecbuf.Get() + buf := new(bytes.Buffer) ea, eb := New(), New() chk.E(scanner.Err()) @@ -42,7 +42,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) { // Create a new buffer for unmarshaling buf2 := bytes.NewBuffer(buf.Bytes()) if err = eb.UnmarshalBinary(buf2); chk.E(err) { - codecbuf.Put(buf) t.Fatal(err) } @@ -57,9 +56,6 @@ func TestTMarshalBinary_UnmarshalBinary(t *testing.T) { ) } - // Return buffer to pool - codecbuf.Put(buf) - counter++ out = out[:0] } diff --git a/pkg/encoders/event/event_test.go b/pkg/encoders/event/event_test.go index fc063f6..1c3147e 100644 --- a/pkg/encoders/event/event_test.go +++ b/pkg/encoders/event/event_test.go @@ -4,10 +4,11 @@ import ( "bufio" "bytes" _ "embed" + "testing" + "orly.dev/pkg/crypto/p256k" "orly.dev/pkg/encoders/event/examples" "orly.dev/pkg/utils/chk" - "testing" ) func TestTMarshal_Unmarshal(t *testing.T) { diff --git a/pkg/encoders/event/json_tags_test.go b/pkg/encoders/event/json_tags_test.go index 56bc689..4942f86 100644 --- a/pkg/encoders/event/json_tags_test.go +++ b/pkg/encoders/event/json_tags_test.go @@ -2,12 +2,13 @@ package event import ( "bytes" + "testing" + "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/tag" "orly.dev/pkg/encoders/tags" text2 "orly.dev/pkg/encoders/text" "orly.dev/pkg/encoders/timestamp" - "testing" ) // compareTags compares two tags and reports any differences @@ -96,7 +97,8 @@ func TestUnmarshalEscapedJSONInTags(t *testing.T) { unmarshaledTag := unmarshaledEvent.Tags.GetTagElement(0) if unmarshaledTag.Len() != 2 { t.Fatalf( - "Expected tag with 2 elements, got %d", unmarshaledTag.Len(), + "Expected tag with 2 elements, got %d", + unmarshaledTag.Len(), ) } diff --git a/pkg/encoders/event/json_whitespace_test.go b/pkg/encoders/event/json_whitespace_test.go index fae167f..0a7ba12 100644 --- a/pkg/encoders/event/json_whitespace_test.go +++ b/pkg/encoders/event/json_whitespace_test.go @@ -2,11 +2,12 @@ package event import ( "bytes" + "testing" + "orly.dev/pkg/encoders/hex" "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/tags" "orly.dev/pkg/encoders/timestamp" - "testing" ) // compareEvents compares two events and reports any differences diff --git a/pkg/encoders/filter/filter.go b/pkg/encoders/filter/filter.go index ac796b5..8a85ef2 100644 --- a/pkg/encoders/filter/filter.go +++ b/pkg/encoders/filter/filter.go @@ -7,6 +7,8 @@ package filter import ( "bytes" "encoding/binary" + "sort" + "orly.dev/pkg/crypto/ec/schnorr" "orly.dev/pkg/crypto/ec/secp256k1" "orly.dev/pkg/crypto/sha256" @@ -22,7 +24,6 @@ import ( "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/errorf" "orly.dev/pkg/utils/pointers" - "sort" "lukechampine.com/frand" ) @@ -440,8 +441,9 @@ invalid: return } -// Matches checks a filter against an event and determines if the event matches the filter. -func (f *F) Matches(ev *event.E) bool { +// MatchesIgnoringTimestampConstraints checks a filter against an event and +// determines if the event matches the filter, ignoring timestamp constraints.. +func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool { if ev == nil { // log.F.ToSliceOfBytes("nil event") return false @@ -461,22 +463,30 @@ func (f *F) Matches(ev *event.E) bool { if f.Tags.Len() > 0 && !ev.Tags.Intersects(f.Tags) { return false } - // if f.Tags.Len() > 0 { - // for _, v := range f.Tags.ToSliceOfTags() { - // tvs := v.ToSliceOfBytes() - // if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) { - // return false - // } - // } - // return false - // } + if f.Tags.Len() > 0 { + for _, v := range f.Tags.ToSliceOfTags() { + tvs := v.ToSliceOfBytes() + if !ev.Tags.ContainsAny(v.FilterKey(), tag.New(tvs...)) { + return false + } + } + return false + } + return true +} + +// Matches checks a filter against an event and determines if the event matches the filter. +func (f *F) Matches(ev *event.E) (match bool) { + if !f.MatchesIgnoringTimestampConstraints(ev) { + return + } if f.Since.Int() != 0 && ev.CreatedAt.I64() < f.Since.I64() { // log.F.ToSliceOfBytes("event is older than since\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String()) - return false + return } if f.Until.Int() != 0 && ev.CreatedAt.I64() > f.Until.I64() { // log.F.ToSliceOfBytes("event is newer than until\nEVENT %s\nFILTER %s", ev.ToObject().String(), f.ToObject().String()) - return false + return } return true } diff --git a/pkg/encoders/filters/filters.go b/pkg/encoders/filters/filters.go index 9f15528..5ee1bae 100644 --- a/pkg/encoders/filters/filters.go +++ b/pkg/encoders/filters/filters.go @@ -120,3 +120,12 @@ func GenFilters(n int) (ff *T, err error) { } return } + +func (f *T) MatchIgnoringTimestampConstraints(ev *event.E) bool { + for _, ff := range f.F { + if ff.MatchesIgnoringTimestampConstraints(ev) { + return true + } + } + return false +} diff --git a/pkg/encoders/subscription/subscriptionid.go b/pkg/encoders/subscription/subscriptionid.go index 989363b..767c958 100644 --- a/pkg/encoders/subscription/subscriptionid.go +++ b/pkg/encoders/subscription/subscriptionid.go @@ -5,6 +5,7 @@ package subscription import ( "crypto/rand" + "orly.dev/pkg/crypto/ec/bech32" "orly.dev/pkg/encoders/text" "orly.dev/pkg/utils/chk" @@ -24,7 +25,7 @@ func (si *Id) IsValid() bool { return len(si.T) <= 64 && len(si.T) > 0 } // NewId inspects a string and converts to Id if it is // valid. Invalid means length == 0 or length > 64. -func NewId[V string | []byte](s V) (*Id, error) { +func NewId[V ~string | ~[]byte](s V) (*Id, error) { si := &Id{T: []byte(s)} if si.IsValid() { return si, nil @@ -40,7 +41,7 @@ func NewId[V string | []byte](s V) (*Id, error) { // MustNew is the same as NewId except it doesn't check if you feed it rubbish. // // DO NOT USE WITHOUT CHECKING THE Id IS NOT NIL AND > 0 AND <= 64 -func MustNew[V string | []byte](s V) *Id { +func MustNew[V ~string | ~[]byte](s V) *Id { return &Id{T: []byte(s)} } diff --git a/pkg/encoders/tags/tags.go b/pkg/encoders/tags/tags.go index 308be2f..a85644d 100644 --- a/pkg/encoders/tags/tags.go +++ b/pkg/encoders/tags/tags.go @@ -7,12 +7,13 @@ import ( "encoding/json" "errors" "fmt" + "os" + "sort" + "orly.dev/pkg/encoders/tag" "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/log" "orly.dev/pkg/utils/lol" - "os" - "sort" ) // T is a list of tag.T - which are lists of string elements with ordering and no uniqueness @@ -161,6 +162,15 @@ func (t *T) GetFirst(tagPrefix *tag.T) *tag.T { return nil } +func (t *T) GetD() (d string) { + for _, v := range t.element { + if bytes.Equal(v.Key(), []byte("d")) { + return string(v.Value()) + } + } + return +} + // GetLast gets the last tag in tags that matches the prefix, see [T.StartsWith] func (t *T) GetLast(tagPrefix *tag.T) *tag.T { for i := len(t.element) - 1; i >= 0; i-- { diff --git a/pkg/encoders/varint/varint.go b/pkg/encoders/varint/varint.go index eaf1de8..2fd4ab9 100644 --- a/pkg/encoders/varint/varint.go +++ b/pkg/encoders/varint/varint.go @@ -5,8 +5,9 @@ package varint import ( - "golang.org/x/exp/constraints" "io" + + "golang.org/x/exp/constraints" "orly.dev/pkg/utils/chk" ) diff --git a/pkg/encoders/varint/varint_test.go b/pkg/encoders/varint/varint_test.go index b35e021..9b82838 100644 --- a/pkg/encoders/varint/varint_test.go +++ b/pkg/encoders/varint/varint_test.go @@ -3,10 +3,10 @@ package varint import ( "bytes" "math" - "orly.dev/pkg/encoders/codecbuf" - "orly.dev/pkg/utils/chk" "testing" + "orly.dev/pkg/utils/chk" + "lukechampine.com/frand" ) @@ -14,7 +14,7 @@ func TestEncode_Decode(t *testing.T) { var v uint64 for range 10000000 { v = uint64(frand.Intn(math.MaxInt64)) - buf1 := codecbuf.Get() + buf1 := new(bytes.Buffer) Encode(buf1, v) buf2 := bytes.NewBuffer(buf1.Bytes()) u, err := Decode(buf2) diff --git a/pkg/protocol/nwc/client.go b/pkg/protocol/nwc/client.go index 93bc3de..c1681b3 100644 --- a/pkg/protocol/nwc/client.go +++ b/pkg/protocol/nwc/client.go @@ -10,6 +10,7 @@ import ( "orly.dev/pkg/encoders/event" "orly.dev/pkg/encoders/filter" "orly.dev/pkg/encoders/filters" + "orly.dev/pkg/encoders/hex" "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/kinds" "orly.dev/pkg/encoders/tag" @@ -19,12 +20,13 @@ import ( "orly.dev/pkg/protocol/ws" "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/context" + "orly.dev/pkg/utils/log" "orly.dev/pkg/utils/values" ) type Client struct { - pool *ws.Pool - relays []string + client *ws.Client + relay string clientSecretKey signer.I walletPublicKey []byte conversationKey []byte // nip44 @@ -66,9 +68,13 @@ func NewClient(c context.T, connectionURI string) (cl *Client, err error) { ); chk.E(err) { return } + var relay *ws.Client + if relay, err = ws.RelayConnect(c, parts.relay); chk.E(err) { + return + } cl = &Client{ - pool: ws.NewPool(c), - relays: parts.relays, + client: relay, + relay: parts.relay, clientSecretKey: clientKey, walletPublicKey: parts.walletPublicKey, conversationKey: ck, @@ -81,14 +87,9 @@ type rpcOptions struct { } func (cl *Client) RPC( - c context.T, method Capability, params, result any, opts *rpcOptions, -) (err error) { - timeout := time.Duration(10) - if opts != nil && opts.timeout != nil { - timeout = *opts.timeout - } - ctx, cancel := context.Timeout(c, timeout) - defer cancel() + c context.T, method Capability, params, result any, noUnmarshal bool, + opts *rpcOptions, +) (raw []byte, err error) { var req []byte if req, err = json.Marshal( Request{ @@ -107,155 +108,58 @@ func (cl *Client) RPC( CreatedAt: timestamp.Now(), Kind: kind.WalletRequest, Tags: tags.New( - tag.New([]byte("p"), cl.walletPublicKey), + tag.New("p", hex.Enc(cl.walletPublicKey)), tag.New(EncryptionTag, Nip44V2), ), } if err = ev.Sign(cl.clientSecretKey); chk.E(err) { return } - hasWorked := make(chan struct{}) - evs := cl.pool.SubMany( - c, cl.relays, &filters.T{ - F: []*filter.F{ - { - Limit: values.ToUintPointer(1), - Kinds: kinds.New(kind.WalletRequest), - Authors: tag.New(cl.walletPublicKey), - Tags: tags.New(tag.New([]byte("#e"), ev.ID)), - }, - }, - }, - ) - for _, u := range cl.relays { - go func(u string) { - var relay *ws.Client - if relay, err = cl.pool.EnsureRelay(u); chk.E(err) { - return - } - if err = relay.Publish(c, ev); chk.E(err) { - return - } - select { - case hasWorked <- struct{}{}: - case <-ctx.Done(): - err = fmt.Errorf("context canceled waiting for request send") - return - default: - } - }(u) + var ok bool + if ok, err = ev.Verify(); chk.E(err) { } - select { - case <-hasWorked: - // continue - case <-ctx.Done(): - err = fmt.Errorf("timed out waiting for relays") + log.I.F("verify: %v", ok) + var rc *ws.Client + if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) { return } + defer rc.Close() + var sub *ws.Subscription + if sub, err = rc.Subscribe( + c, filters.New( + &filter.F{ + Limit: values.ToUintPointer(1), + Kinds: kinds.New(kind.WalletRequest), + Authors: tag.New(cl.clientSecretKey.Pub()), + Tags: tags.New(tag.New([]byte("#e"), ev.ID)), + }, + ), + ); chk.E(err) { + return + } + defer sub.Unsub() + if err = rc.Publish(context.Bg(), ev); chk.E(err) { + return + } + log.I.F("published event %s", ev.Marshal(nil)) select { - case <-ctx.Done(): + case <-c.Done(): err = fmt.Errorf("context canceled waiting for response") - case e := <-evs: - var plain []byte - if plain, err = encryption.Decrypt( - e.Event.Content, cl.conversationKey, + case e := <-sub.Events: + if raw, err = encryption.Decrypt( + e.Content, cl.conversationKey, ); chk.E(err) { return } + if noUnmarshal { + return + } resp := &Response{ Result: &result, } - if err = json.Unmarshal(plain, resp); chk.E(err) { + if err = json.Unmarshal(raw, resp); chk.E(err) { return } - return - } - return -} - -// RPCRaw performs an RPC call and returns the raw JSON response -func (cl *Client) RPCRaw( - c context.T, method Capability, params any, opts *rpcOptions, -) (rawResponse []byte, err error) { - timeout := time.Duration(10) - if opts != nil && opts.timeout != nil { - timeout = *opts.timeout - } - ctx, cancel := context.Timeout(c, timeout) - defer cancel() - var req []byte - if req, err = json.Marshal( - Request{ - Method: string(method), - Params: params, - }, - ); chk.E(err) { - return - } - var content []byte - if content, err = encryption.Encrypt(req, cl.conversationKey); chk.E(err) { - return - } - ev := &event.E{ - Content: content, - CreatedAt: timestamp.Now(), - Kind: kind.WalletRequest, - Tags: tags.New( - tag.New([]byte("p"), cl.walletPublicKey), - tag.New(EncryptionTag, Nip44V2), - ), - } - if err = ev.Sign(cl.clientSecretKey); chk.E(err) { - return - } - hasWorked := make(chan struct{}) - evs := cl.pool.SubMany( - c, cl.relays, &filters.T{ - F: []*filter.F{ - { - Limit: values.ToUintPointer(1), - Kinds: kinds.New(kind.WalletRequest), - Authors: tag.New(cl.walletPublicKey), - Tags: tags.New(tag.New([]byte("#e"), ev.ID)), - }, - }, - }, - ) - for _, u := range cl.relays { - go func(u string) { - var relay *ws.Client - if relay, err = cl.pool.EnsureRelay(u); chk.E(err) { - return - } - if err = relay.Publish(c, ev); chk.E(err) { - return - } - select { - case hasWorked <- struct{}{}: - case <-ctx.Done(): - err = fmt.Errorf("context canceled waiting for request send") - return - default: - } - }(u) - } - select { - case <-hasWorked: - // continue - case <-ctx.Done(): - err = fmt.Errorf("timed out waiting for relays") - return - } - select { - case <-ctx.Done(): - err = fmt.Errorf("context canceled waiting for response") - case e := <-evs: - if rawResponse, err = encryption.Decrypt( - e.Event.Content, cl.conversationKey, - ); chk.E(err) { - return - } - return } return } diff --git a/pkg/protocol/nwc/methods.go b/pkg/protocol/nwc/methods.go index 3374f8d..5ef2905 100644 --- a/pkg/protocol/nwc/methods.go +++ b/pkg/protocol/nwc/methods.go @@ -2,42 +2,55 @@ package nwc import ( "bytes" - "encoding/json" "fmt" + "time" "orly.dev/pkg/encoders/filter" "orly.dev/pkg/encoders/filters" "orly.dev/pkg/encoders/kind" "orly.dev/pkg/encoders/kinds" "orly.dev/pkg/encoders/tag" + "orly.dev/pkg/protocol/ws" "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/context" + "orly.dev/pkg/utils/values" ) -func (cl *Client) GetWalletServiceInfo(c context.T) ( - wsi *WalletServiceInfo, err error, +func (cl *Client) GetWalletServiceInfo(c context.T, noUnmarshal bool) ( + wsi *WalletServiceInfo, raw []byte, err error, ) { - lim := uint(1) - evc := cl.pool.SubMany( - c, cl.relays, &filters.T{ - F: []*filter.F{ - { - Limit: &lim, - Kinds: kinds.New(kind.WalletInfo), - Authors: tag.New(cl.walletPublicKey), - }, + timeout := 10 * time.Second + ctx, cancel := context.Timeout(c, timeout) + defer cancel() + var rc *ws.Client + if rc, err = ws.RelayConnect(c, cl.relay); chk.E(err) { + return + } + if err = rc.Connect(c); chk.E(err) { + return + } + var sub *ws.Subscription + if sub, err = rc.Subscribe( + ctx, filters.New( + &filter.F{ + Limit: values.ToUintPointer(1), + Kinds: kinds.New(kind.WalletRequest), + Authors: tag.New(cl.walletPublicKey), }, - }, - ) + ), + ); chk.E(err) { + return + } + defer sub.Unsub() select { case <-c.Done(): err = fmt.Errorf("GetWalletServiceInfo canceled") return - case ev := <-evc: + case ev := <-sub.Events: var encryptionTypes []EncryptionType var notificationTypes []NotificationType - encryptionTag := ev.Event.Tags.GetFirst(tag.New("encryption")) - notificationsTag := ev.Event.Tags.GetFirst(tag.New("notifications")) + encryptionTag := ev.Tags.GetFirst(tag.New("encryption")) + notificationsTag := ev.Tags.GetFirst(tag.New("notifications")) if encryptionTag != nil { et := encryptionTag.ToSliceOfBytes() encType := bytes.Split(et[0], []byte(" ")) @@ -52,7 +65,7 @@ func (cl *Client) GetWalletServiceInfo(c context.T) ( notificationTypes = append(notificationTypes, e) } } - cp := bytes.Split(ev.Event.Content, []byte(" ")) + cp := bytes.Split(ev.Content, []byte(" ")) var capabilities []Capability for _, capability := range cp { capabilities = append(capabilities, capability) @@ -66,190 +79,72 @@ func (cl *Client) GetWalletServiceInfo(c context.T) ( return } -func (cl *Client) GetWalletServiceInfoRaw(c context.T) ( - raw []byte, err error, -) { - lim := uint(1) - evc := cl.pool.SubMany( - c, cl.relays, &filters.T{ - F: []*filter.F{ - { - Limit: &lim, - Kinds: kinds.New(kind.WalletInfo), - Authors: tag.New(cl.walletPublicKey), - }, - }, - }, - ) - select { - case <-c.Done(): - err = fmt.Errorf("GetWalletServiceInfoRaw canceled") - return - case ev := <-evc: - // Marshal the event to JSON - if raw, err = json.Marshal(ev.Event); chk.E(err) { - return - } - } - return -} - func (cl *Client) CancelHoldInvoice( - c context.T, chi *CancelHoldInvoiceParams, -) (err error) { - if err = cl.RPC(c, CancelHoldInvoice, chi, nil, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) CancelHoldInvoiceRaw( - c context.T, chi *CancelHoldInvoiceParams, + c context.T, chi *CancelHoldInvoiceParams, noUnmarshal bool, ) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, CancelHoldInvoice, chi, nil); chk.E(err) { - return - } - return + return cl.RPC(c, CancelHoldInvoice, chi, nil, noUnmarshal, nil) } func (cl *Client) CreateConnection( - c context.T, cc *CreateConnectionParams, -) (err error) { - if err = cl.RPC(c, CreateConnection, cc, nil, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) CreateConnectionRaw( - c context.T, cc *CreateConnectionParams, + c context.T, cc *CreateConnectionParams, noUnmarshal bool, ) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, CreateConnection, cc, nil); chk.E(err) { - return - } - return + return cl.RPC(c, CreateConnection, cc, nil, noUnmarshal, nil) } -func (cl *Client) GetBalance(c context.T) (gb *GetBalanceResult, err error) { +func (cl *Client) GetBalance(c context.T, noUnmarshal bool) ( + gb *GetBalanceResult, raw []byte, err error, +) { gb = &GetBalanceResult{} - if err = cl.RPC(c, GetBalance, nil, gb, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, GetBalance, nil, gb, noUnmarshal, nil) return } -func (cl *Client) GetBalanceRaw(c context.T) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, GetBalance, nil, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) GetBudget(c context.T) (gb *GetBudgetResult, err error) { +func (cl *Client) GetBudget(c context.T, noUnmarshal bool) ( + gb *GetBudgetResult, raw []byte, err error, +) { gb = &GetBudgetResult{} - if err = cl.RPC(c, GetBudget, nil, gb, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, GetBudget, nil, gb, noUnmarshal, nil) return } -func (cl *Client) GetBudgetRaw(c context.T) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, GetBudget, nil, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) GetInfo(c context.T) (gi *GetInfoResult, err error) { +func (cl *Client) GetInfo(c context.T, noUnmarshal bool) ( + gi *GetInfoResult, raw []byte, err error, +) { gi = &GetInfoResult{} - if err = cl.RPC(c, GetInfo, nil, gi, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) GetInfoRaw(c context.T) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, GetInfo, nil, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, GetInfo, nil, gi, noUnmarshal, nil) return } func (cl *Client) ListTransactions( - c context.T, params *ListTransactionsParams, -) (lt *ListTransactionsResult, err error) { + c context.T, params *ListTransactionsParams, noUnmarshal bool, +) (lt *ListTransactionsResult, raw []byte, err error) { lt = &ListTransactionsResult{} - if err = cl.RPC(c, ListTransactions, params, <, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) ListTransactionsRaw( - c context.T, params *ListTransactionsParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, ListTransactions, params, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, ListTransactions, params, <, noUnmarshal, nil) return } func (cl *Client) LookupInvoice( - c context.T, params *LookupInvoiceParams, -) (li *LookupInvoiceResult, err error) { + c context.T, params *LookupInvoiceParams, noUnmarshal bool, +) (li *LookupInvoiceResult, raw []byte, err error) { li = &LookupInvoiceResult{} - if err = cl.RPC(c, LookupInvoice, params, &li, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) LookupInvoiceRaw( - c context.T, params *LookupInvoiceParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, LookupInvoice, params, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, LookupInvoice, params, &li, noUnmarshal, nil) return } func (cl *Client) MakeHoldInvoice( c context.T, - mhi *MakeHoldInvoiceParams, -) (mi *MakeInvoiceResult, err error) { + mhi *MakeHoldInvoiceParams, noUnmarshal bool, +) (mi *MakeInvoiceResult, raw []byte, err error) { mi = &MakeInvoiceResult{} - if err = cl.RPC(c, MakeHoldInvoice, mhi, mi, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) MakeHoldInvoiceRaw( - c context.T, - mhi *MakeHoldInvoiceParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, MakeHoldInvoice, mhi, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, MakeHoldInvoice, mhi, mi, noUnmarshal, nil) return } func (cl *Client) MakeInvoice( - c context.T, params *MakeInvoiceParams, -) (mi *MakeInvoiceResult, err error) { + c context.T, params *MakeInvoiceParams, noUnmarshal bool, +) (mi *MakeInvoiceResult, raw []byte, err error) { mi = &MakeInvoiceResult{} - if err = cl.RPC(c, MakeInvoice, params, &mi, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) MakeInvoiceRaw( - c context.T, params *MakeInvoiceParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, MakeInvoice, params, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, MakeInvoice, params, &mi, noUnmarshal, nil) return } @@ -258,76 +153,31 @@ func (cl *Client) MakeInvoiceRaw( // MultiPayKeysend func (cl *Client) PayKeysend( - c context.T, params *PayKeysendParams, -) (pk *PayKeysendResult, err error) { + c context.T, params *PayKeysendParams, noUnmarshal bool, +) (pk *PayKeysendResult, raw []byte, err error) { pk = &PayKeysendResult{} - if err = cl.RPC(c, PayKeysend, params, &pk, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) PayKeysendRaw( - c context.T, params *PayKeysendParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, PayKeysend, params, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, PayKeysend, params, &pk, noUnmarshal, nil) return } func (cl *Client) PayInvoice( - c context.T, params *PayInvoiceParams, -) (pi *PayInvoiceResult, err error) { + c context.T, params *PayInvoiceParams, noUnmarshal bool, +) (pi *PayInvoiceResult, raw []byte, err error) { pi = &PayInvoiceResult{} - if err = cl.RPC(c, PayInvoice, params, &pi, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) PayInvoiceRaw( - c context.T, params *PayInvoiceParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, PayInvoice, params, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, PayInvoice, params, &pi, noUnmarshal, nil) return } func (cl *Client) SettleHoldInvoice( - c context.T, shi *SettleHoldInvoiceParams, -) (err error) { - if err = cl.RPC(c, SettleHoldInvoice, shi, nil, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) SettleHoldInvoiceRaw( - c context.T, shi *SettleHoldInvoiceParams, + c context.T, shi *SettleHoldInvoiceParams, noUnmarshal bool, ) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, SettleHoldInvoice, shi, nil); chk.E(err) { - return - } - return + return cl.RPC(c, SettleHoldInvoice, shi, nil, noUnmarshal, nil) } func (cl *Client) SignMessage( - c context.T, sm *SignMessageParams, -) (res *SignMessageResult, err error) { + c context.T, sm *SignMessageParams, noUnmarshal bool, +) (res *SignMessageResult, raw []byte, err error) { res = &SignMessageResult{} - if err = cl.RPC(c, SignMessage, sm, &res, nil); chk.E(err) { - return - } - return -} - -func (cl *Client) SignMessageRaw( - c context.T, sm *SignMessageParams, -) (raw []byte, err error) { - if raw, err = cl.RPCRaw(c, SignMessage, sm, nil); chk.E(err) { - return - } + raw, err = cl.RPC(c, SignMessage, sm, &res, noUnmarshal, nil) return } diff --git a/pkg/protocol/nwc/uri.go b/pkg/protocol/nwc/uri.go index 4571fce..9fa38bf 100644 --- a/pkg/protocol/nwc/uri.go +++ b/pkg/protocol/nwc/uri.go @@ -11,7 +11,7 @@ import ( type ConnectionParams struct { clientSecretKey []byte walletPublicKey []byte - relays []string + relay string } // GetWalletPublicKey returns the wallet public key from the ConnectionParams. @@ -35,13 +35,15 @@ func ParseConnectionURI(nwcUri string) (parts *ConnectionParams, err error) { } query := p.Query() var ok bool - if parts.relays, ok = query["relay"]; !ok { + var relay []string + if relay, ok = query["relay"]; !ok { err = errors.New("missing relay parameter") return } - if len(parts.relays) == 0 { + if len(relay) == 0 { return nil, errors.New("no relays") } + parts.relay = relay[0] var secret string if secret = query.Get("secret"); secret == "" { err = errors.New("missing secret parameter") diff --git a/pkg/protocol/ws/client.go b/pkg/protocol/ws/client.go index f451866..a100686 100644 --- a/pkg/protocol/ws/client.go +++ b/pkg/protocol/ws/client.go @@ -3,11 +3,19 @@ package ws import ( "bytes" "crypto/tls" + "errors" + "fmt" "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/puzpuzpuz/xsync/v3" "orly.dev/pkg/encoders/envelopes" "orly.dev/pkg/encoders/envelopes/authenvelope" "orly.dev/pkg/encoders/envelopes/closedenvelope" - "orly.dev/pkg/encoders/envelopes/countenvelope" "orly.dev/pkg/encoders/envelopes/eoseenvelope" "orly.dev/pkg/encoders/envelopes/eventenvelope" "orly.dev/pkg/encoders/envelopes/noticeenvelope" @@ -16,53 +24,41 @@ import ( "orly.dev/pkg/encoders/filter" "orly.dev/pkg/encoders/filters" "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/tag" + "orly.dev/pkg/encoders/tags" + "orly.dev/pkg/encoders/timestamp" "orly.dev/pkg/interfaces/signer" - "orly.dev/pkg/protocol/auth" - "orly.dev/pkg/utils/atomic" "orly.dev/pkg/utils/chk" "orly.dev/pkg/utils/context" - "orly.dev/pkg/utils/errorf" "orly.dev/pkg/utils/log" "orly.dev/pkg/utils/normalize" - "sync" - "time" - - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" - "github.com/puzpuzpuz/xsync/v3" ) -var subscriptionIDCounter atomic.Int32 +var subscriptionIDCounter atomic.Int64 -type Client struct { +// Relay represents a connection to a Nostr relay. +type Relay struct { closeMutex sync.Mutex - URL string + URL string + requestHeader http.Header // e.g. for origin header - RequestHeader http.Header // e.g. for origin header + Connection *Connection + Subscriptions *xsync.MapOf[int64, *Subscription] - Connection *Connection - - Subscriptions *xsync.MapOf[string, *Subscription] - - ConnectionError error - - connectionContext context.T // will be canceled when the connection closes - - connectionContextCancel context.F - - challenge []byte // NIP-42 challenge, we only keep the last - - notices chan []byte // NIP-01 NOTICEs - - okCallbacks *xsync.MapOf[string, func(bool, string)] - - writeQueue chan writeRequest + ConnectionError error + connectionContext context.T // will be canceled when the connection closes + connectionContextCancel context.C + challenge []byte // NIP-42 challenge, we only keep the last + noticeHandler func(string) // NIP-01 NOTICEs + customHandler func(string) // nonstandard unparseable messages + okCallbacks *xsync.MapOf[string, func(bool, string)] + writeQueue chan writeRequest subscriptionChannelCloseQueue chan *Subscription - signatureChecker func(*event.E) bool - + // custom things that aren't often used + // AssumeValid bool // this will skip verifying signatures for events received from this relay } @@ -71,21 +67,20 @@ type writeRequest struct { answer chan error } -// NewRelay returns a new relay. The relay connection will be closed when the -// context is cancelled. -func NewRelay(c context.T, url string, opts ...RelayOption) *Client { - ctx, cancel := context.Cancel(c) - r := &Client{ - URL: string(normalize.URL([]byte(url))), +// NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection. +func NewRelay(ctx context.T, url string, opts ...RelayOption) *Relay { + ctx, cancel := context.Cause(ctx) + r := &Relay{ + URL: string(normalize.URL(url)), connectionContext: ctx, connectionContextCancel: cancel, - Subscriptions: xsync.NewMapOf[string, *Subscription](), + Subscriptions: xsync.NewMapOf[int64, *Subscription](), okCallbacks: xsync.NewMapOf[string, func( bool, string, )](), writeQueue: make(chan writeRequest), subscriptionChannelCloseQueue: make(chan *Subscription), - signatureChecker: func(e *event.E) bool { ok, _ := e.Verify(); return ok }, + requestHeader: nil, } for _, opt := range opts { @@ -95,301 +90,323 @@ func NewRelay(c context.T, url string, opts ...RelayOption) *Client { return r } -// RelayConnect returns a relay object connected to url. Once successfully -// connected, cancelling ctx has no effect. To close the connection, call -// r.Close(). +// RelayConnect returns a relay object connected to url. +// +// The given subscription is only used during the connection phase. Once successfully connected, cancelling ctx has no effect. +// +// The ongoing relay connection uses a background context. To close the connection, call r.Close(). +// If you need fine grained long-term connection contexts, use NewRelay() instead. func RelayConnect(ctx context.T, url string, opts ...RelayOption) ( - *Client, error, + *Relay, error, ) { r := NewRelay(context.Bg(), url, opts...) err := r.Connect(ctx) return r, err } -// RelayOption is the type of the argument passed for that. +// RelayOption is the type of the argument passed when instantiating relay connections. type RelayOption interface { - ApplyRelayOption(*Client) + ApplyRelayOption(*Relay) } var ( _ RelayOption = (WithNoticeHandler)(nil) - _ RelayOption = (WithSignatureChecker)(nil) + _ RelayOption = (WithCustomHandler)(nil) + _ RelayOption = (WithRequestHeader)(nil) ) -// WithNoticeHandler just takes notices and is expected to do something with -// them. when not given, defaults to logging the notices. -type WithNoticeHandler func(notice []byte) +// WithNoticeHandler just takes notices and is expected to do something with them. +// when not given, defaults to logging the notices. +type WithNoticeHandler func(notice string) -func (nh WithNoticeHandler) ApplyRelayOption(r *Client) { - r.notices = make(chan []byte) - go func() { - for notice := range r.notices { - nh(notice) - } - }() +func (nh WithNoticeHandler) ApplyRelayOption(r *Relay) { + r.noticeHandler = nh } -// WithSignatureChecker must be a function that checks the signature of an event -// and returns true or false. -type WithSignatureChecker func(*event.E) bool +// WithCustomHandler must be a function that handles any relay message that couldn't be +// parsed as a standard envelope. +type WithCustomHandler func(data string) -func (sc WithSignatureChecker) ApplyRelayOption(r *Client) { - r.signatureChecker = sc +func (ch WithCustomHandler) ApplyRelayOption(r *Relay) { + r.customHandler = ch +} + +// WithRequestHeader sets the HTTP request header of the websocket preflight request. +type WithRequestHeader http.Header + +func (ch WithRequestHeader) ApplyRelayOption(r *Relay) { + r.requestHeader = http.Header(ch) } // String just returns the relay URL. -func (r *Client) String() string { +func (r *Relay) String() string { return r.URL } // Context retrieves the context that is associated with this relay connection. -func (r *Client) Context() context.T { return r.connectionContext } +// It will be closed when the relay is disconnected. +func (r *Relay) Context() context.T { return r.connectionContext } // IsConnected returns true if the connection to this relay seems to be active. -func (r *Client) IsConnected() bool { return r.connectionContext.Err() == nil } +func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil } -// Connect tries to establish a websocket connection to r.URL. If the context -// expires before the connection is complete, an error is returned. Once -// successfully connected, context expiration has no effect: call r.Close to -// close the connection. +// Connect tries to establish a websocket connection to r.URL. +// If the context expires before the connection is complete, an error is returned. +// Once successfully connected, context expiration has no effect: call r.Close +// to close the connection. // -// The underlying relay connection will use a background context. If you want to -// pass a custom context to the underlying relay connection, use NewRelay() and -// then Client.Connect(). -func (r *Client) Connect(c context.T) error { return r.ConnectWithTLS(c, nil) } +// The given context here is only used during the connection phase. The long-living +// relay connection will be based on the context given to NewRelay(). +func (r *Relay) Connect(ctx context.T) error { + return r.ConnectWithTLS(ctx, nil) +} -// ConnectWithTLS tries to establish a secured websocket connection to r.URL -// using customized tls.Config (CA's, etc.). -func (r *Client) ConnectWithTLS(ctx context.T, tlsConfig *tls.Config) error { +func subIdToSerial(subId string) int64 { + n := strings.Index(subId, ":") + if n < 0 || n > len(subId) { + return -1 + } + serialId, _ := strconv.ParseInt(subId[0:n], 10, 64) + return serialId +} + +// ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that. +func (r *Relay) ConnectWithTLS( + ctx context.T, tlsConfig *tls.Config, +) (err error) { if r.connectionContext == nil || r.Subscriptions == nil { - return errorf.E("relay must be initialized with a call to NewRelay()") + return fmt.Errorf("relay must be initialized with a call to NewRelay()") } + if r.URL == "" { - return errorf.E("invalid relay URL '%s'", r.URL) + return fmt.Errorf("invalid relay URL '%s'", r.URL) } + if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds var cancel context.F - ctx, cancel = context.Timeout(ctx, 7*time.Second) + ctx, cancel = context.TimeoutCause( + ctx, 7*time.Second, errors.New("connection took too long"), + ) defer cancel() } - conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig) - if err != nil { - return errorf.E( - "error opening websocket to '%s': %s", r.URL, err.Error(), - ) + var conn *Connection + if conn, err = NewConnection( + ctx, r.URL, r.requestHeader, tlsConfig, + ); chk.E(err) { + err = fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) + return } r.Connection = conn - // ping every 29 seconds (??) + // ping every 29 seconds ticker := time.NewTicker(29 * time.Second) - // to be used when the connection is closed - go func() { - <-r.connectionContext.Done() - // close these things when the connection is closed - if r.notices != nil { - close(r.notices) - } - // stop the ticker - ticker.Stop() - // close all subscriptions - r.Subscriptions.Range( - func(_ string, sub *Subscription) bool { - go sub.Unsub() - return true - }, - ) - }() // queue all write operations here so we don't do mutex spaghetti go func() { - var err error for { select { + case <-r.connectionContext.Done(): + ticker.Stop() + r.Connection = nil + for _, sub := range r.Subscriptions.Range { + sub.unsub( + fmt.Errorf( + "relay connection closed: %w / %w", + context.GetCause(r.connectionContext), + r.ConnectionError, + ), + ) + } + return case <-ticker.C: - err = wsutil.WriteClientMessage( - r.Connection.conn, ws.OpPing, nil, - ) - if err != nil { - log.D.F( + err := r.Connection.Ping(r.connectionContext) + if err != nil && !strings.Contains( + err.Error(), "failed to wait for pong", + ) { + log.I.F( "{%s} error writing ping: %v; closing websocket", r.URL, err, ) r.Close() // this should trigger a context cancelation return } - case writeReq := <-r.writeQueue: + case writeRequest := <-r.writeQueue: // all write requests will go through this to prevent races - if err = r.Connection.WriteMessage( - r.connectionContext, - writeReq.msg, - ); chk.T(err) { - writeReq.answer <- err + log.D.F("{%s} sending %v\n", r.URL, string(writeRequest.msg)) + if err := r.Connection.WriteMessage( + r.connectionContext, writeRequest.msg, + ); err != nil { + writeRequest.answer <- err } - close(writeReq.answer) - case <-r.connectionContext.Done(): - // stop here - return + close(writeRequest.answer) } } }() // general message reader loop go func() { - var err error + for { buf := new(bytes.Buffer) - if err = conn.ReadMessage(r.connectionContext, buf); err != nil { + if err := conn.ReadMessage(r.connectionContext, buf); err != nil { r.ConnectionError = err - r.Close() + r.close(err) break } - message := buf.Bytes() - // log.D.F("{%s} %s\n", r.URL, message) - + var err error var t string - if t, message, err = envelopes.Identify(message); chk.E(err) { + var rem []byte + if t, rem, err = envelopes.Identify(buf.Bytes()); chk.E(err) { continue } switch t { case noticeenvelope.L: - env := noticeenvelope.New() - if env, message, err = noticeenvelope.Parse(message); chk.E(err) { - continue - } + env := noticeenvelope.NewFrom(rem) // see WithNoticeHandler - if r.notices != nil { - r.notices <- env.Message + if r.noticeHandler != nil { + r.noticeHandler(string(env.Message)) } else { - log.E.F("NOTICE from %s: '%s'\n", r.URL, env.Message) + log.D.F( + "NOTICE from %s: '%s'\n", r.URL, string(env.Message), + ) } case authenvelope.L: - env := authenvelope.NewChallenge() - if env, message, err = authenvelope.ParseChallenge(message); chk.E(err) { - continue - } - if len(env.Challenge) == 0 { + env := authenvelope.NewChallengeWith(rem) + if env.Challenge == nil { continue } r.challenge = env.Challenge case eventenvelope.L: - // log.I.F("message: %s", message) - env := eventenvelope.NewResult() - if env, message, err = eventenvelope.ParseResult(message); err != nil { + var env *eventenvelope.Result + var ev *event.E + if env, err = eventenvelope.NewResultWith(rem, ev); chk.E(err) { continue } - // log.I.F("%s", env.Event.Marshal(nil)) - if len(env.Subscription.T) == 0 { + sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String())) + if !ok { + log.W.F( + "unknown subscription with id '%s'\n", + env.Subscription.String(), + ) continue } - if sub, ok := r.Subscriptions.Load(env.Subscription.String()); !ok { - // log.D.F( - // "{%s} no subscription with id '%s'\n", r.URL, - // env.Subscription, - // ) + if !sub.Filters.Match(env.Event) { + log.I.F( + "{%s} filter does not match: %v ~ %v\n", r.URL, + sub.Filters, env.Event, + ) continue - } else { - // check if the event matches the desired filter, ignore - // otherwise - if !sub.Filters.Match(env.Event) { - log.D.F( - "{%s} filter does not match: %v ~ %v\n", r.URL, - sub.Filters, env.Event, + } + if !r.AssumeValid { + if ok, err = env.Event.Verify(); !ok || chk.E(err) { + log.I.F( + "{%s} bad signature on %s\n", r.URL, + env.Event.ID, ) continue } - // check signature, ignore invalid, except from trusted - // (AssumeValid) relays - if !r.AssumeValid { - if ok = r.signatureChecker(env.Event); !ok { - log.E.F( - "{%s} bad signature on %s\n", r.URL, - env.Event.ID, - ) - continue - } - } - // dispatch this to the internal .events channel of the - // subscription - sub.dispatchEvent(env.Event) } + sub.dispatchEvent(env.Event) case eoseenvelope.L: - env := eoseenvelope.New() - if env, message, err = eoseenvelope.Parse(message); chk.E(err) { + var env *eoseenvelope.T + if env, rem, err = eoseenvelope.Parse(rem); chk.E(err) { continue } - if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok { - subscription.dispatchEose() + if len(rem) != 0 { + log.W.F( + "{%s} unexpected data after EOSE: %s\n", r.URL, + string(rem), + ) } + sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String())) + if !ok { + log.W.F( + "unknown subscription with id '%s'\n", + env.Subscription.String(), + ) + continue + } + sub.dispatchEose() case closedenvelope.L: - env := closedenvelope.New() - if env, message, err = closedenvelope.Parse(message); chk.E(err) { + var env *closedenvelope.T + if env, rem, err = closedenvelope.Parse(rem); chk.E(err) { continue } - if subscription, ok := r.Subscriptions.Load(env.Subscription.String()); ok { - subscription.dispatchClosed(env.ReasonString()) - } - case countenvelope.L: - env := countenvelope.NewResponse() - if env, message, err = countenvelope.Parse(message); chk.E(err) { + sub, ok := r.Subscriptions.Load(subIdToSerial(env.Subscription.String())) + if !ok { + log.W.F( + "unknown subscription with id '%s'\n", + env.Subscription.String(), + ) continue } - if subscription, ok := r.Subscriptions.Load(env.ID.String()); ok && subscription.countResult != nil { - subscription.countResult <- env.Count - } + sub.handleClosed(env.ReasonString()) case okenvelope.L: - env := okenvelope.New() - if env, message, err = okenvelope.Parse(message); chk.E(err) { + var env *okenvelope.T + if env, rem, err = okenvelope.Parse(rem); chk.E(err) { continue } if okCallback, exist := r.okCallbacks.Load(env.EventID.String()); exist { - okCallback(env.OK, env.ReasonString()) + okCallback(env.OK, string(env.Reason)) } else { log.I.F( "{%s} got an unexpected OK message for event %s", r.URL, env.EventID, ) } + default: + log.W.F("unknown envelope type %s\n%s", t, rem) + continue } } }() - return nil + return } -// Write queues a message to be sent to the relay. -func (r *Client) Write(msg []byte) <-chan error { +// Write queues an arbitrary message to be sent to the relay. +func (r *Relay) Write(msg []byte) <-chan error { ch := make(chan error) select { case r.writeQueue <- writeRequest{msg: msg, answer: ch}: case <-r.connectionContext.Done(): - go func() { ch <- errorf.E("connection closed") }() + go func() { ch <- fmt.Errorf("connection closed") }() } return ch } // Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an // OK response. -func (r *Client) Publish(c context.T, ev *event.E) error { - return r.publish( - c, ev, - ) +func (r *Relay) Publish(ctx context.T, ev *event.E) error { + return r.publish(ctx, ev.ID, ev) } // Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK // response. -func (r *Client) Auth(c context.T, sign signer.I) error { - authEvent := auth.CreateUnsigned(sign.Pub(), r.challenge, r.URL) - if err := authEvent.Sign(sign); chk.T(err) { - return errorf.E("error signing auth event: %w", err) +func (r *Relay) Auth( + ctx context.T, sign signer.I, +) (err error) { + authEvent := &event.E{ + CreatedAt: timestamp.Now(), + Kind: kind.ClientAuthentication, + Tags: tags.New( + tag.New("relay", r.URL), + tag.New([]byte("challenge"), r.challenge), + ), } - return r.publish(c, authEvent) + if err = authEvent.Sign(sign); chk.E(err) { + err = fmt.Errorf("error signing auth event: %w", err) + return + } + return r.publish(ctx, authEvent.ID, authEvent) } -// publish can be used both for EVENT and for AUTH -func (r *Client) publish(ctx context.T, ev *event.E) (err error) { +func (r *Relay) publish( + ctx context.T, id []byte, env *event.E, +) error { + var err error var cancel context.F if _, ok := ctx.Deadline(); !ok { // if no timeout is set, force it to 7 seconds ctx, cancel = context.TimeoutCause( - ctx, 7*time.Second, - errorf.E("given up waiting for an OK"), + ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"), ) defer cancel() } else { @@ -398,32 +415,23 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) { ctx, cancel = context.Cancel(ctx) defer cancel() } + // listen for an OK callback gotOk := false - id := ev.IdString() + ids := string(id) r.okCallbacks.Store( - id, func(ok bool, reason string) { + ids, func(ok bool, reason string) { gotOk = true if !ok { - err = errorf.E("msg: %s", reason) + err = fmt.Errorf("msg: %s", reason) } cancel() }, ) - defer r.okCallbacks.Delete(id) + defer r.okCallbacks.Delete(ids) // publish event - var b []byte - if ev.Kind.Equal(kind.ClientAuthentication) { - if b = authenvelope.NewResponseWith(ev).Marshal(b); chk.E(err) { - return - } - } else { - if b = eventenvelope.NewSubmissionWith(ev).Marshal(b); chk.E(err) { - return - } - } - log.T.F("{%s} sending %s\n", r.URL, b) - if err = <-r.Write(b); chk.T(err) { + envb := env.Marshal(nil) + if err = <-r.Write(envb); err != nil { return err } for { @@ -447,111 +455,168 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) { // context ctx is cancelled ("CLOSE" in NIP-01). // // Remember to cancel subscriptions, either by calling `.Unsub()` on them or -// ensuring their `context.Context` will be canceled at some point. Failure to +// ensuring their `context.T` will be canceled at some point. Failure to // do that will result in a huge number of halted goroutines being created. -func (r *Client) Subscribe( - c context.T, ff *filters.T, - opts ...SubscriptionOption, +func (r *Relay) Subscribe( + ctx context.T, ff *filters.T, opts ...SubscriptionOption, ) (sub *Subscription, err error) { - sub = r.PrepareSubscription(c, ff, opts...) + sub = r.PrepareSubscription(ctx, ff, opts...) if r.Connection == nil { - return nil, errorf.E("not connected to %s", r.URL) + return nil, fmt.Errorf("not connected to %s", r.URL) } - if err = sub.Fire(); chk.T(err) { - return nil, errorf.E( - "couldn't subscribe to %v at %s: %w", ff, r.URL, err, + if err = sub.Fire(); err != nil { + err = fmt.Errorf( + "couldn't subscribe to %v at %s: %w", ff.Marshal(nil), r.URL, err, ) + return } return } // PrepareSubscription creates a subscription, but doesn't fire it. // -// Remember to cancel subscriptions, either by calling `.Unsub()` on them or -// ensuring their `context.Context` will be canceled at some point. Failure to -// do that will result in a huge number of halted goroutines being created. -func (r *Client) PrepareSubscription( - c context.T, ff *filters.T, - opts ...SubscriptionOption, +// Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.T` will be canceled at some point. +// Failure to do that will result in a huge number of halted goroutines being created. +func (r *Relay) PrepareSubscription( + ctx context.T, ff *filters.T, opts ...SubscriptionOption, ) *Subscription { current := subscriptionIDCounter.Add(1) - c, cancel := context.Cancel(c) + ctx, cancel := context.Cause(ctx) sub := &Subscription{ Relay: r, - Context: c, + Context: ctx, cancel: cancel, - counter: int(current), + counter: current, Events: make(event.C), EndOfStoredEvents: make(chan struct{}, 1), ClosedReason: make(chan string, 1), Filters: ff, + match: ff.Match, } + label := "" for _, opt := range opts { switch o := opt.(type) { case WithLabel: - sub.label = string(o) + label = string(o) + // case WithCheckDuplicate: + // sub.checkDuplicate = o + // case WithCheckDuplicateReplaceable: + // sub.checkDuplicateReplaceable = o } } - id := sub.GetID() - r.Subscriptions.Store(id.String(), sub) + // subscription id computation + buf := subIdPool.Get().([]byte)[:0] + buf = strconv.AppendInt(buf, sub.counter, 10) + buf = append(buf, ':') + buf = append(buf, label...) + defer subIdPool.Put(buf) + sub.id = string(buf) + // we track subscriptions only by their counter, no need for the full id + r.Subscriptions.Store(int64(sub.counter), sub) // start handling events, eose, unsub etc: go sub.start() return sub } -// QuerySync is only used in tests. The relay query method is synchronous now -// anyway (it ensures sort order is respected). -func (r *Client) QuerySync( - ctx context.T, f *filter.F, - opts ...SubscriptionOption, -) ([]*event.E, error) { - // log.T.F("QuerySync:\n%s", f.Marshal(nil)) - sub, err := r.Subscribe(ctx, filters.New(f), opts...) +// QueryEvents subscribes to events matching the given filter and returns a channel of events. +// +// In most cases it's better to use Pool instead of this method. +func (r *Relay) QueryEvents(ctx context.T, f *filter.F) ( + evc event.C, err error, +) { + var sub *Subscription + if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) { + return + } + go func() { + for { + select { + case <-sub.ClosedReason: + case <-sub.EndOfStoredEvents: + case <-ctx.Done(): + case <-r.Context().Done(): + } + sub.unsub(errors.New("QueryEvents() ended")) + return + } + }() + return sub.Events, nil +} + +// QuerySync subscribes to events matching the given filter and returns a slice +// of events. This method blocks until all events are received or the context is +// canceled. +// +// If the filter causes a subscription to open, it will stay open until the +// limit is exceeded. So this method will return an error if the limit is nil. +// If the query blocks, the caller needs to cancel the context to prevent the +// thread stalling. +func (r *Relay) QuerySync(ctx context.T, f *filter.F) ( + evs event.S, err error, +) { + if f.Limit == nil { + err = errors.New("limit must be set for a sync query to prevent blocking") + return + } + var sub *Subscription + if sub, err = r.Subscribe(ctx, filters.New(f)); chk.E(err) { + return + } + defer sub.unsub(errors.New("QuerySync() ended")) + evs = make(event.S, 0, *f.Limit) + if _, ok := ctx.Deadline(); !ok { + // if no timeout is set, force it to 7 seconds + var cancel context.F + ctx, cancel = context.TimeoutCause( + ctx, 7*time.Second, errors.New("QuerySync() took too long"), + ) + defer cancel() + } + lim := 250 + if f.Limit != nil { + lim = int(*f.Limit) + } + events := make(event.S, 0, max(lim, 250)) + ch, err := r.QueryEvents(ctx, f) if err != nil { return nil, err } - defer sub.Unsub() - - if _, ok := ctx.Deadline(); !ok { - // if no timeout is set, force it to 7 seconds - var cancel context.F - ctx, cancel = context.Timeout(ctx, 7*time.Second) - defer cancel() + for evt := range ch { + events = append(events, evt) } - var events []*event.E - for { - select { - case evt := <-sub.Events: - if evt == nil { - // channel is closed - return events, nil - } - events = append(events, evt) - case <-sub.EndOfStoredEvents: - return events, nil - case <-ctx.Done(): - return events, nil - } - } + return events, nil } -// TODO: count is a dumb idea anyway, and nothing is using this -// func (r *Client) Count(c context.F, ff *filters.F, opts ...SubscriptionOption) (int, error) { -// sub := r.PrepareSubscription(c, ff, opts...) -// sub.countResult = make(chan int) -// -// if err := sub.Fire(); chk.F(err) { -// return 0, err +// // Count sends a "COUNT" command to the relay and returns the count of events matching the filters. +// func (r *Relay) Count( +// ctx context.T, +// filters Filters, +// opts ...SubscriptionOption, +// ) (int64, []byte, error) { +// v, err := r.countInternal(ctx, filters, opts...) +// if err != nil { +// return 0, nil, err // } // -// defer sub.Unsub() +// return *v.Count, v.HyperLogLog, nil +// } // -// if _, ok := c.Deadline(); !ok { +// func (r *Relay) countInternal(ctx context.T, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) { +// sub := r.PrepareSubscription(ctx, filters, opts...) +// sub.countResult = make(chan CountEnvelope) +// +// if err := sub.Fire(); err != nil { +// return CountEnvelope{}, err +// } +// +// defer sub.unsub(errors.New("countInternal() ended")) +// +// if _, ok := ctx.Deadline(); !ok { // // if no timeout is set, force it to 7 seconds -// var cancel context.F -// c, cancel = context.Timeout(c, 7*time.Second) +// var cancel context.CancelFunc +// ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long")) // defer cancel() // } // @@ -559,28 +624,39 @@ func (r *Client) QuerySync( // select { // case count := <-sub.countResult: // return count, nil -// case <-c.Done(): -// return 0, c.Err() +// case <-ctx.Done(): +// return CountEnvelope{}, ctx.Err() // } // } // } -// Close shuts down a websocket client connection. -func (r *Client) Close() error { +// Close closes the relay connection. +func (r *Relay) Close() error { + return r.close(errors.New("relay connection closed")) +} + +func (r *Relay) close(reason error) error { r.closeMutex.Lock() defer r.closeMutex.Unlock() + if r.connectionContextCancel == nil { - return errorf.E("relay already closed") + return fmt.Errorf("relay already closed") } - r.connectionContextCancel() + r.connectionContextCancel(reason) r.connectionContextCancel = nil + if r.Connection == nil { - return errorf.E("relay not connected") + return fmt.Errorf("relay not connected") } + err := r.Connection.Close() - r.Connection = nil if err != nil { return err } + return nil } + +var subIdPool = sync.Pool{ + New: func() any { return make([]byte, 0, 15) }, +} diff --git a/pkg/protocol/ws/client_test.go b/pkg/protocol/ws/client_test.go index 92e89a3..a25d563 100644 --- a/pkg/protocol/ws/client_test.go +++ b/pkg/protocol/ws/client_test.go @@ -1,146 +1,159 @@ package ws import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" "io" "net/http" "net/http/httptest" - "orly.dev/pkg/crypto/p256k" - "orly.dev/pkg/encoders/envelopes/eventenvelope" - "orly.dev/pkg/encoders/envelopes/okenvelope" - "orly.dev/pkg/encoders/event" - "orly.dev/pkg/encoders/kind" - "orly.dev/pkg/encoders/tag" - "orly.dev/pkg/encoders/tags" - "orly.dev/pkg/encoders/timestamp" - "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/normalize" "sync" "testing" "time" + "orly.dev/pkg/crypto/p256k" + "orly.dev/pkg/encoders/event" + "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/timestamp" + "orly.dev/pkg/utils/chk" + "orly.dev/pkg/utils/normalize" + "golang.org/x/net/websocket" ) -func TestPublish(t *testing.T) { - // test note to be sent over websocket - var err error - signer := &p256k.Signer{} - if err = signer.Generate(); chk.E(err) { - t.Fatal(err) - } - textNote := &event.E{ - Kind: kind.TextNote, - Content: []byte("hello"), - CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp - Tags: tags.New(tag.New("foo", "bar")), - Pubkey: signer.Pub(), - } - if err = textNote.Sign(signer); chk.E(err) { - t.Fatalf("textNote.Sign: %v", err) - } - // fake relay server - var mu sync.Mutex // guards published to satisfy `go test -race` - var published bool - ws := newWebsocketServer( - func(conn *websocket.Conn) { - mu.Lock() - published = true - mu.Unlock() - // verify the client sent exactly the textNote - var raw []json.RawMessage - if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { - t.Errorf("websocket.JSON.Receive: %v", err) - } - - if string(raw[0]) != fmt.Sprintf(`"%s"`, eventenvelope.L) { - t.Errorf("got type %s, want %s", raw[0], eventenvelope.L) - } - env := eventenvelope.NewSubmission() - if raw[1], err = env.Unmarshal(raw[1]); chk.E(err) { - t.Fatal(err) - } - if !bytes.Equal(env.E.Serialize(), textNote.Serialize()) { - t.Errorf( - "received event:\n%s\nwant:\n%s", env.E.Serialize(), - textNote.Serialize(), - ) - } - // send back an ok nip-20 command result - var res []byte - if res = okenvelope.NewFrom( - textNote.ID, true, nil, - ).Marshal(res); chk.E(err) { - t.Fatal(err) - } - if err := websocket.Message.Send(conn, res); chk.T(err) { - t.Errorf("websocket.JSON.Send: %v", err) - } - }, - ) - defer ws.Close() - // connect a client and send the text note - rl := mustRelayConnect(ws.URL) - err = rl.Publish(context.Bg(), textNote) - if err != nil { - t.Errorf("publish should have succeeded") - } - if !published { - t.Errorf("fake relay server saw no event") - } -} - -func TestPublishBlocked(t *testing.T) { - // test note to be sent over websocket - var err error - signer := &p256k.Signer{} - if err = signer.Generate(); chk.E(err) { - t.Fatal(err) - } - textNote := &event.E{ - Kind: kind.TextNote, - Content: []byte("hello"), - CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp - Pubkey: signer.Pub(), - } - if err = textNote.Sign(signer); chk.E(err) { - t.Fatalf("textNote.Sign: %v", err) - } - // fake relay server - ws := newWebsocketServer( - func(conn *websocket.Conn) { - // discard received message; not interested - var raw []json.RawMessage - if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { - t.Errorf("websocket.JSON.Receive: %v", err) - } - // send back a not ok nip-20 command result - var res []byte - if res = okenvelope.NewFrom( - textNote.ID, false, - normalize.Msg(normalize.Blocked, "no reason"), - ).Marshal(res); chk.E(err) { - t.Fatal(err) - } - if err := websocket.Message.Send(conn, res); chk.T(err) { - t.Errorf("websocket.JSON.Send: %v", err) - } - // res := []any{"OK", textNote.ID, false, "blocked"} - chk.E(websocket.JSON.Send(conn, res)) - }, - ) - defer ws.Close() - - // connect a client and send a text note - rl := mustRelayConnect(ws.URL) - if err = rl.Publish(context.Bg(), textNote); !chk.E(err) { - t.Errorf("should have failed to publish") - } -} +// func TestPublish(t *testing.T) { +// // test note to be sent over websocket +// var err error +// signer := &p256k.Signer{} +// if err = signer.Generate(); chk.E(err) { +// t.Fatal(err) +// } +// textNote := &event.E{ +// Kind: kind.TextNote, +// Content: []byte("hello"), +// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp +// Pubkey: signer.Pub(), +// } +// if err = textNote.Sign(signer); chk.E(err) { +// t.Fatalf("textNote.Sign: %v", err) +// } +// // fake relay server +// var published bool +// ws := newWebsocketServer( +// func(conn *websocket.Conn) { +// // receive message +// var raw []json.RawMessage +// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { +// t.Errorf("websocket.JSON.Receive: %v", err) +// } +// // check that it's an EVENT message +// if len(raw) < 2 { +// t.Errorf("message too short: %v", raw) +// } +// var msgType string +// if err := json.Unmarshal(raw[0], &msgType); chk.T(err) { +// t.Errorf("json.Unmarshal: %v", err) +// } +// if msgType != "EVENT" { +// t.Errorf("expected EVENT message, got %q", msgType) +// } +// // check that the event is the one we sent +// var ev event.E +// if err := json.Unmarshal(raw[1], &ev); chk.T(err) { +// t.Errorf("json.Unmarshal: %v", err) +// } +// published = true +// if !bytes.Equal(ev.ID, textNote.ID) { +// t.Errorf( +// "event ID mismatch: got %x, want %x", +// ev.ID, textNote.ID, +// ) +// } +// if !bytes.Equal(ev.Pubkey, textNote.Pubkey) { +// t.Errorf( +// "event pubkey mismatch: got %x, want %x", +// ev.Pubkey, textNote.Pubkey, +// ) +// } +// if !bytes.Equal(ev.Content, textNote.Content) { +// t.Errorf( +// "event content mismatch: got %q, want %q", +// ev.Content, textNote.Content, +// ) +// } +// fmt.Printf( +// "received event: %s\n", +// textNote.Serialize(), +// ) +// // send back an ok nip-20 command result +// var res []byte +// if res = okenvelope.NewFrom( +// textNote.ID, true, nil, +// ).Marshal(res); chk.E(err) { +// t.Fatal(err) +// } +// if err := websocket.Message.Send(conn, res); chk.T(err) { +// t.Errorf("websocket.Message.Send: %v", err) +// } +// }, +// ) +// defer ws.Close() +// // connect a client and send the text note +// rl := mustRelayConnect(ws.URL) +// err = rl.Publish(context.Background(), textNote) +// if err != nil { +// t.Errorf("publish should have succeeded") +// } +// if !published { +// t.Errorf("fake relay server saw no event") +// } +// } +// +// func TestPublishBlocked(t *testing.T) { +// // test note to be sent over websocket +// var err error +// signer := &p256k.Signer{} +// if err = signer.Generate(); chk.E(err) { +// t.Fatal(err) +// } +// textNote := &event.E{ +// Kind: kind.TextNote, +// Content: []byte("hello"), +// CreatedAt: timestamp.FromUnix(1672068534), // random fixed timestamp +// Pubkey: signer.Pub(), +// } +// if err = textNote.Sign(signer); chk.E(err) { +// t.Fatalf("textNote.Sign: %v", err) +// } +// // fake relay server +// ws := newWebsocketServer( +// func(conn *websocket.Conn) { +// // discard received message; not interested +// var raw []json.RawMessage +// if err := websocket.JSON.Receive(conn, &raw); chk.T(err) { +// t.Errorf("websocket.JSON.Receive: %v", err) +// } +// // send back a not ok nip-20 command result +// var res []byte +// if res = okenvelope.NewFrom( +// textNote.ID, false, +// normalize.Msg(normalize.Blocked, "no reason"), +// ).Marshal(res); chk.E(err) { +// t.Fatal(err) +// } +// if err := websocket.Message.Send(conn, res); chk.T(err) { +// t.Errorf("websocket.Message.Send: %v", err) +// } +// // res := []any{"OK", textNote.ID, false, "blocked"} +// }, +// ) +// defer ws.Close() +// +// // connect a client and send a text note +// rl := mustRelayConnect(ws.URL) +// if err = rl.Publish(context.Background(), textNote); !chk.E(err) { +// t.Errorf("should have failed to publish") +// } +// } func TestPublishWriteFailed(t *testing.T) { // test note to be sent over websocket @@ -171,7 +184,7 @@ func TestPublishWriteFailed(t *testing.T) { rl := mustRelayConnect(ws.URL) // Force brief period of time so that publish always fails on closed socket. time.Sleep(1 * time.Millisecond) - err = rl.Publish(context.Bg(), textNote) + err = rl.Publish(context.Background(), textNote) if err == nil { t.Errorf("should have failed to publish") } @@ -192,7 +205,7 @@ func TestConnectContext(t *testing.T) { defer ws.Close() // relay client - ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() r, err := RelayConnect(ctx, ws.URL) if err != nil { @@ -213,7 +226,7 @@ func TestConnectContextCanceled(t *testing.T) { defer ws.Close() // relay client - ctx, cancel := context.Cancel(context.Bg()) + ctx, cancel := context.WithCancel(context.Background()) cancel() // make ctx expired _, err := RelayConnect(ctx, ws.URL) if !errors.Is(err, context.Canceled) { @@ -230,9 +243,9 @@ func TestConnectWithOrigin(t *testing.T) { defer ws.Close() // relay client - r := NewRelay(context.Bg(), string(normalize.URL(ws.URL))) - r.RequestHeader = http.Header{"origin": {"https://example.com"}} - ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second) + r := NewRelay(context.Background(), string(normalize.URL(ws.URL))) + r.requestHeader = http.Header{"origin": {"https://example.com"}} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := r.Connect(ctx) if err != nil { @@ -262,8 +275,8 @@ var anyOriginHandshake = func( return nil } -func mustRelayConnect(url string) (client *Client) { - rl, err := RelayConnect(context.Bg(), url) +func mustRelayConnect(url string) (client *Relay) { + rl, err := RelayConnect(context.Background(), url) if err != nil { panic(err.Error()) } diff --git a/pkg/protocol/ws/connection.go b/pkg/protocol/ws/connection.go index b3040d5..c9eeee0 100644 --- a/pkg/protocol/ws/connection.go +++ b/pkg/protocol/ws/connection.go @@ -1,224 +1,98 @@ package ws import ( - "bytes" - "compress/flate" + "context" "crypto/tls" + "errors" "fmt" - "github.com/gobwas/httphead" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsflate" - "github.com/gobwas/ws/wsutil" "io" - "net" "net/http" - "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/context" - "orly.dev/pkg/utils/errorf" - "orly.dev/pkg/utils/log" + "net/textproto" + "time" + + ws "github.com/coder/websocket" ) -// Connection is an outbound client -> relay connection. -type Connection struct { - conn net.Conn - enableCompression bool - controlHandler wsutil.FrameHandlerFunc - flateReader *wsflate.Reader - reader *wsutil.Reader - flateWriter *wsflate.Writer - writer *wsutil.Writer - msgStateR *wsflate.MessageState - msgStateW *wsflate.MessageState +var defaultConnectionOptions = &ws.DialOptions{ + CompressionMode: ws.CompressionContextTakeover, + HTTPHeader: http.Header{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"}, + }, } -// NewConnection creates a new Connection. -func NewConnection( - c context.T, url string, requestHeader http.Header, - tlsConfig *tls.Config, -) (connection *Connection, errResult error) { - dialer := ws.Dialer{ - Header: ws.HandshakeHeaderHTTP(requestHeader), - Extensions: []httphead.Option{ - wsflate.DefaultParameters.Option(), - }, - TLSConfig: tlsConfig, +func getConnectionOptions( + requestHeader http.Header, tlsConfig *tls.Config, +) *ws.DialOptions { + if requestHeader == nil && tlsConfig == nil { + return defaultConnectionOptions } - conn, _, hs, err := dialer.Dial(c, url) + + return &ws.DialOptions{ + HTTPHeader: requestHeader, + CompressionMode: ws.CompressionContextTakeover, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }, + } +} + +// Connection represents a websocket connection to a Nostr relay. +type Connection struct { + conn *ws.Conn +} + +// NewConnection creates a new websocket connection to a Nostr relay. +func NewConnection( + ctx context.Context, url string, requestHeader http.Header, + tlsConfig *tls.Config, +) (*Connection, error) { + c, _, err := ws.Dial( + ctx, url, getConnectionOptions(requestHeader, tlsConfig), + ) if err != nil { return nil, err } - enableCompression := false - state := ws.StateClientSide - for _, extension := range hs.Extensions { - if string(extension.Name) == wsflate.ExtensionName { - enableCompression = true - state |= ws.StateExtended - break - } - } - // reader - var flateReader *wsflate.Reader - var msgStateR wsflate.MessageState - if enableCompression { - msgStateR.SetCompressed(true) - flateReader = wsflate.NewReader( - nil, func(r io.Reader) wsflate.Decompressor { - return flate.NewReader(r) - }, - ) - } - controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide) - reader := &wsutil.Reader{ - Source: conn, - State: state, - OnIntermediate: controlHandler, - CheckUTF8: false, - Extensions: []wsutil.RecvExtension{ - &msgStateR, - }, - } - // writer - var flateWriter *wsflate.Writer - var msgStateW wsflate.MessageState - if enableCompression { - msgStateW.SetCompressed(true) + c.SetReadLimit(2 << 24) // 33MB - flateWriter = wsflate.NewWriter( - nil, func(w io.Writer) wsflate.Compressor { - fw, err := flate.NewWriter(w, 4) - if err != nil { - log.E.F("Failed to create flate writer: %v", err) - } - return fw - }, - ) - } - writer := wsutil.NewWriter(conn, state, ws.OpText) - writer.SetExtensions(&msgStateW) return &Connection{ - conn: conn, - enableCompression: enableCompression, - controlHandler: controlHandler, - flateReader: flateReader, - reader: reader, - msgStateR: &msgStateR, - flateWriter: flateWriter, - writer: writer, - msgStateW: &msgStateW, + conn: c, }, nil } -// WriteMessage dispatches a message through the Connection. -func (cn *Connection) WriteMessage(c context.T, data []byte) (err error) { - select { - case <-c.Done(): - return errorf.E( - "%s context canceled", - cn.conn.RemoteAddr(), - ) - default: +// WriteMessage writes arbitrary bytes to the websocket connection. +func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { + if err := c.conn.Write(ctx, ws.MessageText, data); err != nil { + return fmt.Errorf("failed to write message: %w", err) } - if cn.msgStateW.IsCompressed() && cn.enableCompression { - cn.flateWriter.Reset(cn.writer) - if _, err := io.Copy( - cn.flateWriter, bytes.NewReader(data), - ); chk.T(err) { - return errorf.E( - "%s failed to write message: %w", - cn.conn.RemoteAddr(), - err, - ) - } - if err := cn.flateWriter.Close(); chk.T(err) { - return errorf.E( - "%s failed to close flate writer: %w", - cn.conn.RemoteAddr(), - err, - ) - } - } else { - if _, err := io.Copy(cn.writer, bytes.NewReader(data)); chk.T(err) { - return errorf.E( - "%s failed to write message: %w", - cn.conn.RemoteAddr(), - err, - ) - } + return nil +} + +// ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer. +func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { + _, reader, err := c.conn.Reader(ctx) + if err != nil { + return fmt.Errorf("failed to get reader: %w", err) } - if err := cn.writer.Flush(); chk.T(err) { - return errorf.E( - "%s failed to flush writer: %w", - cn.conn.RemoteAddr(), - err, - ) + if _, err := io.Copy(buf, reader); err != nil { + return fmt.Errorf("failed to read message: %w", err) } return nil } -// ReadMessage picks up the next incoming message on a Connection. -func (cn *Connection) ReadMessage(c context.T, buf io.Writer) (err error) { - for { - select { - case <-c.Done(): - return errorf.D( - "%s context canceled", - cn.conn.RemoteAddr(), - ) - default: - } - h, err := cn.reader.NextFrame() - if err != nil { - cn.conn.Close() - return fmt.Errorf( - "%s failed to advance frame: %s", - cn.conn.RemoteAddr(), - err.Error(), - ) - } - if h.OpCode.IsControl() { - if err := cn.controlHandler(h, cn.reader); chk.T(err) { - return errorf.E( - "%s failed to handle control frame: %w", - cn.conn.RemoteAddr(), - err, - ) - } - } else if h.OpCode == ws.OpBinary || - h.OpCode == ws.OpText { - break - } - if err := cn.reader.Discard(); chk.T(err) { - return errorf.E( - "%s failed to discard: %w", - cn.conn.RemoteAddr(), - err, - ) - } - } - if cn.msgStateR.IsCompressed() && cn.enableCompression { - cn.flateReader.Reset(cn.reader) - if _, err := io.Copy(buf, cn.flateReader); chk.T(err) { - return errorf.E( - "%s failed to read message: %w", - cn.conn.RemoteAddr(), - err, - ) - } - } else { - if _, err := io.Copy(buf, cn.reader); chk.T(err) { - return errorf.E( - "%s failed to read message: %w", - cn.conn.RemoteAddr(), - err, - ) - } - } - return nil +// Close closes the websocket connection. +func (c *Connection) Close() error { + return c.conn.Close(ws.StatusNormalClosure, "") } -// Close the Connection. -func (cn *Connection) Close() (err error) { - return cn.conn.Close() +// Ping sends a ping message to the websocket connection. +func (c *Connection) Ping(ctx context.Context) error { + ctx, cancel := context.WithTimeoutCause( + ctx, time.Millisecond*800, errors.New("ping took too long"), + ) + defer cancel() + return c.conn.Ping(ctx) } diff --git a/pkg/protocol/ws/listener.go b/pkg/protocol/ws/listener.go index f9824c4..8b83c0f 100644 --- a/pkg/protocol/ws/listener.go +++ b/pkg/protocol/ws/listener.go @@ -3,12 +3,13 @@ package ws import ( "net/http" + "strings" + "sync" + "orly.dev/pkg/app/relay/helpers" "orly.dev/pkg/encoders/event" "orly.dev/pkg/protocol/auth" atomic2 "orly.dev/pkg/utils/atomic" - "strings" - "sync" "github.com/fasthttp/websocket" ) diff --git a/pkg/protocol/ws/pool.go b/pkg/protocol/ws/pool.go index bd732d3..247cb94 100644 --- a/pkg/protocol/ws/pool.go +++ b/pkg/protocol/ws/pool.go @@ -1,348 +1,425 @@ package ws import ( + "errors" "fmt" - "orly.dev/pkg/encoders/event" - "orly.dev/pkg/encoders/filter" - "orly.dev/pkg/encoders/filters" - "orly.dev/pkg/encoders/timestamp" - "orly.dev/pkg/interfaces/signer" - "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/context" - "orly.dev/pkg/utils/errorf" - "orly.dev/pkg/utils/log" - "orly.dev/pkg/utils/normalize" + "math" + "net/http" "slices" "strings" "sync" + "sync/atomic" "time" "unsafe" "github.com/puzpuzpuz/xsync/v3" + "orly.dev/pkg/encoders/event" + "orly.dev/pkg/encoders/filter" + "orly.dev/pkg/encoders/filters" + "orly.dev/pkg/encoders/hex" + "orly.dev/pkg/encoders/timestamp" + "orly.dev/pkg/interfaces/signer" + "orly.dev/pkg/utils/context" + "orly.dev/pkg/utils/log" + "orly.dev/pkg/utils/normalize" ) -var ( - seenAlreadyDropTick = 60 +const ( + seenAlreadyDropTick = time.Minute ) +// Pool manages connections to multiple relays, ensures they are reopened when necessary and not duplicated. type Pool struct { - Relays *xsync.MapOf[string, *Client] - Context context.T - authHandler func() signer.I - cancel context.F - eventMiddleware []func(IncomingEvent) + Relays *xsync.MapOf[string, *Relay] + Context context.T + + authHandler func() signer.I + cancel context.C + + eventMiddleware func(RelayEvent) + duplicateMiddleware func(relay, id string) + queryMiddleware func(relay, pubkey string, kind uint16) + // custom things not often used - SignatureChecker func(*event.E) bool + penaltyBoxMu sync.Mutex + penaltyBox map[string][2]float64 + relayOptions []RelayOption } -type DirectedFilters struct { - Filters *filters.T - Client string +// DirectedFilter combines a Filter with a specific relay URL. +type DirectedFilter struct { + *filter.F + Relay string } -type IncomingEvent struct { - Event *event.E - Client *Client +// RelayEvent represents an event received from a specific relay. +type RelayEvent struct { + *event.E + Relay *Relay } -func (ie IncomingEvent) String() string { - return fmt.Sprintf("[%s] >> %s", ie.Client.URL, ie.Event.Serialize()) +func (ie RelayEvent) String() string { + return fmt.Sprintf( + "[%s] >> %s", ie.Relay.URL, ie.E.Marshal(nil), + ) } +// PoolOption is an interface for options that can be applied to a Pool. type PoolOption interface { ApplyPoolOption(*Pool) } -func NewPool(c context.T, opts ...PoolOption) *Pool { - ctx, cancel := context.Cancel(c) - - pool := &Pool{ - Relays: xsync.NewMapOf[string, *Client](), - +// NewPool creates a new Pool with the given context and options. +func NewPool(c context.T, opts ...PoolOption) (pool *Pool) { + ctx, cancel := context.Cause(c) + pool = &Pool{ + Relays: xsync.NewMapOf[string, *Relay](), Context: ctx, cancel: cancel, } - for _, opt := range opts { opt.ApplyPoolOption(pool) } - return pool } -// WithAuthHandler must be a function that signs the auth event when called. it -// will be called whenever any relay in the pool returns a `CLOSED` message with -// the "auth-required:" prefix, only once for each relay +// WithRelayOptions sets options that will be used on every relay instance created by this pool. +func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt { + return ropts +} + +type withRelayOptionsOpt []RelayOption + +func (h withRelayOptionsOpt) ApplyPoolOption(pool *Pool) { + pool.relayOptions = h +} + +// WithAuthHandler must be a function that signs the auth event when called. +// it will be called whenever any relay in the pool returns a `CLOSED` message +// with the "auth-required:" prefix, only once for each relay type WithAuthHandler func() signer.I func (h WithAuthHandler) ApplyPoolOption(pool *Pool) { pool.authHandler = h } -// WithEventMiddleware is a function that will be called with all events -// received. more than one can be passed at a time. -type WithEventMiddleware func(IncomingEvent) +// WithPenaltyBox just sets the penalty box mechanism so relays that fail to connect +// or that disconnect will be ignored for a while and we won't attempt to connect again. +func WithPenaltyBox() withPenaltyBoxOpt { return withPenaltyBoxOpt{} } + +type withPenaltyBoxOpt struct{} + +func (h withPenaltyBoxOpt) ApplyPoolOption(pool *Pool) { + pool.penaltyBox = make(map[string][2]float64) + go func() { + sleep := 30.0 + for { + time.Sleep(time.Duration(sleep) * time.Second) + + pool.penaltyBoxMu.Lock() + nextSleep := 300.0 + for url, v := range pool.penaltyBox { + remainingSeconds := v[1] + remainingSeconds -= sleep + if remainingSeconds <= 0 { + pool.penaltyBox[url] = [2]float64{v[0], 0} + continue + } else { + pool.penaltyBox[url] = [2]float64{v[0], remainingSeconds} + } + + if remainingSeconds < nextSleep { + nextSleep = remainingSeconds + } + } + + sleep = nextSleep + pool.penaltyBoxMu.Unlock() + } + }() +} + +// WithEventMiddleware is a function that will be called with all events received. +type WithEventMiddleware func(RelayEvent) func (h WithEventMiddleware) ApplyPoolOption(pool *Pool) { - pool.eventMiddleware = append(pool.eventMiddleware, h) + pool.eventMiddleware = h +} + +// WithDuplicateMiddleware is a function that will be called with all duplicate ids received. +type WithDuplicateMiddleware func(relay string, id string) + +func (h WithDuplicateMiddleware) ApplyPoolOption(pool *Pool) { + pool.duplicateMiddleware = h +} + +// WithAuthorKindQueryMiddleware is a function that will be called with every combination of relay+pubkey+kind queried +// in a .SubMany*() call -- when applicable (i.e. when the query contains a pubkey and a kind). +type WithAuthorKindQueryMiddleware func(relay, pubkey string, kind uint16) + +func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *Pool) { + pool.queryMiddleware = h } var ( _ PoolOption = (WithAuthHandler)(nil) _ PoolOption = (WithEventMiddleware)(nil) + _ PoolOption = WithPenaltyBox() + _ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{})) ) -// MaxLocks is the maximum number of sync.Mutex locks used in a pool -// todo: is this too few? -const MaxLocks = 50 +const MAX_LOCKS = 50 -var namedMutexPool = make([]sync.Mutex, MaxLocks) +var namedMutexPool = make([]sync.Mutex, MAX_LOCKS) //go:noescape //go:linkname memhash runtime.memhash func memhash(p unsafe.Pointer, h, s uintptr) uintptr -func namedLock(name string) (unlock func()) { - sptr := unsafe.StringData(name) +func namedLock[V ~[]byte | ~string](name V) (unlock func()) { + sptr := unsafe.StringData(string(name)) idx := uint64( memhash( unsafe.Pointer(sptr), 0, uintptr(len(name)), ), - ) % MaxLocks + ) % MAX_LOCKS namedMutexPool[idx].Lock() return namedMutexPool[idx].Unlock } -// EnsureRelay connects a pool to a relay or fails. -func (pool *Pool) EnsureRelay(url string) (*Client, error) { +// EnsureRelay ensures that a relay connection exists and is active. +// If the relay is not connected, it attempts to connect. +func (p *Pool) EnsureRelay(url string) (*Relay, error) { nm := string(normalize.URL(url)) defer namedLock(nm)() - - relay, ok := pool.Relays.Load(nm) - if ok && relay.IsConnected() { + relay, ok := p.Relays.Load(nm) + if ok && relay == nil { + if p.penaltyBox != nil { + p.penaltyBoxMu.Lock() + defer p.penaltyBoxMu.Unlock() + v, _ := p.penaltyBox[nm] + if v[1] > 0 { + return nil, fmt.Errorf("in penalty box, %fs remaining", v[1]) + } + } + } else if ok && relay.IsConnected() { // already connected, unlock and return return relay, nil - } else { - var err error - // we use this ctx here, so when the pool dies everything dies - ctx, cancel := context.Timeout(pool.Context, time.Second*15) - defer cancel() - - opts := make([]RelayOption, 0, 1+len(pool.eventMiddleware)) - if pool.SignatureChecker != nil { - opts = append(opts, WithSignatureChecker(pool.SignatureChecker)) - } - - if relay, err = RelayConnect(ctx, nm, opts...); chk.T(err) { - return nil, errorf.E("failed to connect: %w", err) - } - - pool.Relays.Store(nm, relay) - return relay, nil } -} - -// SubMany opens a subscription with the given filters to multiple relays the -// subscriptions only end when the context is canceled -func (pool *Pool) SubMany( - c context.T, urls []string, ff *filters.T, -) chan IncomingEvent { - return pool.subMany(c, urls, ff, true) -} - -// SubManyNonUnique is like SubMany, but returns duplicate events if they come -// from different relays -func (pool *Pool) SubManyNonUnique( - c context.T, urls []string, - ff *filters.T, -) chan IncomingEvent { - return pool.subMany(c, urls, ff, false) -} - -func (pool *Pool) subMany( - c context.T, urls []string, ff *filters.T, - unique bool, -) chan IncomingEvent { - ctx, cancel := context.Cancel(c) - _ = cancel // do this so `go vet` will stop complaining - events := make(chan IncomingEvent) - seenAlready := xsync.NewMapOf[string, *timestamp.T]() - ticker := time.NewTicker(time.Duration(seenAlreadyDropTick) * time.Second) - eose := false - pending := xsync.NewCounter() - pending.Add(int64(len(urls))) - for u, url := range urls { - url = string(normalize.URL(url)) - urls[u] = url - if idx := slices.Index(urls, url); idx != u { - // skip duplicate relays in the list - continue - } - - go func(nm string) { - var err error - defer func() { - pending.Dec() - if pending.Value() == 0 { - close(events) - } - cancel() - }() - hasAuthed := false - interval := 3 * time.Second - for { - select { - case <-ctx.Done(): - return - default: - } - var sub *Subscription - var relay *Client - if relay, err = pool.EnsureRelay(nm); chk.T(err) { - goto reconnect - } - hasAuthed = false - subscribe: - if sub, err = relay.Subscribe(ctx, ff); chk.T(err) { - goto reconnect - } - go func() { - <-sub.EndOfStoredEvents - eose = true - }() - // reset interval when we get a good subscription - interval = 3 * time.Second - - for { - select { - case evt, more := <-sub.Events: - if !more { - // this means the connection was closed for weird - // reasons, like the server shutdown, so we will - // update the filters here to include only events - // seem from now on and try to reconnect until we - // succeed - now := timestamp.Now() - for i := range ff.F { - ff.F[i].Since = now - } - goto reconnect - } - ie := IncomingEvent{Event: evt, Client: relay} - for _, mh := range pool.eventMiddleware { - mh(ie) - } - if unique { - if _, seen := seenAlready.LoadOrStore( - evt.EventId().String(), - evt.CreatedAt, - ); seen { - continue - } - } - select { - case events <- ie: - case <-ctx.Done(): - } - case <-ticker.C: - if eose { - old := ×tamp.T{int64(timestamp.Now().Int() - seenAlreadyDropTick)} - seenAlready.Range( - func(id string, value *timestamp.T) bool { - if value.I64() < old.I64() { - seenAlready.Delete(id) - } - return true - }, - ) - } - case reason := <-sub.ClosedReason: - if strings.HasPrefix( - reason, - "auth-required:", - ) && pool.authHandler != nil && !hasAuthed { - // relay is requesting auth. if we can, we will - // perform auth and try again - if err = relay.Auth( - ctx, pool.authHandler(), - ); err == nil { - // so we don't keep doing AUTH again and again - hasAuthed = true - goto subscribe - } - } else { - log.I.F("CLOSED from %s: '%s'\n", nm, reason) - } - return - case <-ctx.Done(): - return - } - } - reconnect: - // we will go back to the beginning of the loop and try to - // connect again and again until the context is canceled - time.Sleep(interval) - interval = interval * 17 / 10 // the next time we try, we will wait longer + // try to connect + // we use this ctx here so when the p dies everything dies + ctx, cancel := context.TimeoutCause( + p.Context, + time.Second*15, + errors.New("connecting to the relay took too long"), + ) + defer cancel() + relay = NewRelay(context.Bg(), url, p.relayOptions...) + if err := relay.Connect(ctx); err != nil { + if p.penaltyBox != nil { + // putting relay in penalty box + p.penaltyBoxMu.Lock() + defer p.penaltyBoxMu.Unlock() + v, _ := p.penaltyBox[nm] + p.penaltyBox[nm] = [2]float64{ + v[0] + 1, 30.0 + math.Pow(2, v[0]+1), } - }(url) + } + return nil, fmt.Errorf("failed to connect: %w", err) } - return events + p.Relays.Store(nm, relay) + return relay, nil } -// SubManyEose is like SubMany, but it stops subscriptions and closes the -// channel when gets a EOSE -func (pool *Pool) SubManyEose( - c context.T, urls []string, ff *filters.T, -) chan IncomingEvent { - return pool.subManyEose(c, urls, ff, true) +// PublishResult represents the result of publishing an event to a relay. +type PublishResult struct { + Error error + RelayURL string + Relay *Relay } -// SubManyEoseNonUnique is like SubManyEose, but returns duplicate events if -// they come from different relays -func (pool *Pool) SubManyEoseNonUnique( - c context.T, urls []string, - ff *filters.T, -) chan IncomingEvent { - return pool.subManyEose(c, urls, ff, false) +// todo: this didn't used to be in this package... probably don't want to add it +// either. +// +// PublishMany publishes an event to multiple relays and returns a +// channel of results emitted as they're received. + +// func (pool *Pool) PublishMany( +// ctx context.T, urls []string, evt *event.E, +// ) chan PublishResult { +// ch := make(chan PublishResult, len(urls)) +// wg := sync.WaitGroup{} +// wg.Add(len(urls)) +// go func() { +// for _, url := range urls { +// go func() { +// defer wg.Done() +// relay, err := pool.EnsureRelay(url) +// if err != nil { +// ch <- PublishResult{err, url, nil} +// return +// } +// if err = relay.Publish(ctx, evt); err == nil { +// // success with no auth required +// ch <- PublishResult{nil, url, relay} +// } else if strings.HasPrefix( +// err.Error(), "msg: auth-required:", +// ) && pool.authHandler != nil { +// // try to authenticate if we can +// if authErr := relay.Auth( +// ctx, pool.authHandler(), +// ); authErr == nil { +// if err := relay.Publish(ctx, evt); err == nil { +// // success after auth +// ch <- PublishResult{nil, url, relay} +// } else { +// // failure after auth +// ch <- PublishResult{err, url, relay} +// } +// } else { +// // failure to auth +// ch <- PublishResult{ +// fmt.Errorf( +// "failed to auth: %w", authErr, +// ), url, relay, +// } +// } +// } else { +// // direct failure +// ch <- PublishResult{err, url, relay} +// } +// }() +// } +// +// wg.Wait() +// close(ch) +// }() +// +// return ch +// } + +// SubscribeMany opens a subscription with the given filter to multiple relays +// the subscriptions ends when the context is canceled or when all relays return a CLOSED. +func (p *Pool) SubscribeMany( + ctx context.T, + urls []string, + filter *filter.F, + opts ...SubscriptionOption, +) chan RelayEvent { + return p.subMany(ctx, urls, filters.New(filter), nil, opts...) } -func (pool *Pool) subManyEose( - c context.T, urls []string, ff *filters.T, - unique bool, -) chan IncomingEvent { - ctx, cancel := context.Cancel(c) +// FetchMany opens a subscription, much like SubscribeMany, but it ends as soon as all Relays +// return an EOSE message. +func (p *Pool) FetchMany( + ctx context.T, + urls []string, + filter *filter.F, + opts ...SubscriptionOption, +) chan RelayEvent { + return p.SubManyEose(ctx, urls, filters.New(filter), opts...) +} - events := make(chan IncomingEvent) - seenAlready := xsync.NewMapOf[string, bool]() +// Deprecated: SubMany is deprecated: use SubscribeMany instead. +func (p *Pool) SubMany( + ctx context.T, + urls []string, + filters *filters.T, + opts ...SubscriptionOption, +) chan RelayEvent { + return p.subMany(ctx, urls, filters, nil, opts...) +} + +// SubscribeManyNotifyEOSE is like SubscribeMany, but takes a channel that is closed when +// all subscriptions have received an EOSE +func (p *Pool) SubscribeManyNotifyEOSE( + ctx context.T, + urls []string, + filter *filter.F, + eoseChan chan struct{}, + opts ...SubscriptionOption, +) chan RelayEvent { + return p.subMany(ctx, urls, filters.New(filter), eoseChan, opts...) +} + +type ReplaceableKey struct { + PubKey string + D string +} + +// FetchManyReplaceable is like FetchMany, but deduplicates replaceable and addressable events and returns +// only the latest for each "d" tag. +func (p *Pool) FetchManyReplaceable( + ctx context.T, + urls []string, + f *filter.F, + opts ...SubscriptionOption, +) *xsync.MapOf[ReplaceableKey, *event.E] { + ctx, cancel := context.Cause(ctx) + results := xsync.NewMapOf[ReplaceableKey, *event.E]() wg := sync.WaitGroup{} wg.Add(len(urls)) - - go func() { - // this will happen when all subscriptions get an eose (or when they - // die) - wg.Wait() - cancel() - close(events) - }() + // todo: this is a hack for compensating for retarded relays that don't + // filter replaceable events because it streams them back over a channel. + // this is out of spec anyway so should not be handled. replaceable events + // are supposed to delete old versions. the end. this is for the incorrect + // behaviour of fiatjaf's database code, which he obviously thinks is clever + // for using channels, and not sorting results before dispatching them + // before EOSE. + _ = 0 + // seenAlreadyLatest := xsync.NewMapOf[ReplaceableKey, + // *timestamp.T]() opts = append( + // opts, WithCheckDuplicateReplaceable( + // func(rk ReplaceableKey, ts Timestamp) bool { + // updated := false + // seenAlreadyLatest.Compute( + // rk, func(latest Timestamp, _ bool) ( + // newValue Timestamp, delete bool, + // ) { + // if ts > latest { + // updated = true // we are updating the most recent + // return ts, false + // } + // return latest, false // the one we had was already more recent + // }, + // ) + // return updated + // }, + // ), + // ) for _, url := range urls { - go func(nm []byte) { - var err error + go func(nm string) { defer wg.Done() - var client *Client - if client, err = pool.EnsureRelay(string(nm)); chk.E(err) { + + if mh := p.queryMiddleware; mh != nil { + if f.Kinds != nil && f.Authors != nil { + for _, kind := range f.Kinds.K { + for _, author := range f.Authors.ToStringSlice() { + mh(nm, author, kind.K) + } + } + } + } + + relay, err := p.EnsureRelay(nm) + if err != nil { + log.D.F("error connecting to %s with %v: %s", nm, f, err) return } hasAuthed := false subscribe: - var sub *Subscription - if sub, err = client.Subscribe(ctx, ff); chk.E(err) || sub == nil { - log.E.F("error subscribing to %s with %v: %s", client, ff, err) + sub, err := relay.Subscribe(ctx, filters.New(f), opts...) + if err != nil { + log.D.F( + "error subscribing to %s with %v: %s", relay, f, err, + ) return } + for { select { case <-ctx.Done(): @@ -351,38 +428,344 @@ func (pool *Pool) subManyEose( return case reason := <-sub.ClosedReason: if strings.HasPrefix( - reason, - "auth-required:", - ) && pool.authHandler != nil && !hasAuthed { - // client is requesting auth. if we can, we will perform - // auth and try again - err = client.Auth(ctx, pool.authHandler()) + reason, "auth-required:", + ) && p.authHandler != nil && !hasAuthed { + // relay is requesting auth. if we can we will perform auth and try again + err = relay.Auth( + ctx, p.authHandler(), + ) if err == nil { - // so we don't keep doing AUTH again and again - hasAuthed = true + hasAuthed = true // so we don't keep doing AUTH again and again goto subscribe } } - log.I.F("CLOSED from %s: '%s'\n", nm, reason) + log.D.F("CLOSED from %s: '%s'\n", nm, reason) return case evt, more := <-sub.Events: if !more { return } - ie := IncomingEvent{Event: evt, Client: client} - for _, mh := range pool.eventMiddleware { + ie := RelayEvent{E: evt, Relay: relay} + if mh := p.eventMiddleware; mh != nil { mh(ie) } - if unique { - if _, seen := seenAlready.LoadOrStore( - evt.EventId().String(), - true, - ); seen { - continue + results.Store( + ReplaceableKey{hex.Enc(evt.Pubkey), evt.Tags.GetD()}, + evt, + ) + } + } + }(string(normalize.URL(url))) + } + + // this will happen when all subscriptions get an eose (or when they die) + wg.Wait() + cancel(errors.New("all subscriptions ended")) + + return results +} + +func (p *Pool) subMany( + ctx context.T, + urls []string, + ff *filters.T, + eoseChan chan struct{}, + opts ...SubscriptionOption, +) chan RelayEvent { + ctx, cancel := context.Cause(ctx) + _ = cancel // do this so `go vet` will stop complaining + events := make(chan RelayEvent) + seenAlready := xsync.NewMapOf[string, *timestamp.T]() + ticker := time.NewTicker(seenAlreadyDropTick) + + eoseWg := sync.WaitGroup{} + eoseWg.Add(len(urls)) + if eoseChan != nil { + go func() { + eoseWg.Wait() + close(eoseChan) + }() + } + + pending := xsync.NewCounter() + pending.Add(int64(len(urls))) + for i, url := range urls { + url = string(normalize.URL(url)) + urls[i] = url + if idx := slices.Index(urls, url); idx != i { + // skip duplicate relays in the list + eoseWg.Done() + continue + } + + eosed := atomic.Bool{} + firstConnection := true + + go func(nm string) { + defer func() { + pending.Dec() + if pending.Value() == 0 { + close(events) + cancel(fmt.Errorf("aborted: %w", context.GetCause(ctx))) + } + if eosed.CompareAndSwap(false, true) { + eoseWg.Done() + } + }() + + hasAuthed := false + interval := 3 * time.Second + for { + select { + case <-ctx.Done(): + return + default: + } + + var sub *Subscription + + if mh := p.queryMiddleware; mh != nil { + for _, f := range ff.F { + if f.Kinds != nil && f.Authors != nil { + for _, k := range f.Kinds.K { + for _, author := range f.Authors.ToSliceOfBytes() { + mh(nm, hex.Enc(author), k.K) + } + } } } + } + + relay, err := p.EnsureRelay(nm) + if err != nil { + // if we never connected to this just fail + if firstConnection { + return + } + + // otherwise (if we were connected and got disconnected) keep trying to reconnect + log.D.F("%s reconnecting because connection failed\n", nm) + goto reconnect + } + firstConnection = false + hasAuthed = false + + subscribe: + sub, err = relay.Subscribe( + ctx, ff, + // append( + opts..., + // WithCheckDuplicate( + // func(id, relay string) bool { + // _, exists := seenAlready.Load(id) + // if exists && p.duplicateMiddleware != nil { + // p.duplicateMiddleware(relay, id) + // } + // return exists + // }, + // ), + // )..., + ) + if err != nil { + log.D.F("%s reconnecting because subscription died\n", nm) + goto reconnect + } + + go func() { + <-sub.EndOfStoredEvents + + // guard here otherwise a resubscription will trigger a duplicate call to eoseWg.Done() + if eosed.CompareAndSwap(false, true) { + eoseWg.Done() + } + }() + + // reset interval when we get a good subscription + interval = 3 * time.Second + + for { + select { + case evt, more := <-sub.Events: + if !more { + // this means the connection was closed for weird reasons, like the server shut down + // so we will update the filters here to include only events seem from now on + // and try to reconnect until we succeed + now := timestamp.Now() + for i := range ff.F { + ff.F[i].Since = now + } + log.D.F( + "%s reconnecting because sub.Events is broken\n", + nm, + ) + goto reconnect + } + + ie := RelayEvent{E: evt, Relay: relay} + if mh := p.eventMiddleware; mh != nil { + mh(ie) + } + + select { + case events <- ie: + case <-ctx.Done(): + return + } + case <-ticker.C: + if eosed.Load() { + old := timestamp.New(time.Now().Add(-seenAlreadyDropTick).Unix()) + for id, value := range seenAlready.Range { + if value.I64() < old.I64() { + seenAlready.Delete(id) + } + } + } + case reason := <-sub.ClosedReason: + if strings.HasPrefix( + reason, "auth-required:", + ) && p.authHandler != nil && !hasAuthed { + // relay is requesting auth. if we can we will perform auth and try again + err = relay.Auth( + ctx, p.authHandler(), + ) + if err == nil { + hasAuthed = true // so we don't keep doing AUTH again and again + goto subscribe + } + } else { + log.D.F("CLOSED from %s: '%s'\n", nm, reason) + } + + return + case <-ctx.Done(): + return + } + } + + reconnect: + // we will go back to the beginning of the loop and try to connect again and again + // until the context is canceled + time.Sleep(interval) + interval = interval * 17 / 10 // the next time we try we will wait longer + } + }(url) + } + + return events +} + +// Deprecated: SubManyEose is deprecated: use FetchMany instead. +func (p *Pool) SubManyEose( + ctx context.T, + urls []string, + filters *filters.T, + opts ...SubscriptionOption, +) chan RelayEvent { + // seenAlready := xsync.NewMapOf[string, struct{}]() + return p.subManyEoseNonOverwriteCheckDuplicate( + ctx, urls, filters, + // WithCheckDuplicate( + // func(id, relay string) bool { + // _, exists := seenAlready.LoadOrStore(id, struct{}{}) + // if exists && p.duplicateMiddleware != nil { + // p.duplicateMiddleware(relay, id) + // } + // return exists + // }, + // ), + opts..., + ) +} + +func (p *Pool) subManyEoseNonOverwriteCheckDuplicate( + ctx context.T, + urls []string, + filters *filters.T, + // wcd WithCheckDuplicate, + opts ...SubscriptionOption, +) chan RelayEvent { + ctx, cancel := context.Cause(ctx) + + events := make(chan RelayEvent) + wg := sync.WaitGroup{} + wg.Add(len(urls)) + + // opts = append(opts, wcd) + + go func() { + // this will happen when all subscriptions get an eose (or when they die) + wg.Wait() + cancel(errors.New("all subscriptions ended")) + close(events) + }() + + for _, url := range urls { + go func(nm string) { + defer wg.Done() + + if mh := p.queryMiddleware; mh != nil { + for _, filter := range filters.F { + if filter.Kinds != nil && filter.Authors != nil { + for _, k := range filter.Kinds.K { + for _, author := range filter.Authors.ToSliceOfBytes() { + mh(nm, hex.Enc(author), k.K) + } + } + } + } + } + + relay, err := p.EnsureRelay(nm) + if err != nil { + log.D.F( + "error connecting to %s with %v: %s", nm, filters, err, + ) + return + } + + hasAuthed := false + + subscribe: + sub, err := relay.Subscribe(ctx, filters, opts...) + if err != nil { + log.D.F( + "error subscribing to %s with %v: %s", relay, filters, err, + ) + return + } + + for { + select { + case <-ctx.Done(): + return + case <-sub.EndOfStoredEvents: + return + case reason := <-sub.ClosedReason: + if strings.HasPrefix( + reason, "auth-required:", + ) && p.authHandler != nil && !hasAuthed { + // relay is requesting auth. if we can we will perform auth and try again + err = relay.Auth( + ctx, p.authHandler(), + ) + if err == nil { + hasAuthed = true // so we don't keep doing AUTH again and again + goto subscribe + } + } + log.D.F("CLOSED from %s: '%s'\n", nm, reason) + return + case evt, more := <-sub.Events: + if !more { + return + } + + ie := RelayEvent{E: evt, Relay: relay} + if mh := p.eventMiddleware; mh != nil { + mh(ie) + } select { case events <- ie: @@ -391,55 +774,110 @@ func (pool *Pool) subManyEose( } } } - }(normalize.URL(url)) + }(string(normalize.URL(url))) } return events } -// QuerySingle returns the first event returned by the first relay, cancels -// everything else. -func (pool *Pool) QuerySingle( - c context.T, urls []string, f *filter.F, -) *IncomingEvent { - ctx, cancel := context.Cancel(c) - defer cancel() - for ievt := range pool.SubManyEose(ctx, urls, filters.New(f)) { +// // CountMany aggregates count results from multiple relays using NIP-45 HyperLogLog +// func (pool *Pool) CountMany( +// ctx context.T, +// urls []string, +// filter *filter.F, +// opts []SubscriptionOption, +// ) int { +// hll := hyperloglog.New(0) // offset is irrelevant here +// +// wg := sync.WaitGroup{} +// wg.Add(len(urls)) +// for _, url := range urls { +// go func(nm string) { +// defer wg.Done() +// relay, err := pool.EnsureRelay(url) +// if err != nil { +// return +// } +// ce, err := relay.countInternal(ctx, Filters{filter}, opts...) +// if err != nil { +// return +// } +// if len(ce.HyperLogLog) != 256 { +// return +// } +// hll.MergeRegisters(ce.HyperLogLog) +// }(NormalizeURL(url)) +// } +// +// wg.Wait() +// return int(hll.Count()) +// } + +// QuerySingle returns the first event returned by the first relay, cancels everything else. +func (p *Pool) QuerySingle( + ctx context.T, + urls []string, + filter *filter.F, + opts ...SubscriptionOption, +) *RelayEvent { + ctx, cancel := context.Cause(ctx) + for ievt := range p.SubManyEose( + ctx, urls, filters.New(filter), opts..., + ) { + cancel(errors.New("got the first event and ended successfully")) return &ievt } + cancel(errors.New("SubManyEose() didn't get yield events")) return nil } -func (pool *Pool) batchedSubMany( - c context.T, - dfs []DirectedFilters, - subFn func(context.T, []string, *filters.T, bool) chan IncomingEvent, -) chan IncomingEvent { - res := make(chan IncomingEvent) - +// BatchedSubManyEose performs batched subscriptions to multiple relays with different filters. +func (p *Pool) BatchedSubManyEose( + ctx context.T, + dfs []DirectedFilter, + opts ...SubscriptionOption, +) chan RelayEvent { + res := make(chan RelayEvent) + wg := sync.WaitGroup{} + wg.Add(len(dfs)) + // seenAlready := xsync.NewMapOf[string, struct{}]() for _, df := range dfs { - go func(df DirectedFilters) { - for ie := range subFn(c, []string{df.Client}, df.Filters, true) { - res <- ie + go func(df DirectedFilter) { + for ie := range p.subManyEoseNonOverwriteCheckDuplicate( + ctx, + []string{df.Relay}, + filters.New(df.F), + // WithCheckDuplicate( + // func(id, relay string) bool { + // _, exists := seenAlready.LoadOrStore(id, struct{}{}) + // if exists && p.duplicateMiddleware != nil { + // p.duplicateMiddleware(relay, id) + // } + // return exists + // }, + // ), + opts..., + ) { + select { + case res <- ie: + case <-ctx.Done(): + wg.Done() + return + } } + wg.Done() }(df) } + go func() { + wg.Wait() + close(res) + }() + return res } -// BatchedSubMany fires subscriptions only to specific relays but batches them -// when they are the same. -func (pool *Pool) BatchedSubMany( - c context.T, dfs []DirectedFilters, -) chan IncomingEvent { - return pool.batchedSubMany(c, dfs, pool.subMany) -} - -// BatchedSubManyEose is like BatchedSubMany, but ends upon receiving EOSE from -// all relays. -func (pool *Pool) BatchedSubManyEose( - c context.T, dfs []DirectedFilters, -) chan IncomingEvent { - return pool.batchedSubMany(c, dfs, pool.subManyEose) +// Close closes the pool with the given reason. +func (p *Pool) Close(reason string) { + p.cancel(fmt.Errorf("pool closed with reason: '%s'", reason)) } diff --git a/pkg/protocol/ws/pool_test.go b/pkg/protocol/ws/pool_test.go new file mode 100644 index 0000000..3bcb004 --- /dev/null +++ b/pkg/protocol/ws/pool_test.go @@ -0,0 +1,216 @@ +package ws + +import ( + "context" + "sync" + "testing" + "time" + + "orly.dev/pkg/encoders/event" + "orly.dev/pkg/encoders/filter" + "orly.dev/pkg/encoders/kind" + "orly.dev/pkg/encoders/timestamp" + "orly.dev/pkg/interfaces/signer" +) + +// mockSigner implements signer.I for testing +type mockSigner struct { + pubkey []byte +} + +func (m *mockSigner) Pub() []byte { return m.pubkey } +func (m *mockSigner) Sign([]byte) ( + []byte, error, +) { + return []byte("mock-signature"), nil +} +func (m *mockSigner) Generate() error { return nil } +func (m *mockSigner) InitSec([]byte) error { return nil } +func (m *mockSigner) InitPub([]byte) error { return nil } +func (m *mockSigner) Sec() []byte { return []byte("mock-secret") } +func (m *mockSigner) Verify([]byte, []byte) (bool, error) { return true, nil } +func (m *mockSigner) Zero() {} +func (m *mockSigner) ECDH([]byte) ( + []byte, error, +) { + return []byte("mock-shared-secret"), nil +} + +func TestNewPool(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + if pool == nil { + t.Fatal("NewPool returned nil") + } + + if pool.Relays == nil { + t.Error("Pool should have initialized Relays map") + } + + if pool.Context == nil { + t.Error("Pool should have a context") + } +} + +func TestPoolWithAuthHandler(t *testing.T) { + ctx := context.Background() + + authHandler := WithAuthHandler( + func() signer.I { + return &mockSigner{pubkey: []byte("test-pubkey")} + }, + ) + + pool := NewPool(ctx, authHandler) + + if pool.authHandler == nil { + t.Error("Pool should have auth handler set") + } + + // Test that auth handler returns the expected signer + signer := pool.authHandler() + if string(signer.Pub()) != "test-pubkey" { + t.Errorf( + "Expected pubkey 'test-pubkey', got '%s'", string(signer.Pub()), + ) + } +} + +func TestPoolWithEventMiddleware(t *testing.T) { + ctx := context.Background() + + var middlewareCalled bool + middleware := WithEventMiddleware( + func(ie RelayEvent) { + middlewareCalled = true + }, + ) + + pool := NewPool(ctx, middleware) + + // Test that middleware is called + testEvent := &event.E{ + Kind: kind.TextNote, + Content: []byte("test"), + CreatedAt: timestamp.Now(), + } + + ie := RelayEvent{E: testEvent, Relay: nil} + pool.eventMiddleware(ie) + + if !middlewareCalled { + t.Error("Expected middleware to be called") + } +} + +func TestRelayEventString(t *testing.T) { + testEvent := &event.E{ + Kind: kind.TextNote, + Content: []byte("test content"), + CreatedAt: timestamp.Now(), + } + + client := &Relay{URL: "wss://test.relay"} + ie := RelayEvent{E: testEvent, Relay: client} + + str := ie.String() + if !contains(str, "wss://test.relay") { + t.Errorf("Expected string to contain relay URL, got: %s", str) + } + + if !contains(str, "test content") { + t.Errorf("Expected string to contain event content, got: %s", str) + } +} + +func TestNamedLock(t *testing.T) { + // Test that named locks work correctly + var wg sync.WaitGroup + var counter int + var mu sync.Mutex + + lockName := "test-lock" + + // Start multiple goroutines that try to increment counter + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + unlock := namedLock(lockName) + defer unlock() + + // Critical section + mu.Lock() + temp := counter + time.Sleep(1 * time.Millisecond) // Simulate work + counter = temp + 1 + mu.Unlock() + }() + } + + wg.Wait() + + if counter != 10 { + t.Errorf("Expected counter to be 10, got %d", counter) + } +} + +func TestPoolEnsureRelayInvalidURL(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + // Test with invalid URL + _, err := pool.EnsureRelay("invalid-url") + if err == nil { + t.Error("Expected error for invalid URL") + } +} + +func TestPoolQuerySingle(t *testing.T) { + ctx := context.Background() + pool := NewPool(ctx) + + // Test with empty URLs slice + result := pool.QuerySingle(ctx, []string{}, &filter.F{}) + if result != nil { + t.Error("Expected nil result for empty URLs") + } +} + +// Helper functions +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + containsSubstring(s, substr))) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func uintPtr(u uint) *uint { + return &u +} + +// Test pool context cancellation +func TestPoolContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := NewPool(ctx) + + // Cancel the context + cancel() + + // Check that pool context is cancelled + select { + case <-pool.Context.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("Expected pool context to be cancelled") + } +} diff --git a/pkg/protocol/ws/subscription.go b/pkg/protocol/ws/subscription.go index 8af238d..971e5a0 100644 --- a/pkg/protocol/ws/subscription.go +++ b/pkg/protocol/ws/subscription.go @@ -1,101 +1,105 @@ package ws import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "orly.dev/pkg/encoders/envelopes/closeenvelope" - "orly.dev/pkg/encoders/envelopes/countenvelope" "orly.dev/pkg/encoders/envelopes/reqenvelope" "orly.dev/pkg/encoders/event" "orly.dev/pkg/encoders/filters" "orly.dev/pkg/encoders/subscription" "orly.dev/pkg/utils/chk" - "orly.dev/pkg/utils/context" - "orly.dev/pkg/utils/errorf" - "orly.dev/pkg/utils/log" - "strconv" - "sync" - "sync/atomic" ) -// Subscription is a client interface for a subscription (what REQ turns into -// after EOSE). +// Subscription represents a subscription to a relay. type Subscription struct { - label string - counter int + counter int64 + id string - Relay *Client + Relay *Relay Filters *filters.T - // for this to be treated as a COUNT and not a REQ this must be set - countResult chan int + // // for this to be treated as a COUNT and not a REQ this must be set + // countResult chan CountEnvelope - // The Events channel emits all EVENTs that come in a Subscription will be - // closed when the subscription ends + // the Events channel emits all EVENTs that come in a Subscription + // will be closed when the subscription ends Events event.C mu sync.Mutex - // The EndOfStoredEvents channel is closed when an EOSE comes for that - // subscription + // the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription EndOfStoredEvents chan struct{} - // The ClosedReason channel emits the reason when a CLOSED message is - // received + // the ClosedReason channel emits the reason when a CLOSED message is received ClosedReason chan string // Context will be .Done() when the subscription ends - Context context.T + Context context.Context + // // if it is not nil, checkDuplicate will be called for every event received + // // if it returns true that event will not be processed further. + // checkDuplicate func(id string, relay string) bool + // + // // if it is not nil, checkDuplicateReplaceable will be called for every event received + // // if it returns true that event will not be processed further. + // checkDuplicateReplaceable func(rk ReplaceableKey, ts Timestamp) bool + + match func(*event.E) bool // this will be either Filters.Match or Filters.MatchIgnoringTimestampConstraints live atomic.Bool eosed atomic.Bool - closed atomic.Bool - cancel context.F + cancel context.CancelCauseFunc - // This keeps track of the events we've received before the EOSE that we - // must dispatch before closing the EndOfStoredEvents channel + // this keeps track of the events we've received before the EOSE that we must dispatch before + // closing the EndOfStoredEvents channel storedwg sync.WaitGroup } -// EventMessage is an event, with the associated relay URL attached. -type EventMessage struct { - Event event.E - Relay string -} - -// SubscriptionOption is the type of the argument passed for that. Some examples -// are WithLabel. +// SubscriptionOption is the type of the argument passed when instantiating relay connections. +// Some examples are WithLabel. type SubscriptionOption interface { IsSubscriptionOption() } -// WithLabel puts a label on the subscription (it is prepended to the automatic -// id) that is sent to relays. +// WithLabel puts a label on the subscription (it is prepended to the automatic id) that is sent to relays. type WithLabel string func (_ WithLabel) IsSubscriptionOption() {} -var _ SubscriptionOption = (WithLabel)("") +// // WithCheckDuplicate sets checkDuplicate on the subscription +// type WithCheckDuplicate func(id, relay string) bool +// +// func (_ WithCheckDuplicate) IsSubscriptionOption() {} +// +// // WithCheckDuplicateReplaceable sets checkDuplicateReplaceable on the subscription +// type WithCheckDuplicateReplaceable func(rk ReplaceableKey, ts *timestamp.T) bool +// +// func (_ WithCheckDuplicateReplaceable) IsSubscriptionOption() {} -// GetID return the Nostr subscription ID as given to the Client it is a -// concatenation of the label and a serial number. -func (sub *Subscription) GetID() (id *subscription.Id) { - var err error - if id, err = subscription.NewId(sub.label + ":" + strconv.Itoa(sub.counter)); chk.E(err) { - return - } - return -} +var ( + _ SubscriptionOption = (WithLabel)("") + // _ SubscriptionOption = (WithCheckDuplicate)(nil) + // _ SubscriptionOption = (WithCheckDuplicateReplaceable)(nil) +) func (sub *Subscription) start() { <-sub.Context.Done() - // the subscription ends once the context is canceled (if not already) - sub.Unsub() // this will set sub.live to false - // do this so we don't have the possibility of closing the Events channel - // and then trying to send to it + // the subscription ends once the context is canceled (if not already) + sub.unsub(errors.New("context done on start()")) // this will set sub.live to false + + // do this so we don't have the possibility of closing the Events channel and then trying to send to it sub.mu.Lock() close(sub.Events) sub.mu.Unlock() } +// GetID returns the subscription ID. +func (sub *Subscription) GetID() string { return sub.id } + func (sub *Subscription) dispatchEvent(evt *event.E) { added := false if !sub.eosed.Load() { @@ -113,7 +117,6 @@ func (sub *Subscription) dispatchEvent(evt *event.E) { case <-sub.Context.Done(): } } - if added { sub.storedwg.Done() } @@ -122,6 +125,7 @@ func (sub *Subscription) dispatchEvent(evt *event.E) { func (sub *Subscription) dispatchEose() { if sub.eosed.CompareAndSwap(false, true) { + sub.match = sub.Filters.MatchIgnoringTimestampConstraints go func() { sub.storedwg.Wait() sub.EndOfStoredEvents <- struct{}{} @@ -129,62 +133,72 @@ func (sub *Subscription) dispatchEose() { } } -func (sub *Subscription) dispatchClosed(reason string) { - if sub.closed.CompareAndSwap(false, true) { - go func() { - sub.ClosedReason <- reason - }() - } +// handleClosed handles the CLOSED message from a relay. +func (sub *Subscription) handleClosed(reason string) { + go func() { + sub.ClosedReason <- reason + sub.live.Store(false) // set this so we don't send an unnecessary CLOSE to the relay + sub.unsub(fmt.Errorf("CLOSED received: %s", reason)) + }() } -// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. Unsub() -// also closes the channel sub.Events and makes a new one. +// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. +// Unsub() also closes the channel sub.Events and makes a new one. func (sub *Subscription) Unsub() { + sub.unsub(errors.New("Unsub() called")) +} + +// unsub is the internal implementation of Unsub. +func (sub *Subscription) unsub(err error) { // cancel the context (if it's not canceled already) - sub.cancel() - // mark the subscription as closed and send a CLOSE to the relay (naïve - // sync.Once implementation) + sub.cancel(err) + + // mark subscription as closed and send a CLOSE to the relay (naïve sync.Once implementation) if sub.live.CompareAndSwap(true, false) { sub.Close() } + // remove subscription from our map - sub.Relay.Subscriptions.Delete(sub.GetID().String()) + sub.Relay.Subscriptions.Delete(sub.counter) } // Close just sends a CLOSE message. You probably want Unsub() instead. func (sub *Subscription) Close() { if sub.Relay.IsConnected() { - id := sub.GetID() + id, err := subscription.NewId(sub.id) + if err != nil { + return + } closeMsg := closeenvelope.NewFrom(id) - var b []byte - b = closeMsg.Marshal(nil) - <-sub.Relay.Write(b) + closeb := closeMsg.Marshal(nil) + <-sub.Relay.Write(closeb) } } -// Sub sets sub.Filters and then calls sub.Fire(ctx). The subscription will be -// closed if the context expires. -func (sub *Subscription) Sub(_ context.T, ff *filters.T) { +// Sub sets sub.Filters and then calls sub.Fire(ctx). +// The subscription will be closed if the context expires. +func (sub *Subscription) Sub(_ context.Context, ff *filters.T) { sub.Filters = ff sub.Fire() } // Fire sends the "REQ" command to the relay. func (sub *Subscription) Fire() (err error) { - id := sub.GetID() - - var b []byte - if sub.countResult == nil { - b = reqenvelope.NewFrom(id, sub.Filters).Marshal(b) - } else { - b = countenvelope.NewRequest(id, sub.Filters).Marshal(b) + // if sub.countResult == nil { + req := reqenvelope.NewWithIdString(sub.id, sub.Filters) + if req == nil { + return fmt.Errorf("invalid ID or filters") } - log.T.F("{%s} sending %s", sub.Relay.URL, b) + reqb := req.Marshal(nil) + // } else + // if len(sub.Filters) == 1 { + // reqb, _ = CountEnvelope{sub.id, sub.Filters[0], nil, nil}.MarshalJSON() + // } else { + // return fmt.Errorf("unexpected sub configuration") sub.live.Store(true) - if err = <-sub.Relay.Write(b); chk.T(err) { - sub.cancel() - return errorf.E("failed to write: %w", err) + if err = <-sub.Relay.Write(reqb); chk.E(err) { + err = fmt.Errorf("failed to write: %w", err) + sub.cancel(err) } - - return nil + return } diff --git a/pkg/utils/context/context.go b/pkg/utils/context/context.go index 1eb3253..a24b477 100644 --- a/pkg/utils/context/context.go +++ b/pkg/utils/context/context.go @@ -28,8 +28,10 @@ var ( TODO = context.TODO // Value - context.WithValue Value = context.WithValue - // CancelCause - context.WithCancelCause - CancelCause = context.WithCancelCause + // Cause - context.WithCancelCause + Cause = context.WithCancelCause + + GetCause = context.Cause // Canceled - context.Canceled Canceled = context.Canceled )