Coverage Report

Created: 2026-02-04 05:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/home/runner/work/lyquor/lyquor/net/src/hub.rs
Line
Count
Source
1
pub use crate::listener::{RequestHandler, RequestReceiver};
2
use crate::pool::Pool;
3
pub use crate::pool::RequestSender;
4
use lyquor_primitives::NodeID;
5
use lyquor_tls::TlsConfig;
6
use std::collections::HashMap;
7
use std::sync::Arc;
8
use thiserror::Error;
9
10
#[derive(Debug, Error)]
11
pub enum HubError {
12
    #[error("invalid address")]
13
    InvalidAddr,
14
    #[error("peer not exist")]
15
    PeerNotExist,
16
    #[error("TLS error: {0}")]
17
    TlsError(#[from] lyquor_tls::TlsError),
18
    #[error("unknown error")]
19
    Unknown,
20
}
21
22
type Result<T> = std::result::Result<T, HubError>;
23
24
pub struct Hub {
25
    id: NodeID,
26
    peers: HashMap<NodeID, crate::pool::Pool>,
27
    listener: crate::listener::Listener,
28
    runtime: tokio::runtime::Handle,
29
    spawner: crate::event::AsyncSpawner,
30
    tls_config: TlsConfig,
31
}
32
33
impl std::fmt::Debug for Hub {
34
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35
0
        f.debug_struct("Hub").field("id", &self.id).finish_non_exhaustive()
36
0
    }
37
}
38
39
impl Hub {
40
    #[tracing::instrument(level = "trace", skip(tls_config))]
41
35
    fn new(id: NodeID, listen_addr: String, tls_config: TlsConfig) -> Result<Self> {
42
35
        let runtime = tokio::runtime::Handle::current(); // TODO: use our runtime
43
35
        let server_config = tls_config.server_config().map_err(|e| HubError::TlsError(
e0
))
?0
;
44
35
        let listener = crate::listener::Listener::new(id, server_config, listen_addr);
45
35
        Ok(Hub {
46
35
            id,
47
35
            peers: HashMap::new(),
48
35
            listener: listener,
49
35
            runtime: runtime.clone(),
50
35
            spawner: crate::event::AsyncSpawner { handle: runtime },
51
35
            tls_config: tls_config,
52
35
        })
53
35
    }
54
55
30
    pub fn spawner(&self) -> crate::event::AsyncSpawner {
56
30
        self.spawner.clone()
57
30
    }
58
59
30
    pub fn get_id(&self) -> NodeID {
60
30
        self.id.clone()
61
30
    }
62
63
30
    pub fn signing_key(&self) -> Arc<lyquor_tls::NodeSigningKey> {
64
30
        self.tls_config.signing_key()
65
30
    }
66
67
35
    pub async fn start(&mut self) -> Result<()> {
68
35
        match self.listener.start().await {
69
            // listener.start() only returns bind error
70
0
            Err(_) => Err(HubError::InvalidAddr),
71
35
            Ok(_) => Ok(()),
72
        }
73
35
    }
74
75
    #[tracing::instrument(level = "trace", ret, err)]
76
62
    pub async fn add_peer(&mut self, id: NodeID, addr: String) -> Result<()> {
77
        tracing::info!("add peer for {:?}: {:?} -> {:?}", self.id, id, addr);
78
        self.listener.add_peer(id.clone()).await;
79
80
        let config = crate::pool::PoolConfig::builder()
81
            .tls_config(self.tls_config.client_config().unwrap())
82
            .addr(addr)
83
            .node_id(id)
84
            .build();
85
        let pool = Pool::new(config);
86
        let _ = pool.start(self.runtime.clone()).await;
87
88
        self.peers.insert(id, pool);
89
        Ok(())
90
62
    }
91
92
    #[tracing::instrument(level = "trace", ret, err(level = "debug"))]
93
0
    pub async fn remove_peer(&mut self, id: NodeID) -> Result<()> {
94
        self.peers.remove(&id);
95
        Ok(())
96
0
    }
97
98
    #[tracing::instrument(level = "trace", err(level = "debug"))]
99
112
    pub fn outbound(&self, id: NodeID) -> Result<RequestSender> {
100
        let pool = self.peers.get(&id).ok_or_else(|| HubError::PeerNotExist)?;
101
        Ok(RequestSender::new(self.id.clone(), pool))
102
112
    }
103
104
    #[tracing::instrument(level = "trace", err(level = "debug"))]
105
60
    pub fn inbound(&self, id: NodeID) -> Result<RequestReceiver> {
106
        self.listener.inbound(id).map_err(|_e| HubError::Unknown)
107
60
    }
108
}
109
110
impl Drop for Hub {
111
35
    fn drop(&mut self) {
112
35
        self.listener.stop();
113
35
    }
114
}
115
116
pub struct HubBuilder {
117
    tls_config: TlsConfig,
118
    listen_addr: String,
119
}
120
121
impl HubBuilder {
122
35
    pub fn new(tls_config: TlsConfig) -> Self {
123
35
        HubBuilder {
124
35
            tls_config: tls_config,
125
35
            listen_addr: "127.0.0.1:0".to_string(),
126
35
        }
127
35
    }
128
129
35
    pub fn listen_addr(mut self, addr: String) -> Self {
130
35
        self.listen_addr = addr;
131
35
        self
132
35
    }
133
134
35
    pub fn build(self) -> Result<Hub> {
135
35
        Hub::new(self.tls_config.node_id(), self.listen_addr, self.tls_config)
136
35
    }
137
}
138
139
#[cfg(test)]
140
pub(crate) mod tests {
141
    use super::*;
142
    use crate::listener::ResponseBuilder;
143
    use bytes::Bytes;
144
    use http_body_util::BodyExt;
145
    use lyquor_test::test;
146
    use std::sync::Arc;
147
    use tokio::sync::Mutex;
148
149
    use lyquor_test::generate_unused_socket;
150
151
    #[test(tokio::test)]
152
    async fn test_hub() {
153
        let mut nodes = Vec::new();
154
155
        for i in 0u8..2 {
156
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
157
            let addr = generate_unused_socket();
158
            let mut hub = HubBuilder::new(tls_config).listen_addr(addr.clone()).build().unwrap();
159
            let _ = hub.start().await;
160
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub))));
161
        }
