feat: basic per-connection rate limiter (#1)
This commit is contained in:
1
go.mod
1
go.mod
@@ -23,6 +23,7 @@ require (
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/tidwall/gjson v1.14.4
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
|
||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
1
go.sum
1
go.sum
@@ -573,6 +573,7 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE=
|
||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
17
handlers.go
17
handlers.go
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/nbd-wtf/go-nostr/nip11"
|
||||
"github.com/nbd-wtf/go-nostr/nip42"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TODO: consider moving these to Server as config params
|
||||
@@ -63,6 +64,13 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
challenge: hex.EncodeToString(challenge),
|
||||
}
|
||||
|
||||
if s.options.perConnectionLimiter != nil {
|
||||
ws.limiter = rate.NewLimiter(
|
||||
s.options.perConnectionLimiter.Limit(),
|
||||
s.options.perConnectionLimiter.Burst(),
|
||||
)
|
||||
}
|
||||
|
||||
// reader
|
||||
go func() {
|
||||
defer func() {
|
||||
@@ -102,6 +110,15 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
break
|
||||
}
|
||||
|
||||
if ws.limiter != nil {
|
||||
// NOTE: Wait will throttle the requests.
|
||||
// To reject requests exceeding the limit, use if !ws.limiter.Allow()
|
||||
if err := ws.limiter.Wait(context.Background()); err != nil {
|
||||
s.Log.Warningf("unexpected limiter error %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if typ == websocket.PingMessage {
|
||||
ws.WriteMessage(websocket.PongMessage, nil)
|
||||
continue
|
||||
|
||||
27
start.go
27
start.go
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/fasthttp/websocket"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Server is a base for package users to implement nostr relays.
|
||||
@@ -34,6 +35,8 @@ type Server struct {
|
||||
// outputting to stderr.
|
||||
Log Logger
|
||||
|
||||
options *Options
|
||||
|
||||
relay Relay
|
||||
|
||||
// keep a connection reference to all connected clients for Server.Shutdown
|
||||
@@ -52,12 +55,18 @@ func (s *Server) Router() *http.ServeMux {
|
||||
|
||||
// NewServer initializes the relay and its storage using their respective Init methods,
|
||||
// returning any non-nil errors, and returns a Server ready to listen for HTTP requests.
|
||||
func NewServer(relay Relay) (*Server, error) {
|
||||
func NewServer(relay Relay, opts ...Option) (*Server, error) {
|
||||
options := DefaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
Log: defaultLogger(relay.Name() + ": "),
|
||||
relay: relay,
|
||||
clients: make(map[*websocket.Conn]struct{}),
|
||||
serveMux: &http.ServeMux{},
|
||||
options: options,
|
||||
}
|
||||
|
||||
// init the relay
|
||||
@@ -142,6 +151,22 @@ func (s *Server) Shutdown(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
|
||||
type Options struct {
|
||||
perConnectionLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
func DefaultOptions() *Options {
|
||||
return &Options{}
|
||||
}
|
||||
|
||||
func WithPerConnectionLimiter(rps rate.Limit, burst int) Option {
|
||||
return func(o *Options) {
|
||||
o.perConnectionLimiter = rate.NewLimiter(rps, burst)
|
||||
}
|
||||
}
|
||||
|
||||
func defaultLogger(prefix string) Logger {
|
||||
l := log.New(os.Stderr, "", log.LstdFlags|log.Lmsgprefix)
|
||||
l.SetPrefix(prefix)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/fasthttp/websocket"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
@@ -13,6 +14,7 @@ type WebSocket struct {
|
||||
// nip42
|
||||
challenge string
|
||||
authed string
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (ws *WebSocket) WriteJSON(any interface{}) error {
|
||||
|
||||
Reference in New Issue
Block a user