1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
5 cell::{Cell, OnceCell},
9 time::{Duration, Instant},
11 use tokio::sync::{mpsc, Mutex};
13 fn to_seqnum(seqnum: u16) -> usize {
14 (seqnum as usize) & (REL_BUFFER - 1)
17 type Result<T> = std::result::Result<T, Error>;
20 timestamp: Option<Instant>,
21 chunks: Vec<OnceCell<Vec<u8>>>,
26 packets: Vec<Cell<Option<Vec<u8>>>>, // char ** 😛
27 splits: HashMap<u16, Split>,
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33 share: Arc<RudpShare<S>>,
34 close: watch::Receiver<bool>,
35 chans: Arc<Vec<Mutex<RecvChan>>>,
36 pkt_tx: mpsc::UnboundedSender<InPkt>,
40 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
43 share: Arc<RudpShare<S>>,
44 close: watch::Receiver<bool>,
45 pkt_tx: mpsc::UnboundedSender<InPkt>,
57 packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
59 splits: HashMap::new(),
67 pub async fn run(&self) {
68 let cleanup_chans = Arc::clone(&self.chans);
69 let mut cleanup_close = self.close.clone();
75 .name("cleanup_splits")*/
77 let timeout = Duration::from_secs(TIMEOUT);
79 ticker!(timeout, cleanup_close, {
80 for chan_mtx in cleanup_chans.iter() {
81 let mut chan = chan_mtx.lock().await;
85 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
92 let mut close = self.close.clone();
94 if let Err(e) = self.handle(self.recv_pkt(&mut close).await) {
95 if let Error::LocalDisco = e {
102 data: &[CtlType::Disco as u8],
113 async fn recv_pkt(&self, close: &mut watch::Receiver<bool>) -> Result<()> {
116 // TODO: reset timeout
117 let mut cursor = io::Cursor::new(tokio::select! {
118 pkt = self.udp_rx.recv() => pkt?,
119 _ = close.changed() => return Err(LocalDisco),
124 let proto_id = cursor.read_u32::<BigEndian>()?;
125 if proto_id != PROTO_ID {
126 return Err(InvalidProtoId(proto_id));
129 let _peer_id = cursor.read_u16::<BigEndian>()?;
131 let n_chan = cursor.read_u8()?;
134 .get(n_chan as usize)
135 .ok_or(InvalidChannel(n_chan))?
139 self.process_pkt(cursor, true, &mut chan).await
143 async fn process_pkt(
145 mut cursor: io::Cursor<Vec<u8>>,
151 match cursor.read_u8()?.try_into()? {
152 PktType::Ctl => match cursor.read_u8()?.try_into()? {
156 let seqnum = cursor.read_u16::<BigEndian>()?;
157 if let Some(ack) = self.share.chans[chan.num as usize]
163 ack.tx.send(true).ok();
166 CtlType::SetPeerID => {
167 println!("SetPeerID");
169 let mut id = self.share.remote_id.write().await;
171 if *id != PeerID::Nil as u16 {
172 return Err(PeerIDAlreadySet);
175 *id = cursor.read_u16::<BigEndian>()?;
182 return Err(RemoteDisco);
188 self.pkt_tx.send(Ok(Pkt {
191 data: cursor.remaining_slice().into(),
197 let seqnum = cursor.read_u16::<BigEndian>()?;
198 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
199 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
201 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
203 chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
207 if split.chunks.len() != chunk_count {
208 return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
214 .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
215 .set(cursor.remaining_slice().into())
221 split.timestamp = if unrel { Some(Instant::now()) } else { None };
223 if split.got == chunk_count {
224 self.pkt_tx.send(Ok(Pkt {
230 .flat_map(|chunk| chunk.get().unwrap().iter())
235 chan.splits.remove(&seqnum);
241 let seqnum = cursor.read_u16::<BigEndian>()?;
242 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
244 let mut ack_data = Vec::with_capacity(3);
245 ack_data.write_u8(CtlType::Ack as u8)?;
246 ack_data.write_u16::<BigEndian>(seqnum)?;
259 fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
260 chan.packets[to_seqnum(chan.seqnum)].take()
263 while let Some(pkt) = next_pkt(chan) {
264 self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
265 chan.seqnum = chan.seqnum.overflowing_add(1).0;
273 fn handle(&self, res: Result<()>) -> Result<()> {
278 Err(RemoteDisco) => Err(RemoteDisco),
279 Err(LocalDisco) => Err(LocalDisco),
280 Err(e) => Ok(self.pkt_tx.send(Err(e))?),