From fd23bb3a2b57d43c115005dcd70f1e18bb005032 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Fri, 6 Jan 2023 17:45:16 +0100 Subject: [PATCH] clean shutdown; send reliables --- Cargo.toml | 1 + src/client.rs | 7 +- src/main.rs | 169 +++++++++++++++++--------------- src/new.rs | 63 ++++++++++++ src/{recv_worker.rs => recv.rs} | 109 ++++++++++++-------- src/send.rs | 55 +++++++++++ 6 files changed, 280 insertions(+), 124 deletions(-) create mode 100644 src/new.rs rename src/{recv_worker.rs => recv.rs} (72%) create mode 100644 src/send.rs diff --git a/Cargo.toml b/Cargo.toml index edcdf2e..7772e3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,5 @@ async-recursion = "1.0.0" async-trait = "0.1.60" byteorder = "1.4.3" num_enum = "0.5.7" +pretty-hex = "0.3.0" tokio = { version = "1.23.0", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index d416e53..172aa96 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,8 +8,8 @@ pub struct Sender { #[async_trait] impl UdpSender for Sender { - async fn send(&self, data: Vec) -> io::Result<()> { - self.sock.send(&data).await?; + async fn send(&self, data: &[u8]) -> io::Result<()> { + self.sock.send(data).await?; Ok(()) } } @@ -42,5 +42,6 @@ pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver sock: Arc::clone(&sock), }, Receiver { sock }, - )) + ) + .await?) } diff --git a/src/main.rs b/src/main.rs index 1f0aca0..0510db7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,19 +3,27 @@ #![feature(once_cell)] mod client; pub mod error; -mod recv_worker; +mod new; +mod recv; +mod send; use async_trait::async_trait; use byteorder::{BigEndian, WriteBytesExt}; pub use client::{connect, Sender as Client}; +pub use new::new; use num_enum::TryFromPrimitive; +use pretty_hex::PrettyHex; use std::{ collections::HashMap, io::{self, Write}, ops, sync::Arc, + time::Duration, +}; +use tokio::{ + sync::{mpsc, watch, Mutex, RwLock}, + task::JoinSet, }; -use tokio::sync::{mpsc, watch, Mutex, RwLock}; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; @@ -24,9 +32,25 @@ pub const REL_BUFFER: usize = 0x8000; pub const INIT_SEQNUM: u16 = 65500; pub const TIMEOUT: u64 = 30; +mod ticker_mod { + #[macro_export] + macro_rules! ticker { + ($duration:expr, $close:expr, $body:block) => { + let mut interval = tokio::time::interval($duration); + + while tokio::select!{ + _ = interval.tick() => true, + _ = $close.changed() => false, + } $body + }; + } + + //pub(crate) use ticker; +} + #[async_trait] pub trait UdpSender: Send + Sync + 'static { - async fn send(&self, data: Vec) -> io::Result<()>; + async fn send(&self, data: &[u8]) -> io::Result<()>; } #[async_trait] @@ -69,14 +93,28 @@ pub struct Pkt { pub type Error = error::Error; pub type InPkt = Result>, Error>; -type AckChan = (watch::Sender, watch::Receiver); + +#[derive(Debug)] +struct Ack { + tx: watch::Sender, + rx: watch::Receiver, + data: Vec, +} + +#[derive(Debug)] +struct Chan { + acks: HashMap, + seqnum: u16, +} #[derive(Debug)] pub struct RudpShare { - pub id: u16, - pub remote_id: RwLock, - pub ack_chans: Mutex>, + id: u16, + remote_id: RwLock, + chans: Vec>, udp_tx: S, + close_tx: watch::Sender, + tasks: Mutex>, } #[derive(Debug)] @@ -90,44 +128,31 @@ pub struct RudpSender { share: Arc>, } -impl RudpShare { - pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { - let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len()); - buf.write_u32::(PROTO_ID)?; - buf.write_u16::(*self.remote_id.read().await)?; - buf.write_u8(pkt.chan as u8)?; - buf.write_u8(tp as u8)?; - buf.write(pkt.data)?; - - self.udp_tx.send(buf).await?; - - Ok(()) - } -} +macro_rules! impl_share { + ($T:ident) => { + impl $T { + pub async fn peer_id(&self) -> u16 { + self.share.id + } -impl RudpSender { - pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { - self.share.send(PktType::Orig, pkt).await // TODO - } + pub async fn is_server(&self) -> bool { + self.share.id == PeerID::Srv as u16 + } - pub async fn peer_id(&self) -> u16 { - self.share.id - } + pub async fn close(self) { + self.share.close_tx.send(true).ok(); - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } + let mut tasks = self.share.tasks.lock().await; + while let Some(res) = tasks.join_next().await { + res.ok(); // TODO: handle error (?) + } + } + } + }; } -impl RudpReceiver { - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } -} +impl_share!(RudpReceiver); +impl_share!(RudpSender); impl ops::Deref for RudpReceiver { type Target = mpsc::UnboundedReceiver; @@ -143,49 +168,16 @@ impl ops::DerefMut for RudpReceiver { } } -pub fn new( - id: u16, - remote_id: u16, - udp_tx: S, - udp_rx: R, -) -> (RudpSender, RudpReceiver) { - let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); - - let share = Arc::new(RudpShare { - id, - remote_id: RwLock::new(remote_id), - udp_tx, - ack_chans: Mutex::new(HashMap::new()), - }); - let recv_share = Arc::clone(&share); - - tokio::spawn(async { - let worker = recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx); - worker.run().await; - }); - - ( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver { share, pkt_rx }, - ) -} - -// connect - -#[tokio::main] -async fn main() -> io::Result<()> { - let (tx, mut rx) = connect("127.0.0.1:30000").await?; - +async fn example(tx: &RudpSender, rx: &mut RudpReceiver) -> io::Result<()> { + // send hello packet let mut mtpkt = vec![]; mtpkt.write_u16::(2)?; // high level type mtpkt.write_u8(29)?; // serialize ver mtpkt.write_u16::(0)?; // compression modes mtpkt.write_u16::(40)?; // MinProtoVer mtpkt.write_u16::(40)?; // MaxProtoVer - mtpkt.write_u16::(3)?; // player name length - mtpkt.write(b"foo")?; // player name + mtpkt.write_u16::(6)?; // player name length + mtpkt.write(b"foobar")?; // player name tx.send(Pkt { unrel: true, @@ -194,17 +186,34 @@ async fn main() -> io::Result<()> { }) .await?; + // handle incoming packets while let Some(result) = rx.recv().await { match result { Ok(pkt) => { - io::stdout().write(pkt.data.as_slice())?; + println!("{}", pkt.data.hex_dump()); } Err(err) => eprintln!("Error: {}", err), } } - println!("disco"); - // close()ing rx is not needed because it has been consumed to the end + Ok(()) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + let (tx, mut rx) = connect("127.0.0.1:30000").await?; + + tokio::select! { + _ = tokio::signal::ctrl_c() => println!("canceled"), + res = example(&tx, &mut rx) => { + res?; + println!("disconnected"); + } + } + + // close either the receiver or the sender + // this shuts down associated tasks + rx.close().await; Ok(()) } diff --git a/src/new.rs b/src/new.rs new file mode 100644 index 0000000..a70b117 --- /dev/null +++ b/src/new.rs @@ -0,0 +1,63 @@ +use crate::*; + +pub async fn new( + id: u16, + remote_id: u16, + udp_tx: S, + udp_rx: R, +) -> io::Result<(RudpSender, RudpReceiver)> { + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + let (close_tx, close_rx) = watch::channel(false); + + let share = Arc::new(RudpShare { + id, + remote_id: RwLock::new(remote_id), + udp_tx, + close_tx, + chans: (0..NUM_CHANS) + .map(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, + }) + }) + .collect(), + tasks: Mutex::new(JoinSet::new()), + }); + + let mut tasks = share.tasks.lock().await; + + let recv_share = Arc::clone(&share); + let recv_close = close_rx.clone(); + tasks + /*.build_task() + .name("recv")*/ + .spawn(async move { + let worker = recv::RecvWorker::new(udp_rx, recv_share, recv_close, pkt_tx); + worker.run().await; + }); + + let resend_share = Arc::clone(&share); + let mut resend_close = close_rx.clone(); + tasks + /*.build_task() + .name("resend")*/ + .spawn(async move { + ticker!(Duration::from_millis(500), resend_close, { + for chan in resend_share.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + resend_share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + } + } + }); + }); + + drop(tasks); + + Ok(( + RudpSender { + share: Arc::clone(&share), + }, + RudpReceiver { share, pkt_rx }, + )) +} diff --git a/src/recv_worker.rs b/src/recv.rs similarity index 72% rename from src/recv_worker.rs rename to src/recv.rs index 83b3273..15811f2 100644 --- a/src/recv_worker.rs +++ b/src/recv.rs @@ -5,8 +5,8 @@ use std::{ cell::{Cell, OnceCell}, collections::HashMap, io, - sync::{Arc, Weak}, - time, + sync::Arc, + time::{Duration, Instant}, }; use tokio::sync::{mpsc, Mutex}; @@ -17,12 +17,12 @@ fn to_seqnum(seqnum: u16) -> usize { type Result = std::result::Result; struct Split { - timestamp: Option, + timestamp: Option, chunks: Vec>>, got: usize, } -struct Chan { +struct RecvChan { packets: Vec>>>, // char ** 😛 splits: HashMap, seqnum: u16, @@ -31,21 +31,28 @@ struct Chan { pub struct RecvWorker { share: Arc>, - chans: Arc>>, + close: watch::Receiver, + chans: Arc>>, pkt_tx: mpsc::UnboundedSender, udp_rx: R, } impl RecvWorker { - pub fn new(udp_rx: R, share: Arc>, pkt_tx: mpsc::UnboundedSender) -> Self { + pub fn new( + udp_rx: R, + share: Arc>, + close: watch::Receiver, + pkt_tx: mpsc::UnboundedSender, + ) -> Self { Self { udp_rx, share, + close, pkt_tx, chans: Arc::new( (0..NUM_CHANS as u8) .map(|num| { - Mutex::new(Chan { + Mutex::new(RecvChan { num, packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(), seqnum: INIT_SEQNUM, @@ -58,28 +65,33 @@ impl RecvWorker { } pub async fn run(&self) { - let cleanup_chans = Arc::downgrade(&self.chans); - tokio::spawn(async move { - let timeout = time::Duration::from_secs(TIMEOUT); - let mut interval = tokio::time::interval(timeout); - - while let Some(chans) = Weak::upgrade(&cleanup_chans) { - for chan in chans.iter() { - let mut ch = chan.lock().await; - ch.splits = ch - .splits - .drain_filter( - |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), - ) - .collect(); - } - - interval.tick().await; - } - }); + let cleanup_chans = Arc::clone(&self.chans); + let mut cleanup_close = self.close.clone(); + self.share + .tasks + .lock() + .await + /*.build_task() + .name("cleanup_splits")*/ + .spawn(async move { + let timeout = Duration::from_secs(TIMEOUT); + + ticker!(timeout, cleanup_close, { + for chan_mtx in cleanup_chans.iter() { + let mut chan = chan_mtx.lock().await; + chan.splits = chan + .splits + .drain_filter( + |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout), + ) + .collect(); + } + }); + }); + let mut close = self.close.clone(); loop { - if let Err(e) = self.handle(self.recv_pkt().await) { + if let Err(e) = self.handle(self.recv_pkt(&mut close).await) { if let Error::LocalDisco = e { self.share .send( @@ -98,11 +110,16 @@ impl RecvWorker { } } - async fn recv_pkt(&self) -> Result<()> { + async fn recv_pkt(&self, close: &mut watch::Receiver) -> Result<()> { use Error::*; - // todo: reset timeout - let mut cursor = io::Cursor::new(self.udp_rx.recv().await?); + // TODO: reset timeout + let mut cursor = io::Cursor::new(tokio::select! { + pkt = self.udp_rx.recv() => pkt?, + _ = close.changed() => return Err(LocalDisco), + }); + + println!("recv"); let proto_id = cursor.read_u32::()?; if proto_id != PROTO_ID { @@ -127,19 +144,28 @@ impl RecvWorker { &self, mut cursor: io::Cursor>, unrel: bool, - chan: &mut Chan, + chan: &mut RecvChan, ) -> Result<()> { use Error::*; match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { + println!("Ack"); + let seqnum = cursor.read_u16::()?; - if let Some((tx, _)) = self.share.ack_chans.lock().await.remove(&seqnum) { - tx.send(true).ok(); + if let Some(ack) = self.share.chans[chan.num as usize] + .lock() + .await + .acks + .remove(&seqnum) + { + ack.tx.send(true).ok(); } } CtlType::SetPeerID => { + println!("SetPeerID"); + let mut id = self.share.remote_id.write().await; if *id != PeerID::Nil as u16 { @@ -148,8 +174,13 @@ impl RecvWorker { *id = cursor.read_u16::()?; } - CtlType::Ping => {} - CtlType::Disco => return Err(RemoteDisco), + CtlType::Ping => { + println!("Ping"); + } + CtlType::Disco => { + println!("Disco"); + return Err(RemoteDisco); + } }, PktType::Orig => { println!("Orig"); @@ -187,11 +218,7 @@ impl RecvWorker { split.got += 1; } - split.timestamp = if unrel { - Some(time::Instant::now()) - } else { - None - }; + split.timestamp = if unrel { Some(Instant::now()) } else { None }; if split.got == chunk_count { self.pkt_tx.send(Ok(Pkt { @@ -229,7 +256,7 @@ impl RecvWorker { ) .await?; - fn next_pkt(chan: &mut Chan) -> Option> { + fn next_pkt(chan: &mut RecvChan) -> Option> { chan.packets[to_seqnum(chan.seqnum)].take() } diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..89c15c7 --- /dev/null +++ b/src/send.rs @@ -0,0 +1,55 @@ +use crate::*; +use tokio::sync::watch; + +type AckResult = io::Result>>; + +impl RudpSender { + pub async fn send(&self, pkt: Pkt<&[u8]>) -> AckResult { + self.share.send(PktType::Orig, pkt).await // TODO: splits + } +} + +impl RudpShare { + pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> AckResult { + let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); + buf.write_u32::(PROTO_ID)?; + buf.write_u16::(*self.remote_id.read().await)?; + buf.write_u8(pkt.chan as u8)?; + + let mut chan = self.chans[pkt.chan as usize].lock().await; + let seqnum = chan.seqnum; + + if !pkt.unrel { + buf.write_u8(PktType::Rel as u8)?; + buf.write_u16::(seqnum)?; + } + + buf.write_u8(tp as u8)?; + buf.write(pkt.data)?; + + self.send_raw(&buf).await?; + + if pkt.unrel { + Ok(None) + } else { + // TODO: reliable window + let (tx, rx) = watch::channel(false); + chan.acks.insert( + seqnum, + Ack { + tx, + rx: rx.clone(), + data: buf, + }, + ); + chan.seqnum += 1; + + Ok(Some(rx)) + } + } + + pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> { + self.udp_tx.send(data).await + // TODO: reset ping timeout + } +} -- 2.44.0