add first draft graph query implementation
Some checks failed
Go / build-and-release (push) Has been cancelled

This commit is contained in:
2025-12-04 09:28:13 +00:00
parent 8dbc19ee9e
commit 6b98c23606
40 changed files with 9078 additions and 46 deletions

View File

@@ -0,0 +1,460 @@
//go:build !(js && wasm)
package database
import (
"bytes"
"context"
"testing"
"github.com/dgraph-io/badger/v4"
"next.orly.dev/pkg/database/indexes"
"next.orly.dev/pkg/database/indexes/types"
"git.mleku.dev/mleku/nostr/encoders/event"
"git.mleku.dev/mleku/nostr/encoders/hex"
"git.mleku.dev/mleku/nostr/encoders/tag"
)
func TestETagGraphEdgeCreation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a parent event (the post being replied to)
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
parentID := make([]byte, 32)
parentID[0] = 0x10
parentSig := make([]byte, 64)
parentSig[0] = 0x10
parentEvent := &event.E{
ID: parentID,
Pubkey: parentPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("This is the parent post"),
Sig: parentSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, parentEvent)
if err != nil {
t.Fatalf("Failed to save parent event: %v", err)
}
// Create a reply event with e-tag pointing to parent
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
replyID := make([]byte, 32)
replyID[0] = 0x20
replySig := make([]byte, 64)
replySig[0] = 0x20
replyEvent := &event.E{
ID: replyID,
Pubkey: replyPubkey,
CreatedAt: 1234567891,
Kind: 1,
Content: []byte("This is a reply"),
Sig: replySig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err = db.SaveEvent(ctx, replyEvent)
if err != nil {
t.Fatalf("Failed to save reply event: %v", err)
}
// Get serials for both events
parentSerial, err := db.GetSerialById(parentID)
if err != nil {
t.Fatalf("Failed to get parent serial: %v", err)
}
replySerial, err := db.GetSerialById(replyID)
if err != nil {
t.Fatalf("Failed to get reply serial: %v", err)
}
t.Logf("Parent serial: %d, Reply serial: %d", parentSerial.Get(), replySerial.Get())
// Verify forward edge exists (reply -> parent)
forwardFound := false
prefix := []byte(indexes.EventEventGraphPrefix)
err = db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
item := it.Item()
key := item.KeyCopy(nil)
// Decode the key
srcSer, tgtSer, kind, direction := indexes.EventEventGraphVars()
keyReader := bytes.NewReader(key)
if err := indexes.EventEventGraphDec(srcSer, tgtSer, kind, direction).UnmarshalRead(keyReader); err != nil {
t.Logf("Failed to decode key: %v", err)
continue
}
// Check if this is our edge
if srcSer.Get() == replySerial.Get() && tgtSer.Get() == parentSerial.Get() {
forwardFound = true
if direction.Letter() != types.EdgeDirectionETagOut {
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagOut, direction.Letter())
}
if kind.Get() != 1 {
t.Errorf("Expected kind 1, got %d", kind.Get())
}
}
}
return nil
})
if err != nil {
t.Fatalf("View failed: %v", err)
}
if !forwardFound {
t.Error("Forward edge (reply -> parent) should exist")
}
// Verify reverse edge exists (parent <- reply)
reverseFound := false
prefix = []byte(indexes.GraphEventEventPrefix)
err = db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
item := it.Item()
key := item.KeyCopy(nil)
// Decode the key
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
keyReader := bytes.NewReader(key)
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
t.Logf("Failed to decode key: %v", err)
continue
}
t.Logf("Found gee edge: tgt=%d kind=%d dir=%d src=%d",
tgtSer.Get(), kind.Get(), direction.Letter(), srcSer.Get())
// Check if this is our edge
if tgtSer.Get() == parentSerial.Get() && srcSer.Get() == replySerial.Get() {
reverseFound = true
if direction.Letter() != types.EdgeDirectionETagIn {
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagIn, direction.Letter())
}
if kind.Get() != 1 {
t.Errorf("Expected kind 1, got %d", kind.Get())
}
}
}
return nil
})
if err != nil {
t.Fatalf("View failed: %v", err)
}
if !reverseFound {
t.Error("Reverse edge (parent <- reply) should exist")
}
}
func TestETagGraphMultipleReplies(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a parent event
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
parentID := make([]byte, 32)
parentID[0] = 0x10
parentSig := make([]byte, 64)
parentSig[0] = 0x10
parentEvent := &event.E{
ID: parentID,
Pubkey: parentPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Parent post"),
Sig: parentSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, parentEvent)
if err != nil {
t.Fatalf("Failed to save parent: %v", err)
}
// Create multiple replies
numReplies := 5
for i := 0; i < numReplies; i++ {
replyPubkey := make([]byte, 32)
replyPubkey[0] = byte(i + 0x20)
replyID := make([]byte, 32)
replyID[0] = byte(i + 0x30)
replySig := make([]byte, 64)
replySig[0] = byte(i + 0x30)
replyEvent := &event.E{
ID: replyID,
Pubkey: replyPubkey,
CreatedAt: int64(1234567891 + i),
Kind: 1,
Content: []byte("Reply"),
Sig: replySig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err := db.SaveEvent(ctx, replyEvent)
if err != nil {
t.Fatalf("Failed to save reply %d: %v", i, err)
}
}
// Count inbound edges to parent
parentSerial, err := db.GetSerialById(parentID)
if err != nil {
t.Fatalf("Failed to get parent serial: %v", err)
}
inboundCount := 0
prefix := []byte(indexes.GraphEventEventPrefix)
err = db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
item := it.Item()
key := item.KeyCopy(nil)
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
keyReader := bytes.NewReader(key)
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
continue
}
if tgtSer.Get() == parentSerial.Get() {
inboundCount++
}
}
return nil
})
if err != nil {
t.Fatalf("View failed: %v", err)
}
if inboundCount != numReplies {
t.Errorf("Expected %d inbound edges, got %d", numReplies, inboundCount)
}
}
func TestETagGraphDifferentKinds(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a parent event (kind 1 - note)
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
parentID := make([]byte, 32)
parentID[0] = 0x10
parentSig := make([]byte, 64)
parentSig[0] = 0x10
parentEvent := &event.E{
ID: parentID,
Pubkey: parentPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("A note"),
Sig: parentSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, parentEvent)
if err != nil {
t.Fatalf("Failed to save parent: %v", err)
}
// Create a reaction (kind 7)
reactionPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
reactionID := make([]byte, 32)
reactionID[0] = 0x20
reactionSig := make([]byte, 64)
reactionSig[0] = 0x20
reactionEvent := &event.E{
ID: reactionID,
Pubkey: reactionPubkey,
CreatedAt: 1234567891,
Kind: 7,
Content: []byte("+"),
Sig: reactionSig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err = db.SaveEvent(ctx, reactionEvent)
if err != nil {
t.Fatalf("Failed to save reaction: %v", err)
}
// Create a repost (kind 6)
repostPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
repostID := make([]byte, 32)
repostID[0] = 0x30
repostSig := make([]byte, 64)
repostSig[0] = 0x30
repostEvent := &event.E{
ID: repostID,
Pubkey: repostPubkey,
CreatedAt: 1234567892,
Kind: 6,
Content: []byte(""),
Sig: repostSig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err = db.SaveEvent(ctx, repostEvent)
if err != nil {
t.Fatalf("Failed to save repost: %v", err)
}
// Query inbound edges by kind
parentSerial, err := db.GetSerialById(parentID)
if err != nil {
t.Fatalf("Failed to get parent serial: %v", err)
}
kindCounts := make(map[uint16]int)
prefix := []byte(indexes.GraphEventEventPrefix)
err = db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
item := it.Item()
key := item.KeyCopy(nil)
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
keyReader := bytes.NewReader(key)
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
continue
}
if tgtSer.Get() == parentSerial.Get() {
kindCounts[kind.Get()]++
}
}
return nil
})
if err != nil {
t.Fatalf("View failed: %v", err)
}
// Verify we have edges for each kind
if kindCounts[7] != 1 {
t.Errorf("Expected 1 kind-7 (reaction) edge, got %d", kindCounts[7])
}
if kindCounts[6] != 1 {
t.Errorf("Expected 1 kind-6 (repost) edge, got %d", kindCounts[6])
}
}
func TestETagGraphUnknownTarget(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create an event with e-tag pointing to non-existent event
unknownID := make([]byte, 32)
unknownID[0] = 0xFF
unknownID[31] = 0xFF
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
replyID := make([]byte, 32)
replyID[0] = 0x10
replySig := make([]byte, 64)
replySig[0] = 0x10
replyEvent := &event.E{
ID: replyID,
Pubkey: replyPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Reply to unknown"),
Sig: replySig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(unknownID)),
),
}
_, err = db.SaveEvent(ctx, replyEvent)
if err != nil {
t.Fatalf("Failed to save reply: %v", err)
}
// Verify event was saved
replySerial, err := db.GetSerialById(replyID)
if err != nil {
t.Fatalf("Failed to get reply serial: %v", err)
}
if replySerial == nil {
t.Fatal("Reply serial should exist")
}
// Verify no forward edge was created (since target doesn't exist)
edgeCount := 0
prefix := []byte(indexes.EventEventGraphPrefix)
err = db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions)
defer it.Close()
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
item := it.Item()
key := item.KeyCopy(nil)
srcSer, _, _, _ := indexes.EventEventGraphVars()
keyReader := bytes.NewReader(key)
if err := indexes.EventEventGraphDec(srcSer, new(types.Uint40), new(types.Uint16), new(types.Letter)).UnmarshalRead(keyReader); err != nil {
continue
}
if srcSer.Get() == replySerial.Get() {
edgeCount++
}
}
return nil
})
if err != nil {
t.Fatalf("View failed: %v", err)
}
if edgeCount != 0 {
t.Errorf("Expected no edges for unknown target, got %d", edgeCount)
}
}

View File

@@ -0,0 +1,42 @@
//go:build !(js && wasm)
package database
import (
"next.orly.dev/pkg/protocol/graph"
)
// GraphAdapter wraps a database instance and implements graph.GraphDatabase interface.
// This allows the graph executor to call database traversal methods without
// the database package importing the graph package.
type GraphAdapter struct {
db *D
}
// NewGraphAdapter creates a new GraphAdapter wrapping the given database.
func NewGraphAdapter(db *D) *GraphAdapter {
return &GraphAdapter{db: db}
}
// TraverseFollows implements graph.GraphDatabase.
func (a *GraphAdapter) TraverseFollows(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) {
return a.db.TraverseFollows(seedPubkey, maxDepth)
}
// TraverseFollowers implements graph.GraphDatabase.
func (a *GraphAdapter) TraverseFollowers(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) {
return a.db.TraverseFollowers(seedPubkey, maxDepth)
}
// FindMentions implements graph.GraphDatabase.
func (a *GraphAdapter) FindMentions(pubkey []byte, kinds []uint16) (graph.GraphResultI, error) {
return a.db.FindMentions(pubkey, kinds)
}
// TraverseThread implements graph.GraphDatabase.
func (a *GraphAdapter) TraverseThread(seedEventID []byte, maxDepth int, direction string) (graph.GraphResultI, error) {
return a.db.TraverseThread(seedEventID, maxDepth, direction)
}
// Verify GraphAdapter implements graph.GraphDatabase
var _ graph.GraphDatabase = (*GraphAdapter)(nil)

View File

