add first draft graph query implementation
Some checks failed
Go / build-and-release (push) Has been cancelled

This commit is contained in:
2025-12-04 09:28:13 +00:00
parent 8dbc19ee9e
commit 6b98c23606
40 changed files with 9078 additions and 46 deletions

View File

@@ -0,0 +1,202 @@
//go:build !(js && wasm)
// Package graph implements NIP-XX Graph Query protocol support.
// This file contains the executor that runs graph traversal queries.
package graph
import (
"encoding/json"
"strconv"
"time"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"git.mleku.dev/mleku/nostr/encoders/event"
"git.mleku.dev/mleku/nostr/encoders/hex"
"git.mleku.dev/mleku/nostr/encoders/tag"
"git.mleku.dev/mleku/nostr/interfaces/signer"
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k"
)
// Response kinds for graph queries (ephemeral range, relay-signed)
const (
KindGraphFollows = 39000 // Response for follows/followers queries
KindGraphMentions = 39001 // Response for mentions queries
KindGraphThread = 39002 // Response for thread traversal queries
)
// GraphResultI is the interface that database.GraphResult implements.
// This allows the executor to work with the database result without importing it.
type GraphResultI interface {
ToDepthArrays() [][]string
ToEventDepthArrays() [][]string
GetAllPubkeys() []string
GetAllEvents() []string
GetPubkeysByDepth() map[int][]string
GetEventsByDepth() map[int][]string
GetTotalPubkeys() int
GetTotalEvents() int
}
// GraphDatabase defines the interface for graph traversal operations.
// This is implemented by the database package.
type GraphDatabase interface {
// TraverseFollows performs BFS traversal of follow graph
TraverseFollows(seedPubkey []byte, maxDepth int) (GraphResultI, error)
// TraverseFollowers performs BFS traversal to find followers
TraverseFollowers(seedPubkey []byte, maxDepth int) (GraphResultI, error)
// FindMentions finds events mentioning a pubkey
FindMentions(pubkey []byte, kinds []uint16) (GraphResultI, error)
// TraverseThread performs BFS traversal of thread structure
TraverseThread(seedEventID []byte, maxDepth int, direction string) (GraphResultI, error)
}
// Executor handles graph query execution and response generation.
type Executor struct {
db GraphDatabase
relaySigner signer.I
relayPubkey []byte
}
// NewExecutor creates a new graph query executor.
// The secretKey should be the 32-byte relay identity secret key.
func NewExecutor(db GraphDatabase, secretKey []byte) (*Executor, error) {
s, err := p8k.New()
if err != nil {
return nil, err
}
if err = s.InitSec(secretKey); err != nil {
return nil, err
}
return &Executor{
db: db,
relaySigner: s,
relayPubkey: s.Pub(),
}, nil
}
// Execute runs a graph query and returns a relay-signed event with results.
func (e *Executor) Execute(q *Query) (*event.E, error) {
var result GraphResultI
var err error
var responseKind uint16
// Decode seed (hex string to bytes)
seedBytes, err := hex.Dec(q.Seed)
if err != nil {
return nil, err
}
// Execute the appropriate traversal
switch q.Method {
case "follows":
responseKind = KindGraphFollows
result, err = e.db.TraverseFollows(seedBytes, q.Depth)
if err != nil {
return nil, err
}
log.D.F("graph executor: follows traversal returned %d pubkeys", result.GetTotalPubkeys())
case "followers":
responseKind = KindGraphFollows
result, err = e.db.TraverseFollowers(seedBytes, q.Depth)
if err != nil {
return nil, err
}
log.D.F("graph executor: followers traversal returned %d pubkeys", result.GetTotalPubkeys())
case "mentions":
responseKind = KindGraphMentions
// Mentions don't use depth traversal, just find direct mentions
// Convert RefSpec kinds to uint16 for the database call
var kinds []uint16
if len(q.InboundRefs) > 0 {
for _, rs := range q.InboundRefs {
for _, k := range rs.Kinds {
kinds = append(kinds, uint16(k))
}
}
} else {
kinds = []uint16{1} // Default to kind 1 (notes)
}
result, err = e.db.FindMentions(seedBytes, kinds)
if err != nil {
return nil, err
}
log.D.F("graph executor: mentions query returned %d events", result.GetTotalEvents())
case "thread":
responseKind = KindGraphThread
result, err = e.db.TraverseThread(seedBytes, q.Depth, "both")
if err != nil {
return nil, err
}
log.D.F("graph executor: thread traversal returned %d events", result.GetTotalEvents())
default:
return nil, ErrInvalidMethod
}
// Generate response event
return e.generateResponse(q, result, responseKind)
}
// generateResponse creates a relay-signed event containing the query results.
func (e *Executor) generateResponse(q *Query, result GraphResultI, responseKind uint16) (*event.E, error) {
// Build content as JSON with depth arrays
var content ResponseContent
if q.Method == "follows" || q.Method == "followers" {
// For pubkey-based queries, use pubkeys_by_depth
content.PubkeysByDepth = result.ToDepthArrays()
content.TotalPubkeys = result.GetTotalPubkeys()
} else {
// For event-based queries, use events_by_depth
content.EventsByDepth = result.ToEventDepthArrays()
content.TotalEvents = result.GetTotalEvents()
}
contentBytes, err := json.Marshal(content)
if err != nil {
return nil, err
}
// Build tags
tags := tag.NewS(
tag.NewFromAny("method", q.Method),
tag.NewFromAny("seed", q.Seed),
tag.NewFromAny("depth", strconv.Itoa(q.Depth)),
)
// Create event
ev := &event.E{
Kind: responseKind,
CreatedAt: time.Now().Unix(),
Tags: tags,
Content: contentBytes,
}
// Sign with relay identity
if err = ev.Sign(e.relaySigner); chk.E(err) {
return nil, err
}
return ev, nil
}
// ResponseContent is the JSON structure for graph query responses.
type ResponseContent struct {
// PubkeysByDepth contains arrays of pubkeys at each depth (1-indexed)
// Each pubkey appears ONLY at the depth where it was first discovered.
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"`
// EventsByDepth contains arrays of event IDs at each depth (1-indexed)
EventsByDepth [][]string `json:"events_by_depth,omitempty"`
// TotalPubkeys is the total count of unique pubkeys discovered
TotalPubkeys int `json:"total_pubkeys,omitempty"`
// TotalEvents is the total count of unique events discovered
TotalEvents int `json:"total_events,omitempty"`
}