162
163
        for (id, _, hub) in &nodes {
164
            for (peer_id, peer_addr, _) in &nodes {
165
                if id == peer_id {
166
                    continue;
167
                }
168
                let mut hub = hub.lock().await;
169
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
170
            }
171
        }
172
173
        let first = nodes[0].2.lock().await;
174
        let second = nodes[1].2.lock().await;
175
176
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
177
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
178
179
1
        tokio::spawn(async move {
180
            loop {
181
2
                let 
req1
= responder.recv().await.
unwrap1
();
182
1
                let req_body = req.request.collect().await.unwrap().to_bytes();
183
1
                assert_eq!(req_body, Bytes::from("Hello, World!"));
184
1
                let res = ResponseBuilder::response("Hello, World!");
185
1
                req.response.send(res).unwrap();
186
            }
187
        });
188
        let response = requester
189
            .send_request(hyper::Request::new(http_body_util::Full::new(Bytes::from(
190
                "Hello, World!",
191
            ))))
192
            .await;
193
        dbg!(&response);
194
        assert!(response.is_ok());
195
196
        let response = response.unwrap();
197
198
        assert_eq!(response.status(), hyper::StatusCode::OK);
199
        let resp_body = response.collect().await.unwrap().to_bytes();
200
        assert_eq!(resp_body, Bytes::from("Hello, World!"));
201
    }
202
203
    #[test(tokio::test)]