@@ -0,0 +1,199 @@
//go:build !(js && wasm)
package database
import (
"lol.mleku.dev/log"
"next.orly.dev/pkg/database/indexes/types"
"git.mleku.dev/mleku/nostr/encoders/hex"
)
// TraverseFollows performs BFS traversal of the follow graph starting from a seed pubkey.
// Returns pubkeys grouped by first-discovered depth (no duplicates across depths).
//
// The traversal works by:
// 1. Starting with the seed pubkey at depth 0 (not included in results)
// 2. For each pubkey at the current depth, find their kind-3 contact list
// 3. Extract p-tags from the contact list to get follows
// 4. Add new (unseen) follows to the next depth
// 5. Continue until maxDepth is reached or no new pubkeys are found
//
// Early termination occurs if two consecutive depths yield no new pubkeys.
func (d *D) TraverseFollows(seedPubkey []byte, maxDepth int) (*GraphResult, error) {
result := NewGraphResult()
if len(seedPubkey) != 32 {
return result, ErrPubkeyNotFound
}
// Get seed pubkey serial
seedSerial, err := d.GetPubkeySerial(seedPubkey)
if err != nil {
log.D.F("TraverseFollows: seed pubkey not in database: %s", hex.Enc(seedPubkey))
return result, nil // Not an error - just no results
}
// Track visited pubkeys by serial to avoid cycles
visited := make(map[uint64]bool)
visited[seedSerial.Get()] = true // Mark seed as visited but don't add to results
// Current frontier (pubkeys to process at this depth)
currentFrontier := []*types.Uint40{seedSerial}
// Track consecutive empty depths for early termination
consecutiveEmptyDepths := 0
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
var nextFrontier []*types.Uint40
newPubkeysAtDepth := 0
for _, pubkeySerial := range currentFrontier {
// Get follows for this pubkey
follows, err := d.GetFollowsFromPubkeySerial(pubkeySerial)
if err != nil {
log.D.F("TraverseFollows: error getting follows for serial %d: %v", pubkeySerial.Get(), err)
continue
}
for _, followSerial := range follows {
// Skip if already visited
if visited[followSerial.Get()] {
continue
}
visited[followSerial.Get()] = true
// Get pubkey hex for result
pubkeyHex, err := d.GetPubkeyHexFromSerial(followSerial)
if err != nil {
log.D.F("TraverseFollows: error getting pubkey hex for serial %d: %v", followSerial.Get(), err)
continue
}
// Add to results at this depth
result.AddPubkeyAtDepth(pubkeyHex, currentDepth)
newPubkeysAtDepth++
// Add to next frontier for further traversal
nextFrontier = append(nextFrontier, followSerial)
}
}
log.T.F("TraverseFollows: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth)
// Check for early termination
if newPubkeysAtDepth == 0 {
consecutiveEmptyDepths++
if consecutiveEmptyDepths >= 2 {
log.T.F("TraverseFollows: early termination at depth %d (2 consecutive empty depths)", currentDepth)
break
}
} else {
consecutiveEmptyDepths = 0
}
// Move to next depth
currentFrontier = nextFrontier
}
log.D.F("TraverseFollows: completed with %d total pubkeys across %d depths",
result.TotalPubkeys, len(result.PubkeysByDepth))
return result, nil
}
// TraverseFollowers performs BFS traversal to find who follows the seed pubkey.
// This is the reverse of TraverseFollows - it finds users whose kind-3 lists
// contain the target pubkey(s).
//
// At each depth:
// - Depth 1: Users who directly follow the seed
// - Depth 2: Users who follow anyone at depth 1 (followers of followers)
// - etc.
func (d *D) TraverseFollowers(seedPubkey []byte, maxDepth int) (*GraphResult, error) {
result := NewGraphResult()
if len(seedPubkey) != 32 {
return result, ErrPubkeyNotFound
}
// Get seed pubkey serial
seedSerial, err := d.GetPubkeySerial(seedPubkey)
if err != nil {
log.D.F("TraverseFollowers: seed pubkey not in database: %s", hex.Enc(seedPubkey))
return result, nil
}
// Track visited pubkeys
visited := make(map[uint64]bool)
visited[seedSerial.Get()] = true
// Current frontier
currentFrontier := []*types.Uint40{seedSerial}
consecutiveEmptyDepths := 0
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
var nextFrontier []*types.Uint40
newPubkeysAtDepth := 0
for _, targetSerial := range currentFrontier {
// Get followers of this pubkey
followers, err := d.GetFollowersOfPubkeySerial(targetSerial)
if err != nil {
log.D.F("TraverseFollowers: error getting followers for serial %d: %v", targetSerial.Get(), err)
continue
}
for _, followerSerial := range followers {
if visited[followerSerial.Get()] {
continue
}
visited[followerSerial.Get()] = true
pubkeyHex, err := d.GetPubkeyHexFromSerial(followerSerial)
if err != nil {
continue
}
result.AddPubkeyAtDepth(pubkeyHex, currentDepth)
newPubkeysAtDepth++
nextFrontier = append(nextFrontier, followerSerial)
}
}
log.T.F("TraverseFollowers: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth)
if newPubkeysAtDepth == 0 {
consecutiveEmptyDepths++
if consecutiveEmptyDepths >= 2 {
break
}
} else {
consecutiveEmptyDepths = 0
}
currentFrontier = nextFrontier
}
log.D.F("TraverseFollowers: completed with %d total pubkeys", result.TotalPubkeys)
return result, nil
}
// TraverseFollowsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
func (d *D) TraverseFollowsFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) {
seedPubkey, err := hex.Dec(seedPubkeyHex)
if err != nil {
return nil, err
}
return d.TraverseFollows(seedPubkey, maxDepth)
}
// TraverseFollowersFromHex is a convenience wrapper that accepts hex-encoded pubkey.
func (d *D) TraverseFollowersFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) {
seedPubkey, err := hex.Dec(seedPubkeyHex)
if err != nil {
return nil, err
}
return d.TraverseFollowers(seedPubkey, maxDepth)
}

View File

@@ -0,0 +1,318 @@
//go:build !(js && wasm)
package database
import (
"context"
"testing"
"git.mleku.dev/mleku/nostr/encoders/event"
"git.mleku.dev/mleku/nostr/encoders/hex"
"git.mleku.dev/mleku/nostr/encoders/tag"
)
func TestTraverseFollows(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a simple follow graph:
// Alice -> Bob, Carol
// Bob -> David, Eve
// Carol -> Eve, Frank
//
// Expected depth 1 from Alice: Bob, Carol
// Expected depth 2 from Alice: David, Eve, Frank (Eve deduplicated)
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
david, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004")
eve, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000005")
frank, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000006")
// Create Alice's follow list (kind 3)
aliceContactID := make([]byte, 32)
aliceContactID[0] = 0x10
aliceContactSig := make([]byte, 64)
aliceContactSig[0] = 0x10
aliceContact := &event.E{
ID: aliceContactID,
Pubkey: alice,
CreatedAt: 1234567890,
Kind: 3,
Content: []byte(""),
Sig: aliceContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(bob)),
tag.NewFromAny("p", hex.Enc(carol)),
),
}
_, err = db.SaveEvent(ctx, aliceContact)
if err != nil {
t.Fatalf("Failed to save Alice's contact list: %v", err)
}
// Create Bob's follow list
bobContactID := make([]byte, 32)
bobContactID[0] = 0x20
bobContactSig := make([]byte, 64)
bobContactSig[0] = 0x20
bobContact := &event.E{
ID: bobContactID,
Pubkey: bob,
CreatedAt: 1234567891,
Kind: 3,
Content: []byte(""),
Sig: bobContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(david)),
tag.NewFromAny("p", hex.Enc(eve)),
),
}
_, err = db.SaveEvent(ctx, bobContact)
if err != nil {
t.Fatalf("Failed to save Bob's contact list: %v", err)
}
// Create Carol's follow list
carolContactID := make([]byte, 32)
carolContactID[0] = 0x30
carolContactSig := make([]byte, 64)
carolContactSig[0] = 0x30
carolContact := &event.E{
ID: carolContactID,
Pubkey: carol,
CreatedAt: 1234567892,
Kind: 3,
Content: []byte(""),
Sig: carolContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(eve)),
tag.NewFromAny("p", hex.Enc(frank)),
),
}
_, err = db.SaveEvent(ctx, carolContact)
if err != nil {
t.Fatalf("Failed to save Carol's contact list: %v", err)
}
// Traverse follows from Alice with depth 2
result, err := db.TraverseFollows(alice, 2)
if err != nil {
t.Fatalf("TraverseFollows failed: %v", err)
}
// Check depth 1: should have Bob and Carol
depth1 := result.GetPubkeysAtDepth(1)
if len(depth1) != 2 {
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(depth1))
}
depth1Set := make(map[string]bool)
for _, pk := range depth1 {
depth1Set[pk] = true
}
if !depth1Set[hex.Enc(bob)] {
t.Error("Bob should be at depth 1")
}
if !depth1Set[hex.Enc(carol)] {
t.Error("Carol should be at depth 1")
}
// Check depth 2: should have David, Eve, Frank (Eve deduplicated)
depth2 := result.GetPubkeysAtDepth(2)
if len(depth2) != 3 {
t.Errorf("Expected 3 pubkeys at depth 2, got %d: %v", len(depth2), depth2)
}
depth2Set := make(map[string]bool)
for _, pk := range depth2 {
depth2Set[pk] = true
}
if !depth2Set[hex.Enc(david)] {
t.Error("David should be at depth 2")
}
if !depth2Set[hex.Enc(eve)] {
t.Error("Eve should be at depth 2")
}
if !depth2Set[hex.Enc(frank)] {
t.Error("Frank should be at depth 2")
}
// Verify total count
if result.TotalPubkeys != 5 {
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys)
}
// Verify ToDepthArrays output
arrays := result.ToDepthArrays()
if len(arrays) != 2 {
t.Errorf("Expected 2 depth arrays, got %d", len(arrays))
}
if len(arrays[0]) != 2 {
t.Errorf("Expected 2 pubkeys in depth 1 array, got %d", len(arrays[0]))
}
if len(arrays[1]) != 3 {
t.Errorf("Expected 3 pubkeys in depth 2 array, got %d", len(arrays[1]))
}
}
func TestTraverseFollowsDepth1(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
// Create Alice's follow list
aliceContactID := make([]byte, 32)
aliceContactID[0] = 0x10
aliceContactSig := make([]byte, 64)
aliceContactSig[0] = 0x10
aliceContact := &event.E{
ID: aliceContactID,
Pubkey: alice,
CreatedAt: 1234567890,
Kind: 3,
Content: []byte(""),
Sig: aliceContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(bob)),
tag.NewFromAny("p", hex.Enc(carol)),
),
}
_, err = db.SaveEvent(ctx, aliceContact)
if err != nil {
t.Fatalf("Failed to save contact list: %v", err)
}
// Traverse with depth 1 only
result, err := db.TraverseFollows(alice, 1)
if err != nil {
t.Fatalf("TraverseFollows failed: %v", err)
}
if result.TotalPubkeys != 2 {
t.Errorf("Expected 2 pubkeys, got %d", result.TotalPubkeys)
}
arrays := result.ToDepthArrays()
if len(arrays) != 1 {
t.Errorf("Expected 1 depth array for depth 1 query, got %d", len(arrays))
}
}
func TestTraverseFollowersBasic(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create scenario: Bob and Carol follow Alice
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
// Bob's contact list includes Alice
bobContactID := make([]byte, 32)
bobContactID[0] = 0x10
bobContactSig := make([]byte, 64)
bobContactSig[0] = 0x10
bobContact := &event.E{
ID: bobContactID,
Pubkey: bob,
CreatedAt: 1234567890,
Kind: 3,
Content: []byte(""),
Sig: bobContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(alice)),
),
}
_, err = db.SaveEvent(ctx, bobContact)
if err != nil {
t.Fatalf("Failed to save Bob's contact list: %v", err)
}
// Carol's contact list includes Alice
carolContactID := make([]byte, 32)
carolContactID[0] = 0x20
carolContactSig := make([]byte, 64)
carolContactSig[0] = 0x20
carolContact := &event.E{
ID: carolContactID,
Pubkey: carol,
CreatedAt: 1234567891,
Kind: 3,
Content: []byte(""),
Sig: carolContactSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(alice)),
),
}
_, err = db.SaveEvent(ctx, carolContact)
if err != nil {
t.Fatalf("Failed to save Carol's contact list: %v", err)
}
// Find Alice's followers
result, err := db.TraverseFollowers(alice, 1)
if err != nil {
t.Fatalf("TraverseFollowers failed: %v", err)
}
if result.TotalPubkeys != 2 {
t.Errorf("Expected 2 followers, got %d", result.TotalPubkeys)
}
followers := result.GetPubkeysAtDepth(1)
followerSet := make(map[string]bool)
for _, pk := range followers {
followerSet[pk] = true
}
if !followerSet[hex.Enc(bob)] {
t.Error("Bob should be a follower")
}
if !followerSet[hex.Enc(carol)] {
t.Error("Carol should be a follower")
}
}
func TestTraverseFollowsNonExistent(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Try to traverse from a pubkey that doesn't exist
nonExistent, _ := hex.Dec("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
result, err := db.TraverseFollows(nonExistent, 2)
if err != nil {
t.Fatalf("TraverseFollows should not error for non-existent pubkey: %v", err)
}
if result.TotalPubkeys != 0 {
t.Errorf("Expected 0 pubkeys for non-existent seed, got %d", result.TotalPubkeys)
}
}

View File

@@ -0,0 +1,91 @@
//go:build !(js && wasm)
package database
import (
"lol.mleku.dev/log"
"next.orly.dev/pkg/database/indexes/types"
"git.mleku.dev/mleku/nostr/encoders/hex"
)
// FindMentions finds events that mention a pubkey via p-tags.
// This returns events grouped by depth, where depth represents how the events relate:
// - Depth 1: Events that directly mention the seed pubkey
// - Depth 2+: Not typically used for mentions (reserved for future expansion)
//
// The kinds parameter filters which event kinds to include (e.g., [1] for notes only,
// [1,7] for notes and reactions, etc.)
func (d *D) FindMentions(pubkey []byte, kinds []uint16) (*GraphResult, error) {
result := NewGraphResult()
if len(pubkey) != 32 {
return result, ErrPubkeyNotFound
}
// Get pubkey serial
pubkeySerial, err := d.GetPubkeySerial(pubkey)
if err != nil {
log.D.F("FindMentions: pubkey not in database: %s", hex.Enc(pubkey))
return result, nil
}
// Find all events that reference this pubkey
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds)
if err != nil {
return nil, err
}
// Add each event at depth 1
for _, eventSerial := range eventSerials {
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
if err != nil {
log.D.F("FindMentions: error getting event ID for serial %d: %v", eventSerial.Get(), err)
continue
}
result.AddEventAtDepth(eventIDHex, 1)
}
log.D.F("FindMentions: found %d events mentioning pubkey %s", result.TotalEvents, hex.Enc(pubkey))
return result, nil
}
// FindMentionsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
func (d *D) FindMentionsFromHex(pubkeyHex string, kinds []uint16) (*GraphResult, error) {
pubkey, err := hex.Dec(pubkeyHex)
if err != nil {
return nil, err
}
return d.FindMentions(pubkey, kinds)
}
// FindMentionsByPubkeys returns events that mention any of the given pubkeys.
// Useful for finding mentions across a set of followed accounts.
func (d *D) FindMentionsByPubkeys(pubkeySerials []*types.Uint40, kinds []uint16) (*GraphResult, error) {
result := NewGraphResult()
seen := make(map[uint64]bool)
for _, pubkeySerial := range pubkeySerials {
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds)
if err != nil {
log.D.F("FindMentionsByPubkeys: error for serial %d: %v", pubkeySerial.Get(), err)
continue
}
for _, eventSerial := range eventSerials {
if seen[eventSerial.Get()] {
continue
}
seen[eventSerial.Get()] = true
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
if err != nil {
continue
}
result.AddEventAtDepth(eventIDHex, 1)
}
}
return result, nil
}

206
pkg/database/graph-refs.go Normal file
View File

