]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
timeouts
[mt_rudp.git] / src / recv.rs
1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::sync::{mpsc, Mutex};
13
14 fn to_seqnum(seqnum: u16) -> usize {
15     (seqnum as usize) & (REL_BUFFER - 1)
16 }
17
18 type Result<T> = std::result::Result<T, Error>;
19
20 struct Split {
21     timestamp: Option<Instant>,
22     chunks: Vec<OnceCell<Vec<u8>>>,
23     got: usize,
24 }
25
26 struct RecvChan {
27     packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
28     splits: HashMap<u16, Split>,
29     seqnum: u16,
30     num: u8,
31 }
32
33 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
34     share: Arc<RudpShare<S>>,
35     close: watch::Receiver<bool>,
36     chans: Arc<Vec<Mutex<RecvChan>>>,
37     pkt_tx: mpsc::UnboundedSender<InPkt>,
38     udp_rx: R,
39 }
40
41 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
42     pub fn new(
43         udp_rx: R,
44         share: Arc<RudpShare<S>>,
45         close: watch::Receiver<bool>,
46         pkt_tx: mpsc::UnboundedSender<InPkt>,
47     ) -> Self {
48         Self {
49             udp_rx,
50             share,
51             close,
52             pkt_tx,
53             chans: Arc::new(
54                 (0..NUM_CHANS as u8)
55                     .map(|num| {
56                         Mutex::new(RecvChan {
57                             num,
58                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
59                             seqnum: INIT_SEQNUM,
60                             splits: HashMap::new(),
61                         })
62                     })
63                     .collect(),
64             ),
65         }
66     }
67
68     pub async fn run(&self) {
69         use Error::*;
70
71         let cleanup_chans = Arc::clone(&self.chans);
72         let mut cleanup_close = self.close.clone();
73         self.share
74             .tasks
75             .lock()
76             .await
77             /*.build_task()
78             .name("cleanup_splits")*/
79             .spawn(async move {
80                 let timeout = Duration::from_secs(TIMEOUT);
81
82                 ticker!(timeout, cleanup_close, {
83                     for chan_mtx in cleanup_chans.iter() {
84                         let mut chan = chan_mtx.lock().await;
85                         chan.splits = chan
86                             .splits
87                             .drain_filter(
88                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
89                             )
90                             .collect();
91                     }
92                 });
93             });
94
95         let mut close = self.close.clone();
96         let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
97         tokio::pin!(timeout);
98
99         loop {
100             if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
101                 // TODO: figure out whether this is a good idea
102                 if let RemoteDisco(to) = e {
103                     self.pkt_tx.send(Err(RemoteDisco(to))).ok();
104                 }
105
106                 match e {
107                                         // anon5's mt notifies the peer on timeout, C++ MT does not
108                                         LocalDisco /*| RemoteDisco(true)*/ => drop(
109                                                 self.share
110                                                         .send(
111                                                                 PktType::Ctl,
112                                                                 Pkt {
113                                                                         unrel: true,
114                                                                         chan: 0,
115                                                                         data: &[CtlType::Disco as u8],
116                                                                 },
117                                                         )
118                                                         .await
119                                                         .ok(),
120                                         ),
121                                         _ => {}
122                                 }
123
124                 break;
125             }
126         }
127     }
128
129     async fn recv_pkt(
130         &self,
131         close: &mut watch::Receiver<bool>,
132         timeout: Pin<&mut tokio::time::Sleep>,
133     ) -> Result<()> {
134         use Error::*;
135
136         // TODO: reset timeout
137         let mut cursor = io::Cursor::new(tokio::select! {
138             pkt = self.udp_rx.recv() => pkt?,
139             _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
140             _ = close.changed() => return Err(LocalDisco),
141         });
142
143         timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
144
145         let proto_id = cursor.read_u32::<BigEndian>()?;
146         if proto_id != PROTO_ID {
147             return Err(InvalidProtoId(proto_id));
148         }
149
150         let _peer_id = cursor.read_u16::<BigEndian>()?;
151
152         let n_chan = cursor.read_u8()?;
153         let mut chan = self
154             .chans
155             .get(n_chan as usize)
156             .ok_or(InvalidChannel(n_chan))?
157             .lock()
158             .await;
159
160         self.process_pkt(cursor, true, &mut chan).await
161     }
162
163     #[async_recursion]
164     async fn process_pkt(
165         &self,
166         mut cursor: io::Cursor<Vec<u8>>,
167         unrel: bool,
168         chan: &mut RecvChan,
169     ) -> Result<()> {
170         use Error::*;
171
172         match cursor.read_u8()?.try_into()? {
173             PktType::Ctl => match cursor.read_u8()?.try_into()? {
174                 CtlType::Ack => {
175                     println!("Ack");
176
177                     let seqnum = cursor.read_u16::<BigEndian>()?;
178                     if let Some(ack) = self.share.chans[chan.num as usize]
179                         .lock()
180                         .await
181                         .acks
182                         .remove(&seqnum)
183                     {
184                         ack.tx.send(true).ok();
185                     }
186                 }
187                 CtlType::SetPeerID => {
188                     println!("SetPeerID");
189
190                     let mut id = self.share.remote_id.write().await;
191
192                     if *id != PeerID::Nil as u16 {
193                         return Err(PeerIDAlreadySet);
194                     }
195
196                     *id = cursor.read_u16::<BigEndian>()?;
197                 }
198                 CtlType::Ping => {
199                     println!("Ping");
200                 }
201                 CtlType::Disco => {
202                     println!("Disco");
203                     return Err(RemoteDisco(false));
204                 }
205             },
206             PktType::Orig => {
207                 println!("Orig");
208
209                 self.pkt_tx.send(Ok(Pkt {
210                     chan: chan.num,
211                     unrel,
212                     data: cursor.remaining_slice().into(),
213                 }))?;
214             }
215             PktType::Split => {
216                 println!("Split");
217
218                 let seqnum = cursor.read_u16::<BigEndian>()?;
219                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
220                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
221
222                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
223                     got: 0,
224                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
225                     timestamp: None,
226                 });
227
228                 if split.chunks.len() != chunk_count {
229                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
230                 }
231
232                 if split
233                     .chunks
234                     .get(chunk_index)
235                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
236                     .set(cursor.remaining_slice().into())
237                     .is_ok()
238                 {
239                     split.got += 1;
240                 }
241
242                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
243
244                 if split.got == chunk_count {
245                     self.pkt_tx.send(Ok(Pkt {
246                         chan: chan.num,
247                         unrel,
248                         data: split
249                             .chunks
250                             .iter()
251                             .flat_map(|chunk| chunk.get().unwrap().iter())
252                             .copied()
253                             .collect(),
254                     }))?;
255
256                     chan.splits.remove(&seqnum);
257                 }
258             }
259             PktType::Rel => {
260                 println!("Rel");
261
262                 let seqnum = cursor.read_u16::<BigEndian>()?;
263                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
264
265                 let mut ack_data = Vec::with_capacity(3);
266                 ack_data.write_u8(CtlType::Ack as u8)?;
267                 ack_data.write_u16::<BigEndian>(seqnum)?;
268
269                 self.share
270                     .send(
271                         PktType::Ctl,
272                         Pkt {
273                             unrel: true,
274                             chan: chan.num,
275                             data: &ack_data,
276                         },
277                     )
278                     .await?;
279
280                 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
281                     chan.packets[to_seqnum(chan.seqnum)].take()
282                 }
283
284                 while let Some(pkt) = next_pkt(chan) {
285                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
286                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
287                 }
288             }
289         }
290
291         Ok(())
292     }
293
294     fn handle(&self, res: Result<()>) -> Result<()> {
295         use Error::*;
296
297         match res {
298             Ok(v) => Ok(v),
299             Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
300             Err(LocalDisco) => Err(LocalDisco),
301             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
302         }
303     }
304 }