From: Lizzy Fleckenstein Date: Thu, 29 Dec 2022 01:19:56 +0000 (+0100) Subject: implement splits X-Git-Url: https://git.lizzy.rs/?a=commitdiff_plain;h=944c16adfb83976149701086e20146797d4330df;p=mt_rudp.git implement splits --- diff --git a/Cargo.toml b/Cargo.toml index 2947bad..edcdf2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +async-recursion = "1.0.0" async-trait = "0.1.60" byteorder = "1.4.3" num_enum = "0.5.7" diff --git a/src/error.rs b/src/error.rs index 9b84ace..f434804 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,10 +7,12 @@ use tokio::sync::mpsc::error::SendError; pub enum Error { IoError(io::Error), InvalidProtoId(u32), - InvalidPeerID, InvalidChannel(u8), InvalidType(u8), InvalidCtlType(u8), + PeerIDAlreadySet, + InvalidChunkIndex(usize, usize), + InvalidChunkCount(usize, usize), RemoteDisco, LocalDisco, } @@ -35,24 +37,26 @@ impl From> for Error { impl From> for Error { fn from(_err: SendError) -> Self { - Self::LocalDisco // technically not a disconnect but a local drop + Self::LocalDisco } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use Error::*; - write!(f, "RUDP Error: ")?; + write!(f, "RUDP error: ")?; match self { - IoError(err) => write!(f, "IO Error: {}", err), - InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"), - InvalidPeerID => write!(f, "Invalid Peer ID"), - InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"), - InvalidType(tp) => write!(f, "Invalid Type: {tp}"), - InvalidCtlType(tp) => write!(f, "Invalid Control Type: {tp}"), - RemoteDisco => write!(f, "Remote Disconnected"), - LocalDisco => write!(f, "Local Disconnected"), + IoError(err) => write!(f, "IO error: {}", err), + InvalidProtoId(id) => write!(f, "invalid protocol ID: {id}"), + InvalidChannel(ch) => write!(f, "invalid channel: {ch}"), + InvalidType(tp) => write!(f, "invalid type: {tp}"), + InvalidCtlType(tp) => write!(f, "invalid control type: {tp}"), + PeerIDAlreadySet => write!(f, "peer ID already set"), + InvalidChunkIndex(i, n) => write!(f, "chunk index {i} bigger than chunk count {n}"), + InvalidChunkCount(o, n) => write!(f, "chunk count changed from {o} to {n}"), + RemoteDisco => write!(f, "remote disconnected"), + LocalDisco => write!(f, "local disconnected"), } } } diff --git a/src/main.rs b/src/main.rs index a190bcd..d5fa952 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ -#![feature(yeet_expr)] #![feature(cursor_remaining)] #![feature(hash_drain_filter)] +#![feature(once_cell)] mod client; pub mod error; mod recv_worker; @@ -9,13 +9,12 @@ use async_trait::async_trait; use byteorder::{BigEndian, WriteBytesExt}; pub use client::{connect, Sender as Client}; use num_enum::TryFromPrimitive; -use std::future::Future; use std::{ io::{self, Write}, ops, sync::Arc, }; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, RwLock}; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; @@ -34,7 +33,8 @@ pub trait UdpReceiver: Send + Sync + 'static { async fn recv(&self) -> io::Result>; } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] +#[repr(u16)] pub enum PeerID { Nil = 0, Srv, @@ -75,7 +75,7 @@ pub struct AckChan; #[derive(Debug)] pub struct RudpShare { pub id: u16, - pub remote_id: u16, + pub remote_id: RwLock, pub chans: Vec, udp_tx: S, } @@ -95,7 +95,7 @@ impl RudpShare { pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; - buf.write_u16::(self.remote_id)?; + buf.write_u16::(*self.remote_id.read().await)?; buf.write_u8(pkt.chan as u8)?; buf.write_u8(tp as u8)?; buf.write(pkt.data)?; @@ -110,6 +110,24 @@ impl RudpSender { pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { self.share.send(PktType::Orig, pkt).await // TODO } + + pub async fn peer_id(&self) -> u16 { + self.share.id + } + + pub async fn is_server(&self) -> bool { + self.share.id == PeerID::Srv as u16 + } +} + +impl RudpReceiver { + pub async fn peer_id(&self) -> u16 { + self.share.id + } + + pub async fn is_server(&self) -> bool { + self.share.id == PeerID::Srv as u16 + } } impl ops::Deref for RudpReceiver { @@ -136,7 +154,7 @@ pub fn new( let share = Arc::new(RudpShare { id, - remote_id, + remote_id: RwLock::new(remote_id), udp_tx, chans: (0..NUM_CHANS).map(|_| AckChan).collect(), }); @@ -159,7 +177,6 @@ pub fn new( #[tokio::main] async fn main() -> io::Result<()> { - //println!("{}", x.deep_size_of()); let (tx, mut rx) = connect("127.0.0.1:30000").await?; let mut mtpkt = vec![]; diff --git a/src/recv_worker.rs b/src/recv_worker.rs index 5a156eb..f83e8ef 100644 --- a/src/recv_worker.rs +++ b/src/recv_worker.rs @@ -1,7 +1,8 @@ use crate::{error::Error, *}; +use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt}; use std::{ - cell::Cell, + cell::{Cell, OnceCell}, collections::HashMap, io, sync::{Arc, Weak}, @@ -13,14 +14,16 @@ fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) } -type Result = std::result::Result<(), Error>; +type Result = std::result::Result; struct Split { - timestamp: time::Instant, + timestamp: Option, + chunks: Vec>>, + got: usize, } struct Chan { - packets: Vec>>>, // in the good old days this used to be called char ** + packets: Vec>>>, // char ** 😛 splits: HashMap, seqnum: u16, num: u8, @@ -65,7 +68,9 @@ impl RecvWorker { let mut ch = chan.lock().await; ch.splits = ch .splits - .drain_filter(|_k, v| v.timestamp.elapsed() < timeout) + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) .collect(); } @@ -93,7 +98,7 @@ impl RecvWorker { } } - async fn recv_pkt(&self) -> Result { + async fn recv_pkt(&self) -> Result<()> { use Error::*; // todo: reset timeout @@ -101,10 +106,10 @@ impl RecvWorker { let proto_id = cursor.read_u32::()?; if proto_id != PROTO_ID { - do yeet InvalidProtoId(proto_id); + return Err(InvalidProtoId(proto_id)); } - let peer_id = cursor.read_u16::()?; + let _peer_id = cursor.read_u16::()?; let n_chan = cursor.read_u8()?; let mut chan = self @@ -114,40 +119,102 @@ impl RecvWorker { .lock() .await; - self.process_pkt(cursor, &mut chan) + self.process_pkt(cursor, true, &mut chan).await } - fn process_pkt(&self, mut cursor: io::Cursor>, chan: &mut Chan) -> Result { - use CtlType::*; + #[async_recursion] + async fn process_pkt( + &self, + mut cursor: io::Cursor>, + unrel: bool, + chan: &mut Chan, + ) -> Result<()> { use Error::*; - use PktType::*; match cursor.read_u8()?.try_into()? { - Ctl => match cursor.read_u8()?.try_into()? { - Disco => return Err(RemoteDisco), - _ => {} + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { /* TODO */ } + CtlType::SetPeerID => { + let mut id = self.share.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::()?; + } + CtlType::Ping => {} + CtlType::Disco => return Err(RemoteDisco), }, - Orig => { + PktType::Orig => { println!("Orig"); self.pkt_tx.send(Ok(Pkt { chan: chan.num, - unrel: true, + unrel, data: cursor.remaining_slice().into(), }))?; } - Split => { + PktType::Split => { println!("Split"); - dbg!(cursor.remaining_slice()); + + let seqnum = cursor.read_u16::()?; + let chunk_index = cursor.read_u16::()? as usize; + let chunk_count = cursor.read_u16::()? as usize; + + let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .set(cursor.remaining_slice().into()) + .is_ok() + { + split.got += 1; + } + + split.timestamp = if unrel { + Some(time::Instant::now()) + } else { + None + }; + + if split.got == chunk_count { + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: split + .chunks + .iter() + .flat_map(|chunk| chunk.get().unwrap().iter()) + .copied() + .collect(), + }))?; + + chan.splits.remove(&seqnum); + } } - Rel => { + PktType::Rel => { println!("Rel"); let seqnum = cursor.read_u16::()?; chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); - while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() { - self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?; + fn next_pkt(chan: &mut Chan) -> Option> { + chan.packets[to_seqnum(chan.seqnum)].take() + } + + while let Some(pkt) = next_pkt(chan) { + self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; chan.seqnum = chan.seqnum.overflowing_add(1).0; } } @@ -156,7 +223,7 @@ impl RecvWorker { Ok(()) } - fn handle(&self, res: Result) -> Result { + fn handle(&self, res: Result<()>) -> Result<()> { use Error::*; match res {