]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
Comment out println debugging
[mt_rudp.git] / src / recv.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     cell::OnceCell,
7     collections::HashMap,
8     io,
9     pin::Pin,
10     sync::Arc,
11     time::{Duration, Instant},
12 };
13 use tokio::sync::{mpsc, watch, Mutex};
14
15 fn to_seqnum(seqnum: u16) -> usize {
16     (seqnum as usize) & (REL_BUFFER - 1)
17 }
18
19 type Result<T> = std::result::Result<T, Error>;
20
21 #[derive(Debug)]
22 struct Split {
23     timestamp: Option<Instant>,
24     chunks: Vec<OnceCell<Vec<u8>>>,
25     got: usize,
26 }
27
28 struct RecvChan {
29     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
30     splits: HashMap<u16, Split>,
31     seqnum: u16,
32     num: u8,
33 }
34
35 pub(crate) struct RecvWorker<R: UdpReceiver, S: UdpSender> {
36     share: Arc<RudpShare<S>>,
37     close: watch::Receiver<bool>,
38     chans: Arc<Vec<Mutex<RecvChan>>>,
39     pkt_tx: mpsc::UnboundedSender<InPkt>,
40     udp_rx: R,
41 }
42
43 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
44     pub fn new(
45         udp_rx: R,
46         share: Arc<RudpShare<S>>,
47         close: watch::Receiver<bool>,
48         pkt_tx: mpsc::UnboundedSender<InPkt>,
49     ) -> Self {
50         Self {
51             udp_rx,
52             share,
53             close,
54             pkt_tx,
55             chans: Arc::new(
56                 (0..NUM_CHANS as u8)
57                     .map(|num| {
58                         Mutex::new(RecvChan {
59                             num,
60                             packets: (0..REL_BUFFER).map(|_| None).collect(),
61                             seqnum: INIT_SEQNUM,
62                             splits: HashMap::new(),
63                         })
64                     })
65                     .collect(),
66             ),
67         }
68     }
69
70     pub async fn run(&self) {
71         use Error::*;
72
73         let cleanup_chans = Arc::clone(&self.chans);
74         let mut cleanup_close = self.close.clone();
75         self.share
76             .tasks
77             .lock()
78             .await
79             /*.build_task()
80             .name("cleanup_splits")*/
81             .spawn(async move {
82                 let timeout = Duration::from_secs(TIMEOUT);
83
84                 ticker!(timeout, cleanup_close, {
85                     for chan_mtx in cleanup_chans.iter() {
86                         let mut chan = chan_mtx.lock().await;
87                         chan.splits = chan
88                             .splits
89                             .drain_filter(
90                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
91                             )
92                             .collect();
93                     }
94                 });
95             });
96
97         let mut close = self.close.clone();
98         let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
99         tokio::pin!(timeout);
100
101         loop {
102             if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
103                 // TODO: figure out whether this is a good idea
104                 if let RemoteDisco(to) = e {
105                     self.pkt_tx.send(Err(RemoteDisco(to))).ok();
106                 }
107
108                 #[allow(clippy::single_match)]
109                 match e {
110                                         // anon5's mt notifies the peer on timeout, C++ MT does not
111                                         LocalDisco /*| RemoteDisco(true)*/ => drop(
112                                                 self.share
113                                                         .send(
114                                                                 PktType::Ctl,
115                                                                 Pkt {
116                                                                         unrel: true,
117                                                                         chan: 0,
118                                                                         data: Cow::Borrowed(&[CtlType::Disco as u8]),
119                                                                 },
120                                                         )
121                                                         .await
122                                                         .ok(),
123                                         ),
124                                         _ => {}
125                                 }
126
127                 break;
128             }
129         }
130     }
131
132     async fn recv_pkt(
133         &self,
134         close: &mut watch::Receiver<bool>,
135         timeout: Pin<&mut tokio::time::Sleep>,
136     ) -> Result<()> {
137         use Error::*;
138
139         let mut cursor = io::Cursor::new(tokio::select! {
140             pkt = self.udp_rx.recv() => pkt?,
141             _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
142             _ = close.changed() => return Err(LocalDisco),
143         });
144
145         timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
146
147         let proto_id = cursor.read_u32::<BigEndian>()?;
148         if proto_id != PROTO_ID {
149             return Err(InvalidProtoId(proto_id));
150         }
151
152         let _peer_id = cursor.read_u16::<BigEndian>()?;
153
154         let n_chan = cursor.read_u8()?;
155         let mut chan = self
156             .chans
157             .get(n_chan as usize)
158             .ok_or(InvalidChannel(n_chan))?
159             .lock()
160             .await;
161
162         self.process_pkt(cursor, true, &mut chan).await
163     }
164
165     #[async_recursion]
166     async fn process_pkt(
167         &self,
168         mut cursor: io::Cursor<Vec<u8>>,
169         unrel: bool,
170         chan: &mut RecvChan,
171     ) -> Result<()> {
172         use Error::*;
173
174         match cursor.read_u8()?.try_into()? {
175             PktType::Ctl => match cursor.read_u8()?.try_into()? {
176                 CtlType::Ack => {
177                     // println!("Ack");
178
179                     let seqnum = cursor.read_u16::<BigEndian>()?;
180                     if let Some(ack) = self.share.chans[chan.num as usize]
181                         .lock()
182                         .await
183                         .acks
184                         .remove(&seqnum)
185                     {
186                         ack.tx.send(true).ok();
187                     }
188                 }
189                 CtlType::SetPeerID => {
190                     // println!("SetPeerID");
191
192                     let mut id = self.share.remote_id.write().await;
193
194                     if *id != PeerID::Nil as u16 {
195                         return Err(PeerIDAlreadySet);
196                     }
197
198                     *id = cursor.read_u16::<BigEndian>()?;
199                 }
200                 CtlType::Ping => {
201                     // println!("Ping");
202                 }
203                 CtlType::Disco => {
204                     // println!("Disco");
205                     return Err(RemoteDisco(false));
206                 }
207             },
208             PktType::Orig => {
209                 // println!("Orig");
210
211                 self.pkt_tx.send(Ok(Pkt {
212                     chan: chan.num,
213                     unrel,
214                     data: Cow::Owned(cursor.remaining_slice().into()),
215                 }))?;
216             }
217             PktType::Split => {
218                 // println!("Split");
219
220                 let seqnum = cursor.read_u16::<BigEndian>()?;
221                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
222                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
223
224                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
225                     got: 0,
226                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
227                     timestamp: None,
228                 });
229
230                 if split.chunks.len() != chunk_count {
231                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
232                 }
233
234                 if split
235                     .chunks
236                     .get(chunk_index)
237                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
238                     .set(cursor.remaining_slice().into())
239                     .is_ok()
240                 {
241                     split.got += 1;
242                 }
243
244                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
245
246                 if split.got == chunk_count {
247                     self.pkt_tx.send(Ok(Pkt {
248                         chan: chan.num,
249                         unrel,
250                         data: split
251                             .chunks
252                             .iter()
253                             .flat_map(|chunk| chunk.get().unwrap().iter())
254                             .copied()
255                             .collect(),
256                     }))?;
257
258                     chan.splits.remove(&seqnum);
259                 }
260             }
261             PktType::Rel => {
262                 // println!("Rel");
263
264                 let seqnum = cursor.read_u16::<BigEndian>()?;
265                 chan.packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
266
267                 let mut ack_data = Vec::with_capacity(3);
268                 ack_data.write_u8(CtlType::Ack as u8)?;
269                 ack_data.write_u16::<BigEndian>(seqnum)?;
270
271                 self.share
272                     .send(
273                         PktType::Ctl,
274                         Pkt {
275                             unrel: true,
276                             chan: chan.num,
277                             data: Cow::Borrowed(&ack_data),
278                         },
279                     )
280                     .await?;
281
282                 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
283                     chan.packets[to_seqnum(chan.seqnum)].take()
284                 }
285
286                 while let Some(pkt) = next_pkt(chan) {
287                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
288                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
289                 }
290             }
291         }
292
293         Ok(())
294     }
295
296     fn handle(&self, res: Result<()>) -> Result<()> {
297         use Error::*;
298
299         match res {
300             Ok(v) => Ok(v),
301             Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
302             Err(LocalDisco) => Err(LocalDisco),
303             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
304         }
305     }
306 }