Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
96eab2270d
|
|||
|
c0bd7d8da3
|
|||
|
1ffb7afb01
|
|||
|
ffa9d85ba5
|
|||
|
1223b1b20e
|
|||
|
deb56664e2
|
|||
|
1641d18993
|
|||
|
eab5d236db
|
|||
|
f3e7188816
|
|||
|
39957c2ebf
|
|||
|
4528d44fc7
|
|||
|
7b19db5806
|
|||
|
14d4417aec
|
|||
|
bdda37732c
|
162
cmd/nwcclient/README.md
Normal file
162
cmd/nwcclient/README.md
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
# NWC Client CLI Tool
|
||||||
|
|
||||||
|
A command-line interface tool for making calls to Nostr Wallet Connect (NWC) services.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This CLI tool allows you to interact with NWC wallet services using the methods defined in the NIP-47 specification. It provides a simple interface for executing wallet operations and displays the JSON response from the wallet service.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> <method> [parameters...]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection URL
|
||||||
|
|
||||||
|
The connection URL should be in the Nostr Wallet Connect format:
|
||||||
|
|
||||||
|
```
|
||||||
|
nostr+walletconnect://<wallet_pubkey>?relay=<relay_url>&secret=<secret>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Methods
|
||||||
|
|
||||||
|
The following methods are supported by this CLI tool:
|
||||||
|
|
||||||
|
- `get_info` - Get wallet information
|
||||||
|
- `get_balance` - Get wallet balance
|
||||||
|
- `get_budget` - Get wallet budget
|
||||||
|
- `make_invoice` - Create an invoice
|
||||||
|
- `pay_invoice` - Pay an invoice
|
||||||
|
- `pay_keysend` - Send a keysend payment
|
||||||
|
- `lookup_invoice` - Look up an invoice
|
||||||
|
- `list_transactions` - List transactions
|
||||||
|
- `sign_message` - Sign a message
|
||||||
|
|
||||||
|
### Unsupported Methods
|
||||||
|
|
||||||
|
The following methods are defined in the NIP-47 specification but are not directly supported by this CLI tool due to limitations in the underlying nwc package:
|
||||||
|
|
||||||
|
- `create_connection` - Create a connection
|
||||||
|
- `make_hold_invoice` - Create a hold invoice
|
||||||
|
- `settle_hold_invoice` - Settle a hold invoice
|
||||||
|
- `cancel_hold_invoice` - Cancel a hold invoice
|
||||||
|
- `multi_pay_invoice` - Pay multiple invoices
|
||||||
|
- `multi_pay_keysend` - Send multiple keysend payments
|
||||||
|
|
||||||
|
## Method Parameters
|
||||||
|
|
||||||
|
### Methods with No Parameters
|
||||||
|
|
||||||
|
- `get_info`
|
||||||
|
- `get_balance`
|
||||||
|
- `get_budget`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> get_info
|
||||||
|
```
|
||||||
|
|
||||||
|
### Methods with Parameters
|
||||||
|
|
||||||
|
#### make_invoice
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> make_invoice <amount> <description> [description_hash] [expiry]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `amount` - Amount in millisatoshis (msats)
|
||||||
|
- `description` - Invoice description
|
||||||
|
- `description_hash` (optional) - Hash of the description
|
||||||
|
- `expiry` (optional) - Expiry time in seconds
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> make_invoice 1000000 "Test invoice" "" 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
#### pay_invoice
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> pay_invoice <invoice> [amount]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `invoice` - BOLT11 invoice
|
||||||
|
- `amount` (optional) - Amount in millisatoshis (msats)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> pay_invoice lnbc1...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### pay_keysend
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> pay_keysend <amount> <pubkey> [preimage]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `amount` - Amount in millisatoshis (msats)
|
||||||
|
- `pubkey` - Recipient's public key
|
||||||
|
- `preimage` (optional) - Payment preimage
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> pay_keysend 1000000 03...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### lookup_invoice
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> lookup_invoice <payment_hash_or_invoice>
|
||||||
|
```
|
||||||
|
|
||||||
|
- `payment_hash_or_invoice` - Payment hash or BOLT11 invoice
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> lookup_invoice 3d...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### list_transactions
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> list_transactions [from <timestamp>] [until <timestamp>] [limit <count>] [offset <count>] [unpaid <true|false>] [type <incoming|outgoing>]
|
||||||
|
```
|
||||||
|
|
||||||
|
Parameters are specified as name-value pairs:
|
||||||
|
|
||||||
|
- `from` - Start timestamp
|
||||||
|
- `until` - End timestamp
|
||||||
|
- `limit` - Maximum number of transactions to return
|
||||||
|
- `offset` - Number of transactions to skip
|
||||||
|
- `unpaid` - Whether to include unpaid transactions
|
||||||
|
- `type` - Transaction type (incoming or outgoing)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> list_transactions limit 10 type incoming
|
||||||
|
```
|
||||||
|
|
||||||
|
#### sign_message
|
||||||
|
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> sign_message <message>
|
||||||
|
```
|
||||||
|
|
||||||
|
- `message` - Message to sign
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
nwcclient <connection URL> sign_message "Hello, world!"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
The tool prints the JSON response from the wallet service to stdout. If an error occurs, an error message is printed to stderr.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- The tool only supports methods that have direct client methods in the nwc package.
|
||||||
|
- Complex parameters like metadata are not supported.
|
||||||
|
- The tool does not support interactive authentication or authorization.
|
||||||
285
cmd/nwcclient/main.go
Normal file
285
cmd/nwcclient/main.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"orly.dev/pkg/protocol/nwc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func printUsage() {
|
||||||
|
fmt.Println("Usage: nwcclient \"<connection URL>\" <method> [parameters...]")
|
||||||
|
fmt.Println("\nSupported methods:")
|
||||||
|
fmt.Println(" get_info - Get wallet information")
|
||||||
|
fmt.Println(" get_balance - Get wallet balance")
|
||||||
|
fmt.Println(" get_budget - Get wallet budget")
|
||||||
|
fmt.Println(" make_invoice - Create an invoice (amount, description, [description_hash], [expiry])")
|
||||||
|
fmt.Println(" pay_invoice - Pay an invoice (invoice, [amount])")
|
||||||
|
fmt.Println(" pay_keysend - Send a keysend payment (amount, pubkey, [preimage])")
|
||||||
|
fmt.Println(" lookup_invoice - Look up an invoice (payment_hash or invoice)")
|
||||||
|
fmt.Println(" list_transactions - List transactions ([from], [until], [limit], [offset], [unpaid], [type])")
|
||||||
|
fmt.Println(" sign_message - Sign a message (message)")
|
||||||
|
fmt.Println("\nUnsupported methods (due to limitations in the nwc package):")
|
||||||
|
fmt.Println(" create_connection - Create a connection")
|
||||||
|
fmt.Println(" make_hold_invoice - Create a hold invoice")
|
||||||
|
fmt.Println(" settle_hold_invoice - Settle a hold invoice")
|
||||||
|
fmt.Println(" cancel_hold_invoice - Cancel a hold invoice")
|
||||||
|
fmt.Println(" multi_pay_invoice - Pay multiple invoices")
|
||||||
|
fmt.Println(" multi_pay_keysend - Send multiple keysend payments")
|
||||||
|
fmt.Println("\nParameters format:")
|
||||||
|
fmt.Println(" - Positional parameters are used for required fields")
|
||||||
|
fmt.Println(" - For list_transactions, named parameters are used: 'from', 'until', 'limit', 'offset', 'unpaid', 'type'")
|
||||||
|
fmt.Println(" Example: nwcclient <url> list_transactions limit 10 type incoming")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Check if we have enough arguments
|
||||||
|
if len(os.Args) < 3 {
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse connection URL and method
|
||||||
|
connectionURL := os.Args[1]
|
||||||
|
methodStr := os.Args[2]
|
||||||
|
method := nwc.Method(methodStr)
|
||||||
|
|
||||||
|
// Parse the wallet connect URL
|
||||||
|
opts, err := nwc.ParseWalletConnectURL(connectionURL)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing connection URL: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new NWC client
|
||||||
|
client, err := nwc.NewNWCClient(opts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating NWC client: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
// Execute the requested method
|
||||||
|
var result interface{}
|
||||||
|
|
||||||
|
switch method {
|
||||||
|
case nwc.GetInfo:
|
||||||
|
result, err = client.GetInfo()
|
||||||
|
|
||||||
|
case nwc.GetBalance:
|
||||||
|
result, err = client.GetBalance()
|
||||||
|
|
||||||
|
case nwc.GetBudget:
|
||||||
|
result, err = client.GetBudget()
|
||||||
|
|
||||||
|
case nwc.MakeInvoice:
|
||||||
|
if len(os.Args) < 5 {
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr,
|
||||||
|
"Error: make_invoice requires at least amount and description\n",
|
||||||
|
)
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
description := os.Args[4]
|
||||||
|
|
||||||
|
req := &nwc.MakeInvoiceRequest{
|
||||||
|
Amount: amount,
|
||||||
|
Description: description,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional parameters
|
||||||
|
if len(os.Args) > 5 {
|
||||||
|
req.DescriptionHash = os.Args[5]
|
||||||
|
}
|
||||||
|
if len(os.Args) > 6 {
|
||||||
|
expiry, err := strconv.ParseInt(os.Args[6], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing expiry: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Expiry = &expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.MakeInvoice(req)
|
||||||
|
|
||||||
|
case nwc.PayInvoice:
|
||||||
|
if len(os.Args) < 4 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: pay_invoice requires an invoice\n")
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &nwc.PayInvoiceRequest{
|
||||||
|
Invoice: os.Args[3],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional amount parameter
|
||||||
|
if len(os.Args) > 4 {
|
||||||
|
amount, err := strconv.ParseInt(os.Args[4], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Amount = &amount
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.PayInvoice(req)
|
||||||
|
|
||||||
|
case nwc.PayKeysend:
|
||||||
|
if len(os.Args) < 5 {
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr, "Error: pay_keysend requires amount and pubkey\n",
|
||||||
|
)
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
amount, err := strconv.ParseInt(os.Args[3], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing amount: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &nwc.PayKeysendRequest{
|
||||||
|
Amount: amount,
|
||||||
|
Pubkey: os.Args[4],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional preimage
|
||||||
|
if len(os.Args) > 5 {
|
||||||
|
req.Preimage = os.Args[5]
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.PayKeysend(req)
|
||||||
|
|
||||||
|
case nwc.LookupInvoice:
|
||||||
|
if len(os.Args) < 4 {
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr,
|
||||||
|
"Error: lookup_invoice requires a payment_hash or invoice\n",
|
||||||
|
)
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
param := os.Args[3]
|
||||||
|
req := &nwc.LookupInvoiceRequest{}
|
||||||
|
|
||||||
|
// Determine if the parameter is a payment hash or an invoice
|
||||||
|
if strings.HasPrefix(param, "ln") {
|
||||||
|
req.Invoice = param
|
||||||
|
} else {
|
||||||
|
req.PaymentHash = param
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.LookupInvoice(req)
|
||||||
|
|
||||||
|
case nwc.ListTransactions:
|
||||||
|
req := &nwc.ListTransactionsRequest{}
|
||||||
|
|
||||||
|
// Parse optional parameters
|
||||||
|
paramIndex := 3
|
||||||
|
for paramIndex < len(os.Args) {
|
||||||
|
if paramIndex+1 >= len(os.Args) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
paramName := os.Args[paramIndex]
|
||||||
|
paramValue := os.Args[paramIndex+1]
|
||||||
|
|
||||||
|
switch paramName {
|
||||||
|
case "from":
|
||||||
|
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing from: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.From = &val
|
||||||
|
case "until":
|
||||||
|
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing until: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Until = &val
|
||||||
|
case "limit":
|
||||||
|
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing limit: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Limit = &val
|
||||||
|
case "offset":
|
||||||
|
val, err := strconv.ParseInt(paramValue, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing offset: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Offset = &val
|
||||||
|
case "unpaid":
|
||||||
|
val := paramValue == "true"
|
||||||
|
req.Unpaid = &val
|
||||||
|
case "type":
|
||||||
|
req.Type = ¶mValue
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
paramIndex += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.ListTransactions(req)
|
||||||
|
|
||||||
|
case nwc.SignMessage:
|
||||||
|
if len(os.Args) < 4 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: sign_message requires a message\n")
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &nwc.SignMessageRequest{
|
||||||
|
Message: os.Args[3],
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.SignMessage(req)
|
||||||
|
|
||||||
|
case nwc.CreateConnection, nwc.MakeHoldInvoice, nwc.SettleHoldInvoice, nwc.CancelHoldInvoice, nwc.MultiPayInvoice, nwc.MultiPayKeysend:
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr,
|
||||||
|
"Error: Method %s is not directly supported by the CLI tool.\n",
|
||||||
|
methodStr,
|
||||||
|
)
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr,
|
||||||
|
"This is because these methods don't have exported client methods in the nwc package.\n",
|
||||||
|
)
|
||||||
|
fmt.Fprintf(
|
||||||
|
os.Stderr,
|
||||||
|
"Only the following methods are currently supported: get_info, get_balance, get_budget, make_invoice, pay_invoice, pay_keysend, lookup_invoice, list_transactions, sign_message\n",
|
||||||
|
)
|
||||||
|
os.Exit(1)
|
||||||
|
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: Unsupported method: %s\n", methodStr)
|
||||||
|
printUsage()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error executing method: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the result as JSON
|
||||||
|
jsonData, err := json.MarshalIndent(result, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marshaling result to JSON: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(string(jsonData))
|
||||||
|
}
|
||||||
@@ -37,7 +37,8 @@ type C struct {
|
|||||||
Pprof string `env:"ORLY_PPROF" usage:"enable pprof on 127.0.0.1:6060" enum:"cpu,memory,allocation"`
|
Pprof string `env:"ORLY_PPROF" usage:"enable pprof on 127.0.0.1:6060" enum:"cpu,memory,allocation"`
|
||||||
AuthRequired bool `env:"ORLY_AUTH_REQUIRED" default:"false" usage:"require authentication for all requests"`
|
AuthRequired bool `env:"ORLY_AUTH_REQUIRED" default:"false" usage:"require authentication for all requests"`
|
||||||
PublicReadable bool `env:"ORLY_PUBLIC_READABLE" default:"true" usage:"allow public read access to regardless of whether the client is authed"`
|
PublicReadable bool `env:"ORLY_PUBLIC_READABLE" default:"true" usage:"allow public read access to regardless of whether the client is authed"`
|
||||||
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/"`
|
SpiderSeeds []string `env:"ORLY_SPIDER_SEEDS" usage:"seeds to use for the spider (relays that are looked up initially to find owner relay lists) (comma separated)" default:"wss://profiles.nostr1.com/,wss://relay.nostr.band/,wss://relay.damus.io/,wss://nostr.wine/,wss://nostr.land/,wss://theforest.nostr1.com/"`
|
||||||
|
SpiderType string `env:"ORLY_SPIDER_TYPE" usage:"whether to spider, and what degree of spidering: none, directory, follows (follows means to the second degree of the follow graph)" default:"directory"`
|
||||||
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
|
Owners []string `env:"ORLY_OWNERS" usage:"list of users whose follow lists designate whitelisted users who can publish events, and who can read if public readable is false (comma separated)"`
|
||||||
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
|
Private bool `env:"ORLY_PRIVATE" usage:"do not spider for user metadata because the relay is private and this would leak relay memberships" default:"false"`
|
||||||
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`
|
Whitelist []string `env:"ORLY_WHITELIST" usage:"only allow connections from this list of IP addresses"`
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ func (s *Server) AddEvent(
|
|||||||
// (they're unpacked from a string containing both, appended at the
|
// (they're unpacked from a string containing both, appended at the
|
||||||
// same time), so if the pubkeys from the http event endpoint sent
|
// same time), so if the pubkeys from the http event endpoint sent
|
||||||
// us here matches the index of this address, we can skip it.
|
// us here matches the index of this address, we can skip it.
|
||||||
log.I.S(pubkeys)
|
|
||||||
for _, pk := range pubkeys {
|
for _, pk := range pubkeys {
|
||||||
if bytes.Equal(s.Peers.Pubkeys[i], pk) {
|
if bytes.Equal(s.Peers.Pubkeys[i], pk) {
|
||||||
log.I.F(
|
log.I.F(
|
||||||
@@ -135,7 +134,6 @@ func (s *Server) AddEvent(
|
|||||||
continue replica
|
continue replica
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.I.F("sending to replica %s", a)
|
|
||||||
var ur *url.URL
|
var ur *url.URL
|
||||||
if ur, err = url.Parse(a + "/api/event"); chk.E(err) {
|
if ur, err = url.Parse(a + "/api/event"); chk.E(err) {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ func (p *Peers) Init(
|
|||||||
addresses []string, sec string,
|
addresses []string, sec string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
|
if len(address) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
split := strings.Split(address, "@")
|
split := strings.Split(address, "@")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
log.E.F("invalid peer address: %s", address)
|
log.E.F("invalid peer address: %s", address)
|
||||||
@@ -46,6 +49,9 @@ func (p *Peers) Init(
|
|||||||
p.Pubkeys = append(p.Pubkeys, pk)
|
p.Pubkeys = append(p.Pubkeys, pk)
|
||||||
log.I.F("peer %s added; pubkey: %0x", split[1], pk)
|
log.I.F("peer %s added; pubkey: %0x", split[1], pk)
|
||||||
}
|
}
|
||||||
|
if sec == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
p.I = &p256k.Signer{}
|
p.I = &p256k.Signer{}
|
||||||
var s []byte
|
var s []byte
|
||||||
if s, err = keys.DecodeNsecOrHex(sec); chk.E(err) {
|
if s, err = keys.DecodeNsecOrHex(sec); chk.E(err) {
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
|
|||||||
}
|
}
|
||||||
if isFollowed {
|
if isFollowed {
|
||||||
if _, _, err = sto.SaveEvent(
|
if _, _, err = sto.SaveEvent(
|
||||||
c, evt, false,
|
c, evt, false, nil,
|
||||||
); err != nil && !errors.Is(
|
); err != nil && !errors.Is(
|
||||||
err, store.ErrDupEvent,
|
err, store.ErrDupEvent,
|
||||||
) {
|
) {
|
||||||
@@ -124,7 +124,7 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
|
|||||||
for _, pk := range owners {
|
for _, pk := range owners {
|
||||||
if bytes.Equal(evt.Pubkey, pk) {
|
if bytes.Equal(evt.Pubkey, pk) {
|
||||||
if _, _, err = sto.SaveEvent(
|
if _, _, err = sto.SaveEvent(
|
||||||
c, evt, false,
|
c, evt, false, nil,
|
||||||
); err != nil && !errors.Is(
|
); err != nil && !errors.Is(
|
||||||
err, store.ErrDupEvent,
|
err, store.ErrDupEvent,
|
||||||
) {
|
) {
|
||||||
@@ -236,7 +236,9 @@ func (s *Server) Publish(c context.T, evt *event.E) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, _, err = sto.SaveEvent(c, evt, false); err != nil && !errors.Is(
|
if _, _, err = sto.SaveEvent(
|
||||||
|
c, evt, false, append(s.Peers.Pubkeys, s.ownersPubkeys...),
|
||||||
|
); err != nil && !errors.Is(
|
||||||
err, store.ErrDupEvent,
|
err, store.ErrDupEvent,
|
||||||
) {
|
) {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"orly.dev/pkg/crypto/ec/schnorr"
|
"orly.dev/pkg/crypto/ec/schnorr"
|
||||||
"orly.dev/pkg/database/indexes/types"
|
"orly.dev/pkg/database/indexes/types"
|
||||||
"orly.dev/pkg/encoders/event"
|
"orly.dev/pkg/encoders/event"
|
||||||
@@ -15,7 +14,6 @@ import (
|
|||||||
"orly.dev/pkg/utils/context"
|
"orly.dev/pkg/utils/context"
|
||||||
"orly.dev/pkg/utils/errorf"
|
"orly.dev/pkg/utils/errorf"
|
||||||
"orly.dev/pkg/utils/log"
|
"orly.dev/pkg/utils/log"
|
||||||
"orly.dev/pkg/utils/lol"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -158,16 +156,13 @@ func (s *Server) SpiderFetch(
|
|||||||
err = nil
|
err = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each event immediately
|
// Process each event immediately
|
||||||
for i, ev := range evss {
|
for i, ev := range evss {
|
||||||
// log.I.S(ev)
|
// log.I.S(ev)
|
||||||
// Create a key based on pubkey and kind for deduplication
|
// Create a key based on pubkey and kind for deduplication
|
||||||
pkKindKey := string(ev.Pubkey) + string(ev.Kind.Marshal(nil))
|
pkKindKey := string(ev.Pubkey) + string(ev.Kind.Marshal(nil))
|
||||||
|
|
||||||
// Check if we already have an event with this pubkey and kind
|
// Check if we already have an event with this pubkey and kind
|
||||||
existing, exists := pkKindMap[pkKindKey]
|
existing, exists := pkKindMap[pkKindKey]
|
||||||
|
|
||||||
// If it doesn't exist or the new event is newer, store it and save to database
|
// If it doesn't exist or the new event is newer, store it and save to database
|
||||||
if !exists || ev.CreatedAtInt64() > existing.Timestamp {
|
if !exists || ev.CreatedAtInt64() > existing.Timestamp {
|
||||||
var ser *types.Uint40
|
var ser *types.Uint40
|
||||||
@@ -180,28 +175,14 @@ func (s *Server) SpiderFetch(
|
|||||||
if valid, err = ev.Verify(); chk.E(err) || !valid {
|
if valid, err = ev.Verify(); chk.E(err) || !valid {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.I.F("event %0x is valid", ev.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = s.Storage().SaveEvent(
|
if _, _, err = s.Storage().SaveEvent(
|
||||||
s.Ctx, ev, true, // already verified
|
s.Ctx, ev, true, nil,
|
||||||
); chk.E(err) {
|
); chk.E(err) {
|
||||||
err = nil
|
err = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if lol.Level.Load() == lol.Trace {
|
|
||||||
log.T.C(
|
|
||||||
func() string {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"saved event:\n%s", ev.Marshal(nil),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
log.I.F("saved event: %0x", ev.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the essential information
|
// Store the essential information
|
||||||
pkKindMap[pkKindKey] = &IdPkTs{
|
pkKindMap[pkKindKey] = &IdPkTs{
|
||||||
Id: ev.ID,
|
Id: ev.ID,
|
||||||
@@ -209,7 +190,6 @@ func (s *Server) SpiderFetch(
|
|||||||
Kind: ev.Kind.ToU16(),
|
Kind: ev.Kind.ToU16(),
|
||||||
Timestamp: ev.CreatedAtInt64(),
|
Timestamp: ev.CreatedAtInt64(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract p tags if not in noExtract mode
|
// Extract p tags if not in noExtract mode
|
||||||
if !noExtract {
|
if !noExtract {
|
||||||
t := ev.Tags.GetAll(tag.New("p"))
|
t := ev.Tags.GetAll(tag.New("p"))
|
||||||
@@ -227,7 +207,6 @@ func (s *Server) SpiderFetch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nil the event in the slice to free memory
|
// Nil the event in the slice to free memory
|
||||||
evss[i] = nil
|
evss[i] = nil
|
||||||
}
|
}
|
||||||
@@ -236,17 +215,14 @@ func (s *Server) SpiderFetch(
|
|||||||
}
|
}
|
||||||
chk.E(s.Storage().Sync())
|
chk.E(s.Storage().Sync())
|
||||||
debug.FreeOSMemory()
|
debug.FreeOSMemory()
|
||||||
|
|
||||||
// If we're in noExtract mode, just return
|
// If we're in noExtract mode, just return
|
||||||
if noExtract {
|
if noExtract {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the collected pubkeys to the return format
|
// Convert the collected pubkeys to the return format
|
||||||
for pk := range pkMap {
|
for pk := range pkMap {
|
||||||
pks = append(pks, []byte(pk))
|
pks = append(pks, []byte(pk))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.I.F("found %d pks", len(pks))
|
log.I.F("found %d pks", len(pks))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,16 +97,18 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
|
|||||||
s.SetFollowedFollows(followedFollows)
|
s.SetFollowedFollows(followedFollows)
|
||||||
s.SetOwnersMuted(ownersMuted)
|
s.SetOwnersMuted(ownersMuted)
|
||||||
// lastly, update all followed users new events in the background
|
// lastly, update all followed users new events in the background
|
||||||
if !dontFetch {
|
if !dontFetch && s.C.SpiderType != "none" {
|
||||||
go func() {
|
go func() {
|
||||||
|
var k *kinds.T
|
||||||
|
if s.C.SpiderType == "directory" {
|
||||||
|
k = kinds.New(
|
||||||
|
kind.ProfileMetadata, kind.RelayListMetadata,
|
||||||
|
kind.DMRelaysList,
|
||||||
|
)
|
||||||
|
}
|
||||||
everyone := append(ownersFollowed, followedFollows...)
|
everyone := append(ownersFollowed, followedFollows...)
|
||||||
s.SpiderFetch(
|
_, _ = s.SpiderFetch(
|
||||||
// kinds.New(
|
k, false, true, everyone...,
|
||||||
// kind.ProfileMetadata, kind.RelayListMetadata,
|
|
||||||
// kind.DMRelaysList,
|
|
||||||
// ),
|
|
||||||
nil,
|
|
||||||
false, true, everyone...,
|
|
||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,42 +5,16 @@ import (
|
|||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"orly.dev/pkg/crypto/p256k"
|
"lukechampine.com/frand"
|
||||||
"orly.dev/pkg/encoders/hex"
|
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
"orly.dev/pkg/utils/errorf"
|
"orly.dev/pkg/utils/errorf"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"lukechampine.com/frand"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ComputeSharedSecret returns a shared secret key used to encrypt messages. The private and public keys should be hex
|
|
||||||
// encoded. Uses the Diffie-Hellman key exchange (ECDH) (RFC 4753).
|
|
||||||
func ComputeSharedSecret(pkh, skh string) (sharedSecret []byte, err error) {
|
|
||||||
var skb, pkb []byte
|
|
||||||
if skb, err = hex.Dec(skh); chk.E(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if pkb, err = hex.Dec(pkh); chk.E(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
signer := new(p256k.Signer)
|
|
||||||
if err = signer.InitSec(skb); chk.E(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if sharedSecret, err = signer.ECDH(pkb); chk.E(err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by
|
// EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by
|
||||||
// ComputeSharedSecret.
|
// ComputeSharedSecret.
|
||||||
//
|
//
|
||||||
// Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector).
|
// Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector).
|
||||||
//
|
func EncryptNip4(msg, key []byte) (ct []byte, err error) {
|
||||||
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
|
||||||
func EncryptNip4(msg string, key []byte) (ct []byte, err error) {
|
|
||||||
// block size is 16 bytes
|
// block size is 16 bytes
|
||||||
iv := make([]byte, 16)
|
iv := make([]byte, 16)
|
||||||
if _, err = frand.Read(iv); chk.E(err) {
|
if _, err = frand.Read(iv); chk.E(err) {
|
||||||
@@ -71,22 +45,20 @@ func EncryptNip4(msg string, key []byte) (ct []byte, err error) {
|
|||||||
|
|
||||||
// DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message ->
|
// DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message ->
|
||||||
// EncryptNip4(message, key).
|
// EncryptNip4(message, key).
|
||||||
//
|
func DecryptNip4(content, key []byte) (msg []byte, err error) {
|
||||||
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
parts := bytes.Split(content, []byte("?iv="))
|
||||||
func DecryptNip4(content string, key []byte) (msg []byte, err error) {
|
|
||||||
parts := strings.Split(content, "?iv=")
|
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return nil, errorf.E(
|
return nil, errorf.E(
|
||||||
"error parsing encrypted message: no initialization vector",
|
"error parsing encrypted message: no initialization vector",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
var ciphertext []byte
|
ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
|
||||||
if ciphertext, err = base64.StdEncoding.DecodeString(parts[0]); chk.E(err) {
|
if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
|
||||||
err = errorf.E("error decoding ciphertext from base64: %w", err)
|
err = errorf.E("error decoding ciphertext from base64: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var iv []byte
|
iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
|
||||||
if iv, err = base64.StdEncoding.DecodeString(parts[1]); chk.E(err) {
|
if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
|
||||||
err = errorf.E("error decoding iv from base64: %w", err)
|
err = errorf.E("error decoding iv from base64: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
"orly.dev/pkg/crypto/p256k"
|
||||||
"orly.dev/pkg/crypto/sha256"
|
"orly.dev/pkg/crypto/sha256"
|
||||||
|
"orly.dev/pkg/interfaces/signer"
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
"orly.dev/pkg/utils/errorf"
|
"orly.dev/pkg/utils/errorf"
|
||||||
)
|
)
|
||||||
@@ -43,11 +45,9 @@ func WithCustomNonce(salt []byte) func(opts *Opts) {
|
|||||||
// Encrypt data using a provided symmetric conversation key using NIP-44
|
// Encrypt data using a provided symmetric conversation key using NIP-44
|
||||||
// encryption (chacha20 cipher stream and sha256 HMAC).
|
// encryption (chacha20 cipher stream and sha256 HMAC).
|
||||||
func Encrypt(
|
func Encrypt(
|
||||||
plaintext string, conversationKey []byte,
|
plaintext, conversationKey []byte, applyOptions ...func(opts *Opts),
|
||||||
applyOptions ...func(opts *Opts),
|
|
||||||
) (
|
) (
|
||||||
cipherString string,
|
cipherString []byte, err error,
|
||||||
err error,
|
|
||||||
) {
|
) {
|
||||||
|
|
||||||
var o Opts
|
var o Opts
|
||||||
@@ -70,7 +70,7 @@ func Encrypt(
|
|||||||
); chk.E(err) {
|
); chk.E(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
plain := []byte(plaintext)
|
plain := plaintext
|
||||||
size := len(plain)
|
size := len(plain)
|
||||||
if size < MinPlaintextSize || size > MaxPlaintextSize {
|
if size < MinPlaintextSize || size > MaxPlaintextSize {
|
||||||
err = errorf.E("plaintext should be between 1b and 64kB")
|
err = errorf.E("plaintext should be between 1b and 64kB")
|
||||||
@@ -93,14 +93,15 @@ func Encrypt(
|
|||||||
ct = append(ct, o.nonce...)
|
ct = append(ct, o.nonce...)
|
||||||
ct = append(ct, cipher...)
|
ct = append(ct, cipher...)
|
||||||
ct = append(ct, mac...)
|
ct = append(ct, mac...)
|
||||||
cipherString = base64.StdEncoding.EncodeToString(ct)
|
cipherString = make([]byte, base64.StdEncoding.EncodedLen(len(ct)))
|
||||||
|
base64.StdEncoding.Encode(cipherString, ct)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt data that has been encoded using a provided symmetric conversation
|
// Decrypt data that has been encoded using a provided symmetric conversation
|
||||||
// key using NIP-44 encryption (chacha20 cipher stream and sha256 HMAC).
|
// key using NIP-44 encryption (chacha20 cipher stream and sha256 HMAC).
|
||||||
func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
|
func Decrypt(b64ciphertextWrapped, conversationKey []byte) (
|
||||||
plaintext string,
|
plaintext []byte,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
cLen := len(b64ciphertextWrapped)
|
cLen := len(b64ciphertextWrapped)
|
||||||
@@ -108,12 +109,12 @@ func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
|
|||||||
err = errorf.E("invalid payload length: %d", cLen)
|
err = errorf.E("invalid payload length: %d", cLen)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if b64ciphertextWrapped[:1] == "#" {
|
if len(b64ciphertextWrapped) > 0 && b64ciphertextWrapped[0] == '#' {
|
||||||
err = errorf.E("unknown version")
|
err = errorf.E("unknown version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var decoded []byte
|
var decoded []byte
|
||||||
if decoded, err = base64.StdEncoding.DecodeString(b64ciphertextWrapped); chk.E(err) {
|
if decoded, err = base64.StdEncoding.DecodeString(string(b64ciphertextWrapped)); chk.E(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if decoded[0] != version {
|
if decoded[0] != version {
|
||||||
@@ -153,7 +154,7 @@ func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (
|
|||||||
err = errorf.E("invalid padding")
|
err = errorf.E("invalid padding")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
plaintext = string(unpadded)
|
plaintext = unpadded
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,8 +168,16 @@ func GenerateConversationKey(pkh, skh string) (ck []byte, err error) {
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var sign signer.I
|
||||||
|
if sign, err = p256k.NewSecFromHex(skh); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pk []byte
|
||||||
|
if pk, err = p256k.HexToBin(pkh); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
var shared []byte
|
var shared []byte
|
||||||
if shared, err = ComputeSharedSecret(pkh, skh); chk.E(err) {
|
if shared, err = sign.ECDH(pk); chk.E(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ck = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
|
ck = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ func assertCryptPriv(
|
|||||||
sk1, sk2, conversationKey, salt, plaintext, expected string,
|
sk1, sk2, conversationKey, salt, plaintext, expected string,
|
||||||
) {
|
) {
|
||||||
var (
|
var (
|
||||||
k1, s []byte
|
k1, s, plaintextBytes, actualBytes,
|
||||||
actual, decrypted string
|
expectedBytes, decrypted []byte
|
||||||
ok bool
|
ok bool
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
@@ -41,25 +41,27 @@ func assertCryptPriv(
|
|||||||
); !ok {
|
); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
actual, err = Encrypt(plaintext, k1, WithCustomNonce(s))
|
plaintextBytes = []byte(plaintext)
|
||||||
|
actualBytes, err = Encrypt(plaintextBytes, k1, WithCustomNonce(s))
|
||||||
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
|
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok {
|
expectedBytes = []byte(expected)
|
||||||
|
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decrypted, err = Decrypt(expected, k1)
|
decrypted, err = Decrypt(expectedBytes, k1)
|
||||||
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
|
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
assert.Equal(t, decrypted, plaintext, "wrong decryption")
|
assert.Equal(t, decrypted, plaintextBytes, "wrong decryption")
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertDecryptFail(
|
func assertDecryptFail(
|
||||||
t *testing.T, conversationKey, plaintext, ciphertext, msg string,
|
t *testing.T, conversationKey, plaintext, ciphertext, msg string,
|
||||||
) {
|
) {
|
||||||
var (
|
var (
|
||||||
k1 []byte
|
k1, ciphertextBytes []byte
|
||||||
ok bool
|
ok bool
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
@@ -69,7 +71,8 @@ func assertDecryptFail(
|
|||||||
); !ok {
|
); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = Decrypt(ciphertext, k1)
|
ciphertextBytes = []byte(ciphertext)
|
||||||
|
_, err = Decrypt(ciphertextBytes, k1)
|
||||||
assert.ErrorContains(t, err, msg)
|
assert.ErrorContains(t, err, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,12 +199,12 @@ func assertMessageKeyGeneration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func assertCryptLong(
|
func assertCryptLong(
|
||||||
t *testing.T, conversationKey, salt, pattern string, repeat int,
|
t *testing.T, conversationKey, salt string, pattern []byte, repeat int,
|
||||||
plaintextSha256, payloadSha256 string,
|
plaintextSha256, payloadSha256 string,
|
||||||
) {
|
) {
|
||||||
var (
|
var (
|
||||||
convKey, convSalt []byte
|
convKey, convSalt, plaintext, payloadBytes []byte
|
||||||
plaintext, actualPlaintextSha256, actualPayload, actualPayloadSha256 string
|
actualPlaintextSha256, actualPayloadSha256 string
|
||||||
h hash.Hash
|
h hash.Hash
|
||||||
ok bool
|
ok bool
|
||||||
err error
|
err error
|
||||||
@@ -218,12 +221,12 @@ func assertCryptLong(
|
|||||||
); !ok {
|
); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
plaintext = ""
|
plaintext = make([]byte, 0, len(pattern)*repeat)
|
||||||
for i := 0; i < repeat; i++ {
|
for i := 0; i < repeat; i++ {
|
||||||
plaintext += pattern
|
plaintext = append(plaintext, pattern...)
|
||||||
}
|
}
|
||||||
h = sha256.New()
|
h = sha256.New()
|
||||||
h.Write([]byte(plaintext))
|
h.Write(plaintext)
|
||||||
actualPlaintextSha256 = hex.Enc(h.Sum(nil))
|
actualPlaintextSha256 = hex.Enc(h.Sum(nil))
|
||||||
if ok = assert.Equalf(
|
if ok = assert.Equalf(
|
||||||
t, plaintextSha256, actualPlaintextSha256,
|
t, plaintextSha256, actualPlaintextSha256,
|
||||||
@@ -231,12 +234,14 @@ func assertCryptLong(
|
|||||||
); !ok {
|
); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
actualPayload, err = Encrypt(plaintext, convKey, WithCustomNonce(convSalt))
|
payloadBytes, err = Encrypt(
|
||||||
|
plaintext, convKey, WithCustomNonce(convSalt),
|
||||||
|
)
|
||||||
if ok = assert.NoErrorf(t, err, "encryption failed: %v", err); !ok {
|
if ok = assert.NoErrorf(t, err, "encryption failed: %v", err); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.Reset()
|
h.Reset()
|
||||||
h.Write([]byte(actualPayload))
|
h.Write(payloadBytes)
|
||||||
actualPayloadSha256 = hex.Enc(h.Sum(nil))
|
actualPayloadSha256 = hex.Enc(h.Sum(nil))
|
||||||
if ok = assert.Equalf(
|
if ok = assert.Equalf(
|
||||||
t, payloadSha256, actualPayloadSha256,
|
t, payloadSha256, actualPayloadSha256,
|
||||||
@@ -383,7 +388,7 @@ func TestCryptLong001(t *testing.T) {
|
|||||||
t,
|
t,
|
||||||
"8fc262099ce0d0bb9b89bac05bb9e04f9bc0090acc181fef6840ccee470371ed",
|
"8fc262099ce0d0bb9b89bac05bb9e04f9bc0090acc181fef6840ccee470371ed",
|
||||||
"326bcb2c943cd6bb717588c9e5a7e738edf6ed14ec5f5344caa6ef56f0b9cff7",
|
"326bcb2c943cd6bb717588c9e5a7e738edf6ed14ec5f5344caa6ef56f0b9cff7",
|
||||||
"x",
|
[]byte("x"),
|
||||||
65535,
|
65535,
|
||||||
"09ab7495d3e61a76f0deb12cb0306f0696cbb17ffc12131368c7a939f12f56d3",
|
"09ab7495d3e61a76f0deb12cb0306f0696cbb17ffc12131368c7a939f12f56d3",
|
||||||
"90714492225faba06310bff2f249ebdc2a5e609d65a629f1c87f2d4ffc55330a",
|
"90714492225faba06310bff2f249ebdc2a5e609d65a629f1c87f2d4ffc55330a",
|
||||||
@@ -395,7 +400,7 @@ func TestCryptLong002(t *testing.T) {
|
|||||||
t,
|
t,
|
||||||
"56adbe3720339363ab9c3b8526ffce9fd77600927488bfc4b59f7a68ffe5eae0",
|
"56adbe3720339363ab9c3b8526ffce9fd77600927488bfc4b59f7a68ffe5eae0",
|
||||||
"ad68da81833c2a8ff609c3d2c0335fd44fe5954f85bb580c6a8d467aa9fc5dd0",
|
"ad68da81833c2a8ff609c3d2c0335fd44fe5954f85bb580c6a8d467aa9fc5dd0",
|
||||||
"!",
|
[]byte("!"),
|
||||||
65535,
|
65535,
|
||||||
"6af297793b72ae092c422e552c3bb3cbc310da274bd1cf9e31023a7fe4a2d75e",
|
"6af297793b72ae092c422e552c3bb3cbc310da274bd1cf9e31023a7fe4a2d75e",
|
||||||
"8013e45a109fad3362133132b460a2d5bce235fe71c8b8f4014793fb52a49844",
|
"8013e45a109fad3362133132b460a2d5bce235fe71c8b8f4014793fb52a49844",
|
||||||
@@ -407,7 +412,7 @@ func TestCryptLong003(t *testing.T) {
|
|||||||
t,
|
t,
|
||||||
"7fc540779979e472bb8d12480b443d1e5eb1098eae546ef2390bee499bbf46be",
|
"7fc540779979e472bb8d12480b443d1e5eb1098eae546ef2390bee499bbf46be",
|
||||||
"34905e82105c20de9a2f6cd385a0d541e6bcc10601d12481ff3a7575dc622033",
|
"34905e82105c20de9a2f6cd385a0d541e6bcc10601d12481ff3a7575dc622033",
|
||||||
"🦄",
|
[]byte("🦄"),
|
||||||
16383,
|
16383,
|
||||||
"a249558d161b77297bc0cb311dde7d77190f6571b25c7e4429cd19044634a61f",
|
"a249558d161b77297bc0cb311dde7d77190f6571b25c7e4429cd19044634a61f",
|
||||||
"b3348422471da1f3c59d79acfe2fe103f3cd24488109e5b18734cdb5953afd15",
|
"b3348422471da1f3c59d79acfe2fe103f3cd24488109e5b18734cdb5953afd15",
|
||||||
@@ -1309,7 +1314,10 @@ func TestMaxLength(t *testing.T) {
|
|||||||
rand.Read(salt)
|
rand.Read(salt)
|
||||||
conversationKey, _ := GenerateConversationKey(pub2, string(sk1))
|
conversationKey, _ := GenerateConversationKey(pub2, string(sk1))
|
||||||
plaintext := strings.Repeat("a", MaxPlaintextSize)
|
plaintext := strings.Repeat("a", MaxPlaintextSize)
|
||||||
encrypted, err := Encrypt(plaintext, conversationKey, WithCustomNonce(salt))
|
plaintextBytes := []byte(plaintext)
|
||||||
|
encrypted, err := Encrypt(
|
||||||
|
plaintextBytes, conversationKey, WithCustomNonce(salt),
|
||||||
|
)
|
||||||
if chk.E(err) {
|
if chk.E(err) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -1321,7 +1329,7 @@ func TestMaxLength(t *testing.T) {
|
|||||||
fmt.Sprintf("%x", conversationKey),
|
fmt.Sprintf("%x", conversationKey),
|
||||||
fmt.Sprintf("%x", salt),
|
fmt.Sprintf("%x", salt),
|
||||||
plaintext,
|
plaintext,
|
||||||
encrypted,
|
string(encrypted),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1330,8 +1338,8 @@ func assertCryptPub(
|
|||||||
sk1, pub2, conversationKey, salt, plaintext, expected string,
|
sk1, pub2, conversationKey, salt, plaintext, expected string,
|
||||||
) {
|
) {
|
||||||
var (
|
var (
|
||||||
k1, s []byte
|
k1, s, plaintextBytes,
|
||||||
actual, decrypted string
|
actualBytes, expectedBytes, decrypted []byte
|
||||||
ok bool
|
ok bool
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
@@ -1352,16 +1360,18 @@ func assertCryptPub(
|
|||||||
); !ok {
|
); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
actual, err = Encrypt(plaintext, k1, WithCustomNonce(s))
|
plaintextBytes = []byte(plaintext)
|
||||||
|
actualBytes, err = Encrypt(plaintextBytes, k1, WithCustomNonce(s))
|
||||||
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
|
if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok {
|
expectedBytes = []byte(expected)
|
||||||
|
if ok = assert.Equalf(t, string(expectedBytes), string(actualBytes), "wrong encryption"); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decrypted, err = Decrypt(expected, k1)
|
decrypted, err = Decrypt(expectedBytes, k1)
|
||||||
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
|
if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
assert.Equal(t, decrypted, plaintext, "wrong decryption")
|
assert.Equal(t, decrypted, plaintextBytes, "wrong decryption")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,3 +18,7 @@ type Signer = btcec.Signer
|
|||||||
type Keygen = btcec.Keygen
|
type Keygen = btcec.Keygen
|
||||||
|
|
||||||
func NewKeygen() (k *Keygen) { return new(Keygen) }
|
func NewKeygen() (k *Keygen) { return new(Keygen) }
|
||||||
|
|
||||||
|
var NewSecFromHex = btcec.NewSecFromHex
|
||||||
|
var NewPubFromHex = btcec.NewPubFromHex
|
||||||
|
var HexToBin = btcec.HexToBin
|
||||||
|
|||||||
@@ -55,10 +55,20 @@ func (s *Signer) InitPub(pub []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sec returns the raw secret key bytes.
|
// Sec returns the raw secret key bytes.
|
||||||
func (s *Signer) Sec() (b []byte) { return s.skb }
|
func (s *Signer) Sec() (b []byte) {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.skb
|
||||||
|
}
|
||||||
|
|
||||||
// Pub returns the raw BIP-340 schnorr public key bytes.
|
// Pub returns the raw BIP-340 schnorr public key bytes.
|
||||||
func (s *Signer) Pub() (b []byte) { return s.pkb }
|
func (s *Signer) Pub() (b []byte) {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.pkb
|
||||||
|
}
|
||||||
|
|
||||||
// Sign a message with the Signer. Requires an initialised secret key.
|
// Sign a message with the Signer. Requires an initialised secret key.
|
||||||
func (s *Signer) Sign(msg []byte) (sig []byte, err error) {
|
func (s *Signer) Sign(msg []byte) (sig []byte, err error) {
|
||||||
|
|||||||
40
pkg/crypto/p256k/btcec/helpers-btcec.go
Normal file
40
pkg/crypto/p256k/btcec/helpers-btcec.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
//go:build !cgo
|
||||||
|
|
||||||
|
package btcec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"orly.dev/pkg/encoders/hex"
|
||||||
|
"orly.dev/pkg/interfaces/signer"
|
||||||
|
"orly.dev/pkg/utils/chk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewSecFromHex[V []byte | string](skh V) (sign signer.I, err error) {
|
||||||
|
var sk []byte
|
||||||
|
if _, err = hex.DecBytes(sk, []byte(skh)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sign = &Signer{}
|
||||||
|
if err = sign.InitSec(sk); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
|
||||||
|
var sk []byte
|
||||||
|
if _, err = hex.DecBytes(sk, []byte(pkh)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sign = &Signer{}
|
||||||
|
if err = sign.InitPub(sk); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func HexToBin(hexStr string) (b []byte, err error) {
|
||||||
|
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
40
pkg/crypto/p256k/helpers.go
Normal file
40
pkg/crypto/p256k/helpers.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
//go:build cgo
|
||||||
|
|
||||||
|
package p256k
|
||||||
|
|
||||||
|
import (
|
||||||
|
"orly.dev/pkg/encoders/hex"
|
||||||
|
"orly.dev/pkg/interfaces/signer"
|
||||||
|
"orly.dev/pkg/utils/chk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewSecFromHex[V []byte | string](skh V) (sign signer.I, err error) {
|
||||||
|
var sk []byte
|
||||||
|
if _, err = hex.DecBytes(sk, []byte(skh)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sign = &Signer{}
|
||||||
|
if err = sign.InitSec(sk); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPubFromHex[V []byte | string](pkh V) (sign signer.I, err error) {
|
||||||
|
var sk []byte
|
||||||
|
if _, err = hex.DecBytes(sk, []byte(pkh)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sign = &Signer{}
|
||||||
|
if err = sign.InitPub(sk); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func HexToBin(hexStr string) (b []byte, err error) {
|
||||||
|
if _, err = hex.DecBytes(b, []byte(hexStr)); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -77,8 +77,18 @@ func (s *Signer) InitPub(pub []byte) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Signer) Sec() (b []byte) { return s.skb }
|
func (s *Signer) Sec() (b []byte) {
|
||||||
func (s *Signer) Pub() (b []byte) { return s.pkb }
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.skb
|
||||||
|
}
|
||||||
|
func (s *Signer) Pub() (b []byte) {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.pkb
|
||||||
|
}
|
||||||
|
|
||||||
// func (s *Signer) ECPub() (b []byte) { return s.pkb }
|
// func (s *Signer) ECPub() (b []byte) { return s.pkb }
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func TestExport(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event: %v", err)
|
t.Fatalf("Failed to save event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func TestFetchEventBySerial(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestGetSerialById(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func TestGetSerialsByRange(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (d *D) Import(rr io.Reader) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, _, err = d.SaveEvent(d.ctx, ev, false); err != nil {
|
if _, _, err = d.SaveEvent(d.ctx, ev, false, nil); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
|
|||||||
baseEvent.Sign(sign)
|
baseEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the base parameterized replaceable event
|
// Save the base parameterized replaceable event
|
||||||
if _, _, err := db.SaveEvent(ctx, baseEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, baseEvent, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save base parameterized replaceable event: %v", err)
|
t.Fatalf("Failed to save base parameterized replaceable event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
|
|||||||
newerEvent.Sign(sign)
|
newerEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the newer parameterized replaceable event
|
// Save the newer parameterized replaceable event
|
||||||
if _, _, err := db.SaveEvent(ctx, newerEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, newerEvent, false, nil); err != nil {
|
||||||
t.Fatalf(
|
t.Fatalf(
|
||||||
"Failed to save newer parameterized replaceable event: %v", err,
|
"Failed to save newer parameterized replaceable event: %v", err,
|
||||||
)
|
)
|
||||||
@@ -83,7 +83,7 @@ func TestMultipleParameterizedReplaceableEvents(t *testing.T) {
|
|||||||
newestEvent.Sign(sign)
|
newestEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the newest parameterized replaceable event
|
// Save the newest parameterized replaceable event
|
||||||
if _, _, err := db.SaveEvent(ctx, newestEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, newestEvent, false, nil); err != nil {
|
||||||
t.Fatalf(
|
t.Fatalf(
|
||||||
"Failed to save newest parameterized replaceable event: %v", err,
|
"Failed to save newest parameterized replaceable event: %v", err,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,11 +16,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryEvents retrieves events based on the provided filter. If the filter
|
|
||||||
// contains Ids, it fetches events by those Ids directly, overriding other
|
|
||||||
// filter criteria. Otherwise, it queries by other filter criteria and fetches
|
|
||||||
// matching events. Results are returned in reverse chronological order of their
|
|
||||||
// creation timestamps.
|
|
||||||
func (d *D) QueryEvents(c context.T, f *filter.F) (evs event.S, err error) {
|
func (d *D) QueryEvents(c context.T, f *filter.F) (evs event.S, err error) {
|
||||||
// if there is Ids in the query, this overrides anything else
|
// if there is Ids in the query, this overrides anything else
|
||||||
if f.Ids != nil && f.Ids.Len() > 0 {
|
if f.Ids != nil && f.Ids.Len() > 0 {
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func setupTestDB(t *testing.T) (
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,7 +202,9 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
replaceableEvent.Tags = tags.New()
|
replaceableEvent.Tags = tags.New()
|
||||||
replaceableEvent.Sign(sign)
|
replaceableEvent.Sign(sign)
|
||||||
// Save the replaceable event
|
// Save the replaceable event
|
||||||
if _, _, err := db.SaveEvent(ctx, replaceableEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(
|
||||||
|
ctx, replaceableEvent, false, nil,
|
||||||
|
); err != nil {
|
||||||
t.Fatalf("Failed to save replaceable event: %v", err)
|
t.Fatalf("Failed to save replaceable event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,7 +218,7 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
newerEvent.Tags = tags.New()
|
newerEvent.Tags = tags.New()
|
||||||
newerEvent.Sign(sign)
|
newerEvent.Sign(sign)
|
||||||
// Save the newer event
|
// Save the newer event
|
||||||
if _, _, err := db.SaveEvent(ctx, newerEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, newerEvent, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save newer event: %v", err)
|
t.Fatalf("Failed to save newer event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,7 +295,7 @@ func TestReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Save the deletion event
|
// Save the deletion event
|
||||||
if _, _, err = db.SaveEvent(ctx, deletionEvent, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, deletionEvent, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save deletion event: %v", err)
|
t.Fatalf("Failed to save deletion event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,7 +381,7 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
paramEvent.Sign(sign)
|
paramEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the parameterized replaceable event
|
// Save the parameterized replaceable event
|
||||||
if _, _, err := db.SaveEvent(ctx, paramEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, paramEvent, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save parameterized replaceable event: %v", err)
|
t.Fatalf("Failed to save parameterized replaceable event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -405,7 +407,9 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
paramDeletionEvent.Sign(sign)
|
paramDeletionEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the parameterized deletion event
|
// Save the parameterized deletion event
|
||||||
if _, _, err := db.SaveEvent(ctx, paramDeletionEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(
|
||||||
|
ctx, paramDeletionEvent, false, nil,
|
||||||
|
); err != nil {
|
||||||
t.Fatalf("Failed to save parameterized deletion event: %v", err)
|
t.Fatalf("Failed to save parameterized deletion event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,7 +442,9 @@ func TestParameterizedReplaceableEventsAndDeletion(t *testing.T) {
|
|||||||
paramDeletionEvent2.Sign(sign)
|
paramDeletionEvent2.Sign(sign)
|
||||||
|
|
||||||
// Save the parameterized deletion event with e-tag
|
// Save the parameterized deletion event with e-tag
|
||||||
if _, _, err := db.SaveEvent(ctx, paramDeletionEvent2, false); err != nil {
|
if _, _, err := db.SaveEvent(
|
||||||
|
ctx, paramDeletionEvent2, false, nil,
|
||||||
|
); err != nil {
|
||||||
t.Fatalf(
|
t.Fatalf(
|
||||||
"Failed to save parameterized deletion event with e-tag: %v", err,
|
"Failed to save parameterized deletion event with e-tag: %v", err,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func TestQueryForAuthorsTags(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func TestQueryForCreatedAt(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func TestQueryForIds(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func TestQueryForKindsAuthorsTags(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func TestQueryForKindsAuthors(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func TestQueryForKindsTags(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func TestQueryForKinds(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func TestQueryForSerials(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func TestQueryForTags(t *testing.T) {
|
|||||||
events = append(events, ev)
|
events = append(events, ev)
|
||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
if _, _, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// SaveEvent saves an event to the database, generating all the necessary indexes.
|
// SaveEvent saves an event to the database, generating all the necessary indexes.
|
||||||
func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
|
func (d *D) SaveEvent(
|
||||||
kc, vc int, err error,
|
c context.T, ev *event.E, noVerify bool, owners [][]byte,
|
||||||
) {
|
) (kc, vc int, err error) {
|
||||||
if !noVerify {
|
if !noVerify {
|
||||||
// check if the event already exists
|
// check if the event already exists
|
||||||
var ser *types.Uint40
|
var ser *types.Uint40
|
||||||
@@ -94,9 +94,13 @@ func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var idxs []Range
|
var idxs []Range
|
||||||
|
keys := [][]byte{ev.Pubkey}
|
||||||
|
for _, owner := range owners {
|
||||||
|
keys = append(keys, owner)
|
||||||
|
}
|
||||||
if idxs, err = GetIndexesFromFilter(
|
if idxs, err = GetIndexesFromFilter(
|
||||||
&filter.F{
|
&filter.F{
|
||||||
Authors: tag.New(ev.Pubkey),
|
Authors: tag.New(keys...),
|
||||||
Kinds: kinds.New(kind.Deletion),
|
Kinds: kinds.New(kind.Deletion),
|
||||||
Tags: tags.New(tag.New([]byte("#e"), ev.ID)),
|
Tags: tags.New(tag.New([]byte("#e"), ev.ID)),
|
||||||
},
|
},
|
||||||
@@ -115,7 +119,7 @@ func (d *D) SaveEvent(c context.T, ev *event.E, noVerify bool) (
|
|||||||
// really there can only be one of these; the chances of an idhash
|
// really there can only be one of these; the chances of an idhash
|
||||||
// collision are basically zero in practice, at least, one in a
|
// collision are basically zero in practice, at least, one in a
|
||||||
// billion or more anyway, more than a human is going to create.
|
// billion or more anyway, more than a human is going to create.
|
||||||
err = errorf.E("blocked: %0x was deleted by event ID", ev.ID)
|
err = errorf.E("blocked: event %0x deleted by event ID", ev.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func TestSaveEvents(t *testing.T) {
|
|||||||
|
|
||||||
// Save the event to the database
|
// Save the event to the database
|
||||||
var k, v int
|
var k, v int
|
||||||
if k, v, err = db.SaveEvent(ctx, ev, false); err != nil {
|
if k, v, err = db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
t.Fatalf("Failed to save event #%d: %v", eventCount+1, err)
|
||||||
}
|
}
|
||||||
kc += k
|
kc += k
|
||||||
@@ -125,7 +125,7 @@ func TestDeletionEventWithETagRejection(t *testing.T) {
|
|||||||
regularEvent.Sign(sign)
|
regularEvent.Sign(sign)
|
||||||
|
|
||||||
// Save the regular event
|
// Save the regular event
|
||||||
if _, _, err := db.SaveEvent(ctx, regularEvent, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, regularEvent, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save regular event: %v", err)
|
t.Fatalf("Failed to save regular event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +146,7 @@ func TestDeletionEventWithETagRejection(t *testing.T) {
|
|||||||
deletionEvent.Sign(sign)
|
deletionEvent.Sign(sign)
|
||||||
|
|
||||||
// Try to save the deletion event, it should be rejected
|
// Try to save the deletion event, it should be rejected
|
||||||
_, _, err = db.SaveEvent(ctx, deletionEvent, false)
|
_, _, err = db.SaveEvent(ctx, deletionEvent, false, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected deletion event with e-tag to be rejected, but it was accepted")
|
t.Fatal("Expected deletion event with e-tag to be rejected, but it was accepted")
|
||||||
}
|
}
|
||||||
@@ -198,12 +198,12 @@ func TestSaveExistingEvent(t *testing.T) {
|
|||||||
ev.Sign(sign)
|
ev.Sign(sign)
|
||||||
|
|
||||||
// Save the event for the first time
|
// Save the event for the first time
|
||||||
if _, _, err := db.SaveEvent(ctx, ev, false); err != nil {
|
if _, _, err := db.SaveEvent(ctx, ev, false, nil); err != nil {
|
||||||
t.Fatalf("Failed to save event: %v", err)
|
t.Fatalf("Failed to save event: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to save the same event again, it should be rejected
|
// Try to save the same event again, it should be rejected
|
||||||
_, _, err = db.SaveEvent(ctx, ev, false)
|
_, _, err = db.SaveEvent(ctx, ev, false, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when saving an existing event, but got nil")
|
t.Fatal("Expected error when saving an existing event, but got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,8 +82,12 @@ type Deleter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Saver interface {
|
type Saver interface {
|
||||||
// SaveEvent is called once relay.AcceptEvent reports true.
|
// SaveEvent is called once relay.AcceptEvent reports true. The owners
|
||||||
SaveEvent(c context.T, ev *event.E, noVerify bool) (kc, vc int, err error)
|
// parameter is for designating admins whose delete by e tag events apply
|
||||||
|
// the same as author's own.
|
||||||
|
SaveEvent(
|
||||||
|
c context.T, ev *event.E, noVerify bool, owners [][]byte,
|
||||||
|
) (kc, vc int, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Importer interface {
|
type Importer interface {
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestQuery(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
var pp *pointers.Profile
|
var pp *pointers.Profile
|
||||||
acct := "fiatjaf.com"
|
acct := "fiatjaf.com"
|
||||||
if pp, err = QueryIdentifier(context.Background(), acct); chk.E(err) {
|
if pp, err = QueryIdentifier(context.Bg(), acct); chk.E(err) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if pkb, err = keys.HexPubkeyToBytes(
|
if pkb, err = keys.HexPubkeyToBytes(
|
||||||
@@ -58,7 +58,7 @@ func TestQuery(t *testing.T) {
|
|||||||
t.Fatalf("invalid query for fiatjaf.com")
|
t.Fatalf("invalid query for fiatjaf.com")
|
||||||
}
|
}
|
||||||
|
|
||||||
pp, err = QueryIdentifier(context.Background(), "htlc@fiatjaf.com")
|
pp, err = QueryIdentifier(context.Bg(), "htlc@fiatjaf.com")
|
||||||
if pkb, err = keys.HexPubkeyToBytes(
|
if pkb, err = keys.HexPubkeyToBytes(
|
||||||
"f9dd6a762506260b38a2d3e5b464213c2e47fa3877429fe9ee60e071a31a07d7",
|
"f9dd6a762506260b38a2d3e5b464213c2e47fa3877429fe9ee60e071a31a07d7",
|
||||||
); chk.E(err) {
|
); chk.E(err) {
|
||||||
|
|||||||
614
pkg/protocol/nwc/client.go
Normal file
614
pkg/protocol/nwc/client.go
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
package nwc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"orly.dev/pkg/crypto/encryption"
|
||||||
|
"orly.dev/pkg/crypto/p256k"
|
||||||
|
"orly.dev/pkg/encoders/event"
|
||||||
|
"orly.dev/pkg/encoders/filter"
|
||||||
|
"orly.dev/pkg/encoders/filters"
|
||||||
|
"orly.dev/pkg/encoders/hex"
|
||||||
|
"orly.dev/pkg/encoders/kind"
|
||||||
|
"orly.dev/pkg/encoders/kinds"
|
||||||
|
"orly.dev/pkg/encoders/tag"
|
||||||
|
"orly.dev/pkg/encoders/tags"
|
||||||
|
"orly.dev/pkg/encoders/timestamp"
|
||||||
|
"orly.dev/pkg/interfaces/signer"
|
||||||
|
"orly.dev/pkg/protocol/ws"
|
||||||
|
"orly.dev/pkg/utils/chk"
|
||||||
|
"orly.dev/pkg/utils/context"
|
||||||
|
"orly.dev/pkg/utils/log"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options represents options for a NWC client
|
||||||
|
type Options struct {
|
||||||
|
RelayURL string
|
||||||
|
Secret signer.I
|
||||||
|
WalletPubkey []byte
|
||||||
|
Lud16 string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client represents a NWC client
|
||||||
|
type Client struct {
|
||||||
|
options Options
|
||||||
|
relay *ws.Client
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseWalletConnectURL parses a wallet connect URL
|
||||||
|
func ParseWalletConnectURL(walletConnectURL string) (opts *Options, err error) {
|
||||||
|
if !strings.HasPrefix(walletConnectURL, "nostr+walletconnect://") {
|
||||||
|
return nil, fmt.Errorf("unexpected scheme. Should be nostr+walletconnect://")
|
||||||
|
}
|
||||||
|
// Parse URL
|
||||||
|
colonIndex := strings.Index(walletConnectURL, ":")
|
||||||
|
if colonIndex == -1 {
|
||||||
|
err = fmt.Errorf("invalid URL format")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
walletConnectURL = walletConnectURL[colonIndex+1:]
|
||||||
|
if strings.HasPrefix(walletConnectURL, "//") {
|
||||||
|
walletConnectURL = walletConnectURL[2:]
|
||||||
|
}
|
||||||
|
walletConnectURL = "https://" + walletConnectURL
|
||||||
|
var u *url.URL
|
||||||
|
if u, err = url.Parse(walletConnectURL); chk.E(err) {
|
||||||
|
err = fmt.Errorf("failed to parse URL: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Get wallet pubkey
|
||||||
|
walletPubkey := u.Host
|
||||||
|
if len(walletPubkey) != 64 {
|
||||||
|
err = fmt.Errorf("incorrect wallet pubkey found in auth string")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pk []byte
|
||||||
|
if pk, err = hex.Dec(walletPubkey); chk.E(err) {
|
||||||
|
err = fmt.Errorf("failed to decode pubkey: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Get relay URL
|
||||||
|
relayURL := u.Query().Get("relay")
|
||||||
|
if relayURL == "" {
|
||||||
|
return nil, fmt.Errorf("no relay URL found in auth string")
|
||||||
|
}
|
||||||
|
// Get secret
|
||||||
|
secret := u.Query().Get("secret")
|
||||||
|
if secret == "" {
|
||||||
|
return nil, fmt.Errorf("no secret found in auth string")
|
||||||
|
}
|
||||||
|
var sk []byte
|
||||||
|
if sk, err = hex.Dec(secret); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sign := &p256k.Signer{}
|
||||||
|
if err = sign.InitSec(sk); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opts = &Options{
|
||||||
|
RelayURL: relayURL,
|
||||||
|
Secret: sign,
|
||||||
|
WalletPubkey: pk,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNWCClient creates a new NWC client
|
||||||
|
func NewNWCClient(options *Options) (cl *Client, err error) {
|
||||||
|
if options.RelayURL == "" {
|
||||||
|
err = fmt.Errorf("missing relay URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.Secret == nil {
|
||||||
|
err = fmt.Errorf("missing secret")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.WalletPubkey == nil {
|
||||||
|
err = fmt.Errorf("missing wallet pubkey")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
options: Options{
|
||||||
|
RelayURL: options.RelayURL,
|
||||||
|
Secret: options.Secret,
|
||||||
|
WalletPubkey: options.WalletPubkey,
|
||||||
|
Lud16: options.Lud16,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NostrWalletConnectURL returns the nostr wallet connect URL
|
||||||
|
func (c *Client) NostrWalletConnectURL() string {
|
||||||
|
return c.GetNostrWalletConnectURL(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNostrWalletConnectURL returns the nostr wallet connect URL
|
||||||
|
func (c *Client) GetNostrWalletConnectURL(includeSecret bool) string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("relay", c.options.RelayURL)
|
||||||
|
if includeSecret {
|
||||||
|
params.Add("secret", hex.Enc(c.options.Secret.Sec()))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"nostr+walletconnect://%s?%s", c.options.WalletPubkey, params.Encode(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connected returns whether the client is connected to the relay
|
||||||
|
func (c *Client) Connected() bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.relay != nil && c.relay.IsConnected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicKey returns the client's public key
|
||||||
|
func (c *Client) GetPublicKey() (pubkey []byte, err error) {
|
||||||
|
pubkey = c.options.Secret.Pub()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the relay connection
|
||||||
|
func (c *Client) Close() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.relay != nil {
|
||||||
|
c.relay.Close()
|
||||||
|
c.relay = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt encrypts content for a pubkey
|
||||||
|
func (c *Client) encrypt(pubkey, content []byte) (
|
||||||
|
cipherText []byte, err error,
|
||||||
|
) {
|
||||||
|
var sharedSecret []byte
|
||||||
|
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cipherText, err = encryption.EncryptNip4(content, sharedSecret)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts content from a pubkey
|
||||||
|
func (c *Client) decrypt(pubkey, content []byte) (plaintext []byte, err error) {
|
||||||
|
var sharedSecret []byte
|
||||||
|
if sharedSecret, err = c.options.Secret.ECDH(pubkey); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
plaintext, err = encryption.DecryptNip4(content, sharedSecret)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInfo gets wallet info
|
||||||
|
func (c *Client) GetInfo() (response *GetInfoResponse, err error) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(GetInfo, nil); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &GetInfoResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBudget gets wallet budget
|
||||||
|
func (c *Client) GetBudget() (response *GetBudgetResponse, err error) {
|
||||||
|
var result []byte
|
||||||
|
result, err = c.executeRequest(GetBudget, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response = &GetBudgetResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBalance gets wallet balance
|
||||||
|
func (c *Client) GetBalance() (response *GetBalanceResponse, err error) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(GetBalance, nil); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &GetBalanceResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayInvoice pays an invoice
|
||||||
|
func (c *Client) PayInvoice(request *PayInvoiceRequest) (
|
||||||
|
response *PayResponse, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
result, err = c.executeRequest(PayInvoice, request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response = &PayResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayKeysend sends a keysend payment
|
||||||
|
func (c *Client) PayKeysend(request *PayKeysendRequest) (
|
||||||
|
response *PayResponse, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(PayKeysend, request); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &PayResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeInvoice creates an invoice
|
||||||
|
func (c *Client) MakeInvoice(request *MakeInvoiceRequest) (
|
||||||
|
response *Transaction, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(MakeInvoice, request); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &Transaction{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupInvoice looks up an invoice
|
||||||
|
func (c *Client) LookupInvoice(request *LookupInvoiceRequest) (
|
||||||
|
response *Transaction, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(LookupInvoice, request); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &Transaction{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTransactions lists transactions
|
||||||
|
func (c *Client) ListTransactions(request *ListTransactionsRequest) (
|
||||||
|
response *ListTransactionsResponse, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(ListTransactions, request); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &ListTransactionsResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); chk.E(err) {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignMessage signs a message
|
||||||
|
func (c *Client) SignMessage(request *SignMessageRequest) (
|
||||||
|
response *SignMessageResponse, err error,
|
||||||
|
) {
|
||||||
|
var result []byte
|
||||||
|
if result, err = c.executeRequest(SignMessage, request); chk.E(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response = &SignMessageResponse{}
|
||||||
|
if err = json.Unmarshal(result, response); err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationHandler is a function that handles notifications
|
||||||
|
type NotificationHandler func(*Notification)
|
||||||
|
|
||||||
|
// SubscribeNotifications subscribes to notifications
|
||||||
|
func (c *Client) SubscribeNotifications(
|
||||||
|
handler NotificationHandler,
|
||||||
|
notificationTypes []NotificationType,
|
||||||
|
) (stop func(), err error) {
|
||||||
|
if handler == nil {
|
||||||
|
err = fmt.Errorf("missing notification handler")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx, cancel := context.Cancel(context.Bg())
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
stop = func() {
|
||||||
|
cancel()
|
||||||
|
<-doneCh
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer close(doneCh)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Check connection
|
||||||
|
if err := c.checkConnected(); err != nil {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Get client pubkey
|
||||||
|
var clientPubkey []byte
|
||||||
|
if clientPubkey, err = c.GetPublicKey(); chk.E(err) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Subscribe to events
|
||||||
|
f := &filter.F{
|
||||||
|
Kinds: kinds.New(kind.WalletResponse),
|
||||||
|
Authors: tag.New(c.options.WalletPubkey),
|
||||||
|
Tags: tags.New(tag.New([]byte("#p"), clientPubkey)),
|
||||||
|
}
|
||||||
|
var sub *ws.Subscription
|
||||||
|
if sub, err = c.relay.Subscribe(
|
||||||
|
context.Bg(), &filters.T{
|
||||||
|
F: []*filter.F{f},
|
||||||
|
},
|
||||||
|
); chk.E(err) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle events
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
sub.Close()
|
||||||
|
return
|
||||||
|
case ev := <-sub.Events:
|
||||||
|
// Decrypt content
|
||||||
|
var decryptedContent []byte
|
||||||
|
if decryptedContent, err = c.decrypt(
|
||||||
|
c.options.WalletPubkey, ev.Content,
|
||||||
|
); chk.E(err) {
|
||||||
|
log.E.F(
|
||||||
|
"Failed to decrypt event content: %v\n", err,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Parse notification
|
||||||
|
notification := &Notification{}
|
||||||
|
if err = json.Unmarshal(
|
||||||
|
decryptedContent, notification,
|
||||||
|
); chk.E(err) {
|
||||||
|
log.E.F(
|
||||||
|
"Failed to parse notification: %v\n", err,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Check if notification type is requested
|
||||||
|
if len(notificationTypes) > 0 {
|
||||||
|
found := false
|
||||||
|
for _, t := range notificationTypes {
|
||||||
|
if notification.NotificationType == t {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle notification
|
||||||
|
handler(notification)
|
||||||
|
case <-sub.EndOfStoredEvents:
|
||||||
|
// Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRequest executes a NIP-47 request
|
||||||
|
func (c *Client) executeRequest(
|
||||||
|
method Method,
|
||||||
|
params any,
|
||||||
|
) (msg json.RawMessage, err error) {
|
||||||
|
// Default timeout values
|
||||||
|
replyTimeout := 3 * time.Second
|
||||||
|
publishTimeout := 3 * time.Second
|
||||||
|
// Create context with timeout
|
||||||
|
ctx, cancel := context.Timeout(context.Bg(), replyTimeout)
|
||||||
|
defer cancel()
|
||||||
|
// Create result channel
|
||||||
|
resultCh := make(chan json.RawMessage, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
// Check connection
|
||||||
|
if err = c.checkConnected(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Create request
|
||||||
|
request := struct {
|
||||||
|
Method Method `json:"method"`
|
||||||
|
Params any `json:"params,omitempty"`
|
||||||
|
}{
|
||||||
|
Method: method,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
// Marshal request
|
||||||
|
requestJSON, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
// Encrypt request
|
||||||
|
var encryptedContent []byte
|
||||||
|
if encryptedContent, err = c.encrypt(
|
||||||
|
c.options.WalletPubkey, requestJSON,
|
||||||
|
); chk.E(err) {
|
||||||
|
return nil, fmt.Errorf("failed to encrypt request: %w", err)
|
||||||
|
}
|
||||||
|
// Create request event
|
||||||
|
requestEvent := &event.E{
|
||||||
|
Kind: kind.WalletRequest,
|
||||||
|
CreatedAt: timestamp.New(time.Now().Unix()),
|
||||||
|
Tags: tags.New(tag.New("p", hex.Enc(c.options.WalletPubkey))),
|
||||||
|
Content: encryptedContent,
|
||||||
|
}
|
||||||
|
// Sign request event
|
||||||
|
err = requestEvent.Sign(c.options.Secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to sign request event: %w", err)
|
||||||
|
}
|
||||||
|
// Subscribe to response events
|
||||||
|
f := &filter.F{
|
||||||
|
Kinds: kinds.New(kind.WalletResponse),
|
||||||
|
Authors: tag.New(c.options.WalletPubkey),
|
||||||
|
Tags: tags.New(tag.New([]byte("#p"), requestEvent.ID)),
|
||||||
|
}
|
||||||
|
log.I.F("%s", f.Marshal(nil))
|
||||||
|
var sub *ws.Subscription
|
||||||
|
if sub, err = c.relay.Subscribe(
|
||||||
|
ctx, &filters.T{
|
||||||
|
F: []*filter.F{f},
|
||||||
|
},
|
||||||
|
); chk.E(err) {
|
||||||
|
err = fmt.Errorf(
|
||||||
|
"failed to subscribe to response events: %w", err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sub.Close()
|
||||||
|
// Set up reply timeout
|
||||||
|
replyTimer := time.AfterFunc(
|
||||||
|
replyTimeout, func() {
|
||||||
|
errCh <- NewReplyTimeoutError(
|
||||||
|
fmt.Sprintf("Timeout waiting for reply to %s", method),
|
||||||
|
"TIMEOUT",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
defer replyTimer.Stop()
|
||||||
|
// Handle response events
|
||||||
|
go func() {
|
||||||
|
var resErr error
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case ev := <-sub.Events:
|
||||||
|
// Decrypt content
|
||||||
|
var decryptedContent []byte
|
||||||
|
decryptedContent, resErr = c.decrypt(
|
||||||
|
c.options.WalletPubkey, ev.Content,
|
||||||
|
)
|
||||||
|
if chk.E(resErr) {
|
||||||
|
errCh <- fmt.Errorf(
|
||||||
|
"failed to decrypt response: %w",
|
||||||
|
resErr,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Parse response
|
||||||
|
var response struct {
|
||||||
|
ResultType string `json:"result_type"`
|
||||||
|
Result json.RawMessage `json:"result"`
|
||||||
|
Error *struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if resErr = json.Unmarshal(
|
||||||
|
decryptedContent, &response,
|
||||||
|
); chk.E(resErr) {
|
||||||
|
errCh <- fmt.Errorf("failed to parse response: %w", resErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check for error
|
||||||
|
if response.Error != nil {
|
||||||
|
errCh <- NewWalletError(
|
||||||
|
response.Error.Message,
|
||||||
|
response.Error.Code,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send result
|
||||||
|
resultCh <- response.Result
|
||||||
|
return
|
||||||
|
case <-sub.EndOfStoredEvents:
|
||||||
|
// Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// Publish request event
|
||||||
|
publishCtx, publishCancel := context.Timeout(
|
||||||
|
context.Bg(), publishTimeout,
|
||||||
|
)
|
||||||
|
defer publishCancel()
|
||||||
|
if err = c.relay.Publish(publishCtx, requestEvent); chk.E(err) {
|
||||||
|
err = fmt.Errorf("failed to publish request event: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for result or error
|
||||||
|
select {
|
||||||
|
case msg = <-resultCh:
|
||||||
|
return
|
||||||
|
case err = <-errCh:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = NewReplyTimeoutError(
|
||||||
|
fmt.Sprintf("Timeout waiting for reply to %s", method),
|
||||||
|
"TIMEOUT",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkConnected checks if the client is connected to the relay
|
||||||
|
func (c *Client) checkConnected() (err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.options.RelayURL == "" {
|
||||||
|
return fmt.Errorf("missing relay URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.relay == nil {
|
||||||
|
if c.relay, err = ws.RelayConnect(
|
||||||
|
context.Bg(), c.options.RelayURL,
|
||||||
|
); chk.E(err) {
|
||||||
|
return NewNetworkError(
|
||||||
|
"Failed to connect to "+c.options.RelayURL,
|
||||||
|
"OTHER",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if !c.relay.IsConnected() {
|
||||||
|
c.relay.Close()
|
||||||
|
if c.relay, err = ws.RelayConnect(
|
||||||
|
context.Bg(), c.options.RelayURL,
|
||||||
|
); chk.E(err) {
|
||||||
|
return NewNetworkError(
|
||||||
|
"Failed to connect to "+c.options.RelayURL,
|
||||||
|
"OTHER",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
473
pkg/protocol/nwc/types.go
Normal file
473
pkg/protocol/nwc/types.go
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
package nwc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncryptionType represents the encryption type used for NIP-47 messages
|
||||||
|
type EncryptionType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Nip04 EncryptionType = "nip04"
|
||||||
|
Nip44V2 EncryptionType = "nip44_v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthorizationUrlOptions represents options for creating an NWC authorization URL
|
||||||
|
type AuthorizationUrlOptions struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
RequestMethods []Method `json:"requestMethods,omitempty"`
|
||||||
|
NotificationTypes []NotificationType `json:"notificationTypes,omitempty"`
|
||||||
|
ReturnTo string `json:"returnTo,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||||
|
MaxAmount *int64 `json:"maxAmount,omitempty"`
|
||||||
|
BudgetRenewal BudgetRenewalPeriod `json:"budgetRenewal,omitempty"`
|
||||||
|
Isolated bool `json:"isolated,omitempty"`
|
||||||
|
Metadata interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err is the base error type for NIP-47 errors
|
||||||
|
type Err struct {
|
||||||
|
Message string
|
||||||
|
Code string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Err) Error() string {
|
||||||
|
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewError creates a new Error
|
||||||
|
func NewError(message, code string) *Err {
|
||||||
|
return &Err{
|
||||||
|
Message: message,
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkError represents a network error in NIP-47 operations
|
||||||
|
type NetworkError struct{ *Err }
|
||||||
|
|
||||||
|
// NewNetworkError creates a new NetworkError
|
||||||
|
func NewNetworkError(message, code string) *NetworkError {
|
||||||
|
return &NetworkError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalletError represents a wallet error in NIP-47 operations
|
||||||
|
type WalletError struct {
|
||||||
|
*Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWalletError creates a new WalletError
|
||||||
|
func NewWalletError(message, code string) *WalletError {
|
||||||
|
return &WalletError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeoutError represents a timeout error in NIP-47 operations
|
||||||
|
type TimeoutError struct{ *Err }
|
||||||
|
|
||||||
|
// NewTimeoutError creates a new TimeoutError
|
||||||
|
func NewTimeoutError(message, code string) *TimeoutError {
|
||||||
|
return &TimeoutError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishTimeoutError represents a publish timeout error in NIP-47 operations
|
||||||
|
type PublishTimeoutError struct{ *TimeoutError }
|
||||||
|
|
||||||
|
// NewPublishTimeoutError creates a new PublishTimeoutError
|
||||||
|
func NewPublishTimeoutError(message, code string) *PublishTimeoutError {
|
||||||
|
return &PublishTimeoutError{
|
||||||
|
TimeoutError: NewTimeoutError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyTimeoutError represents a reply timeout error in NIP-47 operations
|
||||||
|
type ReplyTimeoutError struct{ *TimeoutError }
|
||||||
|
|
||||||
|
// NewReplyTimeoutError creates a new ReplyTimeoutError
|
||||||
|
func NewReplyTimeoutError(message, code string) *ReplyTimeoutError {
|
||||||
|
return &ReplyTimeoutError{
|
||||||
|
TimeoutError: NewTimeoutError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishError represents a publish error in NIP-47 operations
|
||||||
|
type PublishError struct{ *Err }
|
||||||
|
|
||||||
|
// NewPublishError creates a new PublishError
|
||||||
|
func NewPublishError(message, code string) *PublishError {
|
||||||
|
return &PublishError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseDecodingError represents a response decoding error in NIP-47 operations
|
||||||
|
type ResponseDecodingError struct{ *Err }
|
||||||
|
|
||||||
|
// NewResponseDecodingError creates a new ResponseDecodingError
|
||||||
|
func NewResponseDecodingError(message, code string) *ResponseDecodingError {
|
||||||
|
return &ResponseDecodingError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseValidationError represents a response validation error in NIP-47 operations
|
||||||
|
type ResponseValidationError struct{ *Err }
|
||||||
|
|
||||||
|
// NewResponseValidationError creates a new ResponseValidationError
|
||||||
|
func NewResponseValidationError(message, code string) *ResponseValidationError {
|
||||||
|
return &ResponseValidationError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnexpectedResponseError represents an unexpected response error in NIP-47 operations
|
||||||
|
type UnexpectedResponseError struct{ *Err }
|
||||||
|
|
||||||
|
// NewUnexpectedResponseError creates a new UnexpectedResponseError
|
||||||
|
func NewUnexpectedResponseError(message, code string) *UnexpectedResponseError {
|
||||||
|
return &UnexpectedResponseError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsupportedEncryptionError represents an unsupported encryption error in NIP-47 operations
|
||||||
|
type UnsupportedEncryptionError struct {
|
||||||
|
*Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnsupportedEncryptionError creates a new UnsupportedEncryptionError
|
||||||
|
func NewUnsupportedEncryptionError(message, code string) *UnsupportedEncryptionError {
|
||||||
|
return &UnsupportedEncryptionError{
|
||||||
|
Err: NewError(message, code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDTag represents a type with a dTag field
|
||||||
|
type WithDTag struct {
|
||||||
|
DTag string `json:"dTag"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOptionalId represents a type with an optional id field
|
||||||
|
type WithOptionalId struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method represents a NIP-47 method
|
||||||
|
type Method string
|
||||||
|
|
||||||
|
// SingleMethod represents a single NIP-47 method
|
||||||
|
const (
|
||||||
|
GetInfo Method = "get_info"
|
||||||
|
GetBalance Method = "get_balance"
|
||||||
|
GetBudget Method = "get_budget"
|
||||||
|
MakeInvoice Method = "make_invoice"
|
||||||
|
PayInvoice Method = "pay_invoice"
|
||||||
|
PayKeysend Method = "pay_keysend"
|
||||||
|
LookupInvoice Method = "lookup_invoice"
|
||||||
|
ListTransactions Method = "list_transactions"
|
||||||
|
SignMessage Method = "sign_message"
|
||||||
|
CreateConnection Method = "create_connection"
|
||||||
|
MakeHoldInvoice Method = "make_hold_invoice"
|
||||||
|
SettleHoldInvoice Method = "settle_hold_invoice"
|
||||||
|
CancelHoldInvoice Method = "cancel_hold_invoice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MultiMethod represents a multi NIP-47 method
|
||||||
|
const (
|
||||||
|
MultiPayInvoice Method = "multi_pay_invoice"
|
||||||
|
MultiPayKeysend Method = "multi_pay_keysend"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Capability represents a NIP-47 capability
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Notifications Capability = "notifications"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BudgetRenewalPeriod represents a budget renewal period
|
||||||
|
type BudgetRenewalPeriod string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Daily BudgetRenewalPeriod = "daily"
|
||||||
|
Weekly BudgetRenewalPeriod = "weekly"
|
||||||
|
Monthly BudgetRenewalPeriod = "monthly"
|
||||||
|
Yearly BudgetRenewalPeriod = "yearly"
|
||||||
|
Never BudgetRenewalPeriod = "never"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetInfoResponse represents a response to a get_info request
|
||||||
|
type GetInfoResponse struct {
|
||||||
|
Alias string `json:"alias"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
Pubkey string `json:"pubkey"`
|
||||||
|
Network string `json:"network"`
|
||||||
|
BlockHeight int64 `json:"block_height"`
|
||||||
|
BlockHash string `json:"block_hash"`
|
||||||
|
Methods []Method `json:"methods"`
|
||||||
|
Notifications []NotificationType `json:"notifications,omitempty"`
|
||||||
|
Metadata interface{} `json:"metadata,omitempty"`
|
||||||
|
Lud16 string `json:"lud16,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBudgetResponse represents a response to a get_budget request
|
||||||
|
type GetBudgetResponse struct {
|
||||||
|
UsedBudget int64 `json:"used_budget,omitempty"`
|
||||||
|
TotalBudget int64 `json:"total_budget,omitempty"`
|
||||||
|
RenewsAt *int64 `json:"renews_at,omitempty"`
|
||||||
|
RenewalPeriod BudgetRenewalPeriod `json:"renewal_period,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBalanceResponse represents a response to a get_balance request
|
||||||
|
type GetBalanceResponse struct {
|
||||||
|
Balance int64 `json:"balance"` // msats
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayResponse represents a response to a pay request
|
||||||
|
type PayResponse struct {
|
||||||
|
Preimage string `json:"preimage"`
|
||||||
|
FeesPaid int64 `json:"fees_paid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayInvoiceRequest represents a request to pay multiple invoices
|
||||||
|
type MultiPayInvoiceRequest struct {
|
||||||
|
Invoices []PayInvoiceRequestWithID `json:"invoices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayInvoiceRequestWithID combines PayInvoiceRequest with WithOptionalId
|
||||||
|
type PayInvoiceRequestWithID struct {
|
||||||
|
PayInvoiceRequest
|
||||||
|
WithOptionalId
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayKeysendRequest represents a request to pay multiple keysends
|
||||||
|
type MultiPayKeysendRequest struct {
|
||||||
|
Keysends []PayKeysendRequestWithID `json:"keysends"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayKeysendRequestWithID combines PayKeysendRequest with WithOptionalId
|
||||||
|
type PayKeysendRequestWithID struct {
|
||||||
|
PayKeysendRequest
|
||||||
|
WithOptionalId
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayInvoiceResponse represents a response to a multi_pay_invoice request
|
||||||
|
type MultiPayInvoiceResponse struct {
|
||||||
|
Invoices []MultiPayInvoiceResponseItem `json:"invoices"`
|
||||||
|
Errors []interface{} `json:"errors"` // TODO: add error handling
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayInvoiceResponseItem represents an item in a multi_pay_invoice response
|
||||||
|
type MultiPayInvoiceResponseItem struct {
|
||||||
|
Invoice PayInvoiceRequest `json:"invoice"`
|
||||||
|
PayResponse
|
||||||
|
WithDTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayKeysendResponse represents a response to a multi_pay_keysend request
|
||||||
|
type MultiPayKeysendResponse struct {
|
||||||
|
Keysends []MultiPayKeysendResponseItem `json:"keysends"`
|
||||||
|
Errors []interface{} `json:"errors"` // TODO: add error handling
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiPayKeysendResponseItem represents an item in a multi_pay_keysend response
|
||||||
|
type MultiPayKeysendResponseItem struct {
|
||||||
|
Keysend PayKeysendRequest `json:"keysend"`
|
||||||
|
PayResponse
|
||||||
|
WithDTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTransactionsRequest represents a request to list transactions
|
||||||
|
type ListTransactionsRequest struct {
|
||||||
|
From *int64 `json:"from,omitempty"`
|
||||||
|
Until *int64 `json:"until,omitempty"`
|
||||||
|
Limit *int64 `json:"limit,omitempty"`
|
||||||
|
Offset *int64 `json:"offset,omitempty"`
|
||||||
|
Unpaid *bool `json:"unpaid,omitempty"`
|
||||||
|
UnpaidOutgoing *bool `json:"unpaid_outgoing,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||||
|
UnpaidIncoming *bool `json:"unpaid_incoming,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||||
|
Type *string `json:"type,omitempty"` // "incoming" or "outgoing"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTransactionsResponse represents a response to a list_transactions request
|
||||||
|
type ListTransactionsResponse struct {
|
||||||
|
Transactions []Transaction `json:"transactions"`
|
||||||
|
TotalCount int64 `json:"total_count"` // NOTE: non-NIP-47 spec compliant
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransactionType represents the type of a transaction
|
||||||
|
type TransactionType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Incoming TransactionType = "incoming"
|
||||||
|
Outgoing TransactionType = "outgoing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransactionState represents the state of a transaction
|
||||||
|
type TransactionState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Settled TransactionState = "settled"
|
||||||
|
Pending TransactionState = "pending"
|
||||||
|
Failed TransactionState = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transaction represents a transaction
|
||||||
|
type Transaction struct {
|
||||||
|
Type TransactionType `json:"type"`
|
||||||
|
State TransactionState `json:"state"` // NOTE: non-NIP-47 spec compliant
|
||||||
|
Invoice string `json:"invoice"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
DescriptionHash string `json:"description_hash"`
|
||||||
|
Preimage string `json:"preimage"`
|
||||||
|
PaymentHash string `json:"payment_hash"`
|
||||||
|
Amount int64 `json:"amount"`
|
||||||
|
FeesPaid int64 `json:"fees_paid"`
|
||||||
|
SettledAt int64 `json:"settled_at"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
SettleDeadline *int64 `json:"settle_deadline,omitempty"` // NOTE: non-NIP-47 spec compliant
|
||||||
|
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransactionMetadata represents metadata for a transaction
|
||||||
|
type TransactionMetadata struct {
|
||||||
|
Comment string `json:"comment,omitempty"` // LUD-12
|
||||||
|
PayerData *PayerData `json:"payer_data,omitempty"` // LUD-18
|
||||||
|
RecipientData *RecipientData `json:"recipient_data,omitempty"` // LUD-18
|
||||||
|
Nostr *NostrData `json:"nostr,omitempty"` // NIP-57
|
||||||
|
ExtraData map[string]interface{} `json:"-"` // For additional fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayerData represents payer data for a transaction
|
||||||
|
type PayerData struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Pubkey string `json:"pubkey,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecipientData represents recipient data for a transaction
|
||||||
|
type RecipientData struct {
|
||||||
|
Identifier string `json:"identifier,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NostrData represents Nostr data for a transaction
|
||||||
|
type NostrData struct {
|
||||||
|
Pubkey string `json:"pubkey"`
|
||||||
|
Tags [][]string `json:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationType represents a notification type
|
||||||
|
type NotificationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PaymentReceived NotificationType = "payment_received"
|
||||||
|
PaymentSent NotificationType = "payment_sent"
|
||||||
|
HoldInvoiceAccepted NotificationType = "hold_invoice_accepted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Notification represents a notification
|
||||||
|
type Notification struct {
|
||||||
|
NotificationType NotificationType `json:"notification_type"`
|
||||||
|
Notification Transaction `json:"notification"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayInvoiceRequest represents a request to pay an invoice
|
||||||
|
type PayInvoiceRequest struct {
|
||||||
|
Invoice string `json:"invoice"`
|
||||||
|
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||||
|
Amount *int64 `json:"amount,omitempty"` // msats
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayKeysendRequest represents a request to pay a keysend
|
||||||
|
type PayKeysendRequest struct {
|
||||||
|
Amount int64 `json:"amount"` // msats
|
||||||
|
Pubkey string `json:"pubkey"`
|
||||||
|
Preimage string `json:"preimage,omitempty"`
|
||||||
|
TlvRecords []TlvRecord `json:"tlv_records,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TlvRecord represents a TLV record
|
||||||
|
type TlvRecord struct {
|
||||||
|
Type int64 `json:"type"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeInvoiceRequest represents a request to make an invoice
|
||||||
|
type MakeInvoiceRequest struct {
|
||||||
|
Amount int64 `json:"amount"` // msats
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
DescriptionHash string `json:"description_hash,omitempty"`
|
||||||
|
Expiry *int64 `json:"expiry,omitempty"` // in seconds
|
||||||
|
Metadata *TransactionMetadata `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeHoldInvoiceRequest represents a request to make a hold invoice
|
||||||
|
type MakeHoldInvoiceRequest struct {
|
||||||
|
MakeInvoiceRequest
|
||||||
|
PaymentHash string `json:"payment_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SettleHoldInvoiceRequest represents a request to settle a hold invoice
|
||||||
|
type SettleHoldInvoiceRequest struct {
|
||||||
|
Preimage string `json:"preimage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SettleHoldInvoiceResponse represents a response to a settle_hold_invoice request
|
||||||
|
type SettleHoldInvoiceResponse struct{}
|
||||||
|
|
||||||
|
// CancelHoldInvoiceRequest represents a request to cancel a hold invoice
|
||||||
|
type CancelHoldInvoiceRequest struct {
|
||||||
|
PaymentHash string `json:"payment_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelHoldInvoiceResponse represents a response to a cancel_hold_invoice request
|
||||||
|
type CancelHoldInvoiceResponse struct{}
|
||||||
|
|
||||||
|
// LookupInvoiceRequest represents a request to lookup an invoice
|
||||||
|
type LookupInvoiceRequest struct {
|
||||||
|
PaymentHash string `json:"payment_hash,omitempty"`
|
||||||
|
Invoice string `json:"invoice,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignMessageRequest represents a request to sign a message
|
||||||
|
type SignMessageRequest struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateConnectionRequest represents a request to create a connection
|
||||||
|
type CreateConnectionRequest struct {
|
||||||
|
Pubkey string `json:"pubkey"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
RequestMethods []Method `json:"request_methods"`
|
||||||
|
NotificationTypes []NotificationType `json:"notification_types,omitempty"`
|
||||||
|
MaxAmount *int64 `json:"max_amount,omitempty"`
|
||||||
|
BudgetRenewal *BudgetRenewalPeriod `json:"budget_renewal,omitempty"`
|
||||||
|
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||||
|
Isolated *bool `json:"isolated,omitempty"`
|
||||||
|
Metadata any `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateConnectionResponse represents a response to a create_connection request
|
||||||
|
type CreateConnectionResponse struct {
|
||||||
|
WalletPubkey string `json:"wallet_pubkey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignMessageResponse represents a response to a sign_message request
|
||||||
|
type SignMessageResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Signature string `json:"signature"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeoutValues represents timeout values for NIP-47 requests
|
||||||
|
type TimeoutValues struct {
|
||||||
|
ReplyTimeout *int64 `json:"replyTimeout,omitempty"`
|
||||||
|
PublishTimeout *int64 `json:"publishTimeout,omitempty"`
|
||||||
|
}
|
||||||
@@ -16,8 +16,10 @@ import (
|
|||||||
"orly.dev/pkg/encoders/tag"
|
"orly.dev/pkg/encoders/tag"
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
"orly.dev/pkg/utils/context"
|
"orly.dev/pkg/utils/context"
|
||||||
|
"orly.dev/pkg/utils/iptracker"
|
||||||
"orly.dev/pkg/utils/log"
|
"orly.dev/pkg/utils/log"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var EventBody = &huma.RequestBody{
|
var EventBody = &huma.RequestBody{
|
||||||
@@ -104,10 +106,27 @@ func (x *Operations) RegisterEvent(api huma.API) {
|
|||||||
|
|
||||||
var pubkey []byte
|
var pubkey []byte
|
||||||
if x.I.AuthRequired() {
|
if x.I.AuthRequired() {
|
||||||
|
// Check if the IP is blocked due to too many failed auth attempts
|
||||||
|
if iptracker.Global.IsBlocked(remote) {
|
||||||
|
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
|
||||||
|
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
authed, pubkey, super = x.UserAuth(r, remote)
|
authed, pubkey, super = x.UserAuth(r, remote)
|
||||||
if !authed {
|
if !authed {
|
||||||
|
// Record the failed authentication attempt
|
||||||
|
blocked := iptracker.Global.RecordFailedAttempt(remote)
|
||||||
|
if blocked {
|
||||||
|
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
|
||||||
|
err = huma.Error403Forbidden(fmt.Sprintf("Too many failed authentication attempts. Blocked until %s", blockedUntil.Format(time.RFC3339)))
|
||||||
|
} else {
|
||||||
err = huma.Error401Unauthorized("Not Authorized")
|
err = huma.Error401Unauthorized("Not Authorized")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
|
} else {
|
||||||
|
// If authentication is successful, remove any blocks for this IP
|
||||||
|
iptracker.Global.Authenticate(remote)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// get the other pubkeys from the header that will be sent forward
|
// get the other pubkeys from the header that will be sent forward
|
||||||
|
|||||||
@@ -580,7 +580,7 @@ type EventsInput struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EventsOutput struct {
|
type EventsOutput struct {
|
||||||
Body []event.J
|
Body []*event.J
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterEvents is the implementation of the HTTP API Events method.
|
// RegisterEvents is the implementation of the HTTP API Events method.
|
||||||
@@ -667,11 +667,17 @@ Returns events as a JSON array of event objects.`
|
|||||||
}
|
}
|
||||||
tmp = append(tmp, ev)
|
tmp = append(tmp, ev)
|
||||||
}
|
}
|
||||||
|
// cap the number of events to 512 to stop excessively large
|
||||||
|
// response.
|
||||||
|
if len(events) > 512 {
|
||||||
|
break
|
||||||
|
}
|
||||||
events = tmp
|
events = tmp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
output = &EventsOutput{}
|
||||||
for _, ev := range events {
|
for _, ev := range events {
|
||||||
_ = ev
|
output.Body = append(output.Body, ev.ToEventJ())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"orly.dev/pkg/interfaces/server"
|
"orly.dev/pkg/interfaces/server"
|
||||||
"orly.dev/pkg/protocol/auth"
|
"orly.dev/pkg/protocol/auth"
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
|
"orly.dev/pkg/utils/iptracker"
|
||||||
"orly.dev/pkg/utils/log"
|
"orly.dev/pkg/utils/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -71,16 +72,9 @@ func (a *A) HandleAuth(b []byte, srv server.I) (msg []byte) {
|
|||||||
env.Event.Pubkey,
|
env.Event.Pubkey,
|
||||||
)
|
)
|
||||||
a.Listener.SetAuthedPubkey(env.Event.Pubkey)
|
a.Listener.SetAuthedPubkey(env.Event.Pubkey)
|
||||||
// ev := a.Listener.GetPendingEvent()
|
|
||||||
// if ev != nil {
|
// If authentication is successful, remove any blocks for this IP
|
||||||
// var accepted bool
|
iptracker.Global.Authenticate(a.Listener.RealRemote())
|
||||||
// if accepted, msg = a.I.AddEvent(
|
|
||||||
// context.Bg(), srv.Relay(), ev, a.Listener.Request,
|
|
||||||
// a.Listener.RealRemote(),
|
|
||||||
// ); accepted {
|
|
||||||
// log.W.F("saved event %0x", ev.Id)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ import (
|
|||||||
"orly.dev/pkg/interfaces/server"
|
"orly.dev/pkg/interfaces/server"
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
"orly.dev/pkg/utils/context"
|
"orly.dev/pkg/utils/context"
|
||||||
|
"orly.dev/pkg/utils/iptracker"
|
||||||
"orly.dev/pkg/utils/log"
|
"orly.dev/pkg/utils/log"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleEvent processes an incoming event by validating its signature, verifying
|
// HandleEvent processes an incoming event by validating its signature, verifying
|
||||||
@@ -71,7 +73,46 @@ func (a *A) HandleEvent(
|
|||||||
log.I.F("extra '%s'", rem)
|
log.I.F("extra '%s'", rem)
|
||||||
}
|
}
|
||||||
if a.I.AuthRequired() && !a.Listener.IsAuthed() {
|
if a.I.AuthRequired() && !a.Listener.IsAuthed() {
|
||||||
log.I.F("requesting auth from client from %s", a.Listener.RealRemote())
|
remoteIP := a.Listener.RealRemote()
|
||||||
|
log.I.F("requesting auth from client from %s", remoteIP)
|
||||||
|
|
||||||
|
// Check if the IP is blocked due to too many failed auth attempts
|
||||||
|
if iptracker.Global.IsBlocked(remoteIP) {
|
||||||
|
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
|
||||||
|
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||||
|
blockedUntil.Format(time.RFC3339))
|
||||||
|
|
||||||
|
// Send a notice to the client explaining why they're blocked
|
||||||
|
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
|
||||||
|
a.Listener.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record the failed authentication attempt
|
||||||
|
blocked := iptracker.Global.RecordFailedAttempt(remoteIP)
|
||||||
|
if blocked {
|
||||||
|
// If this attempt caused the IP to be blocked, close the connection
|
||||||
|
blockedUntil := iptracker.Global.GetBlockedUntil(remoteIP)
|
||||||
|
blockMsg := fmt.Sprintf("Too many failed authentication attempts. Blocked until %s",
|
||||||
|
blockedUntil.Format(time.RFC3339))
|
||||||
|
|
||||||
|
// Send a notice to the client explaining why they're blocked
|
||||||
|
if err = noticeenvelope.NewFrom(blockMsg).Write(a.Listener); chk.E(err) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
log.I.F("closing connection from %s due to too many failed auth attempts", remoteIP)
|
||||||
|
a.Listener.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue with normal auth flow for non-blocked IPs
|
||||||
a.Listener.RequestAuth()
|
a.Listener.RequestAuth()
|
||||||
if err = Ok.AuthRequired(a, env.E, "auth required"); chk.E(err) {
|
if err = Ok.AuthRequired(a, env.E, "auth required"); chk.E(err) {
|
||||||
return
|
return
|
||||||
@@ -163,6 +204,12 @@ func (a *A) HandleEvent(
|
|||||||
// check and process delete
|
// check and process delete
|
||||||
if env.E.Kind.K == kind.Deletion.K {
|
if env.E.Kind.K == kind.Deletion.K {
|
||||||
log.I.F("delete event\n%s", env.E.Serialize())
|
log.I.F("delete event\n%s", env.E.Serialize())
|
||||||
|
var ownerDelete bool
|
||||||
|
for _, pk := range a.OwnersPubkeys() {
|
||||||
|
if bytes.Equal(pk, env.Pubkey) {
|
||||||
|
ownerDelete = true
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, t := range env.Tags.ToSliceOfTags() {
|
for _, t := range env.Tags.ToSliceOfTags() {
|
||||||
var res []*event.E
|
var res []*event.E
|
||||||
if t.Len() >= 2 {
|
if t.Len() >= 2 {
|
||||||
@@ -196,15 +243,17 @@ func (a *A) HandleEvent(
|
|||||||
referencedEvent := referencedEvents[0]
|
referencedEvent := referencedEvents[0]
|
||||||
|
|
||||||
// Check if the author of the deletion event matches the
|
// Check if the author of the deletion event matches the
|
||||||
// author of the referenced event
|
// author of the referenced event. Owners can delete
|
||||||
if !bytes.Equal(referencedEvent.Pubkey, env.Pubkey) {
|
// anything.
|
||||||
|
if !bytes.Equal(
|
||||||
|
referencedEvent.Pubkey, env.Pubkey,
|
||||||
|
) && !ownerDelete {
|
||||||
if err = Ok.Blocked(
|
if err = Ok.Blocked(
|
||||||
a, env,
|
a, env,
|
||||||
"blocked: cannot delete events from other authors",
|
"blocked: can't delete events from other authors",
|
||||||
); chk.E(err) {
|
); chk.E(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create eventid.T from the event ID bytes
|
// Create eventid.T from the event ID bytes
|
||||||
@@ -278,7 +327,7 @@ func (a *A) HandleEvent(
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !bytes.Equal(pk, env.E.Pubkey) {
|
if !bytes.Equal(pk, env.E.Pubkey) && !ownerDelete {
|
||||||
if err = Ok.Blocked(
|
if err = Ok.Blocked(
|
||||||
a, env,
|
a, env,
|
||||||
"can't delete other users' events (delete by a tag)",
|
"can't delete other users' events (delete by a tag)",
|
||||||
@@ -327,7 +376,7 @@ func (a *A) HandleEvent(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !bytes.Equal(target.Pubkey, env.Pubkey) {
|
if !bytes.Equal(target.Pubkey, env.Pubkey) && !ownerDelete {
|
||||||
if err = Ok.Error(
|
if err = Ok.Error(
|
||||||
a, env, "only author can delete event",
|
a, env, "only author can delete event",
|
||||||
); chk.E(err) {
|
); chk.E(err) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"orly.dev/pkg/protocol/ws"
|
"orly.dev/pkg/protocol/ws"
|
||||||
"orly.dev/pkg/utils/chk"
|
"orly.dev/pkg/utils/chk"
|
||||||
"orly.dev/pkg/utils/context"
|
"orly.dev/pkg/utils/context"
|
||||||
|
"orly.dev/pkg/utils/iptracker"
|
||||||
"orly.dev/pkg/utils/log"
|
"orly.dev/pkg/utils/log"
|
||||||
"orly.dev/pkg/utils/units"
|
"orly.dev/pkg/utils/units"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -56,6 +57,18 @@ type A struct {
|
|||||||
func (a *A) Serve(w http.ResponseWriter, r *http.Request, s server.I) {
|
func (a *A) Serve(w http.ResponseWriter, r *http.Request, s server.I) {
|
||||||
c := a.Config()
|
c := a.Config()
|
||||||
remote := helpers.GetRemoteFromReq(r)
|
remote := helpers.GetRemoteFromReq(r)
|
||||||
|
|
||||||
|
// Check if the IP is blocked due to too many failed auth attempts
|
||||||
|
if iptracker.Global.IsBlocked(remote) {
|
||||||
|
blockedUntil := iptracker.Global.GetBlockedUntil(remote)
|
||||||
|
log.I.F("rejecting websocket connection from banned IP %s (blocked until %s)",
|
||||||
|
remote, blockedUntil.Format(time.RFC3339))
|
||||||
|
|
||||||
|
// We can't send a notice to the client here because the websocket connection
|
||||||
|
// hasn't been established yet, so we just reject the connection
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var whitelisted bool
|
var whitelisted bool
|
||||||
if len(c.Whitelist) > 0 {
|
if len(c.Whitelist) > 0 {
|
||||||
for _, addr := range c.Whitelist {
|
for _, addr := range c.Whitelist {
|
||||||
|
|||||||
@@ -452,17 +452,17 @@ func (r *Client) publish(ctx context.T, ev *event.E) (err error) {
|
|||||||
func (r *Client) Subscribe(
|
func (r *Client) Subscribe(
|
||||||
c context.T, ff *filters.T,
|
c context.T, ff *filters.T,
|
||||||
opts ...SubscriptionOption,
|
opts ...SubscriptionOption,
|
||||||
) (*Subscription, error) {
|
) (sub *Subscription, err error) {
|
||||||
sub := r.PrepareSubscription(c, ff, opts...)
|
sub = r.PrepareSubscription(c, ff, opts...)
|
||||||
if r.Connection == nil {
|
if r.Connection == nil {
|
||||||
return nil, errorf.E("not connected to %s", r.URL)
|
return nil, errorf.E("not connected to %s", r.URL)
|
||||||
}
|
}
|
||||||
if err := sub.Fire(); chk.T(err) {
|
if err = sub.Fire(); chk.T(err) {
|
||||||
return nil, errorf.E(
|
return nil, errorf.E(
|
||||||
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
|
"couldn't subscribe to %v at %s: %w", ff, r.URL, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return sub, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrepareSubscription creates a subscription, but doesn't fire it.
|
// PrepareSubscription creates a subscription, but doesn't fire it.
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func TestPublish(t *testing.T) {
|
|||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
// connect a client and send the text note
|
// connect a client and send the text note
|
||||||
rl := mustRelayConnect(ws.URL)
|
rl := mustRelayConnect(ws.URL)
|
||||||
err = rl.Publish(context.Background(), textNote)
|
err = rl.Publish(context.Bg(), textNote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("publish should have succeeded")
|
t.Errorf("publish should have succeeded")
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ func TestPublishBlocked(t *testing.T) {
|
|||||||
|
|
||||||
// connect a client and send a text note
|
// connect a client and send a text note
|
||||||
rl := mustRelayConnect(ws.URL)
|
rl := mustRelayConnect(ws.URL)
|
||||||
if err = rl.Publish(context.Background(), textNote); !chk.E(err) {
|
if err = rl.Publish(context.Bg(), textNote); !chk.E(err) {
|
||||||
t.Errorf("should have failed to publish")
|
t.Errorf("should have failed to publish")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,7 +171,7 @@ func TestPublishWriteFailed(t *testing.T) {
|
|||||||
rl := mustRelayConnect(ws.URL)
|
rl := mustRelayConnect(ws.URL)
|
||||||
// Force brief period of time so that publish always fails on closed socket.
|
// Force brief period of time so that publish always fails on closed socket.
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
err = rl.Publish(context.Background(), textNote)
|
err = rl.Publish(context.Bg(), textNote)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("should have failed to publish")
|
t.Errorf("should have failed to publish")
|
||||||
}
|
}
|
||||||
@@ -192,7 +192,7 @@ func TestConnectContext(t *testing.T) {
|
|||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
// relay client
|
// relay client
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
r, err := RelayConnect(ctx, ws.URL)
|
r, err := RelayConnect(ctx, ws.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -213,7 +213,7 @@ func TestConnectContextCanceled(t *testing.T) {
|
|||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
// relay client
|
// relay client
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.Cancel(context.Bg())
|
||||||
cancel() // make ctx expired
|
cancel() // make ctx expired
|
||||||
_, err := RelayConnect(ctx, ws.URL)
|
_, err := RelayConnect(ctx, ws.URL)
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
@@ -230,9 +230,9 @@ func TestConnectWithOrigin(t *testing.T) {
|
|||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
// relay client
|
// relay client
|
||||||
r := NewRelay(context.Background(), string(normalize.URL(ws.URL)))
|
r := NewRelay(context.Bg(), string(normalize.URL(ws.URL)))
|
||||||
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
|
r.RequestHeader = http.Header{"origin": {"https://example.com"}}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Bg(), 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := r.Connect(ctx)
|
err := r.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -263,7 +263,7 @@ var anyOriginHandshake = func(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func mustRelayConnect(url string) (client *Client) {
|
func mustRelayConnect(url string) (client *Client) {
|
||||||
rl, err := RelayConnect(context.Background(), url)
|
rl, err := RelayConnect(context.Bg(), url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
173
pkg/utils/iptracker/iptracker.go
Normal file
173
pkg/utils/iptracker/iptracker.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
// Package iptracker provides functionality to track and block IP addresses
|
||||||
|
// based on failed authentication attempts.
|
||||||
|
package iptracker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// BlockDuration is the duration for which an IP will be blocked after
|
||||||
|
// exceeding the maximum number of failed attempts.
|
||||||
|
BlockDuration = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPTracker tracks failed authentication attempts by IP address and provides
|
||||||
|
// functionality to block IPs that exceed a threshold.
|
||||||
|
type IPTracker struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
failedAttempts map[string]int
|
||||||
|
blockedUntil map[string]time.Time
|
||||||
|
offenseCount map[string]int // Tracks the number of times an IP has been blocked
|
||||||
|
blockDurations map[string]time.Duration // Stores the current block duration for each IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPTracker creates a new IPTracker instance.
|
||||||
|
func NewIPTracker() *IPTracker {
|
||||||
|
return &IPTracker{
|
||||||
|
failedAttempts: make(map[string]int),
|
||||||
|
blockedUntil: make(map[string]time.Time),
|
||||||
|
offenseCount: make(map[string]int),
|
||||||
|
blockDurations: make(map[string]time.Duration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFailedAttempt records a failed authentication attempt for the given IP address.
|
||||||
|
// If the number of failed attempts exceeds the threshold, the IP is blocked.
|
||||||
|
// For repeat offenders, the block duration doubles with each offense.
|
||||||
|
// Returns true if the IP is now blocked, false otherwise.
|
||||||
|
func (t *IPTracker) RecordFailedAttempt(ip string) bool {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if the IP is already blocked
|
||||||
|
if t.isBlockedNoLock(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the failed attempts counter
|
||||||
|
t.failedAttempts[ip]++
|
||||||
|
|
||||||
|
// If the number of failed attempts exceeds the threshold, block the IP
|
||||||
|
if t.failedAttempts[ip] >= 3 { // Threshold of 3 failed attempts
|
||||||
|
// Increment the offense count
|
||||||
|
t.offenseCount[ip]++
|
||||||
|
|
||||||
|
// Calculate block duration based on offense count
|
||||||
|
// First offense: 10 minutes, then doubles for each subsequent offense
|
||||||
|
duration := BlockDuration
|
||||||
|
if t.offenseCount[ip] > 1 {
|
||||||
|
// For repeat offenses, double the duration for each previous offense
|
||||||
|
// 10 min, then 20, then 40, then 80, etc.
|
||||||
|
for i := 1; i < t.offenseCount[ip]; i++ {
|
||||||
|
duration *= 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the calculated duration
|
||||||
|
t.blockDurations[ip] = duration
|
||||||
|
|
||||||
|
// Set the block time
|
||||||
|
t.blockedUntil[ip] = time.Now().Add(duration)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked checks if the given IP address is currently blocked.
|
||||||
|
func (t *IPTracker) IsBlocked(ip string) bool {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.isBlockedNoLock(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBlockedNoLock is a helper method that checks if an IP is blocked without
|
||||||
|
// acquiring the lock. It should only be called when the lock is already held.
|
||||||
|
// Blocks persist until authentication, even after the block duration has passed.
|
||||||
|
func (t *IPTracker) isBlockedNoLock(ip string) bool {
|
||||||
|
_, exists := t.blockedUntil[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP is blocked until authenticated, regardless of time elapsed
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasBlockDurationPassed checks if the block duration for an IP has passed,
|
||||||
|
// even though the IP remains blocked until authentication.
|
||||||
|
// This is useful for display purposes.
|
||||||
|
func (t *IPTracker) HasBlockDurationPassed(ip string) bool {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
blockedUntil, exists := t.blockedUntil[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().After(blockedUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockedUntil returns the time until which the given IP address is blocked.
|
||||||
|
// If the IP is not blocked, it returns the zero time.
|
||||||
|
// Note: With the new blocking behavior, an IP remains blocked even after this time
|
||||||
|
// until it successfully authenticates.
|
||||||
|
func (t *IPTracker) GetBlockedUntil(ip string) time.Time {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
blockedUntil, exists := t.blockedUntil[ip]
|
||||||
|
if !exists {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return blockedUntil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockDuration returns the current block duration for the given IP address.
|
||||||
|
// This is useful for displaying how long the IP would have been blocked before
|
||||||
|
// requiring authentication.
|
||||||
|
func (t *IPTracker) GetBlockDuration(ip string) time.Duration {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
duration, exists := t.blockDurations[ip]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate records a successful authentication for an IP address.
|
||||||
|
// If the IP was blocked, it removes the block but preserves the offense count.
|
||||||
|
// This allows the IP to access the system again, but if it offends in the future,
|
||||||
|
// the penalty will still be doubled based on past offenses.
|
||||||
|
func (t *IPTracker) Authenticate(ip string) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove the block but keep the offense count
|
||||||
|
delete(t.failedAttempts, ip)
|
||||||
|
delete(t.blockedUntil, ip)
|
||||||
|
// Note: We intentionally don't delete from offenseCount or blockDurations
|
||||||
|
// so that repeat offenses can be tracked
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset completely resets all tracking for the given IP address.
|
||||||
|
// This is different from Authenticate as it also resets the offense count.
|
||||||
|
func (t *IPTracker) Reset(ip string) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
delete(t.failedAttempts, ip)
|
||||||
|
delete(t.blockedUntil, ip)
|
||||||
|
delete(t.offenseCount, ip)
|
||||||
|
delete(t.blockDurations, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global instance of IPTracker for use across the application
|
||||||
|
var Global = NewIPTracker()
|
||||||
@@ -1 +1 @@
|
|||||||
v0.4.0
|
v0.4.8
|
||||||
Reference in New Issue
Block a user