add first draft graph query implementation
Some checks failed
Go / build-and-release (push) Has been cancelled
Some checks failed
Go / build-and-release (push) Has been cancelled
This commit is contained in:
460
pkg/database/etag-graph_test.go
Normal file
460
pkg/database/etag-graph_test.go
Normal file
@@ -0,0 +1,460 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"next.orly.dev/pkg/database/indexes"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
"git.mleku.dev/mleku/nostr/encoders/event"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
"git.mleku.dev/mleku/nostr/encoders/tag"
|
||||
)
|
||||
|
||||
func TestETagGraphEdgeCreation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a parent event (the post being replied to)
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
parentID := make([]byte, 32)
|
||||
parentID[0] = 0x10
|
||||
parentSig := make([]byte, 64)
|
||||
parentSig[0] = 0x10
|
||||
|
||||
parentEvent := &event.E{
|
||||
ID: parentID,
|
||||
Pubkey: parentPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("This is the parent post"),
|
||||
Sig: parentSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, parentEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save parent event: %v", err)
|
||||
}
|
||||
|
||||
// Create a reply event with e-tag pointing to parent
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
replyID := make([]byte, 32)
|
||||
replyID[0] = 0x20
|
||||
replySig := make([]byte, 64)
|
||||
replySig[0] = 0x20
|
||||
|
||||
replyEvent := &event.E{
|
||||
ID: replyID,
|
||||
Pubkey: replyPubkey,
|
||||
CreatedAt: 1234567891,
|
||||
Kind: 1,
|
||||
Content: []byte("This is a reply"),
|
||||
Sig: replySig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, replyEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reply event: %v", err)
|
||||
}
|
||||
|
||||
// Get serials for both events
|
||||
parentSerial, err := db.GetSerialById(parentID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get parent serial: %v", err)
|
||||
}
|
||||
replySerial, err := db.GetSerialById(replyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reply serial: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Parent serial: %d, Reply serial: %d", parentSerial.Get(), replySerial.Get())
|
||||
|
||||
// Verify forward edge exists (reply -> parent)
|
||||
forwardFound := false
|
||||
prefix := []byte(indexes.EventEventGraphPrefix)
|
||||
|
||||
err = db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
key := item.KeyCopy(nil)
|
||||
|
||||
// Decode the key
|
||||
srcSer, tgtSer, kind, direction := indexes.EventEventGraphVars()
|
||||
keyReader := bytes.NewReader(key)
|
||||
if err := indexes.EventEventGraphDec(srcSer, tgtSer, kind, direction).UnmarshalRead(keyReader); err != nil {
|
||||
t.Logf("Failed to decode key: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is our edge
|
||||
if srcSer.Get() == replySerial.Get() && tgtSer.Get() == parentSerial.Get() {
|
||||
forwardFound = true
|
||||
if direction.Letter() != types.EdgeDirectionETagOut {
|
||||
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagOut, direction.Letter())
|
||||
}
|
||||
if kind.Get() != 1 {
|
||||
t.Errorf("Expected kind 1, got %d", kind.Get())
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("View failed: %v", err)
|
||||
}
|
||||
if !forwardFound {
|
||||
t.Error("Forward edge (reply -> parent) should exist")
|
||||
}
|
||||
|
||||
// Verify reverse edge exists (parent <- reply)
|
||||
reverseFound := false
|
||||
prefix = []byte(indexes.GraphEventEventPrefix)
|
||||
|
||||
err = db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
key := item.KeyCopy(nil)
|
||||
|
||||
// Decode the key
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
|
||||
keyReader := bytes.NewReader(key)
|
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
|
||||
t.Logf("Failed to decode key: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("Found gee edge: tgt=%d kind=%d dir=%d src=%d",
|
||||
tgtSer.Get(), kind.Get(), direction.Letter(), srcSer.Get())
|
||||
|
||||
// Check if this is our edge
|
||||
if tgtSer.Get() == parentSerial.Get() && srcSer.Get() == replySerial.Get() {
|
||||
reverseFound = true
|
||||
if direction.Letter() != types.EdgeDirectionETagIn {
|
||||
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagIn, direction.Letter())
|
||||
}
|
||||
if kind.Get() != 1 {
|
||||
t.Errorf("Expected kind 1, got %d", kind.Get())
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("View failed: %v", err)
|
||||
}
|
||||
if !reverseFound {
|
||||
t.Error("Reverse edge (parent <- reply) should exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestETagGraphMultipleReplies(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
parentID := make([]byte, 32)
|
||||
parentID[0] = 0x10
|
||||
parentSig := make([]byte, 64)
|
||||
parentSig[0] = 0x10
|
||||
|
||||
parentEvent := &event.E{
|
||||
ID: parentID,
|
||||
Pubkey: parentPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Parent post"),
|
||||
Sig: parentSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, parentEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save parent: %v", err)
|
||||
}
|
||||
|
||||
// Create multiple replies
|
||||
numReplies := 5
|
||||
for i := 0; i < numReplies; i++ {
|
||||
replyPubkey := make([]byte, 32)
|
||||
replyPubkey[0] = byte(i + 0x20)
|
||||
replyID := make([]byte, 32)
|
||||
replyID[0] = byte(i + 0x30)
|
||||
replySig := make([]byte, 64)
|
||||
replySig[0] = byte(i + 0x30)
|
||||
|
||||
replyEvent := &event.E{
|
||||
ID: replyID,
|
||||
Pubkey: replyPubkey,
|
||||
CreatedAt: int64(1234567891 + i),
|
||||
Kind: 1,
|
||||
Content: []byte("Reply"),
|
||||
Sig: replySig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err := db.SaveEvent(ctx, replyEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reply %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Count inbound edges to parent
|
||||
parentSerial, err := db.GetSerialById(parentID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get parent serial: %v", err)
|
||||
}
|
||||
|
||||
inboundCount := 0
|
||||
prefix := []byte(indexes.GraphEventEventPrefix)
|
||||
|
||||
err = db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
key := item.KeyCopy(nil)
|
||||
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
|
||||
keyReader := bytes.NewReader(key)
|
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if tgtSer.Get() == parentSerial.Get() {
|
||||
inboundCount++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("View failed: %v", err)
|
||||
}
|
||||
|
||||
if inboundCount != numReplies {
|
||||
t.Errorf("Expected %d inbound edges, got %d", numReplies, inboundCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestETagGraphDifferentKinds(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a parent event (kind 1 - note)
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
parentID := make([]byte, 32)
|
||||
parentID[0] = 0x10
|
||||
parentSig := make([]byte, 64)
|
||||
parentSig[0] = 0x10
|
||||
|
||||
parentEvent := &event.E{
|
||||
ID: parentID,
|
||||
Pubkey: parentPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("A note"),
|
||||
Sig: parentSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, parentEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save parent: %v", err)
|
||||
}
|
||||
|
||||
// Create a reaction (kind 7)
|
||||
reactionPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
reactionID := make([]byte, 32)
|
||||
reactionID[0] = 0x20
|
||||
reactionSig := make([]byte, 64)
|
||||
reactionSig[0] = 0x20
|
||||
|
||||
reactionEvent := &event.E{
|
||||
ID: reactionID,
|
||||
Pubkey: reactionPubkey,
|
||||
CreatedAt: 1234567891,
|
||||
Kind: 7,
|
||||
Content: []byte("+"),
|
||||
Sig: reactionSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, reactionEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reaction: %v", err)
|
||||
}
|
||||
|
||||
// Create a repost (kind 6)
|
||||
repostPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
repostID := make([]byte, 32)
|
||||
repostID[0] = 0x30
|
||||
repostSig := make([]byte, 64)
|
||||
repostSig[0] = 0x30
|
||||
|
||||
repostEvent := &event.E{
|
||||
ID: repostID,
|
||||
Pubkey: repostPubkey,
|
||||
CreatedAt: 1234567892,
|
||||
Kind: 6,
|
||||
Content: []byte(""),
|
||||
Sig: repostSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, repostEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save repost: %v", err)
|
||||
}
|
||||
|
||||
// Query inbound edges by kind
|
||||
parentSerial, err := db.GetSerialById(parentID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get parent serial: %v", err)
|
||||
}
|
||||
|
||||
kindCounts := make(map[uint16]int)
|
||||
prefix := []byte(indexes.GraphEventEventPrefix)
|
||||
|
||||
err = db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
key := item.KeyCopy(nil)
|
||||
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars()
|
||||
keyReader := bytes.NewReader(key)
|
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if tgtSer.Get() == parentSerial.Get() {
|
||||
kindCounts[kind.Get()]++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("View failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify we have edges for each kind
|
||||
if kindCounts[7] != 1 {
|
||||
t.Errorf("Expected 1 kind-7 (reaction) edge, got %d", kindCounts[7])
|
||||
}
|
||||
if kindCounts[6] != 1 {
|
||||
t.Errorf("Expected 1 kind-6 (repost) edge, got %d", kindCounts[6])
|
||||
}
|
||||
}
|
||||
|
||||
func TestETagGraphUnknownTarget(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create an event with e-tag pointing to non-existent event
|
||||
unknownID := make([]byte, 32)
|
||||
unknownID[0] = 0xFF
|
||||
unknownID[31] = 0xFF
|
||||
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
replyID := make([]byte, 32)
|
||||
replyID[0] = 0x10
|
||||
replySig := make([]byte, 64)
|
||||
replySig[0] = 0x10
|
||||
|
||||
replyEvent := &event.E{
|
||||
ID: replyID,
|
||||
Pubkey: replyPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Reply to unknown"),
|
||||
Sig: replySig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(unknownID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, replyEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reply: %v", err)
|
||||
}
|
||||
|
||||
// Verify event was saved
|
||||
replySerial, err := db.GetSerialById(replyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reply serial: %v", err)
|
||||
}
|
||||
if replySerial == nil {
|
||||
t.Fatal("Reply serial should exist")
|
||||
}
|
||||
|
||||
// Verify no forward edge was created (since target doesn't exist)
|
||||
edgeCount := 0
|
||||
prefix := []byte(indexes.EventEventGraphPrefix)
|
||||
|
||||
err = db.View(func(txn *badger.Txn) error {
|
||||
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
|
||||
item := it.Item()
|
||||
key := item.KeyCopy(nil)
|
||||
|
||||
srcSer, _, _, _ := indexes.EventEventGraphVars()
|
||||
keyReader := bytes.NewReader(key)
|
||||
if err := indexes.EventEventGraphDec(srcSer, new(types.Uint40), new(types.Uint16), new(types.Letter)).UnmarshalRead(keyReader); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if srcSer.Get() == replySerial.Get() {
|
||||
edgeCount++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("View failed: %v", err)
|
||||
}
|
||||
|
||||
if edgeCount != 0 {
|
||||
t.Errorf("Expected no edges for unknown target, got %d", edgeCount)
|
||||
}
|
||||
}
|
||||
42
pkg/database/graph-adapter.go
Normal file
42
pkg/database/graph-adapter.go
Normal file
@@ -0,0 +1,42 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"next.orly.dev/pkg/protocol/graph"
|
||||
)
|
||||
|
||||
// GraphAdapter wraps a database instance and implements graph.GraphDatabase interface.
|
||||
// This allows the graph executor to call database traversal methods without
|
||||
// the database package importing the graph package.
|
||||
type GraphAdapter struct {
|
||||
db *D
|
||||
}
|
||||
|
||||
// NewGraphAdapter creates a new GraphAdapter wrapping the given database.
|
||||
func NewGraphAdapter(db *D) *GraphAdapter {
|
||||
return &GraphAdapter{db: db}
|
||||
}
|
||||
|
||||
// TraverseFollows implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseFollows(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) {
|
||||
return a.db.TraverseFollows(seedPubkey, maxDepth)
|
||||
}
|
||||
|
||||
// TraverseFollowers implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseFollowers(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) {
|
||||
return a.db.TraverseFollowers(seedPubkey, maxDepth)
|
||||
}
|
||||
|
||||
// FindMentions implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) FindMentions(pubkey []byte, kinds []uint16) (graph.GraphResultI, error) {
|
||||
return a.db.FindMentions(pubkey, kinds)
|
||||
}
|
||||
|
||||
// TraverseThread implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseThread(seedEventID []byte, maxDepth int, direction string) (graph.GraphResultI, error) {
|
||||
return a.db.TraverseThread(seedEventID, maxDepth, direction)
|
||||
}
|
||||
|
||||
// Verify GraphAdapter implements graph.GraphDatabase
|
||||
var _ graph.GraphDatabase = (*GraphAdapter)(nil)
|
||||
199
pkg/database/graph-follows.go
Normal file
199
pkg/database/graph-follows.go
Normal file
@@ -0,0 +1,199 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"lol.mleku.dev/log"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
)
|
||||
|
||||
// TraverseFollows performs BFS traversal of the follow graph starting from a seed pubkey.
|
||||
// Returns pubkeys grouped by first-discovered depth (no duplicates across depths).
|
||||
//
|
||||
// The traversal works by:
|
||||
// 1. Starting with the seed pubkey at depth 0 (not included in results)
|
||||
// 2. For each pubkey at the current depth, find their kind-3 contact list
|
||||
// 3. Extract p-tags from the contact list to get follows
|
||||
// 4. Add new (unseen) follows to the next depth
|
||||
// 5. Continue until maxDepth is reached or no new pubkeys are found
|
||||
//
|
||||
// Early termination occurs if two consecutive depths yield no new pubkeys.
|
||||
func (d *D) TraverseFollows(seedPubkey []byte, maxDepth int) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(seedPubkey) != 32 {
|
||||
return result, ErrPubkeyNotFound
|
||||
}
|
||||
|
||||
// Get seed pubkey serial
|
||||
seedSerial, err := d.GetPubkeySerial(seedPubkey)
|
||||
if err != nil {
|
||||
log.D.F("TraverseFollows: seed pubkey not in database: %s", hex.Enc(seedPubkey))
|
||||
return result, nil // Not an error - just no results
|
||||
}
|
||||
|
||||
// Track visited pubkeys by serial to avoid cycles
|
||||
visited := make(map[uint64]bool)
|
||||
visited[seedSerial.Get()] = true // Mark seed as visited but don't add to results
|
||||
|
||||
// Current frontier (pubkeys to process at this depth)
|
||||
currentFrontier := []*types.Uint40{seedSerial}
|
||||
|
||||
// Track consecutive empty depths for early termination
|
||||
consecutiveEmptyDepths := 0
|
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
|
||||
var nextFrontier []*types.Uint40
|
||||
newPubkeysAtDepth := 0
|
||||
|
||||
for _, pubkeySerial := range currentFrontier {
|
||||
// Get follows for this pubkey
|
||||
follows, err := d.GetFollowsFromPubkeySerial(pubkeySerial)
|
||||
if err != nil {
|
||||
log.D.F("TraverseFollows: error getting follows for serial %d: %v", pubkeySerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, followSerial := range follows {
|
||||
// Skip if already visited
|
||||
if visited[followSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
visited[followSerial.Get()] = true
|
||||
|
||||
// Get pubkey hex for result
|
||||
pubkeyHex, err := d.GetPubkeyHexFromSerial(followSerial)
|
||||
if err != nil {
|
||||
log.D.F("TraverseFollows: error getting pubkey hex for serial %d: %v", followSerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to results at this depth
|
||||
result.AddPubkeyAtDepth(pubkeyHex, currentDepth)
|
||||
newPubkeysAtDepth++
|
||||
|
||||
// Add to next frontier for further traversal
|
||||
nextFrontier = append(nextFrontier, followSerial)
|
||||
}
|
||||
}
|
||||
|
||||
log.T.F("TraverseFollows: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth)
|
||||
|
||||
// Check for early termination
|
||||
if newPubkeysAtDepth == 0 {
|
||||
consecutiveEmptyDepths++
|
||||
if consecutiveEmptyDepths >= 2 {
|
||||
log.T.F("TraverseFollows: early termination at depth %d (2 consecutive empty depths)", currentDepth)
|
||||
break
|
||||
}
|
||||
} else {
|
||||
consecutiveEmptyDepths = 0
|
||||
}
|
||||
|
||||
// Move to next depth
|
||||
currentFrontier = nextFrontier
|
||||
}
|
||||
|
||||
log.D.F("TraverseFollows: completed with %d total pubkeys across %d depths",
|
||||
result.TotalPubkeys, len(result.PubkeysByDepth))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TraverseFollowers performs BFS traversal to find who follows the seed pubkey.
|
||||
// This is the reverse of TraverseFollows - it finds users whose kind-3 lists
|
||||
// contain the target pubkey(s).
|
||||
//
|
||||
// At each depth:
|
||||
// - Depth 1: Users who directly follow the seed
|
||||
// - Depth 2: Users who follow anyone at depth 1 (followers of followers)
|
||||
// - etc.
|
||||
func (d *D) TraverseFollowers(seedPubkey []byte, maxDepth int) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(seedPubkey) != 32 {
|
||||
return result, ErrPubkeyNotFound
|
||||
}
|
||||
|
||||
// Get seed pubkey serial
|
||||
seedSerial, err := d.GetPubkeySerial(seedPubkey)
|
||||
if err != nil {
|
||||
log.D.F("TraverseFollowers: seed pubkey not in database: %s", hex.Enc(seedPubkey))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Track visited pubkeys
|
||||
visited := make(map[uint64]bool)
|
||||
visited[seedSerial.Get()] = true
|
||||
|
||||
// Current frontier
|
||||
currentFrontier := []*types.Uint40{seedSerial}
|
||||
|
||||
consecutiveEmptyDepths := 0
|
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
|
||||
var nextFrontier []*types.Uint40
|
||||
newPubkeysAtDepth := 0
|
||||
|
||||
for _, targetSerial := range currentFrontier {
|
||||
// Get followers of this pubkey
|
||||
followers, err := d.GetFollowersOfPubkeySerial(targetSerial)
|
||||
if err != nil {
|
||||
log.D.F("TraverseFollowers: error getting followers for serial %d: %v", targetSerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, followerSerial := range followers {
|
||||
if visited[followerSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
visited[followerSerial.Get()] = true
|
||||
|
||||
pubkeyHex, err := d.GetPubkeyHexFromSerial(followerSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
result.AddPubkeyAtDepth(pubkeyHex, currentDepth)
|
||||
newPubkeysAtDepth++
|
||||
nextFrontier = append(nextFrontier, followerSerial)
|
||||
}
|
||||
}
|
||||
|
||||
log.T.F("TraverseFollowers: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth)
|
||||
|
||||
if newPubkeysAtDepth == 0 {
|
||||
consecutiveEmptyDepths++
|
||||
if consecutiveEmptyDepths >= 2 {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
consecutiveEmptyDepths = 0
|
||||
}
|
||||
|
||||
currentFrontier = nextFrontier
|
||||
}
|
||||
|
||||
log.D.F("TraverseFollowers: completed with %d total pubkeys", result.TotalPubkeys)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TraverseFollowsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) TraverseFollowsFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) {
|
||||
seedPubkey, err := hex.Dec(seedPubkeyHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.TraverseFollows(seedPubkey, maxDepth)
|
||||
}
|
||||
|
||||
// TraverseFollowersFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) TraverseFollowersFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) {
|
||||
seedPubkey, err := hex.Dec(seedPubkeyHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.TraverseFollowers(seedPubkey, maxDepth)
|
||||
}
|
||||
318
pkg/database/graph-follows_test.go
Normal file
318
pkg/database/graph-follows_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
"git.mleku.dev/mleku/nostr/encoders/tag"
|
||||
)
|
||||
|
||||
func TestTraverseFollows(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a simple follow graph:
|
||||
// Alice -> Bob, Carol
|
||||
// Bob -> David, Eve
|
||||
// Carol -> Eve, Frank
|
||||
//
|
||||
// Expected depth 1 from Alice: Bob, Carol
|
||||
// Expected depth 2 from Alice: David, Eve, Frank (Eve deduplicated)
|
||||
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
david, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004")
|
||||
eve, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000005")
|
||||
frank, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000006")
|
||||
|
||||
// Create Alice's follow list (kind 3)
|
||||
aliceContactID := make([]byte, 32)
|
||||
aliceContactID[0] = 0x10
|
||||
aliceContactSig := make([]byte, 64)
|
||||
aliceContactSig[0] = 0x10
|
||||
aliceContact := &event.E{
|
||||
ID: aliceContactID,
|
||||
Pubkey: alice,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: aliceContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(bob)),
|
||||
tag.NewFromAny("p", hex.Enc(carol)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, aliceContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save Alice's contact list: %v", err)
|
||||
}
|
||||
|
||||
// Create Bob's follow list
|
||||
bobContactID := make([]byte, 32)
|
||||
bobContactID[0] = 0x20
|
||||
bobContactSig := make([]byte, 64)
|
||||
bobContactSig[0] = 0x20
|
||||
bobContact := &event.E{
|
||||
ID: bobContactID,
|
||||
Pubkey: bob,
|
||||
CreatedAt: 1234567891,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: bobContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(david)),
|
||||
tag.NewFromAny("p", hex.Enc(eve)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, bobContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save Bob's contact list: %v", err)
|
||||
}
|
||||
|
||||
// Create Carol's follow list
|
||||
carolContactID := make([]byte, 32)
|
||||
carolContactID[0] = 0x30
|
||||
carolContactSig := make([]byte, 64)
|
||||
carolContactSig[0] = 0x30
|
||||
carolContact := &event.E{
|
||||
ID: carolContactID,
|
||||
Pubkey: carol,
|
||||
CreatedAt: 1234567892,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: carolContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(eve)),
|
||||
tag.NewFromAny("p", hex.Enc(frank)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, carolContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save Carol's contact list: %v", err)
|
||||
}
|
||||
|
||||
// Traverse follows from Alice with depth 2
|
||||
result, err := db.TraverseFollows(alice, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("TraverseFollows failed: %v", err)
|
||||
}
|
||||
|
||||
// Check depth 1: should have Bob and Carol
|
||||
depth1 := result.GetPubkeysAtDepth(1)
|
||||
if len(depth1) != 2 {
|
||||
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(depth1))
|
||||
}
|
||||
|
||||
depth1Set := make(map[string]bool)
|
||||
for _, pk := range depth1 {
|
||||
depth1Set[pk] = true
|
||||
}
|
||||
if !depth1Set[hex.Enc(bob)] {
|
||||
t.Error("Bob should be at depth 1")
|
||||
}
|
||||
if !depth1Set[hex.Enc(carol)] {
|
||||
t.Error("Carol should be at depth 1")
|
||||
}
|
||||
|
||||
// Check depth 2: should have David, Eve, Frank (Eve deduplicated)
|
||||
depth2 := result.GetPubkeysAtDepth(2)
|
||||
if len(depth2) != 3 {
|
||||
t.Errorf("Expected 3 pubkeys at depth 2, got %d: %v", len(depth2), depth2)
|
||||
}
|
||||
|
||||
depth2Set := make(map[string]bool)
|
||||
for _, pk := range depth2 {
|
||||
depth2Set[pk] = true
|
||||
}
|
||||
if !depth2Set[hex.Enc(david)] {
|
||||
t.Error("David should be at depth 2")
|
||||
}
|
||||
if !depth2Set[hex.Enc(eve)] {
|
||||
t.Error("Eve should be at depth 2")
|
||||
}
|
||||
if !depth2Set[hex.Enc(frank)] {
|
||||
t.Error("Frank should be at depth 2")
|
||||
}
|
||||
|
||||
// Verify total count
|
||||
if result.TotalPubkeys != 5 {
|
||||
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys)
|
||||
}
|
||||
|
||||
// Verify ToDepthArrays output
|
||||
arrays := result.ToDepthArrays()
|
||||
if len(arrays) != 2 {
|
||||
t.Errorf("Expected 2 depth arrays, got %d", len(arrays))
|
||||
}
|
||||
if len(arrays[0]) != 2 {
|
||||
t.Errorf("Expected 2 pubkeys in depth 1 array, got %d", len(arrays[0]))
|
||||
}
|
||||
if len(arrays[1]) != 3 {
|
||||
t.Errorf("Expected 3 pubkeys in depth 2 array, got %d", len(arrays[1]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraverseFollowsDepth1(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
|
||||
// Create Alice's follow list
|
||||
aliceContactID := make([]byte, 32)
|
||||
aliceContactID[0] = 0x10
|
||||
aliceContactSig := make([]byte, 64)
|
||||
aliceContactSig[0] = 0x10
|
||||
aliceContact := &event.E{
|
||||
ID: aliceContactID,
|
||||
Pubkey: alice,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: aliceContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(bob)),
|
||||
tag.NewFromAny("p", hex.Enc(carol)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, aliceContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save contact list: %v", err)
|
||||
}
|
||||
|
||||
// Traverse with depth 1 only
|
||||
result, err := db.TraverseFollows(alice, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("TraverseFollows failed: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalPubkeys != 2 {
|
||||
t.Errorf("Expected 2 pubkeys, got %d", result.TotalPubkeys)
|
||||
}
|
||||
|
||||
arrays := result.ToDepthArrays()
|
||||
if len(arrays) != 1 {
|
||||
t.Errorf("Expected 1 depth array for depth 1 query, got %d", len(arrays))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraverseFollowersBasic(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create scenario: Bob and Carol follow Alice
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
|
||||
// Bob's contact list includes Alice
|
||||
bobContactID := make([]byte, 32)
|
||||
bobContactID[0] = 0x10
|
||||
bobContactSig := make([]byte, 64)
|
||||
bobContactSig[0] = 0x10
|
||||
bobContact := &event.E{
|
||||
ID: bobContactID,
|
||||
Pubkey: bob,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: bobContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(alice)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, bobContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save Bob's contact list: %v", err)
|
||||
}
|
||||
|
||||
// Carol's contact list includes Alice
|
||||
carolContactID := make([]byte, 32)
|
||||
carolContactID[0] = 0x20
|
||||
carolContactSig := make([]byte, 64)
|
||||
carolContactSig[0] = 0x20
|
||||
carolContact := &event.E{
|
||||
ID: carolContactID,
|
||||
Pubkey: carol,
|
||||
CreatedAt: 1234567891,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: carolContactSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(alice)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, carolContact)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save Carol's contact list: %v", err)
|
||||
}
|
||||
|
||||
// Find Alice's followers
|
||||
result, err := db.TraverseFollowers(alice, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("TraverseFollowers failed: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalPubkeys != 2 {
|
||||
t.Errorf("Expected 2 followers, got %d", result.TotalPubkeys)
|
||||
}
|
||||
|
||||
followers := result.GetPubkeysAtDepth(1)
|
||||
followerSet := make(map[string]bool)
|
||||
for _, pk := range followers {
|
||||
followerSet[pk] = true
|
||||
}
|
||||
if !followerSet[hex.Enc(bob)] {
|
||||
t.Error("Bob should be a follower")
|
||||
}
|
||||
if !followerSet[hex.Enc(carol)] {
|
||||
t.Error("Carol should be a follower")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraverseFollowsNonExistent(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Try to traverse from a pubkey that doesn't exist
|
||||
nonExistent, _ := hex.Dec("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
|
||||
result, err := db.TraverseFollows(nonExistent, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("TraverseFollows should not error for non-existent pubkey: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalPubkeys != 0 {
|
||||
t.Errorf("Expected 0 pubkeys for non-existent seed, got %d", result.TotalPubkeys)
|
||||
}
|
||||
}
|
||||
91
pkg/database/graph-mentions.go
Normal file
91
pkg/database/graph-mentions.go
Normal file
@@ -0,0 +1,91 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"lol.mleku.dev/log"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
)
|
||||
|
||||
// FindMentions finds events that mention a pubkey via p-tags.
|
||||
// This returns events grouped by depth, where depth represents how the events relate:
|
||||
// - Depth 1: Events that directly mention the seed pubkey
|
||||
// - Depth 2+: Not typically used for mentions (reserved for future expansion)
|
||||
//
|
||||
// The kinds parameter filters which event kinds to include (e.g., [1] for notes only,
|
||||
// [1,7] for notes and reactions, etc.)
|
||||
func (d *D) FindMentions(pubkey []byte, kinds []uint16) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(pubkey) != 32 {
|
||||
return result, ErrPubkeyNotFound
|
||||
}
|
||||
|
||||
// Get pubkey serial
|
||||
pubkeySerial, err := d.GetPubkeySerial(pubkey)
|
||||
if err != nil {
|
||||
log.D.F("FindMentions: pubkey not in database: %s", hex.Enc(pubkey))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Find all events that reference this pubkey
|
||||
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add each event at depth 1
|
||||
for _, eventSerial := range eventSerials {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
|
||||
if err != nil {
|
||||
log.D.F("FindMentions: error getting event ID for serial %d: %v", eventSerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
result.AddEventAtDepth(eventIDHex, 1)
|
||||
}
|
||||
|
||||
log.D.F("FindMentions: found %d events mentioning pubkey %s", result.TotalEvents, hex.Enc(pubkey))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindMentionsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) FindMentionsFromHex(pubkeyHex string, kinds []uint16) (*GraphResult, error) {
|
||||
pubkey, err := hex.Dec(pubkeyHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.FindMentions(pubkey, kinds)
|
||||
}
|
||||
|
||||
// FindMentionsByPubkeys returns events that mention any of the given pubkeys.
|
||||
// Useful for finding mentions across a set of followed accounts.
|
||||
func (d *D) FindMentionsByPubkeys(pubkeySerials []*types.Uint40, kinds []uint16) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
seen := make(map[uint64]bool)
|
||||
|
||||
for _, pubkeySerial := range pubkeySerials {
|
||||
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds)
|
||||
if err != nil {
|
||||
log.D.F("FindMentionsByPubkeys: error for serial %d: %v", pubkeySerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, eventSerial := range eventSerials {
|
||||
if seen[eventSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
seen[eventSerial.Get()] = true
|
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result.AddEventAtDepth(eventIDHex, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
206
pkg/database/graph-refs.go
Normal file
206
pkg/database/graph-refs.go
Normal file
@@ -0,0 +1,206 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"lol.mleku.dev/log"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
)
|
||||
|
||||
// AddInboundRefsToResult collects inbound references (events that reference discovered items)
|
||||
// for events at a specific depth in the result.
|
||||
//
|
||||
// For example, if you have a follows graph result and want to find all kind-7 reactions
|
||||
// to posts by users at depth 1, this collects those reactions and adds them to result.InboundRefs.
|
||||
//
|
||||
// Parameters:
|
||||
// - result: The graph result to augment with ref data
|
||||
// - depth: The depth at which to collect refs (0 = all depths)
|
||||
// - kinds: Event kinds to collect (e.g., [7] for reactions, [6] for reposts)
|
||||
func (d *D) AddInboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error {
|
||||
// Determine which events to find refs for
|
||||
var targetEventIDs []string
|
||||
|
||||
if depth == 0 {
|
||||
// Collect for all depths
|
||||
targetEventIDs = result.GetAllEvents()
|
||||
} else {
|
||||
targetEventIDs = result.GetEventsAtDepth(depth)
|
||||
}
|
||||
|
||||
// Also collect refs for events authored by pubkeys in the result
|
||||
// This is common for "find reactions to posts by my follows" queries
|
||||
pubkeys := result.GetAllPubkeys()
|
||||
for _, pubkeyHex := range pubkeys {
|
||||
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get events authored by this pubkey
|
||||
// For efficiency, limit to relevant event kinds that might have reactions
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, []uint16{1, 30023}) // notes and articles
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, eventSerial := range authoredEvents {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Add to target list if not already tracking
|
||||
if !result.HasEvent(eventIDHex) {
|
||||
targetEventIDs = append(targetEventIDs, eventIDHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each target event, find referencing events
|
||||
for _, eventIDHex := range targetEventIDs {
|
||||
eventSerial, err := d.EventIDHexToSerial(eventIDHex)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
refSerials, err := d.GetReferencingEvents(eventSerial, kinds)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, refSerial := range refSerials {
|
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the kind of the referencing event
|
||||
// For now, use the first kind in the filter (assumes single kind queries)
|
||||
// TODO: Look up actual event kind from index if needed
|
||||
if len(kinds) > 0 {
|
||||
result.AddInboundRef(kinds[0], eventIDHex, refEventIDHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.D.F("AddInboundRefsToResult: collected refs for %d target events", len(targetEventIDs))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutboundRefsToResult collects outbound references (events referenced by discovered items).
|
||||
//
|
||||
// For example, find all events that posts by users at depth 1 reference (quoted posts, replied-to posts).
|
||||
func (d *D) AddOutboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error {
|
||||
// Determine source events
|
||||
var sourceEventIDs []string
|
||||
|
||||
if depth == 0 {
|
||||
sourceEventIDs = result.GetAllEvents()
|
||||
} else {
|
||||
sourceEventIDs = result.GetEventsAtDepth(depth)
|
||||
}
|
||||
|
||||
// Also include events authored by pubkeys in result
|
||||
pubkeys := result.GetAllPubkeys()
|
||||
for _, pubkeyHex := range pubkeys {
|
||||
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, kinds)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, eventSerial := range authoredEvents {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if !result.HasEvent(eventIDHex) {
|
||||
sourceEventIDs = append(sourceEventIDs, eventIDHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each source event, find referenced events
|
||||
for _, eventIDHex := range sourceEventIDs {
|
||||
eventSerial, err := d.EventIDHexToSerial(eventIDHex)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
refSerials, err := d.GetETagsFromEventSerial(eventSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, refSerial := range refSerials {
|
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use first kind for categorization
|
||||
if len(kinds) > 0 {
|
||||
result.AddOutboundRef(kinds[0], eventIDHex, refEventIDHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.D.F("AddOutboundRefsToResult: collected refs from %d source events", len(sourceEventIDs))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectRefsForPubkeys collects inbound references to events by specific pubkeys.
|
||||
// This is useful for "find all reactions to posts by these users" queries.
|
||||
//
|
||||
// Parameters:
|
||||
// - pubkeySerials: The pubkeys whose events should be checked for refs
|
||||
// - refKinds: Event kinds to collect (e.g., [7] for reactions)
|
||||
// - eventKinds: Event kinds to check for refs (e.g., [1] for notes)
|
||||
func (d *D) CollectRefsForPubkeys(
|
||||
pubkeySerials []*types.Uint40,
|
||||
refKinds []uint16,
|
||||
eventKinds []uint16,
|
||||
) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
for _, pubkeySerial := range pubkeySerials {
|
||||
// Get events by this author
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, eventKinds)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, eventSerial := range authoredEvents {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find refs to this event
|
||||
refSerials, err := d.GetReferencingEvents(eventSerial, refKinds)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, refSerial := range refSerials {
|
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to result
|
||||
if len(refKinds) > 0 {
|
||||
result.AddInboundRef(refKinds[0], eventIDHex, refEventIDHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
327
pkg/database/graph-result.go
Normal file
327
pkg/database/graph-result.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
// GraphResult contains depth-organized traversal results for graph queries.
|
||||
// It tracks pubkeys and events discovered at each depth level, ensuring
|
||||
// each entity appears only at the depth where it was first discovered.
|
||||
type GraphResult struct {
|
||||
// PubkeysByDepth maps depth -> pubkeys first discovered at that depth.
|
||||
// Each pubkey appears ONLY in the array for the depth where it was first seen.
|
||||
// Depth 1 = direct connections, Depth 2 = connections of connections, etc.
|
||||
PubkeysByDepth map[int][]string
|
||||
|
||||
// EventsByDepth maps depth -> event IDs discovered at that depth.
|
||||
// Used for thread traversal queries.
|
||||
EventsByDepth map[int][]string
|
||||
|
||||
// FirstSeenPubkey tracks which depth each pubkey was first discovered.
|
||||
// Key is pubkey hex, value is the depth (1-indexed).
|
||||
FirstSeenPubkey map[string]int
|
||||
|
||||
// FirstSeenEvent tracks which depth each event was first discovered.
|
||||
// Key is event ID hex, value is the depth (1-indexed).
|
||||
FirstSeenEvent map[string]int
|
||||
|
||||
// TotalPubkeys is the count of unique pubkeys discovered across all depths.
|
||||
TotalPubkeys int
|
||||
|
||||
// TotalEvents is the count of unique events discovered across all depths.
|
||||
TotalEvents int
|
||||
|
||||
// InboundRefs tracks inbound references (events that reference discovered items).
|
||||
// Structure: kind -> target_id -> []referencing_event_ids
|
||||
InboundRefs map[uint16]map[string][]string
|
||||
|
||||
// OutboundRefs tracks outbound references (events referenced by discovered items).
|
||||
// Structure: kind -> source_id -> []referenced_event_ids
|
||||
OutboundRefs map[uint16]map[string][]string
|
||||
}
|
||||
|
||||
// NewGraphResult creates a new initialized GraphResult.
|
||||
func NewGraphResult() *GraphResult {
|
||||
return &GraphResult{
|
||||
PubkeysByDepth: make(map[int][]string),
|
||||
EventsByDepth: make(map[int][]string),
|
||||
FirstSeenPubkey: make(map[string]int),
|
||||
FirstSeenEvent: make(map[string]int),
|
||||
InboundRefs: make(map[uint16]map[string][]string),
|
||||
OutboundRefs: make(map[uint16]map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPubkeyAtDepth adds a pubkey to the result at the specified depth if not already seen.
|
||||
// Returns true if the pubkey was added (first time seen), false if already exists.
|
||||
func (r *GraphResult) AddPubkeyAtDepth(pubkeyHex string, depth int) bool {
|
||||
if _, exists := r.FirstSeenPubkey[pubkeyHex]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
r.FirstSeenPubkey[pubkeyHex] = depth
|
||||
r.PubkeysByDepth[depth] = append(r.PubkeysByDepth[depth], pubkeyHex)
|
||||
r.TotalPubkeys++
|
||||
return true
|
||||
}
|
||||
|
||||
// AddEventAtDepth adds an event ID to the result at the specified depth if not already seen.
|
||||
// Returns true if the event was added (first time seen), false if already exists.
|
||||
func (r *GraphResult) AddEventAtDepth(eventIDHex string, depth int) bool {
|
||||
if _, exists := r.FirstSeenEvent[eventIDHex]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
r.FirstSeenEvent[eventIDHex] = depth
|
||||
r.EventsByDepth[depth] = append(r.EventsByDepth[depth], eventIDHex)
|
||||
r.TotalEvents++
|
||||
return true
|
||||
}
|
||||
|
||||
// HasPubkey returns true if the pubkey has been discovered at any depth.
|
||||
func (r *GraphResult) HasPubkey(pubkeyHex string) bool {
|
||||
_, exists := r.FirstSeenPubkey[pubkeyHex]
|
||||
return exists
|
||||
}
|
||||
|
||||
// HasEvent returns true if the event has been discovered at any depth.
|
||||
func (r *GraphResult) HasEvent(eventIDHex string) bool {
|
||||
_, exists := r.FirstSeenEvent[eventIDHex]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetPubkeyDepth returns the depth at which a pubkey was first discovered.
|
||||
// Returns 0 if the pubkey was not found.
|
||||
func (r *GraphResult) GetPubkeyDepth(pubkeyHex string) int {
|
||||
return r.FirstSeenPubkey[pubkeyHex]
|
||||
}
|
||||
|
||||
// GetEventDepth returns the depth at which an event was first discovered.
|
||||
// Returns 0 if the event was not found.
|
||||
func (r *GraphResult) GetEventDepth(eventIDHex string) int {
|
||||
return r.FirstSeenEvent[eventIDHex]
|
||||
}
|
||||
|
||||
// GetDepthsSorted returns all depths that have pubkeys, sorted ascending.
|
||||
func (r *GraphResult) GetDepthsSorted() []int {
|
||||
depths := make([]int, 0, len(r.PubkeysByDepth))
|
||||
for d := range r.PubkeysByDepth {
|
||||
depths = append(depths, d)
|
||||
}
|
||||
sort.Ints(depths)
|
||||
return depths
|
||||
}
|
||||
|
||||
// GetEventDepthsSorted returns all depths that have events, sorted ascending.
|
||||
func (r *GraphResult) GetEventDepthsSorted() []int {
|
||||
depths := make([]int, 0, len(r.EventsByDepth))
|
||||
for d := range r.EventsByDepth {
|
||||
depths = append(depths, d)
|
||||
}
|
||||
sort.Ints(depths)
|
||||
return depths
|
||||
}
|
||||
|
||||
// ToDepthArrays converts the result to the response format: array of arrays.
|
||||
// Index 0 = depth 1 pubkeys, Index 1 = depth 2 pubkeys, etc.
|
||||
// Empty arrays are included for depths with no pubkeys to maintain index alignment.
|
||||
func (r *GraphResult) ToDepthArrays() [][]string {
|
||||
if len(r.PubkeysByDepth) == 0 {
|
||||
return [][]string{}
|
||||
}
|
||||
|
||||
// Find the maximum depth
|
||||
maxDepth := 0
|
||||
for d := range r.PubkeysByDepth {
|
||||
if d > maxDepth {
|
||||
maxDepth = d
|
||||
}
|
||||
}
|
||||
|
||||
// Create result array with entries for each depth
|
||||
result := make([][]string, maxDepth)
|
||||
for i := 0; i < maxDepth; i++ {
|
||||
depth := i + 1 // depths are 1-indexed
|
||||
if pubkeys, exists := r.PubkeysByDepth[depth]; exists {
|
||||
result[i] = pubkeys
|
||||
} else {
|
||||
result[i] = []string{} // Empty array for depths with no pubkeys
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToEventDepthArrays converts event results to the response format: array of arrays.
|
||||
// Index 0 = depth 1 events, Index 1 = depth 2 events, etc.
|
||||
func (r *GraphResult) ToEventDepthArrays() [][]string {
|
||||
if len(r.EventsByDepth) == 0 {
|
||||
return [][]string{}
|
||||
}
|
||||
|
||||
maxDepth := 0
|
||||
for d := range r.EventsByDepth {
|
||||
if d > maxDepth {
|
||||
maxDepth = d
|
||||
}
|
||||
}
|
||||
|
||||
result := make([][]string, maxDepth)
|
||||
for i := 0; i < maxDepth; i++ {
|
||||
depth := i + 1
|
||||
if events, exists := r.EventsByDepth[depth]; exists {
|
||||
result[i] = events
|
||||
} else {
|
||||
result[i] = []string{}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AddInboundRef records an inbound reference from a referencing event to a target.
|
||||
func (r *GraphResult) AddInboundRef(kind uint16, targetIDHex string, referencingEventIDHex string) {
|
||||
if r.InboundRefs[kind] == nil {
|
||||
r.InboundRefs[kind] = make(map[string][]string)
|
||||
}
|
||||
r.InboundRefs[kind][targetIDHex] = append(r.InboundRefs[kind][targetIDHex], referencingEventIDHex)
|
||||
}
|
||||
|
||||
// AddOutboundRef records an outbound reference from a source event to a referenced event.
|
||||
func (r *GraphResult) AddOutboundRef(kind uint16, sourceIDHex string, referencedEventIDHex string) {
|
||||
if r.OutboundRefs[kind] == nil {
|
||||
r.OutboundRefs[kind] = make(map[string][]string)
|
||||
}
|
||||
r.OutboundRefs[kind][sourceIDHex] = append(r.OutboundRefs[kind][sourceIDHex], referencedEventIDHex)
|
||||
}
|
||||
|
||||
// RefAggregation represents aggregated reference data for a single target/source.
|
||||
type RefAggregation struct {
|
||||
// TargetEventID is the event ID being referenced (for inbound) or referencing (for outbound)
|
||||
TargetEventID string
|
||||
|
||||
// TargetAuthor is the author pubkey of the target event (if known)
|
||||
TargetAuthor string
|
||||
|
||||
// TargetDepth is the depth at which this target was discovered in the graph
|
||||
TargetDepth int
|
||||
|
||||
// RefKind is the kind of the referencing events
|
||||
RefKind uint16
|
||||
|
||||
// RefCount is the number of references to/from this target
|
||||
RefCount int
|
||||
|
||||
// RefEventIDs is the list of event IDs that reference this target
|
||||
RefEventIDs []string
|
||||
}
|
||||
|
||||
// GetInboundRefsSorted returns inbound refs for a kind, sorted by count descending.
|
||||
func (r *GraphResult) GetInboundRefsSorted(kind uint16) []RefAggregation {
|
||||
kindRefs := r.InboundRefs[kind]
|
||||
if kindRefs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aggs := make([]RefAggregation, 0, len(kindRefs))
|
||||
for targetID, refs := range kindRefs {
|
||||
agg := RefAggregation{
|
||||
TargetEventID: targetID,
|
||||
TargetDepth: r.GetEventDepth(targetID),
|
||||
RefKind: kind,
|
||||
RefCount: len(refs),
|
||||
RefEventIDs: refs,
|
||||
}
|
||||
aggs = append(aggs, agg)
|
||||
}
|
||||
|
||||
// Sort by count descending
|
||||
sort.Slice(aggs, func(i, j int) bool {
|
||||
return aggs[i].RefCount > aggs[j].RefCount
|
||||
})
|
||||
|
||||
return aggs
|
||||
}
|
||||
|
||||
// GetOutboundRefsSorted returns outbound refs for a kind, sorted by count descending.
|
||||
func (r *GraphResult) GetOutboundRefsSorted(kind uint16) []RefAggregation {
|
||||
kindRefs := r.OutboundRefs[kind]
|
||||
if kindRefs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aggs := make([]RefAggregation, 0, len(kindRefs))
|
||||
for sourceID, refs := range kindRefs {
|
||||
agg := RefAggregation{
|
||||
TargetEventID: sourceID,
|
||||
TargetDepth: r.GetEventDepth(sourceID),
|
||||
RefKind: kind,
|
||||
RefCount: len(refs),
|
||||
RefEventIDs: refs,
|
||||
}
|
||||
aggs = append(aggs, agg)
|
||||
}
|
||||
|
||||
sort.Slice(aggs, func(i, j int) bool {
|
||||
return aggs[i].RefCount > aggs[j].RefCount
|
||||
})
|
||||
|
||||
return aggs
|
||||
}
|
||||
|
||||
// GetAllPubkeys returns all pubkeys discovered across all depths.
|
||||
func (r *GraphResult) GetAllPubkeys() []string {
|
||||
all := make([]string, 0, r.TotalPubkeys)
|
||||
for _, pubkeys := range r.PubkeysByDepth {
|
||||
all = append(all, pubkeys...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// GetAllEvents returns all event IDs discovered across all depths.
|
||||
func (r *GraphResult) GetAllEvents() []string {
|
||||
all := make([]string, 0, r.TotalEvents)
|
||||
for _, events := range r.EventsByDepth {
|
||||
all = append(all, events...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// GetPubkeysAtDepth returns pubkeys at a specific depth, or empty slice if none.
|
||||
func (r *GraphResult) GetPubkeysAtDepth(depth int) []string {
|
||||
if pubkeys, exists := r.PubkeysByDepth[depth]; exists {
|
||||
return pubkeys
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// GetEventsAtDepth returns events at a specific depth, or empty slice if none.
|
||||
func (r *GraphResult) GetEventsAtDepth(depth int) []string {
|
||||
if events, exists := r.EventsByDepth[depth]; exists {
|
||||
return events
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Interface methods for external package access (e.g., pkg/protocol/graph)
|
||||
// These allow the graph executor to extract data without direct struct access.
|
||||
|
||||
// GetPubkeysByDepth returns the PubkeysByDepth map for external access.
|
||||
func (r *GraphResult) GetPubkeysByDepth() map[int][]string {
|
||||
return r.PubkeysByDepth
|
||||
}
|
||||
|
||||
// GetEventsByDepth returns the EventsByDepth map for external access.
|
||||
func (r *GraphResult) GetEventsByDepth() map[int][]string {
|
||||
return r.EventsByDepth
|
||||
}
|
||||
|
||||
// GetTotalPubkeys returns the total pubkey count for external access.
|
||||
func (r *GraphResult) GetTotalPubkeys() int {
|
||||
return r.TotalPubkeys
|
||||
}
|
||||
|
||||
// GetTotalEvents returns the total event count for external access.
|
||||
func (r *GraphResult) GetTotalEvents() int {
|
||||
return r.TotalEvents
|
||||
}
|
||||
191
pkg/database/graph-thread.go
Normal file
191
pkg/database/graph-thread.go
Normal file
@@ -0,0 +1,191 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"lol.mleku.dev/log"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
)
|
||||
|
||||
// TraverseThread performs BFS traversal of thread structure via e-tags.
|
||||
// Starting from a seed event, it finds all replies/references at each depth.
|
||||
//
|
||||
// The traversal works bidirectionally:
|
||||
// - Forward: Events that the seed references (parents, quoted posts)
|
||||
// - Backward: Events that reference the seed (replies, reactions, reposts)
|
||||
//
|
||||
// Parameters:
|
||||
// - seedEventID: The event ID to start traversal from
|
||||
// - maxDepth: Maximum depth to traverse
|
||||
// - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
|
||||
func (d *D) TraverseThread(seedEventID []byte, maxDepth int, direction string) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(seedEventID) != 32 {
|
||||
return result, ErrEventNotFound
|
||||
}
|
||||
|
||||
// Get seed event serial
|
||||
seedSerial, err := d.GetSerialById(seedEventID)
|
||||
if err != nil {
|
||||
log.D.F("TraverseThread: seed event not in database: %s", hex.Enc(seedEventID))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Normalize direction
|
||||
if direction == "" {
|
||||
direction = "both"
|
||||
}
|
||||
|
||||
// Track visited events
|
||||
visited := make(map[uint64]bool)
|
||||
visited[seedSerial.Get()] = true
|
||||
|
||||
// Current frontier
|
||||
currentFrontier := []*types.Uint40{seedSerial}
|
||||
|
||||
consecutiveEmptyDepths := 0
|
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
|
||||
var nextFrontier []*types.Uint40
|
||||
newEventsAtDepth := 0
|
||||
|
||||
for _, eventSerial := range currentFrontier {
|
||||
// Get inbound references (events that reference this event)
|
||||
if direction == "both" || direction == "inbound" {
|
||||
inboundSerials, err := d.GetReferencingEvents(eventSerial, nil)
|
||||
if err != nil {
|
||||
log.D.F("TraverseThread: error getting inbound refs for serial %d: %v", eventSerial.Get(), err)
|
||||
} else {
|
||||
for _, refSerial := range inboundSerials {
|
||||
if visited[refSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
visited[refSerial.Get()] = true
|
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(refSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
result.AddEventAtDepth(eventIDHex, currentDepth)
|
||||
newEventsAtDepth++
|
||||
nextFrontier = append(nextFrontier, refSerial)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get outbound references (events this event references)
|
||||
if direction == "both" || direction == "outbound" {
|
||||
outboundSerials, err := d.GetETagsFromEventSerial(eventSerial)
|
||||
if err != nil {
|
||||
log.D.F("TraverseThread: error getting outbound refs for serial %d: %v", eventSerial.Get(), err)
|
||||
} else {
|
||||
for _, refSerial := range outboundSerials {
|
||||
if visited[refSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
visited[refSerial.Get()] = true
|
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(refSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
result.AddEventAtDepth(eventIDHex, currentDepth)
|
||||
newEventsAtDepth++
|
||||
nextFrontier = append(nextFrontier, refSerial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.T.F("TraverseThread: depth %d found %d new events", currentDepth, newEventsAtDepth)
|
||||
|
||||
if newEventsAtDepth == 0 {
|
||||
consecutiveEmptyDepths++
|
||||
if consecutiveEmptyDepths >= 2 {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
consecutiveEmptyDepths = 0
|
||||
}
|
||||
|
||||
currentFrontier = nextFrontier
|
||||
}
|
||||
|
||||
log.D.F("TraverseThread: completed with %d total events", result.TotalEvents)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
|
||||
func (d *D) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) {
|
||||
seedEventID, err := hex.Dec(seedEventIDHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.TraverseThread(seedEventID, maxDepth, direction)
|
||||
}
|
||||
|
||||
// GetThreadReplies finds all direct replies to an event.
|
||||
// This is a convenience method that returns events at depth 1 with inbound direction.
|
||||
func (d *D) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(eventID) != 32 {
|
||||
return result, ErrEventNotFound
|
||||
}
|
||||
|
||||
eventSerial, err := d.GetSerialById(eventID)
|
||||
if err != nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Get events that reference this event
|
||||
replySerials, err := d.GetReferencingEvents(eventSerial, kinds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, replySerial := range replySerials {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(replySerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result.AddEventAtDepth(eventIDHex, 1)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetThreadParents finds events that a given event references (its parents/quotes).
|
||||
func (d *D) GetThreadParents(eventID []byte) (*GraphResult, error) {
|
||||
result := NewGraphResult()
|
||||
|
||||
if len(eventID) != 32 {
|
||||
return result, ErrEventNotFound
|
||||
}
|
||||
|
||||
eventSerial, err := d.GetSerialById(eventID)
|
||||
if err != nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Get events that this event references
|
||||
parentSerials, err := d.GetETagsFromEventSerial(eventSerial)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, parentSerial := range parentSerials {
|
||||
eventIDHex, err := d.GetEventIDFromSerial(parentSerial)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result.AddEventAtDepth(eventIDHex, 1)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
560
pkg/database/graph-traversal.go
Normal file
560
pkg/database/graph-traversal.go
Normal file
@@ -0,0 +1,560 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"lol.mleku.dev/chk"
|
||||
"lol.mleku.dev/log"
|
||||
"next.orly.dev/pkg/database/indexes"
|
||||
"next.orly.dev/pkg/database/indexes/types"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
)
|
||||
|
||||
// Graph traversal errors
|
||||
var (
|
||||
ErrPubkeyNotFound = errors.New("pubkey not found in database")
|
||||
ErrEventNotFound = errors.New("event not found in database")
|
||||
)
|
||||
|
||||
// GetPTagsFromEventSerial extracts p-tag pubkey serials from an event by its serial.
|
||||
// This is a pure index-based operation - no event decoding required.
|
||||
// It scans the epg (event-pubkey-graph) index for p-tag edges.
|
||||
func (d *D) GetPTagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) {
|
||||
var pubkeySerials []*types.Uint40
|
||||
|
||||
// Build prefix: epg|event_serial
|
||||
prefix := new(bytes.Buffer)
|
||||
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix))
|
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract direction to filter for p-tags only
|
||||
direction := key[15]
|
||||
if direction != types.EdgeDirectionPTagOut {
|
||||
continue // Skip author edges, only want p-tag edges
|
||||
}
|
||||
|
||||
// Extract pubkey serial (bytes 8-12)
|
||||
pubkeySerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[8:13])
|
||||
if err := pubkeySerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
pubkeySerials = append(pubkeySerials, pubkeySerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return pubkeySerials, err
|
||||
}
|
||||
|
||||
// GetETagsFromEventSerial extracts e-tag event serials from an event by its serial.
|
||||
// This is a pure index-based operation - no event decoding required.
|
||||
// It scans the eeg (event-event-graph) index for outbound e-tag edges.
|
||||
func (d *D) GetETagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) {
|
||||
var targetSerials []*types.Uint40
|
||||
|
||||
// Build prefix: eeg|source_event_serial
|
||||
prefix := new(bytes.Buffer)
|
||||
prefix.Write([]byte(indexes.EventEventGraphPrefix))
|
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Decode key: eeg(3)|source_serial(5)|target_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract target serial (bytes 8-12)
|
||||
targetSerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[8:13])
|
||||
if err := targetSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
targetSerials = append(targetSerials, targetSerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return targetSerials, err
|
||||
}
|
||||
|
||||
// GetReferencingEvents finds all events that reference a target event via e-tags.
|
||||
// Optionally filters by event kinds. Uses the gee (reverse e-tag graph) index.
|
||||
func (d *D) GetReferencingEvents(targetSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
|
||||
var sourceSerials []*types.Uint40
|
||||
|
||||
if len(kinds) == 0 {
|
||||
// No kind filter - scan all kinds
|
||||
prefix := new(bytes.Buffer)
|
||||
prefix.Write([]byte(indexes.GraphEventEventPrefix))
|
||||
if err := targetSerial.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Decode key: gee(3)|target_serial(5)|kind(2)|direction(1)|source_serial(5)
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract source serial (bytes 11-15)
|
||||
sourceSerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[11:16])
|
||||
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceSerials = append(sourceSerials, sourceSerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return sourceSerials, err
|
||||
}
|
||||
|
||||
// With kind filter - scan each kind's prefix
|
||||
for _, k := range kinds {
|
||||
kind := new(types.Uint16)
|
||||
kind.Set(k)
|
||||
|
||||
direction := new(types.Letter)
|
||||
direction.Set(types.EdgeDirectionETagIn)
|
||||
|
||||
prefix := new(bytes.Buffer)
|
||||
if err := indexes.GraphEventEventEnc(targetSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Extract source serial (last 5 bytes)
|
||||
if len(key) < 5 {
|
||||
continue
|
||||
}
|
||||
sourceSerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[len(key)-5:])
|
||||
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceSerials = append(sourceSerials, sourceSerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return sourceSerials, nil
|
||||
}
|
||||
|
||||
// FindEventByAuthorAndKind finds the most recent event of a specific kind by an author.
|
||||
// This is used to find kind-3 contact lists for follow graph traversal.
|
||||
// Returns nil, nil if no matching event is found.
|
||||
func (d *D) FindEventByAuthorAndKind(authorSerial *types.Uint40, kind uint16) (*types.Uint40, error) {
|
||||
var eventSerial *types.Uint40
|
||||
|
||||
// First, get the full pubkey from the serial
|
||||
pubkey, err := d.GetPubkeyBySerial(authorSerial)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build prefix for kind-pubkey index: kpc|kind|pubkey_hash
|
||||
pubHash := new(types.PubHash)
|
||||
if err := pubHash.FromPubkey(pubkey); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kindType := new(types.Uint16)
|
||||
kindType.Set(kind)
|
||||
|
||||
prefix := new(bytes.Buffer)
|
||||
prefix.Write([]byte(indexes.KindPubkeyPrefix))
|
||||
if err := kindType.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
if err := pubHash.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err = d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
opts.Reverse = true // Most recent first (highest created_at)
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
// Seek to end of prefix range for reverse iteration
|
||||
seekKey := make([]byte, len(searchPrefix)+8+5) // prefix + max timestamp + max serial
|
||||
copy(seekKey, searchPrefix)
|
||||
for i := len(searchPrefix); i < len(seekKey); i++ {
|
||||
seekKey[i] = 0xFF
|
||||
}
|
||||
|
||||
it.Seek(seekKey)
|
||||
if !it.ValidForPrefix(searchPrefix) {
|
||||
// Try going to the first valid key if seek went past
|
||||
it.Rewind()
|
||||
it.Seek(searchPrefix)
|
||||
}
|
||||
|
||||
if it.ValidForPrefix(searchPrefix) {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Decode key: kpc(3)|kind(2)|pubkey_hash(8)|created_at(8)|serial(5)
|
||||
// Total: 26 bytes
|
||||
if len(key) < 26 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract serial (last 5 bytes)
|
||||
eventSerial = new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[len(key)-5:])
|
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return eventSerial, err
|
||||
}
|
||||
|
||||
// GetPubkeyHexFromSerial converts a pubkey serial to its hex string representation.
|
||||
func (d *D) GetPubkeyHexFromSerial(serial *types.Uint40) (string, error) {
|
||||
pubkey, err := d.GetPubkeyBySerial(serial)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.Enc(pubkey), nil
|
||||
}
|
||||
|
||||
// GetEventIDFromSerial converts an event serial to its hex ID string.
|
||||
func (d *D) GetEventIDFromSerial(serial *types.Uint40) (string, error) {
|
||||
eventID, err := d.GetEventIdBySerial(serial)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.Enc(eventID), nil
|
||||
}
|
||||
|
||||
// GetEventsReferencingPubkey finds all events that reference a pubkey via p-tags.
|
||||
// Uses the peg (pubkey-event-graph) index with direction filter for inbound p-tags.
|
||||
// Optionally filters by event kinds.
|
||||
func (d *D) GetEventsReferencingPubkey(pubkeySerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
|
||||
var eventSerials []*types.Uint40
|
||||
|
||||
if len(kinds) == 0 {
|
||||
// No kind filter - we need to scan common kinds since direction comes after kind in the key
|
||||
// Use same approach as QueryPTagGraph
|
||||
commonKinds := []uint16{1, 6, 7, 9735, 10002, 3, 4, 5, 30023}
|
||||
kinds = commonKinds
|
||||
}
|
||||
|
||||
for _, k := range kinds {
|
||||
kind := new(types.Uint16)
|
||||
kind.Set(k)
|
||||
|
||||
direction := new(types.Letter)
|
||||
direction.Set(types.EdgeDirectionPTagIn) // Inbound p-tags
|
||||
|
||||
prefix := new(bytes.Buffer)
|
||||
if err := indexes.PubkeyEventGraphEnc(pubkeySerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract event serial (last 5 bytes)
|
||||
eventSerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[11:16])
|
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
eventSerials = append(eventSerials, eventSerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return eventSerials, nil
|
||||
}
|
||||
|
||||
// GetEventsByAuthor finds all events authored by a pubkey.
|
||||
// Uses the peg (pubkey-event-graph) index with direction filter for author edges.
|
||||
// Optionally filters by event kinds.
|
||||
func (d *D) GetEventsByAuthor(authorSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) {
|
||||
var eventSerials []*types.Uint40
|
||||
|
||||
if len(kinds) == 0 {
|
||||
// No kind filter - scan for author direction across common kinds
|
||||
// This is less efficient but necessary without kind filter
|
||||
commonKinds := []uint16{0, 1, 3, 6, 7, 30023, 10002}
|
||||
kinds = commonKinds
|
||||
}
|
||||
|
||||
for _, k := range kinds {
|
||||
kind := new(types.Uint16)
|
||||
kind.Set(k)
|
||||
|
||||
direction := new(types.Letter)
|
||||
direction.Set(types.EdgeDirectionAuthor) // Author edges
|
||||
|
||||
prefix := new(bytes.Buffer)
|
||||
if err := indexes.PubkeyEventGraphEnc(authorSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract event serial (last 5 bytes)
|
||||
eventSerial := new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[11:16])
|
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
eventSerials = append(eventSerials, eventSerial)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return eventSerials, nil
|
||||
}
|
||||
|
||||
// GetFollowsFromPubkeySerial returns the pubkey serials that a user follows.
|
||||
// This extracts p-tags from the user's kind-3 contact list event.
|
||||
// Returns an empty slice if no kind-3 event is found.
|
||||
func (d *D) GetFollowsFromPubkeySerial(pubkeySerial *types.Uint40) ([]*types.Uint40, error) {
|
||||
// Find the kind-3 event for this pubkey
|
||||
contactEventSerial, err := d.FindEventByAuthorAndKind(pubkeySerial, 3)
|
||||
if err != nil {
|
||||
log.D.F("GetFollowsFromPubkeySerial: error finding kind-3 for serial %d: %v", pubkeySerial.Get(), err)
|
||||
return nil, nil // No kind-3 event found is not an error
|
||||
}
|
||||
if contactEventSerial == nil {
|
||||
log.T.F("GetFollowsFromPubkeySerial: no kind-3 event found for serial %d", pubkeySerial.Get())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Extract p-tags from the contact list event
|
||||
follows, err := d.GetPTagsFromEventSerial(contactEventSerial)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.T.F("GetFollowsFromPubkeySerial: found %d follows for serial %d", len(follows), pubkeySerial.Get())
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
// GetFollowersOfPubkeySerial returns the pubkey serials of users who follow a given pubkey.
|
||||
// This finds all kind-3 events that have a p-tag referencing the target pubkey.
|
||||
func (d *D) GetFollowersOfPubkeySerial(targetSerial *types.Uint40) ([]*types.Uint40, error) {
|
||||
// Find all kind-3 events that reference this pubkey via p-tag
|
||||
kind3Events, err := d.GetEventsReferencingPubkey(targetSerial, []uint16{3})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the author serials from these events
|
||||
var followerSerials []*types.Uint40
|
||||
seen := make(map[uint64]bool)
|
||||
|
||||
for _, eventSerial := range kind3Events {
|
||||
// Get the author of this kind-3 event
|
||||
// We need to look up the event to get its author
|
||||
// Use the epg index to find the author edge
|
||||
authorSerial, err := d.GetEventAuthorSerial(eventSerial)
|
||||
if err != nil {
|
||||
log.D.F("GetFollowersOfPubkeySerial: couldn't get author for event %d: %v", eventSerial.Get(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Deduplicate (a user might have multiple kind-3 events)
|
||||
if seen[authorSerial.Get()] {
|
||||
continue
|
||||
}
|
||||
seen[authorSerial.Get()] = true
|
||||
followerSerials = append(followerSerials, authorSerial)
|
||||
}
|
||||
|
||||
log.T.F("GetFollowersOfPubkeySerial: found %d followers for serial %d", len(followerSerials), targetSerial.Get())
|
||||
return followerSerials, nil
|
||||
}
|
||||
|
||||
// GetEventAuthorSerial finds the author pubkey serial for an event.
|
||||
// Uses the epg (event-pubkey-graph) index with author direction.
|
||||
func (d *D) GetEventAuthorSerial(eventSerial *types.Uint40) (*types.Uint40, error) {
|
||||
var authorSerial *types.Uint40
|
||||
|
||||
// Build prefix: epg|event_serial
|
||||
prefix := new(bytes.Buffer)
|
||||
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix))
|
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
searchPrefix := prefix.Bytes()
|
||||
|
||||
err := d.View(func(txn *badger.Txn) error {
|
||||
opts := badger.DefaultIteratorOptions
|
||||
opts.PrefetchValues = false
|
||||
opts.Prefix = searchPrefix
|
||||
|
||||
it := txn.NewIterator(opts)
|
||||
defer it.Close()
|
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() {
|
||||
key := it.Item().KeyCopy(nil)
|
||||
|
||||
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check direction - we want author (0)
|
||||
direction := key[15]
|
||||
if direction != types.EdgeDirectionAuthor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract pubkey serial (bytes 8-12)
|
||||
authorSerial = new(types.Uint40)
|
||||
serialReader := bytes.NewReader(key[8:13])
|
||||
if err := authorSerial.UnmarshalRead(serialReader); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil // Found the author
|
||||
}
|
||||
return ErrEventNotFound
|
||||
})
|
||||
|
||||
return authorSerial, err
|
||||
}
|
||||
|
||||
// PubkeyHexToSerial converts a pubkey hex string to its serial, if it exists.
|
||||
// Returns an error if the pubkey is not in the database.
|
||||
func (d *D) PubkeyHexToSerial(pubkeyHex string) (*types.Uint40, error) {
|
||||
pubkeyBytes, err := hex.Dec(pubkeyHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(pubkeyBytes) != 32 {
|
||||
return nil, errors.New("invalid pubkey length")
|
||||
}
|
||||
return d.GetPubkeySerial(pubkeyBytes)
|
||||
}
|
||||
|
||||
// EventIDHexToSerial converts an event ID hex string to its serial, if it exists.
|
||||
// Returns an error if the event is not in the database.
|
||||
func (d *D) EventIDHexToSerial(eventIDHex string) (*types.Uint40, error) {
|
||||
eventIDBytes, err := hex.Dec(eventIDHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(eventIDBytes) != 32 {
|
||||
return nil, errors.New("invalid event ID length")
|
||||
}
|
||||
return d.GetSerialById(eventIDBytes)
|
||||
}
|
||||
547
pkg/database/graph-traversal_test.go
Normal file
547
pkg/database/graph-traversal_test.go
Normal file
@@ -0,0 +1,547 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
"git.mleku.dev/mleku/nostr/encoders/tag"
|
||||
)
|
||||
|
||||
func TestGetPTagsFromEventSerial(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create an author pubkey
|
||||
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
|
||||
// Create p-tag target pubkeys
|
||||
target1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
target2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
|
||||
// Create event with p-tags
|
||||
eventID := make([]byte, 32)
|
||||
eventID[0] = 0x10
|
||||
eventSig := make([]byte, 64)
|
||||
eventSig[0] = 0x10
|
||||
|
||||
ev := &event.E{
|
||||
ID: eventID,
|
||||
Pubkey: authorPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Test event with p-tags"),
|
||||
Sig: eventSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(target1)),
|
||||
tag.NewFromAny("p", hex.Enc(target2)),
|
||||
),
|
||||
}
|
||||
|
||||
_, err = db.SaveEvent(ctx, ev)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save event: %v", err)
|
||||
}
|
||||
|
||||
// Get the event serial
|
||||
eventSerial, err := db.GetSerialById(eventID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get event serial: %v", err)
|
||||
}
|
||||
|
||||
// Get p-tags from event serial
|
||||
ptagSerials, err := db.GetPTagsFromEventSerial(eventSerial)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPTagsFromEventSerial failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have 2 p-tags
|
||||
if len(ptagSerials) != 2 {
|
||||
t.Errorf("Expected 2 p-tag serials, got %d", len(ptagSerials))
|
||||
}
|
||||
|
||||
// Verify the pubkeys
|
||||
for _, serial := range ptagSerials {
|
||||
pubkey, err := db.GetPubkeyBySerial(serial)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get pubkey for serial: %v", err)
|
||||
continue
|
||||
}
|
||||
pubkeyHex := hex.Enc(pubkey)
|
||||
if pubkeyHex != hex.Enc(target1) && pubkeyHex != hex.Enc(target2) {
|
||||
t.Errorf("Unexpected pubkey: %s", pubkeyHex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetETagsFromEventSerial(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
parentID := make([]byte, 32)
|
||||
parentID[0] = 0x10
|
||||
parentSig := make([]byte, 64)
|
||||
parentSig[0] = 0x10
|
||||
|
||||
parentEvent := &event.E{
|
||||
ID: parentID,
|
||||
Pubkey: parentPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Parent post"),
|
||||
Sig: parentSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, parentEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save parent event: %v", err)
|
||||
}
|
||||
|
||||
// Create a reply event with e-tag
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
replyID := make([]byte, 32)
|
||||
replyID[0] = 0x20
|
||||
replySig := make([]byte, 64)
|
||||
replySig[0] = 0x20
|
||||
|
||||
replyEvent := &event.E{
|
||||
ID: replyID,
|
||||
Pubkey: replyPubkey,
|
||||
CreatedAt: 1234567891,
|
||||
Kind: 1,
|
||||
Content: []byte("Reply"),
|
||||
Sig: replySig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, replyEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reply event: %v", err)
|
||||
}
|
||||
|
||||
// Get e-tags from reply
|
||||
replySerial, _ := db.GetSerialById(replyID)
|
||||
etagSerials, err := db.GetETagsFromEventSerial(replySerial)
|
||||
if err != nil {
|
||||
t.Fatalf("GetETagsFromEventSerial failed: %v", err)
|
||||
}
|
||||
|
||||
if len(etagSerials) != 1 {
|
||||
t.Errorf("Expected 1 e-tag serial, got %d", len(etagSerials))
|
||||
}
|
||||
|
||||
// Verify the target event
|
||||
if len(etagSerials) > 0 {
|
||||
targetEventID, err := db.GetEventIdBySerial(etagSerials[0])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get event ID from serial: %v", err)
|
||||
}
|
||||
if hex.Enc(targetEventID) != hex.Enc(parentID) {
|
||||
t.Errorf("Expected parent ID, got %s", hex.Enc(targetEventID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetReferencingEvents(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
parentID := make([]byte, 32)
|
||||
parentID[0] = 0x10
|
||||
parentSig := make([]byte, 64)
|
||||
parentSig[0] = 0x10
|
||||
|
||||
parentEvent := &event.E{
|
||||
ID: parentID,
|
||||
Pubkey: parentPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Parent post"),
|
||||
Sig: parentSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, parentEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save parent event: %v", err)
|
||||
}
|
||||
|
||||
// Create multiple replies and reactions
|
||||
for i := 0; i < 3; i++ {
|
||||
replyPubkey := make([]byte, 32)
|
||||
replyPubkey[0] = byte(0x20 + i)
|
||||
replyID := make([]byte, 32)
|
||||
replyID[0] = byte(0x30 + i)
|
||||
replySig := make([]byte, 64)
|
||||
replySig[0] = byte(0x30 + i)
|
||||
|
||||
var evKind uint16 = 1 // Reply
|
||||
if i == 2 {
|
||||
evKind = 7 // Reaction
|
||||
}
|
||||
|
||||
replyEvent := &event.E{
|
||||
ID: replyID,
|
||||
Pubkey: replyPubkey,
|
||||
CreatedAt: int64(1234567891 + i),
|
||||
Kind: evKind,
|
||||
Content: []byte("Response"),
|
||||
Sig: replySig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("e", hex.Enc(parentID)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, replyEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save reply %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get parent serial
|
||||
parentSerial, _ := db.GetSerialById(parentID)
|
||||
|
||||
// Test without kind filter
|
||||
refs, err := db.GetReferencingEvents(parentSerial, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetReferencingEvents failed: %v", err)
|
||||
}
|
||||
if len(refs) != 3 {
|
||||
t.Errorf("Expected 3 referencing events, got %d", len(refs))
|
||||
}
|
||||
|
||||
// Test with kind filter (only replies)
|
||||
refs, err = db.GetReferencingEvents(parentSerial, []uint16{1})
|
||||
if err != nil {
|
||||
t.Fatalf("GetReferencingEvents with kind filter failed: %v", err)
|
||||
}
|
||||
if len(refs) != 2 {
|
||||
t.Errorf("Expected 2 kind-1 referencing events, got %d", len(refs))
|
||||
}
|
||||
|
||||
// Test with kind filter (only reactions)
|
||||
refs, err = db.GetReferencingEvents(parentSerial, []uint16{7})
|
||||
if err != nil {
|
||||
t.Fatalf("GetReferencingEvents with kind 7 filter failed: %v", err)
|
||||
}
|
||||
if len(refs) != 1 {
|
||||
t.Errorf("Expected 1 kind-7 referencing event, got %d", len(refs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFollowsFromPubkeySerial(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create author and their follows
|
||||
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
follow1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
follow2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
follow3, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004")
|
||||
|
||||
// Create kind-3 contact list
|
||||
eventID := make([]byte, 32)
|
||||
eventID[0] = 0x10
|
||||
eventSig := make([]byte, 64)
|
||||
eventSig[0] = 0x10
|
||||
|
||||
contactList := &event.E{
|
||||
ID: eventID,
|
||||
Pubkey: authorPubkey,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: eventSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(follow1)),
|
||||
tag.NewFromAny("p", hex.Enc(follow2)),
|
||||
tag.NewFromAny("p", hex.Enc(follow3)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, contactList)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save contact list: %v", err)
|
||||
}
|
||||
|
||||
// Get author serial
|
||||
authorSerial, err := db.GetPubkeySerial(authorPubkey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get author serial: %v", err)
|
||||
}
|
||||
|
||||
// Get follows
|
||||
follows, err := db.GetFollowsFromPubkeySerial(authorSerial)
|
||||
if err != nil {
|
||||
t.Fatalf("GetFollowsFromPubkeySerial failed: %v", err)
|
||||
}
|
||||
|
||||
if len(follows) != 3 {
|
||||
t.Errorf("Expected 3 follows, got %d", len(follows))
|
||||
}
|
||||
|
||||
// Verify the follows are correct
|
||||
expectedFollows := map[string]bool{
|
||||
hex.Enc(follow1): false,
|
||||
hex.Enc(follow2): false,
|
||||
hex.Enc(follow3): false,
|
||||
}
|
||||
for _, serial := range follows {
|
||||
pubkey, err := db.GetPubkeyBySerial(serial)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get pubkey from serial: %v", err)
|
||||
continue
|
||||
}
|
||||
pkHex := hex.Enc(pubkey)
|
||||
if _, exists := expectedFollows[pkHex]; exists {
|
||||
expectedFollows[pkHex] = true
|
||||
} else {
|
||||
t.Errorf("Unexpected follow: %s", pkHex)
|
||||
}
|
||||
}
|
||||
for pk, found := range expectedFollows {
|
||||
if !found {
|
||||
t.Errorf("Expected follow not found: %s", pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphResult(t *testing.T) {
|
||||
result := NewGraphResult()
|
||||
|
||||
// Add pubkeys at different depths
|
||||
result.AddPubkeyAtDepth("pubkey1", 1)
|
||||
result.AddPubkeyAtDepth("pubkey2", 1)
|
||||
result.AddPubkeyAtDepth("pubkey3", 2)
|
||||
result.AddPubkeyAtDepth("pubkey4", 2)
|
||||
result.AddPubkeyAtDepth("pubkey5", 3)
|
||||
|
||||
// Try to add duplicate
|
||||
added := result.AddPubkeyAtDepth("pubkey1", 2)
|
||||
if added {
|
||||
t.Error("Should not add duplicate pubkey")
|
||||
}
|
||||
|
||||
// Verify counts
|
||||
if result.TotalPubkeys != 5 {
|
||||
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys)
|
||||
}
|
||||
|
||||
// Verify depth tracking
|
||||
if result.GetPubkeyDepth("pubkey1") != 1 {
|
||||
t.Errorf("pubkey1 should be at depth 1")
|
||||
}
|
||||
if result.GetPubkeyDepth("pubkey3") != 2 {
|
||||
t.Errorf("pubkey3 should be at depth 2")
|
||||
}
|
||||
|
||||
// Verify HasPubkey
|
||||
if !result.HasPubkey("pubkey1") {
|
||||
t.Error("Should have pubkey1")
|
||||
}
|
||||
if result.HasPubkey("nonexistent") {
|
||||
t.Error("Should not have nonexistent pubkey")
|
||||
}
|
||||
|
||||
// Verify ToDepthArrays
|
||||
arrays := result.ToDepthArrays()
|
||||
if len(arrays) != 3 {
|
||||
t.Errorf("Expected 3 depth arrays, got %d", len(arrays))
|
||||
}
|
||||
if len(arrays[0]) != 2 {
|
||||
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(arrays[0]))
|
||||
}
|
||||
if len(arrays[1]) != 2 {
|
||||
t.Errorf("Expected 2 pubkeys at depth 2, got %d", len(arrays[1]))
|
||||
}
|
||||
if len(arrays[2]) != 1 {
|
||||
t.Errorf("Expected 1 pubkey at depth 3, got %d", len(arrays[2]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphResultRefs(t *testing.T) {
|
||||
result := NewGraphResult()
|
||||
|
||||
// Add some pubkeys
|
||||
result.AddPubkeyAtDepth("pubkey1", 1)
|
||||
result.AddEventAtDepth("event1", 1)
|
||||
|
||||
// Add inbound refs (kind 7 reactions)
|
||||
result.AddInboundRef(7, "event1", "reaction1")
|
||||
result.AddInboundRef(7, "event1", "reaction2")
|
||||
result.AddInboundRef(7, "event1", "reaction3")
|
||||
|
||||
// Get sorted refs
|
||||
refs := result.GetInboundRefsSorted(7)
|
||||
if len(refs) != 1 {
|
||||
t.Fatalf("Expected 1 aggregation, got %d", len(refs))
|
||||
}
|
||||
if refs[0].RefCount != 3 {
|
||||
t.Errorf("Expected 3 refs, got %d", refs[0].RefCount)
|
||||
}
|
||||
if refs[0].TargetEventID != "event1" {
|
||||
t.Errorf("Expected event1, got %s", refs[0].TargetEventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFollowersOfPubkeySerial(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create target pubkey (the one being followed)
|
||||
targetPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
|
||||
// Create followers
|
||||
follower1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002")
|
||||
follower2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003")
|
||||
|
||||
// Create kind-3 contact lists for followers
|
||||
for i, followerPubkey := range [][]byte{follower1, follower2} {
|
||||
eventID := make([]byte, 32)
|
||||
eventID[0] = byte(0x10 + i)
|
||||
eventSig := make([]byte, 64)
|
||||
eventSig[0] = byte(0x10 + i)
|
||||
|
||||
contactList := &event.E{
|
||||
ID: eventID,
|
||||
Pubkey: followerPubkey,
|
||||
CreatedAt: int64(1234567890 + i),
|
||||
Kind: 3,
|
||||
Content: []byte(""),
|
||||
Sig: eventSig,
|
||||
Tags: tag.NewS(
|
||||
tag.NewFromAny("p", hex.Enc(targetPubkey)),
|
||||
),
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, contactList)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save contact list %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get target serial
|
||||
targetSerial, err := db.GetPubkeySerial(targetPubkey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get target serial: %v", err)
|
||||
}
|
||||
|
||||
// Get followers
|
||||
followers, err := db.GetFollowersOfPubkeySerial(targetSerial)
|
||||
if err != nil {
|
||||
t.Fatalf("GetFollowersOfPubkeySerial failed: %v", err)
|
||||
}
|
||||
|
||||
if len(followers) != 2 {
|
||||
t.Errorf("Expected 2 followers, got %d", len(followers))
|
||||
}
|
||||
|
||||
// Verify the followers
|
||||
expectedFollowers := map[string]bool{
|
||||
hex.Enc(follower1): false,
|
||||
hex.Enc(follower2): false,
|
||||
}
|
||||
for _, serial := range followers {
|
||||
pubkey, err := db.GetPubkeyBySerial(serial)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get pubkey from serial: %v", err)
|
||||
continue
|
||||
}
|
||||
pkHex := hex.Enc(pubkey)
|
||||
if _, exists := expectedFollowers[pkHex]; exists {
|
||||
expectedFollowers[pkHex] = true
|
||||
} else {
|
||||
t.Errorf("Unexpected follower: %s", pkHex)
|
||||
}
|
||||
}
|
||||
for pk, found := range expectedFollowers {
|
||||
if !found {
|
||||
t.Errorf("Expected follower not found: %s", pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubkeyHexToSerial(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a pubkey by saving an event
|
||||
pubkeyBytes, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001")
|
||||
eventID := make([]byte, 32)
|
||||
eventID[0] = 0x10
|
||||
eventSig := make([]byte, 64)
|
||||
eventSig[0] = 0x10
|
||||
|
||||
ev := &event.E{
|
||||
ID: eventID,
|
||||
Pubkey: pubkeyBytes,
|
||||
CreatedAt: 1234567890,
|
||||
Kind: 1,
|
||||
Content: []byte("Test"),
|
||||
Sig: eventSig,
|
||||
Tags: &tag.S{},
|
||||
}
|
||||
_, err = db.SaveEvent(ctx, ev)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save event: %v", err)
|
||||
}
|
||||
|
||||
// Convert hex to serial
|
||||
pubkeyHex := hex.Enc(pubkeyBytes)
|
||||
serial, err := db.PubkeyHexToSerial(pubkeyHex)
|
||||
if err != nil {
|
||||
t.Fatalf("PubkeyHexToSerial failed: %v", err)
|
||||
}
|
||||
if serial == nil {
|
||||
t.Fatal("Expected non-nil serial")
|
||||
}
|
||||
|
||||
// Convert back and verify
|
||||
backToHex, err := db.GetPubkeyHexFromSerial(serial)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPubkeyHexFromSerial failed: %v", err)
|
||||
}
|
||||
if backToHex != pubkeyHex {
|
||||
t.Errorf("Round-trip failed: %s != %s", backToHex, pubkeyHex)
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,10 @@ const (
|
||||
// Compact event storage indexes
|
||||
SerialEventIdPrefix = I("sei") // event serial -> full 32-byte event ID
|
||||
CompactEventPrefix = I("cmp") // compact event storage with serial references
|
||||
|
||||
// Event-to-event graph indexes (for e-tag references)
|
||||
EventEventGraphPrefix = I("eeg") // source event serial -> target event serial (outbound e-tags)
|
||||
GraphEventEventPrefix = I("gee") // target event serial -> source event serial (reverse e-tags)
|
||||
)
|
||||
|
||||
// Prefix returns the three byte human-readable prefixes that go in front of
|
||||
@@ -142,6 +146,11 @@ func Prefix(prf int) (i I) {
|
||||
return SerialEventIdPrefix
|
||||
case CompactEvent:
|
||||
return CompactEventPrefix
|
||||
|
||||
case EventEventGraph:
|
||||
return EventEventGraphPrefix
|
||||
case GraphEventEvent:
|
||||
return GraphEventEventPrefix
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -205,6 +214,11 @@ func Identify(r io.Reader) (i int, err error) {
|
||||
i = SerialEventId
|
||||
case CompactEventPrefix:
|
||||
i = CompactEvent
|
||||
|
||||
case EventEventGraphPrefix:
|
||||
i = EventEventGraph
|
||||
case GraphEventEventPrefix:
|
||||
i = GraphEventEvent
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -655,3 +669,38 @@ func CompactEventEnc(ser *types.Uint40) (enc *T) {
|
||||
return New(NewPrefix(CompactEvent), ser)
|
||||
}
|
||||
func CompactEventDec(ser *types.Uint40) (enc *T) { return New(NewPrefix(), ser) }
|
||||
|
||||
// EventEventGraph creates a bidirectional graph edge between events via e-tags.
|
||||
// This stores source_event_serial -> target_event_serial relationships with event kind and direction.
|
||||
// Used for thread traversal and finding replies/reactions/reposts to events.
|
||||
// Direction: 0=outbound (this event references target)
|
||||
//
|
||||
// 3 prefix|5 source event serial|5 target event serial|2 kind|1 direction
|
||||
var EventEventGraph = next()
|
||||
|
||||
func EventEventGraphVars() (srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) {
|
||||
return new(types.Uint40), new(types.Uint40), new(types.Uint16), new(types.Letter)
|
||||
}
|
||||
func EventEventGraphEnc(srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) (enc *T) {
|
||||
return New(NewPrefix(EventEventGraph), srcSer, tgtSer, kind, direction)
|
||||
}
|
||||
func EventEventGraphDec(srcSer *types.Uint40, tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter) (enc *T) {
|
||||
return New(NewPrefix(), srcSer, tgtSer, kind, direction)
|
||||
}
|
||||
|
||||
// GraphEventEvent creates the reverse edge: target_event_serial -> source_event_serial with kind and direction.
|
||||
// This enables querying all events that reference a target event (e.g., all replies to a post).
|
||||
// Direction: 1=inbound (target is referenced by source)
|
||||
//
|
||||
// 3 prefix|5 target event serial|2 kind|1 direction|5 source event serial
|
||||
var GraphEventEvent = next()
|
||||
|
||||
func GraphEventEventVars() (tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) {
|
||||
return new(types.Uint40), new(types.Uint16), new(types.Letter), new(types.Uint40)
|
||||
}
|
||||
func GraphEventEventEnc(tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) (enc *T) {
|
||||
return New(NewPrefix(GraphEventEvent), tgtSer, kind, direction, srcSer)
|
||||
}
|
||||
func GraphEventEventDec(tgtSer *types.Uint40, kind *types.Uint16, direction *types.Letter, srcSer *types.Uint40) (enc *T) {
|
||||
return New(NewPrefix(), tgtSer, kind, direction, srcSer)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ const (
|
||||
EdgeDirectionPTagIn byte = 2 // Inbound: This pubkey is referenced in event's p-tag
|
||||
)
|
||||
|
||||
// Edge direction constants for event-to-event (e-tag) graph relationships
|
||||
const (
|
||||
EdgeDirectionETagOut byte = 0 // Outbound: This event references target event via e-tag
|
||||
EdgeDirectionETagIn byte = 1 // Inbound: This event is referenced by source event via e-tag
|
||||
)
|
||||
|
||||
type Letter struct {
|
||||
val byte
|
||||
}
|
||||
|
||||
@@ -349,24 +349,74 @@ func (d *D) SaveEvent(c context.Context, ev *event.E) (
|
||||
}
|
||||
|
||||
// Create event -> pubkey edge (with kind and direction)
|
||||
keyBuf := new(bytes.Buffer)
|
||||
if err = indexes.EventPubkeyGraphEnc(ser, pkInfo.serial, eventKind, directionForward).MarshalWrite(keyBuf); chk.E(err) {
|
||||
epgKeyBuf := new(bytes.Buffer)
|
||||
if err = indexes.EventPubkeyGraphEnc(ser, pkInfo.serial, eventKind, directionForward).MarshalWrite(epgKeyBuf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
if err = txn.Set(keyBuf.Bytes(), nil); chk.E(err) {
|
||||
// Make a copy of the key bytes to avoid buffer reuse issues in txn
|
||||
epgKey := make([]byte, epgKeyBuf.Len())
|
||||
copy(epgKey, epgKeyBuf.Bytes())
|
||||
if err = txn.Set(epgKey, nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create pubkey -> event edge (reverse, with kind and direction for filtering)
|
||||
keyBuf.Reset()
|
||||
if err = indexes.PubkeyEventGraphEnc(pkInfo.serial, eventKind, directionReverse, ser).MarshalWrite(keyBuf); chk.E(err) {
|
||||
pegKeyBuf := new(bytes.Buffer)
|
||||
if err = indexes.PubkeyEventGraphEnc(pkInfo.serial, eventKind, directionReverse, ser).MarshalWrite(pegKeyBuf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
if err = txn.Set(keyBuf.Bytes(), nil); chk.E(err) {
|
||||
if err = txn.Set(pegKeyBuf.Bytes(), nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create event-to-event graph edges for e-tags
|
||||
// This enables thread traversal and finding replies/reactions to events
|
||||
eTags := ev.Tags.GetAll([]byte("e"))
|
||||
for _, eTag := range eTags {
|
||||
if eTag.Len() >= 2 {
|
||||
// Get event ID from e-tag, handling both binary and hex storage formats
|
||||
var targetEventID []byte
|
||||
if targetEventID, err = hex.Dec(string(eTag.ValueHex())); err != nil || len(targetEventID) != 32 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Look up the target event's serial (if it exists in our database)
|
||||
var targetSerial *types.Uint40
|
||||
if targetSerial, err = d.GetSerialById(targetEventID); err != nil {
|
||||
// Target event not in our database - skip edge creation
|
||||
// This is normal for replies to events we don't have
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Create forward edge: source event -> target event (outbound e-tag)
|
||||
directionOut := new(types.Letter)
|
||||
directionOut.Set(types.EdgeDirectionETagOut)
|
||||
eegKeyBuf := new(bytes.Buffer)
|
||||
if err = indexes.EventEventGraphEnc(ser, targetSerial, eventKind, directionOut).MarshalWrite(eegKeyBuf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
// Make a copy of the key bytes to avoid buffer reuse issues in txn
|
||||
eegKey := make([]byte, eegKeyBuf.Len())
|
||||
copy(eegKey, eegKeyBuf.Bytes())
|
||||
if err = txn.Set(eegKey, nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create reverse edge: target event -> source event (inbound e-tag)
|
||||
directionIn := new(types.Letter)
|
||||
directionIn.Set(types.EdgeDirectionETagIn)
|
||||
geeKeyBuf := new(bytes.Buffer)
|
||||
if err = indexes.GraphEventEventEnc(targetSerial, eventKind, directionIn, ser).MarshalWrite(geeKeyBuf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
if err = txn.Set(geeKeyBuf.Bytes(), nil); chk.E(err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
},
|
||||
)
|
||||
|
||||
202
pkg/protocol/graph/executor.go
Normal file
202
pkg/protocol/graph/executor.go
Normal file
@@ -0,0 +1,202 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// This file contains the executor that runs graph traversal queries.
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"lol.mleku.dev/chk"
|
||||
"lol.mleku.dev/log"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
"git.mleku.dev/mleku/nostr/encoders/tag"
|
||||
"git.mleku.dev/mleku/nostr/interfaces/signer"
|
||||
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k"
|
||||
)
|
||||
|
||||
// Response kinds for graph queries (ephemeral range, relay-signed)
|
||||
const (
|
||||
KindGraphFollows = 39000 // Response for follows/followers queries
|
||||
KindGraphMentions = 39001 // Response for mentions queries
|
||||
KindGraphThread = 39002 // Response for thread traversal queries
|
||||
)
|
||||
|
||||
// GraphResultI is the interface that database.GraphResult implements.
|
||||
// This allows the executor to work with the database result without importing it.
|
||||
type GraphResultI interface {
|
||||
ToDepthArrays() [][]string
|
||||
ToEventDepthArrays() [][]string
|
||||
GetAllPubkeys() []string
|
||||
GetAllEvents() []string
|
||||
GetPubkeysByDepth() map[int][]string
|
||||
GetEventsByDepth() map[int][]string
|
||||
GetTotalPubkeys() int
|
||||
GetTotalEvents() int
|
||||
}
|
||||
|
||||
// GraphDatabase defines the interface for graph traversal operations.
|
||||
// This is implemented by the database package.
|
||||
type GraphDatabase interface {
|
||||
// TraverseFollows performs BFS traversal of follow graph
|
||||
TraverseFollows(seedPubkey []byte, maxDepth int) (GraphResultI, error)
|
||||
// TraverseFollowers performs BFS traversal to find followers
|
||||
TraverseFollowers(seedPubkey []byte, maxDepth int) (GraphResultI, error)
|
||||
// FindMentions finds events mentioning a pubkey
|
||||
FindMentions(pubkey []byte, kinds []uint16) (GraphResultI, error)
|
||||
// TraverseThread performs BFS traversal of thread structure
|
||||
TraverseThread(seedEventID []byte, maxDepth int, direction string) (GraphResultI, error)
|
||||
}
|
||||
|
||||
// Executor handles graph query execution and response generation.
|
||||
type Executor struct {
|
||||
db GraphDatabase
|
||||
relaySigner signer.I
|
||||
relayPubkey []byte
|
||||
}
|
||||
|
||||
// NewExecutor creates a new graph query executor.
|
||||
// The secretKey should be the 32-byte relay identity secret key.
|
||||
func NewExecutor(db GraphDatabase, secretKey []byte) (*Executor, error) {
|
||||
s, err := p8k.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.InitSec(secretKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Executor{
|
||||
db: db,
|
||||
relaySigner: s,
|
||||
relayPubkey: s.Pub(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Execute runs a graph query and returns a relay-signed event with results.
|
||||
func (e *Executor) Execute(q *Query) (*event.E, error) {
|
||||
var result GraphResultI
|
||||
var err error
|
||||
var responseKind uint16
|
||||
|
||||
// Decode seed (hex string to bytes)
|
||||
seedBytes, err := hex.Dec(q.Seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the appropriate traversal
|
||||
switch q.Method {
|
||||
case "follows":
|
||||
responseKind = KindGraphFollows
|
||||
result, err = e.db.TraverseFollows(seedBytes, q.Depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: follows traversal returned %d pubkeys", result.GetTotalPubkeys())
|
||||
|
||||
case "followers":
|
||||
responseKind = KindGraphFollows
|
||||
result, err = e.db.TraverseFollowers(seedBytes, q.Depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: followers traversal returned %d pubkeys", result.GetTotalPubkeys())
|
||||
|
||||
case "mentions":
|
||||
responseKind = KindGraphMentions
|
||||
// Mentions don't use depth traversal, just find direct mentions
|
||||
// Convert RefSpec kinds to uint16 for the database call
|
||||
var kinds []uint16
|
||||
if len(q.InboundRefs) > 0 {
|
||||
for _, rs := range q.InboundRefs {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds = append(kinds, uint16(k))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
kinds = []uint16{1} // Default to kind 1 (notes)
|
||||
}
|
||||
result, err = e.db.FindMentions(seedBytes, kinds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: mentions query returned %d events", result.GetTotalEvents())
|
||||
|
||||
case "thread":
|
||||
responseKind = KindGraphThread
|
||||
result, err = e.db.TraverseThread(seedBytes, q.Depth, "both")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: thread traversal returned %d events", result.GetTotalEvents())
|
||||
|
||||
default:
|
||||
return nil, ErrInvalidMethod
|
||||
}
|
||||
|
||||
// Generate response event
|
||||
return e.generateResponse(q, result, responseKind)
|
||||
}
|
||||
|
||||
// generateResponse creates a relay-signed event containing the query results.
|
||||
func (e *Executor) generateResponse(q *Query, result GraphResultI, responseKind uint16) (*event.E, error) {
|
||||
// Build content as JSON with depth arrays
|
||||
var content ResponseContent
|
||||
|
||||
if q.Method == "follows" || q.Method == "followers" {
|
||||
// For pubkey-based queries, use pubkeys_by_depth
|
||||
content.PubkeysByDepth = result.ToDepthArrays()
|
||||
content.TotalPubkeys = result.GetTotalPubkeys()
|
||||
} else {
|
||||
// For event-based queries, use events_by_depth
|
||||
content.EventsByDepth = result.ToEventDepthArrays()
|
||||
content.TotalEvents = result.GetTotalEvents()
|
||||
}
|
||||
|
||||
contentBytes, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build tags
|
||||
tags := tag.NewS(
|
||||
tag.NewFromAny("method", q.Method),
|
||||
tag.NewFromAny("seed", q.Seed),
|
||||
tag.NewFromAny("depth", strconv.Itoa(q.Depth)),
|
||||
)
|
||||
|
||||
// Create event
|
||||
ev := &event.E{
|
||||
Kind: responseKind,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
Tags: tags,
|
||||
Content: contentBytes,
|
||||
}
|
||||
|
||||
// Sign with relay identity
|
||||
if err = ev.Sign(e.relaySigner); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// ResponseContent is the JSON structure for graph query responses.
|
||||
type ResponseContent struct {
|
||||
// PubkeysByDepth contains arrays of pubkeys at each depth (1-indexed)
|
||||
// Each pubkey appears ONLY at the depth where it was first discovered.
|
||||
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"`
|
||||
|
||||
// EventsByDepth contains arrays of event IDs at each depth (1-indexed)
|
||||
EventsByDepth [][]string `json:"events_by_depth,omitempty"`
|
||||
|
||||
// TotalPubkeys is the total count of unique pubkeys discovered
|
||||
TotalPubkeys int `json:"total_pubkeys,omitempty"`
|
||||
|
||||
// TotalEvents is the total count of unique events discovered
|
||||
TotalEvents int `json:"total_events,omitempty"`
|
||||
}
|
||||
183
pkg/protocol/graph/query.go
Normal file
183
pkg/protocol/graph/query.go
Normal file
@@ -0,0 +1,183 @@
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// It provides types and functions for parsing and validating graph traversal queries.
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter"
|
||||
)
|
||||
|
||||
// Query represents a graph traversal query from a _graph filter extension.
|
||||
type Query struct {
|
||||
// Method is the traversal method: "follows", "followers", "mentions", "thread"
|
||||
Method string `json:"method"`
|
||||
|
||||
// Seed is the starting point for traversal (pubkey hex or event ID hex)
|
||||
Seed string `json:"seed"`
|
||||
|
||||
// Depth is the maximum traversal depth (1-16, default: 1)
|
||||
Depth int `json:"depth,omitempty"`
|
||||
|
||||
// InboundRefs specifies which inbound references to collect
|
||||
// (events that reference discovered events via e-tags)
|
||||
InboundRefs []RefSpec `json:"inbound_refs,omitempty"`
|
||||
|
||||
// OutboundRefs specifies which outbound references to collect
|
||||
// (events referenced by discovered events via e-tags)
|
||||
OutboundRefs []RefSpec `json:"outbound_refs,omitempty"`
|
||||
}
|
||||
|
||||
// RefSpec specifies which event references to include in results.
|
||||
type RefSpec struct {
|
||||
// Kinds is the list of event kinds to match (OR semantics within this spec)
|
||||
Kinds []int `json:"kinds"`
|
||||
|
||||
// FromDepth specifies the minimum depth at which to collect refs (default: 0)
|
||||
// 0 = include refs from seed itself
|
||||
// 1 = start from first-hop connections
|
||||
FromDepth int `json:"from_depth,omitempty"`
|
||||
}
|
||||
|
||||
// Validation errors
|
||||
var (
|
||||
ErrMissingMethod = errors.New("_graph.method is required")
|
||||
ErrInvalidMethod = errors.New("_graph.method must be one of: follows, followers, mentions, thread")
|
||||
ErrMissingSeed = errors.New("_graph.seed is required")
|
||||
ErrInvalidSeed = errors.New("_graph.seed must be a 64-character hex string")
|
||||
ErrDepthTooHigh = errors.New("_graph.depth cannot exceed 16")
|
||||
ErrEmptyRefSpecKinds = errors.New("ref spec kinds array cannot be empty")
|
||||
)
|
||||
|
||||
// Valid method names
|
||||
var validMethods = map[string]bool{
|
||||
"follows": true,
|
||||
"followers": true,
|
||||
"mentions": true,
|
||||
"thread": true,
|
||||
}
|
||||
|
||||
// Validate checks the query for correctness and applies defaults.
|
||||
func (q *Query) Validate() error {
|
||||
// Method is required
|
||||
if q.Method == "" {
|
||||
return ErrMissingMethod
|
||||
}
|
||||
if !validMethods[q.Method] {
|
||||
return ErrInvalidMethod
|
||||
}
|
||||
|
||||
// Seed is required
|
||||
if q.Seed == "" {
|
||||
return ErrMissingSeed
|
||||
}
|
||||
if len(q.Seed) != 64 {
|
||||
return ErrInvalidSeed
|
||||
}
|
||||
// Validate hex characters
|
||||
for _, c := range q.Seed {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return ErrInvalidSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Apply depth defaults and limits
|
||||
if q.Depth < 1 {
|
||||
q.Depth = 1
|
||||
}
|
||||
if q.Depth > 16 {
|
||||
return ErrDepthTooHigh
|
||||
}
|
||||
|
||||
// Validate ref specs
|
||||
for _, rs := range q.InboundRefs {
|
||||
if len(rs.Kinds) == 0 {
|
||||
return ErrEmptyRefSpecKinds
|
||||
}
|
||||
}
|
||||
for _, rs := range q.OutboundRefs {
|
||||
if len(rs.Kinds) == 0 {
|
||||
return ErrEmptyRefSpecKinds
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasInboundRefs returns true if the query includes inbound reference collection.
|
||||
func (q *Query) HasInboundRefs() bool {
|
||||
return len(q.InboundRefs) > 0
|
||||
}
|
||||
|
||||
// HasOutboundRefs returns true if the query includes outbound reference collection.
|
||||
func (q *Query) HasOutboundRefs() bool {
|
||||
return len(q.OutboundRefs) > 0
|
||||
}
|
||||
|
||||
// HasRefs returns true if the query includes any reference collection.
|
||||
func (q *Query) HasRefs() bool {
|
||||
return q.HasInboundRefs() || q.HasOutboundRefs()
|
||||
}
|
||||
|
||||
// InboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
// It aggregates all RefSpecs where from_depth <= depth.
|
||||
func (q *Query) InboundKindsAtDepth(depth int) map[int]bool {
|
||||
kinds := make(map[int]bool)
|
||||
for _, rs := range q.InboundRefs {
|
||||
if rs.FromDepth <= depth {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds[k] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
// OutboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
func (q *Query) OutboundKindsAtDepth(depth int) map[int]bool {
|
||||
kinds := make(map[int]bool)
|
||||
for _, rs := range q.OutboundRefs {
|
||||
if rs.FromDepth <= depth {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds[k] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
// ExtractFromFilter checks if a filter has a _graph extension and parses it.
|
||||
// Returns nil if no _graph field is present.
|
||||
// Returns an error if _graph is present but invalid.
|
||||
func ExtractFromFilter(f *filter.F) (*Query, error) {
|
||||
if f == nil || f.Extra == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw, ok := f.Extra["_graph"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var q Query
|
||||
if err := json.Unmarshal(raw, &q); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := q.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &q, nil
|
||||
}
|
||||
|
||||
// IsGraphQuery returns true if the filter contains a _graph extension.
|
||||
// This is a quick check that doesn't parse the full query.
|
||||
func IsGraphQuery(f *filter.F) bool {
|
||||
if f == nil || f.Extra == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := f.Extra["_graph"]
|
||||
return ok
|
||||
}
|
||||
397
pkg/protocol/graph/query_test.go
Normal file
397
pkg/protocol/graph/query_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter"
|
||||
)
|
||||
|
||||
func TestQueryValidate(t *testing.T) {
|
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query Query
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid follows query",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 2,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid followers query",
|
||||
query: Query{
|
||||
Method: "followers",
|
||||
Seed: validSeed,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid mentions query",
|
||||
query: Query{
|
||||
Method: "mentions",
|
||||
Seed: validSeed,
|
||||
Depth: 1,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid thread query",
|
||||
query: Query{
|
||||
Method: "thread",
|
||||
Seed: validSeed,
|
||||
Depth: 10,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid query with inbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 1},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid query with multiple ref specs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 1},
|
||||
{Kinds: []int{6}, FromDepth: 1},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}, FromDepth: 0},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
query: Query{Seed: validSeed},
|
||||
wantErr: ErrMissingMethod,
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
query: Query{
|
||||
Method: "invalid",
|
||||
Seed: validSeed,
|
||||
},
|
||||
wantErr: ErrInvalidMethod,
|
||||
},
|
||||
{
|
||||
name: "missing seed",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
},
|
||||
wantErr: ErrMissingSeed,
|
||||
},
|
||||
{
|
||||
name: "seed too short",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc123",
|
||||
},
|
||||
wantErr: ErrInvalidSeed,
|
||||
},
|
||||
{
|
||||
name: "seed with invalid characters",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg",
|
||||
},
|
||||
wantErr: ErrInvalidSeed,
|
||||
},
|
||||
{
|
||||
name: "depth too high",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 17,
|
||||
},
|
||||
wantErr: ErrDepthTooHigh,
|
||||
},
|
||||
{
|
||||
name: "empty ref spec kinds",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{}, FromDepth: 1},
|
||||
},
|
||||
},
|
||||
wantErr: ErrEmptyRefSpecKinds,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.query.Validate()
|
||||
if tt.wantErr == nil {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryDefaults(t *testing.T) {
|
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
q := Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 0, // Should default to 1
|
||||
}
|
||||
|
||||
err := q.Validate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if q.Depth != 1 {
|
||||
t.Errorf("Depth = %d, want 1 (default)", q.Depth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKindsAtDepth(t *testing.T) {
|
||||
q := Query{
|
||||
Method: "follows",
|
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
Depth: 3,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 0}, // From seed
|
||||
{Kinds: []int{6, 16}, FromDepth: 1}, // From depth 1
|
||||
{Kinds: []int{9735}, FromDepth: 2}, // From depth 2
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}, FromDepth: 1},
|
||||
},
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 0
|
||||
kinds0 := q.InboundKindsAtDepth(0)
|
||||
if !kinds0[7] || kinds0[6] || kinds0[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(0) = %v, want only kind 7", kinds0)
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 1
|
||||
kinds1 := q.InboundKindsAtDepth(1)
|
||||
if !kinds1[7] || !kinds1[6] || !kinds1[16] || kinds1[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(1) = %v, want kinds 7, 6, 16", kinds1)
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 2
|
||||
kinds2 := q.InboundKindsAtDepth(2)
|
||||
if !kinds2[7] || !kinds2[6] || !kinds2[16] || !kinds2[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(2) = %v, want all kinds", kinds2)
|
||||
}
|
||||
|
||||
// Test outbound kinds at depth 0
|
||||
outKinds0 := q.OutboundKindsAtDepth(0)
|
||||
if len(outKinds0) != 0 {
|
||||
t.Errorf("OutboundKindsAtDepth(0) = %v, want empty", outKinds0)
|
||||
}
|
||||
|
||||
// Test outbound kinds at depth 1
|
||||
outKinds1 := q.OutboundKindsAtDepth(1)
|
||||
if !outKinds1[1] {
|
||||
t.Errorf("OutboundKindsAtDepth(1) = %v, want kind 1", outKinds1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFromFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filterJSON string
|
||||
wantQuery bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "filter with valid graph query",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":2}}`,
|
||||
wantQuery: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter without graph query",
|
||||
filterJSON: `{"kinds":[1,7]}`,
|
||||
wantQuery: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter with invalid graph query (missing method)",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}}`,
|
||||
wantQuery: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "filter with complex graph query",
|
||||
filterJSON: `{"kinds":[0],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":3,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`,
|
||||
wantQuery: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &filter.F{}
|
||||
_, err := f.Unmarshal([]byte(tt.filterJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal filter: %v", err)
|
||||
}
|
||||
|
||||
q, err := ExtractFromFilter(f)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantQuery && q == nil {
|
||||
t.Error("expected query, got nil")
|
||||
}
|
||||
if !tt.wantQuery && q != nil {
|
||||
t.Errorf("expected nil query, got %+v", q)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGraphQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filterJSON string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "filter with graph query",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"abc"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "filter without graph query",
|
||||
filterJSON: `{"kinds":[1,7]}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "filter with other extension",
|
||||
filterJSON: `{"kinds":[1],"_custom":"value"}`,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &filter.F{}
|
||||
_, err := f.Unmarshal([]byte(tt.filterJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal filter: %v", err)
|
||||
}
|
||||
|
||||
got := IsGraphQuery(f)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsGraphQuery() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryHasRefs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query Query
|
||||
hasInbound bool
|
||||
hasOutbound bool
|
||||
hasRefs bool
|
||||
}{
|
||||
{
|
||||
name: "no refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
},
|
||||
hasInbound: false,
|
||||
hasOutbound: false,
|
||||
hasRefs: false,
|
||||
},
|
||||
{
|
||||
name: "only inbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
},
|
||||
hasInbound: true,
|
||||
hasOutbound: false,
|
||||
hasRefs: true,
|
||||
},
|
||||
{
|
||||
name: "only outbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
hasInbound: false,
|
||||
hasOutbound: true,
|
||||
hasRefs: true,
|
||||
},
|
||||
{
|
||||
name: "both refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
hasInbound: true,
|
||||
hasOutbound: true,
|
||||
hasRefs: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.query.HasInboundRefs(); got != tt.hasInbound {
|
||||
t.Errorf("HasInboundRefs() = %v, want %v", got, tt.hasInbound)
|
||||
}
|
||||
if got := tt.query.HasOutboundRefs(); got != tt.hasOutbound {
|
||||
t.Errorf("HasOutboundRefs() = %v, want %v", got, tt.hasOutbound)
|
||||
}
|
||||
if got := tt.query.HasRefs(); got != tt.hasRefs {
|
||||
t.Errorf("HasRefs() = %v, want %v", got, tt.hasRefs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
282
pkg/protocol/graph/ratelimit.go
Normal file
282
pkg/protocol/graph/ratelimit.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter with adaptive throttling
|
||||
// based on graph query complexity. It allows cooperative scheduling by inserting
|
||||
// pauses between operations to allow other work to proceed.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Token bucket parameters
|
||||
tokens float64 // Current available tokens
|
||||
maxTokens float64 // Maximum token capacity
|
||||
refillRate float64 // Tokens per second to add
|
||||
lastRefill time.Time // Last time tokens were refilled
|
||||
|
||||
// Throttling parameters
|
||||
baseDelay time.Duration // Minimum delay between operations
|
||||
maxDelay time.Duration // Maximum delay for complex queries
|
||||
depthFactor float64 // Multiplier per depth level
|
||||
limitFactor float64 // Multiplier based on result limit
|
||||
}
|
||||
|
||||
// RateLimiterConfig configures the rate limiter behavior.
|
||||
type RateLimiterConfig struct {
|
||||
// MaxTokens is the maximum number of tokens in the bucket (default: 100)
|
||||
MaxTokens float64
|
||||
|
||||
// RefillRate is tokens added per second (default: 10)
|
||||
RefillRate float64
|
||||
|
||||
// BaseDelay is the minimum delay between operations (default: 1ms)
|
||||
BaseDelay time.Duration
|
||||
|
||||
// MaxDelay is the maximum delay for complex queries (default: 100ms)
|
||||
MaxDelay time.Duration
|
||||
|
||||
// DepthFactor is the cost multiplier per depth level (default: 2.0)
|
||||
// A depth-3 query costs 2^3 = 8x more tokens than depth-1
|
||||
DepthFactor float64
|
||||
|
||||
// LimitFactor is additional cost per 100 results requested (default: 0.1)
|
||||
LimitFactor float64
|
||||
}
|
||||
|
||||
// DefaultRateLimiterConfig returns sensible defaults for the rate limiter.
|
||||
func DefaultRateLimiterConfig() RateLimiterConfig {
|
||||
return RateLimiterConfig{
|
||||
MaxTokens: 100.0,
|
||||
RefillRate: 10.0, // Refills fully in 10 seconds
|
||||
BaseDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
DepthFactor: 2.0,
|
||||
LimitFactor: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with the given configuration.
|
||||
func NewRateLimiter(cfg RateLimiterConfig) *RateLimiter {
|
||||
if cfg.MaxTokens <= 0 {
|
||||
cfg.MaxTokens = DefaultRateLimiterConfig().MaxTokens
|
||||
}
|
||||
if cfg.RefillRate <= 0 {
|
||||
cfg.RefillRate = DefaultRateLimiterConfig().RefillRate
|
||||
}
|
||||
if cfg.BaseDelay <= 0 {
|
||||
cfg.BaseDelay = DefaultRateLimiterConfig().BaseDelay
|
||||
}
|
||||
if cfg.MaxDelay <= 0 {
|
||||
cfg.MaxDelay = DefaultRateLimiterConfig().MaxDelay
|
||||
}
|
||||
if cfg.DepthFactor <= 0 {
|
||||
cfg.DepthFactor = DefaultRateLimiterConfig().DepthFactor
|
||||
}
|
||||
if cfg.LimitFactor <= 0 {
|
||||
cfg.LimitFactor = DefaultRateLimiterConfig().LimitFactor
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
tokens: cfg.MaxTokens,
|
||||
maxTokens: cfg.MaxTokens,
|
||||
refillRate: cfg.RefillRate,
|
||||
lastRefill: time.Now(),
|
||||
baseDelay: cfg.BaseDelay,
|
||||
maxDelay: cfg.MaxDelay,
|
||||
depthFactor: cfg.DepthFactor,
|
||||
limitFactor: cfg.LimitFactor,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryCost calculates the token cost for a graph query based on its complexity.
|
||||
// Higher depths and larger limits cost exponentially more tokens.
|
||||
func (rl *RateLimiter) QueryCost(q *Query) float64 {
|
||||
if q == nil {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// Base cost is exponential in depth: depthFactor^depth
|
||||
// This models the exponential growth of traversal work
|
||||
cost := 1.0
|
||||
for i := 0; i < q.Depth; i++ {
|
||||
cost *= rl.depthFactor
|
||||
}
|
||||
|
||||
// Add cost for reference collection (adds ~50% per ref spec)
|
||||
refCost := float64(len(q.InboundRefs)+len(q.OutboundRefs)) * 0.5
|
||||
cost += refCost
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
// OperationCost calculates the token cost for a single traversal operation.
|
||||
// This is used during query execution for per-operation throttling.
|
||||
func (rl *RateLimiter) OperationCost(depth int, nodesAtDepth int) float64 {
|
||||
// Cost increases with depth and number of nodes to process
|
||||
depthMultiplier := 1.0
|
||||
for i := 0; i < depth; i++ {
|
||||
depthMultiplier *= rl.depthFactor
|
||||
}
|
||||
|
||||
// More nodes at this depth = more work
|
||||
nodeFactor := 1.0 + float64(nodesAtDepth)*0.01
|
||||
|
||||
return depthMultiplier * nodeFactor
|
||||
}
|
||||
|
||||
// refillTokens adds tokens based on elapsed time since last refill.
|
||||
func (rl *RateLimiter) refillTokens() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
rl.lastRefill = now
|
||||
|
||||
rl.tokens += elapsed * rl.refillRate
|
||||
if rl.tokens > rl.maxTokens {
|
||||
rl.tokens = rl.maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire tries to acquire tokens for a query. If not enough tokens are available,
|
||||
// it waits until they become available or the context is cancelled.
|
||||
// Returns the delay that was applied, or an error if context was cancelled.
|
||||
func (rl *RateLimiter) Acquire(ctx context.Context, cost float64) (time.Duration, error) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refillTokens()
|
||||
|
||||
var totalDelay time.Duration
|
||||
|
||||
// Wait until we have enough tokens
|
||||
for rl.tokens < cost {
|
||||
// Calculate how long we need to wait for tokens to refill
|
||||
tokensNeeded := cost - rl.tokens
|
||||
waitTime := time.Duration(tokensNeeded/rl.refillRate*1000) * time.Millisecond
|
||||
|
||||
// Clamp to max delay
|
||||
if waitTime > rl.maxDelay {
|
||||
waitTime = rl.maxDelay
|
||||
}
|
||||
if waitTime < rl.baseDelay {
|
||||
waitTime = rl.baseDelay
|
||||
}
|
||||
|
||||
// Release lock while waiting
|
||||
rl.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rl.mu.Lock()
|
||||
return totalDelay, ctx.Err()
|
||||
case <-time.After(waitTime):
|
||||
}
|
||||
|
||||
totalDelay += waitTime
|
||||
rl.mu.Lock()
|
||||
rl.refillTokens()
|
||||
}
|
||||
|
||||
// Consume tokens
|
||||
rl.tokens -= cost
|
||||
return totalDelay, nil
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire tokens without waiting.
|
||||
// Returns true if successful, false if insufficient tokens.
|
||||
func (rl *RateLimiter) TryAcquire(cost float64) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refillTokens()
|
||||
|
||||
if rl.tokens >= cost {
|
||||
rl.tokens -= cost
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Pause inserts a cooperative delay to allow other work to proceed.
|
||||
// The delay is proportional to the current depth and load.
|
||||
// This should be called periodically during long-running traversals.
|
||||
func (rl *RateLimiter) Pause(ctx context.Context, depth int, itemsProcessed int) error {
|
||||
// Calculate adaptive delay based on depth and progress
|
||||
// Deeper traversals and more processed items = longer pauses
|
||||
delay := rl.baseDelay
|
||||
|
||||
// Increase delay with depth
|
||||
for i := 0; i < depth; i++ {
|
||||
delay += rl.baseDelay
|
||||
}
|
||||
|
||||
// Add extra delay every N items to allow other work
|
||||
if itemsProcessed > 0 && itemsProcessed%100 == 0 {
|
||||
delay += rl.baseDelay * 5
|
||||
}
|
||||
|
||||
// Cap at max delay
|
||||
if delay > rl.maxDelay {
|
||||
delay = rl.maxDelay
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AvailableTokens returns the current number of available tokens.
|
||||
func (rl *RateLimiter) AvailableTokens() float64 {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
rl.refillTokens()
|
||||
return rl.tokens
|
||||
}
|
||||
|
||||
// Throttler provides a simple interface for cooperative scheduling during traversal.
|
||||
// It wraps the rate limiter and provides depth-aware throttling.
|
||||
type Throttler struct {
|
||||
rl *RateLimiter
|
||||
depth int
|
||||
itemsProcessed int
|
||||
}
|
||||
|
||||
// NewThrottler creates a throttler for a specific traversal operation.
|
||||
func NewThrottler(rl *RateLimiter, depth int) *Throttler {
|
||||
return &Throttler{
|
||||
rl: rl,
|
||||
depth: depth,
|
||||
}
|
||||
}
|
||||
|
||||
// Tick should be called after processing each item.
|
||||
// It tracks progress and inserts pauses as needed.
|
||||
func (t *Throttler) Tick(ctx context.Context) error {
|
||||
t.itemsProcessed++
|
||||
|
||||
// Insert cooperative pause periodically
|
||||
// More frequent pauses at higher depths
|
||||
interval := 50
|
||||
if t.depth >= 2 {
|
||||
interval = 25
|
||||
}
|
||||
if t.depth >= 4 {
|
||||
interval = 10
|
||||
}
|
||||
|
||||
if t.itemsProcessed%interval == 0 {
|
||||
return t.rl.Pause(ctx, t.depth, t.itemsProcessed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Complete marks the throttler as complete and returns stats.
|
||||
func (t *Throttler) Complete() (itemsProcessed int) {
|
||||
return t.itemsProcessed
|
||||
}
|
||||
267
pkg/protocol/graph/ratelimit_test.go
Normal file
267
pkg/protocol/graph/ratelimit_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiterQueryCost(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query *Query
|
||||
minCost float64
|
||||
maxCost float64
|
||||
}{
|
||||
{
|
||||
name: "nil query",
|
||||
query: nil,
|
||||
minCost: 1.0,
|
||||
maxCost: 1.0,
|
||||
},
|
||||
{
|
||||
name: "depth 1 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 1,
|
||||
},
|
||||
minCost: 1.5, // depthFactor^1 = 2
|
||||
maxCost: 2.5,
|
||||
},
|
||||
{
|
||||
name: "depth 2 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
},
|
||||
minCost: 3.5, // depthFactor^2 = 4
|
||||
maxCost: 4.5,
|
||||
},
|
||||
{
|
||||
name: "depth 3 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 3,
|
||||
},
|
||||
minCost: 7.5, // depthFactor^3 = 8
|
||||
maxCost: 8.5,
|
||||
},
|
||||
{
|
||||
name: "depth 2 with inbound refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
},
|
||||
minCost: 4.0, // 4 + 0.5 = 4.5
|
||||
maxCost: 5.0,
|
||||
},
|
||||
{
|
||||
name: "depth 2 with both refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
minCost: 4.5, // 4 + 0.5 + 0.5 = 5
|
||||
maxCost: 5.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost := rl.QueryCost(tt.query)
|
||||
if cost < tt.minCost || cost > tt.maxCost {
|
||||
t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterOperationCost(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
|
||||
// Depth 0, 1 node
|
||||
cost0 := rl.OperationCost(0, 1)
|
||||
if cost0 < 1.0 || cost0 > 1.1 {
|
||||
t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
|
||||
}
|
||||
|
||||
// Depth 1, 1 node
|
||||
cost1 := rl.OperationCost(1, 1)
|
||||
if cost1 < 2.0 || cost1 > 2.1 {
|
||||
t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
|
||||
}
|
||||
|
||||
// Depth 2, 100 nodes
|
||||
cost2 := rl.OperationCost(2, 100)
|
||||
if cost2 < 8.0 {
|
||||
t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterAcquire(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
cfg.RefillRate = 100 // Fast refill for testing
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Should acquire immediately when tokens available
|
||||
delay, err := rl.Acquire(ctx, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if delay > time.Millisecond*10 {
|
||||
t.Errorf("expected minimal delay, got %v", delay)
|
||||
}
|
||||
|
||||
// Check remaining tokens
|
||||
remaining := rl.AvailableTokens()
|
||||
if remaining > 6 {
|
||||
t.Errorf("expected ~5 tokens remaining, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterTryAcquire(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Should succeed with enough tokens
|
||||
if !rl.TryAcquire(5) {
|
||||
t.Error("TryAcquire(5) should succeed with 10 tokens")
|
||||
}
|
||||
|
||||
// Should succeed again
|
||||
if !rl.TryAcquire(5) {
|
||||
t.Error("TryAcquire(5) should succeed with 5 tokens")
|
||||
}
|
||||
|
||||
// Should fail with insufficient tokens
|
||||
if rl.TryAcquire(1) {
|
||||
t.Error("TryAcquire(1) should fail with 0 tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterContextCancellation(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 1
|
||||
cfg.RefillRate = 0.1 // Very slow refill
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(1)
|
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Try to acquire - should be cancelled
|
||||
_, err := rl.Acquire(ctx, 10)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterRefill(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
cfg.RefillRate = 1000 // 1000 tokens per second
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(10)
|
||||
|
||||
// Wait for refill
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should have some tokens now
|
||||
available := rl.AvailableTokens()
|
||||
if available < 5 {
|
||||
t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
|
||||
}
|
||||
if available > 10 {
|
||||
t.Errorf("expected <= 10 tokens (max), got %v", available)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterPause(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := rl.Pause(ctx, 1, 0)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have paused for at least baseDelay
|
||||
if elapsed < rl.baseDelay {
|
||||
t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottler(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.BaseDelay = 100 * time.Microsecond // Short for testing
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
throttler := NewThrottler(rl, 1)
|
||||
ctx := context.Background()
|
||||
|
||||
// Process items
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := throttler.Tick(ctx); err != nil {
|
||||
t.Fatalf("unexpected error at tick %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
processed := throttler.Complete()
|
||||
if processed != 100 {
|
||||
t.Errorf("expected 100 items processed, got %d", processed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottlerContextCancellation(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Process some items
|
||||
for i := 0; i < 20; i++ {
|
||||
throttler.Tick(ctx)
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Next tick that would pause should return error
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := throttler.Tick(ctx); err != nil {
|
||||
// Expected - context was cancelled
|
||||
return
|
||||
}
|
||||
}
|
||||
// If we get here without error, the throttler didn't check context
|
||||
// This is acceptable if no pause was needed
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
v0.33.1
|
||||
v0.34.0
|
||||
Reference in New Issue
Block a user