refactor lerproxy to be better laid out
This commit is contained in:
104
cmd/lerproxy/app/app.go
Normal file
104
cmd/lerproxy/app/app.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"net"
|
||||
"net/http"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RunArgs struct {
|
||||
Addr string `arg:"-l,--listen" default:":https" help:"address to listen at"`
|
||||
Conf string `arg:"-m,--map" default:"mapping.txt" help:"file with host/backend mapping"`
|
||||
Cache string `arg:"-c,--cachedir" default:"/var/cache/letsencrypt" help:"path to directory to cache key and certificates"`
|
||||
HSTS bool `arg:"-h,--hsts" help:"add Strict-Transport-Security header"`
|
||||
Email string `arg:"-e,--email" help:"contact email address presented to letsencrypt CA"`
|
||||
HTTP string `arg:"--http" default:":http" help:"optional address to serve http-to-https redirects and ACME http-01 challenge responses"`
|
||||
RTO time.Duration `arg:"-r,--rto" default:"1m" help:"maximum duration before timing out read of the request"`
|
||||
WTO time.Duration `arg:"-w,--wto" default:"5m" help:"maximum duration before timing out write of the response"`
|
||||
Idle time.Duration `arg:"-i,--idle" help:"how long idle connection is kept before closing (set rto, wto to 0 to use this)"`
|
||||
Certs []string `arg:"--cert,separate" help:"certificates and the domain they match: eg: orly.dev:/path/to/cert - this will indicate to load two, one with extension .key and one with .crt, each expected to be PEM encoded TLS private and public keys, respectively"`
|
||||
// Rewrites string `arg:"-r,--rewrites" default:"rewrites.txt"`
|
||||
}
|
||||
|
||||
func Run(c context.T, args RunArgs) (err error) {
|
||||
if args.Cache == "" {
|
||||
err = log.E.Err("no cache specified")
|
||||
return
|
||||
}
|
||||
var srv *http.Server
|
||||
var httpHandler http.Handler
|
||||
if srv, httpHandler, err = SetupServer(args); chk.E(err) {
|
||||
return
|
||||
}
|
||||
srv.ReadHeaderTimeout = 5 * time.Second
|
||||
if args.RTO > 0 {
|
||||
srv.ReadTimeout = args.RTO
|
||||
}
|
||||
if args.WTO > 0 {
|
||||
srv.WriteTimeout = args.WTO
|
||||
}
|
||||
group, ctx := errgroup.WithContext(c)
|
||||
if args.HTTP != "" {
|
||||
httpServer := http.Server{
|
||||
Addr: args.HTTP,
|
||||
Handler: httpHandler,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
chk.E(httpServer.ListenAndServe())
|
||||
return
|
||||
},
|
||||
)
|
||||
group.Go(
|
||||
func() error {
|
||||
<-ctx.Done()
|
||||
ctx, cancel := context.Timeout(
|
||||
context.Bg(),
|
||||
time.Second,
|
||||
)
|
||||
defer cancel()
|
||||
return httpServer.Shutdown(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
if srv.ReadTimeout != 0 || srv.WriteTimeout != 0 || args.Idle == 0 {
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
chk.E(srv.ListenAndServeTLS("", ""))
|
||||
return
|
||||
},
|
||||
)
|
||||
} else {
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
var ln net.Listener
|
||||
if ln, err = net.Listen("tcp", srv.Addr); chk.E(err) {
|
||||
return
|
||||
}
|
||||
defer ln.Close()
|
||||
ln = Listener{
|
||||
Duration: args.Idle,
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
}
|
||||
err = srv.ServeTLS(ln, "", "")
|
||||
chk.E(err)
|
||||
return
|
||||
},
|
||||
)
|
||||
}
|
||||
group.Go(
|
||||
func() error {
|
||||
<-ctx.Done()
|
||||
ctx, cancel := context.Timeout(context.Bg(), time.Second)
|
||||
defer cancel()
|
||||
return srv.Shutdown(ctx)
|
||||
},
|
||||
)
|
||||
return group.Wait()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package app
|
||||
|
||||
import "sync"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package app
|
||||
|
||||
import (
|
||||
"net"
|
||||
63
cmd/lerproxy/app/go-vanity.go
Normal file
63
cmd/lerproxy/app/go-vanity.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GoVanity configures an HTTP handler for redirecting requests to vanity URLs
|
||||
// based on the provided hostname and backend address.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - hn (string): The hostname associated with the vanity URL.
|
||||
//
|
||||
// - ba (string): The backend address, expected to be in the format
|
||||
// "git+<repository-path>".
|
||||
//
|
||||
// - mux (*http.ServeMux): The HTTP serve multiplexer where the handler will be
|
||||
// registered.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Splits the backend address to extract the repository path from the "git+" prefix.
|
||||
//
|
||||
// - If the split fails, logs an error and returns without registering a handler.
|
||||
//
|
||||
// - Generates an HTML redirect page containing metadata for Go import and
|
||||
// redirects to the extracted repository path.
|
||||
//
|
||||
// - Registers a handler on the provided ServeMux that serves this redirect page
|
||||
// when requests are made to the specified hostname.
|
||||
func GoVanity(hn, ba string, mux *http.ServeMux) {
|
||||
split := strings.Split(ba, "git+")
|
||||
if len(split) != 2 {
|
||||
log.E.Ln("invalid go vanity redirect: %s: %s", hn, ba)
|
||||
return
|
||||
}
|
||||
redirector := fmt.Sprintf(
|
||||
`<html><head><meta name="go-import" content="%s git %s"/><meta http-equiv = "refresh" content = " 3 ; url = %s"/></head><body>redirecting to <a href="%s">%s</a></body></html>`,
|
||||
hn, split[1], split[1], split[1], split[1],
|
||||
)
|
||||
mux.HandleFunc(
|
||||
hn+"/",
|
||||
func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Content-Type", "text/html")
|
||||
writer.Header().Set(
|
||||
"Content-Length", fmt.Sprint(len(redirector)),
|
||||
)
|
||||
writer.Header().Set(
|
||||
"strict-transport-security",
|
||||
"max-age=0; includeSubDomains",
|
||||
)
|
||||
fmt.Fprint(writer, redirector)
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package app
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Period can be changed prior to opening a Listener to alter its'
|
||||
// Period can be changed before opening a Listener to alter its
|
||||
// KeepAlivePeriod.
|
||||
var Period = 3 * time.Minute
|
||||
|
||||
// Listener sets TCP keep-alive timeouts on accepted connections.
|
||||
// It's used by ListenAndServe and ListenAndServeTLS so dead TCP connections
|
||||
// It is used by ListenAndServe and ListenAndServeTLS so dead TCP connections
|
||||
// (e.g. closing laptop mid-download) eventually go away.
|
||||
type Listener struct {
|
||||
time.Duration
|
||||
80
cmd/lerproxy/app/nostr-dns.go
Normal file
80
cmd/lerproxy/app/nostr-dns.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"os"
|
||||
)
|
||||
|
||||
type NostrJSON struct {
|
||||
Names map[string]string `json:"names"`
|
||||
Relays map[string][]string `json:"relays"`
|
||||
}
|
||||
|
||||
// NostrDNS handles the configuration and registration of a Nostr DNS endpoint
|
||||
// for a given hostname and backend address.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - hn (string): The hostname for which the Nostr DNS entry is being configured.
|
||||
//
|
||||
// - ba (string): The path to the JSON file containing the Nostr DNS data.
|
||||
//
|
||||
// - mux (*http.ServeMux): The HTTP serve multiplexer to which the Nostr DNS
|
||||
// handler will be registered.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - err (error): An error if any step fails during the configuration or
|
||||
// registration process.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Reads the JSON file specified by `ba` and parses its contents into a
|
||||
// NostrJSON struct.
|
||||
//
|
||||
// - Registers a new HTTP handler on the provided `mux` for the
|
||||
// `.well-known/nostr.json` endpoint under the specified hostname.
|
||||
//
|
||||
// - The handler serves the parsed Nostr DNS data with appropriate HTTP headers
|
||||
// set for CORS and content type.
|
||||
func NostrDNS(hn, ba string, mux *http.ServeMux) (err error) {
|
||||
log.T.Ln(hn, ba)
|
||||
var fb []byte
|
||||
if fb, err = os.ReadFile(ba); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var v NostrJSON
|
||||
if err = json.Unmarshal(fb, &v); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var jb []byte
|
||||
if jb, err = json.Marshal(v); chk.E(err) {
|
||||
return
|
||||
}
|
||||
nostrJSON := string(jb)
|
||||
mux.HandleFunc(
|
||||
hn+"/.well-known/nostr.json",
|
||||
func(writer http.ResponseWriter, request *http.Request) {
|
||||
log.T.Ln("serving nostr json to", hn)
|
||||
writer.Header().Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.Header().Set(
|
||||
"Content-Length", fmt.Sprint(len(nostrJSON)),
|
||||
)
|
||||
writer.Header().Set(
|
||||
"strict-transport-security",
|
||||
"max-age=0; includeSubDomains",
|
||||
)
|
||||
fmt.Fprint(writer, nostrJSON)
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
15
cmd/lerproxy/app/proxy.go
Normal file
15
cmd/lerproxy/app/proxy.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package app
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Proxy struct {
|
||||
http.Handler
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(
|
||||
"Strict-Transport-Security",
|
||||
"max-age=31536000; includeSubDomains; preload",
|
||||
)
|
||||
p.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
62
cmd/lerproxy/app/read-mapping.go
Normal file
62
cmd/lerproxy/app/read-mapping.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ReadMapping reads a mapping file and returns a map of hostnames to backend
|
||||
// addresses.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - file (string): The path to the mapping file to read.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - m (map[string]string): A map containing the hostname to backend address
|
||||
// mappings parsed from the file.
|
||||
//
|
||||
// - err (error): An error if any step during reading or parsing fails.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Opens the specified file and reads its contents line by line.
|
||||
//
|
||||
// - Skips lines that are empty or start with a '#'.
|
||||
//
|
||||
// - Splits each valid line into two parts using the first colon as the
|
||||
// separator.
|
||||
//
|
||||
// - Trims whitespace from both parts and adds them to the map.
|
||||
//
|
||||
// - Returns any error encountered during file operations or parsing.
|
||||
func ReadMapping(file string) (m map[string]string, err error) {
|
||||
var f *os.File
|
||||
if f, err = os.Open(file); chk.E(err) {
|
||||
return
|
||||
}
|
||||
m = make(map[string]string)
|
||||
sc := bufio.NewScanner(f)
|
||||
for sc.Scan() {
|
||||
if b := sc.Bytes(); len(b) == 0 || b[0] == '#' {
|
||||
continue
|
||||
}
|
||||
s := strings.SplitN(sc.Text(), ":", 2)
|
||||
if len(s) != 2 {
|
||||
err = fmt.Errorf("invalid line: %q", sc.Text())
|
||||
log.E.Ln(err)
|
||||
chk.E(f.Close())
|
||||
return
|
||||
}
|
||||
m[strings.TrimSpace(s[0])] = strings.TrimSpace(s[1])
|
||||
}
|
||||
err = sc.Err()
|
||||
chk.E(err)
|
||||
chk.E(f.Close())
|
||||
return
|
||||
}
|
||||
@@ -1,24 +1,27 @@
|
||||
// Package main implements a reverse proxy with proper forwarding headers.
|
||||
package main
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"orly.dev/cmd/lerproxy/utils"
|
||||
"orly.dev/pkg/utils/log"
|
||||
)
|
||||
|
||||
// NewSingleHostReverseProxy is a copy of httputil.NewSingleHostReverseProxy
|
||||
// with the addition of forwarding headers:
|
||||
//
|
||||
// - Legacy X-Forwarded-* headers (X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Host)
|
||||
// - Standardized Forwarded header according to RFC 7239 (https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded)
|
||||
//
|
||||
// - Standardized Forwarded header according to RFC 7239
|
||||
// (https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded)
|
||||
func NewSingleHostReverseProxy(target *url.URL) (rp *httputil.ReverseProxy) {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
log.D.S(req)
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = SingleJoiningSlash(target.Path, req.URL.Path)
|
||||
req.URL.Path = utils.SingleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
@@ -27,28 +30,23 @@ func NewSingleHostReverseProxy(target *url.URL) (rp *httputil.ReverseProxy) {
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
// Set X-Forwarded-* headers for backward compatibility
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
// Get client IP address
|
||||
clientIP := req.RemoteAddr
|
||||
if fwdFor := req.Header.Get("X-Forwarded-For"); fwdFor != "" {
|
||||
clientIP = fwdFor + ", " + clientIP
|
||||
}
|
||||
req.Header.Set("X-Forwarded-For", clientIP)
|
||||
|
||||
// Set X-Forwarded-Host if not already set
|
||||
if _, exists := req.Header["X-Forwarded-Host"]; !exists {
|
||||
req.Header.Set("X-Forwarded-Host", req.Host)
|
||||
}
|
||||
|
||||
// Set standardized Forwarded header according to RFC 7239
|
||||
// Format: Forwarded: by=<identifier>;for=<identifier>;host=<host>;proto=<http|https>
|
||||
forwardedProto := "https"
|
||||
forwardedHost := req.Host
|
||||
forwardedFor := clientIP
|
||||
|
||||
// Build the Forwarded header value
|
||||
forwardedHeader := "proto=" + forwardedProto
|
||||
if forwardedFor != "" {
|
||||
@@ -57,7 +55,6 @@ func NewSingleHostReverseProxy(target *url.URL) (rp *httputil.ReverseProxy) {
|
||||
if forwardedHost != "" {
|
||||
forwardedHeader += ";host=" + forwardedHost
|
||||
}
|
||||
|
||||
req.Header.Set("Forwarded", forwardedHeader)
|
||||
}
|
||||
rp = &httputil.ReverseProxy{Director: director}
|
||||
124
cmd/lerproxy/app/set-proxy.go
Normal file
124
cmd/lerproxy/app/set-proxy.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
log2 "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SetProxy creates an HTTP handler that routes incoming requests to specified
|
||||
// backend addresses based on hostname mappings.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - mapping (map[string]string): A map where keys are hostnames and values are
|
||||
// the corresponding backend addresses.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - h (http.Handler): The HTTP handler configured with the proxy settings.
|
||||
// - err (error): An error if the mapping is empty or invalid.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Validates that the provided hostname to backend address mapping is not empty.
|
||||
//
|
||||
// - Creates a new ServeMux and configures it to route requests based on the
|
||||
// specified hostnames and backend addresses.
|
||||
//
|
||||
// - Handles special cases such as vanity URLs, Nostr DNS entries, and Unix
|
||||
// socket connections.
|
||||
func SetProxy(mapping map[string]string) (h http.Handler, err error) {
|
||||
if len(mapping) == 0 {
|
||||
return nil, fmt.Errorf("empty mapping")
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
for hostname, backendAddr := range mapping {
|
||||
hn, ba := hostname, backendAddr
|
||||
if strings.ContainsRune(hn, os.PathSeparator) {
|
||||
err = log.E.Err("invalid hostname: %q", hn)
|
||||
return
|
||||
}
|
||||
network := "tcp"
|
||||
if ba != "" && ba[0] == '@' && runtime.GOOS == "linux" {
|
||||
// append \0 to address so addrlen for connect(2) is calculated in a
|
||||
// way compatible with some other implementations (i.e. uwsgi)
|
||||
network, ba = "unix", ba+string(byte(0))
|
||||
} else if strings.HasPrefix(ba, "git+") {
|
||||
GoVanity(hn, ba, mux)
|
||||
continue
|
||||
} else if filepath.IsAbs(ba) {
|
||||
network = "unix"
|
||||
switch {
|
||||
case strings.HasSuffix(ba, string(os.PathSeparator)):
|
||||
// path specified as directory with explicit trailing slash; add
|
||||
// this path as static site
|
||||
fs := http.FileServer(http.Dir(ba))
|
||||
mux.Handle(hn+"/", fs)
|
||||
continue
|
||||
case strings.HasSuffix(ba, "nostr.json"):
|
||||
if err = NostrDNS(hn, ba, mux); err != nil {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else if u, err := url.Parse(ba); err == nil {
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
rp := NewSingleHostReverseProxy(u)
|
||||
modifyCORSResponse := func(res *http.Response) error {
|
||||
res.Header.Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
// res.Header.Set("Access-Control-Allow-Credentials", "true")
|
||||
res.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
return nil
|
||||
}
|
||||
rp.ModifyResponse = modifyCORSResponse
|
||||
rp.ErrorLog = log2.New(
|
||||
os.Stderr, "lerproxy", log2.Llongfile,
|
||||
)
|
||||
rp.BufferPool = Pool{}
|
||||
mux.Handle(hn+"/", rp)
|
||||
continue
|
||||
}
|
||||
}
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = req.Host
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-For", req.RemoteAddr)
|
||||
req.Header.Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
req.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
log.D.Ln(req.URL, req.RemoteAddr)
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(c context.T, n, addr string) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return net.DialTimeout(network, ba, 5*time.Second)
|
||||
},
|
||||
},
|
||||
ErrorLog: log2.New(io.Discard, "", 0),
|
||||
BufferPool: Pool{},
|
||||
}
|
||||
mux.Handle(hn+"/", rp)
|
||||
}
|
||||
return mux, nil
|
||||
}
|
||||
81
cmd/lerproxy/app/setup-server.go
Normal file
81
cmd/lerproxy/app/setup-server.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"net/http"
|
||||
"orly.dev/cmd/lerproxy/utils"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"os"
|
||||
)
|
||||
|
||||
// SetupServer configures and returns an HTTP server instance with proxy
|
||||
// handling and automatic certificate management based on the provided RunArgs
|
||||
// configuration.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - a (RunArgs): The configuration arguments containing settings for the server
|
||||
// address, cache directory, mapping file, HSTS header, email, and certificates.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - s (*http.Server): The configured HTTP server instance.
|
||||
//
|
||||
// - h (http.Handler): The HTTP handler used for proxying requests and managing
|
||||
// automatic certificate challenges.
|
||||
//
|
||||
// - err (error): An error if any step during setup fails.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Reads the hostname to backend address mapping from the specified
|
||||
// configuration file.
|
||||
//
|
||||
// - Sets up a proxy handler that routes incoming requests based on the defined
|
||||
// mappings.
|
||||
//
|
||||
// - Enables HSTS header support if enabled in the RunArgs.
|
||||
//
|
||||
// - Creates the cache directory for storing certificates and keys if it does not
|
||||
// already exist.
|
||||
//
|
||||
// - Configures an autocert.Manager to handle automatic certificate management,
|
||||
// including hostname whitelisting, email contact, and cache storage.
|
||||
//
|
||||
// - Initializes the HTTP server with proxy handler, address, and TLS
|
||||
// configuration.
|
||||
func SetupServer(a RunArgs) (s *http.Server, h http.Handler, err error) {
|
||||
var mapping map[string]string
|
||||
if mapping, err = ReadMapping(a.Conf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var proxy http.Handler
|
||||
if proxy, err = SetProxy(mapping); chk.E(err) {
|
||||
return
|
||||
}
|
||||
if a.HSTS {
|
||||
proxy = &Proxy{Handler: proxy}
|
||||
}
|
||||
if err = os.MkdirAll(a.Cache, 0700); chk.E(err) {
|
||||
err = fmt.Errorf(
|
||||
"cannot create cache directory %q: %v",
|
||||
a.Cache, err,
|
||||
)
|
||||
chk.E(err)
|
||||
return
|
||||
}
|
||||
m := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(a.Cache),
|
||||
HostPolicy: autocert.HostWhitelist(utils.GetKeys(mapping)...),
|
||||
Email: a.Email,
|
||||
}
|
||||
s = &http.Server{
|
||||
Handler: proxy,
|
||||
Addr: a.Addr,
|
||||
TLSConfig: TLSConfig(&m, a.Certs...),
|
||||
}
|
||||
h = m.HTTPHandler(nil)
|
||||
return
|
||||
}
|
||||
87
cmd/lerproxy/app/tls-config.go
Normal file
87
cmd/lerproxy/app/tls-config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TLSConfig creates a custom TLS configuration that combines automatic
|
||||
// certificate management with explicitly provided certificates.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - m (*autocert.Manager): The autocert manager used for managing automatic
|
||||
// certificate generation and retrieval.
|
||||
//
|
||||
// - certs (...string): A variadic list of certificate definitions in the format
|
||||
// "domain:/path/to/cert", where each domain maps to a certificate file. The
|
||||
// corresponding key file is expected to be at "/path/to/cert.key".
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - tc (*tls.Config): A new TLS configuration that prioritises explicitly
|
||||
// provided certificates over automatically generated ones.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Loads all explicitly provided certificates and maps them to their
|
||||
// respective domains.
|
||||
//
|
||||
// - Creates a custom GetCertificate function that checks if the requested
|
||||
// domain matches any of the explicitly provided certificates, returning those
|
||||
// first.
|
||||
//
|
||||
// - Falls back to the autocert manager's GetCertificate method if no explicit
|
||||
// certificate is found for the requested domain.
|
||||
func TLSConfig(m *autocert.Manager, certs ...string) (tc *tls.Config) {
|
||||
certMap := make(map[string]*tls.Certificate)
|
||||
var mx sync.Mutex
|
||||
for _, cert := range certs {
|
||||
split := strings.Split(cert, ":")
|
||||
if len(split) != 2 {
|
||||
log.E.F("invalid certificate parameter format: `%s`", cert)
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
var c tls.Certificate
|
||||
if c, err = tls.LoadX509KeyPair(
|
||||
split[1]+".crt", split[1]+".key",
|
||||
); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
certMap[split[0]] = &c
|
||||
}
|
||||
tc = m.TLSConfig()
|
||||
tc.GetCertificate = func(helo *tls.ClientHelloInfo) (
|
||||
cert *tls.Certificate, err error,
|
||||
) {
|
||||
mx.Lock()
|
||||
var own string
|
||||
for i := range certMap {
|
||||
// to also handle explicit subdomain certs, prioritize over a root
|
||||
// wildcard.
|
||||
if helo.ServerName == i {
|
||||
own = i
|
||||
break
|
||||
}
|
||||
// if it got to us and ends in the same-name dot tld assume the
|
||||
// subdomain was redirected, or it is a wildcard certificate; thus
|
||||
// only the ending needs to match.
|
||||
if strings.HasSuffix(helo.ServerName, i) {
|
||||
own = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if own != "" {
|
||||
defer mx.Unlock()
|
||||
return certMap[own], nil
|
||||
}
|
||||
mx.Unlock()
|
||||
return m.GetCertificate(helo)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,395 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
stdLog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"orly.dev/cmd/lerproxy/app"
|
||||
"orly.dev/pkg/utils/chk"
|
||||
"orly.dev/pkg/utils/context"
|
||||
"orly.dev/pkg/utils/log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alexflint/go-arg"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type RunArgs struct {
|
||||
Addr string `arg:"-l,--listen" default:":https" help:"address to listen at"`
|
||||
Conf string `arg:"-m,--map" default:"mapping.txt" help:"file with host/backend mapping"`
|
||||
Cache string `arg:"-c,--cachedir" default:"/var/cache/letsencrypt" help:"path to directory to cache key and certificates"`
|
||||
HSTS bool `arg:"-h,--hsts" help:"add Strict-Transport-Security header"`
|
||||
Email string `arg:"-e,--email" help:"contact email address presented to letsencrypt CA"`
|
||||
HTTP string `arg:"--http" default:":http" help:"optional address to serve http-to-https redirects and ACME http-01 challenge responses"`
|
||||
RTO time.Duration `arg:"-r,--rto" default:"1m" help:"maximum duration before timing out read of the request"`
|
||||
WTO time.Duration `arg:"-w,--wto" default:"5m" help:"maximum duration before timing out write of the response"`
|
||||
Idle time.Duration `arg:"-i,--idle" help:"how long idle connection is kept before closing (set rto, wto to 0 to use this)"`
|
||||
Certs []string `arg:"--cert,separate" help:"certificates and the domain they match: eg: orly.dev:/path/to/cert - this will indicate to load two, one with extension .key and one with .crt, each expected to be PEM encoded TLS private and public keys, respectively"`
|
||||
// Rewrites string `arg:"-r,--rewrites" default:"rewrites.txt"`
|
||||
}
|
||||
|
||||
var args RunArgs
|
||||
var args app.RunArgs
|
||||
|
||||
func main() {
|
||||
arg.MustParse(&args)
|
||||
ctx, cancel := signal.NotifyContext(context.Bg(), os.Interrupt)
|
||||
defer cancel()
|
||||
if err := Run(ctx, args); chk.T(err) {
|
||||
if err := app.Run(ctx, args); chk.T(err) {
|
||||
log.F.Ln(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Run(c context.T, args RunArgs) (err error) {
|
||||
|
||||
if args.Cache == "" {
|
||||
err = log.E.Err("no cache specified")
|
||||
return
|
||||
}
|
||||
|
||||
var srv *http.Server
|
||||
var httpHandler http.Handler
|
||||
if srv, httpHandler, err = SetupServer(args); chk.E(err) {
|
||||
return
|
||||
}
|
||||
srv.ReadHeaderTimeout = 5 * time.Second
|
||||
if args.RTO > 0 {
|
||||
srv.ReadTimeout = args.RTO
|
||||
}
|
||||
if args.WTO > 0 {
|
||||
srv.WriteTimeout = args.WTO
|
||||
}
|
||||
group, ctx := errgroup.WithContext(c)
|
||||
if args.HTTP != "" {
|
||||
httpServer := http.Server{
|
||||
Addr: args.HTTP,
|
||||
Handler: httpHandler,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
chk.E(httpServer.ListenAndServe())
|
||||
return
|
||||
},
|
||||
)
|
||||
group.Go(
|
||||
func() error {
|
||||
<-ctx.Done()
|
||||
ctx, cancel := context.Timeout(
|
||||
context.Bg(),
|
||||
time.Second,
|
||||
)
|
||||
defer cancel()
|
||||
return httpServer.Shutdown(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
if srv.ReadTimeout != 0 || srv.WriteTimeout != 0 || args.Idle == 0 {
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
chk.E(srv.ListenAndServeTLS("", ""))
|
||||
return
|
||||
},
|
||||
)
|
||||
} else {
|
||||
group.Go(
|
||||
func() (err error) {
|
||||
var ln net.Listener
|
||||
if ln, err = net.Listen("tcp", srv.Addr); chk.E(err) {
|
||||
return
|
||||
}
|
||||
defer ln.Close()
|
||||
ln = Listener{
|
||||
Duration: args.Idle,
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
}
|
||||
err = srv.ServeTLS(ln, "", "")
|
||||
chk.E(err)
|
||||
return
|
||||
},
|
||||
)
|
||||
}
|
||||
group.Go(
|
||||
func() error {
|
||||
<-ctx.Done()
|
||||
ctx, cancel := context.Timeout(context.Bg(), time.Second)
|
||||
defer cancel()
|
||||
return srv.Shutdown(ctx)
|
||||
},
|
||||
)
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// TLSConfig returns a TLSConfig that works with a LetsEncrypt automatic SSL
|
||||
// cert issuer as well as any provided .pem certificates from providers.
|
||||
//
|
||||
// The certs are provided in the form "example.com:/path/to/cert.pem"
|
||||
func TLSConfig(m *autocert.Manager, certs ...string) (tc *tls.Config) {
|
||||
certMap := make(map[string]*tls.Certificate)
|
||||
var mx sync.Mutex
|
||||
for _, cert := range certs {
|
||||
split := strings.Split(cert, ":")
|
||||
if len(split) != 2 {
|
||||
log.E.F("invalid certificate parameter format: `%s`", cert)
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
var c tls.Certificate
|
||||
if c, err = tls.LoadX509KeyPair(
|
||||
split[1]+".crt", split[1]+".key",
|
||||
); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
certMap[split[0]] = &c
|
||||
}
|
||||
tc = m.TLSConfig()
|
||||
tc.GetCertificate = func(helo *tls.ClientHelloInfo) (
|
||||
cert *tls.Certificate, err error,
|
||||
) {
|
||||
mx.Lock()
|
||||
var own string
|
||||
for i := range certMap {
|
||||
// to also handle explicit subdomain certs, prioritize over a root wildcard.
|
||||
if helo.ServerName == i {
|
||||
own = i
|
||||
break
|
||||
}
|
||||
// if it got to us and ends in the same-name dot tld assume the
|
||||
// subdomain was redirected, or it is a wildcard certificate; thus
|
||||
// only the ending needs to match.
|
||||
if strings.HasSuffix(helo.ServerName, i) {
|
||||
own = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if own != "" {
|
||||
defer mx.Unlock()
|
||||
return certMap[own], nil
|
||||
}
|
||||
mx.Unlock()
|
||||
return m.GetCertificate(helo)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func SetupServer(a RunArgs) (s *http.Server, h http.Handler, err error) {
|
||||
var mapping map[string]string
|
||||
if mapping, err = ReadMapping(a.Conf); chk.E(err) {
|
||||
return
|
||||
}
|
||||
var proxy http.Handler
|
||||
if proxy, err = SetProxy(mapping); chk.E(err) {
|
||||
return
|
||||
}
|
||||
if a.HSTS {
|
||||
proxy = &Proxy{Handler: proxy}
|
||||
}
|
||||
if err = os.MkdirAll(a.Cache, 0700); chk.E(err) {
|
||||
err = fmt.Errorf(
|
||||
"cannot create cache directory %q: %v",
|
||||
a.Cache, err,
|
||||
)
|
||||
chk.E(err)
|
||||
return
|
||||
}
|
||||
m := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(a.Cache),
|
||||
HostPolicy: autocert.HostWhitelist(GetKeys(mapping)...),
|
||||
Email: a.Email,
|
||||
}
|
||||
s = &http.Server{
|
||||
Handler: proxy,
|
||||
Addr: a.Addr,
|
||||
TLSConfig: TLSConfig(&m, a.Certs...),
|
||||
}
|
||||
h = m.HTTPHandler(nil)
|
||||
return
|
||||
}
|
||||
|
||||
type NostrJSON struct {
|
||||
Names map[string]string `json:"names"`
|
||||
Relays map[string][]string `json:"relays"`
|
||||
}
|
||||
|
||||
func SetProxy(mapping map[string]string) (h http.Handler, err error) {
|
||||
if len(mapping) == 0 {
|
||||
return nil, fmt.Errorf("empty mapping")
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
for hostname, backendAddr := range mapping {
|
||||
hn, ba := hostname, backendAddr
|
||||
if strings.ContainsRune(hn, os.PathSeparator) {
|
||||
err = log.E.Err("invalid hostname: %q", hn)
|
||||
return
|
||||
}
|
||||
network := "tcp"
|
||||
if ba != "" && ba[0] == '@' && runtime.GOOS == "linux" {
|
||||
// append \0 to address so addrlen for connect(2) is calculated in a
|
||||
// way compatible with some other implementations (i.e. uwsgi)
|
||||
network, ba = "unix", ba+string(byte(0))
|
||||
} else if strings.HasPrefix(ba, "git+") {
|
||||
split := strings.Split(ba, "git+")
|
||||
if len(split) != 2 {
|
||||
log.E.Ln("invalid go vanity redirect: %s: %s", hn, ba)
|
||||
continue
|
||||
}
|
||||
redirector := fmt.Sprintf(
|
||||
`<html><head><meta name="go-import" content="%s git %s"/><meta http-equiv = "refresh" content = " 3 ; url = %s"/></head><body>redirecting to <a href="%s">%s</a></body></html>`,
|
||||
hn, split[1], split[1], split[1], split[1],
|
||||
)
|
||||
mux.HandleFunc(
|
||||
hn+"/",
|
||||
func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Content-Type", "text/html")
|
||||
writer.Header().Set(
|
||||
"Content-Length", fmt.Sprint(len(redirector)),
|
||||
)
|
||||
writer.Header().Set(
|
||||
"strict-transport-security",
|
||||
"max-age=0; includeSubDomains",
|
||||
)
|
||||
fmt.Fprint(writer, redirector)
|
||||
},
|
||||
)
|
||||
continue
|
||||
} else if filepath.IsAbs(ba) {
|
||||
network = "unix"
|
||||
switch {
|
||||
case strings.HasSuffix(ba, string(os.PathSeparator)):
|
||||
// path specified as directory with explicit trailing slash; add
|
||||
// this path as static site
|
||||
fs := http.FileServer(http.Dir(ba))
|
||||
mux.Handle(hn+"/", fs)
|
||||
continue
|
||||
case strings.HasSuffix(ba, "nostr.json"):
|
||||
log.I.Ln(hn, ba)
|
||||
var fb []byte
|
||||
if fb, err = os.ReadFile(ba); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
var v NostrJSON
|
||||
if err = json.Unmarshal(fb, &v); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
var jb []byte
|
||||
if jb, err = json.Marshal(v); chk.E(err) {
|
||||
continue
|
||||
}
|
||||
nostrJSON := string(jb)
|
||||
mux.HandleFunc(
|
||||
hn+"/.well-known/nostr.json",
|
||||
func(writer http.ResponseWriter, request *http.Request) {
|
||||
log.I.Ln("serving nostr json to", hn)
|
||||
writer.Header().Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.Header().Set(
|
||||
"Content-Length", fmt.Sprint(len(nostrJSON)),
|
||||
)
|
||||
writer.Header().Set(
|
||||
"strict-transport-security",
|
||||
"max-age=0; includeSubDomains",
|
||||
)
|
||||
fmt.Fprint(writer, nostrJSON)
|
||||
},
|
||||
)
|
||||
continue
|
||||
}
|
||||
} else if u, err := url.Parse(ba); err == nil {
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
rp := NewSingleHostReverseProxy(u)
|
||||
modifyCORSResponse := func(res *http.Response) error {
|
||||
res.Header.Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
// res.Header.Set("Access-Control-Allow-Credentials", "true")
|
||||
res.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
return nil
|
||||
}
|
||||
rp.ModifyResponse = modifyCORSResponse
|
||||
rp.ErrorLog = stdLog.New(
|
||||
os.Stderr, "lerproxy", stdLog.Llongfile,
|
||||
)
|
||||
rp.BufferPool = Pool{}
|
||||
mux.Handle(hn+"/", rp)
|
||||
continue
|
||||
}
|
||||
}
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = req.Host
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-For", req.RemoteAddr)
|
||||
req.Header.Set(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
)
|
||||
// req.Header.Set("Access-Control-Allow-Credentials", "true")
|
||||
req.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
log.D.Ln(req.URL, req.RemoteAddr)
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(c context.T, n, addr string) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return net.DialTimeout(network, ba, 5*time.Second)
|
||||
},
|
||||
},
|
||||
ErrorLog: stdLog.New(io.Discard, "", 0),
|
||||
BufferPool: Pool{},
|
||||
}
|
||||
mux.Handle(hn+"/", rp)
|
||||
}
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
func ReadMapping(file string) (m map[string]string, err error) {
|
||||
var f *os.File
|
||||
if f, err = os.Open(file); chk.E(err) {
|
||||
return
|
||||
}
|
||||
m = make(map[string]string)
|
||||
sc := bufio.NewScanner(f)
|
||||
for sc.Scan() {
|
||||
if b := sc.Bytes(); len(b) == 0 || b[0] == '#' {
|
||||
continue
|
||||
}
|
||||
s := strings.SplitN(sc.Text(), ":", 2)
|
||||
if len(s) != 2 {
|
||||
err = fmt.Errorf("invalid line: %q", sc.Text())
|
||||
log.E.Ln(err)
|
||||
chk.E(f.Close())
|
||||
return
|
||||
}
|
||||
m[strings.TrimSpace(s[0])] = strings.TrimSpace(s[1])
|
||||
}
|
||||
err = sc.Err()
|
||||
chk.E(err)
|
||||
chk.E(f.Close())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package main
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Proxy struct {
|
||||
http.Handler
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().
|
||||
Set(
|
||||
"Strict-Transport-Security",
|
||||
"max-age=31536000; includeSubDomains; preload",
|
||||
)
|
||||
p.ServeHTTP(w, r)
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Package util provides some helpers for lerproxy, a tool to convert maps of
|
||||
// strings to slices of the same strings, and a helper to avoid putting two / in
|
||||
// a URL.
|
||||
package main
|
||||
|
||||
import "strings"
|
||||
|
||||
func GetKeys(m map[string]string) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func SingleJoiningSlash(a, b string) string {
|
||||
suffixSlash := strings.HasSuffix(a, "/")
|
||||
prefixSlash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case suffixSlash && prefixSlash:
|
||||
return a + b[1:]
|
||||
case !suffixSlash && !prefixSlash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
62
cmd/lerproxy/utils/utils.go
Normal file
62
cmd/lerproxy/utils/utils.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package utils
|
||||
|
||||
import "strings"
|
||||
|
||||
// GetKeys returns a slice containing all the keys from the provided map.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - m (map[string]string): The input map from which to extract keys.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - []string: A slice of strings representing the keys in the map.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - Iterates over each key in the map and appends it to a new slice.
|
||||
//
|
||||
// - Returns the slice containing all the keys.
|
||||
func GetKeys(m map[string]string) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SingleJoiningSlash joins two strings with a single slash between them,
|
||||
// ensuring that the resulting path doesn't contain multiple consecutive
|
||||
// slashes.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - a (string): The first string to join.
|
||||
//
|
||||
// - b (string): The second string to join.
|
||||
//
|
||||
// # Return Values
|
||||
//
|
||||
// - result (string): The joined string with a single slash between them if
|
||||
// needed.
|
||||
//
|
||||
// # Expected behaviour
|
||||
//
|
||||
// - If both a and b start and end with a slash, the resulting string will have
|
||||
// only one slash between them.
|
||||
//
|
||||
// - If neither a nor b starts or ends with a slash, the strings will be joined
|
||||
// with a single slash in between.
|
||||
//
|
||||
// - Otherwise, the two strings are simply concatenated.
|
||||
func SingleJoiningSlash(a, b string) string {
|
||||
suffixSlash := strings.HasSuffix(a, "/")
|
||||
prefixSlash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case suffixSlash && prefixSlash:
|
||||
return a + b[1:]
|
||||
case !suffixSlash && !prefixSlash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
Reference in New Issue
Block a user