Coverage Report

Created: 2026-03-22 03:56

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/runner/work/lyquor/lyquor/net/src/hub.rs
Line
Count
Source
1
pub use crate::listener::{RequestHandler, RequestReceiver};
2
use crate::pool::Pool;
3
pub use crate::pool::RequestSender;
4
use lyquor_primitives::NodeID;
5
use lyquor_tls::TlsConfig;
6
use std::collections::HashMap;
7
use std::sync::Arc;
8
use thiserror::Error;
9
use tokio_util::sync::CancellationToken;
10
11
#[derive(Debug, Error)]
12
pub enum HubError {
13
    #[error("failed to bind listener on {addr}: {source}")]
14
    BindAddr {
15
        addr: String,
16
        #[source]
17
        source: Box<dyn std::error::Error + Send + Sync>,
18
    },
19
    #[error("peer not exist")]
20
    PeerNotExist,
21
    #[error("TLS error: {0}")]
22
    TlsError(#[from] lyquor_tls::TlsError),
23
    #[error("unknown error")]
24
    Unknown,
25
}
26
27
type Result<T> = std::result::Result<T, HubError>;
28
29
pub struct Hub {
30
    id: NodeID,
31
    peers: HashMap<NodeID, crate::pool::Pool>,
32
    listener: crate::listener::Listener,
33
    shutdown: CancellationToken,
34
    runtime: tokio::runtime::Handle,
35
    spawner: crate::event::AsyncSpawner,
36
    tls_config: TlsConfig,
37
}
38
39
impl std::fmt::Debug for Hub {
40
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41
0
        f.debug_struct("Hub").field("id", &self.id).finish_non_exhaustive()
42
0
    }
43
}
44
45
impl Hub {
46
    #[tracing::instrument(level = "trace", skip(tls_config))]
47
41
    fn new(id: NodeID, listen_addr: String, tls_config: TlsConfig, shutdown: CancellationToken) -> Result<Self> {
48
41
        let runtime = tokio::runtime::Handle::current(); // TODO: use our runtime
49
41
        let server_config = tls_config.server_config().map_err(HubError::TlsError)
?0
;
50
41
        let listener =
51
41
            crate::listener::Listener::new_with_shutdown(id, server_config, listen_addr, shutdown.child_token());
52
41
        Ok(Hub {
53
41
            id,
54
41
            peers: HashMap::new(),
55
41
            listener,
56
41
            shutdown,
57
41
            runtime: runtime.clone(),
58
41
            spawner: crate::event::AsyncSpawner { handle: runtime },
59
41
            tls_config,
60
41
        })
61
41
    }
62
63
36
    pub fn spawner(&self) -> crate::event::AsyncSpawner {
64
36
        self.spawner.clone()
65
36
    }
66
67
36
    pub fn get_id(&self) -> NodeID {
68
36
        self.id
69
36
    }
70
71
36
    pub fn signing_key(&self) -> Arc<lyquor_tls::NodeSigningKey> {
72
36
        self.tls_config.signing_key()
73
36
    }
74
75
41
    pub async fn start(&mut self) -> Result<()> {
76
41
        match self.listener.start().await {
77
0
            Err(source) => Err(HubError::BindAddr {
78
0
                addr: self.listener.addr().to_string(),
79
0
                source,
80
0
            }),
81
41
            Ok(_) => Ok(()),
82
        }
83
41
    }
84
85
    #[tracing::instrument(level = "trace", ret, err)]
86
62
    pub async fn add_peer(&mut self, id: NodeID, addr: String) -> Result<()> {
87
        tracing::info!("add peer for {:?}: {:?} -> {:?}", self.id, id, addr);
88
        self.listener.add_peer(id).await;
89
90
        let config = crate::pool::PoolConfig::builder()
91
            .tls_config(self.tls_config.client_config().unwrap())
92
            .addr(addr)
93
            .node_id(id)
94
            .build();
95
        let pool = Pool::new_with_shutdown(config, self.shutdown.child_token());
96
        let _ = pool.start(self.runtime.clone()).await;
97
98
        self.peers.insert(id, pool);
99
        Ok(())
100
62
    }
101
102
    #[tracing::instrument(level = "trace", ret, err(level = "debug"))]
103
0
    pub async fn remove_peer(&mut self, id: NodeID) -> Result<()> {
104
        self.peers.remove(&id);
105
        Ok(())
106
0
    }
107
108
    #[tracing::instrument(level = "trace", err(level = "debug"))]
109
109
    pub fn outbound(&self, id: NodeID) -> Result<RequestSender> {
110
        let pool = self.peers.get(&id).ok_or_else(|| HubError::PeerNotExist)?;
111
        Ok(RequestSender::new(self.id, pool))
112
109
    }
113
114
    #[tracing::instrument(level = "trace", err(level = "debug"))]
115
60
    pub fn inbound(&self, id: NodeID) -> Result<RequestReceiver> {
116
        self.listener.inbound(id).map_err(|_e| HubError::Unknown)
117
60
    }
118
}
119
120
impl Drop for Hub {
121
41
    fn drop(&mut self) {
122
41
        self.shutdown.cancel();
123
41
        self.listener.stop();
124
41
    }
125
}
126
127
pub struct HubBuilder {
128
    tls_config: TlsConfig,
129
    listen_addr: String,
130
    shutdown: CancellationToken,
131
}
132
133
impl HubBuilder {
134
41
    pub fn new(tls_config: TlsConfig, shutdown: CancellationToken) -> Self {
135
41
        HubBuilder {
136
41
            tls_config,
137
41
            listen_addr: "127.0.0.1:0".to_string(),
138
41
            shutdown,
139
41
        }
140
41
    }
141
142
41
    pub fn listen_addr(mut self, addr: String) -> Self {
143
41
        self.listen_addr = addr;
144
41
        self
145
41
    }
146
147
41
    pub fn build(self) -> Result<Hub> {
148
41
        Hub::new(
149
41
            self.tls_config.node_id(),
150
41
            self.listen_addr,
151
41
            self.tls_config,
152
41
            self.shutdown,
153
        )
154
41
    }
155
}
156
157
#[cfg(test)]
158
pub(crate) mod tests {
159
    use super::*;
160
    use crate::listener::ResponseBuilder;
161
    use bytes::Bytes;
162
    use http_body_util::BodyExt;
163
    use lyquor_test::test;
164
    use std::sync::Arc;
165
    use std::time::Duration;
166
    use tokio::sync::Mutex;
167
    use tokio_util::sync::CancellationToken;
168
169
    use lyquor_test::generate_unused_socket;
170
171
    #[test(tokio::test)]
172
    async fn test_hub() {
173
        let mut nodes = Vec::new();
174
175
        for i in 0u8..2 {
176
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
177
            let addr = generate_unused_socket();
178
            let shutdown = CancellationToken::new();
179
            let mut hub = HubBuilder::new(tls_config, shutdown.clone())
180
                .listen_addr(addr.clone())
181
                .build()
182
                .unwrap();
183
            let _ = hub.start().await;
184
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub)), shutdown));
185
        }
