]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
Rework structure
[mt_rudp.git] / src / recv.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     cell::OnceCell,
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::sync::{mpsc, watch, Mutex};
13
14 fn to_seqnum(seqnum: u16) -> usize {
15     (seqnum as usize) & (REL_BUFFER - 1)
16 }
17
18 type Result<T> = std::result::Result<T, Error>;
19
20 #[derive(Debug)]
21 struct Split {
22     timestamp: Option<Instant>,
23     chunks: Vec<OnceCell<Vec<u8>>>,
24     got: usize,
25 }
26
27 struct RecvChan {
28     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
29     splits: HashMap<u16, Split>,
30     seqnum: u16,
31     num: u8,
32 }
33
34 pub(crate) struct RecvWorker<R: UdpReceiver, S: UdpSender> {
35     share: Arc<RudpShare<S>>,
36     close: watch::Receiver<bool>,
37     chans: Arc<Vec<Mutex<RecvChan>>>,
38     pkt_tx: mpsc::UnboundedSender<InPkt>,
39     udp_rx: R,
40 }
41
42 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
43     pub fn new(
44         udp_rx: R,
45         share: Arc<RudpShare<S>>,
46         close: watch::Receiver<bool>,
47         pkt_tx: mpsc::UnboundedSender<InPkt>,
48     ) -> Self {
49         Self {
50             udp_rx,
51             share,
52             close,
53             pkt_tx,
54             chans: Arc::new(
55                 (0..NUM_CHANS as u8)
56                     .map(|num| {
57                         Mutex::new(RecvChan {
58                             num,
59                             packets: (0..REL_BUFFER).map(|_| None).collect(),
60                             seqnum: INIT_SEQNUM,
61                             splits: HashMap::new(),
62                         })
63                     })
64                     .collect(),
65             ),
66         }
67     }
68
69     pub async fn run(&self) {
70         use Error::*;
71
72         let cleanup_chans = Arc::clone(&self.chans);
73         let mut cleanup_close = self.close.clone();
74         self.share
75             .tasks
76             .lock()
77             .await
78             /*.build_task()
79             .name("cleanup_splits")*/
80             .spawn(async move {
81                 let timeout = Duration::from_secs(TIMEOUT);
82
83                 ticker!(timeout, cleanup_close, {
84                     for chan_mtx in cleanup_chans.iter() {
85                         let mut chan = chan_mtx.lock().await;
86                         chan.splits = chan
87                             .splits
88                             .drain_filter(
89                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
90                             )
91                             .collect();
92                     }
93                 });
94             });
95
96         let mut close = self.close.clone();
97         let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
98         tokio::pin!(timeout);
99
100         loop {
101             if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
102                 // TODO: figure out whether this is a good idea
103                 if let RemoteDisco(to) = e {
104                     self.pkt_tx.send(Err(RemoteDisco(to))).ok();
105                 }
106
107                 #[allow(clippy::single_match)]
108                 match e {
109                                         // anon5's mt notifies the peer on timeout, C++ MT does not
110                                         LocalDisco /*| RemoteDisco(true)*/ => drop(
111                                                 self.share
112                                                         .send(
113                                                                 PktType::Ctl,
114                                                                 Pkt {
115                                                                         unrel: true,
116                                                                         chan: 0,
117                                                                         data: &[CtlType::Disco as u8],
118                                                                 },
119                                                         )
120                                                         .await
121                                                         .ok(),
122                                         ),
123                                         _ => {}
124                                 }
125
126                 break;
127             }
128         }
129     }
130
131     async fn recv_pkt(
132         &self,
133         close: &mut watch::Receiver<bool>,
134         timeout: Pin<&mut tokio::time::Sleep>,
135     ) -> Result<()> {
136         use Error::*;
137
138         let mut cursor = io::Cursor::new(tokio::select! {
139             pkt = self.udp_rx.recv() => pkt?,
140             _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
141             _ = close.changed() => return Err(LocalDisco),
142         });
143
144         timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
145
146         let proto_id = cursor.read_u32::<BigEndian>()?;
147         if proto_id != PROTO_ID {
148             return Err(InvalidProtoId(proto_id));
149         }
150
151         let _peer_id = cursor.read_u16::<BigEndian>()?;
152
153         let n_chan = cursor.read_u8()?;
154         let mut chan = self
155             .chans
156             .get(n_chan as usize)
157             .ok_or(InvalidChannel(n_chan))?
158             .lock()
159             .await;
160
161         self.process_pkt(cursor, true, &mut chan).await
162     }
163
164     #[async_recursion]
165     async fn process_pkt(
166         &self,
167         mut cursor: io::Cursor<Vec<u8>>,
168         unrel: bool,
169         chan: &mut RecvChan,
170     ) -> Result<()> {
171         use Error::*;
172
173         match cursor.read_u8()?.try_into()? {
174             PktType::Ctl => match cursor.read_u8()?.try_into()? {
175                 CtlType::Ack => {
176                     println!("Ack");
177
178                     let seqnum = cursor.read_u16::<BigEndian>()?;
179                     if let Some(ack) = self.share.chans[chan.num as usize]
180                         .lock()
181                         .await
182                         .acks
183                         .remove(&seqnum)
184                     {
185                         ack.tx.send(true).ok();
186                     }
187                 }
188                 CtlType::SetPeerID => {
189                     println!("SetPeerID");
190
191                     let mut id = self.share.remote_id.write().await;
192
193                     if *id != PeerID::Nil as u16 {
194                         return Err(PeerIDAlreadySet);
195                     }
196
197                     *id = cursor.read_u16::<BigEndian>()?;
198                 }
199                 CtlType::Ping => {
200                     println!("Ping");
201                 }
202                 CtlType::Disco => {
203                     println!("Disco");
204                     return Err(RemoteDisco(false));
205                 }
206             },
207             PktType::Orig => {
208                 println!("Orig");
209
210                 self.pkt_tx.send(Ok(Pkt {
211                     chan: chan.num,
212                     unrel,
213                     data: cursor.remaining_slice().into(),
214                 }))?;
215             }
216             PktType::Split => {
217                 println!("Split");
218
219                 let seqnum = cursor.read_u16::<BigEndian>()?;
220                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
221                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
222
223                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
224                     got: 0,
225                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
226                     timestamp: None,
227                 });
228
229                 if split.chunks.len() != chunk_count {
230                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
231                 }
232
233                 if split
234                     .chunks
235                     .get(chunk_index)
236                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
237                     .set(cursor.remaining_slice().into())
238                     .is_ok()
239                 {
240                     split.got += 1;
241                 }
242
243                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
244
245                 if split.got == chunk_count {
246                     self.pkt_tx.send(Ok(Pkt {
247                         chan: chan.num,
248                         unrel,
249                         data: split
250                             .chunks
251                             .iter()
252                             .flat_map(|chunk| chunk.get().unwrap().iter())
253                             .copied()
254                             .collect(),
255                     }))?;
256
257                     chan.splits.remove(&seqnum);
258                 }
259             }
260             PktType::Rel => {
261                 println!("Rel");
262
263                 let seqnum = cursor.read_u16::<BigEndian>()?;
264                 chan.packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
265
266                 let mut ack_data = Vec::with_capacity(3);
267                 ack_data.write_u8(CtlType::Ack as u8)?;
268                 ack_data.write_u16::<BigEndian>(seqnum)?;
269
270                 self.share
271                     .send(
272                         PktType::Ctl,
273                         Pkt {
274                             unrel: true,
275                             chan: chan.num,
276                             data: &ack_data,
277                         },
278                     )
279                     .await?;
280
281                 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
282                     chan.packets[to_seqnum(chan.seqnum)].take()
283                 }
284
285                 while let Some(pkt) = next_pkt(chan) {
286                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
287                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
288                 }
289             }
290         }
291
292         Ok(())
293     }
294
295     fn handle(&self, res: Result<()>) -> Result<()> {
296         use Error::*;
297
298         match res {
299             Ok(v) => Ok(v),
300             Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
301             Err(LocalDisco) => Err(LocalDisco),
302             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
303         }
304     }
305 }