diff --git a/realy/handleWebsocket.go b/realy/handleWebsocket.go index 7b340e9..7bc9e02 100644 --- a/realy/handleWebsocket.go +++ b/realy/handleWebsocket.go @@ -2,89 +2,11 @@ package realy import ( "net/http" - "time" - "github.com/fasthttp/websocket" - - "realy.mleku.dev/context" - "realy.mleku.dev/envelopes/authenvelope" - "realy.mleku.dev/realy/subscribers" "realy.mleku.dev/socketapi" ) func (s *Server) handleWebsocket(w http.ResponseWriter, r *http.Request) { - conn, err := subscribers.Upgrader.Upgrade(w, r, nil) - if err != nil { - log.E.F("failed to upgrade websocket: %v", err) - return - } - s.clientsMu.Lock() - defer s.clientsMu.Unlock() - s.clients[conn] = struct{}{} - ticker := time.NewTicker(s.listeners.PingPeriod) - ip := conn.RemoteAddr().String() - var realIP string - if realIP = r.Header.Get("X-Forwarded-For"); realIP != "" { - ip = realIP - } else if realIP = r.Header.Get("X-Real-Ip"); realIP != "" { - ip = realIP - } - log.T.F("connected from %s", ip) - ws := s.listeners.GetChallenge(conn, r, ip) - ctx, cancel := context.Cancel(context.Bg()) - sto := s.relay.Storage() - go func() { - defer func() { - cancel() - ticker.Stop() - s.clientsMu.Lock() - if _, ok := s.clients[conn]; ok { - chk.E(conn.Close()) - delete(s.clients, conn) - s.listeners.RemoveSubscriber(ws) - } - s.clientsMu.Unlock() - }() - conn.SetReadLimit(s.listeners.MaxMessageSize) - chk.E(conn.SetReadDeadline(time.Now().Add(s.listeners.PongWait))) - conn.SetPongHandler(func(string) error { - chk.E(conn.SetReadDeadline(time.Now().Add(s.listeners.PongWait))) - return nil - }) - if s.authRequired { - ws.RequestAuth() - } - if ws.AuthRequested() && len(ws.Authed()) == 0 { - log.I.F("requesting auth from client from %s", ws.RealRemote()) - if err = authenvelope.NewChallengeWith(ws.Challenge()).Write(ws); chk.E(err) { - return - } - // return - } - var message []byte - var typ int - for { - typ, message, err = conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, - websocket.CloseNormalClosure, - websocket.CloseGoingAway, - websocket.CloseNoStatusReceived, - websocket.CloseAbnormalClosure, - ) { - log.W.F("unexpected close error from %s: %v", - r.Header.Get("X-Forwarded-For"), err) - } - break - } - if typ == websocket.PingMessage { - if err = ws.WriteMessage(websocket.PongMessage, nil); chk.E(err) { - } - continue - } - a := &socketapi.A{ws} - go s.handleMessage(ctx, a, message, sto) - } - }() - go s.pinger(ctx, ws, conn, ticker, cancel) + a := &socketapi.A{Server: s, ClientsMu: &s.clientsMu, Clients: s.clients} + a.Serve(w, r, s) } diff --git a/realy/interfaces/interfaces.go b/realy/interfaces/interfaces.go index 5a0c859..7f71212 100644 --- a/realy/interfaces/interfaces.go +++ b/realy/interfaces/interfaces.go @@ -6,6 +6,7 @@ import ( "realy.mleku.dev/context" "realy.mleku.dev/event" + "realy.mleku.dev/realy/options" "realy.mleku.dev/realy/subscribers" "realy.mleku.dev/relay" "realy.mleku.dev/store" @@ -33,4 +34,5 @@ type Server interface { SetConfiguration(*store.Configuration) Shutdown() Storage() store.I + Options() *options.T } diff --git a/realy/server-impl.go b/realy/server-impl.go index 18aaf43..5075a73 100644 --- a/realy/server-impl.go +++ b/realy/server-impl.go @@ -7,6 +7,7 @@ import ( "realy.mleku.dev/context" "realy.mleku.dev/event" "realy.mleku.dev/realy/interfaces" + "realy.mleku.dev/realy/options" "realy.mleku.dev/realy/subscribers" "realy.mleku.dev/relay" "realy.mleku.dev/store" @@ -61,4 +62,6 @@ func (s *Server) Owners() [][]byte { return s.owners } func (s *Server) AuthRequired() bool { return s.authRequired } +func (s *Server) Options() *options.T { return s.options } + var _ interfaces.Server = &Server{} diff --git a/realy/subscribers/subscribers.go b/realy/subscribers/subscribers.go index 604cfc1..e488bc7 100644 --- a/realy/subscribers/subscribers.go +++ b/realy/subscribers/subscribers.go @@ -78,10 +78,6 @@ const ( var ( NIP20prefixmatcher = regexp.MustCompile(`^\w+: `) - Upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true - }} ) // New creates a new subscribers.S. @@ -114,8 +110,7 @@ func New(ctx context.T) (l *S) { } // GetChallenge generates a new challenge for a subscriber. -func (s *S) GetChallenge(conn *websocket.Conn, req *http.Request, - addr string) (w *ws.Listener) { +func (s *S) GetChallenge(conn *websocket.Conn, req *http.Request) (w *ws.Listener) { var err error cb := make([]byte, s.ChallengeLength) if _, err = rand.Read(cb); chk.E(err) { diff --git a/realy/ws-handle.go b/socketapi/handleMessage.go similarity index 69% rename from realy/ws-handle.go rename to socketapi/handleMessage.go index ce2a933..ed0ae04 100644 --- a/realy/ws-handle.go +++ b/socketapi/handleMessage.go @@ -1,9 +1,8 @@ -package realy +package socketapi import ( "fmt" - "realy.mleku.dev/context" "realy.mleku.dev/envelopes" "realy.mleku.dev/envelopes/authenvelope" "realy.mleku.dev/envelopes/closeenvelope" @@ -11,11 +10,9 @@ import ( "realy.mleku.dev/envelopes/noticeenvelope" "realy.mleku.dev/envelopes/reqenvelope" "realy.mleku.dev/relay" - "realy.mleku.dev/socketapi" - "realy.mleku.dev/store" ) -func (s *Server) handleMessage(c context.T, a *socketapi.A, msg []byte, sto store.I) { +func (a *A) HandleMessage(msg []byte) { var notice []byte var err error var t string @@ -23,17 +20,16 @@ func (s *Server) handleMessage(c context.T, a *socketapi.A, msg []byte, sto stor if t, rem, err = envelopes.Identify(msg); chk.E(err) { notice = []byte(err.Error()) } - skipEventFunc := s.options.SkipEventFunc - rl := s.relay + rl := a.Relay() switch t { case eventenvelope.L: - notice = a.HandleEvent(c, rem, s) + notice = a.HandleEvent(a.Context(), rem, a.Server) case reqenvelope.L: - notice = a.HandleReq(c, rem, skipEventFunc, s) + notice = a.HandleReq(a.Context(), rem, a.Options().SkipEventFunc, a.Server) case closeenvelope.L: - notice = a.HandleClose(rem, s) + notice = a.HandleClose(rem, a.Server) case authenvelope.L: - notice = a.HandleAuth(rem, s) + notice = a.HandleAuth(rem, a.Server) default: if wsh, ok := rl.(relay.WebSocketHandler); ok { wsh.HandleUnknownType(a.Listener, t, rem) @@ -47,4 +43,5 @@ func (s *Server) handleMessage(c context.T, a *socketapi.A, msg []byte, sto stor return } } + } diff --git a/socketapi/handleWebsocket.go b/socketapi/handleWebsocket.go new file mode 100644 index 0000000..e5f9927 --- /dev/null +++ b/socketapi/handleWebsocket.go @@ -0,0 +1 @@ +package socketapi diff --git a/realy/ping.go b/socketapi/pinger.go similarity index 50% rename from realy/ping.go rename to socketapi/pinger.go index 031354f..aa3b360 100644 --- a/realy/ping.go +++ b/socketapi/pinger.go @@ -1,4 +1,4 @@ -package realy +package socketapi import ( "time" @@ -6,27 +6,26 @@ import ( "github.com/fasthttp/websocket" "realy.mleku.dev/context" - "realy.mleku.dev/ws" + "realy.mleku.dev/realy/interfaces" ) -func (s *Server) pinger(ctx context.T, ws *ws.Listener, conn *websocket.Conn, - ticker *time.Ticker, cancel context.F) { +func (a *A) Pinger(ctx context.T, ticker *time.Ticker, cancel context.F, s interfaces.Server) { defer func() { cancel() ticker.Stop() - _ = conn.Close() + _ = a.Listener.Conn.Close() }() var err error for { select { case <-ticker.C: - err = conn.WriteControl(websocket.PingMessage, nil, - time.Now().Add(s.listeners.WriteWait)) + err = a.Listener.Conn.WriteControl(websocket.PingMessage, nil, + time.Now().Add(s.Listeners().WriteWait)) if err != nil { log.E.F("error writing ping: %v; closing websocket", err) return } - ws.RealRemote() + a.Listener.RealRemote() case <-ctx.Done(): return } diff --git a/socketapi/socketapi.go b/socketapi/socketapi.go index b732343..8ec9aaa 100644 --- a/socketapi/socketapi.go +++ b/socketapi/socketapi.go @@ -1,7 +1,90 @@ package socketapi import ( + "net/http" + "sync" + "time" + + "github.com/fasthttp/websocket" + + "realy.mleku.dev/context" + "realy.mleku.dev/envelopes/authenvelope" + "realy.mleku.dev/realy/interfaces" "realy.mleku.dev/ws" ) -type A struct{ *ws.Listener } +type A struct { + *ws.Listener + interfaces.Server + ClientsMu *sync.Mutex + Clients map[*websocket.Conn]struct{} +} + +func (a *A) Serve(w http.ResponseWriter, r *http.Request, s interfaces.Server) { + + var err error + ticker := time.NewTicker(s.Listeners().PingPeriod) + ctx, cancel := context.Cancel(context.Bg()) + var conn *websocket.Conn + conn, err = Upgrader.Upgrade(w, r, nil) + if err != nil { + log.E.F("failed to upgrade websocket: %v", err) + return + } + a.ClientsMu.Lock() + defer a.ClientsMu.Unlock() + a.Clients[conn] = struct{}{} + a.Listener = s.Listeners().GetChallenge(conn, r) + + defer func() { + cancel() + ticker.Stop() + a.ClientsMu.Lock() + if _, ok := a.Clients[a.Listener.Conn]; ok { + chk.E(a.Listener.Conn.Close()) + delete(a.Clients, a.Listener.Conn) + a.Listeners().RemoveSubscriber(a.Listener) + } + a.ClientsMu.Unlock() + }() + conn.SetReadLimit(a.Listeners().MaxMessageSize) + chk.E(conn.SetReadDeadline(time.Now().Add(a.Listeners().PongWait))) + conn.SetPongHandler(func(string) error { + chk.E(conn.SetReadDeadline(time.Now().Add(a.Listeners().PongWait))) + return nil + }) + if a.Server.AuthRequired() { + a.Listener.RequestAuth() + } + if a.Listener.AuthRequested() && len(a.Listener.Authed()) == 0 { + log.I.F("requesting auth from client from %s", a.Listener.RealRemote()) + if err = authenvelope.NewChallengeWith(a.Listener.Challenge()).Write(a.Listener); chk.E(err) { + return + } + // return + } + go a.Pinger(ctx, ticker, cancel, a.Server) + var message []byte + var typ int + for { + typ, message, err = conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure, + ) { + log.W.F("unexpected close error from %s: %v", + a.Listener.Request.Header.Get("X-Forwarded-For"), err) + } + break + } + if typ == websocket.PingMessage { + if err = a.Listener.WriteMessage(websocket.PongMessage, nil); chk.E(err) { + } + continue + } + go a.HandleMessage(message) + } +} diff --git a/socketapi/upgrader.go b/socketapi/upgrader.go new file mode 100644 index 0000000..8ade039 --- /dev/null +++ b/socketapi/upgrader.go @@ -0,0 +1,12 @@ +package socketapi + +import ( + "net/http" + + "github.com/fasthttp/websocket" +) + +var Upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }} diff --git a/ws/listener.go b/ws/listener.go index 80e5c6b..da918bc 100644 --- a/ws/listener.go +++ b/ws/listener.go @@ -14,8 +14,8 @@ import ( // Listener is a websocket implementation for a relay listener. type Listener struct { mutex sync.Mutex - conn *websocket.Conn - req *http.Request + Conn *websocket.Conn + Request *http.Request challenge atomic.String remote atomic.String authed atomic.String @@ -28,7 +28,7 @@ func NewListener( req *http.Request, challenge []byte, ) (ws *Listener) { - ws = &Listener{conn: conn, req: req} + ws = &Listener{Conn: conn, Request: req} ws.challenge.Store(string(challenge)) ws.authRequested.Store(false) ws.setRemoteFromReq(req) @@ -62,7 +62,7 @@ func (ws *Listener) setRemoteFromReq(r *http.Request) { if rr == "" { // if that fails, fall back to the remote (probably the proxy, unless the realy is // actually directly listening) - rr = ws.conn.NetConn().RemoteAddr().String() + rr = ws.Conn.NetConn().RemoteAddr().String() } ws.remote.Store(rr) } @@ -71,7 +71,7 @@ func (ws *Listener) setRemoteFromReq(r *http.Request) { func (ws *Listener) Write(p []byte) (n int, err error) { ws.mutex.Lock() defer ws.mutex.Unlock() - err = ws.conn.WriteMessage(websocket.TextMessage, p) + err = ws.Conn.WriteMessage(websocket.TextMessage, p) if err != nil { n = len(p) if strings.Contains(err.Error(), "close sent") { @@ -88,7 +88,7 @@ func (ws *Listener) Write(p []byte) (n int, err error) { func (ws *Listener) WriteJSON(any interface{}) error { ws.mutex.Lock() defer ws.mutex.Unlock() - return ws.conn.WriteJSON(any) + return ws.Conn.WriteJSON(any) } // WriteMessage is a wrapper around the websocket WriteMessage, which includes a websocket @@ -96,7 +96,7 @@ func (ws *Listener) WriteJSON(any interface{}) error { func (ws *Listener) WriteMessage(t int, b []byte) error { ws.mutex.Lock() defer ws.mutex.Unlock() - return ws.conn.WriteMessage(t, b) + return ws.Conn.WriteMessage(t, b) } // Challenge returns the current auth challenge string on the socket. @@ -122,7 +122,7 @@ func (ws *Listener) SetAuthed(s string) { } // Req returns the http.Request associated with the client connection to the Listener. -func (ws *Listener) Req() *http.Request { return ws.req } +func (ws *Listener) Req() *http.Request { return ws.Request } // Close the Listener connection from the Listener side. -func (ws *Listener) Close() (err error) { return ws.conn.Close() } +func (ws *Listener) Close() (err error) { return ws.Conn.Close() }