186
187
        for (id, _, hub, _) in &nodes {
188
            for (peer_id, peer_addr, _, _) in &nodes {
189
                if id == peer_id {
190
                    continue;
191
                }
192
                let mut hub = hub.lock().await;
193
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
194
            }
195
        }
196
197
        let first = nodes[0].2.lock().await;
198
        let second = nodes[1].2.lock().await;
199
200
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
201
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
202
203
1
        tokio::spawn(async move {
204
            loop {
205
2
                let 
req1
= responder.recv().await.
unwrap1
();
206
1
                let req_body = req.request.collect().await.unwrap().to_bytes();
207
1
                assert_eq!(req_body, Bytes::from("Hello, World!"));
208
1
                let res = ResponseBuilder::response("Hello, World!");
209
1
                req.response.send(res).unwrap();
210
            }
211
        });
212
        let response = requester
213
            .send_request(hyper::Request::new(http_body_util::Full::new(Bytes::from(
214
                "Hello, World!",
215
            ))))
216
            .await;
217
        dbg!(&response);
218
        assert!(response.is_ok());
219
220
        let response = response.unwrap();
221
222
        assert_eq!(response.status(), hyper::StatusCode::OK);
223
        let resp_body = response.collect().await.unwrap().to_bytes();
224
        assert_eq!(resp_body, Bytes::from("Hello, World!"));
