From 89b1fc1d8d4bd886d80af0fe1d492cc877bce022 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sat, 25 Feb 2023 18:55:53 +0100 Subject: [PATCH] Use channels --- src/client.rs | 47 +++++++---- src/common.rs | 5 -- src/lib.rs | 6 +- src/send.rs | 77 ++++++++++++----- src/share.rs | 87 ------------------- src/{recv.rs => worker.rs} | 167 ++++++++++++++++--------------------- 6 files changed, 163 insertions(+), 226 deletions(-) delete mode 100644 src/share.rs rename src/{recv.rs => worker.rs} (69%) diff --git a/src/client.rs b/src/client.rs index 56db92a..29244d0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,16 +1,19 @@ use super::*; use async_trait::async_trait; use std::{io, sync::Arc}; -use tokio::net; +use tokio::{ + net, + sync::{mpsc, watch}, +}; #[derive(Debug)] -pub struct ToSrv(Arc); +pub struct UdpCltSender(Arc); #[derive(Debug)] -pub struct FromSrv(Arc); +pub struct UdpCltReceiver(Arc); #[async_trait] -impl UdpSender for ToSrv { +impl UdpSender for UdpCltSender { async fn send(&self, data: &[u8]) -> io::Result<()> { self.0.send(data).await?; Ok(()) @@ -18,7 +21,7 @@ impl UdpSender for ToSrv { } #[async_trait] -impl UdpReceiver for FromSrv { +impl UdpReceiver for UdpCltReceiver { async fn recv(&mut self) -> io::Result> { let mut buffer = Vec::new(); buffer.resize(UDP_PKT_SIZE, 0); @@ -30,21 +33,35 @@ impl UdpReceiver for FromSrv { } } -pub struct RemoteSrv; -impl UdpPeer for RemoteSrv { - type Sender = ToSrv; - type Receiver = FromSrv; +#[derive(Debug)] +pub struct CltReceiver(mpsc::UnboundedReceiver, Error>>); + +impl CltReceiver { + pub async fn recv_rudp(&mut self) -> Option, Error>> { + self.0.recv().await + } } -pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { +pub type CltSender = Arc>; +pub type CltWorker = Worker; + +pub async fn connect(addr: &str) -> io::Result<(CltSender, CltReceiver, CltWorker)> { let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?); sock.connect(addr).await?; - new( + let (close_tx, close_rx) = watch::channel(false); + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); + + let sender = Sender::new( + UdpCltSender(Arc::clone(&sock)), + close_tx, PeerID::Srv as u16, PeerID::Nil as u16, - ToSrv(Arc::clone(&sock)), - FromSrv(sock), - ) - .await + ); + + Ok(( + Arc::clone(&sender), + CltReceiver(pkt_rx), + Worker::new(UdpCltReceiver(sock), close_rx, sender, pkt_tx), + )) } diff --git a/src/common.rs b/src/common.rs index 0ed08f3..bdae6d2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -20,11 +20,6 @@ pub trait UdpReceiver: Send { async fn recv(&mut self) -> io::Result>; } -pub trait UdpPeer { - type Sender: UdpSender; - type Receiver: UdpReceiver; -} - #[derive(Debug, Copy, Clone, PartialEq)] #[repr(u16)] pub enum PeerID { diff --git a/src/lib.rs b/src/lib.rs index e7a8ebe..b9a042d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,11 @@ mod client; mod common; mod error; -mod recv; mod send; -mod share; +mod worker; pub use client::*; pub use common::*; pub use error::*; -pub use recv::*; pub use send::*; -pub use share::*; +pub use worker::*; diff --git a/src/send.rs b/src/send.rs index 2c449e1..90bbe2d 100644 --- a/src/send.rs +++ b/src/send.rs @@ -1,34 +1,57 @@ use super::*; use byteorder::{BigEndian, WriteBytesExt}; use std::{ + collections::HashMap, io::{self, Write}, sync::Arc, }; -use tokio::sync::watch; +use tokio::sync::{watch, Mutex, RwLock}; -pub type AckResult = io::Result>>; +pub type Ack = Option>; -pub struct RudpSender { - pub(crate) share: Arc>, +#[derive(Debug)] +pub(crate) struct AckWait { + pub(crate) tx: watch::Sender, + pub(crate) rx: watch::Receiver, + pub(crate) data: Vec, } -// derive(Clone) adds unwanted Clone trait bound to P parameter -impl Clone for RudpSender

