]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
cleanup; readme
[mt_rudp.git] / src / recv.rs
1 use crate::{prelude::*, ticker, RecvChan, RecvWorker, RudpShare, Split};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::sync::{mpsc, watch, Mutex};
13
14 fn to_seqnum(seqnum: u16) -> usize {
15     (seqnum as usize) & (REL_BUFFER - 1)
16 }
17
18 type Result<T> = std::result::Result<T, Error>;
19
20 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
21     pub fn new(
22         udp_rx: R,
23         share: Arc<RudpShare<S>>,
24         close: watch::Receiver<bool>,
25         pkt_tx: mpsc::UnboundedSender<InPkt>,
26     ) -> Self {
27         Self {
28             udp_rx,
29             share,
30             close,
31             pkt_tx,
32             chans: Arc::new(
33                 (0..NUM_CHANS as u8)
34                     .map(|num| {
35                         Mutex::new(RecvChan {
36                             num,
37                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
38                             seqnum: INIT_SEQNUM,
39                             splits: HashMap::new(),
40                         })
41                     })
42                     .collect(),
43             ),
44         }
45     }
46
47     pub async fn run(&self) {
48         use Error::*;
49
50         let cleanup_chans = Arc::clone(&self.chans);
51         let mut cleanup_close = self.close.clone();
52         self.share
53             .tasks
54             .lock()
55             .await
56             /*.build_task()
57             .name("cleanup_splits")*/
58             .spawn(async move {
59                 let timeout = Duration::from_secs(TIMEOUT);
60
61                 ticker!(timeout, cleanup_close, {
62                     for chan_mtx in cleanup_chans.iter() {
63                         let mut chan = chan_mtx.lock().await;
64                         chan.splits = chan
65                             .splits
66                             .drain_filter(
67                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
68                             )
69                             .collect();
70                     }
71                 });
72             });
73
74         let mut close = self.close.clone();
75         let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
76         tokio::pin!(timeout);
77
78         loop {
79             if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
80                 // TODO: figure out whether this is a good idea
81                 if let RemoteDisco(to) = e {
82                     self.pkt_tx.send(Err(RemoteDisco(to))).ok();
83                 }
84
85                 match e {
86                                         // anon5's mt notifies the peer on timeout, C++ MT does not
87                                         LocalDisco /*| RemoteDisco(true)*/ => drop(
88                                                 self.share
89                                                         .send(
90                                                                 PktType::Ctl,
91                                                                 Pkt {
92                                                                         unrel: true,
93                                                                         chan: 0,
94                                                                         data: &[CtlType::Disco as u8],
95                                                                 },
96                                                         )
97                                                         .await
98                                                         .ok(),
99                                         ),
100                                         _ => {}
101                                 }
102
103                 break;
104             }
105         }
106     }
107
108     async fn recv_pkt(
109         &self,
110         close: &mut watch::Receiver<bool>,
111         timeout: Pin<&mut tokio::time::Sleep>,
112     ) -> Result<()> {
113         use Error::*;
114
115         let mut cursor = io::Cursor::new(tokio::select! {
116             pkt = self.udp_rx.recv() => pkt?,
117             _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
118             _ = close.changed() => return Err(LocalDisco),
119         });
120
121         timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
122
123         let proto_id = cursor.read_u32::<BigEndian>()?;
124         if proto_id != PROTO_ID {
125             return Err(InvalidProtoId(proto_id));
126         }
127
128         let _peer_id = cursor.read_u16::<BigEndian>()?;
129
130         let n_chan = cursor.read_u8()?;
131         let mut chan = self
132             .chans
133             .get(n_chan as usize)
134             .ok_or(InvalidChannel(n_chan))?
135             .lock()
136             .await;
137
138         self.process_pkt(cursor, true, &mut chan).await
139     }
140
141     #[async_recursion]
142     async fn process_pkt(
143         &self,
144         mut cursor: io::Cursor<Vec<u8>>,
145         unrel: bool,
146         chan: &mut RecvChan,
147     ) -> Result<()> {
148         use Error::*;
149
150         match cursor.read_u8()?.try_into()? {
151             PktType::Ctl => match cursor.read_u8()?.try_into()? {
152                 CtlType::Ack => {
153                     println!("Ack");
154
155                     let seqnum = cursor.read_u16::<BigEndian>()?;
156                     if let Some(ack) = self.share.chans[chan.num as usize]
157                         .lock()
158                         .await
159                         .acks
160                         .remove(&seqnum)
161                     {
162                         ack.tx.send(true).ok();
163                     }
164                 }
165                 CtlType::SetPeerID => {
166                     println!("SetPeerID");
167
168                     let mut id = self.share.remote_id.write().await;
169
170                     if *id != PeerID::Nil as u16 {
171                         return Err(PeerIDAlreadySet);
172                     }
173
174                     *id = cursor.read_u16::<BigEndian>()?;
175                 }
176                 CtlType::Ping => {
177                     println!("Ping");
178                 }
179                 CtlType::Disco => {
180                     println!("Disco");
181                     return Err(RemoteDisco(false));
182                 }
183             },
184             PktType::Orig => {
185                 println!("Orig");
186
187                 self.pkt_tx.send(Ok(Pkt {
188                     chan: chan.num,
189                     unrel,
190                     data: cursor.remaining_slice().into(),
191                 }))?;
192             }
193             PktType::Split => {
194                 println!("Split");
195
196                 let seqnum = cursor.read_u16::<BigEndian>()?;
197                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
198                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
199
200                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
201                     got: 0,
202                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
203                     timestamp: None,
204                 });
205
206                 if split.chunks.len() != chunk_count {
207                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
208                 }
209
210                 if split
211                     .chunks
212                     .get(chunk_index)
213                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
214                     .set(cursor.remaining_slice().into())
215                     .is_ok()
216                 {
217                     split.got += 1;
218                 }
219
220                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
221
222                 if split.got == chunk_count {
223                     self.pkt_tx.send(Ok(Pkt {
224                         chan: chan.num,
225                         unrel,
226                         data: split
227                             .chunks
228                             .iter()
229                             .flat_map(|chunk| chunk.get().unwrap().iter())
230                             .copied()
231                             .collect(),
232                     }))?;
233
234                     chan.splits.remove(&seqnum);
235                 }
236             }
237             PktType::Rel => {
238                 println!("Rel");
239
240                 let seqnum = cursor.read_u16::<BigEndian>()?;
241                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
242
243                 let mut ack_data = Vec::with_capacity(3);
244                 ack_data.write_u8(CtlType::Ack as u8)?;
245                 ack_data.write_u16::<BigEndian>(seqnum)?;
246
247                 self.share
248                     .send(
249                         PktType::Ctl,
250                         Pkt {
251                             unrel: true,
252                             chan: chan.num,
253                             data: &ack_data,
254                         },
255                     )
256                     .await?;
257
258                 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
259                     chan.packets[to_seqnum(chan.seqnum)].take()
260                 }
261
262                 while let Some(pkt) = next_pkt(chan) {
263                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
264                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
265                 }
266             }
267         }
268
269         Ok(())
270     }
271
272     fn handle(&self, res: Result<()>) -> Result<()> {
273         use Error::*;
274
275         match res {
276             Ok(v) => Ok(v),
277             Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
278             Err(LocalDisco) => Err(LocalDisco),
279             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
280         }
281     }
282 }