add first draft graph query implementation
Some checks failed
Go / build-and-release (push) Has been cancelled
Some checks failed
Go / build-and-release (push) Has been cancelled
This commit is contained in:
202
pkg/protocol/graph/executor.go
Normal file
202
pkg/protocol/graph/executor.go
Normal file
@@ -0,0 +1,202 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// This file contains the executor that runs graph traversal queries.
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"lol.mleku.dev/chk"
|
||||
"lol.mleku.dev/log"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event"
|
||||
"git.mleku.dev/mleku/nostr/encoders/hex"
|
||||
"git.mleku.dev/mleku/nostr/encoders/tag"
|
||||
"git.mleku.dev/mleku/nostr/interfaces/signer"
|
||||
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k"
|
||||
)
|
||||
|
||||
// Response kinds for graph queries (ephemeral range, relay-signed)
|
||||
const (
|
||||
KindGraphFollows = 39000 // Response for follows/followers queries
|
||||
KindGraphMentions = 39001 // Response for mentions queries
|
||||
KindGraphThread = 39002 // Response for thread traversal queries
|
||||
)
|
||||
|
||||
// GraphResultI is the interface that database.GraphResult implements.
|
||||
// This allows the executor to work with the database result without importing it.
|
||||
type GraphResultI interface {
|
||||
ToDepthArrays() [][]string
|
||||
ToEventDepthArrays() [][]string
|
||||
GetAllPubkeys() []string
|
||||
GetAllEvents() []string
|
||||
GetPubkeysByDepth() map[int][]string
|
||||
GetEventsByDepth() map[int][]string
|
||||
GetTotalPubkeys() int
|
||||
GetTotalEvents() int
|
||||
}
|
||||
|
||||
// GraphDatabase defines the interface for graph traversal operations.
|
||||
// This is implemented by the database package.
|
||||
type GraphDatabase interface {
|
||||
// TraverseFollows performs BFS traversal of follow graph
|
||||
TraverseFollows(seedPubkey []byte, maxDepth int) (GraphResultI, error)
|
||||
// TraverseFollowers performs BFS traversal to find followers
|
||||
TraverseFollowers(seedPubkey []byte, maxDepth int) (GraphResultI, error)
|
||||
// FindMentions finds events mentioning a pubkey
|
||||
FindMentions(pubkey []byte, kinds []uint16) (GraphResultI, error)
|
||||
// TraverseThread performs BFS traversal of thread structure
|
||||
TraverseThread(seedEventID []byte, maxDepth int, direction string) (GraphResultI, error)
|
||||
}
|
||||
|
||||
// Executor handles graph query execution and response generation.
|
||||
type Executor struct {
|
||||
db GraphDatabase
|
||||
relaySigner signer.I
|
||||
relayPubkey []byte
|
||||
}
|
||||
|
||||
// NewExecutor creates a new graph query executor.
|
||||
// The secretKey should be the 32-byte relay identity secret key.
|
||||
func NewExecutor(db GraphDatabase, secretKey []byte) (*Executor, error) {
|
||||
s, err := p8k.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = s.InitSec(secretKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Executor{
|
||||
db: db,
|
||||
relaySigner: s,
|
||||
relayPubkey: s.Pub(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Execute runs a graph query and returns a relay-signed event with results.
|
||||
func (e *Executor) Execute(q *Query) (*event.E, error) {
|
||||
var result GraphResultI
|
||||
var err error
|
||||
var responseKind uint16
|
||||
|
||||
// Decode seed (hex string to bytes)
|
||||
seedBytes, err := hex.Dec(q.Seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the appropriate traversal
|
||||
switch q.Method {
|
||||
case "follows":
|
||||
responseKind = KindGraphFollows
|
||||
result, err = e.db.TraverseFollows(seedBytes, q.Depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: follows traversal returned %d pubkeys", result.GetTotalPubkeys())
|
||||
|
||||
case "followers":
|
||||
responseKind = KindGraphFollows
|
||||
result, err = e.db.TraverseFollowers(seedBytes, q.Depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: followers traversal returned %d pubkeys", result.GetTotalPubkeys())
|
||||
|
||||
case "mentions":
|
||||
responseKind = KindGraphMentions
|
||||
// Mentions don't use depth traversal, just find direct mentions
|
||||
// Convert RefSpec kinds to uint16 for the database call
|
||||
var kinds []uint16
|
||||
if len(q.InboundRefs) > 0 {
|
||||
for _, rs := range q.InboundRefs {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds = append(kinds, uint16(k))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
kinds = []uint16{1} // Default to kind 1 (notes)
|
||||
}
|
||||
result, err = e.db.FindMentions(seedBytes, kinds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: mentions query returned %d events", result.GetTotalEvents())
|
||||
|
||||
case "thread":
|
||||
responseKind = KindGraphThread
|
||||
result, err = e.db.TraverseThread(seedBytes, q.Depth, "both")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.D.F("graph executor: thread traversal returned %d events", result.GetTotalEvents())
|
||||
|
||||
default:
|
||||
return nil, ErrInvalidMethod
|
||||
}
|
||||
|
||||
// Generate response event
|
||||
return e.generateResponse(q, result, responseKind)
|
||||
}
|
||||
|
||||
// generateResponse creates a relay-signed event containing the query results.
|
||||
func (e *Executor) generateResponse(q *Query, result GraphResultI, responseKind uint16) (*event.E, error) {
|
||||
// Build content as JSON with depth arrays
|
||||
var content ResponseContent
|
||||
|
||||
if q.Method == "follows" || q.Method == "followers" {
|
||||
// For pubkey-based queries, use pubkeys_by_depth
|
||||
content.PubkeysByDepth = result.ToDepthArrays()
|
||||
content.TotalPubkeys = result.GetTotalPubkeys()
|
||||
} else {
|
||||
// For event-based queries, use events_by_depth
|
||||
content.EventsByDepth = result.ToEventDepthArrays()
|
||||
content.TotalEvents = result.GetTotalEvents()
|
||||
}
|
||||
|
||||
contentBytes, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build tags
|
||||
tags := tag.NewS(
|
||||
tag.NewFromAny("method", q.Method),
|
||||
tag.NewFromAny("seed", q.Seed),
|
||||
tag.NewFromAny("depth", strconv.Itoa(q.Depth)),
|
||||
)
|
||||
|
||||
// Create event
|
||||
ev := &event.E{
|
||||
Kind: responseKind,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
Tags: tags,
|
||||
Content: contentBytes,
|
||||
}
|
||||
|
||||
// Sign with relay identity
|
||||
if err = ev.Sign(e.relaySigner); chk.E(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// ResponseContent is the JSON structure for graph query responses.
|
||||
type ResponseContent struct {
|
||||
// PubkeysByDepth contains arrays of pubkeys at each depth (1-indexed)
|
||||
// Each pubkey appears ONLY at the depth where it was first discovered.
|
||||
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"`
|
||||
|
||||
// EventsByDepth contains arrays of event IDs at each depth (1-indexed)
|
||||
EventsByDepth [][]string `json:"events_by_depth,omitempty"`
|
||||
|
||||
// TotalPubkeys is the total count of unique pubkeys discovered
|
||||
TotalPubkeys int `json:"total_pubkeys,omitempty"`
|
||||
|
||||
// TotalEvents is the total count of unique events discovered
|
||||
TotalEvents int `json:"total_events,omitempty"`
|
||||
}
|
||||
183
pkg/protocol/graph/query.go
Normal file
183
pkg/protocol/graph/query.go
Normal file
@@ -0,0 +1,183 @@
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// It provides types and functions for parsing and validating graph traversal queries.
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter"
|
||||
)
|
||||
|
||||
// Query represents a graph traversal query from a _graph filter extension.
|
||||
type Query struct {
|
||||
// Method is the traversal method: "follows", "followers", "mentions", "thread"
|
||||
Method string `json:"method"`
|
||||
|
||||
// Seed is the starting point for traversal (pubkey hex or event ID hex)
|
||||
Seed string `json:"seed"`
|
||||
|
||||
// Depth is the maximum traversal depth (1-16, default: 1)
|
||||
Depth int `json:"depth,omitempty"`
|
||||
|
||||
// InboundRefs specifies which inbound references to collect
|
||||
// (events that reference discovered events via e-tags)
|
||||
InboundRefs []RefSpec `json:"inbound_refs,omitempty"`
|
||||
|
||||
// OutboundRefs specifies which outbound references to collect
|
||||
// (events referenced by discovered events via e-tags)
|
||||
OutboundRefs []RefSpec `json:"outbound_refs,omitempty"`
|
||||
}
|
||||
|
||||
// RefSpec specifies which event references to include in results.
|
||||
type RefSpec struct {
|
||||
// Kinds is the list of event kinds to match (OR semantics within this spec)
|
||||
Kinds []int `json:"kinds"`
|
||||
|
||||
// FromDepth specifies the minimum depth at which to collect refs (default: 0)
|
||||
// 0 = include refs from seed itself
|
||||
// 1 = start from first-hop connections
|
||||
FromDepth int `json:"from_depth,omitempty"`
|
||||
}
|
||||
|
||||
// Validation errors
|
||||
var (
|
||||
ErrMissingMethod = errors.New("_graph.method is required")
|
||||
ErrInvalidMethod = errors.New("_graph.method must be one of: follows, followers, mentions, thread")
|
||||
ErrMissingSeed = errors.New("_graph.seed is required")
|
||||
ErrInvalidSeed = errors.New("_graph.seed must be a 64-character hex string")
|
||||
ErrDepthTooHigh = errors.New("_graph.depth cannot exceed 16")
|
||||
ErrEmptyRefSpecKinds = errors.New("ref spec kinds array cannot be empty")
|
||||
)
|
||||
|
||||
// Valid method names
|
||||
var validMethods = map[string]bool{
|
||||
"follows": true,
|
||||
"followers": true,
|
||||
"mentions": true,
|
||||
"thread": true,
|
||||
}
|
||||
|
||||
// Validate checks the query for correctness and applies defaults.
|
||||
func (q *Query) Validate() error {
|
||||
// Method is required
|
||||
if q.Method == "" {
|
||||
return ErrMissingMethod
|
||||
}
|
||||
if !validMethods[q.Method] {
|
||||
return ErrInvalidMethod
|
||||
}
|
||||
|
||||
// Seed is required
|
||||
if q.Seed == "" {
|
||||
return ErrMissingSeed
|
||||
}
|
||||
if len(q.Seed) != 64 {
|
||||
return ErrInvalidSeed
|
||||
}
|
||||
// Validate hex characters
|
||||
for _, c := range q.Seed {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return ErrInvalidSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Apply depth defaults and limits
|
||||
if q.Depth < 1 {
|
||||
q.Depth = 1
|
||||
}
|
||||
if q.Depth > 16 {
|
||||
return ErrDepthTooHigh
|
||||
}
|
||||
|
||||
// Validate ref specs
|
||||
for _, rs := range q.InboundRefs {
|
||||
if len(rs.Kinds) == 0 {
|
||||
return ErrEmptyRefSpecKinds
|
||||
}
|
||||
}
|
||||
for _, rs := range q.OutboundRefs {
|
||||
if len(rs.Kinds) == 0 {
|
||||
return ErrEmptyRefSpecKinds
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasInboundRefs returns true if the query includes inbound reference collection.
|
||||
func (q *Query) HasInboundRefs() bool {
|
||||
return len(q.InboundRefs) > 0
|
||||
}
|
||||
|
||||
// HasOutboundRefs returns true if the query includes outbound reference collection.
|
||||
func (q *Query) HasOutboundRefs() bool {
|
||||
return len(q.OutboundRefs) > 0
|
||||
}
|
||||
|
||||
// HasRefs returns true if the query includes any reference collection.
|
||||
func (q *Query) HasRefs() bool {
|
||||
return q.HasInboundRefs() || q.HasOutboundRefs()
|
||||
}
|
||||
|
||||
// InboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
// It aggregates all RefSpecs where from_depth <= depth.
|
||||
func (q *Query) InboundKindsAtDepth(depth int) map[int]bool {
|
||||
kinds := make(map[int]bool)
|
||||
for _, rs := range q.InboundRefs {
|
||||
if rs.FromDepth <= depth {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds[k] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
// OutboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
func (q *Query) OutboundKindsAtDepth(depth int) map[int]bool {
|
||||
kinds := make(map[int]bool)
|
||||
for _, rs := range q.OutboundRefs {
|
||||
if rs.FromDepth <= depth {
|
||||
for _, k := range rs.Kinds {
|
||||
kinds[k] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
// ExtractFromFilter checks if a filter has a _graph extension and parses it.
|
||||
// Returns nil if no _graph field is present.
|
||||
// Returns an error if _graph is present but invalid.
|
||||
func ExtractFromFilter(f *filter.F) (*Query, error) {
|
||||
if f == nil || f.Extra == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw, ok := f.Extra["_graph"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var q Query
|
||||
if err := json.Unmarshal(raw, &q); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := q.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &q, nil
|
||||
}
|
||||
|
||||
// IsGraphQuery returns true if the filter contains a _graph extension.
|
||||
// This is a quick check that doesn't parse the full query.
|
||||
func IsGraphQuery(f *filter.F) bool {
|
||||
if f == nil || f.Extra == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := f.Extra["_graph"]
|
||||
return ok
|
||||
}
|
||||
397
pkg/protocol/graph/query_test.go
Normal file
397
pkg/protocol/graph/query_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter"
|
||||
)
|
||||
|
||||
func TestQueryValidate(t *testing.T) {
|
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query Query
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid follows query",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 2,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid followers query",
|
||||
query: Query{
|
||||
Method: "followers",
|
||||
Seed: validSeed,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid mentions query",
|
||||
query: Query{
|
||||
Method: "mentions",
|
||||
Seed: validSeed,
|
||||
Depth: 1,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid thread query",
|
||||
query: Query{
|
||||
Method: "thread",
|
||||
Seed: validSeed,
|
||||
Depth: 10,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid query with inbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 1},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid query with multiple ref specs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 1},
|
||||
{Kinds: []int{6}, FromDepth: 1},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}, FromDepth: 0},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
query: Query{Seed: validSeed},
|
||||
wantErr: ErrMissingMethod,
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
query: Query{
|
||||
Method: "invalid",
|
||||
Seed: validSeed,
|
||||
},
|
||||
wantErr: ErrInvalidMethod,
|
||||
},
|
||||
{
|
||||
name: "missing seed",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
},
|
||||
wantErr: ErrMissingSeed,
|
||||
},
|
||||
{
|
||||
name: "seed too short",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc123",
|
||||
},
|
||||
wantErr: ErrInvalidSeed,
|
||||
},
|
||||
{
|
||||
name: "seed with invalid characters",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg",
|
||||
},
|
||||
wantErr: ErrInvalidSeed,
|
||||
},
|
||||
{
|
||||
name: "depth too high",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 17,
|
||||
},
|
||||
wantErr: ErrDepthTooHigh,
|
||||
},
|
||||
{
|
||||
name: "empty ref spec kinds",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{}, FromDepth: 1},
|
||||
},
|
||||
},
|
||||
wantErr: ErrEmptyRefSpecKinds,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.query.Validate()
|
||||
if tt.wantErr == nil {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryDefaults(t *testing.T) {
|
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
q := Query{
|
||||
Method: "follows",
|
||||
Seed: validSeed,
|
||||
Depth: 0, // Should default to 1
|
||||
}
|
||||
|
||||
err := q.Validate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if q.Depth != 1 {
|
||||
t.Errorf("Depth = %d, want 1 (default)", q.Depth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKindsAtDepth(t *testing.T) {
|
||||
q := Query{
|
||||
Method: "follows",
|
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
Depth: 3,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}, FromDepth: 0}, // From seed
|
||||
{Kinds: []int{6, 16}, FromDepth: 1}, // From depth 1
|
||||
{Kinds: []int{9735}, FromDepth: 2}, // From depth 2
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}, FromDepth: 1},
|
||||
},
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 0
|
||||
kinds0 := q.InboundKindsAtDepth(0)
|
||||
if !kinds0[7] || kinds0[6] || kinds0[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(0) = %v, want only kind 7", kinds0)
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 1
|
||||
kinds1 := q.InboundKindsAtDepth(1)
|
||||
if !kinds1[7] || !kinds1[6] || !kinds1[16] || kinds1[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(1) = %v, want kinds 7, 6, 16", kinds1)
|
||||
}
|
||||
|
||||
// Test inbound kinds at depth 2
|
||||
kinds2 := q.InboundKindsAtDepth(2)
|
||||
if !kinds2[7] || !kinds2[6] || !kinds2[16] || !kinds2[9735] {
|
||||
t.Errorf("InboundKindsAtDepth(2) = %v, want all kinds", kinds2)
|
||||
}
|
||||
|
||||
// Test outbound kinds at depth 0
|
||||
outKinds0 := q.OutboundKindsAtDepth(0)
|
||||
if len(outKinds0) != 0 {
|
||||
t.Errorf("OutboundKindsAtDepth(0) = %v, want empty", outKinds0)
|
||||
}
|
||||
|
||||
// Test outbound kinds at depth 1
|
||||
outKinds1 := q.OutboundKindsAtDepth(1)
|
||||
if !outKinds1[1] {
|
||||
t.Errorf("OutboundKindsAtDepth(1) = %v, want kind 1", outKinds1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFromFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filterJSON string
|
||||
wantQuery bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "filter with valid graph query",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":2}}`,
|
||||
wantQuery: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter without graph query",
|
||||
filterJSON: `{"kinds":[1,7]}`,
|
||||
wantQuery: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter with invalid graph query (missing method)",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}}`,
|
||||
wantQuery: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "filter with complex graph query",
|
||||
filterJSON: `{"kinds":[0],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":3,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`,
|
||||
wantQuery: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &filter.F{}
|
||||
_, err := f.Unmarshal([]byte(tt.filterJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal filter: %v", err)
|
||||
}
|
||||
|
||||
q, err := ExtractFromFilter(f)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantQuery && q == nil {
|
||||
t.Error("expected query, got nil")
|
||||
}
|
||||
if !tt.wantQuery && q != nil {
|
||||
t.Errorf("expected nil query, got %+v", q)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGraphQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filterJSON string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "filter with graph query",
|
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"abc"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "filter without graph query",
|
||||
filterJSON: `{"kinds":[1,7]}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "filter with other extension",
|
||||
filterJSON: `{"kinds":[1],"_custom":"value"}`,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &filter.F{}
|
||||
_, err := f.Unmarshal([]byte(tt.filterJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal filter: %v", err)
|
||||
}
|
||||
|
||||
got := IsGraphQuery(f)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsGraphQuery() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryHasRefs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query Query
|
||||
hasInbound bool
|
||||
hasOutbound bool
|
||||
hasRefs bool
|
||||
}{
|
||||
{
|
||||
name: "no refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
},
|
||||
hasInbound: false,
|
||||
hasOutbound: false,
|
||||
hasRefs: false,
|
||||
},
|
||||
{
|
||||
name: "only inbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
},
|
||||
hasInbound: true,
|
||||
hasOutbound: false,
|
||||
hasRefs: true,
|
||||
},
|
||||
{
|
||||
name: "only outbound refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
hasInbound: false,
|
||||
hasOutbound: true,
|
||||
hasRefs: true,
|
||||
},
|
||||
{
|
||||
name: "both refs",
|
||||
query: Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
hasInbound: true,
|
||||
hasOutbound: true,
|
||||
hasRefs: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.query.HasInboundRefs(); got != tt.hasInbound {
|
||||
t.Errorf("HasInboundRefs() = %v, want %v", got, tt.hasInbound)
|
||||
}
|
||||
if got := tt.query.HasOutboundRefs(); got != tt.hasOutbound {
|
||||
t.Errorf("HasOutboundRefs() = %v, want %v", got, tt.hasOutbound)
|
||||
}
|
||||
if got := tt.query.HasRefs(); got != tt.hasRefs {
|
||||
t.Errorf("HasRefs() = %v, want %v", got, tt.hasRefs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
282
pkg/protocol/graph/ratelimit.go
Normal file
282
pkg/protocol/graph/ratelimit.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter with adaptive throttling
|
||||
// based on graph query complexity. It allows cooperative scheduling by inserting
|
||||
// pauses between operations to allow other work to proceed.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Token bucket parameters
|
||||
tokens float64 // Current available tokens
|
||||
maxTokens float64 // Maximum token capacity
|
||||
refillRate float64 // Tokens per second to add
|
||||
lastRefill time.Time // Last time tokens were refilled
|
||||
|
||||
// Throttling parameters
|
||||
baseDelay time.Duration // Minimum delay between operations
|
||||
maxDelay time.Duration // Maximum delay for complex queries
|
||||
depthFactor float64 // Multiplier per depth level
|
||||
limitFactor float64 // Multiplier based on result limit
|
||||
}
|
||||
|
||||
// RateLimiterConfig configures the rate limiter behavior.
|
||||
type RateLimiterConfig struct {
|
||||
// MaxTokens is the maximum number of tokens in the bucket (default: 100)
|
||||
MaxTokens float64
|
||||
|
||||
// RefillRate is tokens added per second (default: 10)
|
||||
RefillRate float64
|
||||
|
||||
// BaseDelay is the minimum delay between operations (default: 1ms)
|
||||
BaseDelay time.Duration
|
||||
|
||||
// MaxDelay is the maximum delay for complex queries (default: 100ms)
|
||||
MaxDelay time.Duration
|
||||
|
||||
// DepthFactor is the cost multiplier per depth level (default: 2.0)
|
||||
// A depth-3 query costs 2^3 = 8x more tokens than depth-1
|
||||
DepthFactor float64
|
||||
|
||||
// LimitFactor is additional cost per 100 results requested (default: 0.1)
|
||||
LimitFactor float64
|
||||
}
|
||||
|
||||
// DefaultRateLimiterConfig returns sensible defaults for the rate limiter.
|
||||
func DefaultRateLimiterConfig() RateLimiterConfig {
|
||||
return RateLimiterConfig{
|
||||
MaxTokens: 100.0,
|
||||
RefillRate: 10.0, // Refills fully in 10 seconds
|
||||
BaseDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
DepthFactor: 2.0,
|
||||
LimitFactor: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with the given configuration.
|
||||
func NewRateLimiter(cfg RateLimiterConfig) *RateLimiter {
|
||||
if cfg.MaxTokens <= 0 {
|
||||
cfg.MaxTokens = DefaultRateLimiterConfig().MaxTokens
|
||||
}
|
||||
if cfg.RefillRate <= 0 {
|
||||
cfg.RefillRate = DefaultRateLimiterConfig().RefillRate
|
||||
}
|
||||
if cfg.BaseDelay <= 0 {
|
||||
cfg.BaseDelay = DefaultRateLimiterConfig().BaseDelay
|
||||
}
|
||||
if cfg.MaxDelay <= 0 {
|
||||
cfg.MaxDelay = DefaultRateLimiterConfig().MaxDelay
|
||||
}
|
||||
if cfg.DepthFactor <= 0 {
|
||||
cfg.DepthFactor = DefaultRateLimiterConfig().DepthFactor
|
||||
}
|
||||
if cfg.LimitFactor <= 0 {
|
||||
cfg.LimitFactor = DefaultRateLimiterConfig().LimitFactor
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
tokens: cfg.MaxTokens,
|
||||
maxTokens: cfg.MaxTokens,
|
||||
refillRate: cfg.RefillRate,
|
||||
lastRefill: time.Now(),
|
||||
baseDelay: cfg.BaseDelay,
|
||||
maxDelay: cfg.MaxDelay,
|
||||
depthFactor: cfg.DepthFactor,
|
||||
limitFactor: cfg.LimitFactor,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryCost calculates the token cost for a graph query based on its complexity.
|
||||
// Higher depths and larger limits cost exponentially more tokens.
|
||||
func (rl *RateLimiter) QueryCost(q *Query) float64 {
|
||||
if q == nil {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// Base cost is exponential in depth: depthFactor^depth
|
||||
// This models the exponential growth of traversal work
|
||||
cost := 1.0
|
||||
for i := 0; i < q.Depth; i++ {
|
||||
cost *= rl.depthFactor
|
||||
}
|
||||
|
||||
// Add cost for reference collection (adds ~50% per ref spec)
|
||||
refCost := float64(len(q.InboundRefs)+len(q.OutboundRefs)) * 0.5
|
||||
cost += refCost
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
// OperationCost calculates the token cost for a single traversal operation.
|
||||
// This is used during query execution for per-operation throttling.
|
||||
func (rl *RateLimiter) OperationCost(depth int, nodesAtDepth int) float64 {
|
||||
// Cost increases with depth and number of nodes to process
|
||||
depthMultiplier := 1.0
|
||||
for i := 0; i < depth; i++ {
|
||||
depthMultiplier *= rl.depthFactor
|
||||
}
|
||||
|
||||
// More nodes at this depth = more work
|
||||
nodeFactor := 1.0 + float64(nodesAtDepth)*0.01
|
||||
|
||||
return depthMultiplier * nodeFactor
|
||||
}
|
||||
|
||||
// refillTokens adds tokens based on elapsed time since last refill.
|
||||
func (rl *RateLimiter) refillTokens() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
rl.lastRefill = now
|
||||
|
||||
rl.tokens += elapsed * rl.refillRate
|
||||
if rl.tokens > rl.maxTokens {
|
||||
rl.tokens = rl.maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire tries to acquire tokens for a query. If not enough tokens are available,
|
||||
// it waits until they become available or the context is cancelled.
|
||||
// Returns the delay that was applied, or an error if context was cancelled.
|
||||
func (rl *RateLimiter) Acquire(ctx context.Context, cost float64) (time.Duration, error) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refillTokens()
|
||||
|
||||
var totalDelay time.Duration
|
||||
|
||||
// Wait until we have enough tokens
|
||||
for rl.tokens < cost {
|
||||
// Calculate how long we need to wait for tokens to refill
|
||||
tokensNeeded := cost - rl.tokens
|
||||
waitTime := time.Duration(tokensNeeded/rl.refillRate*1000) * time.Millisecond
|
||||
|
||||
// Clamp to max delay
|
||||
if waitTime > rl.maxDelay {
|
||||
waitTime = rl.maxDelay
|
||||
}
|
||||
if waitTime < rl.baseDelay {
|
||||
waitTime = rl.baseDelay
|
||||
}
|
||||
|
||||
// Release lock while waiting
|
||||
rl.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rl.mu.Lock()
|
||||
return totalDelay, ctx.Err()
|
||||
case <-time.After(waitTime):
|
||||
}
|
||||
|
||||
totalDelay += waitTime
|
||||
rl.mu.Lock()
|
||||
rl.refillTokens()
|
||||
}
|
||||
|
||||
// Consume tokens
|
||||
rl.tokens -= cost
|
||||
return totalDelay, nil
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire tokens without waiting.
|
||||
// Returns true if successful, false if insufficient tokens.
|
||||
func (rl *RateLimiter) TryAcquire(cost float64) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refillTokens()
|
||||
|
||||
if rl.tokens >= cost {
|
||||
rl.tokens -= cost
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Pause inserts a cooperative delay to allow other work to proceed.
|
||||
// The delay is proportional to the current depth and load.
|
||||
// This should be called periodically during long-running traversals.
|
||||
func (rl *RateLimiter) Pause(ctx context.Context, depth int, itemsProcessed int) error {
|
||||
// Calculate adaptive delay based on depth and progress
|
||||
// Deeper traversals and more processed items = longer pauses
|
||||
delay := rl.baseDelay
|
||||
|
||||
// Increase delay with depth
|
||||
for i := 0; i < depth; i++ {
|
||||
delay += rl.baseDelay
|
||||
}
|
||||
|
||||
// Add extra delay every N items to allow other work
|
||||
if itemsProcessed > 0 && itemsProcessed%100 == 0 {
|
||||
delay += rl.baseDelay * 5
|
||||
}
|
||||
|
||||
// Cap at max delay
|
||||
if delay > rl.maxDelay {
|
||||
delay = rl.maxDelay
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AvailableTokens returns the current number of available tokens.
|
||||
func (rl *RateLimiter) AvailableTokens() float64 {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
rl.refillTokens()
|
||||
return rl.tokens
|
||||
}
|
||||
|
||||
// Throttler provides a simple interface for cooperative scheduling during traversal.
|
||||
// It wraps the rate limiter and provides depth-aware throttling.
|
||||
type Throttler struct {
|
||||
rl *RateLimiter
|
||||
depth int
|
||||
itemsProcessed int
|
||||
}
|
||||
|
||||
// NewThrottler creates a throttler for a specific traversal operation.
|
||||
func NewThrottler(rl *RateLimiter, depth int) *Throttler {
|
||||
return &Throttler{
|
||||
rl: rl,
|
||||
depth: depth,
|
||||
}
|
||||
}
|
||||
|
||||
// Tick should be called after processing each item.
|
||||
// It tracks progress and inserts pauses as needed.
|
||||
func (t *Throttler) Tick(ctx context.Context) error {
|
||||
t.itemsProcessed++
|
||||
|
||||
// Insert cooperative pause periodically
|
||||
// More frequent pauses at higher depths
|
||||
interval := 50
|
||||
if t.depth >= 2 {
|
||||
interval = 25
|
||||
}
|
||||
if t.depth >= 4 {
|
||||
interval = 10
|
||||
}
|
||||
|
||||
if t.itemsProcessed%interval == 0 {
|
||||
return t.rl.Pause(ctx, t.depth, t.itemsProcessed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Complete marks the throttler as complete and returns stats.
|
||||
func (t *Throttler) Complete() (itemsProcessed int) {
|
||||
return t.itemsProcessed
|
||||
}
|
||||
267
pkg/protocol/graph/ratelimit_test.go
Normal file
267
pkg/protocol/graph/ratelimit_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiterQueryCost(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query *Query
|
||||
minCost float64
|
||||
maxCost float64
|
||||
}{
|
||||
{
|
||||
name: "nil query",
|
||||
query: nil,
|
||||
minCost: 1.0,
|
||||
maxCost: 1.0,
|
||||
},
|
||||
{
|
||||
name: "depth 1 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 1,
|
||||
},
|
||||
minCost: 1.5, // depthFactor^1 = 2
|
||||
maxCost: 2.5,
|
||||
},
|
||||
{
|
||||
name: "depth 2 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
},
|
||||
minCost: 3.5, // depthFactor^2 = 4
|
||||
maxCost: 4.5,
|
||||
},
|
||||
{
|
||||
name: "depth 3 no refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 3,
|
||||
},
|
||||
minCost: 7.5, // depthFactor^3 = 8
|
||||
maxCost: 8.5,
|
||||
},
|
||||
{
|
||||
name: "depth 2 with inbound refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
},
|
||||
minCost: 4.0, // 4 + 0.5 = 4.5
|
||||
maxCost: 5.0,
|
||||
},
|
||||
{
|
||||
name: "depth 2 with both refs",
|
||||
query: &Query{
|
||||
Method: "follows",
|
||||
Seed: "abc",
|
||||
Depth: 2,
|
||||
InboundRefs: []RefSpec{
|
||||
{Kinds: []int{7}},
|
||||
},
|
||||
OutboundRefs: []RefSpec{
|
||||
{Kinds: []int{1}},
|
||||
},
|
||||
},
|
||||
minCost: 4.5, // 4 + 0.5 + 0.5 = 5
|
||||
maxCost: 5.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost := rl.QueryCost(tt.query)
|
||||
if cost < tt.minCost || cost > tt.maxCost {
|
||||
t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterOperationCost(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
|
||||
// Depth 0, 1 node
|
||||
cost0 := rl.OperationCost(0, 1)
|
||||
if cost0 < 1.0 || cost0 > 1.1 {
|
||||
t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
|
||||
}
|
||||
|
||||
// Depth 1, 1 node
|
||||
cost1 := rl.OperationCost(1, 1)
|
||||
if cost1 < 2.0 || cost1 > 2.1 {
|
||||
t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
|
||||
}
|
||||
|
||||
// Depth 2, 100 nodes
|
||||
cost2 := rl.OperationCost(2, 100)
|
||||
if cost2 < 8.0 {
|
||||
t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterAcquire(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
cfg.RefillRate = 100 // Fast refill for testing
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Should acquire immediately when tokens available
|
||||
delay, err := rl.Acquire(ctx, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if delay > time.Millisecond*10 {
|
||||
t.Errorf("expected minimal delay, got %v", delay)
|
||||
}
|
||||
|
||||
// Check remaining tokens
|
||||
remaining := rl.AvailableTokens()
|
||||
if remaining > 6 {
|
||||
t.Errorf("expected ~5 tokens remaining, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterTryAcquire(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Should succeed with enough tokens
|
||||
if !rl.TryAcquire(5) {
|
||||
t.Error("TryAcquire(5) should succeed with 10 tokens")
|
||||
}
|
||||
|
||||
// Should succeed again
|
||||
if !rl.TryAcquire(5) {
|
||||
t.Error("TryAcquire(5) should succeed with 5 tokens")
|
||||
}
|
||||
|
||||
// Should fail with insufficient tokens
|
||||
if rl.TryAcquire(1) {
|
||||
t.Error("TryAcquire(1) should fail with 0 tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterContextCancellation(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 1
|
||||
cfg.RefillRate = 0.1 // Very slow refill
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(1)
|
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Try to acquire - should be cancelled
|
||||
_, err := rl.Acquire(ctx, 10)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterRefill(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.MaxTokens = 10
|
||||
cfg.RefillRate = 1000 // 1000 tokens per second
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(10)
|
||||
|
||||
// Wait for refill
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should have some tokens now
|
||||
available := rl.AvailableTokens()
|
||||
if available < 5 {
|
||||
t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
|
||||
}
|
||||
if available > 10 {
|
||||
t.Errorf("expected <= 10 tokens (max), got %v", available)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterPause(t *testing.T) {
|
||||
rl := NewRateLimiter(DefaultRateLimiterConfig())
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := rl.Pause(ctx, 1, 0)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have paused for at least baseDelay
|
||||
if elapsed < rl.baseDelay {
|
||||
t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottler(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
cfg.BaseDelay = 100 * time.Microsecond // Short for testing
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
throttler := NewThrottler(rl, 1)
|
||||
ctx := context.Background()
|
||||
|
||||
// Process items
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := throttler.Tick(ctx); err != nil {
|
||||
t.Fatalf("unexpected error at tick %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
processed := throttler.Complete()
|
||||
if processed != 100 {
|
||||
t.Errorf("expected 100 items processed, got %d", processed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThrottlerContextCancellation(t *testing.T) {
|
||||
cfg := DefaultRateLimiterConfig()
|
||||
rl := NewRateLimiter(cfg)
|
||||
|
||||
throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Process some items
|
||||
for i := 0; i < 20; i++ {
|
||||
throttler.Tick(ctx)
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Next tick that would pause should return error
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := throttler.Tick(ctx); err != nil {
|
||||
// Expected - context was cancelled
|
||||
return
|
||||
}
|
||||
}
|
||||
// If we get here without error, the throttler didn't check context
|
||||
// This is acceptable if no pause was needed
|
||||
}
|
||||
Reference in New Issue
Block a user