]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
fix async
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use byteorder::{BigEndian, ReadBytesExt};
3 use std::{
4     cell::Cell,
5     collections::HashMap,
6     io,
7     sync::{Arc, Weak},
8     time,
9 };
10 use tokio::sync::{mpsc, Mutex};
11
12 fn to_seqnum(seqnum: u16) -> usize {
13     (seqnum as usize) & (REL_BUFFER - 1)
14 }
15
16 type Result = std::result::Result<(), Error>;
17
18 struct Split {
19     timestamp: time::Instant,
20 }
21
22 struct Chan {
23     packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
24     splits: HashMap<u16, Split>,
25     seqnum: u16,
26     num: u8,
27 }
28
29 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
30     share: Arc<RudpShare<S>>,
31     chans: Arc<Vec<Mutex<Chan>>>,
32     pkt_tx: mpsc::UnboundedSender<InPkt>,
33     udp_rx: R,
34 }
35
36 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
37     pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
38         Self {
39             udp_rx,
40             share,
41             pkt_tx,
42             chans: Arc::new(
43                 (0..NUM_CHANS as u8)
44                     .map(|num| {
45                         Mutex::new(Chan {
46                             num,
47                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
48                             seqnum: INIT_SEQNUM,
49                             splits: HashMap::new(),
50                         })
51                     })
52                     .collect(),
53             ),
54         }
55     }
56
57     pub async fn run(&self) {
58         let cleanup_chans = Arc::downgrade(&self.chans);
59         tokio::spawn(async move {
60             let timeout = time::Duration::from_secs(TIMEOUT);
61             let mut interval = tokio::time::interval(timeout);
62
63             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
64                 for chan in chans.iter() {
65                     let mut ch = chan.lock().await;
66                     ch.splits = ch
67                         .splits
68                         .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
69                         .collect();
70                 }
71
72                 interval.tick().await;
73             }
74         });
75
76         loop {
77             if let Err(e) = self.handle(self.recv_pkt().await) {
78                 if let Error::LocalDisco = e {
79                     self.share
80                         .send(
81                             PktType::Ctl,
82                             Pkt {
83                                 unrel: true,
84                                 chan: 0,
85                                 data: &[CtlType::Disco as u8],
86                             },
87                         )
88                         .await
89                         .ok();
90                 }
91                 break;
92             }
93         }
94     }
95
96     async fn recv_pkt(&self) -> Result {
97         use Error::*;
98
99         // todo: reset timeout
100         let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
101
102         let proto_id = cursor.read_u32::<BigEndian>()?;
103         if proto_id != PROTO_ID {
104             do yeet InvalidProtoId(proto_id);
105         }
106
107         let peer_id = cursor.read_u16::<BigEndian>()?;
108
109         let n_chan = cursor.read_u8()?;
110         let mut chan = self
111             .chans
112             .get(n_chan as usize)
113             .ok_or(InvalidChannel(n_chan))?
114             .lock()
115             .await;
116
117         self.process_pkt(cursor, &mut chan)
118     }
119
120     fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
121         use CtlType::*;
122         use Error::*;
123         use PktType::*;
124
125         match cursor.read_u8()?.try_into()? {
126             Ctl => match cursor.read_u8()?.try_into()? {
127                 Disco => return Err(RemoteDisco),
128                 _ => {}
129             },
130             Orig => {
131                 println!("Orig");
132
133                 self.pkt_tx.send(Ok(Pkt {
134                     chan: chan.num,
135                     unrel: true,
136                     data: cursor.remaining_slice().into(),
137                 }))?;
138             }
139             Split => {
140                 println!("Split");
141                 dbg!(cursor.remaining_slice());
142             }
143             Rel => {
144                 println!("Rel");
145
146                 let seqnum = cursor.read_u16::<BigEndian>()?;
147                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
148
149                 while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
150                     self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
151                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
152                 }
153             }
154         }
155
156         Ok(())
157     }
158
159     fn handle(&self, res: Result) -> Result {
160         use Error::*;
161
162         match res {
163             Ok(v) => Ok(v),
164             Err(RemoteDisco) => Err(RemoteDisco),
165             Err(LocalDisco) => Err(LocalDisco),
166             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
167         }
168     }
169 }