]> git.lizzy.rs Git - mt_rudp.git/blob - src/worker.rs
Remove debug print
[mt_rudp.git] / src / worker.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::{
13     sync::{mpsc, watch},
14     time::{interval, sleep, Interval, Sleep},
15 };
16
17 fn to_seqnum(seqnum: u16) -> usize {
18     (seqnum as usize) & (REL_BUFFER - 1)
19 }
20
21 type Result<T> = std::result::Result<T, Error>;
22
23 #[derive(Debug)]
24 struct Split {
25     timestamp: Option<Instant>,
26     chunks: Vec<Option<Vec<u8>>>,
27     got: usize,
28 }
29
30 #[derive(Debug)]
31 struct RecvChan {
32     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
33     splits: HashMap<u16, Split>,
34     seqnum: u16,
35 }
36
37 #[derive(Debug)]
38 pub struct Worker<S: UdpSender, R: UdpReceiver> {
39     sender: Arc<Sender<S>>,
40     chans: [RecvChan; NUM_CHANS],
41     input: R,
42     close: watch::Receiver<bool>,
43     resend: Interval,
44     ping: Interval,
45     cleanup: Interval,
46     timeout: Pin<Box<Sleep>>,
47     output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
48     closed: bool,
49 }
50
51 impl<S: UdpSender, R: UdpReceiver> Worker<S, R> {
52     pub(crate) fn new(
53         input: R,
54         close: watch::Receiver<bool>,
55         sender: Arc<Sender<S>>,
56         output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
57     ) -> Self {
58         Self {
59             input,
60             sender,
61             close,
62             output,
63             resend: interval(Duration::from_millis(500)),
64             ping: interval(Duration::from_secs(PING_TIMEOUT)),
65             cleanup: interval(Duration::from_secs(TIMEOUT)),
66             timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
67             closed: false,
68             chans: std::array::from_fn(|_| RecvChan {
69                 packets: (0..REL_BUFFER).map(|_| None).collect(),
70                 seqnum: INIT_SEQNUM,
71                 splits: HashMap::new(),
72             }),
73         }
74     }
75
76     pub async fn run(mut self) {
77         use Error::*;
78
79         while !self.closed {
80             tokio::select! {
81                 _ = self.close.changed() => {
82                     self.sender
83                         .send_rudp_type(
84                             PktType::Ctl,
85                             None,
86                             Pkt {
87                                 unrel: true,
88                                 chan: 0,
89                                 data: Cow::Borrowed(&[CtlType::Disco as u8]),
90                             },
91                         )
92                         .await
93                         .ok();
94
95                     self.output.send(Err(LocalDisco)).ok();
96                     break;
97                 },
98                 _ = &mut self.timeout => {
99                     self.output.send(Err(RemoteDisco(true))).ok();
100                     break;
101                 },
102                 _ = self.cleanup.tick() => {
103                     let timeout = Duration::from_secs(TIMEOUT);
104
105                     for chan in self.chans.iter_mut() {
106                         chan.splits.retain(|_, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout));
107                     }
108                 },
109                 _ = self.resend.tick() => {
110                     for chan in self.sender.chans.iter() {
111                         for (_, ack) in chan.lock().await.acks.iter() {
112                             self.sender.send_udp(&ack.data).await.ok();
113                         }
114                     }
115                 },
116                 _ = self.ping.tick() => {
117                     self.sender
118                         .send_rudp_type(
119                             PktType::Ctl,
120                             None,
121                             Pkt {
122                                 chan: 0,
123                                 unrel: false,
124                                 data: Cow::Borrowed(&[CtlType::Ping as u8]),
125                             },
126                         )
127                         .await
128                         .ok();
129                 }
130                 pkt = self.input.recv() => {
131                     if let Err(e) = self.handle_pkt(pkt).await {
132                         self.output.send(Err(e)).ok();
133                     }
134                 }
135             }
136         }
137     }
138
139     async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
140         use Error::*;
141
142         let mut cursor = io::Cursor::new(pkt?);
143
144         self.timeout
145             .as_mut()
146             .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
147
148         let proto_id = cursor.read_u32::<BigEndian>()?;
149         if proto_id != PROTO_ID {
150             return Err(InvalidProtoId(proto_id));
151         }
152
153         let _peer_id = cursor.read_u16::<BigEndian>()?;
154
155         let chan = cursor.read_u8()?;
156         if chan >= NUM_CHANS as u8 {
157             return Err(InvalidChannel(chan));
158         }
159
160         self.process_pkt(cursor, true, chan).await
161     }
162
163     #[async_recursion]
164     async fn process_pkt(
165         &mut self,
166         mut cursor: io::Cursor<Vec<u8>>,
167         unrel: bool,
168         chan: u8,
169     ) -> Result<()> {
170         use Error::*;
171
172         let ch = chan as usize;
173         match cursor.read_u8()?.try_into()? {
174             PktType::Ctl => match cursor.read_u8()?.try_into()? {
175                 CtlType::Ack => {
176                     let seqnum = cursor.read_u16::<BigEndian>()?;
177                     if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) {
178                         ack.tx.send(true).ok();
179                     }
180                 }
181                 CtlType::SetPeerID => {
182                     let mut id = self.sender.remote_id.write().await;
183
184                     if *id != PeerID::Nil as u16 {
185                         return Err(PeerIDAlreadySet);
186                     }
187
188                     *id = cursor.read_u16::<BigEndian>()?;
189                 }
190                 CtlType::Ping => {}
191                 CtlType::Disco => {
192                     self.closed = true;
193                     return Err(RemoteDisco(false));
194                 }
195             },
196             PktType::Orig => {
197                 self.output
198                     .send(Ok(Pkt {
199                         chan,
200                         unrel,
201                         data: Cow::Owned(cursor.remaining_slice().into()),
202                     }))
203                     .ok();
204             }
205             PktType::Split => {
206                 let seqnum = cursor.read_u16::<BigEndian>()?;
207                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
208                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
209
210                 let mut split = self.chans[ch]
211                     .splits
212                     .entry(seqnum)
213                     .or_insert_with(|| Split {
214                         got: 0,
215                         chunks: (0..chunk_count).map(|_| None).collect(),
216                         timestamp: None,
217                     });
218
219                 if split.chunks.len() != chunk_count {
220                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
221                 }
222
223                 if split
224                     .chunks
225                     .get_mut(chunk_index)
226                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
227                     .replace(cursor.remaining_slice().into())
228                     .is_none()
229                 {
230                     split.got += 1;
231                 }
232
233                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
234
235                 if split.got == chunk_count {
236                     let split = self.chans[ch].splits.remove(&seqnum).unwrap();
237
238                     self.output
239                         .send(Ok(Pkt {
240                             chan,
241                             unrel,
242                             data: split
243                                 .chunks
244                                 .into_iter()
245                                 .map(|x| x.unwrap())
246                                 .reduce(|mut a, mut b| {
247                                     a.append(&mut b);
248                                     a
249                                 })
250                                 .unwrap_or_default()
251                                 .into(),
252                         }))
253                         .ok();
254                 }
255             }
256             PktType::Rel => {
257                 let seqnum = cursor.read_u16::<BigEndian>()?;
258                 self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
259
260                 let mut ack_data = Vec::with_capacity(3);
261                 ack_data.write_u8(CtlType::Ack as u8)?;
262                 ack_data.write_u16::<BigEndian>(seqnum)?;
263
264                 self.sender
265                     .send_rudp_type(
266                         PktType::Ctl,
267                         None,
268                         Pkt {
269                             chan,
270                             unrel: true,
271                             data: ack_data.into(),
272                         },
273                     )
274                     .await?;
275
276                 let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
277                 while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
278                     if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await {
279                         self.output.send(Err(e)).ok();
280                     }
281
282                     self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
283                 }
284             }
285         }
286
287         Ok(())
288     }
289 }