@@ -0,0 +1,206 @@
//go:build !(js && wasm)
package database
import (
"lol.mleku.dev/log"
"next.orly.dev/pkg/database/indexes/types"
)
// AddInboundRefsToResult collects inbound references (events that reference discovered items)
// for events at a specific depth in the result.
//
// For example, if you have a follows graph result and want to find all kind-7 reactions
// to posts by users at depth 1, this collects those reactions and adds them to result.InboundRefs.
//
// Parameters:
// - result: The graph result to augment with ref data
// - depth: The depth at which to collect refs (0 = all depths)
// - kinds: Event kinds to collect (e.g., [7] for reactions, [6] for reposts)
func (d *D) AddInboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error {
// Determine which events to find refs for
var targetEventIDs []string
if depth == 0 {
// Collect for all depths
targetEventIDs = result.GetAllEvents()
} else {
targetEventIDs = result.GetEventsAtDepth(depth)
}
// Also collect refs for events authored by pubkeys in the result
// This is common for "find reactions to posts by my follows" queries
pubkeys := result.GetAllPubkeys()
for _, pubkeyHex := range pubkeys {
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex)
if err != nil {
continue
}
// Get events authored by this pubkey
// For efficiency, limit to relevant event kinds that might have reactions
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, []uint16{1, 30023}) // notes and articles
if err != nil {
continue
}
for _, eventSerial := range authoredEvents {
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
if err != nil {
continue
}
// Add to target list if not already tracking
if !result.HasEvent(eventIDHex) {
targetEventIDs = append(targetEventIDs, eventIDHex)
}
}
}
// For each target event, find referencing events
for _, eventIDHex := range targetEventIDs {
eventSerial, err := d.EventIDHexToSerial(eventIDHex)
if err != nil {
continue
}
refSerials, err := d.GetReferencingEvents(eventSerial, kinds)
if err != nil {
continue
}
for _, refSerial := range refSerials {
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
if err != nil {
continue
}
// Get the kind of the referencing event
// For now, use the first kind in the filter (assumes single kind queries)
// TODO: Look up actual event kind from index if needed
if len(kinds) > 0 {
result.AddInboundRef(kinds[0], eventIDHex, refEventIDHex)
}
}
}
log.D.F("AddInboundRefsToResult: collected refs for %d target events", len(targetEventIDs))
return nil
}
// AddOutboundRefsToResult collects outbound references (events referenced by discovered items).
//
// For example, find all events that posts by users at depth 1 reference (quoted posts, replied-to posts).
func (d *D) AddOutboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error {
// Determine source events
var sourceEventIDs []string
if depth == 0 {
sourceEventIDs = result.GetAllEvents()
} else {
sourceEventIDs = result.GetEventsAtDepth(depth)
}
// Also include events authored by pubkeys in result
pubkeys := result.GetAllPubkeys()
for _, pubkeyHex := range pubkeys {
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex)
if err != nil {
continue
}
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, kinds)
if err != nil {
continue
}
for _, eventSerial := range authoredEvents {
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
if err != nil {
continue
}
if !result.HasEvent(eventIDHex) {
sourceEventIDs = append(sourceEventIDs, eventIDHex)
}
}
}
// For each source event, find referenced events
for _, eventIDHex := range sourceEventIDs {
eventSerial, err := d.EventIDHexToSerial(eventIDHex)
if err != nil {
continue
}
refSerials, err := d.GetETagsFromEventSerial(eventSerial)
if err != nil {
continue
}
for _, refSerial := range refSerials {
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
if err != nil {
continue
}
// Use first kind for categorization
if len(kinds) > 0 {
result.AddOutboundRef(kinds[0], eventIDHex, refEventIDHex)
}
}
}
log.D.F("AddOutboundRefsToResult: collected refs from %d source events", len(sourceEventIDs))
return nil
}
// CollectRefsForPubkeys collects inbound references to events by specific pubkeys.
// This is useful for "find all reactions to posts by these users" queries.
//
// Parameters:
// - pubkeySerials: The pubkeys whose events should be checked for refs
// - refKinds: Event kinds to collect (e.g., [7] for reactions)
// - eventKinds: Event kinds to check for refs (e.g., [1] for notes)
func (d *D) CollectRefsForPubkeys(
pubkeySerials []*types.Uint40,
refKinds []uint16,
eventKinds []uint16,
) (*GraphResult, error) {
result := NewGraphResult()
for _, pubkeySerial := range pubkeySerials {
// Get events by this author
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, eventKinds)
if err != nil {
continue
}
for _, eventSerial := range authoredEvents {
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
if err != nil {
continue
}
// Find refs to this event
refSerials, err := d.GetReferencingEvents(eventSerial, refKinds)
if err != nil {
continue
}
for _, refSerial := range refSerials {
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
if err != nil {
continue
}
// Add to result
if len(refKinds) > 0 {
result.AddInboundRef(refKinds[0], eventIDHex, refEventIDHex)
}
}
}
}
return result, nil
}

View File

@@ -0,0 +1,327 @@
//go:build !(js && wasm)
package database
import (
"sort"
)
// GraphResult contains depth-organized traversal results for graph queries.
// It tracks pubkeys and events discovered at each depth level, ensuring
// each entity appears only at the depth where it was first discovered.
type GraphResult struct {
// PubkeysByDepth maps depth -> pubkeys first discovered at that depth.
// Each pubkey appears ONLY in the array for the depth where it was first seen.
// Depth 1 = direct connections, Depth 2 = connections of connections, etc.
PubkeysByDepth map[int][]string
// EventsByDepth maps depth -> event IDs discovered at that depth.
// Used for thread traversal queries.
EventsByDepth map[int][]string
// FirstSeenPubkey tracks which depth each pubkey was first discovered.
// Key is pubkey hex, value is the depth (1-indexed).
FirstSeenPubkey map[string]int
// FirstSeenEvent tracks which depth each event was first discovered.
// Key is event ID hex, value is the depth (1-indexed).
FirstSeenEvent map[string]int
// TotalPubkeys is the count of unique pubkeys discovered across all depths.
TotalPubkeys int
// TotalEvents is the count of unique events discovered across all depths.
TotalEvents int
// InboundRefs tracks inbound references (events that reference discovered items).
// Structure: kind -> target_id -> []referencing_event_ids
InboundRefs map[uint16]map[string][]string
// OutboundRefs tracks outbound references (events referenced by discovered items).
// Structure: kind -> source_id -> []referenced_event_ids
OutboundRefs map[uint16]map[string][]string
}
// NewGraphResult creates a new initialized GraphResult.
func NewGraphResult() *GraphResult {
return &GraphResult{
PubkeysByDepth: make(map[int][]string),
EventsByDepth: make(map[int][]string),
FirstSeenPubkey: make(map[string]int),
FirstSeenEvent: make(map[string]int),
InboundRefs: make(map[uint16]map[string][]string),
OutboundRefs: make(map[uint16]map[string][]string),
}
}
// AddPubkeyAtDepth adds a pubkey to the result at the specified depth if not already seen.
// Returns true if the pubkey was added (first time seen), false if already exists.
func (r *GraphResult) AddPubkeyAtDepth(pubkeyHex string, depth int) bool {
if _, exists := r.FirstSeenPubkey[pubkeyHex]; exists {
return false
}
r.FirstSeenPubkey[pubkeyHex] = depth
r.PubkeysByDepth[depth] = append(r.PubkeysByDepth[depth], pubkeyHex)
r.TotalPubkeys++
return true
}
// AddEventAtDepth adds an event ID to the result at the specified depth if not already seen.
// Returns true if the event was added (first time seen), false if already exists.
func (r *GraphResult) AddEventAtDepth(eventIDHex string, depth int) bool {
if _, exists := r.FirstSeenEvent[eventIDHex]; exists {
return false
}
r.FirstSeenEvent[eventIDHex] = depth
r.EventsByDepth[depth] = append(r.EventsByDepth[depth], eventIDHex)
r.TotalEvents++
return true
}
// HasPubkey returns true if the pubkey has been discovered at any depth.
func (r *GraphResult) HasPubkey(pubkeyHex string) bool {
_, exists := r.FirstSeenPubkey[pubkeyHex]
return exists
}
// HasEvent returns true if the event has been discovered at any depth.
func (r *GraphResult) HasEvent(eventIDHex string) bool {
_, exists := r.FirstSeenEvent[eventIDHex]
return exists
}
// GetPubkeyDepth returns the depth at which a pubkey was first discovered.
// Returns 0 if the pubkey was not found.
func (r *GraphResult) GetPubkeyDepth(pubkeyHex string) int {
return r.FirstSeenPubkey[pubkeyHex]
}
// GetEventDepth returns the depth at which an event was first discovered.
// Returns 0 if the event was not found.
func (r *GraphResult) GetEventDepth(eventIDHex string) int {
return r.FirstSeenEvent[eventIDHex]
}
// GetDepthsSorted returns all depths that have pubkeys, sorted ascending.
func (r *GraphResult) GetDepthsSorted() []int {
depths := make([]int, 0, len(r.PubkeysByDepth))
for d := range r.PubkeysByDepth {
depths = append(depths, d)
}
sort.Ints(depths)
return depths
}
// GetEventDepthsSorted returns all depths that have events, sorted ascending.
func (r *GraphResult) GetEventDepthsSorted() []int {
depths := make([]int, 0, len(r.EventsByDepth))
for d := range r.EventsByDepth {
depths = append(depths, d)
}
sort.Ints(depths)
return depths
}
// ToDepthArrays converts the result to the response format: array of arrays.
// Index 0 = depth 1 pubkeys, Index 1 = depth 2 pubkeys, etc.
// Empty arrays are included for depths with no pubkeys to maintain index alignment.
func (r *GraphResult) ToDepthArrays() [][]string {
if len(r.PubkeysByDepth) == 0 {
return [][]string{}
}
// Find the maximum depth
maxDepth := 0
for d := range r.PubkeysByDepth {
if d > maxDepth {
maxDepth = d
}
}
// Create result array with entries for each depth
result := make([][]string, maxDepth)
for i := 0; i < maxDepth; i++ {
depth := i + 1 // depths are 1-indexed
if pubkeys, exists := r.PubkeysByDepth[depth]; exists {
result[i] = pubkeys
} else {
result[i] = []string{} // Empty array for depths with no pubkeys
}
}
return result
}
// ToEventDepthArrays converts event results to the response format: array of arrays.
// Index 0 = depth 1 events, Index 1 = depth 2 events, etc.
func (r *GraphResult) ToEventDepthArrays() [][]string {
if len(r.EventsByDepth) == 0 {
return [][]string{}
}
maxDepth := 0
for d := range r.EventsByDepth {
if d > maxDepth {
maxDepth = d
}
}
result := make([][]string, maxDepth)
for i := 0; i < maxDepth; i++ {
depth := i + 1
if events, exists := r.EventsByDepth[depth]; exists {
result[i] = events
} else {
result[i] = []string{}
}
}
return result
}
// AddInboundRef records an inbound reference from a referencing event to a target.
func (r *GraphResult) AddInboundRef(kind uint16, targetIDHex string, referencingEventIDHex string) {
if r.InboundRefs[kind] == nil {
r.InboundRefs[kind] = make(map[string][]string)
}
r.InboundRefs[kind][targetIDHex] = append(r.InboundRefs[kind][targetIDHex], referencingEventIDHex)
}
// AddOutboundRef records an outbound reference from a source event to a referenced event.
func (r *GraphResult) AddOutboundRef(kind uint16, sourceIDHex string, referencedEventIDHex string) {
if r.OutboundRefs[kind] == nil {
r.OutboundRefs[kind] = make(map[string][]string)
}
r.OutboundRefs[kind][sourceIDHex] = append(r.OutboundRefs[kind][sourceIDHex], referencedEventIDHex)
}
// RefAggregation represents aggregated reference data for a single target/source.
type RefAggregation struct {
// TargetEventID is the event ID being referenced (for inbound) or referencing (for outbound)
TargetEventID string
// TargetAuthor is the author pubkey of the target event (if known)
TargetAuthor string
// TargetDepth is the depth at which this target was discovered in the graph
TargetDepth int
// RefKind is the kind of the referencing events
RefKind uint16
// RefCount is the number of references to/from this target
RefCount int
// RefEventIDs is the list of event IDs that reference this target
RefEventIDs []string
}
// GetInboundRefsSorted returns inbound refs for a kind, sorted by count descending.
func (r *GraphResult) GetInboundRefsSorted(kind uint16) []RefAggregation {
kindRefs := r.InboundRefs[kind]
if kindRefs == nil {
return nil
}
aggs := make([]RefAggregation, 0, len(kindRefs))
for targetID, refs := range kindRefs {
agg := RefAggregation{
TargetEventID: targetID,
TargetDepth: r.GetEventDepth(targetID),
RefKind: kind,
RefCount: len(refs),
RefEventIDs: refs,
}
aggs = append(aggs, agg)
}
// Sort by count descending
sort.Slice(aggs, func(i, j int) bool {
return aggs[i].RefCount > aggs[j].RefCount
})
return aggs
}
// GetOutboundRefsSorted returns outbound refs for a kind, sorted by count descending.
func (r *GraphResult) GetOutboundRefsSorted(kind uint16) []RefAggregation {
kindRefs := r.OutboundRefs[kind]
if kindRefs == nil {
return nil
}
aggs := make([]RefAggregation, 0, len(kindRefs))
for sourceID, refs := range kindRefs {
agg := RefAggregation{
TargetEventID: sourceID,
TargetDepth: r.GetEventDepth(sourceID),
RefKind: kind,
RefCount: len(refs),
RefEventIDs: refs,
}
aggs = append(aggs, agg)
}
sort.Slice(aggs, func(i, j int) bool {
return aggs[i].RefCount > aggs[j].RefCount
})
return aggs
}
// GetAllPubkeys returns all pubkeys discovered across all depths.
func (r *GraphResult) GetAllPubkeys() []string {
all := make([]string, 0, r.TotalPubkeys)
for _, pubkeys := range r.PubkeysByDepth {
all = append(all, pubkeys...)
}
return all
}
// GetAllEvents returns all event IDs discovered across all depths.
func (r *GraphResult) GetAllEvents() []string {
all := make([]string, 0, r.TotalEvents)
for _, events := range r.EventsByDepth {
all = append(all, events...)
}
return all
}
// GetPubkeysAtDepth returns pubkeys at a specific depth, or empty slice if none.
func (r *GraphResult) GetPubkeysAtDepth(depth int) []string {
if pubkeys, exists := r.PubkeysByDepth[depth]; exists {
return pubkeys
}
return []string{}
}
// GetEventsAtDepth returns events at a specific depth, or empty slice if none.
func (r *GraphResult) GetEventsAtDepth(depth int) []string {
if events, exists := r.EventsByDepth[depth]; exists {
return events
}
return []string{}
}
// Interface methods for external package access (e.g., pkg/protocol/graph)
// These allow the graph executor to extract data without direct struct access.
// GetPubkeysByDepth returns the PubkeysByDepth map for external access.
func (r *GraphResult) GetPubkeysByDepth() map[int][]string {
return r.PubkeysByDepth
}
// GetEventsByDepth returns the EventsByDepth map for external access.
func (r *GraphResult) GetEventsByDepth() map[int][]string {
return r.EventsByDepth
}
// GetTotalPubkeys returns the total pubkey count for external access.
func (r *GraphResult) GetTotalPubkeys() int {
return r.TotalPubkeys
}
// GetTotalEvents returns the total event count for external access.
func (r *GraphResult) GetTotalEvents() int {
return r.TotalEvents
}

