]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
async
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use byteorder::{BigEndian, ReadBytesExt};
3 use std::{
4     cell::Cell,
5     collections::HashMap,
6     io,
7     sync::{Arc, Weak},
8     time,
9 };
10 use tokio::sync::{mpsc, Mutex};
11
12 fn to_seqnum(seqnum: u16) -> usize {
13     (seqnum as usize) & (REL_BUFFER - 1)
14 }
15
16 type Result = std::result::Result<(), Error>;
17
18 struct Split {
19     timestamp: time::Instant,
20 }
21
22 struct Chan {
23     packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
24     splits: HashMap<u16, Split>,
25     seqnum: u16,
26     num: u8,
27 }
28
29 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
30     share: Arc<RudpShare<S>>,
31     chans: Arc<Vec<Mutex<Chan>>>,
32     pkt_tx: mpsc::UnboundedSender<InPkt>,
33     udp_rx: R,
34 }
35
36 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
37     pub async fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) {
38         Self {
39             udp_rx,
40             share,
41             pkt_tx,
42             chans: Arc::new(
43                 (0..NUM_CHANS as u8)
44                     .map(|num| {
45                         Mutex::new(Chan {
46                             num,
47                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
48                             seqnum: INIT_SEQNUM,
49                             splits: HashMap::new(),
50                         })
51                     })
52                     .collect(),
53             ),
54         }
55         .run()
56         .await
57     }
58
59     pub async fn run(&self) {
60         let cleanup_chans = Arc::downgrade(&self.chans);
61         tokio::spawn(async move {
62             let timeout = time::Duration::from_secs(TIMEOUT);
63             let mut interval = tokio::time::interval(timeout);
64
65             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
66                 for chan in chans.iter() {
67                     let mut ch = chan.lock().await;
68                     ch.splits = ch
69                         .splits
70                         .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
71                         .collect();
72                 }
73
74                 interval.tick().await;
75             }
76         });
77
78         loop {
79             if let Err(e) = self.handle(self.recv_pkt().await) {
80                 if let Error::LocalDisco = e {
81                     self.share
82                         .send(
83                             PktType::Ctl,
84                             Pkt {
85                                 unrel: true,
86                                 chan: 0,
87                                 data: &[CtlType::Disco as u8],
88                             },
89                         )
90                         .await
91                         .ok();
92                 }
93                 break;
94             }
95         }
96     }
97
98     async fn recv_pkt(&self) -> Result {
99         use Error::*;
100
101         // todo: reset timeout
102         let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
103
104         let proto_id = cursor.read_u32::<BigEndian>()?;
105         if proto_id != PROTO_ID {
106             do yeet InvalidProtoId(proto_id);
107         }
108
109         let peer_id = cursor.read_u16::<BigEndian>()?;
110
111         let n_chan = cursor.read_u8()?;
112         let mut chan = self
113             .chans
114             .get(n_chan as usize)
115             .ok_or(InvalidChannel(n_chan))?
116             .lock()
117             .await;
118
119         self.process_pkt(cursor, &mut chan)
120     }
121
122     fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
123         use CtlType::*;
124         use Error::*;
125         use PktType::*;
126
127         match cursor.read_u8()?.try_into()? {
128             Ctl => match cursor.read_u8()?.try_into()? {
129                 Disco => return Err(RemoteDisco),
130                 _ => {}
131             },
132             Orig => {
133                 println!("Orig");
134
135                 self.pkt_tx.send(Ok(Pkt {
136                     chan: chan.num,
137                     unrel: true,
138                     data: cursor.remaining_slice().into(),
139                 }))?;
140             }
141             Split => {
142                 println!("Split");
143                 dbg!(cursor.remaining_slice());
144             }
145             Rel => {
146                 println!("Rel");
147
148                 let seqnum = cursor.read_u16::<BigEndian>()?;
149                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
150
151                 while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
152                     self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
153                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
154                 }
155             }
156         }
157
158         Ok(())
159     }
160
161     fn handle(&self, res: Result) -> Result {
162         use Error::*;
163
164         match res {
165             Ok(v) => Ok(v),
166             Err(RemoteDisco) => Err(RemoteDisco),
167             Err(LocalDisco) => Err(LocalDisco),
168             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
169         }
170     }
171 }