Compare commits

...

14 Commits

Author SHA1 Message Date
96eab2270d actually implement returning events, and cap at 512 2025-08-04 21:53:16 +01:00
c0bd7d8da3 skip init of relay secret key if there isn't one 2025-08-04 21:27:26 +01:00
1ffb7afb01 implement first draft of nwc client and cli 2025-08-04 14:54:15 +01:00
ffa9d85ba5 Fix error handling and logging in wallet service
- pkg/protocol/nwc/wallet_service.go
  - Added import for "orly.dev/pkg/utils/log"
  - Replaced fmt.Printf calls with log.E.F for consistent error logging
  - Simplified error handling using chk.E instead of explicit if err != nil checks
2025-08-03 11:39:19 +01:00
1223b1b20e implement nwc based on translation from js-sdk, refactor encryption to use bytes 2025-08-03 11:20:02 +01:00
deb56664e2 Update spider seeds and bump version to v0.4.6
- pkg/app/config/config.go
  - Updated default SpiderSeeds list with new relay URL
- pkg/version/version
  - Bumped version number from v0.4.5 to v0.4.6
2025-08-02 14:03:15 +01:00
1641d18993 Add IP blocking with authentication-based unblocking and offense tracking
- pkg/protocol/socketapi/handleAuth.go
  - Added import for "orly.dev/pkg/utils/iptracker"
  - Added logic to call iptracker.Global.Authenticate() on successful authentication
- pkg/protocol/openapi/event.go
  - Added logic to call iptracker.Global.Authenticate() on successful authentication
- pkg/utils/iptracker/iptracker.go
  - Introduced offense tracking and block duration calculation based on offense count
  - Added Authenticate method to remove blocks upon successful authentication
  - Modified isBlockedNoLock to always return true when an IP is blocked
  - Added HasBlockDurationPassed, GetBlockDuration, and Reset methods
- pkg/version/version
  - Bumped version number from v0.4.4 to v0.4.5
2025-08-02 12:30:47 +01:00
eab5d236db Add IP blocking based on failed authentication attempts
- pkg/protocol/socketapi/socketapi.go
  - Added import for "orly.dev/pkg/utils/iptracker"
  - Added logic to check if an IP is blocked and reject the connection if it is
- pkg/protocol/socketapi/handleEvent.go
  - Added imports for "orly.dev/pkg/utils/iptracker" and "time"
  - Added logic to check if an IP is blocked, send a notice to the client, and close the connection if it is
  - Added logic to record failed authentication attempts and block IPs that exceed the threshold
- pkg/protocol/openapi/event.go
  - Added imports for "orly.dev/pkg/utils/iptracker" and "time"
  - Added logic to check if an IP is blocked and return a forbidden error if it is
  - Added logic to record failed authentication attempts and return appropriate errors based on whether the IP is blocked or not
- pkg/utils/iptracker/iptracker.go
  - Created new package with functionality to track and block IPs based on failed authentication attempts
- pkg/version/version
  - Bumped version number from v0.4.3 to v0.4.4
2025-08-02 12:01:37 +01:00
f3e7188816 Bump version to v0.4.3
- pkg/version/version
  - Updated version number from v0.4.2 to v0.4.3
2025-08-02 11:44:05 +01:00
39957c2ebf implement correct handling when saving events to search for owner/peer deletes 2025-08-02 11:43:33 +01:00
4528d44fc7 Add owner check for deletion events
- pkg/protocol/socketapi/handleEvent.go
  - Added variable to track if the delete event's author is an owner
  - Modified checks to allow owners to delete events from other owners
  - Updated error messages for clarity and consistency

- bumped to version v0.4.2
2025-08-01 19:54:40 +01:00
7b19db5806 Remove redundant log statements and imports from spider-fetch.go
- pkg/app/relay/spider-fetch.go
  - Removed import of "orly.dev/pkg/utils/lol"
  - Removed log.I.S(ev) statement
  - Removed log.T.C(...) and log.I.F("saved event: %0x", ev.ID) statements
  - Removed redundant comments and whitespace
2025-08-01 09:35:06 +01:00
14d4417aec Bump version and add spider type configuration
pkg/version/version
- Updated version number from v0.4.0 to v0.4.1

pkg/app/config/config.go
- Added new config field `SpiderType` with default value "directory"

pkg/app/relay/peers.go
- Added check to skip empty addresses before processing peer information

pkg/app/relay/spider.go
- Modified spider fetch logic to conditionally execute based on spider type
- Added support for different kinds of events based on spider type
2025-08-01 08:52:22 +01:00
bdda37732c Remove redundant log statements from addEvent.go
- pkg/app/relay/addEvent.go
  - Removed log.I.S(pubkeys) statement
  - Removed log.I.F("sending to replica %s", a) statement
2025-07-31 22:37:34 +01:00
48 changed files with 2078 additions and 201 deletions

162
cmd/nwcclient/README.md Normal file
View File

@@ -0,0 +1,162 @@
# NWC Client CLI Tool
A command-line interface tool for making calls to Nostr Wallet Connect (NWC) services.
## Overview
This CLI tool allows you to interact with NWC wallet services using the methods defined in the NIP-47 specification. It provides a simple interface for executing wallet operations and displays the JSON response from the wallet service.
## Usage
```
nwcclient <connection URL> <method> [parameters...]
```
### Connection URL
The connection URL should be in the Nostr Wallet Connect format:
```
nostr+walletconnect://<wallet_pubkey>?relay=<relay_url>&secret=<secret>
```
### Supported Methods
The following methods are supported by this CLI tool:
- `get_info` - Get wallet information
- `get_balance` - Get wallet balance
- `get_budget` - Get wallet budget
- `make_invoice` - Create an invoice
- `pay_invoice` - Pay an invoice
- `pay_keysend` - Send a keysend payment
- `lookup_invoice` - Look up an invoice
- `list_transactions` - List transactions
- `sign_message` - Sign a message
### Unsupported Methods
The following methods are defined in the NIP-47 specification but are not directly supported by this CLI tool due to limitations in the underlying nwc package:
- `create_connection` - Create a connection
- `make_hold_invoice` - Create a hold invoice
- `settle_hold_invoice` - Settle a hold invoice
- `cancel_hold_invoice` - Cancel a hold invoice
- `multi_pay_invoice` - Pay multiple invoices
- `multi_pay_keysend` - Send multiple keysend payments
## Method Parameters
### Methods with No Parameters
- `get_info`
- `get_balance`
- `get_budget`
Example:
```
nwcclient <connection URL> get_info
```
### Methods with Parameters
#### make_invoice
```
nwcclient <connection URL> make_invoice <amount> <description> [description_hash] [expiry]
```
- `amount` - Amount in millisatoshis (msats)
- `description` - Invoice description
- `description_hash` (optional) - Hash of the description
- `expiry` (optional) - Expiry time in seconds
Example:
```
nwcclient <connection URL> make_invoice 1000000 "Test invoice" "" 3600
```
#### pay_invoice
```
nwcclient <connection URL> pay_invoice <invoice> [amount]
```
- `invoice` - BOLT11 invoice
- `amount` (optional) - Amount in millisatoshis (msats)
Example:
```
nwcclient <connection URL> pay_invoice lnbc1...
```
#### pay_keysend
```
nwcclient <connection URL> pay_keysend <amount> <pubkey> [preimage]
```
- `amount` - Amount in millisatoshis (msats)
- `pubkey` - Recipient's public key
- `preimage` (optional) - Payment preimage
Example:
```
nwcclient <connection URL> pay_keysend 1000000 03...
```
#### lookup_invoice
```
nwcclient <connection URL> lookup_invoice <payment_hash_or_invoice>
```
- `payment_hash_or_invoice` - Payment hash or BOLT11 invoice
Example:
```
nwcclient <connection URL> lookup_invoice 3d...
```
#### list_transactions
```
nwcclient <connection URL> list_transactions [from <timestamp>] [until <timestamp>] [limit <count>] [offset <count>] [unpaid <true|false>] [type <incoming|outgoing>]
```
Parameters are specified as name-value pairs:
- `from` - Start timestamp
- `until` - End timestamp
- `limit` - Maximum number of transactions to return
- `offset` - Number of transactions to skip
- `unpaid` - Whether to include unpaid transactions
- `type` - Transaction type (incoming or outgoing)
Example:
```
nwcclient <connection URL> list_transactions limit 10 type incoming
```
#### sign_message
```
nwcclient <connection URL> sign_message <message>
```
- `message` - Message to sign
Example:
```
nwcclient <connection URL> sign_message "Hello, world!"
```
## Output
The tool prints the JSON response from the wallet service to stdout. If an error occurs, an error message is printed to stderr.
## Limitations
- The tool only supports methods that have direct client methods in the nwc package.
- Complex parameters like metadata are not supported.
- The tool does not support interactive authentication or authorization.

285
cmd/nwcclient/main.go Normal file
View File

