diff --git a/app/config/config.go b/app/config/config.go index 32c5fa6..b6003fd 100644 --- a/app/config/config.go +++ b/app/config/config.go @@ -102,8 +102,22 @@ type C struct { Neo4jPassword string `env:"ORLY_NEO4J_PASSWORD" default:"password" usage:"Neo4j authentication password (only used when ORLY_DB_TYPE=neo4j)"` // Advanced database tuning - SerialCachePubkeys int `env:"ORLY_SERIAL_CACHE_PUBKEYS" default:"100000" usage:"max pubkeys to cache for compact event storage (default: 100000, ~3.2MB memory)"` - SerialCacheEventIds int `env:"ORLY_SERIAL_CACHE_EVENT_IDS" default:"500000" usage:"max event IDs to cache for compact event storage (default: 500000, ~16MB memory)"` + SerialCachePubkeys int `env:"ORLY_SERIAL_CACHE_PUBKEYS" default:"100000" usage:"max pubkeys to cache for compact event storage (default: 100000, ~3.2MB memory)"` + SerialCacheEventIds int `env:"ORLY_SERIAL_CACHE_EVENT_IDS" default:"500000" usage:"max event IDs to cache for compact event storage (default: 500000, ~16MB memory)"` + + // Adaptive rate limiting (PID-controlled) + RateLimitEnabled bool `env:"ORLY_RATE_LIMIT_ENABLED" default:"false" usage:"enable adaptive PID-controlled rate limiting for database operations"` + RateLimitTargetMB int `env:"ORLY_RATE_LIMIT_TARGET_MB" default:"1500" usage:"target memory limit in MB for rate limiting (default: 1500 = 1.5GB)"` + RateLimitWriteKp float64 `env:"ORLY_RATE_LIMIT_WRITE_KP" default:"0.5" usage:"PID proportional gain for write operations"` + RateLimitWriteKi float64 `env:"ORLY_RATE_LIMIT_WRITE_KI" default:"0.1" usage:"PID integral gain for write operations"` + RateLimitWriteKd float64 `env:"ORLY_RATE_LIMIT_WRITE_KD" default:"0.05" usage:"PID derivative gain for write operations (filtered)"` + RateLimitReadKp float64 `env:"ORLY_RATE_LIMIT_READ_KP" default:"0.3" usage:"PID proportional gain for read operations"` + RateLimitReadKi float64 `env:"ORLY_RATE_LIMIT_READ_KI" default:"0.05" usage:"PID integral gain for read operations"` + RateLimitReadKd float64 `env:"ORLY_RATE_LIMIT_READ_KD" default:"0.02" usage:"PID derivative gain for read operations (filtered)"` + RateLimitMaxWriteMs int `env:"ORLY_RATE_LIMIT_MAX_WRITE_MS" default:"1000" usage:"maximum delay for write operations in milliseconds"` + RateLimitMaxReadMs int `env:"ORLY_RATE_LIMIT_MAX_READ_MS" default:"500" usage:"maximum delay for read operations in milliseconds"` + RateLimitWriteTarget float64 `env:"ORLY_RATE_LIMIT_WRITE_TARGET" default:"0.85" usage:"PID setpoint for writes (throttle when load exceeds this, 0.0-1.0)"` + RateLimitReadTarget float64 `env:"ORLY_RATE_LIMIT_READ_TARGET" default:"0.90" usage:"PID setpoint for reads (throttle when load exceeds this, 0.0-1.0)"` // TLS configuration TLSDomains []string `env:"ORLY_TLS_DOMAINS" usage:"comma-separated list of domains to respond to for TLS"` @@ -432,3 +446,22 @@ func (cfg *C) GetDatabaseConfigValues() ( cfg.DBZSTDLevel, cfg.Neo4jURI, cfg.Neo4jUser, cfg.Neo4jPassword } + +// GetRateLimitConfigValues returns the rate limiting configuration values. +// This avoids circular imports with pkg/ratelimit while allowing main.go to construct +// a ratelimit.Config with the correct type. +func (cfg *C) GetRateLimitConfigValues() ( + enabled bool, + targetMB int, + writeKp, writeKi, writeKd float64, + readKp, readKi, readKd float64, + maxWriteMs, maxReadMs int, + writeTarget, readTarget float64, +) { + return cfg.RateLimitEnabled, + cfg.RateLimitTargetMB, + cfg.RateLimitWriteKp, cfg.RateLimitWriteKi, cfg.RateLimitWriteKd, + cfg.RateLimitReadKp, cfg.RateLimitReadKi, cfg.RateLimitReadKd, + cfg.RateLimitMaxWriteMs, cfg.RateLimitMaxReadMs, + cfg.RateLimitWriteTarget, cfg.RateLimitReadTarget +} diff --git a/app/main.go b/app/main.go index d197bd7..080b6d7 100644 --- a/app/main.go +++ b/app/main.go @@ -21,12 +21,13 @@ import ( "next.orly.dev/pkg/protocol/graph" "next.orly.dev/pkg/protocol/nip43" "next.orly.dev/pkg/protocol/publish" + "next.orly.dev/pkg/ratelimit" "next.orly.dev/pkg/spider" dsync "next.orly.dev/pkg/sync" ) func Run( - ctx context.Context, cfg *config.C, db database.Database, + ctx context.Context, cfg *config.C, db database.Database, limiter *ratelimit.Limiter, ) (quit chan struct{}) { quit = make(chan struct{}) var once sync.Once @@ -64,14 +65,15 @@ func Run( } // start listener l := &Server{ - Ctx: ctx, - Config: cfg, - DB: db, - publishers: publish.New(NewPublisher(ctx)), - Admins: adminKeys, - Owners: ownerKeys, - cfg: cfg, - db: db, + Ctx: ctx, + Config: cfg, + DB: db, + publishers: publish.New(NewPublisher(ctx)), + Admins: adminKeys, + Owners: ownerKeys, + rateLimiter: limiter, + cfg: cfg, + db: db, } // Initialize NIP-43 invite manager if enabled @@ -360,6 +362,12 @@ func Run( } } + // Start rate limiter if enabled + if limiter != nil && limiter.IsEnabled() { + limiter.Start() + log.I.F("adaptive rate limiter started") + } + // Wait for database to be ready before accepting requests log.I.F("waiting for database warmup to complete...") <-db.Ready() @@ -457,6 +465,12 @@ func Run( log.I.F("directory spider stopped") } + // Stop rate limiter if running + if l.rateLimiter != nil && l.rateLimiter.IsEnabled() { + l.rateLimiter.Stop() + log.I.F("rate limiter stopped") + } + // Create shutdown context with timeout shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 10*time.Second) defer cancelShutdown() diff --git a/app/server.go b/app/server.go index 75a89e1..038f1d5 100644 --- a/app/server.go +++ b/app/server.go @@ -29,6 +29,7 @@ import ( "next.orly.dev/pkg/protocol/graph" "next.orly.dev/pkg/protocol/nip43" "next.orly.dev/pkg/protocol/publish" + "next.orly.dev/pkg/ratelimit" "next.orly.dev/pkg/spider" dsync "next.orly.dev/pkg/sync" ) @@ -64,6 +65,7 @@ type Server struct { blossomServer *blossom.Server InviteManager *nip43.InviteManager graphExecutor *graph.Executor + rateLimiter *ratelimit.Limiter cfg *config.C db database.Database // Changed from *database.D to interface } diff --git a/main.go b/main.go index deb993f..79db6cf 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "next.orly.dev/pkg/database" _ "next.orly.dev/pkg/neo4j" // Import to register neo4j factory "git.mleku.dev/mleku/nostr/encoders/hex" + "next.orly.dev/pkg/ratelimit" "next.orly.dev/pkg/utils/interrupt" "next.orly.dev/pkg/version" ) @@ -336,6 +337,37 @@ func main() { } acl.Registry.Syncer() + // Create rate limiter if enabled + var limiter *ratelimit.Limiter + rateLimitEnabled, targetMB, + writeKp, writeKi, writeKd, + readKp, readKi, readKd, + maxWriteMs, maxReadMs, + writeTarget, readTarget := cfg.GetRateLimitConfigValues() + + if rateLimitEnabled { + rlConfig := ratelimit.NewConfigFromValues( + rateLimitEnabled, targetMB, + writeKp, writeKi, writeKd, + readKp, readKi, readKd, + maxWriteMs, maxReadMs, + writeTarget, readTarget, + ) + + // Create appropriate monitor based on database type + if badgerDB, ok := db.(*database.D); ok { + limiter = ratelimit.NewBadgerLimiter(rlConfig, badgerDB.DB) + log.I.F("rate limiter configured for Badger backend (target: %dMB)", targetMB) + } else { + // For Neo4j or other backends, create a disabled limiter for now + // Neo4j monitor requires access to the querySem which is internal + limiter = ratelimit.NewDisabledLimiter() + log.I.F("rate limiter disabled for non-Badger backend") + } + } else { + limiter = ratelimit.NewDisabledLimiter() + } + // Start HTTP pprof server if enabled if cfg.PprofHTTP { pprofAddr := fmt.Sprintf("%s:%d", cfg.Listen, 6060) @@ -413,7 +445,7 @@ func main() { }() } - quit := app.Run(ctx, cfg, db) + quit := app.Run(ctx, cfg, db, limiter) sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) for { diff --git a/pkg/interfaces/loadmonitor/loadmonitor.go b/pkg/interfaces/loadmonitor/loadmonitor.go new file mode 100644 index 0000000..30d1755 --- /dev/null +++ b/pkg/interfaces/loadmonitor/loadmonitor.go @@ -0,0 +1,58 @@ +// Package loadmonitor defines the interface for database load monitoring. +// This allows different database backends to provide their own load metrics +// while the rate limiter remains database-agnostic. +package loadmonitor + +import "time" + +// Metrics contains load metrics from a database backend. +// All values are normalized to 0.0-1.0 where 0 means no load and 1 means at capacity. +type Metrics struct { + // MemoryPressure indicates memory usage relative to a target limit (0.0-1.0+). + // Values above 1.0 indicate the target has been exceeded. + MemoryPressure float64 + + // WriteLoad indicates the write-side load level (0.0-1.0). + // For Badger: L0 tables and compaction score + // For Neo4j: active write transactions + WriteLoad float64 + + // ReadLoad indicates the read-side load level (0.0-1.0). + // For Badger: cache hit ratio (inverted) + // For Neo4j: active read transactions + ReadLoad float64 + + // QueryLatency is the recent average query latency. + QueryLatency time.Duration + + // WriteLatency is the recent average write latency. + WriteLatency time.Duration + + // Timestamp is when these metrics were collected. + Timestamp time.Time +} + +// Monitor defines the interface for database load monitoring. +// Implementations are database-specific (Badger, Neo4j, etc.). +type Monitor interface { + // GetMetrics returns the current load metrics. + // This should be efficient as it may be called frequently. + GetMetrics() Metrics + + // RecordQueryLatency records a query latency sample for averaging. + RecordQueryLatency(latency time.Duration) + + // RecordWriteLatency records a write latency sample for averaging. + RecordWriteLatency(latency time.Duration) + + // SetMemoryTarget sets the target memory limit in bytes. + // Memory pressure is calculated relative to this target. + SetMemoryTarget(bytes uint64) + + // Start begins background metric collection. + // Returns a channel that will be closed when the monitor is stopped. + Start() <-chan struct{} + + // Stop halts background metric collection. + Stop() +} diff --git a/pkg/ratelimit/badger_monitor.go b/pkg/ratelimit/badger_monitor.go new file mode 100644 index 0000000..faeb502 --- /dev/null +++ b/pkg/ratelimit/badger_monitor.go @@ -0,0 +1,237 @@ +//go:build !(js && wasm) + +package ratelimit + +import ( + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/dgraph-io/badger/v4" + "next.orly.dev/pkg/interfaces/loadmonitor" +) + +// BadgerMonitor implements loadmonitor.Monitor for the Badger database. +// It collects metrics from Badger's LSM tree, caches, and Go runtime. +type BadgerMonitor struct { + db *badger.DB + + // Target memory for pressure calculation + targetMemoryBytes atomic.Uint64 + + // Latency tracking with exponential moving average + queryLatencyNs atomic.Int64 + writeLatencyNs atomic.Int64 + latencyAlpha float64 // EMA coefficient (default 0.1) + + // Cached metrics (updated by background goroutine) + metricsLock sync.RWMutex + cachedMetrics loadmonitor.Metrics + lastL0Tables int + lastL0Score float64 + + // Background collection + stopChan chan struct{} + stopped chan struct{} + interval time.Duration +} + +// Compile-time check that BadgerMonitor implements loadmonitor.Monitor +var _ loadmonitor.Monitor = (*BadgerMonitor)(nil) + +// NewBadgerMonitor creates a new Badger load monitor. +// The updateInterval controls how often metrics are collected (default 100ms). +func NewBadgerMonitor(db *badger.DB, updateInterval time.Duration) *BadgerMonitor { + if updateInterval <= 0 { + updateInterval = 100 * time.Millisecond + } + + m := &BadgerMonitor{ + db: db, + latencyAlpha: 0.1, // 10% new, 90% old for smooth EMA + stopChan: make(chan struct{}), + stopped: make(chan struct{}), + interval: updateInterval, + } + + // Set a default target (1.5GB) + m.targetMemoryBytes.Store(1500 * 1024 * 1024) + + return m +} + +// GetMetrics returns the current load metrics. +func (m *BadgerMonitor) GetMetrics() loadmonitor.Metrics { + m.metricsLock.RLock() + defer m.metricsLock.RUnlock() + return m.cachedMetrics +} + +// RecordQueryLatency records a query latency sample using exponential moving average. +func (m *BadgerMonitor) RecordQueryLatency(latency time.Duration) { + ns := latency.Nanoseconds() + for { + old := m.queryLatencyNs.Load() + if old == 0 { + if m.queryLatencyNs.CompareAndSwap(0, ns) { + return + } + continue + } + // EMA: new = alpha * sample + (1-alpha) * old + newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) + if m.queryLatencyNs.CompareAndSwap(old, newVal) { + return + } + } +} + +// RecordWriteLatency records a write latency sample using exponential moving average. +func (m *BadgerMonitor) RecordWriteLatency(latency time.Duration) { + ns := latency.Nanoseconds() + for { + old := m.writeLatencyNs.Load() + if old == 0 { + if m.writeLatencyNs.CompareAndSwap(0, ns) { + return + } + continue + } + // EMA: new = alpha * sample + (1-alpha) * old + newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) + if m.writeLatencyNs.CompareAndSwap(old, newVal) { + return + } + } +} + +// SetMemoryTarget sets the target memory limit in bytes. +func (m *BadgerMonitor) SetMemoryTarget(bytes uint64) { + m.targetMemoryBytes.Store(bytes) +} + +// Start begins background metric collection. +func (m *BadgerMonitor) Start() <-chan struct{} { + go m.collectLoop() + return m.stopped +} + +// Stop halts background metric collection. +func (m *BadgerMonitor) Stop() { + close(m.stopChan) + <-m.stopped +} + +// collectLoop periodically collects metrics from Badger. +func (m *BadgerMonitor) collectLoop() { + defer close(m.stopped) + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-m.stopChan: + return + case <-ticker.C: + m.updateMetrics() + } + } +} + +// updateMetrics collects current metrics from Badger and runtime. +func (m *BadgerMonitor) updateMetrics() { + if m.db == nil || m.db.IsClosed() { + return + } + + metrics := loadmonitor.Metrics{ + Timestamp: time.Now(), + } + + // Calculate memory pressure from Go runtime + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + targetBytes := m.targetMemoryBytes.Load() + if targetBytes > 0 { + // Use HeapAlloc as primary memory metric + // This represents the actual live heap objects + metrics.MemoryPressure = float64(memStats.HeapAlloc) / float64(targetBytes) + } + + // Get Badger LSM tree information for write load + levels := m.db.Levels() + var l0Tables int + var maxScore float64 + + for _, level := range levels { + if level.Level == 0 { + l0Tables = level.NumTables + } + if level.Score > maxScore { + maxScore = level.Score + } + } + + // Calculate write load based on L0 tables and compaction score + // L0 tables stall at NumLevelZeroTablesStall (default 16) + // We consider write pressure high when approaching that limit + const l0StallThreshold = 16 + l0Load := float64(l0Tables) / float64(l0StallThreshold) + if l0Load > 1.0 { + l0Load = 1.0 + } + + // Compaction score > 1.0 means compaction is needed + // We blend L0 tables and compaction score for write load + compactionLoad := maxScore / 2.0 // Score of 2.0 = fully loaded + if compactionLoad > 1.0 { + compactionLoad = 1.0 + } + + // Blend: 60% L0 (immediate backpressure), 40% compaction score + metrics.WriteLoad = 0.6*l0Load + 0.4*compactionLoad + + // Calculate read load from cache metrics + blockMetrics := m.db.BlockCacheMetrics() + indexMetrics := m.db.IndexCacheMetrics() + + var blockHitRatio, indexHitRatio float64 + if blockMetrics != nil { + blockHitRatio = blockMetrics.Ratio() + } + if indexMetrics != nil { + indexHitRatio = indexMetrics.Ratio() + } + + // Average cache hit ratio (0 = no hits = high load, 1 = all hits = low load) + avgHitRatio := (blockHitRatio + indexHitRatio) / 2.0 + + // Invert: low hit ratio = high read load + // Use 0.5 as the threshold (below 50% hit ratio is concerning) + if avgHitRatio < 0.5 { + metrics.ReadLoad = 1.0 - avgHitRatio*2 // 0% hits = 1.0 load, 50% hits = 0.0 load + } else { + metrics.ReadLoad = 0 // Above 50% hit ratio = minimal load + } + + // Store latencies + metrics.QueryLatency = time.Duration(m.queryLatencyNs.Load()) + metrics.WriteLatency = time.Duration(m.writeLatencyNs.Load()) + + // Update cached metrics + m.metricsLock.Lock() + m.cachedMetrics = metrics + m.lastL0Tables = l0Tables + m.lastL0Score = maxScore + m.metricsLock.Unlock() +} + +// GetL0Stats returns L0-specific statistics for debugging. +func (m *BadgerMonitor) GetL0Stats() (tables int, score float64) { + m.metricsLock.RLock() + defer m.metricsLock.RUnlock() + return m.lastL0Tables, m.lastL0Score +} diff --git a/pkg/ratelimit/factory.go b/pkg/ratelimit/factory.go new file mode 100644 index 0000000..c92255a --- /dev/null +++ b/pkg/ratelimit/factory.go @@ -0,0 +1,56 @@ +//go:build !(js && wasm) + +package ratelimit + +import ( + "time" + + "github.com/dgraph-io/badger/v4" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "next.orly.dev/pkg/interfaces/loadmonitor" +) + +// NewBadgerLimiter creates a rate limiter configured for a Badger database. +// It automatically creates a BadgerMonitor for the provided database. +func NewBadgerLimiter(config Config, db *badger.DB) *Limiter { + monitor := NewBadgerMonitor(db, 100*time.Millisecond) + return NewLimiter(config, monitor) +} + +// NewNeo4jLimiter creates a rate limiter configured for a Neo4j database. +// It automatically creates a Neo4jMonitor for the provided driver. +// querySem should be the semaphore used to limit concurrent queries. +// maxConcurrency is typically 10 (matching the semaphore size). +func NewNeo4jLimiter( + config Config, + driver neo4j.DriverWithContext, + querySem chan struct{}, + maxConcurrency int, +) *Limiter { + monitor := NewNeo4jMonitor(driver, querySem, maxConcurrency, 100*time.Millisecond) + return NewLimiter(config, monitor) +} + +// NewDisabledLimiter creates a rate limiter that is disabled. +// This is useful when rate limiting is not configured. +func NewDisabledLimiter() *Limiter { + config := DefaultConfig() + config.Enabled = false + return NewLimiter(config, nil) +} + +// MonitorFromBadgerDB creates a BadgerMonitor from a Badger database. +// Exported for use when you need to create the monitor separately. +func MonitorFromBadgerDB(db *badger.DB) loadmonitor.Monitor { + return NewBadgerMonitor(db, 100*time.Millisecond) +} + +// MonitorFromNeo4jDriver creates a Neo4jMonitor from a Neo4j driver. +// Exported for use when you need to create the monitor separately. +func MonitorFromNeo4jDriver( + driver neo4j.DriverWithContext, + querySem chan struct{}, + maxConcurrency int, +) loadmonitor.Monitor { + return NewNeo4jMonitor(driver, querySem, maxConcurrency, 100*time.Millisecond) +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go new file mode 100644 index 0000000..76e3179 --- /dev/null +++ b/pkg/ratelimit/limiter.go @@ -0,0 +1,409 @@ +package ratelimit + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "next.orly.dev/pkg/interfaces/loadmonitor" +) + +// OperationType distinguishes between read and write operations +// for applying different rate limiting strategies. +type OperationType int + +const ( + // Read operations (REQ queries) + Read OperationType = iota + // Write operations (EVENT saves, imports) + Write +) + +// String returns a human-readable name for the operation type. +func (o OperationType) String() string { + switch o { + case Read: + return "read" + case Write: + return "write" + default: + return "unknown" + } +} + +// Config holds configuration for the adaptive rate limiter. +type Config struct { + // Enabled controls whether rate limiting is active. + Enabled bool + + // TargetMemoryMB is the target memory limit in megabytes. + // Memory pressure is calculated relative to this target. + TargetMemoryMB int + + // WriteSetpoint is the target process variable for writes (0.0-1.0). + // Default: 0.85 (throttle when load exceeds 85%) + WriteSetpoint float64 + + // ReadSetpoint is the target process variable for reads (0.0-1.0). + // Default: 0.90 (more tolerant for reads) + ReadSetpoint float64 + + // PID gains for writes + WriteKp float64 + WriteKi float64 + WriteKd float64 + + // PID gains for reads + ReadKp float64 + ReadKi float64 + ReadKd float64 + + // MaxWriteDelayMs is the maximum delay for write operations in milliseconds. + MaxWriteDelayMs int + + // MaxReadDelayMs is the maximum delay for read operations in milliseconds. + MaxReadDelayMs int + + // MetricUpdateInterval is how often to poll the load monitor. + MetricUpdateInterval time.Duration + + // MemoryWeight is the weight given to memory pressure in process variable (0.0-1.0). + // The remaining weight is given to the load metric. + // Default: 0.7 (70% memory, 30% load) + MemoryWeight float64 +} + +// DefaultConfig returns a default configuration for the rate limiter. +func DefaultConfig() Config { + return Config{ + Enabled: true, + TargetMemoryMB: 1500, // 1.5GB target + WriteSetpoint: 0.85, + ReadSetpoint: 0.90, + WriteKp: 0.5, + WriteKi: 0.1, + WriteKd: 0.05, + ReadKp: 0.3, + ReadKi: 0.05, + ReadKd: 0.02, + MaxWriteDelayMs: 1000, // 1 second max + MaxReadDelayMs: 500, // 500ms max + MetricUpdateInterval: 100 * time.Millisecond, + MemoryWeight: 0.7, + } +} + +// NewConfigFromValues creates a Config from individual configuration values. +// This is useful when loading configuration from environment variables. +func NewConfigFromValues( + enabled bool, + targetMB int, + writeKp, writeKi, writeKd float64, + readKp, readKi, readKd float64, + maxWriteMs, maxReadMs int, + writeTarget, readTarget float64, +) Config { + return Config{ + Enabled: enabled, + TargetMemoryMB: targetMB, + WriteSetpoint: writeTarget, + ReadSetpoint: readTarget, + WriteKp: writeKp, + WriteKi: writeKi, + WriteKd: writeKd, + ReadKp: readKp, + ReadKi: readKi, + ReadKd: readKd, + MaxWriteDelayMs: maxWriteMs, + MaxReadDelayMs: maxReadMs, + MetricUpdateInterval: 100 * time.Millisecond, + MemoryWeight: 0.7, + } +} + +// Limiter implements adaptive rate limiting using PID control. +// It monitors database load metrics and computes appropriate delays +// to keep the system within its target operating range. +type Limiter struct { + config Config + monitor loadmonitor.Monitor + + // PID controllers for reads and writes + writePID *PIDController + readPID *PIDController + + // Cached metrics (updated periodically) + metricsLock sync.RWMutex + currentMetrics loadmonitor.Metrics + + // Statistics + totalWriteDelayMs atomic.Int64 + totalReadDelayMs atomic.Int64 + writeThrottles atomic.Int64 + readThrottles atomic.Int64 + + // Lifecycle + ctx context.Context + cancel context.CancelFunc + stopOnce sync.Once + stopped chan struct{} + wg sync.WaitGroup +} + +// NewLimiter creates a new adaptive rate limiter. +// If monitor is nil, the limiter will be disabled. +func NewLimiter(config Config, monitor loadmonitor.Monitor) *Limiter { + ctx, cancel := context.WithCancel(context.Background()) + + l := &Limiter{ + config: config, + monitor: monitor, + ctx: ctx, + cancel: cancel, + stopped: make(chan struct{}), + } + + // Create PID controllers with configured gains + l.writePID = NewPIDController( + config.WriteKp, config.WriteKi, config.WriteKd, + config.WriteSetpoint, + 0.2, // Strong filtering for writes + -2.0, float64(config.MaxWriteDelayMs)/1000.0*2, // Anti-windup limits + 0, float64(config.MaxWriteDelayMs)/1000.0, + ) + + l.readPID = NewPIDController( + config.ReadKp, config.ReadKi, config.ReadKd, + config.ReadSetpoint, + 0.15, // Very strong filtering for reads + -1.0, float64(config.MaxReadDelayMs)/1000.0*2, + 0, float64(config.MaxReadDelayMs)/1000.0, + ) + + // Set memory target on monitor + if monitor != nil && config.TargetMemoryMB > 0 { + monitor.SetMemoryTarget(uint64(config.TargetMemoryMB) * 1024 * 1024) + } + + return l +} + +// Start begins the rate limiter's background metric collection. +func (l *Limiter) Start() { + if l.monitor == nil || !l.config.Enabled { + return + } + + // Start the monitor + l.monitor.Start() + + // Start metric update loop + l.wg.Add(1) + go l.updateLoop() +} + +// updateLoop periodically fetches metrics from the monitor. +func (l *Limiter) updateLoop() { + defer l.wg.Done() + + ticker := time.NewTicker(l.config.MetricUpdateInterval) + defer ticker.Stop() + + for { + select { + case <-l.ctx.Done(): + return + case <-ticker.C: + if l.monitor != nil { + metrics := l.monitor.GetMetrics() + l.metricsLock.Lock() + l.currentMetrics = metrics + l.metricsLock.Unlock() + } + } + } +} + +// Stop halts the rate limiter. +func (l *Limiter) Stop() { + l.stopOnce.Do(func() { + l.cancel() + if l.monitor != nil { + l.monitor.Stop() + } + l.wg.Wait() + close(l.stopped) + }) +} + +// Stopped returns a channel that closes when the limiter has stopped. +func (l *Limiter) Stopped() <-chan struct{} { + return l.stopped +} + +// Wait blocks until the rate limiter permits the operation to proceed. +// It returns the delay that was applied, or 0 if no delay was needed. +// If the context is cancelled, it returns immediately. +func (l *Limiter) Wait(ctx context.Context, opType OperationType) time.Duration { + if !l.config.Enabled || l.monitor == nil { + return 0 + } + + delay := l.ComputeDelay(opType) + if delay <= 0 { + return 0 + } + + // Apply the delay + select { + case <-ctx.Done(): + return 0 + case <-time.After(delay): + return delay + } +} + +// ComputeDelay calculates the recommended delay for an operation. +// This can be used to check the delay without actually waiting. +func (l *Limiter) ComputeDelay(opType OperationType) time.Duration { + if !l.config.Enabled || l.monitor == nil { + return 0 + } + + // Get current metrics + l.metricsLock.RLock() + metrics := l.currentMetrics + l.metricsLock.RUnlock() + + // Compute process variable as weighted combination of memory and load + var loadMetric float64 + switch opType { + case Write: + loadMetric = metrics.WriteLoad + case Read: + loadMetric = metrics.ReadLoad + } + + // Combine memory pressure and load + // Process variable = memoryWeight * memoryPressure + (1-memoryWeight) * loadMetric + pv := l.config.MemoryWeight*metrics.MemoryPressure + (1-l.config.MemoryWeight)*loadMetric + + // Select the appropriate PID controller + var delaySec float64 + switch opType { + case Write: + delaySec = l.writePID.Update(pv) + if delaySec > 0 { + l.writeThrottles.Add(1) + l.totalWriteDelayMs.Add(int64(delaySec * 1000)) + } + case Read: + delaySec = l.readPID.Update(pv) + if delaySec > 0 { + l.readThrottles.Add(1) + l.totalReadDelayMs.Add(int64(delaySec * 1000)) + } + } + + if delaySec <= 0 { + return 0 + } + + return time.Duration(delaySec * float64(time.Second)) +} + +// RecordLatency records an operation latency for the monitor. +func (l *Limiter) RecordLatency(opType OperationType, latency time.Duration) { + if l.monitor == nil { + return + } + + switch opType { + case Write: + l.monitor.RecordWriteLatency(latency) + case Read: + l.monitor.RecordQueryLatency(latency) + } +} + +// Stats returns rate limiter statistics. +type Stats struct { + WriteThrottles int64 + ReadThrottles int64 + TotalWriteDelayMs int64 + TotalReadDelayMs int64 + CurrentMetrics loadmonitor.Metrics + WritePIDState PIDState + ReadPIDState PIDState +} + +// PIDState contains the internal state of a PID controller. +type PIDState struct { + Integral float64 + PrevError float64 + PrevFilteredError float64 +} + +// GetStats returns current rate limiter statistics. +func (l *Limiter) GetStats() Stats { + l.metricsLock.RLock() + metrics := l.currentMetrics + l.metricsLock.RUnlock() + + wIntegral, wPrevErr, wPrevFiltered := l.writePID.GetState() + rIntegral, rPrevErr, rPrevFiltered := l.readPID.GetState() + + return Stats{ + WriteThrottles: l.writeThrottles.Load(), + ReadThrottles: l.readThrottles.Load(), + TotalWriteDelayMs: l.totalWriteDelayMs.Load(), + TotalReadDelayMs: l.totalReadDelayMs.Load(), + CurrentMetrics: metrics, + WritePIDState: PIDState{ + Integral: wIntegral, + PrevError: wPrevErr, + PrevFilteredError: wPrevFiltered, + }, + ReadPIDState: PIDState{ + Integral: rIntegral, + PrevError: rPrevErr, + PrevFilteredError: rPrevFiltered, + }, + } +} + +// Reset clears all PID controller state and statistics. +func (l *Limiter) Reset() { + l.writePID.Reset() + l.readPID.Reset() + l.writeThrottles.Store(0) + l.readThrottles.Store(0) + l.totalWriteDelayMs.Store(0) + l.totalReadDelayMs.Store(0) +} + +// IsEnabled returns whether rate limiting is active. +func (l *Limiter) IsEnabled() bool { + return l.config.Enabled && l.monitor != nil +} + +// UpdateConfig updates the rate limiter configuration. +// This is useful for dynamic tuning. +func (l *Limiter) UpdateConfig(config Config) { + l.config = config + + // Update PID controllers + l.writePID.SetSetpoint(config.WriteSetpoint) + l.writePID.SetGains(config.WriteKp, config.WriteKi, config.WriteKd) + l.writePID.OutputMax = float64(config.MaxWriteDelayMs) / 1000.0 + + l.readPID.SetSetpoint(config.ReadSetpoint) + l.readPID.SetGains(config.ReadKp, config.ReadKi, config.ReadKd) + l.readPID.OutputMax = float64(config.MaxReadDelayMs) / 1000.0 + + // Update memory target + if l.monitor != nil && config.TargetMemoryMB > 0 { + l.monitor.SetMemoryTarget(uint64(config.TargetMemoryMB) * 1024 * 1024) + } +} diff --git a/pkg/ratelimit/neo4j_monitor.go b/pkg/ratelimit/neo4j_monitor.go new file mode 100644 index 0000000..b4f69d0 --- /dev/null +++ b/pkg/ratelimit/neo4j_monitor.go @@ -0,0 +1,259 @@ +package ratelimit + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "next.orly.dev/pkg/interfaces/loadmonitor" +) + +// Neo4jMonitor implements loadmonitor.Monitor for Neo4j database. +// Since Neo4j driver doesn't expose detailed metrics, we track: +// - Memory pressure via Go runtime +// - Query concurrency via the semaphore +// - Latency via recording +type Neo4jMonitor struct { + driver neo4j.DriverWithContext + querySem chan struct{} // Reference to the query semaphore + + // Target memory for pressure calculation + targetMemoryBytes atomic.Uint64 + + // Latency tracking with exponential moving average + queryLatencyNs atomic.Int64 + writeLatencyNs atomic.Int64 + latencyAlpha float64 // EMA coefficient (default 0.1) + + // Concurrency tracking + activeReads atomic.Int32 + activeWrites atomic.Int32 + maxConcurrency int + + // Cached metrics (updated by background goroutine) + metricsLock sync.RWMutex + cachedMetrics loadmonitor.Metrics + + // Background collection + stopChan chan struct{} + stopped chan struct{} + interval time.Duration +} + +// Compile-time check that Neo4jMonitor implements loadmonitor.Monitor +var _ loadmonitor.Monitor = (*Neo4jMonitor)(nil) + +// NewNeo4jMonitor creates a new Neo4j load monitor. +// The querySem should be the same semaphore used for limiting concurrent queries. +// maxConcurrency is the maximum concurrent query limit (typically 10). +func NewNeo4jMonitor( + driver neo4j.DriverWithContext, + querySem chan struct{}, + maxConcurrency int, + updateInterval time.Duration, +) *Neo4jMonitor { + if updateInterval <= 0 { + updateInterval = 100 * time.Millisecond + } + if maxConcurrency <= 0 { + maxConcurrency = 10 + } + + m := &Neo4jMonitor{ + driver: driver, + querySem: querySem, + maxConcurrency: maxConcurrency, + latencyAlpha: 0.1, // 10% new, 90% old for smooth EMA + stopChan: make(chan struct{}), + stopped: make(chan struct{}), + interval: updateInterval, + } + + // Set a default target (1.5GB) + m.targetMemoryBytes.Store(1500 * 1024 * 1024) + + return m +} + +// GetMetrics returns the current load metrics. +func (m *Neo4jMonitor) GetMetrics() loadmonitor.Metrics { + m.metricsLock.RLock() + defer m.metricsLock.RUnlock() + return m.cachedMetrics +} + +// RecordQueryLatency records a query latency sample using exponential moving average. +func (m *Neo4jMonitor) RecordQueryLatency(latency time.Duration) { + ns := latency.Nanoseconds() + for { + old := m.queryLatencyNs.Load() + if old == 0 { + if m.queryLatencyNs.CompareAndSwap(0, ns) { + return + } + continue + } + // EMA: new = alpha * sample + (1-alpha) * old + newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) + if m.queryLatencyNs.CompareAndSwap(old, newVal) { + return + } + } +} + +// RecordWriteLatency records a write latency sample using exponential moving average. +func (m *Neo4jMonitor) RecordWriteLatency(latency time.Duration) { + ns := latency.Nanoseconds() + for { + old := m.writeLatencyNs.Load() + if old == 0 { + if m.writeLatencyNs.CompareAndSwap(0, ns) { + return + } + continue + } + // EMA: new = alpha * sample + (1-alpha) * old + newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) + if m.writeLatencyNs.CompareAndSwap(old, newVal) { + return + } + } +} + +// SetMemoryTarget sets the target memory limit in bytes. +func (m *Neo4jMonitor) SetMemoryTarget(bytes uint64) { + m.targetMemoryBytes.Store(bytes) +} + +// Start begins background metric collection. +func (m *Neo4jMonitor) Start() <-chan struct{} { + go m.collectLoop() + return m.stopped +} + +// Stop halts background metric collection. +func (m *Neo4jMonitor) Stop() { + close(m.stopChan) + <-m.stopped +} + +// collectLoop periodically collects metrics. +func (m *Neo4jMonitor) collectLoop() { + defer close(m.stopped) + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-m.stopChan: + return + case <-ticker.C: + m.updateMetrics() + } + } +} + +// updateMetrics collects current metrics. +func (m *Neo4jMonitor) updateMetrics() { + metrics := loadmonitor.Metrics{ + Timestamp: time.Now(), + } + + // Calculate memory pressure from Go runtime + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + targetBytes := m.targetMemoryBytes.Load() + if targetBytes > 0 { + // Use HeapAlloc as primary memory metric + metrics.MemoryPressure = float64(memStats.HeapAlloc) / float64(targetBytes) + } + + // Calculate load from semaphore usage + // querySem is a buffered channel - count how many slots are taken + if m.querySem != nil { + usedSlots := len(m.querySem) + concurrencyLoad := float64(usedSlots) / float64(m.maxConcurrency) + if concurrencyLoad > 1.0 { + concurrencyLoad = 1.0 + } + // Both read and write use the same semaphore + metrics.WriteLoad = concurrencyLoad + metrics.ReadLoad = concurrencyLoad + } + + // Add latency-based load adjustment + // High latency indicates the database is struggling + queryLatencyNs := m.queryLatencyNs.Load() + writeLatencyNs := m.writeLatencyNs.Load() + + // Consider > 500ms query latency as concerning + const latencyThresholdNs = 500 * 1e6 // 500ms + if queryLatencyNs > 0 { + latencyLoad := float64(queryLatencyNs) / float64(latencyThresholdNs) + if latencyLoad > 1.0 { + latencyLoad = 1.0 + } + // Blend concurrency and latency for read load + metrics.ReadLoad = 0.5*metrics.ReadLoad + 0.5*latencyLoad + } + + if writeLatencyNs > 0 { + latencyLoad := float64(writeLatencyNs) / float64(latencyThresholdNs) + if latencyLoad > 1.0 { + latencyLoad = 1.0 + } + // Blend concurrency and latency for write load + metrics.WriteLoad = 0.5*metrics.WriteLoad + 0.5*latencyLoad + } + + // Store latencies + metrics.QueryLatency = time.Duration(queryLatencyNs) + metrics.WriteLatency = time.Duration(writeLatencyNs) + + // Update cached metrics + m.metricsLock.Lock() + m.cachedMetrics = metrics + m.metricsLock.Unlock() +} + +// IncrementActiveReads tracks an active read operation. +// Call this when starting a read, and call the returned function when done. +func (m *Neo4jMonitor) IncrementActiveReads() func() { + m.activeReads.Add(1) + return func() { + m.activeReads.Add(-1) + } +} + +// IncrementActiveWrites tracks an active write operation. +// Call this when starting a write, and call the returned function when done. +func (m *Neo4jMonitor) IncrementActiveWrites() func() { + m.activeWrites.Add(1) + return func() { + m.activeWrites.Add(-1) + } +} + +// GetConcurrencyStats returns current concurrency statistics for debugging. +func (m *Neo4jMonitor) GetConcurrencyStats() (reads, writes int32, semUsed int) { + reads = m.activeReads.Load() + writes = m.activeWrites.Load() + if m.querySem != nil { + semUsed = len(m.querySem) + } + return +} + +// CheckConnectivity performs a connectivity check to Neo4j. +// This can be used to verify the database is responsive. +func (m *Neo4jMonitor) CheckConnectivity(ctx context.Context) error { + if m.driver == nil { + return nil + } + return m.driver.VerifyConnectivity(ctx) +} diff --git a/pkg/ratelimit/pid.go b/pkg/ratelimit/pid.go new file mode 100644 index 0000000..4987df5 --- /dev/null +++ b/pkg/ratelimit/pid.go @@ -0,0 +1,218 @@ +// Package ratelimit provides adaptive rate limiting using PID control. +// The PID controller uses proportional, integral, and derivative terms +// with a low-pass filter on the derivative to suppress high-frequency noise. +package ratelimit + +import ( + "math" + "sync" + "time" +) + +// PIDController implements a PID controller with filtered derivative. +// It is designed for rate limiting database operations based on load metrics. +// +// The controller computes a delay recommendation based on: +// - Proportional (P): Immediate response to current error +// - Integral (I): Accumulated error to eliminate steady-state offset +// - Derivative (D): Rate of change prediction (filtered to reduce noise) +// +// The filtered derivative uses a low-pass filter to attenuate high-frequency +// noise that would otherwise cause erratic control behavior. +type PIDController struct { + // Gains + Kp float64 // Proportional gain + Ki float64 // Integral gain + Kd float64 // Derivative gain + + // Setpoint is the target process variable value (e.g., 0.85 for 85% of target memory). + // The controller drives the process variable toward this setpoint. + Setpoint float64 + + // DerivativeFilterAlpha is the low-pass filter coefficient for the derivative term. + // Range: 0.0-1.0, where lower values provide stronger filtering. + // Recommended: 0.2 for strong filtering, 0.5 for moderate filtering. + DerivativeFilterAlpha float64 + + // Integral limits for anti-windup + IntegralMax float64 + IntegralMin float64 + + // Output limits + OutputMin float64 // Minimum output (typically 0 = no delay) + OutputMax float64 // Maximum output (max delay in seconds) + + // Internal state (protected by mutex) + mu sync.Mutex + integral float64 + prevError float64 + prevFilteredError float64 + lastUpdate time.Time + initialized bool +} + +// DefaultPIDControllerForWrites creates a PID controller tuned for write operations. +// Writes benefit from aggressive integral and moderate proportional response. +func DefaultPIDControllerForWrites() *PIDController { + return &PIDController{ + Kp: 0.5, // Moderate proportional response + Ki: 0.1, // Steady integral to eliminate offset + Kd: 0.05, // Small derivative for prediction + Setpoint: 0.85, // Target 85% of memory limit + DerivativeFilterAlpha: 0.2, // Strong filtering (20% new, 80% old) + IntegralMax: 10.0, // Anti-windup: max 10 seconds accumulated + IntegralMin: -2.0, // Allow small negative for faster recovery + OutputMin: 0.0, // No delay minimum + OutputMax: 1.0, // Max 1 second delay per write + } +} + +// DefaultPIDControllerForReads creates a PID controller tuned for read operations. +// Reads should be more responsive but with less aggressive throttling. +func DefaultPIDControllerForReads() *PIDController { + return &PIDController{ + Kp: 0.3, // Lower proportional (reads are more important) + Ki: 0.05, // Lower integral (don't accumulate as aggressively) + Kd: 0.02, // Very small derivative + Setpoint: 0.90, // Target 90% (more tolerant of memory use) + DerivativeFilterAlpha: 0.15, // Very strong filtering + IntegralMax: 5.0, // Lower anti-windup limit + IntegralMin: -1.0, // Allow small negative + OutputMin: 0.0, // No delay minimum + OutputMax: 0.5, // Max 500ms delay per read + } +} + +// NewPIDController creates a new PID controller with custom parameters. +func NewPIDController( + kp, ki, kd float64, + setpoint float64, + derivativeFilterAlpha float64, + integralMin, integralMax float64, + outputMin, outputMax float64, +) *PIDController { + return &PIDController{ + Kp: kp, + Ki: ki, + Kd: kd, + Setpoint: setpoint, + DerivativeFilterAlpha: derivativeFilterAlpha, + IntegralMin: integralMin, + IntegralMax: integralMax, + OutputMin: outputMin, + OutputMax: outputMax, + } +} + +// Update computes the PID output based on the current process variable. +// The process variable should be in the range [0.0, 1.0+] representing load level. +// +// Returns the recommended delay in seconds. A value of 0 means no delay needed. +func (p *PIDController) Update(processVariable float64) float64 { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now() + + // Initialize on first call + if !p.initialized { + p.lastUpdate = now + p.prevError = processVariable - p.Setpoint + p.prevFilteredError = p.prevError + p.initialized = true + return 0 // No delay on first call + } + + // Calculate time delta + dt := now.Sub(p.lastUpdate).Seconds() + if dt <= 0 { + dt = 0.001 // Minimum 1ms to avoid division by zero + } + p.lastUpdate = now + + // Calculate current error (positive when above setpoint = need to throttle) + error := processVariable - p.Setpoint + + // Proportional term: immediate response to current error + pTerm := p.Kp * error + + // Integral term: accumulate error over time + // Apply anti-windup by clamping the integral + p.integral += error * dt + p.integral = clamp(p.integral, p.IntegralMin, p.IntegralMax) + iTerm := p.Ki * p.integral + + // Derivative term with low-pass filter + // Apply exponential moving average to filter high-frequency noise: + // filtered = alpha * new + (1 - alpha) * old + // This is equivalent to a first-order low-pass filter + filteredError := p.DerivativeFilterAlpha*error + (1-p.DerivativeFilterAlpha)*p.prevFilteredError + + // Derivative of the filtered error + var dTerm float64 + if dt > 0 { + dTerm = p.Kd * (filteredError - p.prevFilteredError) / dt + } + + // Update previous values for next iteration + p.prevError = error + p.prevFilteredError = filteredError + + // Compute total output and clamp to limits + output := pTerm + iTerm + dTerm + output = clamp(output, p.OutputMin, p.OutputMax) + + // Only return positive delays (throttle when above setpoint) + if output < 0 { + return 0 + } + return output +} + +// Reset clears the controller state, useful when conditions change significantly. +func (p *PIDController) Reset() { + p.mu.Lock() + defer p.mu.Unlock() + + p.integral = 0 + p.prevError = 0 + p.prevFilteredError = 0 + p.initialized = false +} + +// SetSetpoint updates the target setpoint. +func (p *PIDController) SetSetpoint(setpoint float64) { + p.mu.Lock() + defer p.mu.Unlock() + p.Setpoint = setpoint +} + +// SetGains updates the PID gains. +func (p *PIDController) SetGains(kp, ki, kd float64) { + p.mu.Lock() + defer p.mu.Unlock() + p.Kp = kp + p.Ki = ki + p.Kd = kd +} + +// GetState returns the current internal state for monitoring/debugging. +func (p *PIDController) GetState() (integral, prevError, prevFilteredError float64) { + p.mu.Lock() + defer p.mu.Unlock() + return p.integral, p.prevError, p.prevFilteredError +} + +// clamp restricts a value to the range [min, max]. +func clamp(value, min, max float64) float64 { + if math.IsNaN(value) { + return 0 + } + if value < min { + return min + } + if value > max { + return max + } + return value +} diff --git a/pkg/ratelimit/pid_test.go b/pkg/ratelimit/pid_test.go new file mode 100644 index 0000000..b75c19b --- /dev/null +++ b/pkg/ratelimit/pid_test.go @@ -0,0 +1,176 @@ +package ratelimit + +import ( + "testing" + "time" +) + +func TestPIDController_BasicOperation(t *testing.T) { + pid := DefaultPIDControllerForWrites() + + // First call should return 0 (initialization) + delay := pid.Update(0.5) + if delay != 0 { + t.Errorf("expected 0 delay on first call, got %v", delay) + } + + // Sleep a bit to ensure dt > 0 + time.Sleep(10 * time.Millisecond) + + // Process variable below setpoint (0.5 < 0.85) should return 0 delay + delay = pid.Update(0.5) + if delay != 0 { + t.Errorf("expected 0 delay when below setpoint, got %v", delay) + } + + // Process variable above setpoint should return positive delay + time.Sleep(10 * time.Millisecond) + delay = pid.Update(0.95) // 0.95 > 0.85 setpoint + if delay <= 0 { + t.Errorf("expected positive delay when above setpoint, got %v", delay) + } +} + +func TestPIDController_IntegralAccumulation(t *testing.T) { + pid := NewPIDController( + 0.5, 0.5, 0.0, // High Ki, no Kd + 0.5, // setpoint + 0.2, // filter alpha + -10, 10, // integral bounds + 0, 1.0, // output bounds + ) + + // Initialize + pid.Update(0.5) + time.Sleep(10 * time.Millisecond) + + // Continuously above setpoint should accumulate integral + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + pid.Update(0.8) // 0.3 above setpoint + } + + integral, _, _ := pid.GetState() + if integral <= 0 { + t.Errorf("expected positive integral after sustained error, got %v", integral) + } +} + +func TestPIDController_FilteredDerivative(t *testing.T) { + pid := NewPIDController( + 0.0, 0.0, 1.0, // Only Kd + 0.5, // setpoint + 0.5, // 50% filtering + -10, 10, + 0, 1.0, + ) + + // Initialize with low value + pid.Update(0.5) + time.Sleep(10 * time.Millisecond) + + // Second call with same value - derivative should be near zero + pid.Update(0.5) + _, _, prevFiltered := pid.GetState() + + time.Sleep(10 * time.Millisecond) + + // Big jump - filtered derivative should be dampened + delay := pid.Update(1.0) + + // The filtered derivative should cause some response, but dampened + // Since we only have Kd=1.0 and alpha=0.5, the response should be modest + if delay < 0 { + t.Errorf("expected non-negative delay, got %v", delay) + } + + _, _, newFiltered := pid.GetState() + // Filtered error should have moved toward the new error but not fully + if newFiltered <= prevFiltered { + t.Errorf("filtered error should increase with rising process variable") + } +} + +func TestPIDController_AntiWindup(t *testing.T) { + pid := NewPIDController( + 0.0, 1.0, 0.0, // Only Ki + 0.5, // setpoint + 0.2, // filter alpha + -1.0, 1.0, // tight integral bounds + 0, 10.0, // wide output bounds + ) + + // Initialize + pid.Update(0.5) + + // Drive the integral to its limit + for i := 0; i < 100; i++ { + time.Sleep(1 * time.Millisecond) + pid.Update(1.0) // Large positive error + } + + integral, _, _ := pid.GetState() + if integral > 1.0 { + t.Errorf("integral should be clamped at 1.0, got %v", integral) + } +} + +func TestPIDController_Reset(t *testing.T) { + pid := DefaultPIDControllerForWrites() + + // Build up some state + pid.Update(0.5) + time.Sleep(10 * time.Millisecond) + pid.Update(0.9) + time.Sleep(10 * time.Millisecond) + pid.Update(0.95) + + // Reset + pid.Reset() + + integral, prevErr, prevFiltered := pid.GetState() + if integral != 0 || prevErr != 0 || prevFiltered != 0 { + t.Errorf("expected all state to be zero after reset") + } + + // Next call should behave like first call + delay := pid.Update(0.9) + if delay != 0 { + t.Errorf("expected 0 delay on first call after reset, got %v", delay) + } +} + +func TestPIDController_SetGains(t *testing.T) { + pid := DefaultPIDControllerForWrites() + + // Change gains + pid.SetGains(1.0, 0.5, 0.1) + + if pid.Kp != 1.0 || pid.Ki != 0.5 || pid.Kd != 0.1 { + t.Errorf("gains not updated correctly") + } +} + +func TestPIDController_SetSetpoint(t *testing.T) { + pid := DefaultPIDControllerForWrites() + + pid.SetSetpoint(0.7) + + if pid.Setpoint != 0.7 { + t.Errorf("setpoint not updated, got %v", pid.Setpoint) + } +} + +func TestDefaultControllers(t *testing.T) { + writePID := DefaultPIDControllerForWrites() + readPID := DefaultPIDControllerForReads() + + // Write controller should have higher gains and lower setpoint + if writePID.Kp <= readPID.Kp { + t.Errorf("write Kp should be higher than read Kp") + } + + if writePID.Setpoint >= readPID.Setpoint { + t.Errorf("write setpoint should be lower than read setpoint") + } +} diff --git a/pkg/run/run.go b/pkg/run/run.go index 934b2ae..c66eebd 100644 --- a/pkg/run/run.go +++ b/pkg/run/run.go @@ -16,6 +16,7 @@ import ( "next.orly.dev/app/config" "next.orly.dev/pkg/acl" "next.orly.dev/pkg/database" + "next.orly.dev/pkg/ratelimit" ) // Options configures relay startup behavior. @@ -126,8 +127,11 @@ func Start(cfg *config.C, opts *Options) (relay *Relay, err error) { } acl.Registry.Syncer() + // Create rate limiter (disabled for test relay instances) + limiter := ratelimit.NewDisabledLimiter() + // Start the relay - relay.quit = app.Run(relay.ctx, cfg, relay.db) + relay.quit = app.Run(relay.ctx, cfg, relay.db, limiter) return }