183
pkg/protocol/graph/query.go Normal file
View File

@@ -0,0 +1,183 @@
// Package graph implements NIP-XX Graph Query protocol support.
// It provides types and functions for parsing and validating graph traversal queries.
package graph
import (
"encoding/json"
"errors"
"git.mleku.dev/mleku/nostr/encoders/filter"
)
// Query represents a graph traversal query from a _graph filter extension.
type Query struct {
// Method is the traversal method: "follows", "followers", "mentions", "thread"
Method string `json:"method"`
// Seed is the starting point for traversal (pubkey hex or event ID hex)
Seed string `json:"seed"`
// Depth is the maximum traversal depth (1-16, default: 1)
Depth int `json:"depth,omitempty"`
// InboundRefs specifies which inbound references to collect
// (events that reference discovered events via e-tags)
InboundRefs []RefSpec `json:"inbound_refs,omitempty"`
// OutboundRefs specifies which outbound references to collect
// (events referenced by discovered events via e-tags)
OutboundRefs []RefSpec `json:"outbound_refs,omitempty"`
}
// RefSpec specifies which event references to include in results.
type RefSpec struct {
// Kinds is the list of event kinds to match (OR semantics within this spec)
Kinds []int `json:"kinds"`
// FromDepth specifies the minimum depth at which to collect refs (default: 0)
// 0 = include refs from seed itself
// 1 = start from first-hop connections
FromDepth int `json:"from_depth,omitempty"`
}
// Validation errors
var (
ErrMissingMethod = errors.New("_graph.method is required")
ErrInvalidMethod = errors.New("_graph.method must be one of: follows, followers, mentions, thread")
ErrMissingSeed = errors.New("_graph.seed is required")
ErrInvalidSeed = errors.New("_graph.seed must be a 64-character hex string")
ErrDepthTooHigh = errors.New("_graph.depth cannot exceed 16")
ErrEmptyRefSpecKinds = errors.New("ref spec kinds array cannot be empty")
)
// Valid method names
var validMethods = map[string]bool{
"follows": true,
"followers": true,
"mentions": true,
"thread": true,
}
// Validate checks the query for correctness and applies defaults.
func (q *Query) Validate() error {
// Method is required
if q.Method == "" {
return ErrMissingMethod
}
if !validMethods[q.Method] {
return ErrInvalidMethod
}
// Seed is required
if q.Seed == "" {
return ErrMissingSeed
}
if len(q.Seed) != 64 {
return ErrInvalidSeed
}
// Validate hex characters
for _, c := range q.Seed {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return ErrInvalidSeed
}
}
// Apply depth defaults and limits
if q.Depth < 1 {
q.Depth = 1
}
if q.Depth > 16 {
return ErrDepthTooHigh
}
// Validate ref specs
for _, rs := range q.InboundRefs {
if len(rs.Kinds) == 0 {
return ErrEmptyRefSpecKinds
}
}
for _, rs := range q.OutboundRefs {
if len(rs.Kinds) == 0 {
return ErrEmptyRefSpecKinds
}
}
return nil
}
// HasInboundRefs returns true if the query includes inbound reference collection.
func (q *Query) HasInboundRefs() bool {
return len(q.InboundRefs) > 0
}
// HasOutboundRefs returns true if the query includes outbound reference collection.
func (q *Query) HasOutboundRefs() bool {
return len(q.OutboundRefs) > 0
}
// HasRefs returns true if the query includes any reference collection.
func (q *Query) HasRefs() bool {
return q.HasInboundRefs() || q.HasOutboundRefs()
}
// InboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
// It aggregates all RefSpecs where from_depth <= depth.
func (q *Query) InboundKindsAtDepth(depth int) map[int]bool {
kinds := make(map[int]bool)
for _, rs := range q.InboundRefs {
if rs.FromDepth <= depth {
for _, k := range rs.Kinds {
kinds[k] = true
}
}
}
return kinds
}
// OutboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
func (q *Query) OutboundKindsAtDepth(depth int) map[int]bool {
kinds := make(map[int]bool)
for _, rs := range q.OutboundRefs {
if rs.FromDepth <= depth {
for _, k := range rs.Kinds {
kinds[k] = true
}
}
}
return kinds
}
// ExtractFromFilter checks if a filter has a _graph extension and parses it.
// Returns nil if no _graph field is present.
// Returns an error if _graph is present but invalid.
func ExtractFromFilter(f *filter.F) (*Query, error) {
if f == nil || f.Extra == nil {
return nil, nil
}
raw, ok := f.Extra["_graph"]
if !ok {
return nil, nil
}
var q Query
if err := json.Unmarshal(raw, &q); err != nil {
return nil, err
}
if err := q.Validate(); err != nil {
return nil, err
}
return &q, nil
}
// IsGraphQuery returns true if the filter contains a _graph extension.
// This is a quick check that doesn't parse the full query.
func IsGraphQuery(f *filter.F) bool {
if f == nil || f.Extra == nil {
return false
}
_, ok := f.Extra["_graph"]
return ok
}