225
226
        let pool = first.peers.get(&nodes[1].0).expect("pool should be registered");
227
        pool.wait_until_connected().await;
228
229
        nodes[0].3.cancel();
230
231
1
        tokio::time::timeout(Duration::from_secs(5), async {
232
            loop {
233
1
                let req = hyper::Request::new(http_body_util::Full::new(Bytes::from_static(b"ping")));
234
1
                match requester.send_request(req).await {
235
1
                    Err(crate::pool::RequestError::NotConnected) => break,
236
0
                    _ => tokio::time::sleep(Duration::from_millis(20)).await,
237
                }
238
            }
239
1
        })
240
        .await
241
        .expect("pool should stop after shutdown");
242
243
1
        tokio::time::timeout(Duration::from_secs(5), async {
244
            loop {
245
1
                if tokio::net::TcpStream::connect(&nodes[0].1).await.is_err() {
246
1
                    break;
247
0
                }
248
0
                tokio::time::sleep(Duration::from_millis(20)).await;
249
            }
250
1
        })
251
        .await
252
        .expect("listener should stop after shutdown");
253
    }
254
255
    #[test(tokio::test)]
256
    async fn test_cc() {
257
        use async_trait::async_trait;
258
        use hello::hello_client::HelloClient;
259
        use hello::hello_server::{Hello, HelloServer};
260
        use hello::{Command, Empty, HelloMsg, RespMsg};
261
        use rand::Rng;
262
        use rand::RngExt;
263
264
        use lyquor_net_rpc::TowerService;
265
        pub mod hello {
266
            lyquor_net_rpc::include_proto!("lyquor.hello");
267
        }
268
269
2
        fn gen_hellomsg(pld: Bytes) -> HelloMsg {
270
2
            HelloMsg {
271
2
                v_int32: 1,
272
2
                v_int64: 2,
273
2
                v_uint32: 3,
274
2
                v_uint64: 4,
275
2
                v_sint32: 5,
276
2
                v_sint64: 6,
277
2
                v_fixed32: 7,
278
2
                v_fixed64: 8,
279
2
                v_sfixed32: 9,
280
2
                v_sfixed64: 10,
281
2
                v_float: 11.0,
282
2
                v_double: 12.0,
283
2
                v_bool: true,
284
2
                v_string: "HelloMsg".to_string(),
285
2
                v_bytes: [1, 2, 3].to_vec(),
286
2
                cmd: Command::Pong.into(),
287
2
                len: pld.len() as u32,
288
2
                payload: pld.to_vec(),
289
2
            }
290
2
        }
291
292
        #[derive(Debug, Default)]
293
        pub struct HelloImpl {}
294
295
        #[async_trait]
296
        impl Hello for HelloImpl {
297
            async fn hello(
298
                &self, request: hyper::Request<HelloMsg>,
299
1
            ) -> std::result::Result<hyper::Response<HelloMsg>, String> {
300
                let (_, msg) = request.into_parts();
301
                println!("hello got a request: {:?}", msg);
302
303
                let mut rng = rand::rng();
304
                let mut pld = vec![0u8; rand::rng().random_range(100..=1000)];
305
                rng.fill_bytes(&mut pld);
306
307
                Ok(hyper::Response::new(gen_hellomsg(pld.into())))
308
1
            }
309
310
            async fn hello_empty(
311
                &self, request: hyper::Request<HelloMsg>,
312
1
            ) -> std::result::Result<hyper::Response<Empty>, String> {
313
                let (_, msg) = request.into_parts();
314
                println!("hello_empty got a request: {:?}", msg);
315
                let reply = Empty {};
316
                Ok(hyper::Response::new(reply))
317
1
            }
318
319
            async fn hello_resp(
320
                &self, request: hyper::Request<HelloMsg>,
321
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
322
                let (_, msg) = request.into_parts();
323
                println!("hello_resp got a request: {:?}", msg);
324
325
                let mut rng = rand::rng();
326
                let mut pld = vec![0u8; rand::rng().random_range(10..=20)];
327
                rng.fill_bytes(&mut pld);
328
329
                let reply = RespMsg {
330
                    cmd: Command::Pong.into(),
331
                    len: pld.len() as u32,
332
                    payload: pld,
333
                };
334
                Ok(hyper::Response::new(reply))
335
1
            }
336
337
            async fn hello_long(
338
                &self, request: hyper::Request<HelloMsg>,
339
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
340
                let (_, msg) = request.into_parts();
341
                println!("hello_long got a request: {:?}", msg);
342
343
                let mut rng = rand::rng();
344
                // 10M - 100M payload
345
                let mut pld = vec![0u8; rand::rng().random_range(10 * 1024 * 1024..=100 * 1024 * 1024)];
346
                rng.fill_bytes(&mut pld);
347
348
                let reply = RespMsg {
349
                    cmd: Command::Pong.into(),
350
                    len: pld.len() as u32,
351
                    payload: pld,
352
                };
353
                Ok(hyper::Response::new(reply))
354
1
            }
355
356
            async fn hello_error(
357
                &self, _: hyper::Request<HelloMsg>,
358
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
359
                Err("Error!".to_string())
360
1
            }
361
        }
362
363
        let mut nodes = Vec::new();
364
365
        for i in 0u8..2 {
366
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
367
            let addr = generate_unused_socket();
368
            let mut hub = HubBuilder::new(tls_config, CancellationToken::new())
369
                .listen_addr(addr.clone())
370
                .build()
371
                .unwrap();
372
            let _ = hub.start().await;
373
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub))));
374
        }
