]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
f83e8efb7dca69ec722723f3f85ef130f94f57de
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     sync::{Arc, Weak},
9     time,
10 };
11 use tokio::sync::{mpsc, Mutex};
12
13 fn to_seqnum(seqnum: u16) -> usize {
14     (seqnum as usize) & (REL_BUFFER - 1)
15 }
16
17 type Result<T> = std::result::Result<T, Error>;
18
19 struct Split {
20     timestamp: Option<time::Instant>,
21     chunks: Vec<OnceCell<Vec<u8>>>,
22     got: usize,
23 }
24
25 struct Chan {
26     packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
27     splits: HashMap<u16, Split>,
28     seqnum: u16,
29     num: u8,
30 }
31
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33     share: Arc<RudpShare<S>>,
34     chans: Arc<Vec<Mutex<Chan>>>,
35     pkt_tx: mpsc::UnboundedSender<InPkt>,
36     udp_rx: R,
37 }
38
39 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
40     pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
41         Self {
42             udp_rx,
43             share,
44             pkt_tx,
45             chans: Arc::new(
46                 (0..NUM_CHANS as u8)
47                     .map(|num| {
48                         Mutex::new(Chan {
49                             num,
50                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
51                             seqnum: INIT_SEQNUM,
52                             splits: HashMap::new(),
53                         })
54                     })
55                     .collect(),
56             ),
57         }
58     }
59
60     pub async fn run(&self) {
61         let cleanup_chans = Arc::downgrade(&self.chans);
62         tokio::spawn(async move {
63             let timeout = time::Duration::from_secs(TIMEOUT);
64             let mut interval = tokio::time::interval(timeout);
65
66             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
67                 for chan in chans.iter() {
68                     let mut ch = chan.lock().await;
69                     ch.splits = ch
70                         .splits
71                         .drain_filter(
72                             |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
73                         )
74                         .collect();
75                 }
76
77                 interval.tick().await;
78             }
79         });
80
81         loop {
82             if let Err(e) = self.handle(self.recv_pkt().await) {
83                 if let Error::LocalDisco = e {
84                     self.share
85                         .send(
86                             PktType::Ctl,
87                             Pkt {
88                                 unrel: true,
89                                 chan: 0,
90                                 data: &[CtlType::Disco as u8],
91                             },
92                         )
93                         .await
94                         .ok();
95                 }
96                 break;
97             }
98         }
99     }
100
101     async fn recv_pkt(&self) -> Result<()> {
102         use Error::*;
103
104         // todo: reset timeout
105         let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
106
107         let proto_id = cursor.read_u32::<BigEndian>()?;
108         if proto_id != PROTO_ID {
109             return Err(InvalidProtoId(proto_id));
110         }
111
112         let _peer_id = cursor.read_u16::<BigEndian>()?;
113
114         let n_chan = cursor.read_u8()?;
115         let mut chan = self
116             .chans
117             .get(n_chan as usize)
118             .ok_or(InvalidChannel(n_chan))?
119             .lock()
120             .await;
121
122         self.process_pkt(cursor, true, &mut chan).await
123     }
124
125     #[async_recursion]
126     async fn process_pkt(
127         &self,
128         mut cursor: io::Cursor<Vec<u8>>,
129         unrel: bool,
130         chan: &mut Chan,
131     ) -> Result<()> {
132         use Error::*;
133
134         match cursor.read_u8()?.try_into()? {
135             PktType::Ctl => match cursor.read_u8()?.try_into()? {
136                 CtlType::Ack => { /* TODO */ }
137                 CtlType::SetPeerID => {
138                     let mut id = self.share.remote_id.write().await;
139
140                     if *id != PeerID::Nil as u16 {
141                         return Err(PeerIDAlreadySet);
142                     }
143
144                     *id = cursor.read_u16::<BigEndian>()?;
145                 }
146                 CtlType::Ping => {}
147                 CtlType::Disco => return Err(RemoteDisco),
148             },
149             PktType::Orig => {
150                 println!("Orig");
151
152                 self.pkt_tx.send(Ok(Pkt {
153                     chan: chan.num,
154                     unrel,
155                     data: cursor.remaining_slice().into(),
156                 }))?;
157             }
158             PktType::Split => {
159                 println!("Split");
160
161                 let seqnum = cursor.read_u16::<BigEndian>()?;
162                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
163                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
164
165                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
166                     got: 0,
167                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
168                     timestamp: None,
169                 });
170
171                 if split.chunks.len() != chunk_count {
172                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
173                 }
174
175                 if split
176                     .chunks
177                     .get(chunk_index)
178                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
179                     .set(cursor.remaining_slice().into())
180                     .is_ok()
181                 {
182                     split.got += 1;
183                 }
184
185                 split.timestamp = if unrel {
186                     Some(time::Instant::now())
187                 } else {
188                     None
189                 };
190
191                 if split.got == chunk_count {
192                     self.pkt_tx.send(Ok(Pkt {
193                         chan: chan.num,
194                         unrel,
195                         data: split
196                             .chunks
197                             .iter()
198                             .flat_map(|chunk| chunk.get().unwrap().iter())
199                             .copied()
200                             .collect(),
201                     }))?;
202
203                     chan.splits.remove(&seqnum);
204                 }
205             }
206             PktType::Rel => {
207                 println!("Rel");
208
209                 let seqnum = cursor.read_u16::<BigEndian>()?;
210                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
211
212                 fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
213                     chan.packets[to_seqnum(chan.seqnum)].take()
214                 }
215
216                 while let Some(pkt) = next_pkt(chan) {
217                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
218                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
219                 }
220             }
221         }
222
223         Ok(())
224     }
225
226     fn handle(&self, res: Result<()>) -> Result<()> {
227         use Error::*;
228
229         match res {
230             Ok(v) => Ok(v),
231             Err(RemoteDisco) => Err(RemoteDisco),
232             Err(LocalDisco) => Err(LocalDisco),
233             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
234         }
235     }
236 }