Files
reverse/main.go
2025-08-19 19:42:08 +01:00

479 lines
12 KiB
Go

// Command leproxy implements https reverse proxy with automatic Letsencrypt usage for multiple
// hostnames/backends
package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
log2 "log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/artyom/autoflags"
"github.com/mleku/lol/chk"
"github.com/mleku/lol/log"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/sync/errgroup"
)
func main() {
args := runArgs{
Addr: ":https",
HTTP: ":http",
Conf: os.Getenv("HOME") + "/.config/reverse/mapping.conf",
Cache: os.Getenv("HOME") + "/.cache/reverse",
RTo: time.Minute,
WTo: 5 * time.Minute,
}
autoflags.Parse(&args)
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err := run(ctx, args); err != nil {
log2.Fatal(err)
}
}
type runArgs struct {
Addr string `flag:"addr,address to listen at"`
Conf string `flag:"map,file with host/backend mapping"`
Cache string `flag:"cacheDir,path to directory to cache key and certificates"`
HSTS bool `flag:"hsts,add Strict-Transport-Security header"`
Email string `flag:"email,contact email address presented to letsencrypt CA"`
HTTP string `flag:"http,optional address to serve http-to-https redirects and ACME http-01 challenge responses"`
RTo time.Duration `flag:"rto,maximum duration before timing out read of the request"`
WTo time.Duration `flag:"wto,maximum duration before timing out write of the response"`
Idle time.Duration `flag:"idle,how long idle connection is kept before closing (set rto, wto to 0 to use this)"`
}
func run(ctx context.Context, args runArgs) error {
if args.Cache == "" {
return fmt.Errorf("no cache specified")
}
srv, httpHandler, err := setupServer(
args.Addr, args.Conf, args.Cache, args.Email, args.HSTS,
)
if err != nil {
return err
}
srv.ReadHeaderTimeout = 5 * time.Second
if args.RTo > 0 {
srv.ReadTimeout = args.RTo
}
if args.WTo > 0 {
srv.WriteTimeout = args.WTo
}
group, ctx := errgroup.WithContext(ctx)
if args.HTTP != "" {
httpServer := http.Server{
Addr: args.HTTP,
Handler: httpHandler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
group.Go(func() error { return httpServer.ListenAndServe() })
group.Go(
func() error {
<-ctx.Done()
ctx, cancel := context.WithTimeout(
context.Background(), time.Second,
)
defer cancel()
return httpServer.Shutdown(ctx)
},
)
}
if srv.ReadTimeout != 0 || srv.WriteTimeout != 0 || args.Idle == 0 {
group.Go(func() error { return srv.ListenAndServeTLS("", "") })
} else {
group.Go(
func() error {
ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
return err
}
defer ln.Close()
ln = tcpKeepAliveListener{
d: args.Idle,
TCPListener: ln.(*net.TCPListener),
}
return srv.ServeTLS(ln, "", "")
},
)
}
group.Go(
func() error {
<-ctx.Done()
ctx, cancel := context.WithTimeout(
context.Background(), time.Second,
)
defer cancel()
return srv.Shutdown(ctx)
},
)
return group.Wait()
}
func setupServer(
addr, mapfile, cacheDir, email string, hsts bool,
) (*http.Server, http.Handler, error) {
mapping, err := readMapping(mapfile)
if err != nil {
return nil, nil, err
}
proxy, err := setProxy(mapping)
if err != nil {
return nil, nil, err
}
if hsts {
proxy = &hstsProxy{proxy}
}
if err := os.MkdirAll(cacheDir, 0700); err != nil {
return nil, nil, fmt.Errorf(
"cannot create cache directory %q: %v", cacheDir, err,
)
}
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(cacheDir),
HostPolicy: autocert.HostWhitelist(keys(mapping)...),
Email: email,
}
srv := &http.Server{
Handler: proxy,
Addr: addr,
TLSConfig: m.TLSConfig(),
}
return srv, m.HTTPHandler(nil), nil
}
func setProxy(mapping map[string]string) (http.Handler, error) {
if len(mapping) == 0 {
return nil, fmt.Errorf("empty mapping")
}
mux := http.NewServeMux()
for hostname, backendAddr := range mapping {
hn, ba := hostname, backendAddr // intentional shadowing
if strings.ContainsRune(hn, os.PathSeparator) {
return nil, fmt.Errorf("invalid hostname: %q", hn)
}
network := "tcp"
if ba != "" && ba[0] == '@' && runtime.GOOS == "linux" {
// append \0 to address so addrlen for connect(2) is
// calculated in a way compatible with some other
// implementations (i.e. uwsgi)
network, ba = "unix", ba+string(byte(0))
} else if strings.HasPrefix(ba, "git+") {
GoVanity(hn, ba, mux)
continue
} else if filepath.IsAbs(ba) {
network = "unix"
switch {
case strings.HasSuffix(ba, string(os.PathSeparator)):
// path specified as directory with explicit trailing slash; add
// this path as static site
fs := http.FileServer(http.Dir(ba))
mux.Handle(hn+"/", fs)
continue
case strings.HasSuffix(ba, "nostr.json"):
if err := NostrDNS(hn, ba, mux); err != nil {
continue
}
continue
}
} else if u, err := url.Parse(ba); err == nil {
switch u.Scheme {
case "http", "https":
rp := newSingleHostReverseProxy(u)
rp.ErrorLog = log2.New(io.Discard, "", 0)
rp.BufferPool = bufPool{}
mux.Handle(hn+"/", rp)
continue
}
}
rp := &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = req.Host
req.Header.Set("X-Forwarded-Proto", "https")
},
Transport: &http.Transport{
Dial: func(netw, addr string) (net.Conn, error) {
return net.DialTimeout(network, ba, 5*time.Second)
},
},
ErrorLog: log2.New(io.Discard, "", 0),
BufferPool: bufPool{},
}
mux.Handle(hn+"/", rp)
}
return mux, nil
}
func readMapping(file string) (map[string]string, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer f.Close()
m := make(map[string]string)
sc := bufio.NewScanner(f)
for sc.Scan() {
if b := sc.Bytes(); len(b) == 0 || b[0] == '#' {
continue
}
s := strings.SplitN(sc.Text(), ":", 2)
if len(s) != 2 {
return nil, fmt.Errorf("invalid line: %q", sc.Text())
}
m[strings.TrimSpace(s[0])] = strings.TrimSpace(s[1])
}
return m, sc.Err()
}
func keys(m map[string]string) []string {
out := make([]string, 0, len(m))
for k := range m {
out = append(out, k)
}
return out
}
type hstsProxy struct {
http.Handler
}
func (p *hstsProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains; preload",
)
p.Handler.ServeHTTP(w, r)
}
type bufPool struct{}
func (bp bufPool) Get() []byte { return *(bufferPool.Get().(*[]byte)) }
func (bp bufPool) Put(b []byte) { bufferPool.Put(&b) }
var bufferPool = &sync.Pool{
New: func() interface{} {
buf := make([]byte, 32*1024)
return &buf
},
}
// newSingleHostReverseProxy is a copy of httputil.NewSingleHostReverseProxy
// with addition of "X-Forwarded-Proto" header.
func newSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "")
}
req.Header.Set("X-Forwarded-Proto", "https")
}
return &httputil.ReverseProxy{Director: director}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
d time.Duration
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
if ln.d == 0 {
return tc, nil
}
return timeoutConn{d: ln.d, TCPConn: tc}, nil
}
// timeoutConn extends deadline after successful read or write operations
type timeoutConn struct {
d time.Duration
*net.TCPConn
}
func (c timeoutConn) Read(b []byte) (int, error) {
n, err := c.TCPConn.Read(b)
if err == nil {
_ = c.TCPConn.SetDeadline(time.Now().Add(c.d))
}
return n, err
}
func (c timeoutConn) Write(b []byte) (int, error) {
n, err := c.TCPConn.Write(b)
if err == nil {
_ = c.TCPConn.SetDeadline(time.Now().Add(c.d))
}
return n, err
}
// GoVanity configures an HTTP handler for redirecting requests to vanity URLs
// based on the provided hostname and backend address.
//
// # Parameters
//
// - hn (string): The hostname associated with the vanity URL.
//
// - ba (string): The backend address, expected to be in the format
// "git+<repository-path>".
//
// - mux (*http.ServeMux): The HTTP serve multiplexer where the handler will be
// registered.
//
// # Expected behaviour
//
// - Splits the backend address to extract the repository path from the "git+" prefix.
//
// - If the split fails, logs an error and returns without registering a handler.
//
// - Generates an HTML redirect page containing metadata for Go import and
// redirects to the extracted repository path.
//
// - Registers a handler on the provided ServeMux that serves this redirect page
// when requests are made to the specified hostname.
func GoVanity(hn, ba string, mux *http.ServeMux) {
split := strings.Split(ba, "git+")
if len(split) != 2 {
log.E.Ln("invalid go vanity redirect: %s: %s", hn, ba)
return
}
redirector := fmt.Sprintf(
`<html><head><meta name="go-import" content="%s git %s"/><meta http-equiv = "refresh" content = " 3 ; url = %s"/></head><body>redirecting to <a href="%s">%s</a></body></html>`,
hn, split[1], split[1], split[1], split[1],
)
mux.HandleFunc(
hn+"/",
func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set(
"Access-Control-Allow-Methods",
"GET,HEAD,PUT,PATCH,POST,DELETE",
)
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.Header().Set("Content-Type", "text/html")
writer.Header().Set(
"Content-Length", fmt.Sprint(len(redirector)),
)
writer.Header().Set(
"strict-transport-security",
"max-age=0; includeSubDomains",
)
fmt.Fprint(writer, redirector)
},
)
}
type NostrJSON struct {
Names map[string]string `json:"names"`
Relays map[string][]string `json:"relays"`
}
// NostrDNS handles the configuration and registration of a Nostr DNS endpoint
// for a given hostname and backend address.
//
// # Parameters
//
// - hn (string): The hostname for which the Nostr DNS entry is being configured.
//
// - ba (string): The path to the JSON file containing the Nostr DNS data.
//
// - mux (*http.ServeMux): The HTTP serve multiplexer to which the Nostr DNS
// handler will be registered.
//
// # Return Values
//
// - err (error): An error if any step fails during the configuration or
// registration process.
//
// # Expected behaviour
//
// - Reads the JSON file specified by `ba` and parses its contents into a
// NostrJSON struct.
//
// - Registers a new HTTP handler on the provided `mux` for the
// `.well-known/nostr.json` endpoint under the specified hostname.
//
// - The handler serves the parsed Nostr DNS data with appropriate HTTP headers
// set for CORS and content type.
func NostrDNS(hn, ba string, mux *http.ServeMux) (err error) {
log.T.Ln(hn, ba)
var fb []byte
if fb, err = os.ReadFile(ba); chk.E(err) {
return
}
var v NostrJSON
if err = json.Unmarshal(fb, &v); chk.E(err) {
return
}
var jb []byte
if jb, err = json.Marshal(v); chk.E(err) {
return
}
nostrJSON := string(jb)
mux.HandleFunc(
hn+"/.well-known/nostr.json",
func(writer http.ResponseWriter, request *http.Request) {
log.T.Ln("serving nostr json to", hn)
writer.Header().Set(
"Access-Control-Allow-Methods",
"GET,HEAD,PUT,PATCH,POST,DELETE",
)
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.Header().Set("Content-Type", "application/json")
writer.Header().Set(
"Content-Length", fmt.Sprint(len(nostrJSON)),
)
writer.Header().Set(
"strict-transport-security",
"max-age=0; includeSubDomains",
)
fmt.Fprint(writer, nostrJSON)
},
)
return
}