refactoring rpc

This commit is contained in:
greg stone
2023-03-04 01:14:00 +00:00
parent 678a433ae4
commit 18ca0e492a
10 changed files with 131 additions and 105 deletions

View File

@@ -17,7 +17,7 @@ func configureDevice() {
var err error var err error
dev.SetPrivateKey(tunKey.AsDeviceKey()) dev.SetPrivateKey(tunKey.AsDeviceKey())
dev.IpcSet("listen_port=" + strconv.Itoa(int(tunnelPort))) dev.IpcSet("listen_port=" + strconv.Itoa(int(o.tunPort)))
for _, peer_whitelist := range tunWhitelist { for _, peer_whitelist := range tunWhitelist {

View File

@@ -3,39 +3,38 @@ package rpc
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"os"
) )
var ( var (
unixPathFlag = "rpc-unix-listen" UnixPathFlag = "rpc-unix-listen"
tunEnableFlag = "rpc-tun-enable" TunEnableFlag = "rpc-tun-enable"
tunKeyFlag = "rpc-tun-key" tunKeyFlag = "rpc-tun-key"
tunPortFlag = "rpc-tun-port" TunPortFlag = "rpc-tun-port"
tunPeersFlag = "rpc-tun-peer" TunPeersFlag = "rpc-tun-peer"
) )
var ( var (
unixPath string
tunEnabled bool = false
tunKeyRaw string tunKeyRaw string
tunPeersRaw = []string{} tunPort int = 0
tunPeersRaw = []string{}
) )
func InitFlags(cmd *cobra.Command) { func InitFlags(cmd *cobra.Command) {
cobra.OnInitialize(initUnixSockPath) cmd.PersistentFlags().StringVarP(&unixPath, UnixPathFlag, "",
cmd.PersistentFlags().StringVarP(&unixPath, unixPathFlag, "",
"", "",
"binds to a unix socket with path (default is $HOME/.indra/indra.sock)", "binds to a unix socket with path (default is $HOME/.indra/indra.sock)",
) )
viper.BindPFlag(unixPathFlag, cmd.PersistentFlags().Lookup(unixPathFlag)) viper.BindPFlag(UnixPathFlag, cmd.PersistentFlags().Lookup(UnixPathFlag))
cmd.PersistentFlags().BoolVarP(&isTunnelEnabled, tunEnableFlag, "", cmd.PersistentFlags().BoolVarP(&tunEnabled, TunEnableFlag, "",
false, false,
"enables the rpc server tunnel (default false)", "enables the rpc server tunnel (default false)",
) )
viper.BindPFlag(tunEnableFlag, cmd.PersistentFlags().Lookup(tunEnableFlag)) viper.BindPFlag(TunEnableFlag, cmd.PersistentFlags().Lookup(TunEnableFlag))
//cmd.Flags().StringVarP(&tunKeyRaw, tunKeyFlag, "", //cmd.Flags().StringVarP(&tunKeyRaw, tunKeyFlag, "",
// "", // "",
@@ -44,30 +43,17 @@ func InitFlags(cmd *cobra.Command) {
// //
//viper.BindPFlag(tunKeyFlag, cmd.Flags().Lookup(tunKeyFlag)) //viper.BindPFlag(tunKeyFlag, cmd.Flags().Lookup(tunKeyFlag))
cmd.PersistentFlags().IntVarP(&tunnelPort, tunPortFlag, "", cmd.PersistentFlags().IntVarP(&tunPort, TunPortFlag, "",
tunnelPort, tunPort,
"binds the udp server to port (random if not selected)", "binds the udp server to port (random if not selected)",
) )
viper.BindPFlag(tunPortFlag, cmd.PersistentFlags().Lookup(tunPortFlag)) viper.BindPFlag(TunPortFlag, cmd.PersistentFlags().Lookup(TunPortFlag))
cmd.PersistentFlags().StringSliceVarP(&tunPeersRaw, tunPeersFlag, "", cmd.PersistentFlags().StringSliceVarP(&tunPeersRaw, TunPeersFlag, "",
tunPeersRaw, tunPeersRaw,
"adds a peer id to the whitelist for access", "adds a peer id to the whitelist for access",
) )
viper.BindPFlag(tunPeersFlag, cmd.PersistentFlags().Lookup(tunPeersFlag)) viper.BindPFlag(TunPeersFlag, cmd.PersistentFlags().Lookup(TunPeersFlag))
}
func initUnixSockPath() {
if viper.GetString(unixPathFlag) != "" {
return
}
home, err := os.UserHomeDir()
cobra.CheckErr(err)
viper.Set(unixPathFlag, home+"/.indra/indra.sock")
} }

View File

@@ -7,7 +7,7 @@ import (
var ( var (
server *grpc.Server server *grpc.Server
o *serverOptions o *ServerOptions
) )
var ( var (
@@ -24,22 +24,41 @@ func RunWith(r func(srv *grpc.Server), opts ...ServerOption) {
log.I.Ln("initializing the rpc server") log.I.Ln("initializing the rpc server")
o = &serverOptions{false, &storeMem{}} o = &ServerOptions{
&storeMem{},
unixPathDefault,
false,
NullPort,
[]string{},
}
for _, opt := range opts { for _, opt := range opts {
opt.apply(o) opt.apply(o)
} }
if o.unixPath != "" {
log.I.Ln("enabling rpc unix listener:")
log.I.F("- [/unix%s]", o.unixPath)
isUnixSockEnabled = true
}
if o.tunEnable {
configureTunnel()
}
isConfigured <- true
server = grpc.NewServer() server = grpc.NewServer()
configureUnixSocket()
configureTunnel()
r(server) r(server)
go start() go start()
} }
func Options() *ServerOptions {
return o
}
func start() { func start() {
log.I.Ln("starting rpc server") log.I.Ln("starting rpc server")

View File

@@ -1,36 +1,55 @@
package rpc package rpc
type serverOptions struct { type ServerOptions struct {
disableTunnel bool store Store
store Store unixPath string
tunEnable bool
tunPort uint16
tunPeers []string
} }
func (s *ServerOptions) GetTunPort() uint16 { return s.tunPort }
type ServerOption interface { type ServerOption interface {
apply(*serverOptions) apply(*ServerOptions)
} }
type funcServerOption struct { type funcServerOption struct {
f func(*serverOptions) f func(*ServerOptions)
} }
func (fdo *funcServerOption) apply(do *serverOptions) { func (fdo *funcServerOption) apply(do *ServerOptions) {
fdo.f(do) fdo.f(do)
} }
func newFuncServerOption(f func(*serverOptions)) *funcServerOption { func newFuncServerOption(f func(*ServerOptions)) *funcServerOption {
return &funcServerOption{ return &funcServerOption{
f: f, f: f,
} }
} }
func WithDisableTunnel() ServerOption { func WithDisableTunnel() ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *ServerOptions) {
o.disableTunnel = true o.tunEnable = false
}) })
} }
func WithStore(store Store) ServerOption { func WithStore(store Store) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *ServerOptions) {
o.store = store o.store = store
}) })
} }
func WithUnixPath(path string) ServerOption {
return newFuncServerOption(func(o *ServerOptions) {
o.unixPath = path
})
}
func WithTunOptions(port uint16, peers []string) ServerOption {
return newFuncServerOption(func(o *ServerOptions) {
o.tunEnable = true
o.tunPort = port
o.tunPeers = peers
})
}