View File

@@ -0,0 +1,191 @@
//go:build !(js && wasm)
package database
import (
"lol.mleku.dev/log"
"next.orly.dev/pkg/database/indexes/types"
"git.mleku.dev/mleku/nostr/encoders/hex"
)
// TraverseThread performs BFS traversal of thread structure via e-tags.
// Starting from a seed event, it finds all replies/references at each depth.
//
// The traversal works bidirectionally:
// - Forward: Events that the seed references (parents, quoted posts)
// - Backward: Events that reference the seed (replies, reactions, reposts)
//
// Parameters:
// - seedEventID: The event ID to start traversal from
// - maxDepth: Maximum depth to traverse
// - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
func (d *D) TraverseThread(seedEventID []byte, maxDepth int, direction string) (*GraphResult, error) {
result := NewGraphResult()
if len(seedEventID) != 32 {
return result, ErrEventNotFound
}
// Get seed event serial
seedSerial, err := d.GetSerialById(seedEventID)
if err != nil {
log.D.F("TraverseThread: seed event not in database: %s", hex.Enc(seedEventID))
return result, nil
}
// Normalize direction
if direction == "" {
direction = "both"
}
// Track visited events
visited := make(map[uint64]bool)
visited[seedSerial.Get()] = true
// Current frontier
currentFrontier := []*types.Uint40{seedSerial}
consecutiveEmptyDepths := 0
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
var nextFrontier []*types.Uint40
newEventsAtDepth := 0
for _, eventSerial := range currentFrontier {
// Get inbound references (events that reference this event)
if direction == "both" || direction == "inbound" {
inboundSerials, err := d.GetReferencingEvents(eventSerial, nil)
if err != nil {
log.D.F("TraverseThread: error getting inbound refs for serial %d: %v", eventSerial.Get(), err)
} else {
for _, refSerial := range inboundSerials {
if visited[refSerial.Get()] {
continue
}
visited[refSerial.Get()] = true
eventIDHex, err := d.GetEventIDFromSerial(refSerial)
if err != nil {
continue
}
result.AddEventAtDepth(eventIDHex, currentDepth)
newEventsAtDepth++
nextFrontier = append(nextFrontier, refSerial)
}
}
}
// Get outbound references (events this event references)
if direction == "both" || direction == "outbound" {
outboundSerials, err := d.GetETagsFromEventSerial(eventSerial)
if err != nil {
log.D.F("TraverseThread: error getting outbound refs for serial %d: %v", eventSerial.Get(), err)
} else {
for _, refSerial := range outboundSerials {
if visited[refSerial.Get()] {
continue
}
visited[refSerial.Get()] = true
eventIDHex, err := d.GetEventIDFromSerial(refSerial)
if err != nil {
continue
}
result.AddEventAtDepth(eventIDHex, currentDepth)
newEventsAtDepth++
nextFrontier = append(nextFrontier, refSerial)
}
}
}
}
log.T.F("TraverseThread: depth %d found %d new events", currentDepth, newEventsAtDepth)
if newEventsAtDepth == 0 {
consecutiveEmptyDepths++
if consecutiveEmptyDepths >= 2 {
break
}
} else {
consecutiveEmptyDepths = 0
}
currentFrontier = nextFrontier
}
log.D.F("TraverseThread: completed with %d total events", result.TotalEvents)
return result, nil
}
// TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
func (d *D) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) {
seedEventID, err := hex.Dec(seedEventIDHex)
if err != nil {
return nil, err
}
return d.TraverseThread(seedEventID, maxDepth, direction)
}
// GetThreadReplies finds all direct replies to an event.
// This is a convenience method that returns events at depth 1 with inbound direction.
func (d *D) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) {
result := NewGraphResult()
if len(eventID) != 32 {
return result, ErrEventNotFound
}
eventSerial, err := d.GetSerialById(eventID)
if err != nil {
return result, nil
}
// Get events that reference this event
replySerials, err := d.GetReferencingEvents(eventSerial, kinds)
if err != nil {
return nil, err
}
for _, replySerial := range replySerials {
eventIDHex, err := d.GetEventIDFromSerial(replySerial)
if err != nil {
continue
}
result.AddEventAtDepth(eventIDHex, 1)
}
return result, nil
}
// GetThreadParents finds events that a given event references (its parents/quotes).
func (d *D) GetThreadParents(eventID []byte) (*GraphResult, error) {
result := NewGraphResult()
if len(eventID) != 32 {
return result, ErrEventNotFound
}
eventSerial, err := d.GetSerialById(eventID)
if err != nil {
return result, nil
}
// Get events that this event references
parentSerials, err := d.GetETagsFromEventSerial(eventSerial)
if err != nil {
return nil, err
}
for _, parentSerial := range parentSerials {
eventIDHex, err := d.GetEventIDFromSerial(parentSerial)
if err != nil {
continue
}
result.AddEventAtDepth(eventIDHex, 1)
}
return result, nil
}

View File

@@ -0,0 +1,560 @@
//go:build !(js && wasm)
package database
import (
"bytes"
"errors"
"github.com/dgraph-io/badger/v4"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"next.orly.dev/pkg/database/indexes"
"next.orly.dev/pkg/database/indexes/types"
"git.mleku.dev/mleku/nostr/encoders/hex"
)
// Graph traversal errors
var (
ErrPubkeyNotFound = errors.New("pubkey not found in database")
ErrEventNotFound = errors.New("event not found in database")
)
// GetPTagsFromEventSerial extracts p-tag pubkey serials from an event by its serial.
// This is a pure index-based operation - no event decoding required.
// It scans the epg (event-pubkey-graph) index for p-tag edges.
func (d *D) GetPTagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) {
var pubkeySerials []*types.Uint40
// Build prefix: epg|event_serial
prefix := new(bytes.Buffer)
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix))
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
if len(key) != 16 {
continue
}
// Extract direction to filter for p-tags only
direction := key[15]
if direction != types.EdgeDirectionPTagOut {
continue // Skip author edges, only want p-tag edges
}
// Extract pubkey serial (bytes 8-12)
pubkeySerial := new(types.Uint40)
serialReader := bytes.NewReader(key[8:13])
if err := pubkeySerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
pubkeySerials = append(pubkeySerials, pubkeySerial)
}
return nil
})
return pubkeySerials, err
}
// GetETagsFromEventSerial extracts e-tag event serials from an event by its serial.
// This is a pure index-based operation - no event decoding required.
// It scans the eeg (event-event-graph) index for outbound e-tag edges.
func (d *D) GetETagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) {
var targetSerials []*types.Uint40
// Build prefix: eeg|source_event_serial
prefix := new(bytes.Buffer)
prefix.Write([]byte(indexes.EventEventGraphPrefix))
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Decode key: eeg(3)|source_serial(5)|target_serial(5)|kind(2)|direction(1)
if len(key) != 16 {
continue
}
// Extract target serial (bytes 8-12)
targetSerial := new(types.Uint40)
serialReader := bytes.NewReader(key[8:13])
if err := targetSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
targetSerials = append(targetSerials, targetSerial)
}
return nil
})
return targetSerials, err
}
// GetReferencingEvents finds all events that reference a target event via e-tags.
// Optionally filters by event kinds. Uses the gee (reverse e-tag graph) index.
func (d *D) GetReferencingEvents(targetSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
var sourceSerials []*types.Uint40
if len(kinds) == 0 {
// No kind filter - scan all kinds
prefix := new(bytes.Buffer)
prefix.Write([]byte(indexes.GraphEventEventPrefix))
if err := targetSerial.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Decode key: gee(3)|target_serial(5)|kind(2)|direction(1)|source_serial(5)
if len(key) != 16 {
continue
}
// Extract source serial (bytes 11-15)
sourceSerial := new(types.Uint40)
serialReader := bytes.NewReader(key[11:16])
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
sourceSerials = append(sourceSerials, sourceSerial)
}
return nil
})
return sourceSerials, err
}
// With kind filter - scan each kind's prefix
for _, k := range kinds {
kind := new(types.Uint16)
kind.Set(k)
direction := new(types.Letter)
direction.Set(types.EdgeDirectionETagIn)
prefix := new(bytes.Buffer)
if err := indexes.GraphEventEventEnc(targetSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Extract source serial (last 5 bytes)
if len(key) < 5 {
continue
}
sourceSerial := new(types.Uint40)
serialReader := bytes.NewReader(key[len(key)-5:])
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
sourceSerials = append(sourceSerials, sourceSerial)
}
return nil
})
if chk.E(err) {
return nil, err
}
}
return sourceSerials, nil
}
// FindEventByAuthorAndKind finds the most recent event of a specific kind by an author.
// This is used to find kind-3 contact lists for follow graph traversal.
// Returns nil, nil if no matching event is found.
func (d *D) FindEventByAuthorAndKind(authorSerial *types.Uint40, kind uint16) (*types.Uint40, error) {
var eventSerial *types.Uint40
// First, get the full pubkey from the serial
pubkey, err := d.GetPubkeyBySerial(authorSerial)
if err != nil {
return nil, err
}
// Build prefix for kind-pubkey index: kpc|kind|pubkey_hash
pubHash := new(types.PubHash)
if err := pubHash.FromPubkey(pubkey); chk.E(err) {
return nil, err
}
kindType := new(types.Uint16)
kindType.Set(kind)
prefix := new(bytes.Buffer)
prefix.Write([]byte(indexes.KindPubkeyPrefix))
if err := kindType.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
if err := pubHash.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err = d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
opts.Reverse = true // Most recent first (highest created_at)
it := txn.NewIterator(opts)
defer it.Close()
// Seek to end of prefix range for reverse iteration
seekKey := make([]byte, len(searchPrefix)+8+5) // prefix + max timestamp + max serial
copy(seekKey, searchPrefix)
for i := len(searchPrefix); i < len(seekKey); i++ {
seekKey[i] = 0xFF
}
it.Seek(seekKey)
if !it.ValidForPrefix(searchPrefix) {
// Try going to the first valid key if seek went past
it.Rewind()
it.Seek(searchPrefix)
}
if it.ValidForPrefix(searchPrefix) {
key := it.Item().KeyCopy(nil)
// Decode key: kpc(3)|kind(2)|pubkey_hash(8)|created_at(8)|serial(5)
// Total: 26 bytes
if len(key) < 26 {
return nil
}
// Extract serial (last 5 bytes)
eventSerial = new(types.Uint40)
serialReader := bytes.NewReader(key[len(key)-5:])
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
return err
}
}
return nil
})
return eventSerial, err
}
// GetPubkeyHexFromSerial converts a pubkey serial to its hex string representation.
func (d *D) GetPubkeyHexFromSerial(serial *types.Uint40) (string, error) {
pubkey, err := d.GetPubkeyBySerial(serial)
if err != nil {
return "", err
}
return hex.Enc(pubkey), nil
}
// GetEventIDFromSerial converts an event serial to its hex ID string.
func (d *D) GetEventIDFromSerial(serial *types.Uint40) (string, error) {
eventID, err := d.GetEventIdBySerial(serial)
if err != nil {
return "", err
}
return hex.Enc(eventID), nil
}
// GetEventsReferencingPubkey finds all events that reference a pubkey via p-tags.
// Uses the peg (pubkey-event-graph) index with direction filter for inbound p-tags.
// Optionally filters by event kinds.
func (d *D) GetEventsReferencingPubkey(pubkeySerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
var eventSerials []*types.Uint40
if len(kinds) == 0 {
// No kind filter - we need to scan common kinds since direction comes after kind in the key
// Use same approach as QueryPTagGraph
commonKinds := []uint16{1, 6, 7, 9735, 10002, 3, 4, 5, 30023}
kinds = commonKinds
}
for _, k := range kinds {
kind := new(types.Uint16)
kind.Set(k)
direction := new(types.Letter)
direction.Set(types.EdgeDirectionPTagIn) // Inbound p-tags
prefix := new(bytes.Buffer)
if err := indexes.PubkeyEventGraphEnc(pubkeySerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
if len(key) != 16 {
continue
}
// Extract event serial (last 5 bytes)
eventSerial := new(types.Uint40)
serialReader := bytes.NewReader(key[11:16])
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
eventSerials = append(eventSerials, eventSerial)
}
return nil
})
if chk.E(err) {
return nil, err
}
}
return eventSerials, nil
}
// GetEventsByAuthor finds all events authored by a pubkey.
// Uses the peg (pubkey-event-graph) index with direction filter for author edges.
// Optionally filters by event kinds.
func (d *D) GetEventsByAuthor(authorSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
var eventSerials []*types.Uint40
if len(kinds) == 0 {
// No kind filter - scan for author direction across common kinds
// This is less efficient but necessary without kind filter
commonKinds := []uint16{0, 1, 3, 6, 7, 30023, 10002}
kinds = commonKinds
}
for _, k := range kinds {
kind := new(types.Uint16)
kind.Set(k)
direction := new(types.Letter)
direction.Set(types.EdgeDirectionAuthor) // Author edges
prefix := new(bytes.Buffer)
if err := indexes.PubkeyEventGraphEnc(authorSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
if len(key) != 16 {
continue
}
// Extract event serial (last 5 bytes)
eventSerial := new(types.Uint40)
serialReader := bytes.NewReader(key[11:16])
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
eventSerials = append(eventSerials, eventSerial)
}
return nil
})
if chk.E(err) {
return nil, err
}
}
return eventSerials, nil
}
// GetFollowsFromPubkeySerial returns the pubkey serials that a user follows.
// This extracts p-tags from the user's kind-3 contact list event.
// Returns an empty slice if no kind-3 event is found.
func (d *D) GetFollowsFromPubkeySerial(pubkeySerial *types.Uint40) ([]*types.Uint40, error) {
// Find the kind-3 event for this pubkey
contactEventSerial, err := d.FindEventByAuthorAndKind(pubkeySerial, 3)
if err != nil {
log.D.F("GetFollowsFromPubkeySerial: error finding kind-3 for serial %d: %v", pubkeySerial.Get(), err)
return nil, nil // No kind-3 event found is not an error
}
if contactEventSerial == nil {
log.T.F("GetFollowsFromPubkeySerial: no kind-3 event found for serial %d", pubkeySerial.Get())
return nil, nil
}
// Extract p-tags from the contact list event
follows, err := d.GetPTagsFromEventSerial(contactEventSerial)
if err != nil {
return nil, err
}
log.T.F("GetFollowsFromPubkeySerial: found %d follows for serial %d", len(follows), pubkeySerial.Get())
return follows, nil
}
// GetFollowersOfPubkeySerial returns the pubkey serials of users who follow a given pubkey.
// This finds all kind-3 events that have a p-tag referencing the target pubkey.
func (d *D) GetFollowersOfPubkeySerial(targetSerial *types.Uint40) ([]*types.Uint40, error) {
// Find all kind-3 events that reference this pubkey via p-tag
kind3Events, err := d.GetEventsReferencingPubkey(targetSerial, []uint16{3})
if err != nil {
return nil, err
}
// Extract the author serials from these events
var followerSerials []*types.Uint40
seen := make(map[uint64]bool)
for _, eventSerial := range kind3Events {
// Get the author of this kind-3 event
// We need to look up the event to get its author
// Use the epg index to find the author edge
authorSerial, err := d.GetEventAuthorSerial(eventSerial)
if err != nil {
log.D.F("GetFollowersOfPubkeySerial: couldn't get author for event %d: %v", eventSerial.Get(), err)
continue
}
// Deduplicate (a user might have multiple kind-3 events)
if seen[authorSerial.Get()] {
continue
}
seen[authorSerial.Get()] = true
followerSerials = append(followerSerials, authorSerial)
}
log.T.F("GetFollowersOfPubkeySerial: found %d followers for serial %d", len(followerSerials), targetSerial.Get())
return followerSerials, nil
}
// GetEventAuthorSerial finds the author pubkey serial for an event.
// Uses the epg (event-pubkey-graph) index with author direction.
func (d *D) GetEventAuthorSerial(eventSerial *types.Uint40) (*types.Uint40, error) {
var authorSerial *types.Uint40
// Build prefix: epg|event_serial
prefix := new(bytes.Buffer)
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix))
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
return nil, err
}
searchPrefix := prefix.Bytes()
err := d.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = false
opts.Prefix = searchPrefix
it := txn.NewIterator(opts)
defer it.Close()
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
key := it.Item().KeyCopy(nil)
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
if len(key) != 16 {
continue
}
// Check direction - we want author (0)
direction := key[15]
if direction != types.EdgeDirectionAuthor {
continue
}
// Extract pubkey serial (bytes 8-12)
authorSerial = new(types.Uint40)
serialReader := bytes.NewReader(key[8:13])
if err := authorSerial.UnmarshalRead(serialReader); chk.E(err) {
continue
}
return nil // Found the author
}
return ErrEventNotFound
})
return authorSerial, err
}
// PubkeyHexToSerial converts a pubkey hex string to its serial, if it exists.
// Returns an error if the pubkey is not in the database.
func (d *D) PubkeyHexToSerial(pubkeyHex string) (*types.Uint40, error) {
pubkeyBytes, err := hex.Dec(pubkeyHex)
if err != nil {
return nil, err
}
if len(pubkeyBytes) != 32 {
return nil, errors.New("invalid pubkey length")
}
return d.GetPubkeySerial(pubkeyBytes)
}
// EventIDHexToSerial converts an event ID hex string to its serial, if it exists.
// Returns an error if the event is not in the database.
func (d *D) EventIDHexToSerial(eventIDHex string) (*types.Uint40, error) {
eventIDBytes, err := hex.Dec(eventIDHex)
if err != nil {
return nil, err
}
if len(eventIDBytes) != 32 {
return nil, errors.New("invalid event ID length")
}
return d.GetSerialById(eventIDBytes)
}

