Coverage Report

Created: 2026-05-15 09:02

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/runner/work/lyquor/lyquor/net/src/peermanager.rs
Line
Count
Source
1
use crate::hub::HubError;
2
use crate::inbound::{ConnectionInfo, Handler, Inbound, InboundConnection};
3
pub use crate::outbound::RequestSender;
4
use crate::outbound::{Outbound, RPCSender};
5
use lyquor_primitives::NodeID;
6
use scc::HashMap;
7
use std::str::FromStr;
8
use std::sync::Arc;
9
use tokio::sync::watch;
10
use tokio_util::sync::CancellationToken;
11
12
type Result<T> = std::result::Result<T, HubError>;
13
14
struct PeerState {
15
    outbound_endpoint: Option<String>,
16
    inbound: Inbound,
17
    outbound: Option<Outbound>,
18
}
19
20
pub struct PeerManager {
21
    local_id: NodeID,
22
    client_config: Arc<rustls::ClientConfig>,
23
    shutdown: CancellationToken,
24
    peers: HashMap<NodeID, PeerState>,
25
    rpc_service: watch::Receiver<Handler>,
26
}
27
28
impl PeerManager {
29
75
    fn canonical_endpoint(addr: &str) -> String {
30
75
        if let Ok(socket_addr) = std::net::SocketAddr::from_str(addr) {
31
75
            return socket_addr.to_string();
32
0
        }
33
0
        addr.to_owned()
34
75
    }
35
36
58
    fn with_peer_state<T, F>(&self, id: &NodeID, accessor: F) -> Result<T>
37
58
    where
38
58
        F: FnOnce(&PeerState) -> Result<T>,
39
    {
40
58
        let peer_state = self.peers.get_sync(id).ok_or(HubError::PeerNotExist)
?0
;
41
58
        accessor(&peer_state)
42
58
    }
43
44
48
    pub fn new(
45
48
        local_id: NodeID, client_config: Arc<rustls::ClientConfig>, shutdown: CancellationToken,
46
48
        rpc_service: watch::Receiver<Handler>,
47
48
    ) -> Self {
48
48
        Self {
49
48
            local_id,
50
48
            client_config,
51
48
            shutdown,
52
48
            peers: HashMap::<NodeID, PeerState>::new(),
53
48
            rpc_service,
54
48
        }
55
48
    }
56
57
    /// Add(connect) to a peer with NodeID or IP addr.
58
75
    pub async fn add_outbound(&self, id: &NodeID, addr: &str) -> Result<()> {
59
75
        if *id == self.local_id {
60
0
            return Err(HubError::SelfConnection);
61
75
        }
62
75
        let canonical_endpoint = Self::canonical_endpoint(addr);
63
64
        // Check if we already connected
65
75
        if let Some(
existing58
) = self.peers.get_async(id).await {
66
58
            if existing.outbound_endpoint.as_deref() == Some(canonical_endpoint.as_str()) && 
existing.outbound0
.
is_some0
()
67
            {
68
0
                return Ok(());
69
58
            }
70
17
        }
71
72
        // Connect and store outbound
73
75
        let config = crate::outbound::OutboundConfig::builder()
74
75
            .tls_config(self.client_config.clone())
75
75
            .addr(addr.to_owned())
76
75
            .node_id(id.clone())
77
75
            .build();
78
75
        let outbound = Outbound::new_with_shutdown(config, self.shutdown.child_token());
79
75
        outbound.start().await;
80
81
        // Check again in case value changed during connect
82
75
        let mut old_outbound = None;
83
75
        self.peers
84
75
            .entry_async(*id)
85
75
            .await
86
75
            .and_modify(|peer_state| 
{58
87
58
                if peer_state.outbound_endpoint.as_deref() != Some(canonical_endpoint.as_str()) {
88
58
                    peer_state.outbound_endpoint = Some(canonical_endpoint.clone());
89
58
                
}0
90
58
                old_outbound = peer_state.outbound.take();
91
58
                peer_state.outbound = Some(outbound.clone());
92
58
            })
93
75
            .or_insert_with(|| PeerState {
94
17
                outbound_endpoint: Some(canonical_endpoint),
95
17
                inbound: Inbound::new(),
96
17
                outbound: Some(outbound),
97
17
            });
98
99
75
        if let Some(
outbound0
) = old_outbound {
100
0
            outbound.disconnect_wait().await;
101
75
        }
102
103
75
        Ok(())
104
75
    }
105
106
    /// Add an inbound connection for a peer
107
    /// This will be called from listener
108
63
    pub fn add_inbound(&self, conn_info: ConnectionInfo) -> Result<()> {
109
        // NodeID and Addr has been verified in previous stage.
110
        // c_info cannot be copied so we have to deal with it in this style
111
63
        let entry = self.peers.entry_sync(conn_info.id);
112
63
        match entry {
113
63
            scc::hash_map::Entry::Occupied(mut peer_state) => {
114
63
                if let Some(
iconn0
) = peer_state.inbound.connection.as_mut() {
115
0
                    iconn.disconnect();
116
63
                }
117
63
                let mut conn = InboundConnection::new(
118
63
                    peer_state.inbound.service(),
119
63
                    self.rpc_service.clone(),
120
63
                    self.shutdown.child_token(),
121
                );
122
63
                conn.start_connection(conn_info);
123
63
                peer_state.inbound.connection = Some(conn);
124
            }
125
0
            scc::hash_map::Entry::Vacant(e) => {
126
0
                let mut inbound = Inbound::new();
127
0
                let mut iconn =
128
0
                    InboundConnection::new(inbound.service(), self.rpc_service.clone(), self.shutdown.child_token());
129
0
                iconn.start_connection(conn_info);
130
0
                inbound.connection = Some(iconn);
131
0
                e.insert_entry(PeerState {
132
0
                    outbound_endpoint: None,
133
0
                    inbound,
134
0
                    outbound: None,
135
0
                });
136
0
            }
137
        }
138
139
63
        Ok(())
140
63
    }
141
142
    /// Add a peer to our hub
143
74
    pub async fn add_peer(&self, id: &NodeID, addr: &str) -> Result<()> {
144
74
        self.add_outbound(id, addr).await
145
74
    }
146
147
62
    pub(crate) fn serve_inbound(&self, id: &NodeID, service: Handler) -> Result<()> {
148
62
        if *id == self.local_id {
149
0
            return Err(HubError::SelfConnection);
150
62
        }
151
152
62
        let entry = self.peers.entry_sync(*id);
153
62
        match entry {
154
4
            scc::hash_map::Entry::Occupied(peer_state) => {
155
4
                peer_state.inbound.serve(service);
156
4
            }
157
58
            scc::hash_map::Entry::Vacant(e) => {
158
58
                let inbound = Inbound::new();
159
58
                inbound.serve(service);
160
58
                e.insert_entry(PeerState {
161
58
                    outbound_endpoint: None,
162
58
                    inbound,
163
58
                    outbound: None,
164
58
                });
165
58
            }
166
        }
167
62
        Ok(())
168
62
    }
169
170
    /// Remove a peer from out hub
171
1
    pub async fn remove_peer(&self, id: &NodeID) -> Result<()> {
172
1
        let peer_state = self.peers.remove_async(id).await;
173
1
        if let Some((_id, mut peer)) = peer_state {
174
1
            if let Some(iconn) = peer.inbound.connection.as_mut() {
175
1
                iconn.disconnect_wait().await;
176
0
            }
177
1
            if let Some(outbound) = peer.outbound.as_ref() {
178
1
                outbound.disconnect_wait().await;
179
0
            }
180
0
        }
181
1
        Ok(())
182
1
    }
183
184
    /// Remove all peers from hub
185
1
    pub async fn remove_all_peers(&self) -> Result<()> {
186
1
        let peer_ids = self.get_peers().await;
187
1
        for peer_id in peer_ids {
188
1
            let _ = self.remove_peer(&peer_id).await;
189
        }
190
1
        Ok(())
191
1
    }
192
193
54
    pub fn outbound(&self, id: &NodeID) -> Result<RequestSender> {
194
54
        self.with_peer_state(id, |peer_state| {
195
54
            let outbound = peer_state.outbound.as_ref().ok_or(HubError::OutboundError)
?0
;
196
54
            Ok(RequestSender::new(outbound))
197
54
        })
198
54
    }
199
200
4
    pub fn rpc_outbound(&self, id: &NodeID) -> Result<RPCSender> {
201
4
        self.with_peer_state(id, |peer_state| {
202
4
            let outbound = peer_state.outbound.as_ref().ok_or(HubError::OutboundError)
?0
;
203
4
            Ok(RPCSender::new(
204
4
                peer_state.outbound_endpoint.clone().unwrap_or_default(),
205
4
                RequestSender::new(outbound),
206
4
            ))
207
4
        })
208
4
    }
209
210
3
    pub async fn get_peers(&self) -> Vec<NodeID> {
211
3
        let mut peers = Vec::<NodeID>::new();
212
3
        self.peers
213
3
            .iter_async(|id, _| 
{2
214
2
                peers.push(*id);
215
2
                true
216
2
            })
217
3
            .await;
218
3
        peers
219
3
    }
220
}
221
222
#[cfg(test)]
223
pub mod tests {
224
    use std::sync::Arc;
225
226
    use crate::{
227
        inbound::{BoxError, ConnectionInfo, Inbound, Request, ResponseBuilder, boxed_handler, handler_channel},
228
        outbound::RequestError,
229
        peermanager::PeerManager,
230
    };
231
    use bytes::Bytes;
232
    use http_body_util::BodyExt;
233
    use lyquor_test::test;
234
    use tokio::net::TcpListener;
235
    use tokio_util::sync::CancellationToken;
236
    use tower::ServiceExt;
237
238
    #[test(tokio::test)]
239
    // We add inbound and outbound, then we remove all peers, lastly we verify if connections are getting closed properly.
240
    async fn test_hub_peermanager() {
241
        let (local_node_id, local_tls_config) = lyquor_tls::generator::test_config(9);
242
        let (remote_node_id, remote_tls_config) = lyquor_tls::generator::test_config(10);
243
        let (_rpc_service, rpc_receiver) = handler_channel();
244
        let peermgr = PeerManager::new(
245
            local_node_id,
246
            local_tls_config.client_config().unwrap(),
247
            CancellationToken::new(),
248
            rpc_receiver.clone(),
249
        );
250
        let peermgr = Arc::new(peermgr);
251
252
        let remote_socket = lyquor_test::reserve_unused_socket();
253
        let remote_addr = remote_socket.local_addr();
254
        drop(remote_socket);
255
        let remote_listener = TcpListener::bind(&remote_addr).await.unwrap();
256
257
        // Add outbound
258
        let pm_clone = peermgr.clone();
259
        let remote_addr_clone = remote_addr.clone();
260
1
        let connect_task = tokio::spawn(async move {
261
1
            pm_clone
262
1
                .add_outbound(&remote_node_id, &remote_addr_clone)
263
1
                .await
264
1
                .unwrap();
265
1
        });
266
267
        let remote_acceptor = tokio_rustls::TlsAcceptor::from(remote_tls_config.server_config().unwrap());
268
        let (remote_conn, remote_peer_addr) = remote_listener.accept().await.unwrap();
269
        let remote_stream = remote_acceptor.accept(remote_conn).await.unwrap();
270
        let (_, remote_conn_ref) = remote_stream.get_ref();
271
        let remote_peer_id = lyquor_tls::get_nodeid_from_cert(remote_conn_ref.peer_certificates()).unwrap();
272
        assert_eq!(remote_peer_id, local_node_id);
273
274
        let remote_inbound = Inbound::new();
275
1
        remote_inbound.serve(boxed_handler(tower::service_fn(|req: Request| async move {
276
1
            let req_body = req.into_body().collect().await.unwrap().to_bytes();
277
1
            assert_eq!(req_body, Bytes::from("You should see this."));
278
1
            Ok::<_, BoxError>(ResponseBuilder::response("You should see this."))
279
2
        })));
280
        let mut i_conn =
281
            crate::inbound::InboundConnection::new(remote_inbound.service(), rpc_receiver, CancellationToken::new());
282
        i_conn.start_connection(ConnectionInfo {
283
            id: remote_peer_id,
284
            stream: remote_stream,
285
            addr: remote_peer_addr,
286
        });
287
        connect_task.await.unwrap();
288
289
        let sender = peermgr.outbound(&remote_node_id).unwrap();
290
        let send_task = tokio::spawn({
291
            let sender = sender.clone();
292
1
            async move {
293
1
                sender
294
1
                    .oneshot(hyper::Request::new(http_body_util::Full::new(Bytes::from(
295
1
                        "You should see this.",
296
1
                    ))))
297
1
                    .await
298
1
            }
299
        });
300
        send_task.await.unwrap().unwrap();
301
302
        let local_socket = lyquor_test::reserve_unused_socket();
303
        let local_addr = local_socket.local_addr();
304
        drop(local_socket);
305
        let local_listener = TcpListener::bind(&local_addr).await.unwrap();
306
307
        let (send_after_close_tx, send_after_close_rx) = tokio::sync::oneshot::channel::<()>();
308
        let (send_after_close_result_tx, send_after_close_result_rx) =
309
            tokio::sync::oneshot::channel::<std::result::Result<hyper::StatusCode, hyper::Error>>();
310
1
        let remote_inbound_task = tokio::spawn(async move {
311
1
            let connector = tokio_rustls::TlsConnector::from(remote_tls_config.client_config().unwrap());
312
1
            let tcp = tokio::net::TcpStream::connect(&local_addr).await.unwrap();
313
1
            let tls_stream = connector
314
1
                .connect(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), tcp)
315
1
                .await
316
1
                .unwrap();
317
318
1
            let io = hyper_util::rt::TokioIo::new(tls_stream);
319
1
            let (mut sender, conn) = hyper::client::conn::http2::handshake(hyper_util::rt::TokioExecutor::new(), io)
320
1
                .await
321
1
                .unwrap();
322
1
            let conn_task = tokio::spawn(async move {
323
1
                let _ = conn.await;
324
1
            });
325
326
1
            let _ = send_after_close_rx.await;
327
1
            let send_result = sender
328
1
                .send_request(hyper::Request::new(http_body_util::Full::new(Bytes::from(
329
1
                    "should get a bad statuscode",
330
1
                ))))
331
1
                .await
332
1
                .map(|res| 
res0
.
status0
());
333
1
            let _ = send_after_close_result_tx.send(send_result);
334
335
1
            conn_task.abort();
336
1
            let _ = conn_task.await;
337
1
        });
338
339
        // Add inbound
340
        let local_acceptor = tokio_rustls::TlsAcceptor::from(local_tls_config.server_config().unwrap());
341
        let (conn2, addr2) = local_listener.accept().await.unwrap();
342
        let stream2 = local_acceptor.accept(conn2).await.unwrap();
343
        let (_, conn2) = stream2.get_ref();
344
        let certs2 = conn2.peer_certificates();
345
        let id2 = lyquor_tls::get_nodeid_from_cert(certs2).unwrap();
346
        assert_eq!(id2, remote_node_id);
347
        peermgr
348
            .add_inbound(ConnectionInfo {
349
                id: id2,
350
                stream: stream2,
351
                addr: addr2,
352
            })
353
            .unwrap();
354
355
        assert_eq!(peermgr.get_peers().await.len(), 1);
356
357
        // Clean up
358
        let _ = peermgr.remove_all_peers().await;
359
        assert!(
360
            peermgr.get_peers().await.is_empty(),
361
            "PeerManager is not empty after removal."
362
        );
363
364
        // Check outbound is closed
365
        let resp = sender
366
            .oneshot(hyper::Request::new(http_body_util::Full::new(Bytes::from(
367
                "You should not see this.",
368
            ))))
369
            .await;
370
        assert!(
371
            matches!(resp, Err(RequestError::NotConnected)),
372
            "Connection is not closed: {resp:?}"
373
        );
374
375
        // Check inbound connection is closed
376
        let _ = send_after_close_tx.send(());
377
        let post_close_result = send_after_close_result_rx.await.unwrap();
378
        assert!(post_close_result.is_err(), "Inbound connection is not closed.");
379
        remote_inbound_task.await.unwrap();
380
    }
381
}