204
    async fn test_cc() {
205
        use async_trait::async_trait;
206
        use hello::hello_client::HelloClient;
207
        use hello::hello_server::{Hello, HelloServer};
208
        use hello::{Command, Empty, HelloMsg, RespMsg};
209
        use rand::Rng;
210
        use rand::RngCore;
211
212
        use lyquor_net_rpc::TowerService;
213
        pub mod hello {
214
            lyquor_net_rpc::include_proto!("lyquor.hello");
215
        }
216
217
2
        fn gen_hellomsg(pld: Bytes) -> HelloMsg {
218
2
            HelloMsg {
219
2
                v_int32: 1,
220
2
                v_int64: 2,
221
2
                v_uint32: 3,
222
2
                v_uint64: 4,
223
2
                v_sint32: 5,
224
2
                v_sint64: 6,
225
2
                v_fixed32: 7,
226
2
                v_fixed64: 8,
227
2
                v_sfixed32: 9,
228
2
                v_sfixed64: 10,
229
2
                v_float: 11.0,
230
2
                v_double: 12.0,
231
2
                v_bool: true,
232
2
                v_string: "HelloMsg".to_string(),
233
2
                v_bytes: [1, 2, 3].to_vec(),
234
2
                cmd: Command::Pong.into(),
235
2
                len: pld.len() as u32,
236
2
                payload: pld.to_vec(),
237
2
            }
238
2
        }
239
240
        #[derive(Debug, Default)]
241
        pub struct HelloImpl {}
242
243
        #[async_trait]
244
        impl Hello for HelloImpl {
245
            async fn hello(
246
                &self, request: hyper::Request<HelloMsg>,
247
1
            ) -> std::result::Result<hyper::Response<HelloMsg>, String> {
248
                let (_, msg) = request.into_parts();
249
                println!("hello got a request: {:?}", msg);
250
251
                let mut rng = rand::rng();
252
                let mut pld = vec![0u8; rand::rng().random_range(100..=1000)];
253
                rng.fill_bytes(&mut pld);
254
255
                Ok(hyper::Response::new(gen_hellomsg(pld.into())))
256
1
            }
257
258
            async fn hello_empty(
259
                &self, request: hyper::Request<HelloMsg>,
260
1
            ) -> std::result::Result<hyper::Response<Empty>, String> {
261
                let (_, msg) = request.into_parts();
262
                println!("hello_empty got a request: {:?}", msg);
263
                let reply = Empty {};
264
                Ok(hyper::Response::new(reply))
265
1
            }
266
267
            async fn hello_resp(
268
                &self, request: hyper::Request<HelloMsg>,
269
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
270
                let (_, msg) = request.into_parts();
271
                println!("hello_resp got a request: {:?}", msg);
272
273
                let mut rng = rand::rng();
274
                let mut pld = vec![0u8; rand::rng().random_range(10..=20)];
275
                rng.fill_bytes(&mut pld);
276
277
                let reply = RespMsg {
278
                    cmd: Command::Pong.into(),
279
                    len: pld.len() as u32,
280
                    payload: pld,
281
                };
282
                Ok(hyper::Response::new(reply))
283
1
            }
284
285
            async fn hello_long(
286
                &self, request: hyper::Request<HelloMsg>,
287
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
288
                let (_, msg) = request.into_parts();
289
                println!("hello_long got a request: {:?}", msg);
290
291
                let mut rng = rand::rng();
292
                // 10M - 100M payload
293
                let mut pld = vec![0u8; rand::rng().random_range(10 * 1024 * 1024..=100 * 1024 * 1024)];
294
                rng.fill_bytes(&mut pld);
295
296
                let reply = RespMsg {
297
                    cmd: Command::Pong.into(),
298
                    len: pld.len() as u32,
299
                    payload: pld,
300
                };
301
                Ok(hyper::Response::new(reply))
302
1
            }
303
304
            async fn hello_error(
305
                &self, _: hyper::Request<HelloMsg>,
306
1
            ) -> std::result::Result<hyper::Response<RespMsg>, String> {
307
                Err("Error!".to_string())
308
1
            }
309
        }
310
311
        let mut nodes = Vec::new();
312
313
        for i in 0u8..2 {
314
            let (node_id, tls_config) = lyquor_tls::generator::test_config(i);
315
            let addr = generate_unused_socket();
316
            let mut hub = HubBuilder::new(tls_config).listen_addr(addr.clone()).build().unwrap();
317
            let _ = hub.start().await;
318
            nodes.push((node_id, addr, Arc::new(Mutex::new(hub))));
319
        }
320
321
        for (id, _, hub) in &nodes {
322
            for (peer_id, peer_addr, _) in &nodes {
323
                if id == peer_id {
324
                    continue;
325
                }
326
                let mut hub = hub.lock().await;
327
                hub.add_peer(peer_id.clone(), peer_addr.clone()).await.unwrap();
328
            }
329
        }
330
331
        let first = nodes[0].2.lock().await;
332
        let second = nodes[1].2.lock().await;
333
334
        let requester = first.outbound(nodes[1].0.clone()).unwrap();
335
        let responder = second.inbound(nodes[0].0.clone()).unwrap();
336
337
        let mut hello_srv = HelloServer::new(HelloImpl::default());
338
1
        tokio::spawn(async move {
339
            loop {
340
6
                let 
req5
= responder.recv().await.
unwrap5
();
341
5
                println!("{:?}", req.request);
342
5
                assert_eq!(req.request.headers().get("lyquor-pkt-type").unwrap(), "RPC");
343
344
5
                let res = hello_srv.call(req.request).await.unwrap();
345
5
                req.response.send(res).unwrap();
346
            }
347
        });
348
349
        let msg = gen_hellomsg([0].to_vec().into());
350
351
        let mut client = HelloClient::new(requester);
352
        let resp = client.hello(msg.clone()).await;
353
        let response = resp.unwrap();
354
        assert_eq!(response.len, response.payload.len() as u32);
355
        println!("hello RESP: {:?}", response);
356
357
        let resp = client.hello_empty(msg.clone()).await;
358
        let response = resp.unwrap();
359
        println!("hello_empty RESP: {:?}", response);
360
361
        let resp = client.hello_resp(msg.clone()).await;
362
        let response = resp.unwrap();
363
        assert_eq!(response.len, response.payload.len() as u32);
364
        println!("hello_resp RESP: {:?}", response);
365
366
        let resp = client.hello_long(msg.clone()).await;
367
        let response = resp.unwrap();
368
        assert_eq!(response.len, response.payload.len() as u32);
369
        println!("hello_resp RESP length {:?}KB", response.len / 1023);
370
371
        let resp = client.hello_error(msg.clone()).await;
372
        match resp {
373
            Ok(_) => {
374
                // Shouldn't see this
375
                assert_eq!(0, 1);
376
            }
377
            Err(e) => {
378
                println!("{:?}", e);
379
            }
380
        };
381
    }
382
}