View File

@@ -0,0 +1,547 @@
//go:build !(js && wasm)
package database
import (
"context"
"testing"
"git.mleku.dev/mleku/nostr/encoders/event"
"git.mleku.dev/mleku/nostr/encoders/hex"
"git.mleku.dev/mleku/nostr/encoders/tag"
)
func TestGetPTagsFromEventSerial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create an author pubkey
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
// Create p-tag target pubkeys
target1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
target2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
// Create event with p-tags
eventID := make([]byte, 32)
eventID[0] = 0x10
eventSig := make([]byte, 64)
eventSig[0] = 0x10
ev := &event.E{
ID: eventID,
Pubkey: authorPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Test event with p-tags"),
Sig: eventSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(target1)),
tag.NewFromAny("p", hex.Enc(target2)),
),
}
_, err = db.SaveEvent(ctx, ev)
if err != nil {
t.Fatalf("Failed to save event: %v", err)
}
// Get the event serial
eventSerial, err := db.GetSerialById(eventID)
if err != nil {
t.Fatalf("Failed to get event serial: %v", err)
}
// Get p-tags from event serial
ptagSerials, err := db.GetPTagsFromEventSerial(eventSerial)
if err != nil {
t.Fatalf("GetPTagsFromEventSerial failed: %v", err)
}
// Should have 2 p-tags
if len(ptagSerials) != 2 {
t.Errorf("Expected 2 p-tag serials, got %d", len(ptagSerials))
}
// Verify the pubkeys
for _, serial := range ptagSerials {
pubkey, err := db.GetPubkeyBySerial(serial)
if err != nil {
t.Errorf("Failed to get pubkey for serial: %v", err)
continue
}
pubkeyHex := hex.Enc(pubkey)
if pubkeyHex != hex.Enc(target1) && pubkeyHex != hex.Enc(target2) {
t.Errorf("Unexpected pubkey: %s", pubkeyHex)
}
}
}
func TestGetETagsFromEventSerial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a parent event
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
parentID := make([]byte, 32)
parentID[0] = 0x10
parentSig := make([]byte, 64)
parentSig[0] = 0x10
parentEvent := &event.E{
ID: parentID,
Pubkey: parentPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Parent post"),
Sig: parentSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, parentEvent)
if err != nil {
t.Fatalf("Failed to save parent event: %v", err)
}
// Create a reply event with e-tag
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
replyID := make([]byte, 32)
replyID[0] = 0x20
replySig := make([]byte, 64)
replySig[0] = 0x20
replyEvent := &event.E{
ID: replyID,
Pubkey: replyPubkey,
CreatedAt: 1234567891,
Kind: 1,
Content: []byte("Reply"),
Sig: replySig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err = db.SaveEvent(ctx, replyEvent)
if err != nil {
t.Fatalf("Failed to save reply event: %v", err)
}
// Get e-tags from reply
replySerial, _ := db.GetSerialById(replyID)
etagSerials, err := db.GetETagsFromEventSerial(replySerial)
if err != nil {
t.Fatalf("GetETagsFromEventSerial failed: %v", err)
}
if len(etagSerials) != 1 {
t.Errorf("Expected 1 e-tag serial, got %d", len(etagSerials))
}
// Verify the target event
if len(etagSerials) > 0 {
targetEventID, err := db.GetEventIdBySerial(etagSerials[0])
if err != nil {
t.Fatalf("Failed to get event ID from serial: %v", err)
}
if hex.Enc(targetEventID) != hex.Enc(parentID) {
t.Errorf("Expected parent ID, got %s", hex.Enc(targetEventID))
}
}
}
func TestGetReferencingEvents(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a parent event
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
parentID := make([]byte, 32)
parentID[0] = 0x10
parentSig := make([]byte, 64)
parentSig[0] = 0x10
parentEvent := &event.E{
ID: parentID,
Pubkey: parentPubkey,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Parent post"),
Sig: parentSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, parentEvent)
if err != nil {
t.Fatalf("Failed to save parent event: %v", err)
}
// Create multiple replies and reactions
for i := 0; i < 3; i++ {
replyPubkey := make([]byte, 32)
replyPubkey[0] = byte(0x20 + i)
replyID := make([]byte, 32)
replyID[0] = byte(0x30 + i)
replySig := make([]byte, 64)
replySig[0] = byte(0x30 + i)
var evKind uint16 = 1 // Reply
if i == 2 {
evKind = 7 // Reaction
}
replyEvent := &event.E{
ID: replyID,
Pubkey: replyPubkey,
CreatedAt: int64(1234567891 + i),
Kind: evKind,
Content: []byte("Response"),
Sig: replySig,
Tags: tag.NewS(
tag.NewFromAny("e", hex.Enc(parentID)),
),
}
_, err = db.SaveEvent(ctx, replyEvent)
if err != nil {
t.Fatalf("Failed to save reply %d: %v", i, err)
}
}
// Get parent serial
parentSerial, _ := db.GetSerialById(parentID)
// Test without kind filter
refs, err := db.GetReferencingEvents(parentSerial, nil)
if err != nil {
t.Fatalf("GetReferencingEvents failed: %v", err)
}
if len(refs) != 3 {
t.Errorf("Expected 3 referencing events, got %d", len(refs))
}
// Test with kind filter (only replies)
refs, err = db.GetReferencingEvents(parentSerial, []uint16{1})
if err != nil {
t.Fatalf("GetReferencingEvents with kind filter failed: %v", err)
}
if len(refs) != 2 {
t.Errorf("Expected 2 kind-1 referencing events, got %d", len(refs))
}
// Test with kind filter (only reactions)
refs, err = db.GetReferencingEvents(parentSerial, []uint16{7})
if err != nil {
t.Fatalf("GetReferencingEvents with kind 7 filter failed: %v", err)
}
if len(refs) != 1 {
t.Errorf("Expected 1 kind-7 referencing event, got %d", len(refs))
}
}
func TestGetFollowsFromPubkeySerial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create author and their follows
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
follow1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
follow2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
follow3, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004")
// Create kind-3 contact list
eventID := make([]byte, 32)
eventID[0] = 0x10
eventSig := make([]byte, 64)
eventSig[0] = 0x10
contactList := &event.E{
ID: eventID,
Pubkey: authorPubkey,
CreatedAt: 1234567890,
Kind: 3,
Content: []byte(""),
Sig: eventSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(follow1)),
tag.NewFromAny("p", hex.Enc(follow2)),
tag.NewFromAny("p", hex.Enc(follow3)),
),
}
_, err = db.SaveEvent(ctx, contactList)
if err != nil {
t.Fatalf("Failed to save contact list: %v", err)
}
// Get author serial
authorSerial, err := db.GetPubkeySerial(authorPubkey)
if err != nil {
t.Fatalf("Failed to get author serial: %v", err)
}
// Get follows
follows, err := db.GetFollowsFromPubkeySerial(authorSerial)
if err != nil {
t.Fatalf("GetFollowsFromPubkeySerial failed: %v", err)
}
if len(follows) != 3 {
t.Errorf("Expected 3 follows, got %d", len(follows))
}
// Verify the follows are correct
expectedFollows := map[string]bool{
hex.Enc(follow1): false,
hex.Enc(follow2): false,
hex.Enc(follow3): false,
}
for _, serial := range follows {
pubkey, err := db.GetPubkeyBySerial(serial)
if err != nil {
t.Errorf("Failed to get pubkey from serial: %v", err)
continue
}
pkHex := hex.Enc(pubkey)
if _, exists := expectedFollows[pkHex]; exists {
expectedFollows[pkHex] = true
} else {
t.Errorf("Unexpected follow: %s", pkHex)
}
}
for pk, found := range expectedFollows {
if !found {
t.Errorf("Expected follow not found: %s", pk)
}
}
}
func TestGraphResult(t *testing.T) {
result := NewGraphResult()
// Add pubkeys at different depths
result.AddPubkeyAtDepth("pubkey1", 1)
result.AddPubkeyAtDepth("pubkey2", 1)
result.AddPubkeyAtDepth("pubkey3", 2)
result.AddPubkeyAtDepth("pubkey4", 2)
result.AddPubkeyAtDepth("pubkey5", 3)
// Try to add duplicate
added := result.AddPubkeyAtDepth("pubkey1", 2)
if added {
t.Error("Should not add duplicate pubkey")
}
// Verify counts
if result.TotalPubkeys != 5 {
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys)
}
// Verify depth tracking
if result.GetPubkeyDepth("pubkey1") != 1 {
t.Errorf("pubkey1 should be at depth 1")
}
if result.GetPubkeyDepth("pubkey3") != 2 {
t.Errorf("pubkey3 should be at depth 2")
}
// Verify HasPubkey
if !result.HasPubkey("pubkey1") {
t.Error("Should have pubkey1")
}
if result.HasPubkey("nonexistent") {
t.Error("Should not have nonexistent pubkey")
}
// Verify ToDepthArrays
arrays := result.ToDepthArrays()
if len(arrays) != 3 {
t.Errorf("Expected 3 depth arrays, got %d", len(arrays))
}
if len(arrays[0]) != 2 {
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(arrays[0]))
}
if len(arrays[1]) != 2 {
t.Errorf("Expected 2 pubkeys at depth 2, got %d", len(arrays[1]))
}
if len(arrays[2]) != 1 {
t.Errorf("Expected 1 pubkey at depth 3, got %d", len(arrays[2]))
}
}
func TestGraphResultRefs(t *testing.T) {
result := NewGraphResult()
// Add some pubkeys
result.AddPubkeyAtDepth("pubkey1", 1)
result.AddEventAtDepth("event1", 1)
// Add inbound refs (kind 7 reactions)
result.AddInboundRef(7, "event1", "reaction1")
result.AddInboundRef(7, "event1", "reaction2")
result.AddInboundRef(7, "event1", "reaction3")
// Get sorted refs
refs := result.GetInboundRefsSorted(7)
if len(refs) != 1 {
t.Fatalf("Expected 1 aggregation, got %d", len(refs))
}
if refs[0].RefCount != 3 {
t.Errorf("Expected 3 refs, got %d", refs[0].RefCount)
}
if refs[0].TargetEventID != "event1" {
t.Errorf("Expected event1, got %s", refs[0].TargetEventID)
}
}
func TestGetFollowersOfPubkeySerial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create target pubkey (the one being followed)
targetPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
// Create followers
follower1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
follower2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
// Create kind-3 contact lists for followers
for i, followerPubkey := range [][]byte{follower1, follower2} {
eventID := make([]byte, 32)
eventID[0] = byte(0x10 + i)
eventSig := make([]byte, 64)
eventSig[0] = byte(0x10 + i)
contactList := &event.E{
ID: eventID,
Pubkey: followerPubkey,
CreatedAt: int64(1234567890 + i),
Kind: 3,
Content: []byte(""),
Sig: eventSig,
Tags: tag.NewS(
tag.NewFromAny("p", hex.Enc(targetPubkey)),
),
}
_, err = db.SaveEvent(ctx, contactList)
if err != nil {
t.Fatalf("Failed to save contact list %d: %v", i, err)
}
}
// Get target serial
targetSerial, err := db.GetPubkeySerial(targetPubkey)
if err != nil {
t.Fatalf("Failed to get target serial: %v", err)
}
// Get followers
followers, err := db.GetFollowersOfPubkeySerial(targetSerial)
if err != nil {
t.Fatalf("GetFollowersOfPubkeySerial failed: %v", err)
}
if len(followers) != 2 {
t.Errorf("Expected 2 followers, got %d", len(followers))
}
// Verify the followers
expectedFollowers := map[string]bool{
hex.Enc(follower1): false,
hex.Enc(follower2): false,
}
for _, serial := range followers {
pubkey, err := db.GetPubkeyBySerial(serial)
if err != nil {
t.Errorf("Failed to get pubkey from serial: %v", err)
continue
}
pkHex := hex.Enc(pubkey)
if _, exists := expectedFollowers[pkHex]; exists {
expectedFollowers[pkHex] = true
} else {
t.Errorf("Unexpected follower: %s", pkHex)
}
}
for pk, found := range expectedFollowers {
if !found {
t.Errorf("Expected follower not found: %s", pk)
}
}
}
func TestPubkeyHexToSerial(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := New(ctx, cancel, t.TempDir(), "info")
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create a pubkey by saving an event
pubkeyBytes, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
eventID := make([]byte, 32)
eventID[0] = 0x10
eventSig := make([]byte, 64)
eventSig[0] = 0x10
ev := &event.E{
ID: eventID,
Pubkey: pubkeyBytes,
CreatedAt: 1234567890,
Kind: 1,
Content: []byte("Test"),
Sig: eventSig,
Tags: &tag.S{},
}
_, err = db.SaveEvent(ctx, ev)
if err != nil {
t.Fatalf("Failed to save event: %v", err)
}
// Convert hex to serial
pubkeyHex := hex.Enc(pubkeyBytes)
serial, err := db.PubkeyHexToSerial(pubkeyHex)
if err != nil {
t.Fatalf("PubkeyHexToSerial failed: %v", err)
}
if serial == nil {
t.Fatal("Expected non-nil serial")
}
// Convert back and verify
backToHex, err := db.GetPubkeyHexFromSerial(serial)
if err != nil {
t.Fatalf("GetPubkeyHexFromSerial failed: %v", err)
}
if backToHex != pubkeyHex {
t.Errorf("Round-trip failed: %s != %s", backToHex, pubkeyHex)
}
}

