]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
send acks
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     sync::{Arc, Weak},
9     time,
10 };
11 use tokio::sync::{mpsc, Mutex};
12
13 fn to_seqnum(seqnum: u16) -> usize {
14     (seqnum as usize) & (REL_BUFFER - 1)
15 }
16
17 type Result<T> = std::result::Result<T, Error>;
18
19 struct Split {
20     timestamp: Option<time::Instant>,
21     chunks: Vec<OnceCell<Vec<u8>>>,
22     got: usize,
23 }
24
25 struct Chan {
26     packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
27     splits: HashMap<u16, Split>,
28     seqnum: u16,
29     num: u8,
30 }
31
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33     share: Arc<RudpShare<S>>,
34     chans: Arc<Vec<Mutex<Chan>>>,
35     pkt_tx: mpsc::UnboundedSender<InPkt>,
36     udp_rx: R,
37 }
38
39 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
40     pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
41         Self {
42             udp_rx,
43             share,
44             pkt_tx,
45             chans: Arc::new(
46                 (0..NUM_CHANS as u8)
47                     .map(|num| {
48                         Mutex::new(Chan {
49                             num,
50                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
51                             seqnum: INIT_SEQNUM,
52                             splits: HashMap::new(),
53                         })
54                     })
55                     .collect(),
56             ),
57         }
58     }
59
60     pub async fn run(&self) {
61         let cleanup_chans = Arc::downgrade(&self.chans);
62         tokio::spawn(async move {
63             let timeout = time::Duration::from_secs(TIMEOUT);
64             let mut interval = tokio::time::interval(timeout);
65
66             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
67                 for chan in chans.iter() {
68                     let mut ch = chan.lock().await;
69                     ch.splits = ch
70                         .splits
71                         .drain_filter(
72                             |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
73                         )
74                         .collect();
75                 }
76
77                 interval.tick().await;
78             }
79         });
80
81         loop {
82             if let Err(e) = self.handle(self.recv_pkt().await) {
83                 if let Error::LocalDisco = e {
84                     self.share
85                         .send(
86                             PktType::Ctl,
87                             Pkt {
88                                 unrel: true,
89                                 chan: 0,
90                                 data: &[CtlType::Disco as u8],
91                             },
92                         )
93                         .await
94                         .ok();
95                 }
96                 break;
97             }
98         }
99     }
100
101     async fn recv_pkt(&self) -> Result<()> {
102         use Error::*;
103
104         // todo: reset timeout
105         let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
106
107         let proto_id = cursor.read_u32::<BigEndian>()?;
108         if proto_id != PROTO_ID {
109             return Err(InvalidProtoId(proto_id));
110         }
111
112         let _peer_id = cursor.read_u16::<BigEndian>()?;
113
114         let n_chan = cursor.read_u8()?;
115         let mut chan = self
116             .chans
117             .get(n_chan as usize)
118             .ok_or(InvalidChannel(n_chan))?
119             .lock()
120             .await;
121
122         self.process_pkt(cursor, true, &mut chan).await
123     }
124
125     #[async_recursion]
126     async fn process_pkt(
127         &self,
128         mut cursor: io::Cursor<Vec<u8>>,
129         unrel: bool,
130         chan: &mut Chan,
131     ) -> Result<()> {
132         use Error::*;
133
134         match cursor.read_u8()?.try_into()? {
135             PktType::Ctl => match cursor.read_u8()?.try_into()? {
136                 CtlType::Ack => {
137                     let seqnum = cursor.read_u16::<BigEndian>()?;
138                     if let Some((tx, _)) = self.share.ack_chans.lock().await.remove(&seqnum) {
139                         tx.send(true).ok();
140                     }
141                 }
142                 CtlType::SetPeerID => {
143                     let mut id = self.share.remote_id.write().await;
144
145                     if *id != PeerID::Nil as u16 {
146                         return Err(PeerIDAlreadySet);
147                     }
148
149                     *id = cursor.read_u16::<BigEndian>()?;
150                 }
151                 CtlType::Ping => {}
152                 CtlType::Disco => return Err(RemoteDisco),
153             },
154             PktType::Orig => {
155                 println!("Orig");
156
157                 self.pkt_tx.send(Ok(Pkt {
158                     chan: chan.num,
159                     unrel,
160                     data: cursor.remaining_slice().into(),
161                 }))?;
162             }
163             PktType::Split => {
164                 println!("Split");
165
166                 let seqnum = cursor.read_u16::<BigEndian>()?;
167                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
168                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
169
170                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
171                     got: 0,
172                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
173                     timestamp: None,
174                 });
175
176                 if split.chunks.len() != chunk_count {
177                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
178                 }
179
180                 if split
181                     .chunks
182                     .get(chunk_index)
183                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
184                     .set(cursor.remaining_slice().into())
185                     .is_ok()
186                 {
187                     split.got += 1;
188                 }
189
190                 split.timestamp = if unrel {
191                     Some(time::Instant::now())
192                 } else {
193                     None
194                 };
195
196                 if split.got == chunk_count {
197                     self.pkt_tx.send(Ok(Pkt {
198                         chan: chan.num,
199                         unrel,
200                         data: split
201                             .chunks
202                             .iter()
203                             .flat_map(|chunk| chunk.get().unwrap().iter())
204                             .copied()
205                             .collect(),
206                     }))?;
207
208                     chan.splits.remove(&seqnum);
209                 }
210             }
211             PktType::Rel => {
212                 println!("Rel");
213
214                 let seqnum = cursor.read_u16::<BigEndian>()?;
215                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
216
217                 let mut ack_data = Vec::with_capacity(3);
218                 ack_data.write_u8(CtlType::Ack as u8)?;
219                 ack_data.write_u16::<BigEndian>(seqnum)?;
220
221                 self.share
222                     .send(
223                         PktType::Ctl,
224                         Pkt {
225                             unrel: true,
226                             chan: chan.num,
227                             data: &ack_data,
228                         },
229                     )
230                     .await?;
231
232                 fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
233                     chan.packets[to_seqnum(chan.seqnum)].take()
234                 }
235
236                 while let Some(pkt) = next_pkt(chan) {
237                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
238                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
239                 }
240             }
241         }
242
243         Ok(())
244     }
245
246     fn handle(&self, res: Result<()>) -> Result<()> {
247         use Error::*;
248
249         match res {
250             Ok(v) => Ok(v),
251             Err(RemoteDisco) => Err(RemoteDisco),
252             Err(LocalDisco) => Err(LocalDisco),
253             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
254         }
255     }
256 }