Coverage Report

Created: 2025-12-04 08:57

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/runner/work/lyquor/lyquor/net/src/hub.rs
Line
Count
Source
1
pub use crate::listener::{RequestHandler, RequestReceiver};
2
use crate::pool::Pool;
3
pub use crate::pool::RequestSender;
4
use lyquor_primitives::NodeID;
5
use lyquor_tls::TlsConfig;
6
use std::collections::HashMap;
7
use std::sync::Arc;
8
use thiserror::Error;
9
10
#[derive(Debug, Error)]
11
pub enum HubError {
12
    #[error("invalid address")]
13
    InvalidAddr,
14
    #[error("peer not exist")]
15
    PeerNotExist,
16
    #[error("TLS error: {0}")]
17
    TlsError(#[from] lyquor_tls::TlsError),
18
    #[error("unknown error")]
19
    Unknown,
20
}
21
22
type Result<T> = std::result::Result<T, HubError>;
23
24
pub struct Hub {
25
    id: NodeID,
26
    peers: HashMap<NodeID, crate::pool::Pool>,
27
    listener: crate::listener::Listener,
28
    runtime: tokio::runtime::Handle,
29
    spawner: crate::event::AsyncSpawner,
30
    tls_config: TlsConfig,
31
}
32
33
impl std::fmt::Debug for Hub {
34
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35
0
        f.debug_struct("Hub").field("id", &self.id).finish_non_exhaustive()
36
0
    }
37
}
38
39
impl Hub {
40
    #[tracing::instrument(level = "trace", skip(tls_config))]
41
30
    fn new(id: NodeID, listen_addr: String, tls_config: TlsConfig) -> Result<Self> {
42
30
        let runtime = tokio::runtime::Handle::current(); // TODO: use our runtime
43
30
        let server_config = tls_config.server_config().map_err(|e| HubError::TlsError(
e0
))
?0
;
44
30
        let listener = crate::listener::Listener::new(id, server_config, listen_addr);
45
30
        Ok(Hub {
46
30
            id,
47
30
            peers: HashMap::new(),
48
30
            listener: listener,
49
30
            runtime: runtime.clone(),
50
30
            spawner: crate::event::AsyncSpawner { handle: runtime },
51
30
            tls_config: tls_config,
52
30
        })
53
30
    }
54
55
26
    pub fn spawner(&self) -> crate::event::AsyncSpawner {
56
26
        self.spawner.clone()
57
26
    }
58
59
26
    pub fn get_id(&self) -> NodeID {
60
26
        self.id.clone()
61
26
    }
62
63
26
    pub fn signing_key(&self) -> Arc<lyquor_tls::NodeSigningKey> {
64
26
        self.tls_config.signing_key()
65
26
    }
66
67
30
    pub async fn start(&mut self) -> Result<()> {
68
30
        match self.listener.start().await {
69
            // listener.start() only returns bind error
70
0
            Err(_) => Err(HubError::InvalidAddr),
71
30
            Ok(_) => Ok(()),
72
        }
73
30
    }
74
75
    #[tracing::instrument(level = "trace", ret, err)]
76
62
    pub async fn add_peer(&mut self, id: NodeID, addr: String) -> Result<()> {
77
        tracing::info!("add peer for {:?}: {:?} -> {:?}", self.id, id, addr);
78
        self.listener.add_peer(id.clone()).await;
79
80
        let config = crate::pool::PoolConfig::builder()
81
            .tls_config(self.tls_config.client_config().unwrap())
82
            .addr(addr)
83
            .build();
84
        let pool = Pool::new(config);
85
        let _ = pool.start(self.runtime.clone()).await;
86
87
        // TODO: make sure the pool connected to the node matches the id.
88
        self.peers.insert(id, pool);
89
        Ok(())
90
62
    }
91
92
    #[tracing::instrument(level = "trace", ret, err(level = "debug"))]
93
0
    pub async fn remove_peer(&mut self, id: NodeID) -> Result<()> {
94
        self.peers.remove(&id);
95
        Ok(())
96
0
    }
97
98
    #[tracing::instrument(level = "trace", err(level = "debug"))]
99
112
    pub fn outbound(&self, id: NodeID) -> Result<RequestSender> {
100
        let pool = self.peers.get(&id).ok_or_else(|| HubError::PeerNotExist)?;
101
        Ok(RequestSender::new(self.id.clone(), pool))
102
112
    }
103
104
    #[tracing::instrument(level = "trace", err(level = "debug"))]
105
60
    pub fn inbound(&self, id: NodeID) -> Result<RequestReceiver> {
106
        self.listener.inbound(id).map_err(|_e| HubError::Unknown)
107
60
    }
108
}
109
110
impl Drop for Hub {
111
30
    fn drop(&mut self) {
112
30
        self.listener.stop();
113
30
    }
114
}
115
116
pub struct HubBuilder {
117
    tls_config: TlsConfig,
118
    listen_addr: String,
119
}
120
121
impl HubBuilder {
122
30
    pub fn new(tls_config: TlsConfig) -> Self {
123
30
        HubBuilder {
124
30
            tls_config: tls_config,
125
30
            listen_addr: "127.0.0.1:0".to_string(),
126
30
        }
127
30
    }
128
129
30
    pub fn listen_addr(mut self, addr: String) -> Self {
130
30
        self.listen_addr = addr;
131
30
        self
132
30
    }
133
134
30
    pub fn build(self) -> Result<Hub> {
135
30
        Hub::new(self.tls_config.node_id(), self.listen_addr, self.tls_config)
136
30
    }
137
}
138
139
#[cfg(test)]
140
pub(crate) mod tests {
141
    use super::*;
142
    use crate::listener::ResponseBuilder;
143
    use bytes::Bytes;
144
    use http_body_util::BodyExt;
145
    use lyquor_test::test;
146
    use std::sync::Arc;
147
    use tokio::sync::Mutex;
148
149
4
    fn generate_unused_socket() -> String {
150
4
        std::net::TcpListener::bind("127.0.0.1:0")
151
4
            .unwrap()
152
4
            .local_addr()
153
4
            .unwrap()
154
4
            .to_string()
155
4
    }
156
157
    #[test(tokio::test)]
158
    async fn test_hub() {
159
        let mut nodes = Vec::new();
160
161
        for i in 0u8..2 {
162
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
163
            let addr = generate_unused_socket();
164
            let mut hub = HubBuilder::new(tls_config).listen_addr(addr.clone()).build().unwrap();
165
            let _ = hub.start().await;
166
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub))));
167
        }
