]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
Don't spawn tasks
[mt_rudp.git] / src / recv.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     collections::{HashMap, VecDeque},
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::{
13     sync::watch,
14     time::{interval, sleep, Interval, Sleep},
15 };
16
17 fn to_seqnum(seqnum: u16) -> usize {
18     (seqnum as usize) & (REL_BUFFER - 1)
19 }
20
21 type Result<T> = std::result::Result<T, Error>;
22
23 #[derive(Debug)]
24 struct Split {
25     timestamp: Option<Instant>,
26     chunks: Vec<Option<Vec<u8>>>,
27     got: usize,
28 }
29
30 struct RecvChan {
31     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
32     splits: HashMap<u16, Split>,
33     seqnum: u16,
34 }
35
36 pub struct RudpReceiver<P: UdpPeer> {
37     pub(crate) share: Arc<RudpShare<P>>,
38     chans: [RecvChan; NUM_CHANS],
39     udp: P::Receiver,
40     close: watch::Receiver<bool>,
41     closed: bool,
42     resend: Interval,
43     ping: Interval,
44     cleanup: Interval,
45     timeout: Pin<Box<Sleep>>,
46     queue: VecDeque<Result<Pkt<'static>>>,
47 }
48
49 impl<P: UdpPeer> RudpReceiver<P> {
50     pub(crate) fn new(
51         udp: P::Receiver,
52         share: Arc<RudpShare<P>>,
53         close: watch::Receiver<bool>,
54     ) -> Self {
55         Self {
56             udp,
57             share,
58             close,
59             closed: false,
60             resend: interval(Duration::from_millis(500)),
61             ping: interval(Duration::from_secs(PING_TIMEOUT)),
62             cleanup: interval(Duration::from_secs(TIMEOUT)),
63             timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
64             chans: std::array::from_fn(|_| RecvChan {
65                 packets: (0..REL_BUFFER).map(|_| None).collect(),
66                 seqnum: INIT_SEQNUM,
67                 splits: HashMap::new(),
68             }),
69             queue: VecDeque::new(),
70         }
71     }
72
73     fn handle_err(&mut self, res: Result<()>) -> Result<()> {
74         use Error::*;
75
76         match res {
77             Err(RemoteDisco(_)) | Err(LocalDisco) => {
78                 self.closed = true;
79                 res
80             }
81             Ok(_) => res,
82             Err(e) => {
83                 self.queue.push_back(Err(e));
84                 Ok(())
85             }
86         }
87     }
88
89     pub async fn recv(&mut self) -> Option<Result<Pkt<'static>>> {
90         use Error::*;
91
92         if self.closed {
93             return None;
94         }
95
96         loop {
97             if let Some(x) = self.queue.pop_front() {
98                 return Some(x);
99             }
100
101             tokio::select! {
102                 _ = self.close.changed() => {
103                     self.closed = true;
104                     return Some(Err(LocalDisco));
105                 },
106                 _ = self.cleanup.tick() => {
107                     let timeout = Duration::from_secs(TIMEOUT);
108
109                     for chan in self.chans.iter_mut() {
110                         chan.splits = chan
111                             .splits
112                             .drain_filter(
113                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
114                             )
115                             .collect();
116                     }
117                 },
118                 _ = self.resend.tick() => {
119                     for chan in self.share.chans.iter() {
120                         for (_, ack) in chan.lock().await.acks.iter() {
121                             self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?)
122                         }
123                     }
124                 },
125                 _ = self.ping.tick() => {
126                     self.share
127                         .send(
128                             PktType::Ctl,
129                             Pkt {
130                                 chan: 0,
131                                 unrel: false,
132                                 data: Cow::Borrowed(&[CtlType::Ping as u8]),
133                             },
134                         )
135                         .await
136                         .ok();
137                 }
138                 _ = &mut self.timeout => {
139                     self.closed = true;
140                     return Some(Err(RemoteDisco(true)));
141                 },
142                 pkt = self.udp.recv() => {
143                     if let Err(e) = self.handle_pkt(pkt).await {
144                         return Some(Err(e));
145                     }
146                 }
147             }
148         }
149     }
150
151     async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
152         use Error::*;
153
154         let mut cursor = io::Cursor::new(pkt?);
155
156         self.timeout
157             .as_mut()
158             .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
159
160         let proto_id = cursor.read_u32::<BigEndian>()?;
161         if proto_id != PROTO_ID {
162             return Err(InvalidProtoId(proto_id));
163         }
164
165         let _peer_id = cursor.read_u16::<BigEndian>()?;
166
167         let chan = cursor.read_u8()?;
168         if chan >= NUM_CHANS as u8 {
169             return Err(InvalidChannel(chan));
170         }
171
172         let res = self.process_pkt(cursor, true, chan).await;
173         self.handle_err(res)?;
174
175         Ok(())
176     }
177
178     #[async_recursion]
179     async fn process_pkt(
180         &mut self,
181         mut cursor: io::Cursor<Vec<u8>>,
182         unrel: bool,
183         chan: u8,
184     ) -> Result<()> {
185         use Error::*;
186
187         let ch = chan as usize;
188         match cursor.read_u8()?.try_into()? {
189             PktType::Ctl => match cursor.read_u8()?.try_into()? {
190                 CtlType::Ack => {
191                     // println!("Ack");
192
193                     let seqnum = cursor.read_u16::<BigEndian>()?;
194                     if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) {
195                         ack.tx.send(true).ok();
196                     }
197                 }
198                 CtlType::SetPeerID => {
199                     // println!("SetPeerID");
200
201                     let mut id = self.share.remote_id.write().await;
202
203                     if *id != PeerID::Nil as u16 {
204                         return Err(PeerIDAlreadySet);
205                     }
206
207                     *id = cursor.read_u16::<BigEndian>()?;
208                 }
209                 CtlType::Ping => {
210                     // println!("Ping");
211                 }
212                 CtlType::Disco => {
213                     // println!("Disco");
214                     return Err(RemoteDisco(false));
215                 }
216             },
217             PktType::Orig => {
218                 // println!("Orig");
219
220                 self.queue.push_back(Ok(Pkt {
221                     chan,
222                     unrel,
223                     data: Cow::Owned(cursor.remaining_slice().into()),
224                 }));
225             }
226             PktType::Split => {
227                 // println!("Split");
228
229                 let seqnum = cursor.read_u16::<BigEndian>()?;
230                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
231                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
232
233                 let mut split = self.chans[ch]
234                     .splits
235                     .entry(seqnum)
236                     .or_insert_with(|| Split {
237                         got: 0,
238                         chunks: (0..chunk_count).map(|_| None).collect(),
239                         timestamp: None,
240                     });
241
242                 if split.chunks.len() != chunk_count {
243                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
244                 }
245
246                 if split
247                     .chunks
248                     .get_mut(chunk_index)
249                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
250                     .replace(cursor.remaining_slice().into())
251                     .is_none()
252                 {
253                     split.got += 1;
254                 }
255
256                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
257
258                 if split.got == chunk_count {
259                     let split = self.chans[ch].splits.remove(&seqnum).unwrap();
260
261                     self.queue.push_back(Ok(Pkt {
262                         chan,
263                         unrel,
264                         data: split
265                             .chunks
266                             .into_iter()
267                             .map(|x| x.unwrap())
268                             .reduce(|mut a, mut b| {
269                                 a.append(&mut b);
270                                 a
271                             })
272                             .unwrap_or_default()
273                             .into(),
274                     }));
275                 }
276             }
277             PktType::Rel => {
278                 // println!("Rel");
279
280                 let seqnum = cursor.read_u16::<BigEndian>()?;
281                 self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
282
283                 let mut ack_data = Vec::with_capacity(3);
284                 ack_data.write_u8(CtlType::Ack as u8)?;
285                 ack_data.write_u16::<BigEndian>(seqnum)?;
286
287                 self.share
288                     .send(
289                         PktType::Ctl,
290                         Pkt {
291                             chan,
292                             unrel: true,
293                             data: ack_data.into(),
294                         },
295                     )
296                     .await?;
297
298                 let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
299                 while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
300                     let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await;
301                     self.handle_err(res)?;
302                     self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
303                 }
304             }
305         }
306
307         Ok(())
308     }
309 }