]> git.lizzy.rs Git - mt_rudp.git/blob - src/worker.rs
Implement sending splits
[mt_rudp.git] / src / worker.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::{
13     sync::{mpsc, watch},
14     time::{interval, sleep, Interval, Sleep},
15 };
16
17 fn to_seqnum(seqnum: u16) -> usize {
18     (seqnum as usize) & (REL_BUFFER - 1)
19 }
20
21 type Result<T> = std::result::Result<T, Error>;
22
23 #[derive(Debug)]
24 struct Split {
25     timestamp: Option<Instant>,
26     chunks: Vec<Option<Vec<u8>>>,
27     got: usize,
28 }
29
30 #[derive(Debug)]
31 struct RecvChan {
32     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
33     splits: HashMap<u16, Split>,
34     seqnum: u16,
35 }
36
37 #[derive(Debug)]
38 pub struct Worker<S: UdpSender, R: UdpReceiver> {
39     sender: Arc<Sender<S>>,
40     chans: [RecvChan; NUM_CHANS],
41     input: R,
42     close: watch::Receiver<bool>,
43     resend: Interval,
44     ping: Interval,
45     cleanup: Interval,
46     timeout: Pin<Box<Sleep>>,
47     output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
48     closed: bool,
49 }
50
51 impl<S: UdpSender, R: UdpReceiver> Worker<S, R> {
52     pub(crate) fn new(
53         input: R,
54         close: watch::Receiver<bool>,
55         sender: Arc<Sender<S>>,
56         output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
57     ) -> Self {
58         Self {
59             input,
60             sender,
61             close,
62             output,
63             resend: interval(Duration::from_millis(500)),
64             ping: interval(Duration::from_secs(PING_TIMEOUT)),
65             cleanup: interval(Duration::from_secs(TIMEOUT)),
66             timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
67             closed: false,
68             chans: std::array::from_fn(|_| RecvChan {
69                 packets: (0..REL_BUFFER).map(|_| None).collect(),
70                 seqnum: INIT_SEQNUM,
71                 splits: HashMap::new(),
72             }),
73         }
74     }
75
76     pub async fn run(mut self) {
77         use Error::*;
78
79         while !self.closed {
80             tokio::select! {
81                 _ = self.close.changed() => {
82                     self.sender
83                         .send_rudp_type(
84                             PktType::Ctl,
85                             None,
86                             Pkt {
87                                 unrel: true,
88                                 chan: 0,
89                                 data: Cow::Borrowed(&[CtlType::Disco as u8]),
90                             },
91                         )
92                         .await
93                         .ok();
94
95                     self.output.send(Err(LocalDisco)).ok();
96                     break;
97                 },
98                 _ = &mut self.timeout => {
99                     self.output.send(Err(RemoteDisco(true))).ok();
100                     break;
101                 },
102                 _ = self.cleanup.tick() => {
103                     let timeout = Duration::from_secs(TIMEOUT);
104
105                     for chan in self.chans.iter_mut() {
106                         chan.splits = chan
107                             .splits
108                             .drain_filter(
109                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
110                             )
111                             .collect();
112                     }
113                 },
114                 _ = self.resend.tick() => {
115                     for chan in self.sender.chans.iter() {
116                         for (_, ack) in chan.lock().await.acks.iter() {
117                             self.sender.send_udp(&ack.data).await.ok();
118                         }
119                     }
120                 },
121                 _ = self.ping.tick() => {
122                     self.sender
123                         .send_rudp_type(
124                             PktType::Ctl,
125                             None,
126                             Pkt {
127                                 chan: 0,
128                                 unrel: false,
129                                 data: Cow::Borrowed(&[CtlType::Ping as u8]),
130                             },
131                         )
132                         .await
133                         .ok();
134                 }
135                 pkt = self.input.recv() => {
136                     if let Err(e) = self.handle_pkt(pkt).await {
137                         self.output.send(Err(e)).ok();
138                     }
139                 }
140             }
141         }
142     }
143
144     async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
145         use Error::*;
146
147         let mut cursor = io::Cursor::new(pkt?);
148
149         self.timeout
150             .as_mut()
151             .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
152
153         let proto_id = cursor.read_u32::<BigEndian>()?;
154         if proto_id != PROTO_ID {
155             return Err(InvalidProtoId(proto_id));
156         }
157
158         let _peer_id = cursor.read_u16::<BigEndian>()?;
159
160         let chan = cursor.read_u8()?;
161         if chan >= NUM_CHANS as u8 {
162             return Err(InvalidChannel(chan));
163         }
164
165         self.process_pkt(cursor, true, chan).await
166     }
167
168     #[async_recursion]
169     async fn process_pkt(
170         &mut self,
171         mut cursor: io::Cursor<Vec<u8>>,
172         unrel: bool,
173         chan: u8,
174     ) -> Result<()> {
175         use Error::*;
176
177         let ch = chan as usize;
178         match cursor.read_u8()?.try_into()? {
179             PktType::Ctl => match cursor.read_u8()?.try_into()? {
180                 CtlType::Ack => {
181                     let seqnum = cursor.read_u16::<BigEndian>()?;
182                     if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) {
183                         ack.tx.send(true).ok();
184                     }
185                 }
186                 CtlType::SetPeerID => {
187                     let mut id = self.sender.remote_id.write().await;
188
189                     if *id != PeerID::Nil as u16 {
190                         return Err(PeerIDAlreadySet);
191                     }
192
193                     *id = cursor.read_u16::<BigEndian>()?;
194                 }
195                 CtlType::Ping => {}
196                 CtlType::Disco => {
197                     self.closed = true;
198                     return Err(RemoteDisco(false));
199                 }
200             },
201             PktType::Orig => {
202                 self.output
203                     .send(Ok(Pkt {
204                         chan,
205                         unrel,
206                         data: Cow::Owned(cursor.remaining_slice().into()),
207                     }))
208                     .ok();
209             }
210             PktType::Split => {
211                 let seqnum = cursor.read_u16::<BigEndian>()?;
212                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
213                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
214
215                 let mut split = self.chans[ch]
216                     .splits
217                     .entry(seqnum)
218                     .or_insert_with(|| Split {
219                         got: 0,
220                         chunks: (0..chunk_count).map(|_| None).collect(),
221                         timestamp: None,
222                     });
223
224                 if split.chunks.len() != chunk_count {
225                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
226                 }
227
228                 if split
229                     .chunks
230                     .get_mut(chunk_index)
231                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
232                     .replace(cursor.remaining_slice().into())
233                     .is_none()
234                 {
235                     split.got += 1;
236                 }
237
238                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
239
240                 if split.got == chunk_count {
241                     let split = self.chans[ch].splits.remove(&seqnum).unwrap();
242
243                     self.output
244                         .send(Ok(Pkt {
245                             chan,
246                             unrel,
247                             data: split
248                                 .chunks
249                                 .into_iter()
250                                 .map(|x| x.unwrap())
251                                 .reduce(|mut a, mut b| {
252                                     a.append(&mut b);
253                                     a
254                                 })
255                                 .unwrap_or_default()
256                                 .into(),
257                         }))
258                         .ok();
259                 }
260             }
261             PktType::Rel => {
262                 let seqnum = cursor.read_u16::<BigEndian>()?;
263                 self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
264
265                 println!("{seqnum}");
266
267                 let mut ack_data = Vec::with_capacity(3);
268                 ack_data.write_u8(CtlType::Ack as u8)?;
269                 ack_data.write_u16::<BigEndian>(seqnum)?;
270
271                 self.sender
272                     .send_rudp_type(
273                         PktType::Ctl,
274                         None,
275                         Pkt {
276                             chan,
277                             unrel: true,
278                             data: ack_data.into(),
279                         },
280                     )
281                     .await?;
282
283                 let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
284                 while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
285                     if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await {
286                         self.output.send(Err(e)).ok();
287                     }
288
289                     self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
290                 }
291             }
292         }
293
294         Ok(())
295     }
296 }