]> git.lizzy.rs Git - mt_rudp.git/commitdiff
implement splits
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Thu, 29 Dec 2022 01:19:56 +0000 (02:19 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Thu, 29 Dec 2022 01:19:56 +0000 (02:19 +0100)
Cargo.toml
src/error.rs
src/main.rs
src/recv_worker.rs

index 2947badc7617114a12dbbc5c541f2f1e2952993c..edcdf2ed01a0e4caa3ed78d97fa14eeb247a3443 100644 (file)
@@ -4,6 +4,7 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
+async-recursion = "1.0.0"
 async-trait = "0.1.60"
 byteorder = "1.4.3"
 num_enum = "0.5.7"
index 9b84ace0f55984077a0b9eaa1748e47f060991bf..f434804e1e6fb9315801ae2b56b13fbdbbd346a0 100644 (file)
@@ -7,10 +7,12 @@ use tokio::sync::mpsc::error::SendError;
 pub enum Error {
     IoError(io::Error),
     InvalidProtoId(u32),
-    InvalidPeerID,
     InvalidChannel(u8),
     InvalidType(u8),
     InvalidCtlType(u8),
+    PeerIDAlreadySet,
+    InvalidChunkIndex(usize, usize),
+    InvalidChunkCount(usize, usize),
     RemoteDisco,
     LocalDisco,
 }
@@ -35,24 +37,26 @@ impl From<TryFromPrimitiveError<CtlType>> for Error {
 
 impl From<SendError<InPkt>> for Error {
     fn from(_err: SendError<InPkt>) -> Self {
-        Self::LocalDisco // technically not a disconnect but a local drop
+        Self::LocalDisco
     }
 }
 
 impl fmt::Display for Error {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         use Error::*;
-        write!(f, "RUDP Error: ")?;
+        write!(f, "RUDP error: ")?;
 
         match self {
-            IoError(err) => write!(f, "IO Error: {}", err),
-            InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
-            InvalidPeerID => write!(f, "Invalid Peer ID"),
-            InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
-            InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
-            InvalidCtlType(tp) => write!(f, "Invalid Control Type: {tp}"),
-            RemoteDisco => write!(f, "Remote Disconnected"),
-            LocalDisco => write!(f, "Local Disconnected"),
+            IoError(err) => write!(f, "IO error: {}", err),
+            InvalidProtoId(id) => write!(f, "invalid protocol ID: {id}"),
+            InvalidChannel(ch) => write!(f, "invalid channel: {ch}"),
+            InvalidType(tp) => write!(f, "invalid type: {tp}"),
+            InvalidCtlType(tp) => write!(f, "invalid control type: {tp}"),
+            PeerIDAlreadySet => write!(f, "peer ID already set"),
+            InvalidChunkIndex(i, n) => write!(f, "chunk index {i} bigger than chunk count {n}"),
+            InvalidChunkCount(o, n) => write!(f, "chunk count changed from {o} to {n}"),
+            RemoteDisco => write!(f, "remote disconnected"),
+            LocalDisco => write!(f, "local disconnected"),
         }
     }
 }
index a190bcd094447afb6bdd928b1e0c2d3023f12e43..d5fa9526d89e7365dcd7d7abb0267df55ebc6cf7 100644 (file)
@@ -1,6 +1,6 @@
-#![feature(yeet_expr)]
 #![feature(cursor_remaining)]
 #![feature(hash_drain_filter)]
+#![feature(once_cell)]
 mod client;
 pub mod error;
 mod recv_worker;
@@ -9,13 +9,12 @@ use async_trait::async_trait;
 use byteorder::{BigEndian, WriteBytesExt};
 pub use client::{connect, Sender as Client};
 use num_enum::TryFromPrimitive;
-use std::future::Future;
 use std::{
     io::{self, Write},
     ops,
     sync::Arc,
 };
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, RwLock};
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
@@ -34,7 +33,8 @@ pub trait UdpReceiver: Send + Sync + 'static {
     async fn recv(&self) -> io::Result<Vec<u8>>;
 }
 
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[repr(u16)]
 pub enum PeerID {
     Nil = 0,
     Srv,
@@ -75,7 +75,7 @@ pub struct AckChan;
 #[derive(Debug)]
 pub struct RudpShare<S: UdpSender> {
     pub id: u16,
-    pub remote_id: u16,
+    pub remote_id: RwLock<u16>,
     pub chans: Vec<AckChan>,
     udp_tx: S,
 }
@@ -95,7 +95,7 @@ impl<S: UdpSender> RudpShare<S> {
     pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
         buf.write_u32::<BigEndian>(PROTO_ID)?;
-        buf.write_u16::<BigEndian>(self.remote_id)?;
+        buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
         buf.write_u8(pkt.chan as u8)?;
         buf.write_u8(tp as u8)?;
         buf.write(pkt.data)?;
@@ -110,6 +110,24 @@ impl<S: UdpSender> RudpSender<S> {
     pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
         self.share.send(PktType::Orig, pkt).await // TODO
     }
+
+    pub async fn peer_id(&self) -> u16 {
+        self.share.id
+    }
+
+    pub async fn is_server(&self) -> bool {
+        self.share.id == PeerID::Srv as u16
+    }
+}
+
+impl<S: UdpSender> RudpReceiver<S> {
+    pub async fn peer_id(&self) -> u16 {
+        self.share.id
+    }
+
+    pub async fn is_server(&self) -> bool {
+        self.share.id == PeerID::Srv as u16
+    }
 }
 
 impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
@@ -136,7 +154,7 @@ pub fn new<S: UdpSender, R: UdpReceiver>(
 
     let share = Arc::new(RudpShare {
         id,
-        remote_id,
+        remote_id: RwLock::new(remote_id),
         udp_tx,
         chans: (0..NUM_CHANS).map(|_| AckChan).collect(),
     });
@@ -159,7 +177,6 @@ pub fn new<S: UdpSender, R: UdpReceiver>(
 
 #[tokio::main]
 async fn main() -> io::Result<()> {
-    //println!("{}", x.deep_size_of());
     let (tx, mut rx) = connect("127.0.0.1:30000").await?;
 
     let mut mtpkt = vec![];
index 5a156eb9a6f07883a95adc6ed6919254fde882be..f83e8efb7dca69ec722723f3f85ef130f94f57de 100644 (file)
@@ -1,7 +1,8 @@
 use crate::{error::Error, *};
+use async_recursion::async_recursion;
 use byteorder::{BigEndian, ReadBytesExt};
 use std::{
-    cell::Cell,
+    cell::{Cell, OnceCell},
     collections::HashMap,
     io,
     sync::{Arc, Weak},
@@ -13,14 +14,16 @@ fn to_seqnum(seqnum: u16) -> usize {
     (seqnum as usize) & (REL_BUFFER - 1)
 }
 
-type Result = std::result::Result<(), Error>;
+type Result<T> = std::result::Result<T, Error>;
 
 struct Split {
-    timestamp: time::Instant,
+    timestamp: Option<time::Instant>,
+    chunks: Vec<OnceCell<Vec<u8>>>,
+    got: usize,
 }
 
 struct Chan {
-    packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
+    packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
     splits: HashMap<u16, Split>,
     seqnum: u16,
     num: u8,
@@ -65,7 +68,9 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                     let mut ch = chan.lock().await;
                     ch.splits = ch
                         .splits
-                        .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
+                        .drain_filter(
+                            |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
+                        )
                         .collect();
                 }
 
@@ -93,7 +98,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
         }
     }
 
-    async fn recv_pkt(&self) -> Result {
+    async fn recv_pkt(&self) -> Result<()> {
         use Error::*;
 
         // todo: reset timeout
@@ -101,10 +106,10 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
 
         let proto_id = cursor.read_u32::<BigEndian>()?;
         if proto_id != PROTO_ID {
-            do yeet InvalidProtoId(proto_id);
+            return Err(InvalidProtoId(proto_id));
         }
 
-        let peer_id = cursor.read_u16::<BigEndian>()?;
+        let _peer_id = cursor.read_u16::<BigEndian>()?;
 
         let n_chan = cursor.read_u8()?;
         let mut chan = self
@@ -114,40 +119,102 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
             .lock()
             .await;
 
-        self.process_pkt(cursor, &mut chan)
+        self.process_pkt(cursor, true, &mut chan).await
     }
 
-    fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
-        use CtlType::*;
+    #[async_recursion]
+    async fn process_pkt(
+        &self,
+        mut cursor: io::Cursor<Vec<u8>>,
+        unrel: bool,
+        chan: &mut Chan,
+    ) -> Result<()> {
         use Error::*;
-        use PktType::*;
 
         match cursor.read_u8()?.try_into()? {
-            Ctl => match cursor.read_u8()?.try_into()? {
-                Disco => return Err(RemoteDisco),
-                _ => {}
+            PktType::Ctl => match cursor.read_u8()?.try_into()? {
+                CtlType::Ack => { /* TODO */ }
+                CtlType::SetPeerID => {
+                    let mut id = self.share.remote_id.write().await;
+
+                    if *id != PeerID::Nil as u16 {
+                        return Err(PeerIDAlreadySet);
+                    }
+
+                    *id = cursor.read_u16::<BigEndian>()?;
+                }
+                CtlType::Ping => {}
+                CtlType::Disco => return Err(RemoteDisco),
             },
-            Orig => {
+            PktType::Orig => {
                 println!("Orig");
 
                 self.pkt_tx.send(Ok(Pkt {
                     chan: chan.num,
-                    unrel: true,
+                    unrel,
                     data: cursor.remaining_slice().into(),
                 }))?;
             }
-            Split => {
+            PktType::Split => {
                 println!("Split");
-                dbg!(cursor.remaining_slice());
+
+                let seqnum = cursor.read_u16::<BigEndian>()?;
+                let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
+                let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
+
+                let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
+                    got: 0,
+                    chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
+                    timestamp: None,
+                });
+
+                if split.chunks.len() != chunk_count {
+                    return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
+                }
+
+                if split
+                    .chunks
+                    .get(chunk_index)
+                    .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
+                    .set(cursor.remaining_slice().into())
+                    .is_ok()
+                {
+                    split.got += 1;
+                }
+
+                split.timestamp = if unrel {
+                    Some(time::Instant::now())
+                } else {
+                    None
+                };
+
+                if split.got == chunk_count {
+                    self.pkt_tx.send(Ok(Pkt {
+                        chan: chan.num,
+                        unrel,
+                        data: split
+                            .chunks
+                            .iter()
+                            .flat_map(|chunk| chunk.get().unwrap().iter())
+                            .copied()
+                            .collect(),
+                    }))?;
+
+                    chan.splits.remove(&seqnum);
+                }
             }
-            Rel => {
+            PktType::Rel => {
                 println!("Rel");
 
                 let seqnum = cursor.read_u16::<BigEndian>()?;
                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
 
-                while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
-                    self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
+                fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
+                    chan.packets[to_seqnum(chan.seqnum)].take()
+                }
+
+                while let Some(pkt) = next_pkt(chan) {
+                    self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
                 }
             }
@@ -156,7 +223,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
         Ok(())
     }
 
-    fn handle(&self, res: Result) -> Result {
+    fn handle(&self, res: Result<()>) -> Result<()> {
         use Error::*;
 
         match res {