375
376
        for (id, _, hub) in &nodes {
377
            for (peer_id, peer_addr, _) in &nodes {
378
                if id == peer_id {
379
                    continue;
380
                }
381
                let mut hub = hub.lock().await;
382
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
383
            }
384
        }
385
386
        let first = nodes[0].2.lock().await;
387
        let second = nodes[1].2.lock().await;
388
389
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
390
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
391
392
        let mut hello_srv = HelloServer::new(HelloImpl::default());
393
1
        tokio::spawn(async move {
394
            loop {
395
6
                let 
req5
= responder.recv().await.
unwrap5
();
396
5
                println!("{:?}", req.request);
397
5
                assert_eq!(req.request.headers().get("lyquor-pkt-type").unwrap(), "RPC");
398
399
5
                let res = hello_srv.call(req.request).await.unwrap();
400
5
                req.response.send(res).unwrap();
401
            }
402
        });
403
404
        let msg = gen_hellomsg([0].to_vec().into());
405
406
        let mut client = HelloClient::new(requester);
407
        let resp = client.hello(msg.clone()).await;
408
        let response = resp.unwrap();
409
        assert_eq!(response.len, response.payload.len() as u32);
410
        println!("hello RESP: {:?}", response);
411
412
        let resp = client.hello_empty(msg.clone()).await;
413
        let response = resp.unwrap();
414
        println!("hello_empty RESP: {:?}", response);
415
416
        let resp = client.hello_resp(msg.clone()).await;
417
        let response = resp.unwrap();
418
        assert_eq!(response.len, response.payload.len() as u32);
419
        println!("hello_resp RESP: {:?}", response);
420
421
        let resp = client.hello_long(msg.clone()).await;
422
        let response = resp.unwrap();
423
        assert_eq!(response.len, response.payload.len() as u32);
424
        println!("hello_resp RESP length {:?}KB", response.len / 1023);
425
426
        let resp = client.hello_error(msg.clone()).await;
427
        match resp {
428
            Ok(_) => {
429
                // Shouldn't see this
430
                assert_eq!(0, 1);
431
            }
432
            Err(e) => {
433
                println!("{:?}", e);
434
            }
435
        };
436
    }
437
}