1 use crate::{error::Error, *};
2 use async_recursion::async_recursion;
3 use byteorder::{BigEndian, ReadBytesExt};
5 cell::{Cell, OnceCell},
11 use tokio::sync::{mpsc, Mutex};
13 fn to_seqnum(seqnum: u16) -> usize {
14 (seqnum as usize) & (REL_BUFFER - 1)
17 type Result<T> = std::result::Result<T, Error>;
20 timestamp: Option<time::Instant>,
21 chunks: Vec<OnceCell<Vec<u8>>>,
26 packets: Vec<Cell<Option<Vec<u8>>>>, // char ** 😛
27 splits: HashMap<u16, Split>,
32 pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
33 share: Arc<RudpShare<S>>,
34 chans: Arc<Vec<Mutex<Chan>>>,
35 pkt_tx: mpsc::UnboundedSender<InPkt>,
39 impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
40 pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
50 packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
52 splits: HashMap::new(),
60 pub async fn run(&self) {
61 let cleanup_chans = Arc::downgrade(&self.chans);
62 tokio::spawn(async move {
63 let timeout = time::Duration::from_secs(TIMEOUT);
64 let mut interval = tokio::time::interval(timeout);
66 while let Some(chans) = Weak::upgrade(&cleanup_chans) {
67 for chan in chans.iter() {
68 let mut ch = chan.lock().await;
72 |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
77 interval.tick().await;
82 if let Err(e) = self.handle(self.recv_pkt().await) {
83 if let Error::LocalDisco = e {
90 data: &[CtlType::Disco as u8],
101 async fn recv_pkt(&self) -> Result<()> {
104 // todo: reset timeout
105 let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
107 let proto_id = cursor.read_u32::<BigEndian>()?;
108 if proto_id != PROTO_ID {
109 return Err(InvalidProtoId(proto_id));
112 let _peer_id = cursor.read_u16::<BigEndian>()?;
114 let n_chan = cursor.read_u8()?;
117 .get(n_chan as usize)
118 .ok_or(InvalidChannel(n_chan))?
122 self.process_pkt(cursor, true, &mut chan).await
126 async fn process_pkt(
128 mut cursor: io::Cursor<Vec<u8>>,
134 match cursor.read_u8()?.try_into()? {
135 PktType::Ctl => match cursor.read_u8()?.try_into()? {
137 let seqnum = cursor.read_u16::<BigEndian>()?;
138 if let Some((tx, _)) = self.share.ack_chans.lock().await.remove(&seqnum) {
142 CtlType::SetPeerID => {
143 let mut id = self.share.remote_id.write().await;
145 if *id != PeerID::Nil as u16 {
146 return Err(PeerIDAlreadySet);
149 *id = cursor.read_u16::<BigEndian>()?;
152 CtlType::Disco => return Err(RemoteDisco),
157 self.pkt_tx.send(Ok(Pkt {
160 data: cursor.remaining_slice().into(),
166 let seqnum = cursor.read_u16::<BigEndian>()?;
167 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
168 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
170 let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
172 chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
176 if split.chunks.len() != chunk_count {
177 return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
183 .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
184 .set(cursor.remaining_slice().into())
190 split.timestamp = if unrel {
191 Some(time::Instant::now())
196 if split.got == chunk_count {
197 self.pkt_tx.send(Ok(Pkt {
203 .flat_map(|chunk| chunk.get().unwrap().iter())
208 chan.splits.remove(&seqnum);
214 let seqnum = cursor.read_u16::<BigEndian>()?;
215 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
217 let mut ack_data = Vec::with_capacity(3);
218 ack_data.write_u8(CtlType::Ack as u8)?;
219 ack_data.write_u16::<BigEndian>(seqnum)?;
232 fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
233 chan.packets[to_seqnum(chan.seqnum)].take()
236 while let Some(pkt) = next_pkt(chan) {
237 self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
238 chan.seqnum = chan.seqnum.overflowing_add(1).0;
246 fn handle(&self, res: Result<()>) -> Result<()> {
251 Err(RemoteDisco) => Err(RemoteDisco),
252 Err(LocalDisco) => Err(LocalDisco),
253 Err(e) => Ok(self.pkt_tx.send(Err(e))?),