]> git.lizzy.rs Git - mt_rudp.git/blob - src/worker.rs
72bf2b5741ffeb93ec9b4009e7573a10e9b35f9b
[mt_rudp.git] / src / worker.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::{
13     sync::{mpsc, watch},
14     time::{interval, sleep, Interval, Sleep},
15 };
16
17 fn to_seqnum(seqnum: u16) -> usize {
18     (seqnum as usize) & (REL_BUFFER - 1)
19 }
20
21 type Result<T> = std::result::Result<T, Error>;
22
23 #[derive(Debug)]
24 struct Split {
25     timestamp: Option<Instant>,
26     chunks: Vec<Option<Vec<u8>>>,
27     got: usize,
28 }
29
30 #[derive(Debug)]
31 struct RecvChan {
32     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
33     splits: HashMap<u16, Split>,
34     seqnum: u16,
35 }
36
37 #[derive(Debug)]
38 pub struct Worker<S: UdpSender, R: UdpReceiver> {
39     sender: Arc<Sender<S>>,
40     chans: [RecvChan; NUM_CHANS],
41     input: R,
42     close: watch::Receiver<bool>,
43     resend: Interval,
44     ping: Interval,
45     cleanup: Interval,
46     timeout: Pin<Box<Sleep>>,
47     output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
48 }
49
50 impl<S: UdpSender, R: UdpReceiver> Worker<S, R> {
51     pub(crate) fn new(
52         input: R,
53         close: watch::Receiver<bool>,
54         sender: Arc<Sender<S>>,
55         output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
56     ) -> Self {
57         Self {
58             input,
59             sender,
60             close,
61             output,
62             resend: interval(Duration::from_millis(500)),
63             ping: interval(Duration::from_secs(PING_TIMEOUT)),
64             cleanup: interval(Duration::from_secs(TIMEOUT)),
65             timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
66             chans: std::array::from_fn(|_| RecvChan {
67                 packets: (0..REL_BUFFER).map(|_| None).collect(),
68                 seqnum: INIT_SEQNUM,
69                 splits: HashMap::new(),
70             }),
71         }
72     }
73
74     pub async fn run(mut self) {
75         use Error::*;
76
77         loop {
78             tokio::select! {
79                 _ = self.close.changed() => {
80                     self.sender
81                         .send_rudp_type(
82                             PktType::Ctl,
83                             Pkt {
84                                 unrel: true,
85                                 chan: 0,
86                                 data: Cow::Borrowed(&[CtlType::Disco as u8]),
87                             },
88                         )
89                         .await
90                         .ok();
91
92                     self.output.send(Err(LocalDisco)).ok();
93                     break;
94                 },
95                 _ = &mut self.timeout => {
96                     self.output.send(Err(RemoteDisco(true))).ok();
97                     break;
98                 },
99                 _ = self.cleanup.tick() => {
100                     let timeout = Duration::from_secs(TIMEOUT);
101
102                     for chan in self.chans.iter_mut() {
103                         chan.splits = chan
104                             .splits
105                             .drain_filter(
106                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
107                             )
108                             .collect();
109                     }
110                 },
111                 _ = self.resend.tick() => {
112                     for chan in self.sender.chans.iter() {
113                         for (_, ack) in chan.lock().await.acks.iter() {
114                             self.sender.send_udp(&ack.data).await.ok();
115                         }
116                     }
117                 },
118                 _ = self.ping.tick() => {
119                     self.sender
120                         .send_rudp_type(
121                             PktType::Ctl,
122                             Pkt {
123                                 chan: 0,
124                                 unrel: false,
125                                 data: Cow::Borrowed(&[CtlType::Ping as u8]),
126                             },
127                         )
128                         .await
129                         .ok();
130                 }
131                 pkt = self.input.recv() => {
132                     if let Err(e) = self.handle_pkt(pkt).await {
133                         self.output.send(Err(e)).ok();
134                     }
135                 }
136             }
137         }
138     }
139
140     async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
141         use Error::*;
142
143         let mut cursor = io::Cursor::new(pkt?);
144
145         self.timeout
146             .as_mut()
147             .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
148
149         let proto_id = cursor.read_u32::<BigEndian>()?;
150         if proto_id != PROTO_ID {
151             return Err(InvalidProtoId(proto_id));
152         }
153
154         let _peer_id = cursor.read_u16::<BigEndian>()?;
155
156         let chan = cursor.read_u8()?;
157         if chan >= NUM_CHANS as u8 {
158             return Err(InvalidChannel(chan));
159         }
160
161         self.process_pkt(cursor, true, chan).await
162     }
163
164     #[async_recursion]
165     async fn process_pkt(
166         &mut self,
167         mut cursor: io::Cursor<Vec<u8>>,
168         unrel: bool,
169         chan: u8,
170     ) -> Result<()> {
171         use Error::*;
172
173         let ch = chan as usize;
174         match cursor.read_u8()?.try_into()? {
175             PktType::Ctl => match cursor.read_u8()?.try_into()? {
176                 CtlType::Ack => {
177                     let seqnum = cursor.read_u16::<BigEndian>()?;
178                     if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) {
179                         ack.tx.send(true).ok();
180                     }
181                 }
182                 CtlType::SetPeerID => {
183                     let mut id = self.sender.remote_id.write().await;
184
185                     if *id != PeerID::Nil as u16 {
186                         return Err(PeerIDAlreadySet);
187                     }
188
189                     *id = cursor.read_u16::<BigEndian>()?;
190                 }
191                 CtlType::Ping => {}
192                 CtlType::Disco => {
193                     return Err(RemoteDisco(false));
194                 }
195             },
196             PktType::Orig => {
197                 self.output
198                     .send(Ok(Pkt {
199                         chan,
200                         unrel,
201                         data: Cow::Owned(cursor.remaining_slice().into()),
202                     }))
203                     .ok();
204             }
205             PktType::Split => {
206                 let seqnum = cursor.read_u16::<BigEndian>()?;
207                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
208                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
209
210                 let mut split = self.chans[ch]
211                     .splits
212                     .entry(seqnum)
213                     .or_insert_with(|| Split {
214                         got: 0,
215                         chunks: (0..chunk_count).map(|_| None).collect(),
216                         timestamp: None,
217                     });
218
219                 if split.chunks.len() != chunk_count {
220                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
221                 }
222
223                 if split
224                     .chunks
225                     .get_mut(chunk_index)
226                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
227                     .replace(cursor.remaining_slice().into())
228                     .is_none()
229                 {
230                     split.got += 1;
231                 }
232
233                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
234
235                 if split.got == chunk_count {
236                     let split = self.chans[ch].splits.remove(&seqnum).unwrap();
237
238                     self.output
239                         .send(Ok(Pkt {
240                             chan,
241                             unrel,
242                             data: split
243                                 .chunks
244                                 .into_iter()
245                                 .map(|x| x.unwrap())
246                                 .reduce(|mut a, mut b| {
247                                     a.append(&mut b);
248                                     a
249                                 })
250                                 .unwrap_or_default()
251                                 .into(),
252                         }))
253                         .ok();
254                 }
255             }
256             PktType::Rel => {
257                 let seqnum = cursor.read_u16::<BigEndian>()?;
258                 self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
259
260                 let mut ack_data = Vec::with_capacity(3);
261                 ack_data.write_u8(CtlType::Ack as u8)?;
262                 ack_data.write_u16::<BigEndian>(seqnum)?;
263
264                 self.sender
265                     .send_rudp_type(
266                         PktType::Ctl,
267                         Pkt {
268                             chan,
269                             unrel: true,
270                             data: ack_data.into(),
271                         },
272                     )
273                     .await?;
274
275                 let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
276                 while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
277                     if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await {
278                         self.output.send(Err(e)).ok();
279                     }
280
281                     self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
282                 }
283             }
284         }
285
286         Ok(())
287     }
288 }