]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv.rs
clean shutdown; send reliables
[mt_rudp.git] / src / recv.rs
1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     sync::Arc,
9     time::{Duration, Instant},
10 };
11 use tokio::sync::{mpsc, Mutex};
12
13 fn to_seqnum(seqnum: u16) -> usize {
14     (seqnum as usize) & (REL_BUFFER - 1)
15 }
16
17 type Result<T> = std::result::Result<T, Error>;
18
19 struct Split {
20     timestamp: Option<Instant>,
21     chunks: Vec<OnceCell<Vec<u8>>>,
22     got: usize,
23 }
24
25 struct RecvChan {
26     packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
27     splits: HashMap<u16, Split>,
28     seqnum: u16,
29     num: u8,
30 }
31
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33     share: Arc<RudpShare<S>>,
34     close: watch::Receiver<bool>,
35     chans: Arc<Vec<Mutex<RecvChan>>>,
36     pkt_tx: mpsc::UnboundedSender<InPkt>,
37     udp_rx: R,
38 }
39
40 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
41     pub fn new(
42         udp_rx: R,
43         share: Arc<RudpShare<S>>,
44         close: watch::Receiver<bool>,
45         pkt_tx: mpsc::UnboundedSender<InPkt>,
46     ) -> Self {
47         Self {
48             udp_rx,
49             share,
50             close,
51             pkt_tx,
52             chans: Arc::new(
53                 (0..NUM_CHANS as u8)
54                     .map(|num| {
55                         Mutex::new(RecvChan {
56                             num,
57                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
58                             seqnum: INIT_SEQNUM,
59                             splits: HashMap::new(),
60                         })
61                     })
62                     .collect(),
63             ),
64         }
65     }
66
67     pub async fn run(&self) {
68         let cleanup_chans = Arc::clone(&self.chans);
69         let mut cleanup_close = self.close.clone();
70         self.share
71             .tasks
72             .lock()
73             .await
74             /*.build_task()
75             .name("cleanup_splits")*/
76             .spawn(async move {
77                 let timeout = Duration::from_secs(TIMEOUT);
78
79                 ticker!(timeout, cleanup_close, {
80                     for chan_mtx in cleanup_chans.iter() {
81                         let mut chan = chan_mtx.lock().await;
82                         chan.splits = chan
83                             .splits
84                             .drain_filter(
85                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
86                             )
87                             .collect();
88                     }
89                 });
90             });
91
92         let mut close = self.close.clone();
93         loop {
94             if let Err(e) = self.handle(self.recv_pkt(&mut close).await) {
95                 if let Error::LocalDisco = e {
96                     self.share
97                         .send(
98                             PktType::Ctl,
99                             Pkt {
100                                 unrel: true,
101                                 chan: 0,
102                                 data: &[CtlType::Disco as u8],
103                             },
104                         )
105                         .await
106                         .ok();
107                 }
108                 break;
109             }
110         }
111     }
112
113     async fn recv_pkt(&self, close: &mut watch::Receiver<bool>) -> Result<()> {
114         use Error::*;
115
116         // TODO: reset timeout
117         let mut cursor = io::Cursor::new(tokio::select! {
118             pkt = self.udp_rx.recv() => pkt?,
119             _ = close.changed() => return Err(LocalDisco),
120         });
121
122         println!("recv");
123
124         let proto_id = cursor.read_u32::<BigEndian>()?;
125         if proto_id != PROTO_ID {
126             return Err(InvalidProtoId(proto_id));
127         }
128
129         let _peer_id = cursor.read_u16::<BigEndian>()?;
130
131         let n_chan = cursor.read_u8()?;
132         let mut chan = self
133             .chans
134             .get(n_chan as usize)
135             .ok_or(InvalidChannel(n_chan))?
136             .lock()
137             .await;
138
139         self.process_pkt(cursor, true, &mut chan).await
140     }
141
142     #[async_recursion]
143     async fn process_pkt(
144         &self,
145         mut cursor: io::Cursor<Vec<u8>>,
146         unrel: bool,
147         chan: &mut RecvChan,
148     ) -> Result<()> {
149         use Error::*;
150
151         match cursor.read_u8()?.try_into()? {
152             PktType::Ctl => match cursor.read_u8()?.try_into()? {
153                 CtlType::Ack => {
154                     println!("Ack");
155
156                     let seqnum = cursor.read_u16::<BigEndian>()?;
157                     if let Some(ack) = self.share.chans[chan.num as usize]
158                         .lock()
159                         .await
160                         .acks
161                         .remove(&seqnum)
162                     {
163                         ack.tx.send(true).ok();
164                     }
165                 }
166                 CtlType::SetPeerID => {
167                     println!("SetPeerID");
168
169                     let mut id = self.share.remote_id.write().await;
170
171                     if *id != PeerID::Nil as u16 {
172                         return Err(PeerIDAlreadySet);
173                     }
174
175                     *id = cursor.read_u16::<BigEndian>()?;
176                 }
177                 CtlType::Ping => {
178                     println!("Ping");
179                 }
180                 CtlType::Disco => {
181                     println!("Disco");
182                     return Err(RemoteDisco);
183                 }
184             },
185             PktType::Orig => {
186                 println!("Orig");
187
188                 self.pkt_tx.send(Ok(Pkt {
189                     chan: chan.num,
190                     unrel,
191                     data: cursor.remaining_slice().into(),
192                 }))?;
193             }
194             PktType::Split => {
195                 println!("Split");
196
197                 let seqnum = cursor.read_u16::<BigEndian>()?;
198                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
199                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
200
201                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
202                     got: 0,
203                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
204                     timestamp: None,
205                 });
206
207                 if split.chunks.len() != chunk_count {
208                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
209                 }
210
211                 if split
212                     .chunks
213                     .get(chunk_index)
214                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
215                     .set(cursor.remaining_slice().into())
216                     .is_ok()
217                 {
218                     split.got += 1;
219                 }
220
221                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
222
223                 if split.got == chunk_count {
224                     self.pkt_tx.send(Ok(Pkt {
225                         chan: chan.num,
226                         unrel,
227                         data: split
228                             .chunks
229                             .iter()
230                             .flat_map(|chunk| chunk.get().unwrap().iter())
231                             .copied()
232                             .collect(),
233                     }))?;
234
235                     chan.splits.remove(&seqnum);
236                 }
237             }
238             PktType::Rel => {
239                 println!("Rel");
240
241                 let seqnum = cursor.read_u16::<BigEndian>()?;
242                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
243
244                 let mut ack_data = Vec::with_capacity(3);
245                 ack_data.write_u8(CtlType::Ack as u8)?;
246                 ack_data.write_u16::<BigEndian>(seqnum)?;
247
248                 self.share
249                     .send(
250                         PktType::Ctl,
251                         Pkt {
252                             unrel: true,
253                             chan: chan.num,
254                             data: &ack_data,
255                         },
256                     )
257                     .await?;
258
259                 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
260                     chan.packets[to_seqnum(chan.seqnum)].take()
261                 }
262
263                 while let Some(pkt) = next_pkt(chan) {
264                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
265                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
266                 }
267             }
268         }
269
270         Ok(())
271     }
272
273     fn handle(&self, res: Result<()>) -> Result<()> {
274         use Error::*;
275
276         match res {
277             Ok(v) => Ok(v),
278             Err(RemoteDisco) => Err(RemoteDisco),
279             Err(LocalDisco) => Err(LocalDisco),
280             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
281         }
282     }
283 }