View File

@@ -2,6 +2,7 @@ package rpc
var ( var (
startupErrors = make(chan error, 128) startupErrors = make(chan error, 128)
isConfigured = make(chan bool, 1)
isReady = make(chan bool, 1) isReady = make(chan bool, 1)
) )
@@ -9,6 +10,10 @@ func WhenStartFailed() chan error {
return startupErrors return startupErrors
} }
func IsConfigured() chan bool {
return isConfigured
}
func IsReady() chan bool { func IsReady() chan bool {
return isReady return isReady
} }

View File

@@ -6,10 +6,11 @@ import (
"os" "os"
) )
const unixPathDefault = "/tmp/indra.sock"
var ( var (
isUnixSockEnabled bool = false isUnixSockEnabled bool = false
unixSock net.Listener unixSock net.Listener
unixPath string
) )
func startUnixSocket(srv *grpc.Server) (err error) { func startUnixSocket(srv *grpc.Server) (err error) {
@@ -18,7 +19,7 @@ func startUnixSocket(srv *grpc.Server) (err error) {
return return
} }
if unixSock, err = net.Listen("unix", unixPath); err != nil { if unixSock, err = net.Listen("unix", o.unixPath); err != nil {
return return
} }
@@ -39,7 +40,7 @@ func stopUnixSocket() (err error) {
} }
} }
os.Remove(unixPath) os.Remove(o.unixPath)
return return
} }