View File

@@ -85,6 +85,10 @@ const (
// Compact event storage indexes
SerialEventIdPrefix = I("sei") // event serial -> full 32-byte event ID
CompactEventPrefix = I("cmp") // compact event storage with serial references
// Event-to-event graph indexes (for e-tag references)
EventEventGraphPrefix = I("eeg") // source event serial -> target event serial (outbound e-tags)
GraphEventEventPrefix = I("gee") // target event serial -> source event serial (reverse e-tags)
)
// Prefix returns the three byte human-readable prefixes that go in front of
@@ -142,6 +146,11 @@ func Prefix(prf int) (i I) {
return SerialEventIdPrefix
case CompactEvent:
return CompactEventPrefix
case EventEventGraph:
return EventEventGraphPrefix
case GraphEventEvent:
return GraphEventEventPrefix
}
return
}
@@ -205,6 +214,11 @@ func Identify(r io.Reader) (i int, err error) {
i = SerialEventId
case CompactEventPrefix:
i = CompactEvent
case EventEventGraphPrefix:
i = EventEventGraph
case GraphEventEventPrefix:
i = GraphEventEvent
}
return
}
@@ -655,3 +669,38 @@ func CompactEventEnc(ser *types.Uint40) (enc *T) {
return New(NewPrefix(CompactEvent), ser)
}
func CompactEventDec(ser *types.Uint40) (enc *T) { return New(NewPrefix(), ser) }
// EventEventGraph creates a bidirectional graph edge between events via e-tags.
// This stores source_event_serial -> target_event_serial relationships with event kind and direction.
// Used for thread traversal and finding replies/reactions/reposts to events.
// Direction: 0=outbound (this event references target)
//
// 3 prefix|5 source event serial|5 target event serial|2 kind|1 direction
var EventEventGraph = next()
func EventEventGraphVars() (srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) {
return new(types.Uint40), new(types.Uint40), new(types.Uint16), new(types.Letter)
}
func EventEventGraphEnc(srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) (enc *T) {
return New(NewPrefix(EventEventGraph), srcSer, tgtSer, kind, direction)
}
func EventEventGraphDec(srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) (enc *T) {
return New(NewPrefix(), srcSer, tgtSer, kind, direction)
}
// GraphEventEvent creates the reverse edge: target_event_serial -> source_event_serial with kind and direction.
// This enables querying all events that reference a target event (e.g., all replies to a post).
// Direction: 1=inbound (target is referenced by source)
//
// 3 prefix|5 target event serial|2 kind|1 direction|5 source event serial
var GraphEventEvent = next()
func GraphEventEventVars() (tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) {
return new(types.Uint40), new(types.Uint16), new(types.Letter), new(types.Uint40)
}
func GraphEventEventEnc(tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) (enc *T) {
return New(NewPrefix(GraphEventEvent), tgtSer, kind, direction, srcSer)
}
func GraphEventEventDec(tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) (enc *T) {
return New(NewPrefix(), tgtSer, kind, direction, srcSer)
}

View File

@@ -15,6 +15,12 @@ const (
EdgeDirectionPTagIn byte = 2 // Inbound: This pubkey is referenced in event's p-tag
)
// Edge direction constants for event-to-event (e-tag) graph relationships
const (
EdgeDirectionETagOut byte = 0 // Outbound: This event references target event via e-tag
EdgeDirectionETagIn byte = 1 // Inbound: This event is referenced by source event via e-tag
)
type Letter struct {
val byte
}

View File

