From 718e8618544c4cdde78138655305eee7c08058ee Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sat, 18 Feb 2023 03:03:40 +0100 Subject: [PATCH] Don't spawn tasks --- Cargo.toml | 6 +- examples/example.rs | 6 +- src/client.rs | 10 +- src/common.rs | 73 ++--------- src/error.rs | 7 -- src/lib.rs | 19 +-- src/recv.rs | 293 ++++++++++++++++++++++---------------------- src/send.rs | 22 +++- src/share.rs | 116 +++++++----------- 9 files changed, 235 insertions(+), 317 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8e5f7ef..254a0db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,10 @@ edition = "2021" async-recursion = "1.0.0" async-trait = "0.1.60" byteorder = "1.4.3" -delegate = "0.9.0" -drop_bomb = "0.1.5" num_enum = "0.5.7" thiserror = "1.0.38" -tokio = { version = "1.23.0", features = ["sync", "time", "net", "signal", "macros", "rt"] } +tokio = { version = "1.23.0", features = ["sync", "time", "net", "macros"] } [dev-dependencies] pretty-hex = "0.3.0" -tokio = { version = "1.23.0", features = ["rt-multi-thread"] } +tokio = { version = "1.23.0", features = ["rt-multi-thread", "rt", "signal", "tracing"] } diff --git a/examples/example.rs b/examples/example.rs index 8625243..afb9d5b 100644 --- a/examples/example.rs +++ b/examples/example.rs @@ -1,9 +1,9 @@ use byteorder::{BigEndian, WriteBytesExt}; -use mt_rudp::{RudpReceiver, RudpSender, ToSrv}; +use mt_rudp::{RemoteSrv, RudpReceiver, RudpSender}; use pretty_hex::PrettyHex; use std::io::{self, Write}; -async fn example(tx: &RudpSender, rx: &mut RudpReceiver) -> io::Result<()> { +async fn example(tx: &RudpSender, rx: &mut RudpReceiver) -> io::Result<()> { // send hello packet let mut pkt = vec![]; pkt.write_u16::(2)?; // high level type @@ -17,7 +17,7 @@ async fn example(tx: &RudpSender, rx: &mut RudpReceiver) -> io::Re tx.send(mt_rudp::Pkt { unrel: true, chan: 1, - data: &pkt, + data: pkt.into(), }) .await?; diff --git a/src/client.rs b/src/client.rs index c4922ec..56db92a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,7 +19,7 @@ impl UdpSender for ToSrv { #[async_trait] impl UdpReceiver for FromSrv { - async fn recv(&self) -> io::Result> { + async fn recv(&mut self) -> io::Result> { let mut buffer = Vec::new(); buffer.resize(UDP_PKT_SIZE, 0); @@ -30,7 +30,13 @@ impl UdpReceiver for FromSrv { } } -pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { +pub struct RemoteSrv; +impl UdpPeer for RemoteSrv { + type Sender = ToSrv; + type Receiver = FromSrv; +} + +pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?); sock.connect(addr).await?; diff --git a/src/common.rs b/src/common.rs index 4c0bc08..0ed08f3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,9 +1,6 @@ -use super::*; use async_trait::async_trait; -use delegate::delegate; use num_enum::TryFromPrimitive; -use std::{borrow::Cow, io, sync::Arc}; -use tokio::sync::mpsc; +use std::{borrow::Cow, fmt::Debug, io}; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; @@ -14,13 +11,18 @@ pub const TIMEOUT: u64 = 30; pub const PING_TIMEOUT: u64 = 5; #[async_trait] -pub trait UdpSender: Send + Sync + 'static { +pub trait UdpSender: Send + Sync { async fn send(&self, data: &[u8]) -> io::Result<()>; } #[async_trait] -pub trait UdpReceiver: Send + Sync + 'static { - async fn recv(&self) -> io::Result>; +pub trait UdpReceiver: Send { + async fn recv(&mut self) -> io::Result>; +} + +pub trait UdpPeer { + type Sender: UdpSender; + type Receiver: UdpReceiver; } #[derive(Debug, Copy, Clone, PartialEq)] @@ -55,60 +57,3 @@ pub struct Pkt<'a> { pub chan: u8, pub data: Cow<'a, [u8]>, } - -pub type InPkt = Result, Error>; - -#[derive(Debug)] -pub struct RudpReceiver { - pub(crate) share: Arc>, - pub(crate) pkt_rx: mpsc::UnboundedReceiver, -} - -#[derive(Debug)] -pub struct RudpSender { - pub(crate) share: Arc>, -} - -// derive(Clone) adds unwanted Clone trait bound to S parameter -impl Clone for RudpSender { - fn clone(&self) -> Self { - Self { - share: Arc::clone(&self.share), - } - } -} - -macro_rules! impl_share { - ($T:ident) => { - impl $T { - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } - - pub async fn close(self) { - self.share.bomb.lock().await.defuse(); - self.share.close_tx.send(true).ok(); - - let mut tasks = self.share.tasks.lock().await; - while let Some(res) = tasks.join_next().await { - res.ok(); // TODO: handle error (?) - } - } - } - }; -} - -impl_share!(RudpReceiver); -impl_share!(RudpSender); - -impl RudpReceiver { - delegate! { - to self.pkt_rx { - pub async fn recv(&mut self) -> Option; - } - } -} diff --git a/src/error.rs b/src/error.rs index 7cfc057..28233e8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,6 @@ use super::*; use num_enum::TryFromPrimitiveError; use thiserror::Error; -use tokio::sync::mpsc::error::SendError; #[derive(Error, Debug)] pub enum Error { @@ -38,9 +37,3 @@ impl From> for Error { Self::InvalidCtlType(err.number) } } - -impl From> for Error { - fn from(_err: SendError) -> Self { - Self::LocalDisco - } -} diff --git a/src/lib.rs b/src/lib.rs index a02eb20..e7a8ebe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![feature(cursor_remaining)] #![feature(hash_drain_filter)] -#![feature(once_cell)] + mod client; mod common; mod error; @@ -11,21 +11,6 @@ mod share; pub use client::*; pub use common::*; pub use error::*; -use recv::*; +pub use recv::*; pub use send::*; pub use share::*; -pub use ticker_mod::*; - -mod ticker_mod { - #[macro_export] - macro_rules! ticker { - ($duration:expr, $close:expr, $body:block) => { - let mut interval = tokio::time::interval($duration); - - while tokio::select!{ - _ = interval.tick() => true, - _ = $close.changed() => false, - } $body - }; - } -} diff --git a/src/recv.rs b/src/recv.rs index fd6f299..34e273c 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -3,14 +3,16 @@ use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::{ borrow::Cow, - cell::OnceCell, - collections::HashMap, + collections::{HashMap, VecDeque}, io, pin::Pin, sync::Arc, time::{Duration, Instant}, }; -use tokio::sync::{mpsc, watch, Mutex}; +use tokio::{ + sync::watch, + time::{interval, sleep, Interval, Sleep}, +}; fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) @@ -21,7 +23,7 @@ type Result = std::result::Result; #[derive(Debug)] struct Split { timestamp: Option, - chunks: Vec>>, + chunks: Vec>>, got: usize, } @@ -29,61 +31,82 @@ struct RecvChan { packets: Vec>>, // char ** 😛 splits: HashMap, seqnum: u16, - num: u8, } -pub(crate) struct RecvWorker { - share: Arc>, +pub struct RudpReceiver { + pub(crate) share: Arc>, + chans: [RecvChan; NUM_CHANS], + udp: P::Receiver, close: watch::Receiver, - chans: Arc>>, - pkt_tx: mpsc::UnboundedSender, - udp_rx: R, + closed: bool, + resend: Interval, + ping: Interval, + cleanup: Interval, + timeout: Pin>, + queue: VecDeque>>, } -impl RecvWorker { - pub fn new( - udp_rx: R, - share: Arc>, +impl RudpReceiver

{ + pub(crate) fn new( + udp: P::Receiver, + share: Arc>, close: watch::Receiver, - pkt_tx: mpsc::UnboundedSender, ) -> Self { Self { - udp_rx, + udp, share, close, - pkt_tx, - chans: Arc::new( - (0..NUM_CHANS as u8) - .map(|num| { - Mutex::new(RecvChan { - num, - packets: (0..REL_BUFFER).map(|_| None).collect(), - seqnum: INIT_SEQNUM, - splits: HashMap::new(), - }) - }) - .collect(), - ), + closed: false, + resend: interval(Duration::from_millis(500)), + ping: interval(Duration::from_secs(PING_TIMEOUT)), + cleanup: interval(Duration::from_secs(TIMEOUT)), + timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))), + chans: std::array::from_fn(|_| RecvChan { + packets: (0..REL_BUFFER).map(|_| None).collect(), + seqnum: INIT_SEQNUM, + splits: HashMap::new(), + }), + queue: VecDeque::new(), + } + } + + fn handle_err(&mut self, res: Result<()>) -> Result<()> { + use Error::*; + + match res { + Err(RemoteDisco(_)) | Err(LocalDisco) => { + self.closed = true; + res + } + Ok(_) => res, + Err(e) => { + self.queue.push_back(Err(e)); + Ok(()) + } } } - pub async fn run(&self) { + pub async fn recv(&mut self) -> Option>> { use Error::*; - let cleanup_chans = Arc::clone(&self.chans); - let mut cleanup_close = self.close.clone(); - self.share - .tasks - .lock() - .await - /*.build_task() - .name("cleanup_splits")*/ - .spawn(async move { - let timeout = Duration::from_secs(TIMEOUT); - - ticker!(timeout, cleanup_close, { - for chan_mtx in cleanup_chans.iter() { - let mut chan = chan_mtx.lock().await; + if self.closed { + return None; + } + + loop { + if let Some(x) = self.queue.pop_front() { + return Some(x); + } + + tokio::select! { + _ = self.close.changed() => { + self.closed = true; + return Some(Err(LocalDisco)); + }, + _ = self.cleanup.tick() => { + let timeout = Duration::from_secs(TIMEOUT); + + for chan in self.chans.iter_mut() { chan.splits = chan .splits .drain_filter( @@ -91,58 +114,48 @@ impl RecvWorker { ) .collect(); } - }); - }); - - let mut close = self.close.clone(); - let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT)); - tokio::pin!(timeout); - - loop { - if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) { - // TODO: figure out whether this is a good idea - if let RemoteDisco(to) = e { - self.pkt_tx.send(Err(RemoteDisco(to))).ok(); + }, + _ = self.resend.tick() => { + for chan in self.share.chans.iter() { + for (_, ack) in chan.lock().await.acks.iter() { + self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + } + } + }, + _ = self.ping.tick() => { + self.share + .send( + PktType::Ctl, + Pkt { + chan: 0, + unrel: false, + data: Cow::Borrowed(&[CtlType::Ping as u8]), + }, + ) + .await + .ok(); + } + _ = &mut self.timeout => { + self.closed = true; + return Some(Err(RemoteDisco(true))); + }, + pkt = self.udp.recv() => { + if let Err(e) = self.handle_pkt(pkt).await { + return Some(Err(e)); + } } - - #[allow(clippy::single_match)] - match e { - // anon5's mt notifies the peer on timeout, C++ MT does not - LocalDisco /*| RemoteDisco(true)*/ => drop( - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: Cow::Borrowed(&[CtlType::Disco as u8]), - }, - ) - .await - .ok(), - ), - _ => {} - } - - break; } } } - async fn recv_pkt( - &self, - close: &mut watch::Receiver, - timeout: Pin<&mut tokio::time::Sleep>, - ) -> Result<()> { + async fn handle_pkt(&mut self, pkt: io::Result>) -> Result<()> { use Error::*; - let mut cursor = io::Cursor::new(tokio::select! { - pkt = self.udp_rx.recv() => pkt?, - _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)), - _ = close.changed() => return Err(LocalDisco), - }); + let mut cursor = io::Cursor::new(pkt?); - timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT)); let proto_id = cursor.read_u32::()?; if proto_id != PROTO_ID { @@ -151,38 +164,34 @@ impl RecvWorker { let _peer_id = cursor.read_u16::()?; - let n_chan = cursor.read_u8()?; - let mut chan = self - .chans - .get(n_chan as usize) - .ok_or(InvalidChannel(n_chan))? - .lock() - .await; + let chan = cursor.read_u8()?; + if chan >= NUM_CHANS as u8 { + return Err(InvalidChannel(chan)); + } + + let res = self.process_pkt(cursor, true, chan).await; + self.handle_err(res)?; - self.process_pkt(cursor, true, &mut chan).await + Ok(()) } #[async_recursion] async fn process_pkt( - &self, + &mut self, mut cursor: io::Cursor>, unrel: bool, - chan: &mut RecvChan, + chan: u8, ) -> Result<()> { use Error::*; + let ch = chan as usize; match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { // println!("Ack"); let seqnum = cursor.read_u16::()?; - if let Some(ack) = self.share.chans[chan.num as usize] - .lock() - .await - .acks - .remove(&seqnum) - { + if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { ack.tx.send(true).ok(); } } @@ -208,11 +217,11 @@ impl RecvWorker { PktType::Orig => { // println!("Orig"); - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, + self.queue.push_back(Ok(Pkt { + chan, unrel, data: Cow::Owned(cursor.remaining_slice().into()), - }))?; + })); } PktType::Split => { // println!("Split"); @@ -221,11 +230,14 @@ impl RecvWorker { let chunk_index = cursor.read_u16::()? as usize; let chunk_count = cursor.read_u16::()? as usize; - let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split { - got: 0, - chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(), - timestamp: None, - }); + let mut split = self.chans[ch] + .splits + .entry(seqnum) + .or_insert_with(|| Split { + got: 0, + chunks: (0..chunk_count).map(|_| None).collect(), + timestamp: None, + }); if split.chunks.len() != chunk_count { return Err(InvalidChunkCount(split.chunks.len(), chunk_count)); @@ -233,10 +245,10 @@ impl RecvWorker { if split .chunks - .get(chunk_index) + .get_mut(chunk_index) .ok_or(InvalidChunkIndex(chunk_index, chunk_count))? - .set(cursor.remaining_slice().into()) - .is_ok() + .replace(cursor.remaining_slice().into()) + .is_none() { split.got += 1; } @@ -244,25 +256,29 @@ impl RecvWorker { split.timestamp = if unrel { Some(Instant::now()) } else { None }; if split.got == chunk_count { - self.pkt_tx.send(Ok(Pkt { - chan: chan.num, + let split = self.chans[ch].splits.remove(&seqnum).unwrap(); + + self.queue.push_back(Ok(Pkt { + chan, unrel, data: split .chunks - .iter() - .flat_map(|chunk| chunk.get().unwrap().iter()) - .copied() - .collect(), - }))?; - - chan.splits.remove(&seqnum); + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })); } } PktType::Rel => { // println!("Rel"); let seqnum = cursor.read_u16::()?; - chan.packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); + self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); let mut ack_data = Vec::with_capacity(3); ack_data.write_u8(CtlType::Ack as u8)?; @@ -272,35 +288,22 @@ impl RecvWorker { .send( PktType::Ctl, Pkt { + chan, unrel: true, - chan: chan.num, - data: Cow::Borrowed(&ack_data), + data: ack_data.into(), }, ) .await?; - fn next_pkt(chan: &mut RecvChan) -> Option> { - chan.packets[to_seqnum(chan.seqnum)].take() - } - - while let Some(pkt) = next_pkt(chan) { - self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?; - chan.seqnum = chan.seqnum.overflowing_add(1).0; + let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); + while let Some(pkt) = next_pkt(&mut self.chans[ch]) { + let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; + self.handle_err(res)?; + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; } } } Ok(()) } - - fn handle(&self, res: Result<()>) -> Result<()> { - use Error::*; - - match res { - Ok(v) => Ok(v), - Err(RemoteDisco(to)) => Err(RemoteDisco(to)), - Err(LocalDisco) => Err(LocalDisco), - Err(e) => Ok(self.pkt_tx.send(Err(e))?), - } - } } diff --git a/src/send.rs b/src/send.rs index a3a7f03..2c449e1 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,17 +1,33 @@ use super::*; use byteorder::{BigEndian, WriteBytesExt}; -use std::io::{self, Write}; +use std::{ + io::{self, Write}, + sync::Arc, +}; use tokio::sync::watch; pub type AckResult = io::Result>>; -impl RudpSender { +pub struct RudpSender { + pub(crate) share: Arc>, +} + +// derive(Clone) adds unwanted Clone trait bound to P parameter +impl Clone for RudpSender

{ + fn clone(&self) -> Self { + Self { + share: Arc::clone(&self.share), + } + } +} + +impl RudpSender

{ pub async fn send(&self, pkt: Pkt<'_>) -> AckResult { self.share.send(PktType::Orig, pkt).await // TODO: splits } } -impl RudpShare { +impl RudpShare

{ pub async fn send(&self, tp: PktType, pkt: Pkt<'_>) -> AckResult { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; diff --git a/src/share.rs b/src/share.rs index a2afc4c..02e37b2 100644 --- a/src/share.rs +++ b/src/share.rs @@ -1,10 +1,6 @@ use super::*; -use drop_bomb::DropBomb; -use std::{borrow::Cow, collections::HashMap, io, sync::Arc, time::Duration}; -use tokio::{ - sync::{mpsc, watch, Mutex, RwLock}, - task::JoinSet, -}; +use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; +use tokio::sync::{watch, Mutex, RwLock}; #[derive(Debug)] pub(crate) struct Ack { @@ -20,96 +16,72 @@ pub(crate) struct Chan { } #[derive(Debug)] -pub(crate) struct RudpShare { +pub(crate) struct RudpShare { pub(crate) id: u16, pub(crate) remote_id: RwLock, - pub(crate) chans: Vec>, - pub(crate) udp_tx: S, - pub(crate) close_tx: watch::Sender, - pub(crate) tasks: Mutex>, - pub(crate) bomb: Mutex, + pub(crate) chans: [Mutex; NUM_CHANS], + pub(crate) udp_tx: P::Sender, + pub(crate) close: watch::Sender, } -pub async fn new( +pub async fn new( id: u16, remote_id: u16, - udp_tx: S, - udp_rx: R, -) -> io::Result<(RudpSender, RudpReceiver)> { - let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + udp_tx: P::Sender, + udp_rx: P::Receiver, +) -> io::Result<(RudpSender

, RudpReceiver

)> { let (close_tx, close_rx) = watch::channel(false); let share = Arc::new(RudpShare { id, remote_id: RwLock::new(remote_id), udp_tx, - close_tx, - chans: (0..NUM_CHANS) - .map(|_| { - Mutex::new(Chan { - acks: HashMap::new(), - seqnum: INIT_SEQNUM, - }) + close: close_tx, + chans: std::array::from_fn(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, }) - .collect(), - tasks: Mutex::new(JoinSet::new()), - bomb: Mutex::new(DropBomb::new("rudp connection must be explicitly closed")), + }), }); - let mut tasks = share.tasks.lock().await; + Ok(( + RudpSender { + share: Arc::clone(&share), + }, + RudpReceiver::new(udp_rx, share, close_rx), + )) +} - let recv_share = Arc::clone(&share); - let recv_close = close_rx.clone(); - tasks - /*.build_task() - .name("recv")*/ - .spawn(async move { - let worker = RecvWorker::new(udp_rx, recv_share, recv_close, pkt_tx); - worker.run().await; - }); +macro_rules! impl_share { + ($T:ident) => { + impl $T

{ + pub async fn peer_id(&self) -> u16 { + self.share.id + } - let resend_share = Arc::clone(&share); - let mut resend_close = close_rx.clone(); - tasks - /*.build_task() - .name("resend")*/ - .spawn(async move { - ticker!(Duration::from_millis(500), resend_close, { - for chan in resend_share.chans.iter() { - for (_, ack) in chan.lock().await.acks.iter() { - resend_share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) - } - } - }); - }); + pub async fn is_server(&self) -> bool { + self.share.id == PeerID::Srv as u16 + } - let ping_share = Arc::clone(&share); - let mut ping_close = close_rx.clone(); - tasks - /*.build_task() - .name("ping")*/ - .spawn(async move { - ticker!(Duration::from_secs(PING_TIMEOUT), ping_close, { - ping_share + pub async fn close(self) { + self.share.close.send(true).ok(); // FIXME: handle err? + + self.share .send( PktType::Ctl, Pkt { + unrel: true, chan: 0, - unrel: false, - data: Cow::Borrowed(&[CtlType::Ping as u8]), + data: Cow::Borrowed(&[CtlType::Disco as u8]), }, ) .await .ok(); - }); - }); - - drop(tasks); - - Ok(( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver { share, pkt_rx }, - )) + } + } + }; } + +impl_share!(RudpReceiver); +impl_share!(RudpSender); -- 2.44.0