View File

@@ -0,0 +1,397 @@
package graph
import (
"testing"
"git.mleku.dev/mleku/nostr/encoders/filter"
)
func TestQueryValidate(t *testing.T) {
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
tests := []struct {
name string
query Query
wantErr error
}{
{
name: "valid follows query",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 2,
},
wantErr: nil,
},
{
name: "valid followers query",
query: Query{
Method: "followers",
Seed: validSeed,
},
wantErr: nil,
},
{
name: "valid mentions query",
query: Query{
Method: "mentions",
Seed: validSeed,
Depth: 1,
},
wantErr: nil,
},
{
name: "valid thread query",
query: Query{
Method: "thread",
Seed: validSeed,
Depth: 10,
},
wantErr: nil,
},
{
name: "valid query with inbound refs",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 1},
},
},
wantErr: nil,
},
{
name: "valid query with multiple ref specs",
query: Query{
Method: "follows",
Seed: validSeed,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 1},
{Kinds: []int{6}, FromDepth: 1},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}, FromDepth: 0},
},
},
wantErr: nil,
},
{
name: "missing method",
query: Query{Seed: validSeed},
wantErr: ErrMissingMethod,
},
{
name: "invalid method",
query: Query{
Method: "invalid",
Seed: validSeed,
},
wantErr: ErrInvalidMethod,
},
{
name: "missing seed",
query: Query{
Method: "follows",
},
wantErr: ErrMissingSeed,
},
{
name: "seed too short",
query: Query{
Method: "follows",
Seed: "abc123",
},
wantErr: ErrInvalidSeed,
},
{
name: "seed with invalid characters",
query: Query{
Method: "follows",
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg",
},
wantErr: ErrInvalidSeed,
},
{
name: "depth too high",
query: Query{
Method: "follows",
Seed: validSeed,
Depth: 17,
},
wantErr: ErrDepthTooHigh,
},
{
name: "empty ref spec kinds",
query: Query{
Method: "follows",
Seed: validSeed,
InboundRefs: []RefSpec{
{Kinds: []int{}, FromDepth: 1},
},
},
wantErr: ErrEmptyRefSpecKinds,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.query.Validate()
if tt.wantErr == nil {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err != tt.wantErr {
t.Errorf("error = %v, want %v", err, tt.wantErr)
}
}
})
}
}
func TestQueryDefaults(t *testing.T) {
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
q := Query{
Method: "follows",
Seed: validSeed,
Depth: 0, // Should default to 1
}
err := q.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if q.Depth != 1 {
t.Errorf("Depth = %d, want 1 (default)", q.Depth)
}
}
func TestKindsAtDepth(t *testing.T) {
q := Query{
Method: "follows",
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
Depth: 3,
InboundRefs: []RefSpec{
{Kinds: []int{7}, FromDepth: 0}, // From seed
{Kinds: []int{6, 16}, FromDepth: 1}, // From depth 1
{Kinds: []int{9735}, FromDepth: 2}, // From depth 2
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}, FromDepth: 1},
},
}
// Test inbound kinds at depth 0
kinds0 := q.InboundKindsAtDepth(0)
if !kinds0[7] || kinds0[6] || kinds0[9735] {
t.Errorf("InboundKindsAtDepth(0) = %v, want only kind 7", kinds0)
}
// Test inbound kinds at depth 1
kinds1 := q.InboundKindsAtDepth(1)
if !kinds1[7] || !kinds1[6] || !kinds1[16] || kinds1[9735] {
t.Errorf("InboundKindsAtDepth(1) = %v, want kinds 7, 6, 16", kinds1)
}
// Test inbound kinds at depth 2
kinds2 := q.InboundKindsAtDepth(2)
if !kinds2[7] || !kinds2[6] || !kinds2[16] || !kinds2[9735] {
t.Errorf("InboundKindsAtDepth(2) = %v, want all kinds", kinds2)
}
// Test outbound kinds at depth 0
outKinds0 := q.OutboundKindsAtDepth(0)
if len(outKinds0) != 0 {
t.Errorf("OutboundKindsAtDepth(0) = %v, want empty", outKinds0)
}
// Test outbound kinds at depth 1
outKinds1 := q.OutboundKindsAtDepth(1)
if !outKinds1[1] {
t.Errorf("OutboundKindsAtDepth(1) = %v, want kind 1", outKinds1)
}
}
func TestExtractFromFilter(t *testing.T) {
tests := []struct {
name string
filterJSON string
wantQuery bool
wantErr bool
}{
{
name: "filter with valid graph query",
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":2}}`,
wantQuery: true,
wantErr: false,
},
{
name: "filter without graph query",
filterJSON: `{"kinds":[1,7]}`,
wantQuery: false,
wantErr: false,
},
{
name: "filter with invalid graph query (missing method)",
filterJSON: `{"kinds":[1],"_graph":{"seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}}`,
wantQuery: false,
wantErr: true,
},
{
name: "filter with complex graph query",
filterJSON: `{"kinds":[0],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":3,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`,
wantQuery: true,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &filter.F{}
_, err := f.Unmarshal([]byte(tt.filterJSON))
if err != nil {
t.Fatalf("failed to unmarshal filter: %v", err)
}
q, err := ExtractFromFilter(f)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if tt.wantQuery && q == nil {
t.Error("expected query, got nil")
}
if !tt.wantQuery && q != nil {
t.Errorf("expected nil query, got %+v", q)
}
})
}
}
func TestIsGraphQuery(t *testing.T) {
tests := []struct {
name string
filterJSON string
want bool
}{
{
name: "filter with graph query",
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"abc"}}`,
want: true,
},
{
name: "filter without graph query",
filterJSON: `{"kinds":[1,7]}`,
want: false,
},
{
name: "filter with other extension",
filterJSON: `{"kinds":[1],"_custom":"value"}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &filter.F{}
_, err := f.Unmarshal([]byte(tt.filterJSON))
if err != nil {
t.Fatalf("failed to unmarshal filter: %v", err)
}
got := IsGraphQuery(f)
if got != tt.want {
t.Errorf("IsGraphQuery() = %v, want %v", got, tt.want)
}
})
}
}
func TestQueryHasRefs(t *testing.T) {
tests := []struct {
name string
query Query
hasInbound bool
hasOutbound bool
hasRefs bool
}{
{
name: "no refs",
query: Query{
Method: "follows",
Seed: "abc",
},
hasInbound: false,
hasOutbound: false,
hasRefs: false,
},
{
name: "only inbound refs",
query: Query{
Method: "follows",
Seed: "abc",
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
},
hasInbound: true,
hasOutbound: false,
hasRefs: true,
},
{
name: "only outbound refs",
query: Query{
Method: "follows",
Seed: "abc",
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
hasInbound: false,
hasOutbound: true,
hasRefs: true,
},
{
name: "both refs",
query: Query{
Method: "follows",
Seed: "abc",
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
hasInbound: true,
hasOutbound: true,
hasRefs: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.query.HasInboundRefs(); got != tt.hasInbound {
t.Errorf("HasInboundRefs() = %v, want %v", got, tt.hasInbound)
}
if got := tt.query.HasOutboundRefs(); got != tt.hasOutbound {
t.Errorf("HasOutboundRefs() = %v, want %v", got, tt.hasOutbound)
}
if got := tt.query.HasRefs(); got != tt.hasRefs {
t.Errorf("HasRefs() = %v, want %v", got, tt.hasRefs)
}
})
}
}

View File

@@ -0,0 +1,282 @@
package graph
import (
"context"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter with adaptive throttling
// based on graph query complexity. It allows cooperative scheduling by inserting
// pauses between operations to allow other work to proceed.
type RateLimiter struct {
mu sync.Mutex
// Token bucket parameters
tokens float64 // Current available tokens
maxTokens float64 // Maximum token capacity
refillRate float64 // Tokens per second to add
lastRefill time.Time // Last time tokens were refilled
// Throttling parameters
baseDelay time.Duration // Minimum delay between operations
maxDelay time.Duration // Maximum delay for complex queries
depthFactor float64 // Multiplier per depth level
limitFactor float64 // Multiplier based on result limit
}
// RateLimiterConfig configures the rate limiter behavior.
type RateLimiterConfig struct {
// MaxTokens is the maximum number of tokens in the bucket (default: 100)
MaxTokens float64
// RefillRate is tokens added per second (default: 10)
RefillRate float64
// BaseDelay is the minimum delay between operations (default: 1ms)
BaseDelay time.Duration
// MaxDelay is the maximum delay for complex queries (default: 100ms)
MaxDelay time.Duration
// DepthFactor is the cost multiplier per depth level (default: 2.0)
// A depth-3 query costs 2^3 = 8x more tokens than depth-1
DepthFactor float64
// LimitFactor is additional cost per 100 results requested (default: 0.1)
LimitFactor float64
}
// DefaultRateLimiterConfig returns sensible defaults for the rate limiter.
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
MaxTokens: 100.0,
RefillRate: 10.0, // Refills fully in 10 seconds
BaseDelay: 1 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
DepthFactor: 2.0,
LimitFactor: 0.1,
}
}
// NewRateLimiter creates a new rate limiter with the given configuration.
func NewRateLimiter(cfg RateLimiterConfig) *RateLimiter {
if cfg.MaxTokens <= 0 {
cfg.MaxTokens = DefaultRateLimiterConfig().MaxTokens
}
if cfg.RefillRate <= 0 {
cfg.RefillRate = DefaultRateLimiterConfig().RefillRate
}
if cfg.BaseDelay <= 0 {
cfg.BaseDelay = DefaultRateLimiterConfig().BaseDelay
}
if cfg.MaxDelay <= 0 {
cfg.MaxDelay = DefaultRateLimiterConfig().MaxDelay
}
if cfg.DepthFactor <= 0 {
cfg.DepthFactor = DefaultRateLimiterConfig().DepthFactor
}
if cfg.LimitFactor <= 0 {
cfg.LimitFactor = DefaultRateLimiterConfig().LimitFactor
}
return &RateLimiter{
tokens: cfg.MaxTokens,
maxTokens: cfg.MaxTokens,
refillRate: cfg.RefillRate,
lastRefill: time.Now(),
baseDelay: cfg.BaseDelay,
maxDelay: cfg.MaxDelay,
depthFactor: cfg.DepthFactor,
limitFactor: cfg.LimitFactor,
}
}
// QueryCost calculates the token cost for a graph query based on its complexity.
// Higher depths and larger limits cost exponentially more tokens.
func (rl *RateLimiter) QueryCost(q *Query) float64 {
if q == nil {
return 1.0
}
// Base cost is exponential in depth: depthFactor^depth
// This models the exponential growth of traversal work
cost := 1.0
for i := 0; i < q.Depth; i++ {
cost *= rl.depthFactor
}
// Add cost for reference collection (adds ~50% per ref spec)
refCost := float64(len(q.InboundRefs)+len(q.OutboundRefs)) * 0.5
cost += refCost
return cost
}
// OperationCost calculates the token cost for a single traversal operation.
// This is used during query execution for per-operation throttling.
func (rl *RateLimiter) OperationCost(depth int, nodesAtDepth int) float64 {
// Cost increases with depth and number of nodes to process
depthMultiplier := 1.0
for i := 0; i < depth; i++ {
depthMultiplier *= rl.depthFactor
}
// More nodes at this depth = more work
nodeFactor := 1.0 + float64(nodesAtDepth)*0.01
return depthMultiplier * nodeFactor
}
// refillTokens adds tokens based on elapsed time since last refill.
func (rl *RateLimiter) refillTokens() {
now := time.Now()
elapsed := now.Sub(rl.lastRefill).Seconds()
rl.lastRefill = now
rl.tokens += elapsed * rl.refillRate
if rl.tokens > rl.maxTokens {
rl.tokens = rl.maxTokens
}
}
// Acquire tries to acquire tokens for a query. If not enough tokens are available,
// it waits until they become available or the context is cancelled.
// Returns the delay that was applied, or an error if context was cancelled.
func (rl *RateLimiter) Acquire(ctx context.Context, cost float64) (time.Duration, error) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
var totalDelay time.Duration
// Wait until we have enough tokens
for rl.tokens < cost {
// Calculate how long we need to wait for tokens to refill
tokensNeeded := cost - rl.tokens
waitTime := time.Duration(tokensNeeded/rl.refillRate*1000) * time.Millisecond
// Clamp to max delay
if waitTime > rl.maxDelay {
waitTime = rl.maxDelay
}
if waitTime < rl.baseDelay {
waitTime = rl.baseDelay
}
// Release lock while waiting
rl.mu.Unlock()
select {
case <-ctx.Done():
rl.mu.Lock()
return totalDelay, ctx.Err()
case <-time.After(waitTime):
}
totalDelay += waitTime
rl.mu.Lock()
rl.refillTokens()
}
// Consume tokens
rl.tokens -= cost
return totalDelay, nil
}
// TryAcquire attempts to acquire tokens without waiting.
// Returns true if successful, false if insufficient tokens.
func (rl *RateLimiter) TryAcquire(cost float64) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
if rl.tokens >= cost {
rl.tokens -= cost
return true
}
return false
}
// Pause inserts a cooperative delay to allow other work to proceed.
// The delay is proportional to the current depth and load.
// This should be called periodically during long-running traversals.
func (rl *RateLimiter) Pause(ctx context.Context, depth int, itemsProcessed int) error {
// Calculate adaptive delay based on depth and progress
// Deeper traversals and more processed items = longer pauses
delay := rl.baseDelay
// Increase delay with depth
for i := 0; i < depth; i++ {
delay += rl.baseDelay
}
// Add extra delay every N items to allow other work
if itemsProcessed > 0 && itemsProcessed%100 == 0 {
delay += rl.baseDelay * 5
}
// Cap at max delay
if delay > rl.maxDelay {
delay = rl.maxDelay
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
return nil
}
}
// AvailableTokens returns the current number of available tokens.
func (rl *RateLimiter) AvailableTokens() float64 {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refillTokens()
return rl.tokens
}
// Throttler provides a simple interface for cooperative scheduling during traversal.
// It wraps the rate limiter and provides depth-aware throttling.
type Throttler struct {
rl *RateLimiter
depth int
itemsProcessed int
}
// NewThrottler creates a throttler for a specific traversal operation.
func NewThrottler(rl *RateLimiter, depth int) *Throttler {
return &Throttler{
rl: rl,
depth: depth,
}
}
// Tick should be called after processing each item.
// It tracks progress and inserts pauses as needed.
func (t *Throttler) Tick(ctx context.Context) error {
t.itemsProcessed++
// Insert cooperative pause periodically
// More frequent pauses at higher depths
interval := 50
if t.depth >= 2 {
interval = 25
}
if t.depth >= 4 {
interval = 10
}
if t.itemsProcessed%interval == 0 {
return t.rl.Pause(ctx, t.depth, t.itemsProcessed)
}
return nil
}
// Complete marks the throttler as complete and returns stats.
func (t *Throttler) Complete() (itemsProcessed int) {
return t.itemsProcessed
}