168
169
        for (id, _, hub) in &nodes {
170
            for (peer_id, peer_addr, _) in &nodes {
171
                if id == peer_id {
172
                    continue;
173
                }
174
                let mut hub = hub.lock().await;
175
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
176
            }
177
        }
178
179
        let first = nodes[0].2.lock().await;
180
        let second = nodes[1].2.lock().await;
181
182
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
183
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
184
185
1
        tokio::spawn(async move {
186
            loop {
187
2
                let 
req1
= responder.recv().await.
unwrap1
();
188
1
                let req_body = req.request.collect().await.unwrap().to_bytes();
189
1
                assert_eq!(req_body, Bytes::from("Hello, World!"));
190
1
                let res = ResponseBuilder::response("Hello, World!");
191
1
                req.response.send(res).unwrap();
192
            }
193
        });
194
        let response = requester
195
            .send_request(hyper::Request::new(http_body_util::Full::new(Bytes::from(
196
                "Hello, World!",
197
            ))))
198
            .await;
199
        dbg!(&response);
200
        assert!(response.is_ok());
201
202
        let response = response.unwrap();
203
204
        assert_eq!(response.status(), hyper::StatusCode::OK);
205
        let resp_body = response.collect().await.unwrap().to_bytes();
206
        assert_eq!(resp_body, Bytes::from("Hello, World!"));
207
    }
208
209
    #[test(tokio::test)]