@@ -349,24 +349,74 @@ func (d *D) SaveEvent(c context.Context, ev *event.E) (
}
// Create event -> pubkey edge (with kind and direction)
keyBuf := new(bytes.Buffer)
if err = indexes.EventPubkeyGraphEnc(ser, pkInfo.serial, eventKind, directionForward).MarshalWrite(keyBuf); chk.E(err) {
epgKeyBuf := new(bytes.Buffer)
if err = indexes.EventPubkeyGraphEnc(ser, pkInfo.serial, eventKind, directionForward).MarshalWrite(epgKeyBuf); chk.E(err) {
return
}
if err = txn.Set(keyBuf.Bytes(), nil); chk.E(err) {
// Make a copy of the key bytes to avoid buffer reuse issues in txn
epgKey := make([]byte, epgKeyBuf.Len())
copy(epgKey, epgKeyBuf.Bytes())
if err = txn.Set(epgKey, nil); chk.E(err) {
return
}
// Create pubkey -> event edge (reverse, with kind and direction for filtering)
keyBuf.Reset()
if err = indexes.PubkeyEventGraphEnc(pkInfo.serial, eventKind, directionReverse, ser).MarshalWrite(keyBuf); chk.E(err) {
pegKeyBuf := new(bytes.Buffer)
if err = indexes.PubkeyEventGraphEnc(pkInfo.serial, eventKind, directionReverse, ser).MarshalWrite(pegKeyBuf); chk.E(err) {
return
}
if err = txn.Set(keyBuf.Bytes(), nil); chk.E(err) {
if err = txn.Set(pegKeyBuf.Bytes(), nil); chk.E(err) {
return
}
}
// Create event-to-event graph edges for e-tags
// This enables thread traversal and finding replies/reactions to events
eTags := ev.Tags.GetAll([]byte("e"))
for _, eTag := range eTags {
if eTag.Len() >= 2 {
// Get event ID from e-tag, handling both binary and hex storage formats
var targetEventID []byte
if targetEventID, err = hex.Dec(string(eTag.ValueHex())); err != nil || len(targetEventID) != 32 {
continue
}
// Look up the target event's serial (if it exists in our database)
var targetSerial *types.Uint40
if targetSerial, err = d.GetSerialById(targetEventID); err != nil {
// Target event not in our database - skip edge creation
// This is normal for replies to events we don't have
err = nil
continue
}
// Create forward edge: source event -> target event (outbound e-tag)
directionOut := new(types.Letter)
directionOut.Set(types.EdgeDirectionETagOut)
eegKeyBuf := new(bytes.Buffer)
if err = indexes.EventEventGraphEnc(ser, targetSerial, eventKind, directionOut).MarshalWrite(eegKeyBuf); chk.E(err) {
return
}
// Make a copy of the key bytes to avoid buffer reuse issues in txn
eegKey := make([]byte, eegKeyBuf.Len())
copy(eegKey, eegKeyBuf.Bytes())
if err = txn.Set(eegKey, nil); chk.E(err) {
return
}
// Create reverse edge: target event -> source event (inbound e-tag)
directionIn := new(types.Letter)
directionIn.Set(types.EdgeDirectionETagIn)
geeKeyBuf := new(bytes.Buffer)
if err = indexes.GraphEventEventEnc(targetSerial, eventKind, directionIn, ser).MarshalWrite(geeKeyBuf); chk.E(err) {
return
}
if err = txn.Set(geeKeyBuf.Bytes(), nil); chk.E(err) {
return
}
}
}
return
},
)

View File

@@ -0,0 +1,202 @@
//go:build !(js && wasm)
// Package graph implements NIP-XX Graph Query protocol support.
// This file contains the executor that runs graph traversal queries.
package graph
import (
"encoding/json"
"strconv"
"time"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"git.mleku.dev/mleku/nostr/encoders/event"
"git.mleku.dev/mleku/nostr/encoders/hex"
"git.mleku.dev/mleku/nostr/encoders/tag"
"git.mleku.dev/mleku/nostr/interfaces/signer"
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k"
)
// Response kinds for graph queries (ephemeral range, relay-signed)
const (
KindGraphFollows = 39000 // Response for follows/followers queries
KindGraphMentions = 39001 // Response for mentions queries
KindGraphThread = 39002 // Response for thread traversal queries
)
// GraphResultI is the interface that database.GraphResult implements.
// This allows the executor to work with the database result without importing it.
type GraphResultI interface {
ToDepthArrays() [][]string
ToEventDepthArrays() [][]string
GetAllPubkeys() []string
GetAllEvents() []string
GetPubkeysByDepth() map[int][]string
GetEventsByDepth() map[int][]string
GetTotalPubkeys() int
GetTotalEvents() int
}
// GraphDatabase defines the interface for graph traversal operations.
// This is implemented by the database package.
type GraphDatabase interface {
// TraverseFollows performs BFS traversal of follow graph
TraverseFollows(seedPubkey []byte, maxDepth int) (GraphResultI, error)
// TraverseFollowers performs BFS traversal to find followers
TraverseFollowers(seedPubkey []byte, maxDepth int) (GraphResultI, error)
// FindMentions finds events mentioning a pubkey
FindMentions(pubkey []byte, kinds []uint16) (GraphResultI, error)
// TraverseThread performs BFS traversal of thread structure
TraverseThread(seedEventID []byte, maxDepth int, direction string) (GraphResultI, error)
}
// Executor handles graph query execution and response generation.
type Executor struct {
db GraphDatabase
relaySigner signer.I
relayPubkey []byte
}
// NewExecutor creates a new graph query executor.
// The secretKey should be the 32-byte relay identity secret key.
func NewExecutor(db GraphDatabase, secretKey []byte) (*Executor, error) {
s, err := p8k.New()
if err != nil {
return nil, err
}
if err = s.InitSec(secretKey); err != nil {
return nil, err
}
return &Executor{
db: db,
relaySigner: s,
relayPubkey: s.Pub(),
}, nil
}
// Execute runs a graph query and returns a relay-signed event with results.
func (e *Executor) Execute(q *Query) (*event.E, error) {
var result GraphResultI
var err error
var responseKind uint16
// Decode seed (hex string to bytes)
seedBytes, err := hex.Dec(q.Seed)
if err != nil {
return nil, err
}
// Execute the appropriate traversal
switch q.Method {
case "follows":
responseKind = KindGraphFollows
result, err = e.db.TraverseFollows(seedBytes, q.Depth)
if err != nil {
return nil, err
}
log.D.F("graph executor: follows traversal returned %d pubkeys", result.GetTotalPubkeys())
case "followers":
responseKind = KindGraphFollows
result, err = e.db.TraverseFollowers(seedBytes, q.Depth)
if err != nil {
return nil, err
}
log.D.F("graph executor: followers traversal returned %d pubkeys", result.GetTotalPubkeys())
case "mentions":
responseKind = KindGraphMentions
// Mentions don't use depth traversal, just find direct mentions
// Convert RefSpec kinds to uint16 for the database call
var kinds []uint16
if len(q.InboundRefs) > 0 {
for _, rs := range q.InboundRefs {
for _, k := range rs.Kinds {
kinds = append(kinds, uint16(k))
}
}
} else {
kinds = []uint16{1} // Default to kind 1 (notes)
}
result, err = e.db.FindMentions(seedBytes, kinds)
if err != nil {
return nil, err
}
log.D.F("graph executor: mentions query returned %d events", result.GetTotalEvents())
case "thread":
responseKind = KindGraphThread
result, err = e.db.TraverseThread(seedBytes, q.Depth, "both")
if err != nil {
return nil, err
}
log.D.F("graph executor: thread traversal returned %d events", result.GetTotalEvents())
default:
return nil, ErrInvalidMethod
}
// Generate response event
return e.generateResponse(q, result, responseKind)
}
// generateResponse creates a relay-signed event containing the query results.
func (e *Executor) generateResponse(q *Query, result GraphResultI, responseKind uint16) (*event.E, error) {
// Build content as JSON with depth arrays
var content ResponseContent
if q.Method == "follows" || q.Method == "followers" {
// For pubkey-based queries, use pubkeys_by_depth
content.PubkeysByDepth = result.ToDepthArrays()
content.TotalPubkeys = result.GetTotalPubkeys()
} else {
// For event-based queries, use events_by_depth
content.EventsByDepth = result.ToEventDepthArrays()
content.TotalEvents = result.GetTotalEvents()
}
contentBytes, err := json.Marshal(content)
if err != nil {
return nil, err
}
// Build tags
tags := tag.NewS(
tag.NewFromAny("method", q.Method),
tag.NewFromAny("seed", q.Seed),
tag.NewFromAny("depth", strconv.Itoa(q.Depth)),
)
// Create event
ev := &event.E{
Kind: responseKind,
CreatedAt: time.Now().Unix(),
Tags: tags,
Content: contentBytes,
}
// Sign with relay identity
if err = ev.Sign(e.relaySigner); chk.E(err) {
return nil, err
}
return ev, nil
}
// ResponseContent is the JSON structure for graph query responses.
type ResponseContent struct {
// PubkeysByDepth contains arrays of pubkeys at each depth (1-indexed)
// Each pubkey appears ONLY at the depth where it was first discovered.
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"`
// EventsByDepth contains arrays of event IDs at each depth (1-indexed)
EventsByDepth [][]string `json:"events_by_depth,omitempty"`
// TotalPubkeys is the total count of unique pubkeys discovered
TotalPubkeys int `json:"total_pubkeys,omitempty"`
// TotalEvents is the total count of unique events discovered
TotalEvents int `json:"total_events,omitempty"`
}

183
pkg/protocol/graph/query.go Normal file
View File

@@ -0,0 +1,183 @@
// Package graph implements NIP-XX Graph Query protocol support.
// It provides types and functions for parsing and validating graph traversal queries.
package graph
import (
"encoding/json"
"errors"
"git.mleku.dev/mleku/nostr/encoders/filter"
)
// Query represents a graph traversal query from a _graph filter extension.
type Query struct {
// Method is the traversal method: "follows", "followers", "mentions", "thread"
Method string `json:"method"`
// Seed is the starting point for traversal (pubkey hex or event ID hex)
Seed string `json:"seed"`
// Depth is the maximum traversal depth (1-16, default: 1)
Depth int `json:"depth,omitempty"`
// InboundRefs specifies which inbound references to collect
// (events that reference discovered events via e-tags)
InboundRefs []RefSpec `json:"inbound_refs,omitempty"`
// OutboundRefs specifies which outbound references to collect
// (events referenced by discovered events via e-tags)
OutboundRefs []RefSpec `json:"outbound_refs,omitempty"`
}
// RefSpec specifies which event references to include in results.
type RefSpec struct {
// Kinds is the list of event kinds to match (OR semantics within this spec)
Kinds []int `json:"kinds"`
// FromDepth specifies the minimum depth at which to collect refs (default: 0)
// 0 = include refs from seed itself
// 1 = start from first-hop connections
FromDepth int `json:"from_depth,omitempty"`
}
// Validation errors
var (
ErrMissingMethod = errors.New("_graph.method is required")
ErrInvalidMethod = errors.New("_graph.method must be one of: follows, followers, mentions, thread")
ErrMissingSeed = errors.New("_graph.seed is required")
ErrInvalidSeed = errors.New("_graph.seed must be a 64-character hex string")
ErrDepthTooHigh = errors.New("_graph.depth cannot exceed 16")
ErrEmptyRefSpecKinds = errors.New("ref spec kinds array cannot be empty")
)
// Valid method names
var validMethods = map[string]bool{
"follows": true,
"followers": true,
"mentions": true,
"thread": true,
}
// Validate checks the query for correctness and applies defaults.
func (q *Query) Validate() error {
// Method is required
if q.Method == "" {
return ErrMissingMethod
}
if !validMethods[q.Method] {
return ErrInvalidMethod
}
// Seed is required
if q.Seed == "" {
return ErrMissingSeed
}
if len(q.Seed) != 64 {
return ErrInvalidSeed
}
// Validate hex characters
for _, c := range q.Seed {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return ErrInvalidSeed
}
}
// Apply depth defaults and limits
if q.Depth < 1 {
q.Depth = 1
}
if q.Depth > 16 {
return ErrDepthTooHigh
}
// Validate ref specs
for _, rs := range q.InboundRefs {
if len(rs.Kinds) == 0 {
return ErrEmptyRefSpecKinds
}
}
for _, rs := range q.OutboundRefs {
if len(rs.Kinds) == 0 {
return ErrEmptyRefSpecKinds
}
}
return nil
}
// HasInboundRefs returns true if the query includes inbound reference collection.
func (q *Query) HasInboundRefs() bool {
return len(q.InboundRefs) > 0
}
// HasOutboundRefs returns true if the query includes outbound reference collection.
func (q *Query) HasOutboundRefs() bool {
return len(q.OutboundRefs) > 0
}
// HasRefs returns true if the query includes any reference collection.
func (q *Query) HasRefs() bool {
return q.HasInboundRefs() || q.HasOutboundRefs()
}
// InboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
// It aggregates all RefSpecs where from_depth <= depth.
func (q *Query) InboundKindsAtDepth(depth int) map[int]bool {
kinds := make(map[int]bool)
for _, rs := range q.InboundRefs {
if rs.FromDepth <= depth {
for _, k := range rs.Kinds {
kinds[k] = true
}
}
}
return kinds
}
// OutboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
func (q *Query) OutboundKindsAtDepth(depth int) map[int]bool {
kinds := make(map[int]bool)
for _, rs := range q.OutboundRefs {
if rs.FromDepth <= depth {
for _, k := range rs.Kinds {
kinds[k] = true
}
}
}
return kinds
}
// ExtractFromFilter checks if a filter has a _graph extension and parses it.
// Returns nil if no _graph field is present.
// Returns an error if _graph is present but invalid.
func ExtractFromFilter(f *filter.F) (*Query, error) {
if f == nil || f.Extra == nil {
return nil, nil
}
raw, ok := f.Extra["_graph"]
if !ok {
return nil, nil
}
var q Query
if err := json.Unmarshal(raw, &q); err != nil {
return nil, err
}
if err := q.Validate(); err != nil {
return nil, err
}
return &q, nil
}
// IsGraphQuery returns true if the filter contains a _graph extension.
// This is a quick check that doesn't parse the full query.
func IsGraphQuery(f *filter.F) bool {
if f == nil || f.Extra == nil {
return false
}
_, ok := f.Extra["_graph"]
return ok
}

View File