@@ -0,0 +1,285 @@
package main
import (
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"orly.dev/pkg/protocol/nwc"
)
func printUsage() {
fmt.Println("Usage: nwcclient \"<connection URL>\" <method> [parameters...]")
fmt.Println("\nSupported methods:")
fmt.Println(" get_info - Get wallet information")
fmt.Println(" get_balance - Get wallet balance")
fmt.Println(" get_budget - Get wallet budget")
fmt.Println(" make_invoice - Create an invoice (amount, description, [description_hash], [expiry])")
fmt.Println(" pay_invoice - Pay an invoice (invoice, [amount])")
fmt.Println(" pay_keysend - Send a keysend payment (amount, pubkey, [preimage])")
fmt.Println(" lookup_invoice - Look up an invoice (payment_hash or invoice)")
fmt.Println(" list_transactions - List transactions ([from], [until], [limit], [offset], [unpaid], [type])")
fmt.Println(" sign_message - Sign a message (message)")
fmt.Println("\nUnsupported methods (due to limitations in the nwc package):")
fmt.Println(" create_connection - Create a connection")
fmt.Println(" make_hold_invoice - Create a hold invoice")
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
fmt.Println(" multi_pay_invoice - Pay multiple invoices")
fmt.Println(" multi_pay_keysend - Send multiple keysend payments")
fmt.Println("\nParameters format:")
fmt.Println(" - Positional parameters are used for required fields")
fmt.Println(" - For list_transactions, named parameters are used: 'from', 'until', 'limit', 'offset', 'unpaid', 'type'")
fmt.Println(" Example: nwcclient <url> list_transactions limit 10 type incoming")
os.Exit(1)
}
func main() {
// Check if we have enough arguments
if len(os.Args) < 3 {
printUsage()
}
// Parse connection URL and method
connectionURL := os.Args[1]
methodStr := os.Args[2]
method := nwc.Method(methodStr)
// Parse the wallet connect URL
opts, err := nwc.ParseWalletConnectURL(connectionURL)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing connection URL: %v\n", err)
os.Exit(1)
}
// Create a new NWC client
client, err := nwc.NewNWCClient(opts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating NWC client: %v\n", err)
os.Exit(1)
}
defer client.Close()
// Execute the requested method
var result interface{}
switch method {
case nwc.GetInfo:
result, err = client.GetInfo()
case nwc.GetBalance:
result, err = client.GetBalance()
case nwc.GetBudget:
result, err = client.GetBudget()
case nwc.MakeInvoice:
if len(os.Args) < 5 {
fmt.Fprintf(
os.Stderr,
"Error: make_invoice requires at least amount and description\n",
)
printUsage()
}
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
description := os.Args[4]
req := &nwc.MakeInvoiceRequest{
Amount: amount,
Description: description,
}
// Optional parameters
if len(os.Args) > 5 {
req.DescriptionHash = os.Args[5]
}
if len(os.Args) > 6 {
expiry, err := strconv.ParseInt(os.Args[6], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing expiry: %v\n", err)
os.Exit(1)
}
req.Expiry = &expiry
}
result, err = client.MakeInvoice(req)
case nwc.PayInvoice:
if len(os.Args) < 4 {
fmt.Fprintf(os.Stderr, "Error: pay_invoice requires an invoice\n")
printUsage()
}
req := &nwc.PayInvoiceRequest{
Invoice: os.Args[3],
}
// Optional amount parameter
if len(os.Args) > 4 {
amount, err := strconv.ParseInt(os.Args[4], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
req.Amount = &amount
}
result, err = client.PayInvoice(req)
case nwc.PayKeysend:
if len(os.Args) < 5 {
fmt.Fprintf(
os.Stderr, "Error: pay_keysend requires amount and pubkey\n",
)
printUsage()
}
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
os.Exit(1)
}
req := &nwc.PayKeysendRequest{
Amount: amount,
Pubkey: os.Args[4],
}
// Optional preimage
if len(os.Args) > 5 {
req.Preimage = os.Args[5]
}
result, err = client.PayKeysend(req)
case nwc.LookupInvoice:
if len(os.Args) < 4 {
fmt.Fprintf(
os.Stderr,
"Error: lookup_invoice requires a payment_hash or invoice\n",
)
printUsage()
}
param := os.Args[3]
req := &nwc.LookupInvoiceRequest{}
// Determine if the parameter is a payment hash or an invoice
if strings.HasPrefix(param, "ln") {
req.Invoice = param
} else {
req.PaymentHash = param
}
result, err = client.LookupInvoice(req)
case nwc.ListTransactions:
req := &nwc.ListTransactionsRequest{}
// Parse optional parameters
paramIndex := 3
for paramIndex < len(os.Args) {
if paramIndex+1 >= len(os.Args) {
break
}
paramName := os.Args[paramIndex]
paramValue := os.Args[paramIndex+1]
switch paramName {
case "from":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing from: %v\n", err)
os.Exit(1)
}
req.From = &val
case "until":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing until: %v\n", err)
os.Exit(1)
}
req.Until = &val
case "limit":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing limit: %v\n", err)
os.Exit(1)
}
req.Limit = &val
case "offset":
val, err := strconv.ParseInt(paramValue, 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing offset: %v\n", err)
os.Exit(1)
}
req.Offset = &val
case "unpaid":
val := paramValue == "true"
req.Unpaid = &val
case "type":
req.Type = &paramValue
default:
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
os.Exit(1)
}
paramIndex += 2
}
result, err = client.ListTransactions(req)
case nwc.SignMessage:
if len(os.Args) < 4 {
fmt.Fprintf(os.Stderr, "Error: sign_message requires a message\n")
printUsage()
}
req := &nwc.SignMessageRequest{
Message: os.Args[3],
}
result, err = client.SignMessage(req)
case nwc.CreateConnection, nwc.MakeHoldInvoice, nwc.SettleHoldInvoice, nwc.CancelHoldInvoice, nwc.MultiPayInvoice, nwc.MultiPayKeysend:
fmt.Fprintf(
os.Stderr,
"Error: Method %s is not directly supported by the CLI tool.\n",
methodStr,
)
fmt.Fprintf(
os.Stderr,
"This is because these methods don't have exported client methods in the nwc package.\n",
)
fmt.Fprintf(
os.Stderr,
"Only the following methods are currently supported: get_info, get_balance, get_budget, make_invoice, pay_invoice, pay_keysend, lookup_invoice, list_transactions, sign_message\n",
)
os.Exit(1)
default:
fmt.Fprintf(os.Stderr, "Error: Unsupported method: %s\n", methodStr)
printUsage()
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error executing method: %v\n", err)
os.Exit(1)
}
// Print the result as JSON
jsonData, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marshaling result to JSON: %v\n", err)
os.Exit(1)
}
fmt.Println(string(jsonData))
}

View File

