]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
finish receiver
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
4 use std::{
5     cell::{Cell, OnceCell},
6     collections::HashMap,
7     io,
8     sync::{Arc, Weak},
9     time,
10 };
11 use tokio::sync::{mpsc, Mutex};
12
13 fn to_seqnum(seqnum: u16) -> usize {
14     (seqnum as usize) & (REL_BUFFER - 1)
15 }
16
17 type Result<T> = std::result::Result<T, Error>;
18
19 struct Split {
20     timestamp: Option<time::Instant>,
21     chunks: Vec<OnceCell<Vec<u8>>>,
22     got: usize,
23 }
24
25 struct Chan {
26     packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
27     splits: HashMap<u16, Split>,
28     seqnum: u16,
29     num: u8,
30 }
31
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33     share: Arc<RudpShare<S>>,
34     chans: Arc<Vec<Mutex<Chan>>>,
35     pkt_tx: mpsc::UnboundedSender<InPkt>,
36     udp_rx: R,
37 }
38
39 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
40     pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
41         Self {
42             udp_rx,
43             share,
44             pkt_tx,
45             chans: Arc::new(
46                 (0..NUM_CHANS as u8)
47                     .map(|num| {
48                         Mutex::new(Chan {
49                             num,
50                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
51                             seqnum: INIT_SEQNUM,
52                             splits: HashMap::new(),
53                         })
54                     })
55                     .collect(),
56             ),
57         }
58     }
59
60     pub async fn run(&self) {
61         let cleanup_chans = Arc::downgrade(&self.chans);
62         tokio::spawn(async move {
63             let timeout = time::Duration::from_secs(TIMEOUT);
64             let mut interval = tokio::time::interval(timeout);
65
66             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
67                 for chan in chans.iter() {
68                     let mut ch = chan.lock().await;
69                     ch.splits = ch
70                         .splits
71                         .drain_filter(
72                             |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
73                         )
74                         .collect();
75                 }
76
77                 interval.tick().await;
78             }
79         });
80
81         loop {
82             if let Err(e) = self.handle(self.recv_pkt().await) {
83                 if let Error::LocalDisco = e {
84                     self.share
85                         .send(
86                             PktType::Ctl,
87                             Pkt {
88                                 unrel: true,
89                                 chan: 0,
90                                 data: &[CtlType::Disco as u8],
91                             },
92                         )
93                         .await
94                         .ok();
95                 }
96                 break;
97             }
98         }
99     }
100
101     async fn recv_pkt(&self) -> Result<()> {
102         use Error::*;
103
104         // todo: reset timeout
105         let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
106
107         let proto_id = cursor.read_u32::<BigEndian>()?;
108         if proto_id != PROTO_ID {
109             return Err(InvalidProtoId(proto_id));
110         }
111
112         let _peer_id = cursor.read_u16::<BigEndian>()?;
113
114         let n_chan = cursor.read_u8()?;
115         let mut chan = self
116             .chans
117             .get(n_chan as usize)
118             .ok_or(InvalidChannel(n_chan))?
119             .lock()
120             .await;
121
122         self.process_pkt(cursor, true, &mut chan).await
123     }
124
125     #[async_recursion]
126     async fn process_pkt(
127         &self,
128         mut cursor: io::Cursor<Vec<u8>>,
129         unrel: bool,
130         chan: &mut Chan,
131     ) -> Result<()> {
132         use Error::*;
133
134         match cursor.read_u8()?.try_into()? {
135             PktType::Ctl => match cursor.read_u8()?.try_into()? {
136                 CtlType::Ack => {
137                     let seqnum = cursor.read_u16::<BigEndian>()?;
138                     self.share.ack_chans.lock().await.remove(&seqnum);
139                 }
140                 CtlType::SetPeerID => {
141                     let mut id = self.share.remote_id.write().await;
142
143                     if *id != PeerID::Nil as u16 {
144                         return Err(PeerIDAlreadySet);
145                     }
146
147                     *id = cursor.read_u16::<BigEndian>()?;
148                 }
149                 CtlType::Ping => {}
150                 CtlType::Disco => return Err(RemoteDisco),
151             },
152             PktType::Orig => {
153                 println!("Orig");
154
155                 self.pkt_tx.send(Ok(Pkt {
156                     chan: chan.num,
157                     unrel,
158                     data: cursor.remaining_slice().into(),
159                 }))?;
160             }
161             PktType::Split => {
162                 println!("Split");
163
164                 let seqnum = cursor.read_u16::<BigEndian>()?;
165                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
166                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
167
168                 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
169                     got: 0,
170                     chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
171                     timestamp: None,
172                 });
173
174                 if split.chunks.len() != chunk_count {
175                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
176                 }
177
178                 if split
179                     .chunks
180                     .get(chunk_index)
181                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
182                     .set(cursor.remaining_slice().into())
183                     .is_ok()
184                 {
185                     split.got += 1;
186                 }
187
188                 split.timestamp = if unrel {
189                     Some(time::Instant::now())
190                 } else {
191                     None
192                 };
193
194                 if split.got == chunk_count {
195                     self.pkt_tx.send(Ok(Pkt {
196                         chan: chan.num,
197                         unrel,
198                         data: split
199                             .chunks
200                             .iter()
201                             .flat_map(|chunk| chunk.get().unwrap().iter())
202                             .copied()
203                             .collect(),
204                     }))?;
205
206                     chan.splits.remove(&seqnum);
207                 }
208             }
209             PktType::Rel => {
210                 println!("Rel");
211
212                 let seqnum = cursor.read_u16::<BigEndian>()?;
213                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
214
215                 fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
216                     chan.packets[to_seqnum(chan.seqnum)].take()
217                 }
218
219                 while let Some(pkt) = next_pkt(chan) {
220                     self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
221                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
222                 }
223             }
224         }
225
226         Ok(())
227     }
228
229     fn handle(&self, res: Result<()>) -> Result<()> {
230         use Error::*;
231
232         match res {
233             Ok(v) => Ok(v),
234             Err(RemoteDisco) => Err(RemoteDisco),
235             Err(LocalDisco) => Err(LocalDisco),
236             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
237         }
238     }
239 }