@@ -0,0 +1,397 @@
package graph
import (
"testing"
"git.mleku.dev/mleku/nostr/encoders/filter"
)
func TestQueryValidate(t *testing.T) {
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
tests := []struct {
name string
query Query
wantErr error
}{
{
name: "valid follows query",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 2,
},
wantErr: nil,
},
{
name: "valid followers query",
query: Query{
Method: "followers",
Seed: validSeed,
},
wantErr: nil,
},
{
name: "valid mentions query",
query: Query{
Method: "mentions",
Seed: validSeed,
Depth: 1,
},
wantErr: nil,
},
{
name: "valid thread query",
query: Query{
Method: "thread",
Seed: validSeed,
Depth: 10,
},
wantErr: nil,
},
{
name: "valid query with inbound refs",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 1},
},
},
wantErr: nil,
},
{
name: "valid query with multiple ref specs",
query: Query{
Method: "follows",
Seed: validSeed,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 1},
{Kinds: []int{6}, FromDepth: 1},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}, FromDepth: 0},
},
},
wantErr: nil,
},
{
name: "missing method",
query: Query{Seed: validSeed},
wantErr: ErrMissingMethod,
},
{
name: "invalid method",
query: Query{
Method: "invalid",
Seed: validSeed,
},
wantErr: ErrInvalidMethod,
},
{
name: "missing seed",
query: Query{
Method: "follows",
},
wantErr: ErrMissingSeed,
},
{
name: "seed too short",
query: Query{
Method: "follows",
Seed: "abc123",
},
wantErr: ErrInvalidSeed,
},
{
name: "seed with invalid characters",
query: Query{
Method: "follows",
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg",
},
wantErr: ErrInvalidSeed,
},
{
name: "depth too high",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 17,
},
wantErr: ErrDepthTooHigh,
},
{
name: "empty ref spec kinds",
query: Query{
Method: "follows",
Seed: validSeed,
InboundRefs: []RefSpec{
{Kinds: []int{}, FromDepth: 1},
},
},
wantErr: ErrEmptyRefSpecKinds,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.query.Validate()
if tt.wantErr == nil {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err != tt.wantErr {
t.Errorf("error = %v, want %v", err, tt.wantErr)
}
}
})
}
}
func TestQueryDefaults(t *testing.T) {
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
q := Query{
Method: "follows",
Seed: validSeed,
Depth: 0, // Should default to 1
}
err := q.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if q.Depth != 1 {
t.Errorf("Depth = %d, want 1 (default)", q.Depth)
}
}
func TestKindsAtDepth(t *testing.T) {
q := Query{
Method: "follows",
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
Depth: 3,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 0}, // From seed
{Kinds: []int{6, 16}, FromDepth: 1}, // From depth 1
{Kinds: []int{9735}, FromDepth: 2}, // From depth 2
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}, FromDepth: 1},
},
}
// Test inbound kinds at depth 0
kinds0 := q.InboundKindsAtDepth(0)
if !kinds0[7] || kinds0[6] || kinds0[9735] {
t.Errorf("InboundKindsAtDepth(0) = %v, want only kind 7", kinds0)
}
// Test inbound kinds at depth 1
kinds1 := q.InboundKindsAtDepth(1)
if !kinds1[7] || !kinds1[6] || !kinds1[16] || kinds1[9735] {
t.Errorf("InboundKindsAtDepth(1) = %v, want kinds 7, 6, 16", kinds1)
}
// Test inbound kinds at depth 2
kinds2 := q.InboundKindsAtDepth(2)
if !kinds2[7] || !kinds2[6] || !kinds2[16] || !kinds2[9735] {
t.Errorf("InboundKindsAtDepth(2) = %v, want all kinds", kinds2)
}
// Test outbound kinds at depth 0
outKinds0 := q.OutboundKindsAtDepth(0)
if len(outKinds0) != 0 {
t.Errorf("OutboundKindsAtDepth(0) = %v, want empty", outKinds0)
}
// Test outbound kinds at depth 1
outKinds1 := q.OutboundKindsAtDepth(1)
if !outKinds1[1] {
t.Errorf("OutboundKindsAtDepth(1) = %v, want kind 1", outKinds1)
}
}
func TestExtractFromFilter(t *testing.T) {
tests := []struct {
name string
filterJSON string
wantQuery bool
wantErr bool
}{
{
name: "filter with valid graph query",
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":2}}`,
wantQuery: true,
wantErr: false,
},
{
name: "filter without graph query",
filterJSON: `{"kinds":[1,7]}`,
wantQuery: false,
wantErr: false,
},
{
name: "filter with invalid graph query (missing method)",
filterJSON: `{"kinds":[1],"_graph":{"seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}}`,
wantQuery: false,
wantErr: true,
},
{
name: "filter with complex graph query",
filterJSON: `{"kinds":[0],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":3,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`,
wantQuery: true,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &filter.F{}
_, err := f.Unmarshal([]byte(tt.filterJSON))
if err != nil {
t.Fatalf("failed to unmarshal filter: %v", err)
}
q, err := ExtractFromFilter(f)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if tt.wantQuery && q == nil {
t.Error("expected query, got nil")
}
if !tt.wantQuery && q != nil {
t.Errorf("expected nil query, got %+v", q)
}
})
}
}
func TestIsGraphQuery(t *testing.T) {
tests := []struct {
name string
filterJSON string
want bool
}{
{
name: "filter with graph query",
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"abc"}}`,
want: true,
},
{
name: "filter without graph query",
filterJSON: `{"kinds":[1,7]}`,
want: false,
},
{
name: "filter with other extension",
filterJSON: `{"kinds":[1],"_custom":"value"}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &filter.F{}
_, err := f.Unmarshal([]byte(tt.filterJSON))
if err != nil {
t.Fatalf("failed to unmarshal filter: %v", err)
}
got := IsGraphQuery(f)
if got != tt.want {
t.Errorf("IsGraphQuery() = %v, want %v", got, tt.want)
}
})
}
}
func TestQueryHasRefs(t *testing.T) {
tests := []struct {
name string
query Query
hasInbound bool
hasOutbound bool
hasRefs bool
}{
{
name: "no refs",
query: Query{
Method: "follows",
Seed: "abc",
},
hasInbound: false,
hasOutbound: false,
hasRefs: false,
},
{
name: "only inbound refs",
query: Query{
Method: "follows",
Seed: "abc",
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
},
hasInbound: true,
hasOutbound: false,
hasRefs: true,
},
{
name: "only outbound refs",
query: Query{
Method: "follows",
Seed: "abc",
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
hasInbound: false,
hasOutbound: true,
hasRefs: true,
},
{
name: "both refs",
query: Query{
Method: "follows",
Seed: "abc",
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
hasInbound: true,
hasOutbound: true,
hasRefs: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.query.HasInboundRefs(); got != tt.hasInbound {
t.Errorf("HasInboundRefs() = %v, want %v", got, tt.hasInbound)
}
if got := tt.query.HasOutboundRefs(); got != tt.hasOutbound {
t.Errorf("HasOutboundRefs() = %v, want %v", got, tt.hasOutbound)
}
if got := tt.query.HasRefs(); got != tt.hasRefs {
t.Errorf("HasRefs() = %v, want %v", got, tt.hasRefs)
}
})
}
}

View File

@@ -0,0 +1,282 @@
package graph
import (
"context"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter with adaptive throttling
// based on graph query complexity. It allows cooperative scheduling by inserting
// pauses between operations to allow other work to proceed.
type RateLimiter struct {
mu sync.Mutex
// Token bucket parameters
tokens float64 // Current available tokens
maxTokens float64 // Maximum token capacity
refillRate float64 // Tokens per second to add
lastRefill time.Time // Last time tokens were refilled
// Throttling parameters
baseDelay time.Duration // Minimum delay between operations
maxDelay time.Duration // Maximum delay for complex queries
depthFactor float64 // Multiplier per depth level
limitFactor float64 // Multiplier based on result limit
}
// RateLimiterConfig configures the rate limiter behavior.
type RateLimiterConfig struct {
// MaxTokens is the maximum number of tokens in the bucket (default: 100)
MaxTokens float64
// RefillRate is tokens added per second (default: 10)
RefillRate float64
// BaseDelay is the minimum delay between operations (default: 1ms)
BaseDelay time.Duration
// MaxDelay is the maximum delay for complex queries (default: 100ms)
MaxDelay time.Duration
// DepthFactor is the cost multiplier per depth level (default: 2.0)
// A depth-3 query costs 2^3 = 8x more tokens than depth-1
DepthFactor float64
// LimitFactor is additional cost per 100 results requested (default: 0.1)
LimitFactor float64
}
// DefaultRateLimiterConfig returns sensible defaults for the rate limiter.
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
MaxTokens: 100.0,
RefillRate: 10.0, // Refills fully in 10 seconds
BaseDelay: 1 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
DepthFactor: 2.0,
LimitFactor: 0.1,
}
}
// NewRateLimiter creates a new rate limiter with the given configuration.
func NewRateLimiter(cfg RateLimiterConfig) *RateLimiter {
if cfg.MaxTokens <= 0 {
cfg.MaxTokens = DefaultRateLimiterConfig().MaxTokens
}
if cfg.RefillRate <= 0 {
cfg.RefillRate = DefaultRateLimiterConfig().RefillRate
}
if cfg.BaseDelay <= 0 {
cfg.BaseDelay = DefaultRateLimiterConfig().BaseDelay
}
if cfg.MaxDelay <= 0 {
cfg.MaxDelay = DefaultRateLimiterConfig().MaxDelay
}
if cfg.DepthFactor <= 0 {
cfg.DepthFactor = DefaultRateLimiterConfig().DepthFactor
}
if cfg.LimitFactor <= 0 {
cfg.LimitFactor = DefaultRateLimiterConfig().LimitFactor
}
return &RateLimiter{
tokens: cfg.MaxTokens,
maxTokens: cfg.MaxTokens,
refillRate: cfg.RefillRate,
lastRefill: time.Now(),
baseDelay: cfg.BaseDelay,
maxDelay: cfg.MaxDelay,
depthFactor: cfg.DepthFactor,
limitFactor: cfg.LimitFactor,
}
}
// QueryCost calculates the token cost for a graph query based on its complexity.
// Higher depths and larger limits cost exponentially more tokens.
func (rl *RateLimiter) QueryCost(q *Query) float64 {
if q == nil {
return 1.0
}
// Base cost is exponential in depth: depthFactor^depth
// This models the exponential growth of traversal work
cost := 1.0
for i := 0; i < q.Depth; i++ {
cost *= rl.depthFactor
}
// Add cost for reference collection (adds ~50% per ref spec)
refCost := float64(len(q.InboundRefs)+len(q.OutboundRefs)) * 0.5
cost += refCost
return cost
}
// OperationCost calculates the token cost for a single traversal operation.
// This is used during query execution for per-operation throttling.
func (rl *RateLimiter) OperationCost(depth int, nodesAtDepth int) float64 {
// Cost increases with depth and number of nodes to process
depthMultiplier := 1.0
for i := 0; i < depth; i++ {
depthMultiplier *= rl.depthFactor
}
// More nodes at this depth = more work
nodeFactor := 1.0 + float64(nodesAtDepth)*0.01
return depthMultiplier * nodeFactor
}
// refillTokens adds tokens based on elapsed time since last refill.
func (rl *RateLimiter) refillTokens() {
now := time.Now()
elapsed := now.Sub(rl.lastRefill).Seconds()
rl.lastRefill = now
rl.tokens += elapsed * rl.refillRate
if rl.tokens > rl.maxTokens {
rl.tokens = rl.maxTokens
}
}
// Acquire tries to acquire tokens for a query. If not enough tokens are available,
// it waits until they become available or the context is cancelled.
// Returns the delay that was applied, or an error if context was cancelled.
func (rl *RateLimiter) Acquire(ctx context.Context, cost float64) (time.Duration, error) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
var totalDelay time.Duration
// Wait until we have enough tokens
for rl.tokens < cost {
// Calculate how long we need to wait for tokens to refill
tokensNeeded := cost - rl.tokens
waitTime := time.Duration(tokensNeeded/rl.refillRate*1000) * time.Millisecond
// Clamp to max delay
if waitTime > rl.maxDelay {
waitTime = rl.maxDelay
}
if waitTime < rl.baseDelay {
waitTime = rl.baseDelay
}
// Release lock while waiting
rl.mu.Unlock()
select {
case <-ctx.Done():
rl.mu.Lock()
return totalDelay, ctx.Err()
case <-time.After(waitTime):
}
totalDelay += waitTime
rl.mu.Lock()
rl.refillTokens()
}
// Consume tokens
rl.tokens -= cost
return totalDelay, nil
}
// TryAcquire attempts to acquire tokens without waiting.
// Returns true if successful, false if insufficient tokens.
func (rl *RateLimiter) TryAcquire(cost float64) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
if rl.tokens >= cost {
rl.tokens -= cost
return true
}
return false
}
// Pause inserts a cooperative delay to allow other work to proceed.
// The delay is proportional to the current depth and load.
// This should be called periodically during long-running traversals.
func (rl *RateLimiter) Pause(ctx context.Context, depth int, itemsProcessed int) error {
// Calculate adaptive delay based on depth and progress
// Deeper traversals and more processed items = longer pauses
delay := rl.baseDelay
// Increase delay with depth
for i := 0; i < depth; i++ {
delay += rl.baseDelay
}
// Add extra delay every N items to allow other work
if itemsProcessed > 0 && itemsProcessed%100 == 0 {
delay += rl.baseDelay * 5
}
// Cap at max delay
if delay > rl.maxDelay {
delay = rl.maxDelay
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
return nil
}
}
// AvailableTokens returns the current number of available tokens.
func (rl *RateLimiter) AvailableTokens() float64 {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
return rl.tokens
}
// Throttler provides a simple interface for cooperative scheduling during traversal.
// It wraps the rate limiter and provides depth-aware throttling.
type Throttler struct {
rl *RateLimiter
depth int
itemsProcessed int
}
// NewThrottler creates a throttler for a specific traversal operation.
func NewThrottler(rl *RateLimiter, depth int) *Throttler {
return &Throttler{
rl: rl,
depth: depth,
}
}
// Tick should be called after processing each item.
// It tracks progress and inserts pauses as needed.
func (t *Throttler) Tick(ctx context.Context) error {
t.itemsProcessed++
// Insert cooperative pause periodically
// More frequent pauses at higher depths
interval := 50
if t.depth >= 2 {
interval = 25
}
if t.depth >= 4 {
interval = 10
}
if t.itemsProcessed%interval == 0 {
return t.rl.Pause(ctx, t.depth, t.itemsProcessed)
}
return nil
}
// Complete marks the throttler as complete and returns stats.
func (t *Throttler) Complete() (itemsProcessed int) {
return t.itemsProcessed
}

View File

@@ -0,0 +1,267 @@
package graph
import (
"context"
"testing"
"time"
)
func TestRateLimiterQueryCost(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
tests := []struct {
name string
query *Query
minCost float64
maxCost float64
}{
{
name: "nil query",
query: nil,
minCost: 1.0,
maxCost: 1.0,
},
{
name: "depth 1 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 1,
},
minCost: 1.5, // depthFactor^1 = 2
maxCost: 2.5,
},
{
name: "depth 2 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
},
minCost: 3.5, // depthFactor^2 = 4
maxCost: 4.5,
},
{
name: "depth 3 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 3,
},
minCost: 7.5, // depthFactor^3 = 8
maxCost: 8.5,
},
{
name: "depth 2 with inbound refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
},
minCost: 4.0, // 4 + 0.5 = 4.5
maxCost: 5.0,
},
{
name: "depth 2 with both refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
minCost: 4.5, // 4 + 0.5 + 0.5 = 5
maxCost: 5.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost := rl.QueryCost(tt.query)
if cost < tt.minCost || cost > tt.maxCost {
t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
}
})
}
}
func TestRateLimiterOperationCost(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
// Depth 0, 1 node
cost0 := rl.OperationCost(0, 1)
if cost0 < 1.0 || cost0 > 1.1 {
t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
}
// Depth 1, 1 node
cost1 := rl.OperationCost(1, 1)
if cost1 < 2.0 || cost1 > 2.1 {
t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
}
// Depth 2, 100 nodes
cost2 := rl.OperationCost(2, 100)
if cost2 < 8.0 {
t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
}
}
func TestRateLimiterAcquire(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
cfg.RefillRate = 100 // Fast refill for testing
rl := NewRateLimiter(cfg)
ctx := context.Background()
// Should acquire immediately when tokens available
delay, err := rl.Acquire(ctx, 5)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if delay > time.Millisecond*10 {
t.Errorf("expected minimal delay, got %v", delay)
}
// Check remaining tokens
remaining := rl.AvailableTokens()
if remaining > 6 {
t.Errorf("expected ~5 tokens remaining, got %v", remaining)
}
}
func TestRateLimiterTryAcquire(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
rl := NewRateLimiter(cfg)
// Should succeed with enough tokens
if !rl.TryAcquire(5) {
t.Error("TryAcquire(5) should succeed with 10 tokens")
}
// Should succeed again
if !rl.TryAcquire(5) {
t.Error("TryAcquire(5) should succeed with 5 tokens")
}
// Should fail with insufficient tokens
if rl.TryAcquire(1) {
t.Error("TryAcquire(1) should fail with 0 tokens")
}
}
func TestRateLimiterContextCancellation(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 1
cfg.RefillRate = 0.1 // Very slow refill
rl := NewRateLimiter(cfg)
// Drain tokens
rl.TryAcquire(1)
// Create cancellable context
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
// Try to acquire - should be cancelled
_, err := rl.Acquire(ctx, 10)
if err != context.DeadlineExceeded {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestRateLimiterRefill(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
cfg.RefillRate = 1000 // 1000 tokens per second
rl := NewRateLimiter(cfg)
// Drain tokens
rl.TryAcquire(10)
// Wait for refill
time.Sleep(15 * time.Millisecond)
// Should have some tokens now
available := rl.AvailableTokens()
if available < 5 {
t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
}
if available > 10 {
t.Errorf("expected <= 10 tokens (max), got %v", available)
}
}
func TestRateLimiterPause(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
ctx := context.Background()
start := time.Now()
err := rl.Pause(ctx, 1, 0)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have paused for at least baseDelay
if elapsed < rl.baseDelay {
t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
}
}
func TestThrottler(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.BaseDelay = 100 * time.Microsecond // Short for testing
rl := NewRateLimiter(cfg)
throttler := NewThrottler(rl, 1)
ctx := context.Background()
// Process items
for i := 0; i < 100; i++ {
if err := throttler.Tick(ctx); err != nil {
t.Fatalf("unexpected error at tick %d: %v", i, err)
}
}
processed := throttler.Complete()
if processed != 100 {
t.Errorf("expected 100 items processed, got %d", processed)
}
}
func TestThrottlerContextCancellation(t *testing.T) {
cfg := DefaultRateLimiterConfig()
rl := NewRateLimiter(cfg)
throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
ctx, cancel := context.WithCancel(context.Background())
// Process some items
for i := 0; i < 20; i++ {
throttler.Tick(ctx)
}
// Cancel context
cancel()
// Next tick that would pause should return error
for i := 0; i < 100; i++ {
if err := throttler.Tick(ctx); err != nil {
// Expected - context was cancelled
return
}
}
// If we get here without error, the throttler didn't check context
// This is acceptable if no pause was needed
}

View File

@@ -1 +1 @@
v0.33.1
v0.34.0