210
    async fn test_cc() {
211
        use async_trait::async_trait;
212
        use hello::hello_client::HelloClient;
213
        use hello::hello_server::{Hello, HelloServer};
214
        use hello::{Command, Empty, HelloMsg, RespMsg};
215
        use rand::Rng;
216
        use rand::RngCore;
217
218
        use lyquor_net_rpc::TowerService;
219
        pub mod hello {
220
            lyquor_net_rpc::include_proto!("lyquor.hello");
221
        }
222
223
2
        fn gen_hellomsg(pld: Bytes) -> HelloMsg {
224
2
            HelloMsg {
225
2
                v_int32: 1,
226
2
                v_int64: 2,
227
2
                v_uint32: 3,
228
2
                v_uint64: 4,
229
2
                v_sint32: 5,
230
2
                v_sint64: 6,
231
2
                v_fixed32: 7,
232
2
                v_fixed64: 8,
233
2
                v_sfixed32: 9,
234
2
                v_sfixed64: 10,
235
2
                v_float: 11.0,
236
2
                v_double: 12.0,
237
2
                v_bool: true,
238
2
                v_string: "HelloMsg".to_string(),
239
2
                v_bytes: [1, 2, 3].to_vec(),
240
2
                cmd: Command::Pong.into(),
241
2
                len: pld.len() as u32,
242
2
                payload: pld.to_vec(),
243
2
            }
244
2
        }
245
246
        #[derive(Debug, Default)]
247
        pub struct HelloImpl {}
248
249
        #[async_trait]
250
        impl Hello for HelloImpl {
251
            async fn hello(
252
                &self, request: hyper::Request<HelloMsg>,
253
1
            ) -> std::result::Result<hyper::Response<HelloMsg>, String> {
254
                let (_, msg) = request.into_parts();
255
                println!("hello got a request: {:?}", msg);
256
257
                let mut rng = rand::rng();
258
                let mut pld = vec![0u8; rand::rng().random_range(100..=1000)];
259
                rng.fill_bytes(&mut pld);
260
261
                Ok(hyper::Response::new(gen_hellomsg(pld.into())))
262
1
            }
263
264
            async fn hello_empty(
265
                &self, request: hyper::Request<HelloMsg>,
266
1
            ) -> std::result::Result<hyper::Response<Empty>, String> {
267
                let (_, msg) = request.into_parts();
268
                println!("hello_empty got a request: {:?}", msg);
269
                let reply = Empty {};
270
                Ok(hyper::Response::new(reply))
271
1
            }
272
273
            async fn hello_resp(
274
                &self, request: hyper::Request<HelloMsg>,
275
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
276
                let (_, msg) = request.into_parts();
277
                println!("hello_resp got a request: {:?}", msg);
278
279
                let mut rng = rand::rng();
280
                let mut pld = vec![0u8; rand::rng().random_range(10..=20)];
281
                rng.fill_bytes(&mut pld);
282
283
                let reply = RespMsg {
284
                    cmd: Command::Pong.into(),
285
                    len: pld.len() as u32,
286
                    payload: pld,
287
                };
288
                Ok(hyper::Response::new(reply))
289
1
            }
290
291
            async fn hello_long(
292
                &self, request: hyper::Request<HelloMsg>,
293
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
294
                let (_, msg) = request.into_parts();
295
                println!("hello_long got a request: {:?}", msg);
296
297
                let mut rng = rand::rng();
298
                // 10M - 100M payload
299
                let mut pld = vec![0u8; rand::rng().random_range(10 * 1024 * 1024..=100 * 1024 * 1024)];
300
                rng.fill_bytes(&mut pld);
301
302
                let reply = RespMsg {
303
                    cmd: Command::Pong.into(),
304
                    len: pld.len() as u32,
305
                    payload: pld,
306
                };
307
                Ok(hyper::Response::new(reply))
308
1
            }
309
310
            async fn hello_error(
311
                &self, _: hyper::Request<HelloMsg>,
312
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
313
                Err("Error!".to_string())
314
1
            }
315
        }
316
317
        let mut nodes = Vec::new();
318
319
        for i in 0u8..2 {
320
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
321
            let addr = generate_unused_socket();
322
            let mut hub = HubBuilder::new(tls_config).listen_addr(addr.clone()).build().unwrap();
323
            let _ = hub.start().await;
324
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub))));
325
        }
326
327
        for (id, _, hub) in &nodes {
328
            for (peer_id, peer_addr, _) in &nodes {
329
                if id == peer_id {
330
                    continue;
331
                }
332
                let mut hub = hub.lock().await;
333
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
334
            }
335
        }
336
337
        let first = nodes[0].2.lock().await;
338
        let second = nodes[1].2.lock().await;
339
340
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
341
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
342
343
        let mut hello_srv = HelloServer::new(HelloImpl::default());
344
1
        tokio::spawn(async move {
345
            loop {
346
6
                let 
req5
= responder.recv().await.
unwrap5
();
347
5
                println!("{:?}", req.request);
348
5
                assert_eq!(req.request.headers().get("lyquor-pkt-type").unwrap(), "RPC");
349
350
5
                let res = hello_srv.call(req.request).await.unwrap();
351
5
                req.response.send(res).unwrap();
352
            }
353
        });
354
355
        let msg = gen_hellomsg([0].to_vec().into());
356
357
        let mut client = HelloClient::new(requester);
358
        let resp = client.hello(msg.clone()).await;
359
        let response = resp.unwrap();
360
        assert_eq!(response.len, response.payload.len() as u32);
361
        println!("hello RESP: {:?}", response);
362
363
        let resp = client.hello_empty(msg.clone()).await;
364
        let response = resp.unwrap();
365
        println!("hello_empty RESP: {:?}", response);
366
367
        let resp = client.hello_resp(msg.clone()).await;
368
        let response = resp.unwrap();
369
        assert_eq!(response.len, response.payload.len() as u32);
370
        println!("hello_resp RESP: {:?}", response);
371
372
        let resp = client.hello_long(msg.clone()).await;
373
        let response = resp.unwrap();
374
        assert_eq!(response.len, response.payload.len() as u32);
375
        println!("hello_resp RESP length {:?}KB", response.len / 1023);
376
377
        let resp = client.hello_error(msg.clone()).await;
378
        match resp {
379
            Ok(_) => {
380
                // Shouldn't see this
381
                assert_eq!(0, 1);
382
            }
383
            Err(e) => {
384
                println!("{:?}", e);
385
            }
386
        };
387
    }
388
}