@@ -37,7 +37,8 @@ type C struct {
Pprof string `env:"ORLY_PPROF" usage:"enable pprof on 127.0.0.1:6060" enum:"cpu,memory,allocation"`
AuthRequired bool `env:"ORLY_AUTH_REQUIRED" default:"false" usage:"require authentication for all requests"`
PublicReadable bool `env:"ORLY_PUBLIC_READABLE" default:"true" usage:"allow public read access to regardless of whether the client is authed"`
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/"`
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://profiles.nostr1.com/,wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/"`
SpiderType string `env:"ORLY_SPIDER_TYPE" usage:"whether to spider, and what degree of spidering: none, directory, follows (follows means to the second degree of the follow graph)" default:"directory"`
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`

View File

@@ -125,7 +125,6 @@ func (s *Server) AddEvent(
// (they're unpacked from a string containing both, appended at the
// same time), so if the pubkeys from the http event endpoint sent
// us here matches the index of this address, we can skip it.
log.I.S(pubkeys)
for _, pk := range pubkeys {
if bytes.Equal(s.Peers.Pubkeys[i], pk) {
log.I.F(
@@ -135,7 +134,6 @@ func (s *Server) AddEvent(
continue replica
}
}
log.I.F("sending to replica %s", a)
var ur *url.URL
if ur, err = url.Parse(a + "/api/event"); chk.E(err) {
continue

View File

@@ -33,6 +33,9 @@ func (p *Peers) Init(
addresses []string, sec string,
) (err error) {
for _, address := range addresses {
if len(address) == 0 {
continue
}
split := strings.Split(address, "@")
if len(split) != 2 {
log.E.F("invalid peer address: %s", address)
@@ -46,6 +49,9 @@ func (p *Peers) Init(
p.Pubkeys = append(p.Pubkeys, pk)
log.I.F("peer %s added; pubkey: %0x", split[1], pk)
}
if sec == "" {
return
}
p.I = &p256k.Signer{}
var s []byte
if s, err = keys.DecodeNsecOrHex(sec); chk.E(err) {

View File

@@ -102,7 +102,7 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
}
if isFollowed {
if _, _, err = sto.SaveEvent(
c, evt, false,
c, evt, false, nil,
); err != nil && !errors.Is(
err, store.ErrDupEvent,
) {
@@ -124,7 +124,7 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
for _, pk := range owners {
if bytes.Equal(evt.Pubkey, pk) {
if _, _, err = sto.SaveEvent(
c, evt, false,
c, evt, false, nil,
); err != nil && !errors.Is(
err, store.ErrDupEvent,
) {
@@ -236,7 +236,9 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
}
}
}
if _, _, err = sto.SaveEvent(c, evt, false); err != nil && !errors.Is(
if _, _, err = sto.SaveEvent(
c, evt, false, append(s.Peers.Pubkeys, s.ownersPubkeys...),
); err != nil && !errors.Is(
err, store.ErrDupEvent,
) {
return

View File

@@ -1,7 +1,6 @@
package relay
import (
"fmt"
"orly.dev/pkg/crypto/ec/schnorr"
"orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/event"
@@ -15,7 +14,6 @@ import (
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/errorf"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/lol"
"runtime/debug"
"time"
)
@@ -158,16 +156,13 @@ func (s *Server) SpiderFetch(
err = nil
return
}
// Process each event immediately
for i, ev := range evss {
// log.I.S(ev)
// Create a key based on pubkey and kind for deduplication
pkKindKey := string(ev.Pubkey) + string(ev.Kind.Marshal(nil))
// Check if we already have an event with this pubkey and kind
existing, exists := pkKindMap[pkKindKey]
// If it doesn't exist or the new event is newer, store it and save to database
if !exists || ev.CreatedAtInt64() > existing.Timestamp {
var ser *types.Uint40
@@ -180,28 +175,14 @@ func (s *Server) SpiderFetch(
if valid, err = ev.Verify(); chk.E(err) || !valid {
continue
}
log.I.F("event %0x is valid", ev.ID)
}
// Save the event to the database
if _, _, err = s.Storage().SaveEvent(
s.Ctx, ev, true, // already verified
s.Ctx, ev, true, nil,
); chk.E(err) {
err = nil
continue
}
if lol.Level.Load() == lol.Trace {
log.T.C(
func() string {
return fmt.Sprintf(
"saved event:\n%s", ev.Marshal(nil),
)
},
)
} else {
log.I.F("saved event: %0x", ev.ID)
}
// Store the essential information
pkKindMap[pkKindKey] = &IdPkTs{
Id: ev.ID,
@@ -209,7 +190,6 @@ func (s *Server) SpiderFetch(
Kind: ev.Kind.ToU16(),
Timestamp: ev.CreatedAtInt64(),
}
// Extract p tags if not in noExtract mode
if !noExtract {
t := ev.Tags.GetAll(tag.New("p"))
@@ -227,7 +207,6 @@ func (s *Server) SpiderFetch(
}
}
}
// Nil the event in the slice to free memory
evss[i] = nil
}
@@ -236,17 +215,14 @@ func (s *Server) SpiderFetch(
}
chk.E(s.Storage().Sync())
debug.FreeOSMemory()
// If we're in noExtract mode, just return
if noExtract {
return
}
// Convert the collected pubkeys to the return format
for pk := range pkMap {
pks = append(pks, []byte(pk))
}
log.I.F("found %d pks", len(pks))
return
}

View File

@@ -97,16 +97,18 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
s.SetFollowedFollows(followedFollows)
s.SetOwnersMuted(ownersMuted)
// lastly, update all followed users new events in the background
if !dontFetch {
if !dontFetch && s.C.SpiderType != "none" {
go func() {
var k *kinds.T
if s.C.SpiderType == "directory" {
k = kinds.New(
kind.ProfileMetadata, kind.RelayListMetadata,
kind.DMRelaysList,
)
}
everyone := append(ownersFollowed, followedFollows...)
s.SpiderFetch(
// kinds.New(
// kind.ProfileMetadata, kind.RelayListMetadata,
// kind.DMRelaysList,
// ),
nil,
false, true, everyone...,
_, _ = s.SpiderFetch(
k, false, true, everyone...,
)
}()
}

View File

@@ -5,42 +5,16 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/hex"
"lukechampine.com/frand"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/errorf"
"strings"
"lukechampine.com/frand"
)
// ComputeSharedSecret returns a shared secret key used to encrypt messages. The private and public keys should be hex
// encoded. Uses the Diffie-Hellman key exchange (ECDH) (RFC 4753).
func ComputeSharedSecret(pkh, skh string) (sharedSecret []byte, err error) {
var skb, pkb []byte
if skb, err = hex.Dec(skh); chk.E(err) {
return
}
if pkb, err = hex.Dec(pkh); chk.E(err) {
return
}
signer := new(p256k.Signer)
if err = signer.InitSec(skb); chk.E(err) {
return
}
if sharedSecret, err = signer.ECDH(pkb); chk.E(err) {
return
}
return
}
// EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by
// ComputeSharedSecret.
//
// Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector).
//
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
func EncryptNip4(msg string, key []byte) (ct []byte, err error) {
func EncryptNip4(msg, key []byte) (ct []byte, err error) {
// block size is 16 bytes
iv := make([]byte, 16)
if _, err = frand.Read(iv); chk.E(err) {
@@ -71,22 +45,20 @@ func EncryptNip4(msg string, key []byte) (ct []byte, err error) {
// DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message ->
// EncryptNip4(message, key).
//
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
func DecryptNip4(content string, key []byte) (msg []byte, err error) {
parts := strings.Split(content, "?iv=")
func DecryptNip4(content, key []byte) (msg []byte, err error) {
parts := bytes.Split(content, []byte("?iv="))
if len(parts) < 2 {
return nil, errorf.E(
"error parsing encrypted message: no initialization vector",
)
}
var ciphertext []byte
if ciphertext, err = base64.StdEncoding.DecodeString(parts[0]); chk.E(err) {
ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
err = errorf.E("error decoding ciphertext from base64: %w", err)
return
}
var iv []byte
if iv, err = base64.StdEncoding.DecodeString(parts[1]); chk.E(err) {
iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
err = errorf.E("error decoding iv from base64: %w", err)
return
}

View File

@@ -10,7 +10,9 @@ import (
"golang.org/x/crypto/hkdf"
"io"
"math"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/crypto/sha256"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/errorf"
)
@@ -43,11 +45,9 @@ func WithCustomNonce(salt []byte) func(opts *Opts) {
// Encrypt data using a provided symmetric conversation key using NIP-44
// encryption (chacha20 cipher stream and sha256 HMAC).
func Encrypt(
plaintext string, conversationKey []byte,
applyOptions ...func(opts *Opts),
plaintext, conversationKey []byte, applyOptions ...func(opts *Opts),
) (
cipherString string,
err error,
cipherString []byte, err error,
) {
var o Opts
@@ -70,7 +70,7 @@ func Encrypt(
); chk.E(err) {
return
}
plain := []byte(plaintext)
plain := plaintext
size := len(plain)
if size < MinPlaintextSize || size > MaxPlaintextSize {
err = errorf.E("plaintext should be between 1b and 64kB")
@@ -93,14 +93,15 @@ func Encrypt(
ct = append(ct, o.nonce...)
ct = append(ct, cipher...)
ct = append(ct, mac...)
cipherString = base64.StdEncoding.EncodeToString(ct)
cipherString = make([]byte, base64.StdEncoding.EncodedLen(len(ct)))
base64.StdEncoding.Encode(cipherString, ct)
return
}
// Decrypt data that has been encoded using a provided symmetric conversation
// key using NIP-44 encryption (chacha20 cipher stream and sha256 HMAC).
func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
plaintext string,
func Decrypt(b64ciphertextWrapped, conversationKey []byte) (
plaintext []byte,
err error,
) {
cLen := len(b64ciphertextWrapped)
@@ -108,12 +109,12 @@ func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
err = errorf.E("invalid payload length: %d", cLen)
return
}
if b64ciphertextWrapped[:1] == "#" {
if len(b64ciphertextWrapped) > 0 && b64ciphertextWrapped[0] == '#' {
err = errorf.E("unknown version")
return
}
var decoded []byte
if decoded, err = base64.StdEncoding.DecodeString(b64ciphertextWrapped); chk.E(err) {
if decoded, err = base64.StdEncoding.DecodeString(string(b64ciphertextWrapped)); chk.E(err) {
return
}
if decoded[0] != version {
@@ -153,7 +154,7 @@ func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
err = errorf.E("invalid padding")
return
}
plaintext = string(unpadded)
plaintext = unpadded
return
}
@@ -167,8 +168,16 @@ func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
)
return
}
var sign signer.I
if sign, err = p256k.NewSecFromHex(skh); chk.E(err) {
return
}
var pk []byte
if pk, err = p256k.HexToBin(pkh); chk.E(err) {
return
}
var shared []byte
if shared, err = ComputeSharedSecret(pkh, skh); chk.E(err) {
if shared, err = sign.ECDH(pk); chk.E(err) {
return
}
ck = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))

View File

@@ -19,10 +19,10 @@ func assertCryptPriv(
sk1, sk2, conversationKey, salt, plaintext, expected string,
) {
var (
k1, s []byte
actual, decrypted string
ok bool
err error
k1, s, plaintextBytes, actualBytes,
expectedBytes, decrypted []byte
ok bool
err error
)
k1, err = hex.Dec(conversationKey)
if ok = assert.NoErrorf(
@@ -41,25 +41,27 @@ func assertCryptPriv(
); !ok {
return
}
actual, err = Encrypt(plaintext, k1, WithCustomNonce(s))
plaintextBytes = []byte(plaintext)
actualBytes, err = Encrypt(plaintextBytes, k1, WithCustomNonce(s))
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
return
}
if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok {
expectedBytes = []byte(expected)
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
return
}
decrypted, err = Decrypt(expected, k1)
decrypted, err = Decrypt(expectedBytes, k1)
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
return
}
assert.Equal(t, decrypted, plaintext, "wrong decryption")
assert.Equal(t, decrypted, plaintextBytes, "wrong decryption")
}
func assertDecryptFail(
t *testing.T, conversationKey, plaintext, ciphertext, msg string,
) {
var (
k1 []byte
k1, ciphertextBytes []byte
ok bool
err error
)
@@ -69,7 +71,8 @@ func assertDecryptFail(
); !ok {
return
}
_, err = Decrypt(ciphertext, k1)
ciphertextBytes = []byte(ciphertext)
_, err = Decrypt(ciphertextBytes, k1)
assert.ErrorContains(t, err, msg)
}
@@ -196,15 +199,15 @@ func assertMessageKeyGeneration(
}
func assertCryptLong(
t *testing.T, conversationKey, salt, pattern string, repeat int,
t *testing.T, conversationKey, salt string, pattern []byte, repeat int,
plaintextSha256, payloadSha256 string,
) {
var (
convKey, convSalt []byte
plaintext, actualPlaintextSha256, actualPayload, actualPayloadSha256 string
h hash.Hash
ok bool
err error
convKey, convSalt, plaintext, payloadBytes []byte
actualPlaintextSha256, actualPayloadSha256 string
h hash.Hash
ok bool
err error
)
convKey, err = hex.Dec(conversationKey)
if ok = assert.NoErrorf(
@@ -218,12 +221,12 @@ func assertCryptLong(
); !ok {
return
}
plaintext = ""
plaintext = make([]byte, 0, len(pattern)*repeat)
for i := 0; i < repeat; i++ {
plaintext += pattern
plaintext = append(plaintext, pattern...)
}
h = sha256.New()
h.Write([]byte(plaintext))
h.Write(plaintext)
actualPlaintextSha256 = hex.Enc(h.Sum(nil))
if ok = assert.Equalf(
t, plaintextSha256, actualPlaintextSha256,
@@ -231,12 +234,14 @@ func assertCryptLong(
); !ok {
return
}
actualPayload, err = Encrypt(plaintext, convKey, WithCustomNonce(convSalt))
payloadBytes, err = Encrypt(
plaintext, convKey, WithCustomNonce(convSalt),
)
if ok = assert.NoErrorf(t, err, "encryption failed: %v", err); !ok {
return
}
h.Reset()
h.Write([]byte(actualPayload))
h.Write(payloadBytes)
actualPayloadSha256 = hex.Enc(h.Sum(nil))
if ok = assert.Equalf(
t, payloadSha256, actualPayloadSha256,
@@ -383,7 +388,7 @@ func TestCryptLong001(t *testing.T) {
t,
"8fc262099ce0d0bb9b89bac05bb9e04f9bc0090acc181fef6840ccee470371ed",
"326bcb2c943cd6bb717588c9e5a7e738edf6ed14ec5f5344caa6ef56f0b9cff7",
"x",
[]byte("x"),
65535,
"09ab7495d3e61a76f0deb12cb0306f0696cbb17ffc12131368c7a939f12f56d3",
"90714492225faba06310bff2f249ebdc2a5e609d65a629f1c87f2d4ffc55330a",
@@ -395,7 +400,7 @@ func TestCryptLong002(t *testing.T) {
t,
"56adbe3720339363ab9c3b8526ffce9fd77600927488bfc4b59f7a68ffe5eae0",
"ad68da81833c2a8ff609c3d2c0335fd44fe5954f85bb580c6a8d467aa9fc5dd0",
"!",
[]byte("!"),
65535,
"6af297793b72ae092c422e552c3bb3cbc310da274bd1cf9e31023a7fe4a2d75e",
"8013e45a109fad3362133132b460a2d5bce235fe71c8b8f4014793fb52a49844",
@@ -407,7 +412,7 @@ func TestCryptLong003(t *testing.T) {
t,
"7fc540779979e472bb8d12480b443d1e5eb1098eae546ef2390bee499bbf46be",
"34905e82105c20de9a2f6cd385a0d541e6bcc10601d12481ff3a7575dc622033",
"🦄",
[]byte("🦄"),
16383,
"a249558d161b77297bc0cb311dde7d77190f6571b25c7e4429cd19044634a61f",
"b3348422471da1f3c59d79acfe2fe103f3cd24488109e5b18734cdb5953afd15",
@@ -1309,7 +1314,10 @@ func TestMaxLength(t *testing.T) {
rand.Read(salt)
conversationKey, _ := GenerateConversationKey(pub2, string(sk1))
plaintext := strings.Repeat("a", MaxPlaintextSize)
encrypted, err := Encrypt(plaintext, conversationKey, WithCustomNonce(salt))
plaintextBytes := []byte(plaintext)
encrypted, err := Encrypt(
plaintextBytes, conversationKey, WithCustomNonce(salt),
)
if chk.E(err) {
t.Error(err)
}
@@ -1321,7 +1329,7 @@ func TestMaxLength(t *testing.T) {
fmt.Sprintf("%x", conversationKey),
fmt.Sprintf("%x", salt),
plaintext,
encrypted,
string(encrypted),
)
}
@@ -1330,10 +1338,10 @@ func assertCryptPub(
sk1, pub2, conversationKey, salt, plaintext, expected string,
) {
var (
k1, s []byte
actual, decrypted string
ok bool
err error
k1, s, plaintextBytes,
actualBytes, expectedBytes, decrypted []byte
ok bool
err error
)
k1, err = hex.Dec(conversationKey)
if ok = assert.NoErrorf(
@@ -1352,16 +1360,18 @@ func assertCryptPub(
); !ok {
return
}
actual, err = Encrypt(plaintext, k1, WithCustomNonce(s))
plaintextBytes = []byte(plaintext)
actualBytes, err = Encrypt(plaintextBytes, k1, WithCustomNonce(s))
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
return
}
if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok {
expectedBytes = []byte(expected)
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
return
}
decrypted, err = Decrypt(expected, k1)
decrypted, err = Decrypt(expectedBytes, k1)
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
return
}
assert.Equal(t, decrypted, plaintext, "wrong decryption")
assert.Equal(t, decrypted, plaintextBytes, "wrong decryption")
}

View File

@@ -18,3 +18,7 @@ type Signer = btcec.Signer
type Keygen = btcec.Keygen
func NewKeygen() (k *Keygen) { return new(Keygen) }
var NewSecFromHex = btcec.NewSecFromHex
var NewPubFromHex = btcec.NewPubFromHex
var HexToBin = btcec.HexToBin

View File

@@ -55,10 +55,20 @@ func (s *Signer) InitPub(pub []byte) (err error) {
}
// Sec returns the raw secret key bytes.
func (s *Signer) Sec() (b []byte) { return s.skb }
func (s *Signer) Sec() (b []byte) {
if s == nil {
return nil
}
return s.skb
}
// Pub returns the raw BIP-340 schnorr public key bytes.
func (s *Signer) Pub() (b []byte) { return s.pkb }
func (s *Signer) Pub() (b []byte) {
if s == nil {
return nil
}
return s.pkb
}
// Sign a message with the Signer. Requires an initialised secret key.
func (s *Signer) Sign(msg []byte) (sig []byte, err error) {

View File

@@ -0,0 +1,40 @@
//go:build !cgo
package btcec
import (
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/utils/chk"
)
func NewSecFromHex[V []byte | string](skh V) (sign signer.I, err error) {
var sk []byte
if _, err = hex.DecBytes(sk, []byte(skh)); chk.E(err) {
return
}
sign = &Signer{}
if err = sign.InitSec(sk); chk.E(err) {
return
}
return
}
func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
var sk []byte
if _, err = hex.DecBytes(sk, []byte(pkh)); chk.E(err) {
return
}
sign = &Signer{}
if err = sign.InitPub(sk); chk.E(err) {
return
}
return
}
func HexToBin(hexStr string) (b []byte, err error) {
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
return
}
return
}

View File

@@ -0,0 +1,40 @@
//go:build cgo
package p256k
import (
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/utils/chk"
)
func NewSecFromHex[V []byte | string](skh V) (sign signer.I, err error) {
var sk []byte
if _, err = hex.DecBytes(sk, []byte(skh)); chk.E(err) {
return
}
sign = &Signer{}
if err = sign.InitSec(sk); chk.E(err) {
return
}
return
}
func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
var sk []byte
if _, err = hex.DecBytes(sk, []byte(pkh)); chk.E(err) {
return
}
sign = &Signer{}
if err = sign.InitPub(sk); chk.E(err) {
return
}
return
}
func HexToBin(hexStr string) (b []byte, err error) {
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
return
}
return
}

View File

@@ -77,8 +77,18 @@ func (s *Signer) InitPub(pub []byte) (err error) {
return
}
func (s *Signer) Sec() (b []byte) { return s.skb }
func (s *Signer) Pub() (b []byte) { return s.pkb }
func (s *Signer) Sec() (b []byte) {
if s == nil {
return nil
}
return s.skb
}
func (s *Signer) Pub() (b []byte) {
if s == nil {
return nil
}
return s.pkb
}
// func (s *Signer) ECPub() (b []byte) { return s.pkb }

View File

@@ -55,7 +55,7 @@ func TestExport(t *testing.T) {
}
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event: %v", err)
}

View File

@@ -56,7 +56,7 @@ func TestFetchEventBySerial(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -53,7 +53,7 @@ func TestGetSerialById(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -60,7 +60,7 @@ func TestGetSerialsByRange(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -56,7 +56,7 @@ func (d *D) Import(rr io.Reader) {
continue
}
if _, _, err = d.SaveEvent(d.ctx, ev, false); err != nil {
if _, _, err = d.SaveEvent(d.ctx, ev, false, nil); err != nil {
continue
}

View File

@@ -45,7 +45,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
baseEvent.Sign(sign)
// Save the base parameterized replaceable event
if _, _, err := db.SaveEvent(ctx, baseEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, baseEvent, false, nil); err != nil {
t.Fatalf("Failed to save base parameterized replaceable event: %v", err)
}
@@ -63,7 +63,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
newerEvent.Sign(sign)
// Save the newer parameterized replaceable event
if _, _, err := db.SaveEvent(ctx, newerEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, newerEvent, false, nil); err != nil {
t.Fatalf(
"Failed to save newer parameterized replaceable event: %v", err,
)
@@ -83,7 +83,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
newestEvent.Sign(sign)
// Save the newest parameterized replaceable event
if _, _, err := db.SaveEvent(ctx, newestEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, newestEvent, false, nil); err != nil {
t.Fatalf(
"Failed to save newest parameterized replaceable event: %v", err,
)

View File

@@ -16,11 +16,6 @@ import (
"strconv"
)
// QueryEvents retrieves events based on the provided filter. If the filter
// contains Ids, it fetches events by those Ids directly, overriding other
// filter criteria. Otherwise, it queries by other filter criteria and fetches
// matching events. Results are returned in reverse chronological order of their
// creation timestamps.
func (d *D) QueryEvents(c context.T, f *filter.F) (evs event.S, err error) {
// if there is Ids in the query, this overrides anything else
if f.Ids != nil && f.Ids.Len() > 0 {

View File

@@ -62,7 +62,7 @@ func setupTestDB(t *testing.T) (
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}
@@ -202,7 +202,9 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
replaceableEvent.Tags = tags.New()
replaceableEvent.Sign(sign)
// Save the replaceable event
if _, _, err := db.SaveEvent(ctx, replaceableEvent, false); err != nil {
if _, _, err := db.SaveEvent(
ctx, replaceableEvent, false, nil,
); err != nil {
t.Fatalf("Failed to save replaceable event: %v", err)
}
@@ -216,7 +218,7 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
newerEvent.Tags = tags.New()
newerEvent.Sign(sign)
// Save the newer event
if _, _, err := db.SaveEvent(ctx, newerEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, newerEvent, false, nil); err != nil {
t.Fatalf("Failed to save newer event: %v", err)
}
@@ -293,7 +295,7 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
)
// Save the deletion event
if _, _, err = db.SaveEvent(ctx, deletionEvent, false); err != nil {
if _, _, err = db.SaveEvent(ctx, deletionEvent, false, nil); err != nil {
t.Fatalf("Failed to save deletion event: %v", err)
}
@@ -379,7 +381,7 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
paramEvent.Sign(sign)
// Save the parameterized replaceable event
if _, _, err := db.SaveEvent(ctx, paramEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, paramEvent, false, nil); err != nil {
t.Fatalf("Failed to save parameterized replaceable event: %v", err)
}
@@ -405,7 +407,9 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
paramDeletionEvent.Sign(sign)
// Save the parameterized deletion event
if _, _, err := db.SaveEvent(ctx, paramDeletionEvent, false); err != nil {
if _, _, err := db.SaveEvent(
ctx, paramDeletionEvent, false, nil,
); err != nil {
t.Fatalf("Failed to save parameterized deletion event: %v", err)
}
@@ -438,7 +442,9 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
paramDeletionEvent2.Sign(sign)
// Save the parameterized deletion event with e-tag
if _, _, err := db.SaveEvent(ctx, paramDeletionEvent2, false); err != nil {
if _, _, err := db.SaveEvent(
ctx, paramDeletionEvent2, false, nil,
); err != nil {
t.Fatalf(
"Failed to save parameterized deletion event with e-tag: %v", err,
)

View File

@@ -57,7 +57,7 @@ func TestQueryForAuthorsTags(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -56,7 +56,7 @@ func TestQueryForCreatedAt(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -60,7 +60,7 @@ func TestQueryForIds(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -58,7 +58,7 @@ func TestQueryForKindsAuthorsTags(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -58,7 +58,7 @@ func TestQueryForKindsAuthors(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -58,7 +58,7 @@ func TestQueryForKindsTags(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -57,7 +57,7 @@ func TestQueryForKinds(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -60,7 +60,7 @@ func TestQueryForSerials(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -57,7 +57,7 @@ func TestQueryForTags(t *testing.T) {
events = append(events, ev)
// Save the event to the database
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}

View File

@@ -20,9 +20,9 @@ import (
)
// SaveEvent saves an event to the database, generating all the necessary indexes.
func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
kc, vc int, err error,
) {
func (d *D) SaveEvent(
c context.T, ev *event.E, noVerify bool, owners [][]byte,
) (kc, vc int, err error) {
if !noVerify {
// check if the event already exists
var ser *types.Uint40
@@ -94,9 +94,13 @@ func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
}
} else {
var idxs []Range
keys := [][]byte{ev.Pubkey}
for _, owner := range owners {
keys = append(keys, owner)
}
if idxs, err = GetIndexesFromFilter(
&filter.F{
Authors: tag.New(ev.Pubkey),
Authors: tag.New(keys...),
Kinds: kinds.New(kind.Deletion),
Tags: tags.New(tag.New([]byte("#e"), ev.ID)),
},
@@ -115,7 +119,7 @@ func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
// really there can only be one of these; the chances of an idhash
// collision are basically zero in practice, at least, one in a
// billion or more anyway, more than a human is going to create.
err = errorf.E("blocked: %0x was deleted by event ID", ev.ID)
err = errorf.E("blocked: event %0x deleted by event ID", ev.ID)
return
}
}

View File

@@ -64,7 +64,7 @@ func TestSaveEvents(t *testing.T) {
// Save the event to the database
var k, v int
if k, v, err = db.SaveEvent(ctx, ev, false); err != nil {
if k, v, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
}
kc += k
@@ -125,7 +125,7 @@ func TestDeletionEventWithETagRejection(t *testing.T) {
regularEvent.Sign(sign)
// Save the regular event
if _, _, err := db.SaveEvent(ctx, regularEvent, false); err != nil {
if _, _, err := db.SaveEvent(ctx, regularEvent, false, nil); err != nil {
t.Fatalf("Failed to save regular event: %v", err)
}
@@ -146,7 +146,7 @@ func TestDeletionEventWithETagRejection(t *testing.T) {
deletionEvent.Sign(sign)
// Try to save the deletion event, it should be rejected
_, _, err = db.SaveEvent(ctx, deletionEvent, false)
_, _, err = db.SaveEvent(ctx, deletionEvent, false, nil)
if err == nil {
t.Fatal("Expected deletion event with e-tag to be rejected, but it was accepted")
}
@@ -198,12 +198,12 @@ func TestSaveExistingEvent(t *testing.T) {
ev.Sign(sign)
// Save the event for the first time
if _, _, err := db.SaveEvent(ctx, ev, false); err != nil {
if _, _, err := db.SaveEvent(ctx, ev, false, nil); err != nil {
t.Fatalf("Failed to save event: %v", err)
}
// Try to save the same event again, it should be rejected
_, _, err = db.SaveEvent(ctx, ev, false)
_, _, err = db.SaveEvent(ctx, ev, false, nil)
if err == nil {
t.Fatal("Expected error when saving an existing event, but got nil")
}

View File

@@ -82,8 +82,12 @@ type Deleter interface {
}
type Saver interface {
// SaveEvent is called once relay.AcceptEvent reports true.
SaveEvent(c context.T, ev *event.E, noVerify bool) (kc, vc int, err error)
// SaveEvent is called once relay.AcceptEvent reports true. The owners
// parameter is for designating admins whose delete by e tag events apply
// the same as author's own.
SaveEvent(
c context.T, ev *event.E, noVerify bool, owners [][]byte,
) (kc, vc int, err error)
}
type Importer interface {

View File

@@ -46,7 +46,7 @@ func TestQuery(t *testing.T) {
var err error
var pp *pointers.Profile
acct := "fiatjaf.com"
if pp, err = QueryIdentifier(context.Background(), acct); chk.E(err) {
if pp, err = QueryIdentifier(context.Bg(), acct); chk.E(err) {
t.Fatal(err)
}
if pkb, err = keys.HexPubkeyToBytes(
@@ -58,7 +58,7 @@ func TestQuery(t *testing.T) {
t.Fatalf("invalid query for fiatjaf.com")
}
pp, err = QueryIdentifier(context.Background(), "htlc@fiatjaf.com")
pp, err = QueryIdentifier(context.Bg(), "htlc@fiatjaf.com")
if pkb, err = keys.HexPubkeyToBytes(
"f9dd6a762506260b38a2d3e5b464213c2e47fa3877429fe9ee60e071a31a07d7",
); chk.E(err) {

614
pkg/protocol/nwc/client.go Normal file
View File

@@ -0,0 +1,614 @@
package nwc
import (
"encoding/json"
"fmt"
"net/url"
"orly.dev/pkg/crypto/encryption"
"orly.dev/pkg/crypto/p256k"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/filters"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/encoders/tags"
"orly.dev/pkg/encoders/timestamp"
"orly.dev/pkg/interfaces/signer"
"orly.dev/pkg/protocol/ws"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/log"
"strings"
"sync"
"time"
)
// Options represents options for a NWC client
type Options struct {
RelayURL string
Secret signer.I
WalletPubkey []byte
Lud16 string
}
// Client represents a NWC client
type Client struct {
options Options
relay *ws.Client
mu sync.Mutex
}
// ParseWalletConnectURL parses a wallet connect URL
func ParseWalletConnectURL(walletConnectURL string) (opts *Options, err error) {
if !strings.HasPrefix(walletConnectURL, "nostr+walletconnect://") {
return nil, fmt.Errorf("unexpected scheme. Should be nostr+walletconnect://")
}
// Parse URL
colonIndex := strings.Index(walletConnectURL, ":")
if colonIndex == -1 {
err = fmt.Errorf("invalid URL format")
return
}
walletConnectURL = walletConnectURL[colonIndex+1:]
if strings.HasPrefix(walletConnectURL, "//") {
walletConnectURL = walletConnectURL[2:]
}
walletConnectURL = "https://" + walletConnectURL
var u *url.URL
if u, err = url.Parse(walletConnectURL); chk.E(err) {
err = fmt.Errorf("failed to parse URL: %w", err)
return
}
// Get wallet pubkey
walletPubkey := u.Host
if len(walletPubkey) != 64 {
err = fmt.Errorf("incorrect wallet pubkey found in auth string")
return
}
var pk []byte
if pk, err = hex.Dec(walletPubkey); chk.E(err) {
err = fmt.Errorf("failed to decode pubkey: %w", err)
return
}
// Get relay URL
relayURL := u.Query().Get("relay")
if relayURL == "" {
return nil, fmt.Errorf("no relay URL found in auth string")
}
// Get secret
secret := u.Query().Get("secret")
if secret == "" {
return nil, fmt.Errorf("no secret found in auth string")
}
var sk []byte
if sk, err = hex.Dec(secret); chk.E(err) {
return
}
sign := &p256k.Signer{}
if err = sign.InitSec(sk); chk.E(err) {
return
}
opts = &Options{
RelayURL: relayURL,
Secret: sign,
WalletPubkey: pk,
}
return
}
// NewNWCClient creates a new NWC client
func NewNWCClient(options *Options) (cl *Client, err error) {
if options.RelayURL == "" {
err = fmt.Errorf("missing relay URL")
return
}
if options.Secret == nil {
err = fmt.Errorf("missing secret")
return
}
if options.WalletPubkey == nil {
err = fmt.Errorf("missing wallet pubkey")
return
}
return &Client{
options: Options{
RelayURL: options.RelayURL,
Secret: options.Secret,
WalletPubkey: options.WalletPubkey,
Lud16: options.Lud16,
},
}, nil
}
// NostrWalletConnectURL returns the nostr wallet connect URL
func (c *Client) NostrWalletConnectURL() string {
return c.GetNostrWalletConnectURL(true)
}
// GetNostrWalletConnectURL returns the nostr wallet connect URL
func (c *Client) GetNostrWalletConnectURL(includeSecret bool) string {
params := url.Values{}
params.Add("relay", c.options.RelayURL)
if includeSecret {
params.Add("secret", hex.Enc(c.options.Secret.Sec()))
}
return fmt.Sprintf(
"nostr+walletconnect://%s?%s", c.options.WalletPubkey, params.Encode(),
)
}
// Connected returns whether the client is connected to the relay
func (c *Client) Connected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.relay != nil && c.relay.IsConnected()
}
// GetPublicKey returns the client's public key
func (c *Client) GetPublicKey() (pubkey []byte, err error) {
pubkey = c.options.Secret.Pub()
return
}
// Close closes the relay connection
func (c *Client) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.relay != nil {
c.relay.Close()
c.relay = nil
}
}
// Encrypt encrypts content for a pubkey
func (c *Client) encrypt(pubkey, content []byte) (
cipherText []byte, err error,
) {
var sharedSecret []byte
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
return
}
cipherText, err = encryption.EncryptNip4(content, sharedSecret)
return
}
// Decrypt decrypts content from a pubkey
func (c *Client) decrypt(pubkey, content []byte) (plaintext []byte, err error) {
var sharedSecret []byte
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
return
}
plaintext, err = encryption.DecryptNip4(content, sharedSecret)
return
}
// GetInfo gets wallet info
func (c *Client) GetInfo() (response *GetInfoResponse, err error) {
var result []byte
if result, err = c.executeRequest(GetInfo, nil); chk.E(err) {
return
}
response = &GetInfoResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// GetBudget gets wallet budget
func (c *Client) GetBudget() (response *GetBudgetResponse, err error) {
var result []byte
result, err = c.executeRequest(GetBudget, nil)
if err != nil {
return nil, err
}
response = &GetBudgetResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// GetBalance gets wallet balance
func (c *Client) GetBalance() (response *GetBalanceResponse, err error) {
var result []byte
if result, err = c.executeRequest(GetBalance, nil); chk.E(err) {
return
}
response = &GetBalanceResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// PayInvoice pays an invoice
func (c *Client) PayInvoice(request *PayInvoiceRequest) (
response *PayResponse, err error,
) {
var result []byte
result, err = c.executeRequest(PayInvoice, request)
if err != nil {
return nil, err
}
response = &PayResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// PayKeysend sends a keysend payment
func (c *Client) PayKeysend(request *PayKeysendRequest) (
response *PayResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(PayKeysend, request); chk.E(err) {
return
}
response = &PayResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// MakeInvoice creates an invoice
func (c *Client) MakeInvoice(request *MakeInvoiceRequest) (
response *Transaction, err error,
) {
var result []byte
if result, err = c.executeRequest(MakeInvoice, request); chk.E(err) {
return
}
response = &Transaction{}
if err = json.Unmarshal(result, response); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return
}
// LookupInvoice looks up an invoice
func (c *Client) LookupInvoice(request *LookupInvoiceRequest) (
response *Transaction, err error,
) {
var result []byte
if result, err = c.executeRequest(LookupInvoice, request); chk.E(err) {
return
}
response = &Transaction{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// ListTransactions lists transactions
func (c *Client) ListTransactions(request *ListTransactionsRequest) (
response *ListTransactionsResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(ListTransactions, request); chk.E(err) {
return
}
response = &ListTransactionsResponse{}
if err = json.Unmarshal(result, response); chk.E(err) {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// SignMessage signs a message
func (c *Client) SignMessage(request *SignMessageRequest) (
response *SignMessageResponse, err error,
) {
var result []byte
if result, err = c.executeRequest(SignMessage, request); chk.E(err) {
return
}
response = &SignMessageResponse{}
if err = json.Unmarshal(result, response); err != nil {
err = fmt.Errorf("failed to unmarshal response: %w", err)
return
}
return
}
// NotificationHandler is a function that handles notifications
type NotificationHandler func(*Notification)
// SubscribeNotifications subscribes to notifications
func (c *Client) SubscribeNotifications(
handler NotificationHandler,
notificationTypes []NotificationType,
) (stop func(), err error) {
if handler == nil {
err = fmt.Errorf("missing notification handler")
return
}
ctx, cancel := context.Cancel(context.Bg())
doneCh := make(chan struct{})
stop = func() {
cancel()
<-doneCh
}
go func() {
defer close(doneCh)
for {
select {
case <-ctx.Done():
return
default:
// Check connection
if err := c.checkConnected(); err != nil {
time.Sleep(1 * time.Second)
continue
}
// Get client pubkey
var clientPubkey []byte
if clientPubkey, err = c.GetPublicKey(); chk.E(err) {
time.Sleep(1 * time.Second)
continue
}
// Subscribe to events
f := &filter.F{
Kinds: kinds.New(kind.WalletResponse),
Authors: tag.New(c.options.WalletPubkey),
Tags: tags.New(tag.New([]byte("#p"), clientPubkey)),
}
var sub *ws.Subscription
if sub, err = c.relay.Subscribe(
context.Bg(), &filters.T{
F: []*filter.F{f},
},
); chk.E(err) {
time.Sleep(1 * time.Second)
continue
}
// Handle events
for {
select {
case <-ctx.Done():
sub.Close()
return
case ev := <-sub.Events:
// Decrypt content
var decryptedContent []byte
if decryptedContent, err = c.decrypt(
c.options.WalletPubkey, ev.Content,
); chk.E(err) {
log.E.F(
"Failed to decrypt event content: %v\n", err,
)
continue
}
// Parse notification
notification := &Notification{}
if err = json.Unmarshal(
decryptedContent, notification,
); chk.E(err) {
log.E.F(
"Failed to parse notification: %v\n", err,
)
continue
}
// Check if notification type is requested
if len(notificationTypes) > 0 {
found := false
for _, t := range notificationTypes {
if notification.NotificationType == t {
found = true
break
}
}
if !found {
continue
}
}
// Handle notification
handler(notification)
case <-sub.EndOfStoredEvents:
// Ignore
}
}
}
}
}()
return
}
// executeRequest executes a NIP-47 request
func (c *Client) executeRequest(
method Method,
params any,
) (msg json.RawMessage, err error) {
// Default timeout values
replyTimeout := 3 * time.Second
publishTimeout := 3 * time.Second
// Create context with timeout
ctx, cancel := context.Timeout(context.Bg(), replyTimeout)
defer cancel()
// Create result channel
resultCh := make(chan json.RawMessage, 1)
errCh := make(chan error, 1)
// Check connection
if err = c.checkConnected(); err != nil {
return nil, err
}
// Create request
request := struct {
Method Method `json:"method"`
Params any `json:"params,omitempty"`
}{
Method: method,
Params: params,
}
// Marshal request
requestJSON, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Encrypt request
var encryptedContent []byte
if encryptedContent, err = c.encrypt(
c.options.WalletPubkey, requestJSON,
); chk.E(err) {
return nil, fmt.Errorf("failed to encrypt request: %w", err)
}
// Create request event
requestEvent := &event.E{
Kind: kind.WalletRequest,
CreatedAt: timestamp.New(time.Now().Unix()),
Tags: tags.New(tag.New("p", hex.Enc(c.options.WalletPubkey))),
Content: encryptedContent,
}
// Sign request event
err = requestEvent.Sign(c.options.Secret)
if err != nil {
return nil, fmt.Errorf("failed to sign request event: %w", err)
}
// Subscribe to response events
f := &filter.F{
Kinds: kinds.New(kind.WalletResponse),
Authors: tag.New(c.options.WalletPubkey),
Tags: tags.New(tag.New([]byte("#p"), requestEvent.ID)),
}
log.I.F("%s", f.Marshal(nil))
var sub *ws.Subscription
if sub, err = c.relay.Subscribe(
ctx, &filters.T{
F: []*filter.F{f},
},
); chk.E(err) {
err = fmt.Errorf(
"failed to subscribe to response events: %w", err,
)
return
}
defer sub.Close()
// Set up reply timeout
replyTimer := time.AfterFunc(
replyTimeout, func() {
errCh <- NewReplyTimeoutError(
fmt.Sprintf("Timeout waiting for reply to %s", method),
"TIMEOUT",
)
},
)
defer replyTimer.Stop()
// Handle response events
go func() {
var resErr error
for {
select {
case <-ctx.Done():
return
case ev := <-sub.Events:
// Decrypt content
var decryptedContent []byte
decryptedContent, resErr = c.decrypt(
c.options.WalletPubkey, ev.Content,
)
if chk.E(resErr) {
errCh <- fmt.Errorf(
"failed to decrypt response: %w",
resErr,
)
return
}
// Parse response
var response struct {
ResultType string `json:"result_type"`
Result json.RawMessage `json:"result"`
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if resErr = json.Unmarshal(
decryptedContent, &response,
); chk.E(resErr) {
errCh <- fmt.Errorf("failed to parse response: %w", resErr)
return
}
// Check for error
if response.Error != nil {
errCh <- NewWalletError(
response.Error.Message,
response.Error.Code,
)
return
}
// Send result
resultCh <- response.Result
return
case <-sub.EndOfStoredEvents:
// Ignore
}
}
}()
// Publish request event
publishCtx, publishCancel := context.Timeout(
context.Bg(), publishTimeout,
)
defer publishCancel()
if err = c.relay.Publish(publishCtx, requestEvent); chk.E(err) {
err = fmt.Errorf("failed to publish request event: %w", err)
return
}
// Wait for result or error
select {
case msg = <-resultCh:
return
case err = <-errCh:
return
case <-ctx.Done():
err = NewReplyTimeoutError(
fmt.Sprintf("Timeout waiting for reply to %s", method),
"TIMEOUT",
)
return
}
}
// checkConnected checks if the client is connected to the relay
func (c *Client) checkConnected() (err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.options.RelayURL == "" {
return fmt.Errorf("missing relay URL")
}
if c.relay == nil {
if c.relay, err = ws.RelayConnect(
context.Bg(), c.options.RelayURL,
); chk.E(err) {
return NewNetworkError(
"Failed to connect to "+c.options.RelayURL,
"OTHER",
)
}
} else if !c.relay.IsConnected() {
c.relay.Close()
if c.relay, err = ws.RelayConnect(
context.Bg(), c.options.RelayURL,
); chk.E(err) {
return NewNetworkError(
"Failed to connect to "+c.options.RelayURL,
"OTHER",
)
}
}
return nil
}

473
pkg/protocol/nwc/types.go Normal file
View File

@@ -0,0 +1,473 @@
package nwc
import (
"fmt"
"time"
)
// EncryptionType represents the encryption type used for NIP-47 messages
type EncryptionType string
const (
Nip04 EncryptionType = "nip04"
Nip44V2 EncryptionType = "nip44_v2"
)
// AuthorizationUrlOptions represents options for creating an NWC authorization URL
type AuthorizationUrlOptions struct {
Name string `json:"name,omitempty"`
Icon string `json:"icon,omitempty"`
RequestMethods []Method `json:"requestMethods,omitempty"`
NotificationTypes []NotificationType `json:"notificationTypes,omitempty"`
ReturnTo string `json:"returnTo,omitempty"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
MaxAmount *int64 `json:"maxAmount,omitempty"`
BudgetRenewal BudgetRenewalPeriod `json:"budgetRenewal,omitempty"`
Isolated bool `json:"isolated,omitempty"`
Metadata interface{} `json:"metadata,omitempty"`
}
// Err is the base error type for NIP-47 errors
type Err struct {
Message string
Code string
}
func (e *Err) Error() string {
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
}
// NewError creates a new Error
func NewError(message, code string) *Err {
return &Err{
Message: message,
Code: code,
}
}
// NetworkError represents a network error in NIP-47 operations
type NetworkError struct{ *Err }
// NewNetworkError creates a new NetworkError
func NewNetworkError(message, code string) *NetworkError {
return &NetworkError{
Err: NewError(message, code),
}
}
// WalletError represents a wallet error in NIP-47 operations
type WalletError struct {
*Err
}
// NewWalletError creates a new WalletError
func NewWalletError(message, code string) *WalletError {
return &WalletError{
Err: NewError(message, code),
}
}
// TimeoutError represents a timeout error in NIP-47 operations
type TimeoutError struct{ *Err }
// NewTimeoutError creates a new TimeoutError
func NewTimeoutError(message, code string) *TimeoutError {
return &TimeoutError{
Err: NewError(message, code),
}
}
// PublishTimeoutError represents a publish timeout error in NIP-47 operations
type PublishTimeoutError struct{ *TimeoutError }
// NewPublishTimeoutError creates a new PublishTimeoutError
func NewPublishTimeoutError(message, code string) *PublishTimeoutError {
return &PublishTimeoutError{
TimeoutError: NewTimeoutError(message, code),
}
}
// ReplyTimeoutError represents a reply timeout error in NIP-47 operations
type ReplyTimeoutError struct{ *TimeoutError }
// NewReplyTimeoutError creates a new ReplyTimeoutError
func NewReplyTimeoutError(message, code string) *ReplyTimeoutError {
return &ReplyTimeoutError{
TimeoutError: NewTimeoutError(message, code),
}
}
// PublishError represents a publish error in NIP-47 operations
type PublishError struct{ *Err }
// NewPublishError creates a new PublishError
func NewPublishError(message, code string) *PublishError {
return &PublishError{
Err: NewError(message, code),
}
}
// ResponseDecodingError represents a response decoding error in NIP-47 operations
type ResponseDecodingError struct{ *Err }
// NewResponseDecodingError creates a new ResponseDecodingError
func NewResponseDecodingError(message, code string) *ResponseDecodingError {
return &ResponseDecodingError{
Err: NewError(message, code),
}
}
// ResponseValidationError represents a response validation error in NIP-47 operations
type ResponseValidationError struct{ *Err }
// NewResponseValidationError creates a new ResponseValidationError
func NewResponseValidationError(message, code string) *ResponseValidationError {
return &ResponseValidationError{
Err: NewError(message, code),
}
}
// UnexpectedResponseError represents an unexpected response error in NIP-47 operations
type UnexpectedResponseError struct{ *Err }
// NewUnexpectedResponseError creates a new UnexpectedResponseError
func NewUnexpectedResponseError(message, code string) *UnexpectedResponseError {
return &UnexpectedResponseError{
Err: NewError(message, code),
}
}
// UnsupportedEncryptionError represents an unsupported encryption error in NIP-47 operations
type UnsupportedEncryptionError struct {
*Err
}
// NewUnsupportedEncryptionError creates a new UnsupportedEncryptionError
func NewUnsupportedEncryptionError(message, code string) *UnsupportedEncryptionError {
return &UnsupportedEncryptionError{
Err: NewError(message, code),
}
}
// WithDTag represents a type with a dTag field
type WithDTag struct {
DTag string `json:"dTag"`
}
// WithOptionalId represents a type with an optional id field
type WithOptionalId struct {
ID string `json:"id,omitempty"`
}
// Method represents a NIP-47 method
type Method string
// SingleMethod represents a single NIP-47 method
const (
GetInfo Method = "get_info"
GetBalance Method = "get_balance"
GetBudget Method = "get_budget"
MakeInvoice Method = "make_invoice"
PayInvoice Method = "pay_invoice"
PayKeysend Method = "pay_keysend"
LookupInvoice Method = "lookup_invoice"
ListTransactions Method = "list_transactions"
SignMessage Method = "sign_message"
CreateConnection Method = "create_connection"
MakeHoldInvoice Method = "make_hold_invoice"
SettleHoldInvoice Method = "settle_hold_invoice"
CancelHoldInvoice Method = "cancel_hold_invoice"
)
// MultiMethod represents a multi NIP-47 method
const (
MultiPayInvoice Method = "multi_pay_invoice"
MultiPayKeysend Method = "multi_pay_keysend"
)
// Capability represents a NIP-47 capability
type Capability string
const (
Notifications Capability = "notifications"
)
// BudgetRenewalPeriod represents a budget renewal period
type BudgetRenewalPeriod string
const (
Daily BudgetRenewalPeriod = "daily"
Weekly BudgetRenewalPeriod = "weekly"
Monthly BudgetRenewalPeriod = "monthly"
Yearly BudgetRenewalPeriod = "yearly"
Never BudgetRenewalPeriod = "never"
)
// GetInfoResponse represents a response to a get_info request
type GetInfoResponse struct {
Alias string `json:"alias"`
Color string `json:"color"`
Pubkey string `json:"pubkey"`
Network string `json:"network"`
BlockHeight int64 `json:"block_height"`
BlockHash string `json:"block_hash"`
Methods []Method `json:"methods"`
Notifications []NotificationType `json:"notifications,omitempty"`
Metadata interface{} `json:"metadata,omitempty"`
Lud16 string `json:"lud16,omitempty"`
}
// GetBudgetResponse represents a response to a get_budget request
type GetBudgetResponse struct {
UsedBudget int64 `json:"used_budget,omitempty"`
TotalBudget int64 `json:"total_budget,omitempty"`
RenewsAt *int64 `json:"renews_at,omitempty"`
RenewalPeriod BudgetRenewalPeriod `json:"renewal_period,omitempty"`
}
// GetBalanceResponse represents a response to a get_balance request
type GetBalanceResponse struct {
Balance int64 `json:"balance"` // msats
}
// PayResponse represents a response to a pay request
type PayResponse struct {
Preimage string `json:"preimage"`
FeesPaid int64 `json:"fees_paid"`
}
// MultiPayInvoiceRequest represents a request to pay multiple invoices
type MultiPayInvoiceRequest struct {
Invoices []PayInvoiceRequestWithID `json:"invoices"`
}
// PayInvoiceRequestWithID combines PayInvoiceRequest with WithOptionalId
type PayInvoiceRequestWithID struct {
PayInvoiceRequest
WithOptionalId
}
// MultiPayKeysendRequest represents a request to pay multiple keysends
type MultiPayKeysendRequest struct {
Keysends []PayKeysendRequestWithID `json:"keysends"`
}
// PayKeysendRequestWithID combines PayKeysendRequest with WithOptionalId
type PayKeysendRequestWithID struct {
PayKeysendRequest
WithOptionalId
}
// MultiPayInvoiceResponse represents a response to a multi_pay_invoice request
type MultiPayInvoiceResponse struct {
Invoices []MultiPayInvoiceResponseItem `json:"invoices"`
Errors []interface{} `json:"errors"` // TODO: add error handling
}
// MultiPayInvoiceResponseItem represents an item in a multi_pay_invoice response
type MultiPayInvoiceResponseItem struct {
Invoice PayInvoiceRequest `json:"invoice"`
PayResponse
WithDTag
}
// MultiPayKeysendResponse represents a response to a multi_pay_keysend request
type MultiPayKeysendResponse struct {
Keysends []MultiPayKeysendResponseItem `json:"keysends"`
Errors []interface{} `json:"errors"` // TODO: add error handling
}
// MultiPayKeysendResponseItem represents an item in a multi_pay_keysend response
type MultiPayKeysendResponseItem struct {
Keysend PayKeysendRequest `json:"keysend"`
PayResponse
WithDTag
}
// ListTransactionsRequest represents a request to list transactions
type ListTransactionsRequest struct {
From *int64 `json:"from,omitempty"`
Until *int64 `json:"until,omitempty"`
Limit *int64 `json:"limit,omitempty"`
Offset *int64 `json:"offset,omitempty"`
Unpaid *bool `json:"unpaid,omitempty"`
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"` // NOTE: non-NIP-47 spec compliant
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"` // NOTE: non-NIP-47 spec compliant
Type *string `json:"type,omitempty"` // "incoming" or "outgoing"
}
// ListTransactionsResponse represents a response to a list_transactions request
type ListTransactionsResponse struct {
Transactions []Transaction `json:"transactions"`
TotalCount int64 `json:"total_count"` // NOTE: non-NIP-47 spec compliant
}
// TransactionType represents the type of a transaction
type TransactionType string
const (
Incoming TransactionType = "incoming"
Outgoing TransactionType = "outgoing"
)
// TransactionState represents the state of a transaction
type TransactionState string
const (
Settled TransactionState = "settled"
Pending TransactionState = "pending"
Failed TransactionState = "failed"
)
// Transaction represents a transaction
type Transaction struct {
Type TransactionType `json:"type"`
State TransactionState `json:"state"` // NOTE: non-NIP-47 spec compliant
Invoice string `json:"invoice"`
Description string `json:"description"`
DescriptionHash string `json:"description_hash"`
Preimage string `json:"preimage"`
PaymentHash string `json:"payment_hash"`
Amount int64 `json:"amount"`
FeesPaid int64 `json:"fees_paid"`
SettledAt int64 `json:"settled_at"`
CreatedAt int64 `json:"created_at"`
ExpiresAt int64 `json:"expires_at"`
SettleDeadline *int64 `json:"settle_deadline,omitempty"` // NOTE: non-NIP-47 spec compliant
Metadata *TransactionMetadata `json:"metadata,omitempty"`
}
// TransactionMetadata represents metadata for a transaction
type TransactionMetadata struct {
Comment string `json:"comment,omitempty"` // LUD-12
PayerData *PayerData `json:"payer_data,omitempty"` // LUD-18
RecipientData *RecipientData `json:"recipient_data,omitempty"` // LUD-18
Nostr *NostrData `json:"nostr,omitempty"` // NIP-57
ExtraData map[string]interface{} `json:"-"` // For additional fields
}
// PayerData represents payer data for a transaction
type PayerData struct {
Email string `json:"email,omitempty"`
Name string `json:"name,omitempty"`
Pubkey string `json:"pubkey,omitempty"`
}
// RecipientData represents recipient data for a transaction
type RecipientData struct {
Identifier string `json:"identifier,omitempty"`
}
// NostrData represents Nostr data for a transaction
type NostrData struct {
Pubkey string `json:"pubkey"`
Tags [][]string `json:"tags"`
}
// NotificationType represents a notification type
type NotificationType string
const (
PaymentReceived NotificationType = "payment_received"
PaymentSent NotificationType = "payment_sent"
HoldInvoiceAccepted NotificationType = "hold_invoice_accepted"
)
// Notification represents a notification
type Notification struct {
NotificationType NotificationType `json:"notification_type"`
Notification Transaction `json:"notification"`
}
// PayInvoiceRequest represents a request to pay an invoice
type PayInvoiceRequest struct {
Invoice string `json:"invoice"`
Metadata *TransactionMetadata `json:"metadata,omitempty"`
Amount *int64 `json:"amount,omitempty"` // msats
}
// PayKeysendRequest represents a request to pay a keysend
type PayKeysendRequest struct {
Amount int64 `json:"amount"` // msats
Pubkey string `json:"pubkey"`
Preimage string `json:"preimage,omitempty"`
TlvRecords []TlvRecord `json:"tlv_records,omitempty"`
}
// TlvRecord represents a TLV record
type TlvRecord struct {
Type int64 `json:"type"`
Value string `json:"value"`
}
// MakeInvoiceRequest represents a request to make an invoice
type MakeInvoiceRequest struct {
Amount int64 `json:"amount"` // msats
Description string `json:"description,omitempty"`
DescriptionHash string `json:"description_hash,omitempty"`
Expiry *int64 `json:"expiry,omitempty"` // in seconds
Metadata *TransactionMetadata `json:"metadata,omitempty"`
}
// MakeHoldInvoiceRequest represents a request to make a hold invoice
type MakeHoldInvoiceRequest struct {
MakeInvoiceRequest
PaymentHash string `json:"payment_hash"`
}
// SettleHoldInvoiceRequest represents a request to settle a hold invoice
type SettleHoldInvoiceRequest struct {
Preimage string `json:"preimage"`
}
// SettleHoldInvoiceResponse represents a response to a settle_hold_invoice request
type SettleHoldInvoiceResponse struct{}
// CancelHoldInvoiceRequest represents a request to cancel a hold invoice
type CancelHoldInvoiceRequest struct {
PaymentHash string `json:"payment_hash"`
}
// CancelHoldInvoiceResponse represents a response to a cancel_hold_invoice request
type CancelHoldInvoiceResponse struct{}
// LookupInvoiceRequest represents a request to lookup an invoice
type LookupInvoiceRequest struct {
PaymentHash string `json:"payment_hash,omitempty"`
Invoice string `json:"invoice,omitempty"`
}
// SignMessageRequest represents a request to sign a message
type SignMessageRequest struct {
Message string `json:"message"`
}
// CreateConnectionRequest represents a request to create a connection
type CreateConnectionRequest struct {
Pubkey string `json:"pubkey"`
Name string `json:"name"`
RequestMethods []Method `json:"request_methods"`
NotificationTypes []NotificationType `json:"notification_types,omitempty"`
MaxAmount *int64 `json:"max_amount,omitempty"`
BudgetRenewal *BudgetRenewalPeriod `json:"budget_renewal,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
Isolated *bool `json:"isolated,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
// CreateConnectionResponse represents a response to a create_connection request
type CreateConnectionResponse struct {
WalletPubkey string `json:"wallet_pubkey"`
}
// SignMessageResponse represents a response to a sign_message request
type SignMessageResponse struct {
Message string `json:"message"`
Signature string `json:"signature"`
}
// TimeoutValues represents timeout values for NIP-47 requests
type TimeoutValues struct {
ReplyTimeout *int64 `json:"replyTimeout,omitempty"`
PublishTimeout *int64 `json:"publishTimeout,omitempty"`
}

View File

@@ -16,8 +16,10 @@ import (
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/iptracker"
"orly.dev/pkg/utils/log"
"strings"
"time"
)
var EventBody = &huma.RequestBody{
@@ -104,10 +106,27 @@ func (x *Operations) RegisterEvent(api huma.API) {
var pubkey []byte
if x.I.AuthRequired() {
// Check if the IP is blocked due to too many failed auth attempts
if iptracker.Global.IsBlocked(remote) {
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
return
}
authed, pubkey, super = x.UserAuth(r, remote)
if !authed {
err = huma.Error401Unauthorized("Not Authorized")
// Record the failed authentication attempt
blocked := iptracker.Global.RecordFailedAttempt(remote)
if blocked {
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
} else {
err = huma.Error401Unauthorized("Not Authorized")
}
return
} else {
// If authentication is successful, remove any blocks for this IP
iptracker.Global.Authenticate(remote)
}
}
// get the other pubkeys from the header that will be sent forward

View File

@@ -580,7 +580,7 @@ type EventsInput struct {
}
type EventsOutput struct {
Body []event.J
Body []*event.J
}
// RegisterEvents is the implementation of the HTTP API Events method.
@@ -667,11 +667,17 @@ Returns events as a JSON array of event objects.`
}
tmp = append(tmp, ev)
}
// cap the number of events to 512 to stop excessively large
// response.
if len(events) > 512 {
break
}
events = tmp
}
}
output = &EventsOutput{}
for _, ev := range events {
_ = ev
output.Body = append(output.Body, ev.ToEventJ())
}
return
},

View File

@@ -7,6 +7,7 @@ import (
"orly.dev/pkg/interfaces/server"
"orly.dev/pkg/protocol/auth"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/iptracker"
"orly.dev/pkg/utils/log"
)
@@ -71,16 +72,9 @@ func (a *A) HandleAuth(b []byte, srv server.I) (msg []byte) {
env.Event.Pubkey,
)
a.Listener.SetAuthedPubkey(env.Event.Pubkey)
// ev := a.Listener.GetPendingEvent()
// if ev != nil {
// var accepted bool
// if accepted, msg = a.I.AddEvent(
// context.Bg(), srv.Relay(), ev, a.Listener.Request,
// a.Listener.RealRemote(),
// ); accepted {
// log.W.F("saved event %0x", ev.Id)
// }
// }
// If authentication is successful, remove any blocks for this IP
iptracker.Global.Authenticate(a.Listener.RealRemote())
}
}
return

View File

@@ -19,8 +19,10 @@ import (
"orly.dev/pkg/interfaces/server"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/iptracker"
"orly.dev/pkg/utils/log"
"strings"
"time"
)
// HandleEvent processes an incoming event by validating its signature, verifying
@@ -71,7 +73,46 @@ func (a *A) HandleEvent(
log.I.F("extra '%s'", rem)
}
if a.I.AuthRequired() && !a.Listener.IsAuthed() {
log.I.F("requesting auth from client from %s", a.Listener.RealRemote())
remoteIP := a.Listener.RealRemote()
log.I.F("requesting auth from client from %s", remoteIP)
// Check if the IP is blocked due to too many failed auth attempts
if iptracker.Global.IsBlocked(remoteIP) {
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockedUntil.Format(time.RFC3339))
// Send a notice to the client explaining why they're blocked
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
err = nil
}
// Close the connection
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
a.Listener.Close()
return
}
// Record the failed authentication attempt
blocked := iptracker.Global.RecordFailedAttempt(remoteIP)
if blocked {
// If this attempt caused the IP to be blocked, close the connection
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
blockedUntil.Format(time.RFC3339))
// Send a notice to the client explaining why they're blocked
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
err = nil
}
// Close the connection
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
a.Listener.Close()
return
}
// Continue with normal auth flow for non-blocked IPs
a.Listener.RequestAuth()
if err = Ok.AuthRequired(a, env.E, "auth required"); chk.E(err) {
return
@@ -163,6 +204,12 @@ func (a *A) HandleEvent(
// check and process delete
if env.E.Kind.K == kind.Deletion.K {
log.I.F("delete event\n%s", env.E.Serialize())
var ownerDelete bool
for _, pk := range a.OwnersPubkeys() {
if bytes.Equal(pk, env.Pubkey) {
ownerDelete = true
}
}
for _, t := range env.Tags.ToSliceOfTags() {
var res []*event.E
if t.Len() >= 2 {
@@ -196,15 +243,17 @@ func (a *A) HandleEvent(
referencedEvent := referencedEvents[0]
// Check if the author of the deletion event matches the
// author of the referenced event
if !bytes.Equal(referencedEvent.Pubkey, env.Pubkey) {
// author of the referenced event. Owners can delete
// anything.
if !bytes.Equal(
referencedEvent.Pubkey, env.Pubkey,
) && !ownerDelete {
if err = Ok.Blocked(
a, env,
"blocked: cannot delete events from other authors",
"blocked: can't delete events from other authors",
); chk.E(err) {
return
}
return
}
// Create eventid.T from the event ID bytes
@@ -278,7 +327,7 @@ func (a *A) HandleEvent(
}
return
}
if !bytes.Equal(pk, env.E.Pubkey) {
if !bytes.Equal(pk, env.E.Pubkey) && !ownerDelete {
if err = Ok.Blocked(
a, env,
"can't delete other users' events (delete by a tag)",
@@ -327,7 +376,7 @@ func (a *A) HandleEvent(
)
continue
}
if !bytes.Equal(target.Pubkey, env.Pubkey) {
if !bytes.Equal(target.Pubkey, env.Pubkey) && !ownerDelete {
if err = Ok.Error(
a, env, "only author can delete event",
); chk.E(err) {

View File

@@ -8,6 +8,7 @@ import (
"orly.dev/pkg/protocol/ws"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/iptracker"
"orly.dev/pkg/utils/log"
"orly.dev/pkg/utils/units"
"strings"
@@ -56,6 +57,18 @@ type A struct {
func (a *A) Serve(w http.ResponseWriter, r *http.Request, s server.I) {
c := a.Config()
remote := helpers.GetRemoteFromReq(r)
// Check if the IP is blocked due to too many failed auth attempts
if iptracker.Global.IsBlocked(remote) {
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
remote, blockedUntil.Format(time.RFC3339))
// We can't send a notice to the client here because the websocket connection
// hasn't been established yet, so we just reject the connection
return
}
var whitelisted bool
if len(c.Whitelist) > 0 {
for _, addr := range c.Whitelist {

View File

@@ -452,17 +452,17 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
func (r *Client) Subscribe(
c context.T, ff *filters.T,
opts ...SubscriptionOption,
) (*Subscription, error) {
sub := r.PrepareSubscription(c, ff, opts...)
) (sub *Subscription, err error) {
sub = r.PrepareSubscription(c, ff, opts...)
if r.Connection == nil {
return nil, errorf.E("not connected to %s", r.URL)
}
if err := sub.Fire(); chk.T(err) {
if err = sub.Fire(); chk.T(err) {
return nil, errorf.E(
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
)
}
return sub, nil
return
}
// PrepareSubscription creates a subscription, but doesn't fire it.

View File

@@ -85,7 +85,7 @@ func TestPublish(t *testing.T) {
defer ws.Close()
// connect a client and send the text note
rl := mustRelayConnect(ws.URL)
err = rl.Publish(context.Background(), textNote)
err = rl.Publish(context.Bg(), textNote)
if err != nil {
t.Errorf("publish should have succeeded")
}
@@ -137,7 +137,7 @@ func TestPublishBlocked(t *testing.T) {
// connect a client and send a text note
rl := mustRelayConnect(ws.URL)
if err = rl.Publish(context.Background(), textNote); !chk.E(err) {
if err = rl.Publish(context.Bg(), textNote); !chk.E(err) {
t.Errorf("should have failed to publish")
}
}
@@ -171,7 +171,7 @@ func TestPublishWriteFailed(t *testing.T) {
rl := mustRelayConnect(ws.URL)
// Force brief period of time so that publish always fails on closed socket.
time.Sleep(1 * time.Millisecond)
err = rl.Publish(context.Background(), textNote)
err = rl.Publish(context.Bg(), textNote)
if err == nil {
t.Errorf("should have failed to publish")
}
@@ -192,7 +192,7 @@ func TestConnectContext(t *testing.T) {
defer ws.Close()
// relay client
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
defer cancel()
r, err := RelayConnect(ctx, ws.URL)
if err != nil {
@@ -213,7 +213,7 @@ func TestConnectContextCanceled(t *testing.T) {
defer ws.Close()
// relay client
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.Cancel(context.Bg())
cancel() // make ctx expired
_, err := RelayConnect(ctx, ws.URL)
if !errors.Is(err, context.Canceled) {
@@ -230,9 +230,9 @@ func TestConnectWithOrigin(t *testing.T) {
defer ws.Close()
// relay client
r := NewRelay(context.Background(), string(normalize.URL(ws.URL)))
r := NewRelay(context.Bg(), string(normalize.URL(ws.URL)))
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
defer cancel()
err := r.Connect(ctx)
if err != nil {
@@ -263,7 +263,7 @@ var anyOriginHandshake = func(
}
func mustRelayConnect(url string) (client *Client) {
rl, err := RelayConnect(context.Background(), url)
rl, err := RelayConnect(context.Bg(), url)
if err != nil {
panic(err.Error())
}

View File

@@ -0,0 +1,173 @@
// Package iptracker provides functionality to track and block IP addresses
// based on failed authentication attempts.
package iptracker
import (
"sync"
"time"
)
const (
// BlockDuration is the duration for which an IP will be blocked after
// exceeding the maximum number of failed attempts.
BlockDuration = 10 * time.Minute
)
// IPTracker tracks failed authentication attempts by IP address and provides
// functionality to block IPs that exceed a threshold.
type IPTracker struct {
mu sync.RWMutex
failedAttempts map[string]int
blockedUntil map[string]time.Time
offenseCount map[string]int // Tracks the number of times an IP has been blocked
blockDurations map[string]time.Duration // Stores the current block duration for each IP
}
// NewIPTracker creates a new IPTracker instance.
func NewIPTracker() *IPTracker {
return &IPTracker{
failedAttempts: make(map[string]int),
blockedUntil: make(map[string]time.Time),
offenseCount: make(map[string]int),
blockDurations: make(map[string]time.Duration),
}
}
// RecordFailedAttempt records a failed authentication attempt for the given IP address.
// If the number of failed attempts exceeds the threshold, the IP is blocked.
// For repeat offenders, the block duration doubles with each offense.
// Returns true if the IP is now blocked, false otherwise.
func (t *IPTracker) RecordFailedAttempt(ip string) bool {
t.mu.Lock()
defer t.mu.Unlock()
// Check if the IP is already blocked
if t.isBlockedNoLock(ip) {
return true
}
// Increment the failed attempts counter
t.failedAttempts[ip]++
// If the number of failed attempts exceeds the threshold, block the IP
if t.failedAttempts[ip] >= 3 { // Threshold of 3 failed attempts
// Increment the offense count
t.offenseCount[ip]++
// Calculate block duration based on offense count
// First offense: 10 minutes, then doubles for each subsequent offense
duration := BlockDuration
if t.offenseCount[ip] > 1 {
// For repeat offenses, double the duration for each previous offense
// 10 min, then 20, then 40, then 80, etc.
for i := 1; i < t.offenseCount[ip]; i++ {
duration *= 2
}
}
// Store the calculated duration
t.blockDurations[ip] = duration
// Set the block time
t.blockedUntil[ip] = time.Now().Add(duration)
return true
}
return false
}
// IsBlocked checks if the given IP address is currently blocked.
func (t *IPTracker) IsBlocked(ip string) bool {
t.mu.RLock()
defer t.mu.RUnlock()
return t.isBlockedNoLock(ip)
}
// isBlockedNoLock is a helper method that checks if an IP is blocked without
// acquiring the lock. It should only be called when the lock is already held.
// Blocks persist until authentication, even after the block duration has passed.
func (t *IPTracker) isBlockedNoLock(ip string) bool {
_, exists := t.blockedUntil[ip]
if !exists {
return false
}
// IP is blocked until authenticated, regardless of time elapsed
return true
}
// HasBlockDurationPassed checks if the block duration for an IP has passed,
// even though the IP remains blocked until authentication.
// This is useful for display purposes.
func (t *IPTracker) HasBlockDurationPassed(ip string) bool {
t.mu.RLock()
defer t.mu.RUnlock()
blockedUntil, exists := t.blockedUntil[ip]
if !exists {
return false
}
return time.Now().After(blockedUntil)
}
// GetBlockedUntil returns the time until which the given IP address is blocked.
// If the IP is not blocked, it returns the zero time.
// Note: With the new blocking behavior, an IP remains blocked even after this time
// until it successfully authenticates.
func (t *IPTracker) GetBlockedUntil(ip string) time.Time {
t.mu.RLock()
defer t.mu.RUnlock()
blockedUntil, exists := t.blockedUntil[ip]
if !exists {
return time.Time{}
}
return blockedUntil
}
// GetBlockDuration returns the current block duration for the given IP address.
// This is useful for displaying how long the IP would have been blocked before
// requiring authentication.
func (t *IPTracker) GetBlockDuration(ip string) time.Duration {
t.mu.RLock()
defer t.mu.RUnlock()
duration, exists := t.blockDurations[ip]
if !exists {
return 0
}
return duration
}
// Authenticate records a successful authentication for an IP address.
// If the IP was blocked, it removes the block but preserves the offense count.
// This allows the IP to access the system again, but if it offends in the future,
// the penalty will still be doubled based on past offenses.
func (t *IPTracker) Authenticate(ip string) {
t.mu.Lock()
defer t.mu.Unlock()
// Remove the block but keep the offense count
delete(t.failedAttempts, ip)
delete(t.blockedUntil, ip)
// Note: We intentionally don't delete from offenseCount or blockDurations
// so that repeat offenses can be tracked
}
// Reset completely resets all tracking for the given IP address.
// This is different from Authenticate as it also resets the offense count.
func (t *IPTracker) Reset(ip string) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.failedAttempts, ip)
delete(t.blockedUntil, ip)
delete(t.offenseCount, ip)
delete(t.blockDurations, ip)
}
// Global instance of IPTracker for use across the application
var Global = NewIPTracker()

View File

@@ -1 +1 @@
v0.4.0
v0.4.8