made all local node accessors concurrent safe
This commit is contained in:
@@ -76,7 +76,7 @@ func (eng *Engine) Shutdown() {
|
||||
return
|
||||
}
|
||||
log.T.C(func() string {
|
||||
return "shutting down client " + eng.GetLocalNode().AddrPort.String()
|
||||
return "shutting down client " + eng.GetLocalNodeAddress().String()
|
||||
})
|
||||
eng.ShuttingDown.Store(true)
|
||||
eng.C.Q()
|
||||
|
||||
@@ -115,7 +115,7 @@ func TestClient_SendExit(t *testing.T) {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
_ = clients[i].GetLocalNode().AddService(&service.Service{
|
||||
_ = clients[i].AddServiceToLocalNode(&service.Service{
|
||||
Port: port,
|
||||
Transport: sim,
|
||||
RelayRate: 18000 * 4,
|
||||
@@ -165,9 +165,9 @@ out:
|
||||
log.I.F("success\n\n")
|
||||
wg.Done()
|
||||
})
|
||||
bb := <-clients[3].GetLocalNode().Services[0].Receive()
|
||||
bb := <-clients[3].ReceiveToLocalNode(port)
|
||||
log.T.S(bb.ToBytes())
|
||||
if e = clients[3].GetLocalNode().SendTo(port, respMsg); check(e) {
|
||||
if e = clients[3].SendFromLocalNode(port, respMsg); check(e) {
|
||||
t.Error("fail send")
|
||||
}
|
||||
log.T.Ln("response sent")
|
||||
|
||||
@@ -66,7 +66,7 @@ func (eng *Engine) crypt(on *crypt.Layer, b slice.Bytes,
|
||||
// This is a little more complicated as we need to decrement the
|
||||
// amount before sending out the balance.
|
||||
eng.DecSession(sess.ID,
|
||||
(eng.GetLocalNode().RelayRate*lnwire.MilliSatoshi(len(b)+oo.
|
||||
(eng.GetLocalNodeRelayRate()*lnwire.MilliSatoshi(len(b)+oo.
|
||||
Len())/2)/1024/1024,
|
||||
false, "directbalance")
|
||||
o[2].(*balance.Layer).MilliSatoshi = sess.Remaining
|
||||
|
||||
@@ -24,12 +24,12 @@ func (eng *Engine) exit(ex *exit.Layer, b slice.Bytes,
|
||||
h := sha256.Single(ex.Bytes)
|
||||
log.T.S(h)
|
||||
log.T.F("received exit id %x", ex.ID)
|
||||
if e = eng.GetLocalNode().SendTo(ex.Port, ex.Bytes); check(e) {
|
||||
if e = eng.SendFromLocalNode(ex.Port, ex.Bytes); check(e) {
|
||||
return
|
||||
}
|
||||
timer := time.NewTicker(time.Second * 5) // todo: timeout/retries etc
|
||||
select {
|
||||
case result = <-eng.GetLocalNode().ReceiveFrom(ex.Port):
|
||||
case result = <-eng.ReceiveToLocalNode(ex.Port):
|
||||
case <-timer.C:
|
||||
}
|
||||
// We need to wrap the result in a message crypt.
|
||||
|
||||
@@ -14,7 +14,7 @@ func (eng *Engine) forward(on *forward.Layer, b slice.Bytes,
|
||||
|
||||
// forward the whole buffer received onwards. Usually there will be a
|
||||
// crypt.Layer under this which will be unwrapped by the receiver.
|
||||
if on.AddrPort.String() == eng.GetLocalNode().AddrPort.String() {
|
||||
if on.AddrPort.String() == eng.GetLocalNodeAddress().String() {
|
||||
// it is for us, we want to unwrap the next part.
|
||||
eng.handleMessage(BudgeUp(b, *c), on)
|
||||
} else {
|
||||
@@ -23,7 +23,7 @@ func (eng *Engine) forward(on *forward.Layer, b slice.Bytes,
|
||||
sess := eng.FindSessionByHeader(on1.ToPriv)
|
||||
if sess != nil {
|
||||
eng.DecSession(sess.ID,
|
||||
eng.GetLocalNode().RelayRate*lnwire.MilliSatoshi(len(b))/1024/1024,
|
||||
eng.GetLocalNodeRelayRate()*lnwire.MilliSatoshi(len(b))/1024/1024,
|
||||
false, "forward")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
|
||||
|
||||
var e error
|
||||
var on2 types.Onion
|
||||
if on.AddrPort.String() == eng.GetLocalNode().AddrPort.String() {
|
||||
if on.AddrPort.String() == eng.GetLocalNodeAddress().String() {
|
||||
if on2, e = onion.Peel(b, c); check(e) {
|
||||
return
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
|
||||
hdr, pld, _, _ := eng.FindCloaked(on1.Cloak)
|
||||
if hdr == nil || pld == nil {
|
||||
log.E.F("failed to find key for %s",
|
||||
eng.GetLocalNode().AddrPort.String())
|
||||
eng.GetLocalNodeAddress().String())
|
||||
return
|
||||
}
|
||||
// We need to find the PayloadPub to match.
|
||||
@@ -55,7 +55,7 @@ func (eng *Engine) reverse(on *reverse.Layer, b slice.Bytes,
|
||||
sess := eng.FindSessionByHeader(hdr)
|
||||
if sess != nil {
|
||||
eng.DecSession(sess.ID,
|
||||
eng.GetLocalNode().RelayRate*lnwire.
|
||||
eng.GetLocalNodeRelayRate()*lnwire.
|
||||
MilliSatoshi(len(b))/1024/1024, false, "reverse")
|
||||
eng.handleMessage(BudgeUp(b, start), on1)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
|
||||
func (eng *Engine) handler() (out bool) {
|
||||
log.T.C(func() string {
|
||||
return eng.GetLocalNode().AddrPort.String() +
|
||||
return eng.GetLocalNodeAddress().String() +
|
||||
" awaiting message"
|
||||
})
|
||||
var prev types.Onion
|
||||
@@ -33,7 +33,7 @@ func (eng *Engine) handler() (out bool) {
|
||||
eng.Cleanup()
|
||||
out = true
|
||||
break
|
||||
case b := <-eng.GetLocalNode().Receive():
|
||||
case b := <-eng.ReceiveToLocalNode(0):
|
||||
eng.handleMessage(b, prev)
|
||||
case p := <-eng.PaymentChan:
|
||||
log.D.F("incoming payment for %x: %v", p.ID, p.Amount)
|
||||
@@ -107,7 +107,7 @@ func (eng *Engine) handleMessage(b slice.Bytes, prev types.Onion) {
|
||||
|
||||
func recLog(on types.Onion, b slice.Bytes, cl *Engine) func() string {
|
||||
return func() string {
|
||||
return cl.GetLocalNode().AddrPort.String() +
|
||||
return cl.GetLocalNodeAddress().String() +
|
||||
" received " +
|
||||
fmt.Sprint(reflect.TypeOf(on)) + "\n" +
|
||||
spew.Sdump(b.ToBytes())
|
||||
|
||||
@@ -14,10 +14,10 @@ func (eng *Engine) FindCloaked(clk cloak.PubKey) (hdr *prv.Key,
|
||||
|
||||
var b cloak.Blinder
|
||||
copy(b[:], clk[:cloak.BlindLen])
|
||||
hash := cloak.Cloak(b, eng.GetLocalNode().IdentityBytes)
|
||||
hash := cloak.Cloak(b, eng.GetLocalNodeIdentityBytes())
|
||||
if hash == clk {
|
||||
log.T.F("encrypted to identity key")
|
||||
hdr = eng.GetLocalNode().IdentityPrv
|
||||
hdr = eng.GetLocalNodeIdentityPrv()
|
||||
// there is no payload key for the node, only in sessions.
|
||||
identity = true
|
||||
return
|
||||
|
||||
@@ -17,7 +17,7 @@ func (eng *Engine) Send(addr *netip.AddrPort, b slice.Bytes) {
|
||||
eng.ForEachNode(func(n *traffic.Node) bool {
|
||||
if as == n.AddrPort.String() {
|
||||
log.T.C(func() string {
|
||||
return eng.GetLocalNode().AddrPort.String() +
|
||||
return eng.GetLocalNodeAddress().String() +
|
||||
" sending to " +
|
||||
addr.String() +
|
||||
"\n" +
|
||||
|
||||
@@ -36,7 +36,7 @@ func CreateNMockCircuits(inclSessions bool, nCircuits int) (cl []*Engine, e erro
|
||||
if cl[i], e = NewEngine(transports[i], idPrv, nodes[i], nil); check(e) {
|
||||
return
|
||||
}
|
||||
cl[i].GetLocalNode().AddrPort = nodes[i].AddrPort
|
||||
cl[i].SetLocalNodeAddress(nodes[i].AddrPort)
|
||||
cl[i].SetLocalNode(nodes[i])
|
||||
if inclSessions {
|
||||
// create a session for all but the first
|
||||
|
||||
@@ -4,7 +4,13 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"git-indra.lan/indra-labs/lnd/lnd/lnwire"
|
||||
|
||||
"git-indra.lan/indra-labs/indra/pkg/crypto/key/prv"
|
||||
"git-indra.lan/indra-labs/indra/pkg/crypto/key/pub"
|
||||
"git-indra.lan/indra-labs/indra/pkg/crypto/nonce"
|
||||
"git-indra.lan/indra-labs/indra/pkg/service"
|
||||
"git-indra.lan/indra-labs/indra/pkg/util/slice"
|
||||
)
|
||||
|
||||
// NodesLen returns the length of a Nodes.
|
||||
@@ -17,8 +23,67 @@ func (sm *SessionManager) NodesLen() int {
|
||||
// GetLocalNode returns the engine's local Node.
|
||||
func (sm *SessionManager) GetLocalNode() *Node { return sm.nodes[0] }
|
||||
|
||||
// Concurrent safe accessors.
|
||||
|
||||
func (sm *SessionManager) GetLocalNodeAddress() (addr *netip.AddrPort) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().AddrPort
|
||||
}
|
||||
|
||||
func (sm *SessionManager) SetLocalNodeAddress(addr *netip.AddrPort) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
sm.GetLocalNode().AddrPort = addr
|
||||
}
|
||||
|
||||
func (sm *SessionManager) SendFromLocalNode(port uint16,
|
||||
b slice.Bytes) (e error) {
|
||||
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().SendTo(port, b)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) ReceiveToLocalNode(port uint16) <-chan slice.Bytes {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
if port == 0 {
|
||||
return sm.GetLocalNode().Receive()
|
||||
}
|
||||
return sm.GetLocalNode().ReceiveFrom(port)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) AddServiceToLocalNode(s *service.Service) (e error) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().AddService(s)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetLocalNodeRelayRate() (rate lnwire.MilliSatoshi) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().RelayRate
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetLocalNodeIdentityBytes() (ident pub.Bytes) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().IdentityBytes
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetLocalNodeIdentityPrv() (ident *prv.Key) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
return sm.GetLocalNode().IdentityPrv
|
||||
}
|
||||
|
||||
// SetLocalNode returns the engine's local Node.
|
||||
func (sm *SessionManager) SetLocalNode(n *Node) { sm.nodes[0] = n }
|
||||
func (sm *SessionManager) SetLocalNode(n *Node) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
sm.nodes[0] = n
|
||||
}
|
||||
|
||||
// AddNodes adds a Node to a Nodes.
|
||||
func (sm *SessionManager) AddNodes(nn ...*Node) {
|
||||
@@ -84,6 +149,9 @@ func (sm *SessionManager) DeleteNodeByAddrPort(ip *netip.AddrPort) (e error) {
|
||||
|
||||
// ForEachNode runs a function over the slice of nodes with the mutex locked,
|
||||
// and terminates when the function returns true.
|
||||
//
|
||||
// Do not call any SessionManager methods above inside this function or there
|
||||
// will be a mutex double locking panic, except GetLocalNode.
|
||||
func (sm *SessionManager) ForEachNode(fn func(n *Node) bool) {
|
||||
sm.Lock()
|
||||
defer sm.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user