View File

@@ -12,10 +12,6 @@ import (
const NullPort = 0 const NullPort = 0
var (
isTunnelEnabled bool = false
)
var ( var (
network *netstack.Net network *netstack.Net
tunnel tun.Device tunnel tun.Device
@@ -25,7 +21,6 @@ var (
var ( var (
tunKey *RPCPrivateKey tunKey *RPCPrivateKey
tunWhitelist []RPCPublicKey tunWhitelist []RPCPublicKey
tunnelPort int = 0
tunnelMTU int = 1420 tunnelMTU int = 1420
) )
@@ -43,7 +38,7 @@ func createTunnel() {
func startTunnel(srv *grpc.Server) (err error) { func startTunnel(srv *grpc.Server) (err error) {
if !isTunnelEnabled { if !o.tunEnable {
return return
} }
@@ -66,7 +61,7 @@ func startTunnel(srv *grpc.Server) (err error) {
func stopTunnel() (err error) { func stopTunnel() (err error) {
if !isTunnelEnabled { if !o.tunEnable {
return return
} }

View File

@@ -1,28 +1,8 @@
package rpc package rpc
import (
"github.com/spf13/viper"
)
func configureUnixSocket() {
unixPath = viper.GetString(unixPathFlag)
if unixPath == "" {
return
}
log.I.Ln("enabling rpc unix listener:")
log.I.F("- [/unix%s]", unixPath)
isUnixSockEnabled = true
}
func configureTunnel() { func configureTunnel() {
isTunnelEnabled = viper.GetBool(tunEnableFlag) if !o.tunEnable {
if !isTunnelEnabled {
return return
} }
@@ -31,7 +11,7 @@ func configureTunnel() {
configureTunnelPort() configureTunnelPort()
log.I.Ln("rpc tunnel listeners:") log.I.Ln("rpc tunnel listeners:")
log.I.F("- [/ip4/0.0.0.0/udp/%d /ip6/:::/udp/%d]", viper.GetUint16(tunPortFlag), viper.GetUint16(tunPortFlag)) log.I.F("- [/ip4/0.0.0.0/udp/%d /ip6/:::/udp/%d]", o.tunPort, o.tunPort)
configureTunnelKey() configureTunnelKey()
configurePeerWhitelist() configurePeerWhitelist()
@@ -69,29 +49,24 @@ func configureTunnelKey() {
func configureTunnelPort() { func configureTunnelPort() {
if viper.GetUint16(tunPortFlag) != NullPort { if o.tunPort != NullPort {
tunnelPort = int(viper.GetUint16(tunPortFlag))
return return
} }
log.I.Ln("rpc tunnel port not provided, generating a random one.") log.I.Ln("rpc tunnel port not provided, generating a random one.")
viper.Set(tunPortFlag, genRandomPort(10000)) o.tunPort = genRandomPort(10000)
tunnelPort = int(viper.GetUint16(tunPortFlag))
} }
func configurePeerWhitelist() { func configurePeerWhitelist() {
if len(viper.GetStringSlice(tunPeersFlag)) == 0 { if len(o.tunPeers) == 0 {
return return
} }
log.I.Ln("rpc tunnel whitelisted peers:") log.I.Ln("rpc tunnel whitelisted peers:")
for _, peer := range viper.GetStringSlice(tunPeersFlag) { for _, peer := range o.tunPeers {
var pubKey RPCPublicKey var pubKey RPCPublicKey

View File

@@ -5,6 +5,7 @@ import (
"git-indra.lan/indra-labs/indra/pkg/p2p" "git-indra.lan/indra-labs/indra/pkg/p2p"
"git-indra.lan/indra-labs/indra/pkg/rpc" "git-indra.lan/indra-labs/indra/pkg/rpc"
"git-indra.lan/indra-labs/indra/pkg/storage" "git-indra.lan/indra-labs/indra/pkg/storage"
"github.com/spf13/viper"
"github.com/tutorialedge/go-grpc-tutorial/chat" "github.com/tutorialedge/go-grpc-tutorial/chat"
"google.golang.org/grpc" "google.golang.org/grpc"
"sync" "sync"
@@ -63,22 +64,44 @@ func Run(ctx context.Context) {
// RPC // RPC
// //
go rpc.RunWith(func(srv *grpc.Server) { opts := []rpc.ServerOption{
chat.RegisterChatServiceServer(srv, &chat.Server{}) rpc.WithUnixPath(
}, viper.GetString(rpc.UnixPathFlag),
rpc.WithStore(&rpc.BadgerStore{storage.DB()}), ),
) rpc.WithStore(
&rpc.BadgerStore{storage.DB()},
),
}
select { if viper.GetBool(rpc.TunEnableFlag) {
case err := <-rpc.WhenStartFailed(): opts = append(opts,
log.E.Ln("rpc can't start:", err) rpc.WithTunOptions(
startupErrors <- err viper.GetUint16(rpc.TunPortFlag),
return viper.GetStringSlice(rpc.TunPeersFlag),
case <-rpc.IsReady(): ))
// continue }
case <-ctx.Done():
Shutdown() services := func(srv *grpc.Server) {
return chat.RegisterChatServiceServer(srv, &chat.Server{})
}
go rpc.RunWith(services, opts...)
for {
select {
case <-rpc.IsConfigured():
// We need to get the randomly generated port
viper.Set(rpc.TunPortFlag, rpc.Options().GetTunPort())
case err := <-rpc.WhenStartFailed():
log.E.Ln("rpc can't start:", err)
startupErrors <- err
return
case <-rpc.IsReady():
// continue
case <-ctx.Done():
Shutdown()
return
}
} }
// //

View File

@@ -3,6 +3,7 @@ package storage
import ( import (
"git-indra.lan/indra-labs/indra/pkg/rpc" "git-indra.lan/indra-labs/indra/pkg/rpc"
"github.com/dgraph-io/badger/v3" "github.com/dgraph-io/badger/v3"
"github.com/spf13/viper"
"google.golang.org/grpc" "google.golang.org/grpc"
"sync" "sync"
) )
@@ -65,7 +66,7 @@ signals:
func(srv *grpc.Server) { func(srv *grpc.Server) {
RegisterUnlockServiceServer(srv, NewUnlockService()) RegisterUnlockServiceServer(srv, NewUnlockService())
}, },
rpc.WithDisableTunnel(), rpc.WithUnixPath(viper.GetString(rpc.UnixPathFlag)),
) )
case <-rpc.IsReady(): case <-rpc.IsReady():
log.I.Ln("... awaiting unlock over rpc") log.I.Ln("... awaiting unlock over rpc")
@@ -92,6 +93,8 @@ func Shutdown() (err error) {
log.I.Ln("- storage db closing, it may take a minute...") log.I.Ln("- storage db closing, it may take a minute...")
db.RunValueLogGC(0.5)
if err = db.Close(); err != nil { if err = db.Close(); err != nil {
log.W.Ln("- storage shutdown warning: ", err) log.W.Ln("- storage shutdown warning: ", err)
} }