View File

@@ -0,0 +1,267 @@
package graph
import (
"context"
"testing"
"time"
)
func TestRateLimiterQueryCost(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
tests := []struct {
name string
query *Query
minCost float64
maxCost float64
}{
{
name: "nil query",
query: nil,
minCost: 1.0,
maxCost: 1.0,
},
{
name: "depth 1 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 1,
},
minCost: 1.5, // depthFactor^1 = 2
maxCost: 2.5,
},
{
name: "depth 2 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
},
minCost: 3.5, // depthFactor^2 = 4
maxCost: 4.5,
},
{
name: "depth 3 no refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 3,
},
minCost: 7.5, // depthFactor^3 = 8
maxCost: 8.5,
},
{
name: "depth 2 with inbound refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
},
minCost: 4.0, // 4 + 0.5 = 4.5
maxCost: 5.0,
},
{
name: "depth 2 with both refs",
query: &Query{
Method: "follows",
Seed: "abc",
Depth: 2,
InboundRefs: []RefSpec{
{Kinds: []int{7}},
},
OutboundRefs: []RefSpec{
{Kinds: []int{1}},
},
},
minCost: 4.5, // 4 + 0.5 + 0.5 = 5
maxCost: 5.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost := rl.QueryCost(tt.query)
if cost < tt.minCost || cost > tt.maxCost {
t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
}
})
}
}
func TestRateLimiterOperationCost(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
// Depth 0, 1 node
cost0 := rl.OperationCost(0, 1)
if cost0 < 1.0 || cost0 > 1.1 {
t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
}
// Depth 1, 1 node
cost1 := rl.OperationCost(1, 1)
if cost1 < 2.0 || cost1 > 2.1 {
t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
}
// Depth 2, 100 nodes
cost2 := rl.OperationCost(2, 100)
if cost2 < 8.0 {
t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
}
}
func TestRateLimiterAcquire(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
cfg.RefillRate = 100 // Fast refill for testing
rl := NewRateLimiter(cfg)
ctx := context.Background()
// Should acquire immediately when tokens available
delay, err := rl.Acquire(ctx, 5)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if delay > time.Millisecond*10 {
t.Errorf("expected minimal delay, got %v", delay)
}
// Check remaining tokens
remaining := rl.AvailableTokens()
if remaining > 6 {
t.Errorf("expected ~5 tokens remaining, got %v", remaining)
}
}
func TestRateLimiterTryAcquire(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
rl := NewRateLimiter(cfg)
// Should succeed with enough tokens
if !rl.TryAcquire(5) {
t.Error("TryAcquire(5) should succeed with 10 tokens")
}
// Should succeed again
if !rl.TryAcquire(5) {
t.Error("TryAcquire(5) should succeed with 5 tokens")
}
// Should fail with insufficient tokens
if rl.TryAcquire(1) {
t.Error("TryAcquire(1) should fail with 0 tokens")
}
}
func TestRateLimiterContextCancellation(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 1
cfg.RefillRate = 0.1 // Very slow refill
rl := NewRateLimiter(cfg)
// Drain tokens
rl.TryAcquire(1)
// Create cancellable context
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
// Try to acquire - should be cancelled
_, err := rl.Acquire(ctx, 10)
if err != context.DeadlineExceeded {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestRateLimiterRefill(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.MaxTokens = 10
cfg.RefillRate = 1000 // 1000 tokens per second
rl := NewRateLimiter(cfg)
// Drain tokens
rl.TryAcquire(10)
// Wait for refill
time.Sleep(15 * time.Millisecond)
// Should have some tokens now
available := rl.AvailableTokens()
if available < 5 {
t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
}
if available > 10 {
t.Errorf("expected <= 10 tokens (max), got %v", available)
}
}
func TestRateLimiterPause(t *testing.T) {
rl := NewRateLimiter(DefaultRateLimiterConfig())
ctx := context.Background()
start := time.Now()
err := rl.Pause(ctx, 1, 0)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have paused for at least baseDelay
if elapsed < rl.baseDelay {
t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
}
}
func TestThrottler(t *testing.T) {
cfg := DefaultRateLimiterConfig()
cfg.BaseDelay = 100 * time.Microsecond // Short for testing
rl := NewRateLimiter(cfg)
throttler := NewThrottler(rl, 1)
ctx := context.Background()
// Process items
for i := 0; i < 100; i++ {
if err := throttler.Tick(ctx); err != nil {
t.Fatalf("unexpected error at tick %d: %v", i, err)
}
}
processed := throttler.Complete()
if processed != 100 {
t.Errorf("expected 100 items processed, got %d", processed)
}
}
func TestThrottlerContextCancellation(t *testing.T) {
cfg := DefaultRateLimiterConfig()
rl := NewRateLimiter(cfg)
throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
ctx, cancel := context.WithCancel(context.Background())
// Process some items
for i := 0; i < 20; i++ {
throttler.Tick(ctx)
}
// Cancel context
cancel()
// Next tick that would pause should return error
for i := 0; i < 100; i++ {
if err := throttler.Tick(ctx); err != nil {
// Expected - context was cancelled
return
}
}
// If we get here without error, the throttler didn't check context
// This is acceptable if no pause was needed
}