made all local node accessors concurrent safe

This commit is contained in:
херетик
2023-01-31 17:01:15 +00:00
parent 265cb6e7c2
commit 61537f1aa0
11 changed files with 88 additions and 20 deletions

View File

@@ -76,7 +76,7 @@ func (eng *Engine) Shutdown() {
return
}
log.T.C(func() string {
return "shutting down client " + eng.GetLocalNode().AddrPort.String()
return "shutting down client " + eng.GetLocalNodeAddress().String()
})
eng.ShuttingDown.Store(true)
eng.C.Q()

View File

@@ -115,7 +115,7 @@ func TestClient_SendExit(t *testing.T) {
if i == 0 {
continue
}
_ = clients[i].GetLocalNode().AddService(&service.Service{
_ = clients[i].AddServiceToLocalNode(&service.Service{
Port: port,
Transport: sim,
RelayRate: 18000 * 4,
@@ -165,9 +165,9 @@ out:
log.I.F("success\n\n")
wg.Done()
})
bb := <-clients[3].GetLocalNode().Services[0].Receive()
bb := <-clients[3].ReceiveToLocalNode(port)
log.T.S(bb.ToBytes())
if e = clients[3].GetLocalNode().SendTo(port, respMsg); check(e) {
if e = clients[3].SendFromLocalNode(port, respMsg); check(e) {
t.Error("fail send")
}
log.T.Ln("response sent")

View File

@@ -66,7 +66,7 @@ func (eng *Engine) crypt(on *crypt.Layer, b slice.Bytes,
// This is a little more complicated as we need to decrement the
// amount before sending out the balance.
eng.DecSession(sess.ID,
(eng.GetLocalNode().RelayRate*lnwire.MilliSatoshi(len(b)+oo.
(eng.GetLocalNodeRelayRate()*lnwire.MilliSatoshi(len(b)+oo.
Len())/2)/1024/1024,
false, "directbalance")
o[2].(*balance.Layer).MilliSatoshi = sess.Remaining

View File

@@ -24,12 +24,12 @@ func (eng *Engine) exit(ex *exit.Layer, b slice.Bytes,
h := sha256.Single(ex.Bytes)
log.T.S(h)
log.T.F("received exit id %x", ex.ID)
if e = eng.GetLocalNode().SendTo(ex.Port, ex.Bytes); check(e) {
if e = eng.SendFromLocalNode(ex.Port, ex.Bytes); check(e) {
return
}
timer := time.NewTicker(time.Second * 5) // todo: timeout/retries etc
select {
case result = <-eng.GetLocalNode().ReceiveFrom(ex.Port):
case result = <-eng.ReceiveToLocalNode(ex.Port):
case <-timer.C:
}
// We need to wrap the result in a message crypt.

View File

@@ -14,7 +14,7 @@ func (eng *Engine) forward(on *forward.Layer, b slice.Bytes,
// forward the whole buffer received onwards. Usually there will be a
// crypt.Layer under this which will be unwrapped by the receiver.
if on.AddrPort.String() == eng.GetLocalNode().AddrPort.String() {
if on.AddrPort.String() == eng.GetLocalNodeAddress().String() {
// it is for us, we want to unwrap the next part.
eng.handleMessage(BudgeUp(b, *c), on)
} else {
@@ -23,7 +23,7 @@ func (eng *Engine) forward(on *forward.Layer, b slice.Bytes,
sess := eng.FindSessionByHeader(on1.ToPriv)
if sess != nil {
eng.DecSession(sess.ID,
eng.GetLocalNode().RelayRate*lnwire.MilliSatoshi(len(b))/1024/1024,
eng.GetLocalNodeRelayRate()*lnwire.MilliSatoshi(len(b))/1024/1024,
false, "forward")
}
}

View File

@@ -16,7 +16,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
var e error
var on2 types.Onion
if on.AddrPort.String() == eng.GetLocalNode().AddrPort.String() {
if on.AddrPort.String() == eng.GetLocalNodeAddress().String() {
if on2, e = onion.Peel(b, c); check(e) {
return
}
@@ -30,7 +30,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
hdr, pld, _, _ := eng.FindCloaked(on1.Cloak)
if hdr == nil || pld == nil {
log.E.F("failed to find key for %s",
eng.GetLocalNode().AddrPort.String())
eng.GetLocalNodeAddress().String())
return
}
// We need to find the PayloadPub to match.
@@ -55,7 +55,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
sess := eng.FindSessionByHeader(hdr)
if sess != nil {
eng.DecSession(sess.ID,
eng.GetLocalNode().RelayRate*lnwire.
eng.GetLocalNodeRelayRate()*lnwire.
MilliSatoshi(len(b))/1024/1024, false, "reverse")
eng.handleMessage(BudgeUp(b, start), on1)
}

View File

@@ -24,7 +24,7 @@ import (
func (eng *Engine) handler() (out bool) {
log.T.C(func() string {
return eng.GetLocalNode().AddrPort.String() +
return eng.GetLocalNodeAddress().String() +
" awaiting message"
})
var prev types.Onion
@@ -33,7 +33,7 @@ func (eng *Engine) handler() (out bool) {
eng.Cleanup()
out = true
break
case b := <-eng.GetLocalNode().Receive():
case b := <-eng.ReceiveToLocalNode(0):
eng.handleMessage(b, prev)
case p := <-eng.PaymentChan:
log.D.F("incoming payment for %x: %v", p.ID, p.Amount)
@@ -107,7 +107,7 @@ func (eng *Engine) handleMessage(b slice.Bytes, prev types.Onion) {
func recLog(on types.Onion, b slice.Bytes, cl *Engine) func() string {
return func() string {
return cl.GetLocalNode().AddrPort.String() +
return cl.GetLocalNodeAddress().String() +
" received " +
fmt.Sprint(reflect.TypeOf(on)) + "\n" +
spew.Sdump(b.ToBytes())

View File

@@ -14,10 +14,10 @@ func (eng *Engine) FindCloaked(clk cloak.PubKey) (hdr *prv.Key,
var b cloak.Blinder
copy(b[:], clk[:cloak.BlindLen])
hash := cloak.Cloak(b, eng.GetLocalNode().IdentityBytes)
hash := cloak.Cloak(b, eng.GetLocalNodeIdentityBytes())
if hash == clk {
log.T.F("encrypted to identity key")
hdr = eng.GetLocalNode().IdentityPrv
hdr = eng.GetLocalNodeIdentityPrv()
// there is no payload key for the node, only in sessions.
identity = true
return

View File

@@ -17,7 +17,7 @@ func (eng *Engine) Send(addr *netip.AddrPort, b slice.Bytes) {
eng.ForEachNode(func(n *traffic.Node) bool {
if as == n.AddrPort.String() {
log.T.C(func() string {
return eng.GetLocalNode().AddrPort.String() +
return eng.GetLocalNodeAddress().String() +
" sending to " +
addr.String() +
"\n" +

View File

@@ -36,7 +36,7 @@ func CreateNMockCircuits(inclSessions bool, nCircuits int) (cl []*Engine, e erro
if cl[i], e = NewEngine(transports[i], idPrv, nodes[i], nil); check(e) {
return
}
cl[i].GetLocalNode().AddrPort = nodes[i].AddrPort
cl[i].SetLocalNodeAddress(nodes[i].AddrPort)
cl[i].SetLocalNode(nodes[i])
if inclSessions {
// create a session for all but the first

View File

@@ -4,7 +4,13 @@ import (
"fmt"
"net/netip"
"git-indra.lan/indra-labs/lnd/lnd/lnwire"
"git-indra.lan/indra-labs/indra/pkg/crypto/key/prv"
"git-indra.lan/indra-labs/indra/pkg/crypto/key/pub"
"git-indra.lan/indra-labs/indra/pkg/crypto/nonce"
"git-indra.lan/indra-labs/indra/pkg/service"
"git-indra.lan/indra-labs/indra/pkg/util/slice"
)
// NodesLen returns the length of a Nodes.
@@ -17,8 +23,67 @@ func (sm *SessionManager) NodesLen() int {
// GetLocalNode returns the engine's local Node.
func (sm *SessionManager) GetLocalNode() *Node { return sm.nodes[0] }
// Concurrent safe accessors.
func (sm *SessionManager) GetLocalNodeAddress() (addr *netip.AddrPort) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().AddrPort
}
func (sm *SessionManager) SetLocalNodeAddress(addr *netip.AddrPort) {
sm.Lock()
defer sm.Unlock()
sm.GetLocalNode().AddrPort = addr
}
func (sm *SessionManager) SendFromLocalNode(port uint16,
b slice.Bytes) (e error) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().SendTo(port, b)
}
func (sm *SessionManager) ReceiveToLocalNode(port uint16) <-chan slice.Bytes {
sm.Lock()
defer sm.Unlock()
if port == 0 {
return sm.GetLocalNode().Receive()
}
return sm.GetLocalNode().ReceiveFrom(port)
}
func (sm *SessionManager) AddServiceToLocalNode(s *service.Service) (e error) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().AddService(s)
}
func (sm *SessionManager) GetLocalNodeRelayRate() (rate lnwire.MilliSatoshi) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().RelayRate
}
func (sm *SessionManager) GetLocalNodeIdentityBytes() (ident pub.Bytes) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().IdentityBytes
}
func (sm *SessionManager) GetLocalNodeIdentityPrv() (ident *prv.Key) {
sm.Lock()
defer sm.Unlock()
return sm.GetLocalNode().IdentityPrv
}
// SetLocalNode returns the engine's local Node.
func (sm *SessionManager) SetLocalNode(n *Node) { sm.nodes[0] = n }
func (sm *SessionManager) SetLocalNode(n *Node) {
sm.Lock()
defer sm.Unlock()
sm.nodes[0] = n
}
// AddNodes adds a Node to a Nodes.
func (sm *SessionManager) AddNodes(nn ...*Node) {
@@ -84,6 +149,9 @@ func (sm *SessionManager) DeleteNodeByAddrPort(ip *netip.AddrPort) (e error) {
// ForEachNode runs a function over the slice of nodes with the mutex locked,
// and terminates when the function returns true.
//
// Do not call any SessionManager methods above inside this function or there
// will be a mutex double locking panic, except GetLocalNode.
func (sm *SessionManager) ForEachNode(fn func(n *Node) bool) {
sm.Lock()
defer sm.Unlock()