]> git.lizzy.rs Git - mt_rudp.git/blobdiff - src/recv_worker.rs
implement splits
[mt_rudp.git] / src / recv_worker.rs
index 5a156eb9a6f07883a95adc6ed6919254fde882be..f83e8efb7dca69ec722723f3f85ef130f94f57de 100644 (file)
@@ -1,7 +1,8 @@
 use crate::{error::Error, *};
+use async_recursion::async_recursion;
 use byteorder::{BigEndian, ReadBytesExt};
 use std::{
-    cell::Cell,
+    cell::{Cell, OnceCell},
     collections::HashMap,
     io,
     sync::{Arc, Weak},
@@ -13,14 +14,16 @@ fn to_seqnum(seqnum: u16) -> usize {
     (seqnum as usize) & (REL_BUFFER - 1)
 }
 
-type Result = std::result::Result<(), Error>;
+type Result<T> = std::result::Result<T, Error>;
 
 struct Split {
-    timestamp: time::Instant,
+    timestamp: Option<time::Instant>,
+    chunks: Vec<OnceCell<Vec<u8>>>,
+    got: usize,
 }
 
 struct Chan {
-    packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
+    packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
     splits: HashMap<u16, Split>,
     seqnum: u16,
     num: u8,
@@ -65,7 +68,9 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                     let mut ch = chan.lock().await;
                     ch.splits = ch
                         .splits
-                        .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
+                        .drain_filter(
+                            |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
+                        )
                         .collect();
                 }
 
@@ -93,7 +98,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
         }
     }
 
-    async fn recv_pkt(&self) -> Result {
+    async fn recv_pkt(&self) -> Result<()> {
         use Error::*;
 
         // todo: reset timeout
@@ -101,10 +106,10 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
 
         let proto_id = cursor.read_u32::<BigEndian>()?;
         if proto_id != PROTO_ID {
-            do yeet InvalidProtoId(proto_id);
+            return Err(InvalidProtoId(proto_id));
         }
 
-        let peer_id = cursor.read_u16::<BigEndian>()?;
+        let _peer_id = cursor.read_u16::<BigEndian>()?;
 
         let n_chan = cursor.read_u8()?;
         let mut chan = self
@@ -114,40 +119,102 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
             .lock()
             .await;
 
-        self.process_pkt(cursor, &mut chan)
+        self.process_pkt(cursor, true, &mut chan).await
     }
 
-    fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
-        use CtlType::*;
+    #[async_recursion]
+    async fn process_pkt(
+        &self,
+        mut cursor: io::Cursor<Vec<u8>>,
+        unrel: bool,
+        chan: &mut Chan,
+    ) -> Result<()> {
         use Error::*;
-        use PktType::*;
 
         match cursor.read_u8()?.try_into()? {
-            Ctl => match cursor.read_u8()?.try_into()? {
-                Disco => return Err(RemoteDisco),
-                _ => {}
+            PktType::Ctl => match cursor.read_u8()?.try_into()? {
+                CtlType::Ack => { /* TODO */ }
+                CtlType::SetPeerID => {
+                    let mut id = self.share.remote_id.write().await;
+
+                    if *id != PeerID::Nil as u16 {
+                        return Err(PeerIDAlreadySet);
+                    }
+
+                    *id = cursor.read_u16::<BigEndian>()?;
+                }
+                CtlType::Ping => {}
+                CtlType::Disco => return Err(RemoteDisco),
             },
-            Orig => {
+            PktType::Orig => {
                 println!("Orig");
 
                 self.pkt_tx.send(Ok(Pkt {
                     chan: chan.num,
-                    unrel: true,
+                    unrel,
                     data: cursor.remaining_slice().into(),
                 }))?;
             }
-            Split => {
+            PktType::Split => {
                 println!("Split");
-                dbg!(cursor.remaining_slice());
+
+                let seqnum = cursor.read_u16::<BigEndian>()?;
+                let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
+                let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
+
+                let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
+                    got: 0,
+                    chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
+                    timestamp: None,
+                });
+
+                if split.chunks.len() != chunk_count {
+                    return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
+                }
+
+                if split
+                    .chunks
+                    .get(chunk_index)
+                    .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
+                    .set(cursor.remaining_slice().into())
+                    .is_ok()
+                {
+                    split.got += 1;
+                }
+
+                split.timestamp = if unrel {
+                    Some(time::Instant::now())
+                } else {
+                    None
+                };
+
+                if split.got == chunk_count {
+                    self.pkt_tx.send(Ok(Pkt {
+                        chan: chan.num,
+                        unrel,
+                        data: split
+                            .chunks
+                            .iter()
+                            .flat_map(|chunk| chunk.get().unwrap().iter())
+                            .copied()
+                            .collect(),
+                    }))?;
+
+                    chan.splits.remove(&seqnum);
+                }
             }
-            Rel => {
+            PktType::Rel => {
                 println!("Rel");
 
                 let seqnum = cursor.read_u16::<BigEndian>()?;
                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
 
-                while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
-                    self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
+                fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
+                    chan.packets[to_seqnum(chan.seqnum)].take()
+                }
+
+                while let Some(pkt) = next_pkt(chan) {
+                    self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
                 }
             }
@@ -156,7 +223,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
         Ok(())
     }
 
-    fn handle(&self, res: Result) -> Result {
+    fn handle(&self, res: Result<()>) -> Result<()> {
         use Error::*;
 
         match res {