]> git.lizzy.rs Git - mt_rudp.git/blob - src/worker.rs
Properly close on rudp disco
[mt_rudp.git] / src / worker.rs
1 use super::*;
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use std::{
5     borrow::Cow,
6     collections::HashMap,
7     io,
8     pin::Pin,
9     sync::Arc,
10     time::{Duration, Instant},
11 };
12 use tokio::{
13     sync::{mpsc, watch},
14     time::{interval, sleep, Interval, Sleep},
15 };
16
17 fn to_seqnum(seqnum: u16) -> usize {
18     (seqnum as usize) & (REL_BUFFER - 1)
19 }
20
21 type Result<T> = std::result::Result<T, Error>;
22
23 #[derive(Debug)]
24 struct Split {
25     timestamp: Option<Instant>,
26     chunks: Vec<Option<Vec<u8>>>,
27     got: usize,
28 }
29
30 #[derive(Debug)]
31 struct RecvChan {
32     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
33     splits: HashMap<u16, Split>,
34     seqnum: u16,
35 }
36
37 #[derive(Debug)]
38 pub struct Worker<S: UdpSender, R: UdpReceiver> {
39     sender: Arc<Sender<S>>,
40     chans: [RecvChan; NUM_CHANS],
41     input: R,
42     close: watch::Receiver<bool>,
43     resend: Interval,
44     ping: Interval,
45     cleanup: Interval,
46     timeout: Pin<Box<Sleep>>,
47     output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
48     closed: bool,
49 }
50
51 impl<S: UdpSender, R: UdpReceiver> Worker<S, R> {
52     pub(crate) fn new(
53         input: R,
54         close: watch::Receiver<bool>,
55         sender: Arc<Sender<S>>,
56         output: mpsc::UnboundedSender<Result<Pkt<'static>>>,
57     ) -> Self {
58         Self {
59             input,
60             sender,
61             close,
62             output,
63             resend: interval(Duration::from_millis(500)),
64             ping: interval(Duration::from_secs(PING_TIMEOUT)),
65             cleanup: interval(Duration::from_secs(TIMEOUT)),
66             timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
67             closed: false,
68             chans: std::array::from_fn(|_| RecvChan {
69                 packets: (0..REL_BUFFER).map(|_| None).collect(),
70                 seqnum: INIT_SEQNUM,
71                 splits: HashMap::new(),
72             }),
73         }
74     }
75
76     pub async fn run(mut self) {
77         use Error::*;
78
79         while !self.closed {
80             tokio::select! {
81                 _ = self.close.changed() => {
82                     self.sender
83                         .send_rudp_type(
84                             PktType::Ctl,
85                             Pkt {
86                                 unrel: true,
87                                 chan: 0,
88                                 data: Cow::Borrowed(&[CtlType::Disco as u8]),
89                             },
90                         )
91                         .await
92                         .ok();
93
94                     self.output.send(Err(LocalDisco)).ok();
95                     break;
96                 },
97                 _ = &mut self.timeout => {
98                     self.output.send(Err(RemoteDisco(true))).ok();
99                     break;
100                 },
101                 _ = self.cleanup.tick() => {
102                     let timeout = Duration::from_secs(TIMEOUT);
103
104                     for chan in self.chans.iter_mut() {
105                         chan.splits = chan
106                             .splits
107                             .drain_filter(
108                                 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
109                             )
110                             .collect();
111                     }
112                 },
113                 _ = self.resend.tick() => {
114                     for chan in self.sender.chans.iter() {
115                         for (_, ack) in chan.lock().await.acks.iter() {
116                             self.sender.send_udp(&ack.data).await.ok();
117                         }
118                     }
119                 },
120                 _ = self.ping.tick() => {
121                     self.sender
122                         .send_rudp_type(
123                             PktType::Ctl,
124                             Pkt {
125                                 chan: 0,
126                                 unrel: false,
127                                 data: Cow::Borrowed(&[CtlType::Ping as u8]),
128                             },
129                         )
130                         .await
131                         .ok();
132                 }
133                 pkt = self.input.recv() => {
134                     if let Err(e) = self.handle_pkt(pkt).await {
135                         self.output.send(Err(e)).ok();
136                     }
137                 }
138             }
139         }
140     }
141
142     async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
143         use Error::*;
144
145         let mut cursor = io::Cursor::new(pkt?);
146
147         self.timeout
148             .as_mut()
149             .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
150
151         let proto_id = cursor.read_u32::<BigEndian>()?;
152         if proto_id != PROTO_ID {
153             return Err(InvalidProtoId(proto_id));
154         }
155
156         let _peer_id = cursor.read_u16::<BigEndian>()?;
157
158         let chan = cursor.read_u8()?;
159         if chan >= NUM_CHANS as u8 {
160             return Err(InvalidChannel(chan));
161         }
162
163         self.process_pkt(cursor, true, chan).await
164     }
165
166     #[async_recursion]
167     async fn process_pkt(
168         &mut self,
169         mut cursor: io::Cursor<Vec<u8>>,
170         unrel: bool,
171         chan: u8,
172     ) -> Result<()> {
173         use Error::*;
174
175         let ch = chan as usize;
176         match cursor.read_u8()?.try_into()? {
177             PktType::Ctl => match cursor.read_u8()?.try_into()? {
178                 CtlType::Ack => {
179                     let seqnum = cursor.read_u16::<BigEndian>()?;
180                     if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) {
181                         ack.tx.send(true).ok();
182                     }
183                 }
184                 CtlType::SetPeerID => {
185                     let mut id = self.sender.remote_id.write().await;
186
187                     if *id != PeerID::Nil as u16 {
188                         return Err(PeerIDAlreadySet);
189                     }
190
191                     *id = cursor.read_u16::<BigEndian>()?;
192                 }
193                 CtlType::Ping => {}
194                 CtlType::Disco => {
195                     self.closed = true;
196                     return Err(RemoteDisco(false));
197                 }
198             },
199             PktType::Orig => {
200                 self.output
201                     .send(Ok(Pkt {
202                         chan,
203                         unrel,
204                         data: Cow::Owned(cursor.remaining_slice().into()),
205                     }))
206                     .ok();
207             }
208             PktType::Split => {
209                 let seqnum = cursor.read_u16::<BigEndian>()?;
210                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
211                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
212
213                 let mut split = self.chans[ch]
214                     .splits
215                     .entry(seqnum)
216                     .or_insert_with(|| Split {
217                         got: 0,
218                         chunks: (0..chunk_count).map(|_| None).collect(),
219                         timestamp: None,
220                     });
221
222                 if split.chunks.len() != chunk_count {
223                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
224                 }
225
226                 if split
227                     .chunks
228                     .get_mut(chunk_index)
229                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
230                     .replace(cursor.remaining_slice().into())
231                     .is_none()
232                 {
233                     split.got += 1;
234                 }
235
236                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
237
238                 if split.got == chunk_count {
239                     let split = self.chans[ch].splits.remove(&seqnum).unwrap();
240
241                     self.output
242                         .send(Ok(Pkt {
243                             chan,
244                             unrel,
245                             data: split
246                                 .chunks
247                                 .into_iter()
248                                 .map(|x| x.unwrap())
249                                 .reduce(|mut a, mut b| {
250                                     a.append(&mut b);
251                                     a
252                                 })
253                                 .unwrap_or_default()
254                                 .into(),
255                         }))
256                         .ok();
257                 }
258             }
259             PktType::Rel => {
260                 let seqnum = cursor.read_u16::<BigEndian>()?;
261                 self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
262
263                 let mut ack_data = Vec::with_capacity(3);
264                 ack_data.write_u8(CtlType::Ack as u8)?;
265                 ack_data.write_u16::<BigEndian>(seqnum)?;
266
267                 self.sender
268                     .send_rudp_type(
269                         PktType::Ctl,
270                         Pkt {
271                             chan,
272                             unrel: true,
273                             data: ack_data.into(),
274                         },
275                     )
276                     .await?;
277
278                 let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
279                 while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
280                     if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await {
281                         self.output.send(Err(e)).ok();
282                     }
283
284                     self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
285                 }
286             }
287         }
288
289         Ok(())
290     }
291 }