From: Lizzy Fleckenstein Date: Fri, 6 Jan 2023 16:45:16 +0000 (+0100) Subject: clean shutdown; send reliables X-Git-Url: https://git.lizzy.rs/?a=commitdiff_plain;h=fd23bb3a2b57d43c115005dcd70f1e18bb005032;p=mt_rudp.git clean shutdown; send reliables --- diff --git a/Cargo.toml b/Cargo.toml index edcdf2e..7772e3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,5 @@ async-recursion = "1.0.0" async-trait = "0.1.60" byteorder = "1.4.3" num_enum = "0.5.7" +pretty-hex = "0.3.0" tokio = { version = "1.23.0", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index d416e53..172aa96 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,8 +8,8 @@ pub struct Sender { #[async_trait] impl UdpSender for Sender { - async fn send(&self, data: Vec) -> io::Result<()> { - self.sock.send(&data).await?; + async fn send(&self, data: &[u8]) -> io::Result<()> { + self.sock.send(data).await?; Ok(()) } } @@ -42,5 +42,6 @@ pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver sock: Arc::clone(&sock), }, Receiver { sock }, - )) + ) + .await?) } diff --git a/src/main.rs b/src/main.rs index 1f0aca0..0510db7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,19 +3,27 @@ #![feature(once_cell)] mod client; pub mod error; -mod recv_worker; +mod new; +mod recv; +mod send; use async_trait::async_trait; use byteorder::{BigEndian, WriteBytesExt}; pub use client::{connect, Sender as Client}; +pub use new::new; use num_enum::TryFromPrimitive; +use pretty_hex::PrettyHex; use std::{ collections::HashMap, io::{self, Write}, ops, sync::Arc, + time::Duration, +}; +use tokio::{ + sync::{mpsc, watch, Mutex, RwLock}, + task::JoinSet, }; -use tokio::sync::{mpsc, watch, Mutex, RwLock}; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; @@ -24,9 +32,25 @@ pub const REL_BUFFER: usize = 0x8000; pub const INIT_SEQNUM: u16 = 65500; pub const TIMEOUT: u64 = 30; +mod ticker_mod { + #[macro_export] + macro_rules! ticker { + ($duration:expr, $close:expr, $body:block) => { + let mut interval = tokio::time::interval($duration); + + while tokio::select!{ + _ = interval.tick() => true, + _ = $close.changed() => false, + } $body + }; + } + + //pub(crate) use ticker; +} + #[async_trait] pub trait UdpSender: Send + Sync + 'static { - async fn send(&self, data: Vec) -> io::Result<()>; + async fn send(&self, data: &[u8]) -> io::Result<()>; } #[async_trait] @@ -69,14 +93,28 @@ pub struct Pkt { pub type Error = error::Error; pub type InPkt = Result>, Error>; -type AckChan = (watch::Sender, watch::Receiver); + +#[derive(Debug)] +struct Ack { + tx: watch::Sender, + rx: watch::Receiver, + data: Vec, +} + +#[derive(Debug)] +struct Chan { + acks: HashMap, + seqnum: u16, +} #[derive(Debug)] pub struct RudpShare { - pub id: u16, - pub remote_id: RwLock, - pub ack_chans: Mutex>, + id: u16, + remote_id: RwLock, + chans: Vec>, udp_tx: S, + close_tx: watch::Sender, + tasks: Mutex>, } #[derive(Debug)] @@ -90,44 +128,31 @@ pub struct RudpSender { share: Arc>, } -impl RudpShare { - pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { - let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len()); - buf.write_u32::(PROTO_ID)?; - buf.write_u16::(*self.remote_id.read().await)?; - buf.write_u8(pkt.chan as u8)?; - buf.write_u8(tp as u8)?; - buf.write(pkt.data)?; - - self.udp_tx.send(buf).await?; - - Ok(()) - } -} +macro_rules! impl_share { + ($T:ident) => { + impl $T { + pub async fn peer_id(&self) -> u16 { + self.share.id + } -impl RudpSender { - pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { - self.share.send(PktType::Orig, pkt).await // TODO - } + pub async fn is_server(&self) -> bool { + self.share.id == PeerID::Srv as u16 + } - pub async fn peer_id(&self) -> u16 { - self.share.id - } + pub async fn close(self) { + self.share.close_tx.send(true).ok(); - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } + let mut tasks = self.share.tasks.lock().await; + while let Some(res) = tasks.join_next().await { + res.ok(); // TODO: handle error (?) + } + } + } + }; } -impl RudpReceiver { - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } -} +impl_share!(RudpReceiver); +impl_share!(RudpSender); impl ops::Deref for RudpReceiver { type Target = mpsc::UnboundedReceiver; @@ -143,49 +168,16 @@ impl ops::DerefMut for RudpReceiver { } } -pub fn new( - id: u16, - remote_id: u16, - udp_tx: S, - udp_rx: R, -) -> (RudpSender, RudpReceiver) { - let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); - - let share = Arc::new(RudpShare { - id, - remote_id: RwLock::new(remote_id), - udp_tx, - ack_chans: Mutex::new(HashMap::new()), - }); - let recv_share = Arc::clone(&share); - - tokio::spawn(async { - let worker = recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx); - worker.run().await; - }); - - ( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver { share, pkt_rx }, - ) -} - -// connect - -#[tokio::main] -async fn main() -> io::Result<()> { - let (tx, mut rx) = connect("127.0.0.1:30000").await?; - +async fn example(tx: &RudpSender, rx: &mut RudpReceiver) -> io::Result<()> { + // send hello packet let mut mtpkt = vec![]; mtpkt.write_u16::(2)?; // high level type mtpkt.write_u8(29)?; // serialize ver mtpkt.write_u16::(0)?; // compression modes mtpkt.write_u16::(40)?; // MinProtoVer mtpkt.write_u16::(40)?; // MaxProtoVer - mtpkt.write_u16::(3)?; // player name length - mtpkt.write(b"foo")?; // player name + mtpkt.write_u16::(6)?; // player name length + mtpkt.write(b"foobar")?; // player name tx.send(Pkt { unrel: true, @@ -194,17 +186,34 @@ async fn main() -> io::Result<()> { }) .await?; + // handle incoming packets while let Some(result) = rx.recv().await { match result { Ok(pkt) => { - io::stdout().write(pkt.data.as_slice())?; + println!("{}", pkt.data.hex_dump()); } Err(err) => eprintln!("Error: {}", err), } } - println!("disco"); - // close()ing rx is not needed because it has been consumed to the end + Ok(()) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + let (tx, mut rx) = connect("127.0.0.1:30000").await?; + + tokio::select! { + _ = tokio::signal::ctrl_c() => println!("canceled"), + res = example(&tx, &mut rx) => { + res?; + println!("disconnected"); + } + } + + // close either the receiver or the sender + // this shuts down associated tasks + rx.close().await; Ok(()) } diff --git a/src/new.rs b/src/new.rs new file mode 100644 index 0000000..a70b117 --- /dev/null +++ b/src/new.rs @@ -0,0 +1,63 @@ +use crate::*; + +pub async fn new( + id: u16, + remote_id: u16, + udp_tx: S, + udp_rx: R, +) -> io::Result<(RudpSender, RudpReceiver)> { + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + let (close_tx, close_rx) = watch::channel(false); + + let share = Arc::new(RudpShare { + id, + remote_id: RwLock::new(remote_id), + udp_tx, + close_tx, + chans: (0..NUM_CHANS) + .map(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, + }) + }) + .collect(), + tasks: Mutex::new(JoinSet::new()), + }); + + let mut tasks = share.tasks.lock().await; + + let recv_share = Arc::clone(&share); + let recv_close = close_rx.clone(); + tasks + /*.build_task() + .name("recv")*/ + .spawn(async move { + let worker = recv::RecvWorker::new(udp_rx, recv_share, recv_close, pkt_tx); + worker.run().await; + }); + + let resend_share = Arc::clone(&share); + let mut resend_close = close_rx.clone(); + tasks + /*.build_task() + .name("resend")*/ + .spawn(async move { + ticker!(Duration::from_millis(500), resend_close, { + for chan in resend_share.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + resend_share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + } + } + }); + }); + + drop(tasks); + + Ok(( + RudpSender { + share: Arc::clone(&share), + }, + RudpReceiver { share, pkt_rx }, + )) +} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..15811f2 --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,283 @@ +use crate::{error::Error, *}; +use async_recursion::async_recursion; +use byteorder::{BigEndian, ReadBytesExt}; +use std::{ + cell::{Cell, OnceCell}, + collections::HashMap, + io, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{mpsc, Mutex}; + +fn to_seqnum(seqnum: u16) -> usize { + (seqnum as usize) & (REL_BUFFER - 1) +} + +type Result = std::result::Result; + +struct Split { + timestamp: Option, + chunks: Vec>>, + got: usize, +} + +struct RecvChan { + packets: Vec>>>, // char ** 😛 + splits: HashMap, + seqnum: u16, + num: u8, +} + +pub struct RecvWorker { + share: Arc>, + close: watch::Receiver, + chans: Arc>>, + pkt_tx: mpsc::UnboundedSender, + udp_rx: R, +} + +impl RecvWorker { + pub fn new( + udp_rx: R, + share: Arc>, + close: watch::Receiver, + pkt_tx: mpsc::UnboundedSender, + ) -> Self { + Self { + udp_rx, + share, + close, + pkt_tx, + chans: Arc::new( + (0..NUM_CHANS as u8) + .map(|num| { + Mutex::new(RecvChan { + num, + packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }) + }) + .collect(), + ), + } + } + + pub async fn run(&self) { + let cleanup_chans = Arc::clone(&self.chans); + let mut cleanup_close = self.close.clone(); + self.share + .tasks + .lock() + .await + /*.build_task() + .name("cleanup_splits")*/ + .spawn(async move { + let timeout = Duration::from_secs(TIMEOUT); + + ticker!(timeout, cleanup_close, { + for chan_mtx in cleanup_chans.iter() { + let mut chan = chan_mtx.lock().await; + chan.splits = chan + .splits + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) + .collect(); + } + }); + }); + + let mut close = self.close.clone(); + loop { + if let Err(e) = self.handle(self.recv_pkt(&mut close).await) { + if let Error::LocalDisco = e { + self.share + .send( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: &[CtlType::Disco as u8], + }, + ) + .await + .ok(); + } + break; + } + } + } + + async fn recv_pkt(&self, close: &mut watch::Receiver) -> Result<()> { + use Error::*; + + // TODO: reset timeout + let mut cursor = io::Cursor::new(tokio::select! { + pkt = self.udp_rx.recv() => pkt?, + _ = close.changed() => return Err(LocalDisco), + }); + + println!("recv"); + + let proto_id = cursor.read_u32::()?; + if proto_id != PROTO_ID { + return Err(InvalidProtoId(proto_id)); + } + + let _peer_id = cursor.read_u16::()?; + + let n_chan = cursor.read_u8()?; + let mut chan = self + .chans + .get(n_chan as usize) + .ok_or(InvalidChannel(n_chan))? + .lock() + .await; + + self.process_pkt(cursor, true, &mut chan).await + } + + #[async_recursion] + async fn process_pkt( + &self, + mut cursor: io::Cursor>, + unrel: bool, + chan: &mut RecvChan, + ) -> Result<()> { + use Error::*; + + match cursor.read_u8()?.try_into()? { + PktType::Ctl => match cursor.read_u8()?.try_into()? { + CtlType::Ack => { + println!("Ack"); + + let seqnum = cursor.read_u16::()?; + if let Some(ack) = self.share.chans[chan.num as usize] + .lock() + .await + .acks + .remove(&seqnum) + { + ack.tx.send(true).ok(); + } + } + CtlType::SetPeerID => { + println!("SetPeerID"); + + let mut id = self.share.remote_id.write().await; + + if *id != PeerID::Nil as u16 { + return Err(PeerIDAlreadySet); + } + + *id = cursor.read_u16::()?; + } + CtlType::Ping => { + println!("Ping"); + } + CtlType::Disco => { + println!("Disco"); + return Err(RemoteDisco); + } + }, + PktType::Orig => { + println!("Orig"); + + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: cursor.remaining_slice().into(), + }))?; + } + PktType::Split => { + println!("Split"); + + let seqnum = cursor.read_u16::()?; + let chunk_index = cursor.read_u16::()? as usize; + let chunk_count = cursor.read_u16::()? as usize; + + let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), + timestamp: None, + }); + + if split.chunks.len() != chunk_count { + return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); + } + + if split + .chunks + .get(chunk_index) + .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? + .set(cursor.remaining_slice().into()) + .is_ok() + { + split.got += 1; + } + + split.timestamp = if unrel { Some(Instant::now()) } else { None }; + + if split.got == chunk_count { + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel, + data: split + .chunks + .iter() + .flat_map(|chunk| chunk.get().unwrap().iter()) + .copied() + .collect(), + }))?; + + chan.splits.remove(&seqnum); + } + } + PktType::Rel => { + println!("Rel"); + + let seqnum = cursor.read_u16::()?; + chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); + + let mut ack_data = Vec::with_capacity(3); + ack_data.write_u8(CtlType::Ack as u8)?; + ack_data.write_u16::(seqnum)?; + + self.share + .send( + PktType::Ctl, + Pkt { + unrel: true, + chan: chan.num, + data: &ack_data, + }, + ) + .await?; + + fn next_pkt(chan: &mut RecvChan) -> Option> { + chan.packets[to_seqnum(chan.seqnum)].take() + } + + while let Some(pkt) = next_pkt(chan) { + self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; + chan.seqnum = chan.seqnum.overflowing_add(1).0; + } + } + } + + Ok(()) + } + + fn handle(&self, res: Result<()>) -> Result<()> { + use Error::*; + + match res { + Ok(v) => Ok(v), + Err(RemoteDisco) => Err(RemoteDisco), + Err(LocalDisco) => Err(LocalDisco), + Err(e) => Ok(self.pkt_tx.send(Err(e))?), + } + } +} diff --git a/src/recv_worker.rs b/src/recv_worker.rs deleted file mode 100644 index 83b3273..0000000 --- a/src/recv_worker.rs +++ /dev/null @@ -1,256 +0,0 @@ -use crate::{error::Error, *}; -use async_recursion::async_recursion; -use byteorder::{BigEndian, ReadBytesExt}; -use std::{ - cell::{Cell, OnceCell}, - collections::HashMap, - io, - sync::{Arc, Weak}, - time, -}; -use tokio::sync::{mpsc, Mutex}; - -fn to_seqnum(seqnum: u16) -> usize { - (seqnum as usize) & (REL_BUFFER - 1) -} - -type Result = std::result::Result; - -struct Split { - timestamp: Option, - chunks: Vec>>, - got: usize, -} - -struct Chan { - packets: Vec>>>, // char ** 😛 - splits: HashMap, - seqnum: u16, - num: u8, -} - -pub struct RecvWorker { - share: Arc>, - chans: Arc>>, - pkt_tx: mpsc::UnboundedSender, - udp_rx: R, -} - -impl RecvWorker { - pub fn new(udp_rx: R, share: Arc>, pkt_tx: mpsc::UnboundedSender) -> Self { - Self { - udp_rx, - share, - pkt_tx, - chans: Arc::new( - (0..NUM_CHANS as u8) - .map(|num| { - Mutex::new(Chan { - num, - packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(), - seqnum: INIT_SEQNUM, - splits: HashMap::new(), - }) - }) - .collect(), - ), - } - } - - pub async fn run(&self) { - let cleanup_chans = Arc::downgrade(&self.chans); - tokio::spawn(async move { - let timeout = time::Duration::from_secs(TIMEOUT); - let mut interval = tokio::time::interval(timeout); - - while let Some(chans) = Weak::upgrade(&cleanup_chans) { - for chan in chans.iter() { - let mut ch = chan.lock().await; - ch.splits = ch - .splits - .drain_filter( - |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), - ) - .collect(); - } - - interval.tick().await; - } - }); - - loop { - if let Err(e) = self.handle(self.recv_pkt().await) { - if let Error::LocalDisco = e { - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: &[CtlType::Disco as u8], - }, - ) - .await - .ok(); - } - break; - } - } - } - - async fn recv_pkt(&self) -> Result<()> { - use Error::*; - - // todo: reset timeout - let mut cursor = io::Cursor::new(self.udp_rx.recv().await?); - - let proto_id = cursor.read_u32::()?; - if proto_id != PROTO_ID { - return Err(InvalidProtoId(proto_id)); - } - - let _peer_id = cursor.read_u16::()?; - - let n_chan = cursor.read_u8()?; - let mut chan = self - .chans - .get(n_chan as usize) - .ok_or(InvalidChannel(n_chan))? - .lock() - .await; - - self.process_pkt(cursor, true, &mut chan).await - } - - #[async_recursion] - async fn process_pkt( - &self, - mut cursor: io::Cursor>, - unrel: bool, - chan: &mut Chan, - ) -> Result<()> { - use Error::*; - - match cursor.read_u8()?.try_into()? { - PktType::Ctl => match cursor.read_u8()?.try_into()? { - CtlType::Ack => { - let seqnum = cursor.read_u16::()?; - if let Some((tx, _)) = self.share.ack_chans.lock().await.remove(&seqnum) { - tx.send(true).ok(); - } - } - CtlType::SetPeerID => { - let mut id = self.share.remote_id.write().await; - - if *id != PeerID::Nil as u16 { - return Err(PeerIDAlreadySet); - } - - *id = cursor.read_u16::()?; - } - CtlType::Ping => {} - CtlType::Disco => return Err(RemoteDisco), - }, - PktType::Orig => { - println!("Orig"); - - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, - unrel, - data: cursor.remaining_slice().into(), - }))?; - } - PktType::Split => { - println!("Split"); - - let seqnum = cursor.read_u16::()?; - let chunk_index = cursor.read_u16::()? as usize; - let chunk_count = cursor.read_u16::()? as usize; - - let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { - got: 0, - chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), - timestamp: None, - }); - - if split.chunks.len() != chunk_count { - return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); - } - - if split - .chunks - .get(chunk_index) - .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? - .set(cursor.remaining_slice().into()) - .is_ok() - { - split.got += 1; - } - - split.timestamp = if unrel { - Some(time::Instant::now()) - } else { - None - }; - - if split.got == chunk_count { - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, - unrel, - data: split - .chunks - .iter() - .flat_map(|chunk| chunk.get().unwrap().iter()) - .copied() - .collect(), - }))?; - - chan.splits.remove(&seqnum); - } - } - PktType::Rel => { - println!("Rel"); - - let seqnum = cursor.read_u16::()?; - chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); - - let mut ack_data = Vec::with_capacity(3); - ack_data.write_u8(CtlType::Ack as u8)?; - ack_data.write_u16::(seqnum)?; - - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: chan.num, - data: &ack_data, - }, - ) - .await?; - - fn next_pkt(chan: &mut Chan) -> Option> { - chan.packets[to_seqnum(chan.seqnum)].take() - } - - while let Some(pkt) = next_pkt(chan) { - self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; - chan.seqnum = chan.seqnum.overflowing_add(1).0; - } - } - } - - Ok(()) - } - - fn handle(&self, res: Result<()>) -> Result<()> { - use Error::*; - - match res { - Ok(v) => Ok(v), - Err(RemoteDisco) => Err(RemoteDisco), - Err(LocalDisco) => Err(LocalDisco), - Err(e) => Ok(self.pkt_tx.send(Err(e))?), - } - } -} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..89c15c7 --- /dev/null +++ b/src/send.rs @@ -0,0 +1,55 @@ +use crate::*; +use tokio::sync::watch; + +type AckResult = io::Result>>; + +impl RudpSender { + pub async fn send(&self, pkt: Pkt<&[u8]>) -> AckResult { + self.share.send(PktType::Orig, pkt).await // TODO: splits + } +} + +impl RudpShare { + pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> AckResult { + let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); + buf.write_u32::(PROTO_ID)?; + buf.write_u16::(*self.remote_id.read().await)?; + buf.write_u8(pkt.chan as u8)?; + + let mut chan = self.chans[pkt.chan as usize].lock().await; + let seqnum = chan.seqnum; + + if !pkt.unrel { + buf.write_u8(PktType::Rel as u8)?; + buf.write_u16::(seqnum)?; + } + + buf.write_u8(tp as u8)?; + buf.write(pkt.data)?; + + self.send_raw(&buf).await?; + + if pkt.unrel { + Ok(None) + } else { + // TODO: reliable window + let (tx, rx) = watch::channel(false); + chan.acks.insert( + seqnum, + Ack { + tx, + rx: rx.clone(), + data: buf, + }, + ); + chan.seqnum += 1; + + Ok(Some(rx)) + } + } + + pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> { + self.udp_tx.send(data).await + // TODO: reset ping timeout + } +}