/home/runner/work/lyquor/lyquor/net/src/peermanager.rs
Line | Count | Source |
1 | | use crate::inbound::{ConnectionInfo, Inbound, InboundConnection}; |
2 | | use crate::outbound::Outbound; |
3 | | pub use crate::outbound::RequestSender; |
4 | | use crate::{hub::HubError, inbound::RequestReceiver}; |
5 | | use lyquor_primitives::NodeID; |
6 | | use scc::HashMap; |
7 | | use std::str::FromStr; |
8 | | use std::sync::Arc; |
9 | | use tokio_util::sync::CancellationToken; |
10 | | |
11 | | type Result<T> = std::result::Result<T, HubError>; |
12 | | |
13 | | struct PeerState { |
14 | | outbound_endpoint: Option<String>, |
15 | | inbound: Inbound, |
16 | | outbound: Option<Outbound>, |
17 | | } |
18 | | |
19 | | pub struct PeerManager { |
20 | | local_id: NodeID, |
21 | | client_config: Arc<rustls::ClientConfig>, |
22 | | runtime: tokio::runtime::Handle, |
23 | | shutdown: CancellationToken, |
24 | | peers: HashMap<NodeID, PeerState>, |
25 | | } |
26 | | |
27 | | impl PeerManager { |
28 | 62 | fn canonical_endpoint(addr: &str) -> String { |
29 | 62 | if let Ok(socket_addr) = std::net::SocketAddr::from_str(addr) { |
30 | 62 | return socket_addr.to_string(); |
31 | 0 | } |
32 | 0 | addr.to_owned() |
33 | 62 | } |
34 | | |
35 | 169 | fn with_peer_state<T, F>(&self, id: &NodeID, accessor: F) -> Result<T> |
36 | 169 | where |
37 | 169 | F: FnOnce(&PeerState) -> Result<T>, |
38 | | { |
39 | 169 | let peer_state = self.peers.get_sync(id).ok_or(HubError::PeerNotExist)?0 ; |
40 | 169 | accessor(&peer_state) |
41 | 169 | } |
42 | | |
43 | 42 | pub fn new( |
44 | 42 | local_id: NodeID, client_config: Arc<rustls::ClientConfig>, runtime: tokio::runtime::Handle, |
45 | 42 | shutdown: CancellationToken, |
46 | 42 | ) -> Self { |
47 | 42 | Self { |
48 | 42 | local_id, |
49 | 42 | client_config, |
50 | 42 | runtime, |
51 | 42 | shutdown, |
52 | 42 | peers: HashMap::<NodeID, PeerState>::new(), |
53 | 42 | } |
54 | 42 | } |
55 | | |
56 | | /// Add(connect) to a peer with NodeID or IP addr. |
57 | 62 | pub async fn add_outbound(&self, id: &NodeID, addr: &str) -> Result<()> { |
58 | 62 | if *id == self.local_id { |
59 | 0 | return Err(HubError::SelfConnection); |
60 | 62 | } |
61 | 62 | let canonical_endpoint = Self::canonical_endpoint(addr); |
62 | | |
63 | | // Check if we already connected |
64 | 62 | if let Some(existing0 ) = self.peers.get_async(id).await { |
65 | 0 | if existing.outbound_endpoint.as_deref() == Some(canonical_endpoint.as_str()) && existing.outbound.is_some() |
66 | | { |
67 | 0 | return Ok(()); |
68 | 0 | } |
69 | 62 | } |
70 | | |
71 | | // Connect and store outbound |
72 | 62 | let config = crate::outbound::OutboundConfig::builder() |
73 | 62 | .tls_config(self.client_config.clone()) |
74 | 62 | .addr(addr.to_owned()) |
75 | 62 | .node_id(id.clone()) |
76 | 62 | .build(); |
77 | 62 | let outbound = Outbound::new_with_shutdown(config, self.shutdown.child_token()); |
78 | 62 | outbound.start(self.runtime.clone()).await; |
79 | | |
80 | | // Check again in case value changed during connect |
81 | 62 | self.peers |
82 | 62 | .entry_async(*id) |
83 | 62 | .await |
84 | 62 | .and_modify(|peer_state| {0 |
85 | 0 | if peer_state.outbound_endpoint.as_deref() != Some(canonical_endpoint.as_str()) { |
86 | 0 | peer_state.outbound_endpoint = Some(canonical_endpoint.clone()); |
87 | 0 | } |
88 | 0 | peer_state.outbound = Some(outbound.clone()); |
89 | 0 | }) |
90 | 62 | .or_insert_with(|| PeerState { |
91 | 62 | outbound_endpoint: Some(canonical_endpoint), |
92 | 62 | inbound: Inbound::new(), |
93 | 62 | outbound: Some(outbound), |
94 | 62 | }); |
95 | 62 | Ok(()) |
96 | 62 | } |
97 | | |
98 | | /// Add an inbound connection for a peer |
99 | | /// This will be called from listener |
100 | 62 | pub fn add_inbound(&self, conn_info: ConnectionInfo) -> Result<()> { |
101 | | // NodeID and Addr has been verified in previous stage. |
102 | 62 | let runtime = conn_info.runtime.clone(); |
103 | | |
104 | | // c_info cannot be copied so we have to deal with it in this style |
105 | 62 | let entry = self.peers.entry_sync(conn_info.id); |
106 | 62 | match entry { |
107 | 62 | scc::hash_map::Entry::Occupied(mut peer_state) => { |
108 | 62 | if let Some(iconn0 ) = peer_state.inbound.connection.as_mut() { |
109 | 0 | iconn.restart_connection(conn_info, &runtime); |
110 | 62 | } else { |
111 | 62 | let mut conn = InboundConnection::new(peer_state.inbound.tx(), self.shutdown.child_token()); |
112 | 62 | conn.start_connection(conn_info); |
113 | 62 | peer_state.inbound.connection = Some(conn); |
114 | 62 | } |
115 | | } |
116 | 0 | scc::hash_map::Entry::Vacant(e) => { |
117 | 0 | let mut inbound = Inbound::new(); |
118 | 0 | let tx = inbound.tx(); |
119 | 0 | let mut iconn = InboundConnection::new(tx, self.shutdown.child_token()); |
120 | 0 | iconn.start_connection(conn_info); |
121 | 0 | inbound.connection = Some(iconn); |
122 | 0 | e.insert_entry(PeerState { |
123 | 0 | outbound_endpoint: None, |
124 | 0 | inbound, |
125 | 0 | outbound: None, |
126 | 0 | }); |
127 | 0 | } |
128 | | } |
129 | | |
130 | 62 | Ok(()) |
131 | 62 | } |
132 | | |
133 | | /// Add a peer to our hub |
134 | 62 | pub async fn add_peer(&self, id: &NodeID, addr: &str) -> Result<()> { |
135 | 62 | self.add_outbound(id, addr).await |
136 | 62 | } |
137 | | |
138 | | /// Remove a peer from out hub |
139 | 0 | pub async fn remove_peer(&self, id: &NodeID) -> Result<()> { |
140 | 0 | let runtime = self.runtime.clone(); |
141 | 0 | let peer_state = self.peers.remove_async(id).await; |
142 | 0 | if let Some((_id, mut peer)) = peer_state { |
143 | 0 | if let Some(iconn) = peer.inbound.connection.as_mut() { |
144 | 0 | iconn.disconnect(&runtime); |
145 | 0 | } |
146 | 0 | if let Some(outbound) = peer.outbound.as_ref() { |
147 | 0 | outbound.shutdown(); |
148 | 0 | } |
149 | 0 | } |
150 | 0 | Ok(()) |
151 | 0 | } |
152 | | |
153 | | /// Remove all peers from hub |
154 | 1 | pub async fn remove_all_peers(&self) -> Result<()> { |
155 | 1 | let runtime = self.runtime.clone(); |
156 | 1 | self.peers |
157 | 1 | .retain_async(|_, peer_state| {0 |
158 | 0 | if let Some(inbound_connection) = peer_state.inbound.connection.as_mut() { |
159 | 0 | inbound_connection.disconnect(&runtime); |
160 | 0 | } |
161 | 0 | if let Some(outbound) = peer_state.outbound.as_ref() { |
162 | 0 | outbound.shutdown(); |
163 | 0 | } |
164 | 0 | false |
165 | 0 | }) |
166 | 1 | .await; |
167 | | |
168 | 1 | Ok(()) |
169 | 1 | } |
170 | | |
171 | 109 | pub fn outbound(&self, id: &NodeID) -> Result<RequestSender> { |
172 | 109 | self.with_peer_state(id, |peer_state| { |
173 | 109 | let outbound = peer_state.outbound.as_ref().ok_or(HubError::OutboundError)?0 ; |
174 | 109 | Ok(RequestSender::new(self.local_id, outbound)) |
175 | 109 | }) |
176 | 109 | } |
177 | | |
178 | 60 | pub fn inbound(&self, id: &NodeID) -> Result<RequestReceiver> { |
179 | 60 | self.with_peer_state(id, |peer_state| Ok(peer_state.inbound.rx())) |
180 | 60 | } |
181 | | } |