/home/runner/work/lyquor/lyquor/net/src/peermanager.rs
Line | Count | Source |
1 | | use crate::hub::HubError; |
2 | | use crate::inbound::{ConnectionInfo, Handler, Inbound, InboundConnection}; |
3 | | pub use crate::outbound::RequestSender; |
4 | | use crate::outbound::{Outbound, RPCSender}; |
5 | | use lyquor_primitives::NodeID; |
6 | | use scc::HashMap; |
7 | | use std::str::FromStr; |
8 | | use std::sync::Arc; |
9 | | use tokio::sync::watch; |
10 | | use tokio_util::sync::CancellationToken; |
11 | | |
12 | | type Result<T> = std::result::Result<T, HubError>; |
13 | | |
14 | | struct PeerState { |
15 | | outbound_endpoint: Option<String>, |
16 | | inbound: Inbound, |
17 | | outbound: Option<Outbound>, |
18 | | } |
19 | | |
20 | | pub struct PeerManager { |
21 | | local_id: NodeID, |
22 | | client_config: Arc<rustls::ClientConfig>, |
23 | | shutdown: CancellationToken, |
24 | | peers: HashMap<NodeID, PeerState>, |
25 | | rpc_service: watch::Receiver<Handler>, |
26 | | } |
27 | | |
28 | | impl PeerManager { |
29 | 75 | fn canonical_endpoint(addr: &str) -> String { |
30 | 75 | if let Ok(socket_addr) = std::net::SocketAddr::from_str(addr) { |
31 | 75 | return socket_addr.to_string(); |
32 | 0 | } |
33 | 0 | addr.to_owned() |
34 | 75 | } |
35 | | |
36 | 58 | fn with_peer_state<T, F>(&self, id: &NodeID, accessor: F) -> Result<T> |
37 | 58 | where |
38 | 58 | F: FnOnce(&PeerState) -> Result<T>, |
39 | | { |
40 | 58 | let peer_state = self.peers.get_sync(id).ok_or(HubError::PeerNotExist)?0 ; |
41 | 58 | accessor(&peer_state) |
42 | 58 | } |
43 | | |
44 | 48 | pub fn new( |
45 | 48 | local_id: NodeID, client_config: Arc<rustls::ClientConfig>, shutdown: CancellationToken, |
46 | 48 | rpc_service: watch::Receiver<Handler>, |
47 | 48 | ) -> Self { |
48 | 48 | Self { |
49 | 48 | local_id, |
50 | 48 | client_config, |
51 | 48 | shutdown, |
52 | 48 | peers: HashMap::<NodeID, PeerState>::new(), |
53 | 48 | rpc_service, |
54 | 48 | } |
55 | 48 | } |
56 | | |
57 | | /// Add(connect) to a peer with NodeID or IP addr. |
58 | 75 | pub async fn add_outbound(&self, id: &NodeID, addr: &str) -> Result<()> { |
59 | 75 | if *id == self.local_id { |
60 | 0 | return Err(HubError::SelfConnection); |
61 | 75 | } |
62 | 75 | let canonical_endpoint = Self::canonical_endpoint(addr); |
63 | | |
64 | | // Check if we already connected |
65 | 75 | if let Some(existing58 ) = self.peers.get_async(id).await { |
66 | 58 | if existing.outbound_endpoint.as_deref() == Some(canonical_endpoint.as_str()) && existing.outbound0 .is_some0 () |
67 | | { |
68 | 0 | return Ok(()); |
69 | 58 | } |
70 | 17 | } |
71 | | |
72 | | // Connect and store outbound |
73 | 75 | let config = crate::outbound::OutboundConfig::builder() |
74 | 75 | .tls_config(self.client_config.clone()) |
75 | 75 | .addr(addr.to_owned()) |
76 | 75 | .node_id(id.clone()) |
77 | 75 | .build(); |
78 | 75 | let outbound = Outbound::new_with_shutdown(config, self.shutdown.child_token()); |
79 | 75 | outbound.start().await; |
80 | | |
81 | | // Check again in case value changed during connect |
82 | 75 | let mut old_outbound = None; |
83 | 75 | self.peers |
84 | 75 | .entry_async(*id) |
85 | 75 | .await |
86 | 75 | .and_modify(|peer_state| {58 |
87 | 58 | if peer_state.outbound_endpoint.as_deref() != Some(canonical_endpoint.as_str()) { |
88 | 58 | peer_state.outbound_endpoint = Some(canonical_endpoint.clone()); |
89 | 58 | }0 |
90 | 58 | old_outbound = peer_state.outbound.take(); |
91 | 58 | peer_state.outbound = Some(outbound.clone()); |
92 | 58 | }) |
93 | 75 | .or_insert_with(|| PeerState { |
94 | 17 | outbound_endpoint: Some(canonical_endpoint), |
95 | 17 | inbound: Inbound::new(), |
96 | 17 | outbound: Some(outbound), |
97 | 17 | }); |
98 | | |
99 | 75 | if let Some(outbound0 ) = old_outbound { |
100 | 0 | outbound.disconnect_wait().await; |
101 | 75 | } |
102 | | |
103 | 75 | Ok(()) |
104 | 75 | } |
105 | | |
106 | | /// Add an inbound connection for a peer |
107 | | /// This will be called from listener |
108 | 63 | pub fn add_inbound(&self, conn_info: ConnectionInfo) -> Result<()> { |
109 | | // NodeID and Addr has been verified in previous stage. |
110 | | // c_info cannot be copied so we have to deal with it in this style |
111 | 63 | let entry = self.peers.entry_sync(conn_info.id); |
112 | 63 | match entry { |
113 | 63 | scc::hash_map::Entry::Occupied(mut peer_state) => { |
114 | 63 | if let Some(iconn0 ) = peer_state.inbound.connection.as_mut() { |
115 | 0 | iconn.disconnect(); |
116 | 63 | } |
117 | 63 | let mut conn = InboundConnection::new( |
118 | 63 | peer_state.inbound.service(), |
119 | 63 | self.rpc_service.clone(), |
120 | 63 | self.shutdown.child_token(), |
121 | | ); |
122 | 63 | conn.start_connection(conn_info); |
123 | 63 | peer_state.inbound.connection = Some(conn); |
124 | | } |
125 | 0 | scc::hash_map::Entry::Vacant(e) => { |
126 | 0 | let mut inbound = Inbound::new(); |
127 | 0 | let mut iconn = |
128 | 0 | InboundConnection::new(inbound.service(), self.rpc_service.clone(), self.shutdown.child_token()); |
129 | 0 | iconn.start_connection(conn_info); |
130 | 0 | inbound.connection = Some(iconn); |
131 | 0 | e.insert_entry(PeerState { |
132 | 0 | outbound_endpoint: None, |
133 | 0 | inbound, |
134 | 0 | outbound: None, |
135 | 0 | }); |
136 | 0 | } |
137 | | } |
138 | | |
139 | 63 | Ok(()) |
140 | 63 | } |
141 | | |
142 | | /// Add a peer to our hub |
143 | 74 | pub async fn add_peer(&self, id: &NodeID, addr: &str) -> Result<()> { |
144 | 74 | self.add_outbound(id, addr).await |
145 | 74 | } |
146 | | |
147 | 62 | pub(crate) fn serve_inbound(&self, id: &NodeID, service: Handler) -> Result<()> { |
148 | 62 | if *id == self.local_id { |
149 | 0 | return Err(HubError::SelfConnection); |
150 | 62 | } |
151 | | |
152 | 62 | let entry = self.peers.entry_sync(*id); |
153 | 62 | match entry { |
154 | 4 | scc::hash_map::Entry::Occupied(peer_state) => { |
155 | 4 | peer_state.inbound.serve(service); |
156 | 4 | } |
157 | 58 | scc::hash_map::Entry::Vacant(e) => { |
158 | 58 | let inbound = Inbound::new(); |
159 | 58 | inbound.serve(service); |
160 | 58 | e.insert_entry(PeerState { |
161 | 58 | outbound_endpoint: None, |
162 | 58 | inbound, |
163 | 58 | outbound: None, |
164 | 58 | }); |
165 | 58 | } |
166 | | } |
167 | 62 | Ok(()) |
168 | 62 | } |
169 | | |
170 | | /// Remove a peer from out hub |
171 | 1 | pub async fn remove_peer(&self, id: &NodeID) -> Result<()> { |
172 | 1 | let peer_state = self.peers.remove_async(id).await; |
173 | 1 | if let Some((_id, mut peer)) = peer_state { |
174 | 1 | if let Some(iconn) = peer.inbound.connection.as_mut() { |
175 | 1 | iconn.disconnect_wait().await; |
176 | 0 | } |
177 | 1 | if let Some(outbound) = peer.outbound.as_ref() { |
178 | 1 | outbound.disconnect_wait().await; |
179 | 0 | } |
180 | 0 | } |
181 | 1 | Ok(()) |
182 | 1 | } |
183 | | |
184 | | /// Remove all peers from hub |
185 | 1 | pub async fn remove_all_peers(&self) -> Result<()> { |
186 | 1 | let peer_ids = self.get_peers().await; |
187 | 1 | for peer_id in peer_ids { |
188 | 1 | let _ = self.remove_peer(&peer_id).await; |
189 | | } |
190 | 1 | Ok(()) |
191 | 1 | } |
192 | | |
193 | 54 | pub fn outbound(&self, id: &NodeID) -> Result<RequestSender> { |
194 | 54 | self.with_peer_state(id, |peer_state| { |
195 | 54 | let outbound = peer_state.outbound.as_ref().ok_or(HubError::OutboundError)?0 ; |
196 | 54 | Ok(RequestSender::new(outbound)) |
197 | 54 | }) |
198 | 54 | } |
199 | | |
200 | 4 | pub fn rpc_outbound(&self, id: &NodeID) -> Result<RPCSender> { |
201 | 4 | self.with_peer_state(id, |peer_state| { |
202 | 4 | let outbound = peer_state.outbound.as_ref().ok_or(HubError::OutboundError)?0 ; |
203 | 4 | Ok(RPCSender::new( |
204 | 4 | peer_state.outbound_endpoint.clone().unwrap_or_default(), |
205 | 4 | RequestSender::new(outbound), |
206 | 4 | )) |
207 | 4 | }) |
208 | 4 | } |
209 | | |
210 | 3 | pub async fn get_peers(&self) -> Vec<NodeID> { |
211 | 3 | let mut peers = Vec::<NodeID>::new(); |
212 | 3 | self.peers |
213 | 3 | .iter_async(|id, _| {2 |
214 | 2 | peers.push(*id); |
215 | 2 | true |
216 | 2 | }) |
217 | 3 | .await; |
218 | 3 | peers |
219 | 3 | } |
220 | | } |
221 | | |
222 | | #[cfg(test)] |
223 | | pub mod tests { |
224 | | use std::sync::Arc; |
225 | | |
226 | | use crate::{ |
227 | | inbound::{BoxError, ConnectionInfo, Inbound, Request, ResponseBuilder, boxed_handler, handler_channel}, |
228 | | outbound::RequestError, |
229 | | peermanager::PeerManager, |
230 | | }; |
231 | | use bytes::Bytes; |
232 | | use http_body_util::BodyExt; |
233 | | use lyquor_test::test; |
234 | | use tokio::net::TcpListener; |
235 | | use tokio_util::sync::CancellationToken; |
236 | | use tower::ServiceExt; |
237 | | |
238 | | #[test(tokio::test)] |
239 | | // We add inbound and outbound, then we remove all peers, lastly we verify if connections are getting closed properly. |
240 | | async fn test_hub_peermanager() { |
241 | | let (local_node_id, local_tls_config) = lyquor_tls::generator::test_config(9); |
242 | | let (remote_node_id, remote_tls_config) = lyquor_tls::generator::test_config(10); |
243 | | let (_rpc_service, rpc_receiver) = handler_channel(); |
244 | | let peermgr = PeerManager::new( |
245 | | local_node_id, |
246 | | local_tls_config.client_config().unwrap(), |
247 | | CancellationToken::new(), |
248 | | rpc_receiver.clone(), |
249 | | ); |
250 | | let peermgr = Arc::new(peermgr); |
251 | | |
252 | | let remote_socket = lyquor_test::reserve_unused_socket(); |
253 | | let remote_addr = remote_socket.local_addr(); |
254 | | drop(remote_socket); |
255 | | let remote_listener = TcpListener::bind(&remote_addr).await.unwrap(); |
256 | | |
257 | | // Add outbound |
258 | | let pm_clone = peermgr.clone(); |
259 | | let remote_addr_clone = remote_addr.clone(); |
260 | 1 | let connect_task = tokio::spawn(async move { |
261 | 1 | pm_clone |
262 | 1 | .add_outbound(&remote_node_id, &remote_addr_clone) |
263 | 1 | .await |
264 | 1 | .unwrap(); |
265 | 1 | }); |
266 | | |
267 | | let remote_acceptor = tokio_rustls::TlsAcceptor::from(remote_tls_config.server_config().unwrap()); |
268 | | let (remote_conn, remote_peer_addr) = remote_listener.accept().await.unwrap(); |
269 | | let remote_stream = remote_acceptor.accept(remote_conn).await.unwrap(); |
270 | | let (_, remote_conn_ref) = remote_stream.get_ref(); |
271 | | let remote_peer_id = lyquor_tls::get_nodeid_from_cert(remote_conn_ref.peer_certificates()).unwrap(); |
272 | | assert_eq!(remote_peer_id, local_node_id); |
273 | | |
274 | | let remote_inbound = Inbound::new(); |
275 | 1 | remote_inbound.serve(boxed_handler(tower::service_fn(|req: Request| async move { |
276 | 1 | let req_body = req.into_body().collect().await.unwrap().to_bytes(); |
277 | 1 | assert_eq!(req_body, Bytes::from("You should see this.")); |
278 | 1 | Ok::<_, BoxError>(ResponseBuilder::response("You should see this.")) |
279 | 2 | }))); |
280 | | let mut i_conn = |
281 | | crate::inbound::InboundConnection::new(remote_inbound.service(), rpc_receiver, CancellationToken::new()); |
282 | | i_conn.start_connection(ConnectionInfo { |
283 | | id: remote_peer_id, |
284 | | stream: remote_stream, |
285 | | addr: remote_peer_addr, |
286 | | }); |
287 | | connect_task.await.unwrap(); |
288 | | |
289 | | let sender = peermgr.outbound(&remote_node_id).unwrap(); |
290 | | let send_task = tokio::spawn({ |
291 | | let sender = sender.clone(); |
292 | 1 | async move { |
293 | 1 | sender |
294 | 1 | .oneshot(hyper::Request::new(http_body_util::Full::new(Bytes::from( |
295 | 1 | "You should see this.", |
296 | 1 | )))) |
297 | 1 | .await |
298 | 1 | } |
299 | | }); |
300 | | send_task.await.unwrap().unwrap(); |
301 | | |
302 | | let local_socket = lyquor_test::reserve_unused_socket(); |
303 | | let local_addr = local_socket.local_addr(); |
304 | | drop(local_socket); |
305 | | let local_listener = TcpListener::bind(&local_addr).await.unwrap(); |
306 | | |
307 | | let (send_after_close_tx, send_after_close_rx) = tokio::sync::oneshot::channel::<()>(); |
308 | | let (send_after_close_result_tx, send_after_close_result_rx) = |
309 | | tokio::sync::oneshot::channel::<std::result::Result<hyper::StatusCode, hyper::Error>>(); |
310 | 1 | let remote_inbound_task = tokio::spawn(async move { |
311 | 1 | let connector = tokio_rustls::TlsConnector::from(remote_tls_config.client_config().unwrap()); |
312 | 1 | let tcp = tokio::net::TcpStream::connect(&local_addr).await.unwrap(); |
313 | 1 | let tls_stream = connector |
314 | 1 | .connect(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), tcp) |
315 | 1 | .await |
316 | 1 | .unwrap(); |
317 | | |
318 | 1 | let io = hyper_util::rt::TokioIo::new(tls_stream); |
319 | 1 | let (mut sender, conn) = hyper::client::conn::http2::handshake(hyper_util::rt::TokioExecutor::new(), io) |
320 | 1 | .await |
321 | 1 | .unwrap(); |
322 | 1 | let conn_task = tokio::spawn(async move { |
323 | 1 | let _ = conn.await; |
324 | 1 | }); |
325 | | |
326 | 1 | let _ = send_after_close_rx.await; |
327 | 1 | let send_result = sender |
328 | 1 | .send_request(hyper::Request::new(http_body_util::Full::new(Bytes::from( |
329 | 1 | "should get a bad statuscode", |
330 | 1 | )))) |
331 | 1 | .await |
332 | 1 | .map(|res| res0 .status0 ()); |
333 | 1 | let _ = send_after_close_result_tx.send(send_result); |
334 | | |
335 | 1 | conn_task.abort(); |
336 | 1 | let _ = conn_task.await; |
337 | 1 | }); |
338 | | |
339 | | // Add inbound |
340 | | let local_acceptor = tokio_rustls::TlsAcceptor::from(local_tls_config.server_config().unwrap()); |
341 | | let (conn2, addr2) = local_listener.accept().await.unwrap(); |
342 | | let stream2 = local_acceptor.accept(conn2).await.unwrap(); |
343 | | let (_, conn2) = stream2.get_ref(); |
344 | | let certs2 = conn2.peer_certificates(); |
345 | | let id2 = lyquor_tls::get_nodeid_from_cert(certs2).unwrap(); |
346 | | assert_eq!(id2, remote_node_id); |
347 | | peermgr |
348 | | .add_inbound(ConnectionInfo { |
349 | | id: id2, |
350 | | stream: stream2, |
351 | | addr: addr2, |
352 | | }) |
353 | | .unwrap(); |
354 | | |
355 | | assert_eq!(peermgr.get_peers().await.len(), 1); |
356 | | |
357 | | // Clean up |
358 | | let _ = peermgr.remove_all_peers().await; |
359 | | assert!( |
360 | | peermgr.get_peers().await.is_empty(), |
361 | | "PeerManager is not empty after removal." |
362 | | ); |
363 | | |
364 | | // Check outbound is closed |
365 | | let resp = sender |
366 | | .oneshot(hyper::Request::new(http_body_util::Full::new(Bytes::from( |
367 | | "You should not see this.", |
368 | | )))) |
369 | | .await; |
370 | | assert!( |
371 | | matches!(resp, Err(RequestError::NotConnected)), |
372 | | "Connection is not closed: {resp:?}" |
373 | | ); |
374 | | |
375 | | // Check inbound connection is closed |
376 | | let _ = send_after_close_tx.send(()); |
377 | | let post_close_result = send_after_close_result_rx.await.unwrap(); |
378 | | assert!(post_close_result.is_err(), "Inbound connection is not closed."); |
379 | | remote_inbound_task.await.unwrap(); |
380 | | } |
381 | | } |