{ - fn clone(&self) -> Self { - Self { - share: Arc::clone(&self.share), - } - } +#[derive(Debug)] +pub(crate) struct Chan { + pub(crate) acks: HashMap, + pub(crate) seqnum: u16, } -impl RudpSender

{ - pub async fn send(&self, pkt: Pkt<'_>) -> AckResult { - self.share.send(PktType::Orig, pkt).await // TODO: splits - } +#[derive(Debug)] +pub struct Sender { + pub(crate) id: u16, + pub(crate) remote_id: RwLock, + pub(crate) chans: [Mutex; NUM_CHANS], + udp: S, + close: watch::Sender, } -impl RudpShare

{ - pub async fn send(&self, tp: PktType, pkt: Pkt<'_>) -> AckResult { +impl Sender { + pub fn new(udp: S, close: watch::Sender, id: u16, remote_id: u16) -> Arc { + Arc::new(Self { + id, + remote_id: RwLock::new(remote_id), + udp, + close, + chans: std::array::from_fn(|_| { + Mutex::new(Chan { + acks: HashMap::new(), + seqnum: INIT_SEQNUM, + }) + }), + }) + } + + pub async fn send_rudp(&self, pkt: Pkt<'_>) -> io::Result { + self.send_rudp_type(PktType::Orig, pkt).await // TODO: splits + } + + pub async fn send_rudp_type(&self, tp: PktType, pkt: Pkt<'_>) -> io::Result { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; buf.write_u16::(*self.remote_id.read().await)?; @@ -45,7 +68,7 @@ impl RudpShare

{ buf.write_u8(tp as u8)?; buf.write_all(pkt.data.as_ref())?; - self.send_raw(&buf).await?; + self.send_udp(&buf).await?; if pkt.unrel { Ok(None) @@ -54,7 +77,7 @@ impl RudpShare

{ let (tx, rx) = watch::channel(false); chan.acks.insert( seqnum, - Ack { + AckWait { tx, rx: rx.clone(), data: buf, @@ -66,11 +89,23 @@ impl RudpShare

{ } } - pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> { + pub async fn send_udp(&self, data: &[u8]) -> io::Result<()> { if data.len() > UDP_PKT_SIZE { panic!("splitting packets is not implemented yet"); } - self.udp_tx.send(data).await + self.udp.send(data).await + } + + pub async fn peer_id(&self) -> u16 { + self.id + } + + pub async fn is_server(&self) -> bool { + self.id == PeerID::Srv as u16 + } + + pub fn close(&self) { + self.close.send(true).ok(); } } diff --git a/src/share.rs b/src/share.rs deleted file mode 100644 index 02e37b2..0000000 --- a/src/share.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::*; -use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; -use tokio::sync::{watch, Mutex, RwLock}; - -#[derive(Debug)] -pub(crate) struct Ack { - pub(crate) tx: watch::Sender, - pub(crate) rx: watch::Receiver, - pub(crate) data: Vec, -} - -#[derive(Debug)] -pub(crate) struct Chan { - pub(crate) acks: HashMap, - pub(crate) seqnum: u16, -} - -#[derive(Debug)] -pub(crate) struct RudpShare { - pub(crate) id: u16, - pub(crate) remote_id: RwLock, - pub(crate) chans: [Mutex; NUM_CHANS], - pub(crate) udp_tx: P::Sender, - pub(crate) close: watch::Sender, -} - -pub async fn new( - id: u16, - remote_id: u16, - udp_tx: P::Sender, - udp_rx: P::Receiver, -) -> io::Result<(RudpSender

, RudpReceiver

)> { - let (close_tx, close_rx) = watch::channel(false); - - let share = Arc::new(RudpShare { - id, - remote_id: RwLock::new(remote_id), - udp_tx, - close: close_tx, - chans: std::array::from_fn(|_| { - Mutex::new(Chan { - acks: HashMap::new(), - seqnum: INIT_SEQNUM, - }) - }), - }); - - Ok(( - RudpSender { - share: Arc::clone(&share), - }, - RudpReceiver::new(udp_rx, share, close_rx), - )) -} - -macro_rules! impl_share { - ($T:ident) => { - impl $T

{ - pub async fn peer_id(&self) -> u16 { - self.share.id - } - - pub async fn is_server(&self) -> bool { - self.share.id == PeerID::Srv as u16 - } - - pub async fn close(self) { - self.share.close.send(true).ok(); // FIXME: handle err? - - self.share - .send( - PktType::Ctl, - Pkt { - unrel: true, - chan: 0, - data: Cow::Borrowed(&[CtlType::Disco as u8]), - }, - ) - .await - .ok(); - } - } - }; -} - -impl_share!(RudpReceiver); -impl_share!(RudpSender); diff --git a/src/recv.rs b/src/worker.rs similarity index 69% rename from src/recv.rs rename to src/worker.rs index 309bf94..72bf2b5 100644 --- a/src/recv.rs +++ b/src/worker.rs @@ -3,14 +3,14 @@ use async_recursion::async_recursion; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use std::{ borrow::Cow, - collections::{HashMap, VecDeque}, + collections::HashMap, io, pin::Pin, sync::Arc, time::{Duration, Instant}, }; use tokio::{ - sync::watch, + sync::{mpsc, watch}, time::{interval, sleep, Interval, Sleep}, }; @@ -27,36 +27,38 @@ struct Split { got: usize, } +#[derive(Debug)] struct RecvChan { packets: Vec>>, // char ** 😛 splits: HashMap, seqnum: u16, } -pub struct RudpReceiver { - pub(crate) share: Arc>, +#[derive(Debug)] +pub struct Worker { + sender: Arc>, chans: [RecvChan; NUM_CHANS], - udp: P::Receiver, + input: R, close: watch::Receiver, - closed: bool, resend: Interval, ping: Interval, cleanup: Interval, timeout: Pin>, - queue: VecDeque>>, + output: mpsc::UnboundedSender>>, } -impl RudpReceiver

{ +impl Worker { pub(crate) fn new( - udp: P::Receiver, - share: Arc>, + input: R, close: watch::Receiver, + sender: Arc>, + output: mpsc::UnboundedSender>>, ) -> Self { Self { - udp, - share, + input, + sender, close, - closed: false, + output, resend: interval(Duration::from_millis(500)), ping: interval(Duration::from_secs(PING_TIMEOUT)), cleanup: interval(Duration::from_secs(TIMEOUT)), @@ -66,42 +68,33 @@ impl RudpReceiver

{ seqnum: INIT_SEQNUM, splits: HashMap::new(), }), - queue: VecDeque::new(), } } - fn handle_err(&mut self, res: Result<()>) -> Result<()> { + pub async fn run(mut self) { use Error::*; - match res { - Err(RemoteDisco(_)) | Err(LocalDisco) => { - self.closed = true; - res - } - Ok(_) => res, - Err(e) => { - self.queue.push_back(Err(e)); - Ok(()) - } - } - } - - pub async fn recv(&mut self) -> Option>> { - use Error::*; - - if self.closed { - return None; - } - loop { - if let Some(x) = self.queue.pop_front() { - return Some(x); - } - tokio::select! { _ = self.close.changed() => { - self.closed = true; - return Some(Err(LocalDisco)); + self.sender + .send_rudp_type( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: Cow::Borrowed(&[CtlType::Disco as u8]), + }, + ) + .await + .ok(); + + self.output.send(Err(LocalDisco)).ok(); + break; + }, + _ = &mut self.timeout => { + self.output.send(Err(RemoteDisco(true))).ok(); + break; }, _ = self.cleanup.tick() => { let timeout = Duration::from_secs(TIMEOUT); @@ -116,15 +109,15 @@ impl RudpReceiver

{ } }, _ = self.resend.tick() => { - for chan in self.share.chans.iter() { + for chan in self.sender.chans.iter() { for (_, ack) in chan.lock().await.acks.iter() { - self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?) + self.sender.send_udp(&ack.data).await.ok(); } } }, _ = self.ping.tick() => { - self.share - .send( + self.sender + .send_rudp_type( PktType::Ctl, Pkt { chan: 0, @@ -135,13 +128,9 @@ impl RudpReceiver

{ .await .ok(); } - _ = &mut self.timeout => { - self.closed = true; - return Some(Err(RemoteDisco(true))); - }, - pkt = self.udp.recv() => { + pkt = self.input.recv() => { if let Err(e) = self.handle_pkt(pkt).await { - return Some(Err(e)); + self.output.send(Err(e)).ok(); } } } @@ -169,10 +158,7 @@ impl RudpReceiver

{ return Err(InvalidChannel(chan)); } - let res = self.process_pkt(cursor, true, chan).await; - self.handle_err(res)?; - - Ok(()) + self.process_pkt(cursor, true, chan).await } #[async_recursion] @@ -188,17 +174,13 @@ impl RudpReceiver

{ match cursor.read_u8()?.try_into()? { PktType::Ctl => match cursor.read_u8()?.try_into()? { CtlType::Ack => { - // println!("Ack"); - let seqnum = cursor.read_u16::()?; - if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) { + if let Some(ack) = self.sender.chans[ch].lock().await.acks.remove(&seqnum) { ack.tx.send(true).ok(); } } CtlType::SetPeerID => { - // println!("SetPeerID"); - - let mut id = self.share.remote_id.write().await; + let mut id = self.sender.remote_id.write().await; if *id != PeerID::Nil as u16 { return Err(PeerIDAlreadySet); @@ -206,26 +188,21 @@ impl RudpReceiver

{ *id = cursor.read_u16::()?; } - CtlType::Ping => { - // println!("Ping"); - } + CtlType::Ping => {} CtlType::Disco => { - // println!("Disco"); return Err(RemoteDisco(false)); } }, PktType::Orig => { - // println!("Orig"); - - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: Cow::Owned(cursor.remaining_slice().into()), - })); + self.output + .send(Ok(Pkt { + chan, + unrel, + data: Cow::Owned(cursor.remaining_slice().into()), + })) + .ok(); } PktType::Split => { - // println!("Split"); - let seqnum = cursor.read_u16::()?; let chunk_count = cursor.read_u16::()? as usize; let chunk_index = cursor.read_u16::()? as usize; @@ -258,25 +235,25 @@ impl RudpReceiver

{ if split.got == chunk_count { let split = self.chans[ch].splits.remove(&seqnum).unwrap(); - self.queue.push_back(Ok(Pkt { - chan, - unrel, - data: split - .chunks - .into_iter() - .map(|x| x.unwrap()) - .reduce(|mut a, mut b| { - a.append(&mut b); - a - }) - .unwrap_or_default() - .into(), - })); + self.output + .send(Ok(Pkt { + chan, + unrel, + data: split + .chunks + .into_iter() + .map(|x| x.unwrap()) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() + .into(), + })) + .ok(); } } PktType::Rel => { - // println!("Rel"); - let seqnum = cursor.read_u16::()?; self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into()); @@ -284,8 +261,8 @@ impl RudpReceiver

{ ack_data.write_u8(CtlType::Ack as u8)?; ack_data.write_u16::(seqnum)?; - self.share - .send( + self.sender + .send_rudp_type( PktType::Ctl, Pkt { chan, @@ -297,8 +274,10 @@ impl RudpReceiver

{ let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take(); while let Some(pkt) = next_pkt(&mut self.chans[ch]) { - let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await; - self.handle_err(res)?; + if let Err(e) = self.process_pkt(io::Cursor::new(pkt), false, chan).await { + self.output.send(Err(e)).ok(); + } + self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0; } } -- 2.44.0