]> git.lizzy.rs Git - mt_rudp.git/blob - src/recv_worker.rs
splits infrastructure
[mt_rudp.git] / src / recv_worker.rs
1 use crate::{error::Error, *};
2 use byteorder::{BigEndian, ReadBytesExt};
3 use std::{
4     cell::Cell,
5     collections::HashMap,
6     io, result,
7     sync::{mpsc, Arc, Mutex, Weak},
8     thread, time,
9 };
10
11 fn to_seqnum(seqnum: u16) -> usize {
12     (seqnum as usize) & (REL_BUFFER - 1)
13 }
14
15 type PktTx = mpsc::Sender<InPkt>;
16 type Result = result::Result<(), Error>;
17
18 struct Split {
19     timestamp: time::Instant,
20 }
21
22 struct Chan {
23     packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
24     splits: HashMap<u16, Split>,
25     seqnum: u16,
26     num: u8,
27 }
28
29 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
30     share: Arc<RudpShare<S>>,
31     chans: Arc<Vec<Mutex<Chan>>>,
32     pkt_tx: PktTx,
33     udp_rx: R,
34 }
35
36 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
37     pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: PktTx) -> Self {
38         Self {
39             udp_rx,
40             share,
41             pkt_tx,
42             chans: Arc::new(
43                 (0..NUM_CHANS as u8)
44                     .map(|num| {
45                         Mutex::new(Chan {
46                             num,
47                             packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
48                             seqnum: INIT_SEQNUM,
49                             splits: HashMap::new(),
50                         })
51                     })
52                     .collect(),
53             ),
54         }
55     }
56
57     pub fn run(&self) {
58         let cleanup_chans = Arc::downgrade(&self.chans);
59         thread::spawn(move || {
60             let timeout = time::Duration::from_secs(TIMEOUT);
61
62             while let Some(chans) = Weak::upgrade(&cleanup_chans) {
63                 for chan in chans.iter() {
64                     let mut ch = chan.lock().unwrap();
65                     ch.splits = ch
66                         .splits
67                         .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
68                         .collect();
69                 }
70
71                 thread::sleep(timeout);
72             }
73         });
74
75         loop {
76             if let Err(e) = self.handle(self.recv_pkt()) {
77                 if let Error::LocalDisco = e {
78                     self.share
79                         .send(
80                             PktType::Ctl,
81                             Pkt {
82                                 unrel: true,
83                                 chan: 0,
84                                 data: &[CtlType::Disco as u8],
85                             },
86                         )
87                         .ok();
88                 }
89                 break;
90             }
91         }
92     }
93
94     fn recv_pkt(&self) -> Result {
95         use Error::*;
96
97         // todo: reset timeout
98         let mut cursor = io::Cursor::new(self.udp_rx.recv()?);
99
100         let proto_id = cursor.read_u32::<BigEndian>()?;
101         if proto_id != PROTO_ID {
102             do yeet InvalidProtoId(proto_id);
103         }
104
105         let peer_id = cursor.read_u16::<BigEndian>()?;
106
107         let n_chan = cursor.read_u8()?;
108         let mut chan = self
109             .chans
110             .get(n_chan as usize)
111             .ok_or(InvalidChannel(n_chan))?
112             .lock()
113             .unwrap();
114
115         self.process_pkt(cursor, &mut chan)
116     }
117
118     fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
119         use CtlType::*;
120         use Error::*;
121         use PktType::*;
122
123         match cursor.read_u8()?.try_into()? {
124             Ctl => match cursor.read_u8()?.try_into()? {
125                 Disco => return Err(RemoteDisco),
126                 _ => {}
127             },
128             Orig => {
129                 println!("Orig");
130
131                 self.pkt_tx.send(Ok(Pkt {
132                     chan: chan.num,
133                     unrel: true,
134                     data: cursor.remaining_slice().into(),
135                 }))?;
136             }
137             Split => {
138                 println!("Split");
139                 dbg!(cursor.remaining_slice());
140             }
141             Rel => {
142                 println!("Rel");
143
144                 let seqnum = cursor.read_u16::<BigEndian>()?;
145                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
146
147                 while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
148                     self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
149                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
150                 }
151             }
152         }
153
154         Ok(())
155     }
156
157     fn handle(&self, res: Result) -> Result {
158         use Error::*;
159
160         match res {
161             Ok(v) => Ok(v),
162             Err(RemoteDisco) => Err(RemoteDisco),
163             Err(LocalDisco) => Err(LocalDisco),
164             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
165         }
166     }
167 }