feat: basic per-connection rate limiter (#1)

This commit is contained in:
Ben Woodward
2023-08-05 07:50:38 -07:00
committed by fiatjaf
parent 36603006f4
commit 1935f62c29
5 changed files with 47 additions and 1 deletions

1
go.mod
View File

@@ -23,6 +23,7 @@ require (
github.com/stretchr/testify v1.8.0
github.com/tidwall/gjson v1.14.4
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
)
require (

1
go.sum
View File

@@ -573,6 +573,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -15,6 +15,7 @@ import (
"github.com/nbd-wtf/go-nostr/nip11"
"github.com/nbd-wtf/go-nostr/nip42"
"golang.org/x/exp/slices"
"golang.org/x/time/rate"
)
// TODO: consider moving these to Server as config params
@@ -63,6 +64,13 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
challenge: hex.EncodeToString(challenge),
}
if s.options.perConnectionLimiter != nil {
ws.limiter = rate.NewLimiter(
s.options.perConnectionLimiter.Limit(),
s.options.perConnectionLimiter.Burst(),
)
}
// reader
go func() {
defer func() {
@@ -102,6 +110,15 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
break
}
if ws.limiter != nil {
// NOTE: Wait will throttle the requests.
// To reject requests exceeding the limit, use if !ws.limiter.Allow()
if err := ws.limiter.Wait(context.Background()); err != nil {
s.Log.Warningf("unexpected limiter error %v", err)
continue
}
}
if typ == websocket.PingMessage {
ws.WriteMessage(websocket.PongMessage, nil)
continue

View File

@@ -13,6 +13,7 @@ import (
"github.com/fasthttp/websocket"
"github.com/rs/cors"
"golang.org/x/time/rate"
)
// Server is a base for package users to implement nostr relays.
@@ -34,6 +35,8 @@ type Server struct {
// outputting to stderr.
Log Logger
options *Options
relay Relay
// keep a connection reference to all connected clients for Server.Shutdown
@@ -52,12 +55,18 @@ func (s *Server) Router() *http.ServeMux {
// NewServer initializes the relay and its storage using their respective Init methods,
// returning any non-nil errors, and returns a Server ready to listen for HTTP requests.
func NewServer(relay Relay) (*Server, error) {
func NewServer(relay Relay, opts ...Option) (*Server, error) {
options := DefaultOptions()
for _, opt := range opts {
opt(options)
}
srv := &Server{
Log: defaultLogger(relay.Name() + ": "),
relay: relay,
clients: make(map[*websocket.Conn]struct{}),
serveMux: &http.ServeMux{},
options: options,
}
// init the relay
@@ -142,6 +151,22 @@ func (s *Server) Shutdown(ctx context.Context) {
}
}
type Option func(*Options)
type Options struct {
perConnectionLimiter *rate.Limiter
}
func DefaultOptions() *Options {
return &Options{}
}
func WithPerConnectionLimiter(rps rate.Limit, burst int) Option {
return func(o *Options) {
o.perConnectionLimiter = rate.NewLimiter(rps, burst)
}
}
func defaultLogger(prefix string) Logger {
l := log.New(os.Stderr, "", log.LstdFlags|log.Lmsgprefix)
l.SetPrefix(prefix)

View File

@@ -4,6 +4,7 @@ import (
"sync"
"github.com/fasthttp/websocket"
"golang.org/x/time/rate"
)
type WebSocket struct {
@@ -13,6 +14,7 @@ type WebSocket struct {
// nip42
challenge string
authed string
limiter *rate.Limiter
}
func (ws *WebSocket) WriteJSON(any interface{}) error {