From 976be3aa3b50ef7721fce2d38ed5855ed64a719a Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Fri, 23 Dec 2022 18:00:28 +0100 Subject: [PATCH] async --- Cargo.toml | 1 + src/client.rs | 17 +++++++++-------- src/error.rs | 7 ++++--- src/main.rs | 38 ++++++++++++++++++++------------------ src/recv_worker.rs | 34 +++++++++++++++++++--------------- 5 files changed, 53 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bedb394..55cc8ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" [dependencies] byteorder = "1.4.3" num_enum = "0.5.7" +tokio = { version = "1.23.0", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index 81c1bfb..97a18d7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,14 @@ use crate::*; -use std::{io, net, sync::Arc}; +use std::{io, sync::Arc}; +use tokio::net; pub struct Sender { sock: Arc, } impl UdpSender for Sender { - fn send(&self, data: Vec) -> io::Result<()> { - self.sock.send(&data)?; + async fn send(&self, data: Vec) -> io::Result<()> { + self.sock.send(&data).await?; Ok(()) } } @@ -17,20 +18,20 @@ pub struct Receiver { } impl UdpReceiver for Receiver { - fn recv(&self) -> io::Result> { + async fn recv(&self) -> io::Result> { let mut buffer = Vec::new(); buffer.resize(UDP_PKT_SIZE, 0); - let len = self.sock.recv(&mut buffer)?; + let len = self.sock.recv(&mut buffer).await?; buffer.truncate(len); Ok(buffer) } } -pub fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { - let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0")?); - sock.connect(addr)?; +pub async fn connect(addr: &str) -> io::Result<(RudpSender, RudpReceiver)> { + let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?); + sock.connect(addr).await?; Ok(new( PeerID::Srv as u16, diff --git a/src/error.rs b/src/error.rs index 02080c7..9b84ace 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,7 @@ use crate::{CtlType, InPkt, PktType}; use num_enum::TryFromPrimitiveError; -use std::{fmt, io, sync::mpsc}; +use std::{fmt, io}; +use tokio::sync::mpsc::error::SendError; #[derive(Debug)] pub enum Error { @@ -32,8 +33,8 @@ impl From> for Error { } } -impl From> for Error { - fn from(_err: mpsc::SendError) -> Self { +impl From> for Error { + fn from(_err: SendError) -> Self { Self::LocalDisco // technically not a disconnect but a local drop } } diff --git a/src/main.rs b/src/main.rs index aadb5cc..241d324 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #![feature(yeet_expr)] #![feature(cursor_remaining)] #![feature(hash_drain_filter)] +#![feature(async_fn_in_trait)] mod client; pub mod error; mod recv_worker; @@ -8,12 +9,13 @@ mod recv_worker; use byteorder::{BigEndian, WriteBytesExt}; pub use client::{connect, Sender as Client}; use num_enum::TryFromPrimitive; +use std::future::Future; use std::{ io::{self, Write}, ops, - sync::{mpsc, Arc}, - thread, + sync::Arc, }; +use tokio::sync::mpsc; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; @@ -23,11 +25,11 @@ pub const INIT_SEQNUM: u16 = 65500; pub const TIMEOUT: u64 = 30; pub trait UdpSender: Send + Sync + 'static { - fn send(&self, data: Vec) -> io::Result<()>; + async fn send(&self, data: Vec) -> io::Result<()>; } pub trait UdpReceiver: Send + Sync + 'static { - fn recv(&self) -> io::Result>; + async fn recv(&self) -> io::Result>; } #[derive(Debug, Copy, Clone)] @@ -79,7 +81,7 @@ pub struct RudpShare { #[derive(Debug)] pub struct RudpReceiver { share: Arc>, - pkt_rx: mpsc::Receiver, + pkt_rx: mpsc::UnboundedReceiver, } #[derive(Debug)] @@ -88,7 +90,7 @@ pub struct RudpSender { } impl RudpShare { - pub fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { + pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; buf.write_u16::(self.remote_id)?; @@ -96,20 +98,20 @@ impl RudpShare { buf.write_u8(tp as u8)?; buf.write(pkt.data)?; - self.udp_tx.send(buf)?; + self.udp_tx.send(buf).await?; Ok(()) } } impl RudpSender { - pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { - self.share.send(PktType::Orig, pkt) // TODO + pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { + self.share.send(PktType::Orig, pkt).await // TODO } } impl ops::Deref for RudpReceiver { - type Target = mpsc::Receiver; + type Target = mpsc::UnboundedReceiver; fn deref(&self) -> &Self::Target { &self.pkt_rx @@ -122,7 +124,7 @@ pub fn new( udp_tx: S, udp_rx: R, ) -> (RudpSender, RudpReceiver) { - let (pkt_tx, pkt_rx) = mpsc::channel(); + let (pkt_tx, pkt_rx) = mpsc::unbounded_channel(); let share = Arc::new(RudpShare { id, @@ -132,9 +134,7 @@ pub fn new( }); let recv_share = Arc::clone(&share); - thread::spawn(|| { - recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx).run(); - }); + tokio::spawn(async { recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx).await }); ( RudpSender { @@ -146,9 +146,10 @@ pub fn new( // connect -fn main() -> io::Result<()> { +#[tokio::main] +async fn main() -> io::Result<()> { //println!("{}", x.deep_size_of()); - let (tx, rx) = connect("127.0.0.1:30000")?; + let (tx, rx) = connect("127.0.0.1:30000").await?; let mut mtpkt = vec![]; mtpkt.write_u16::(2)?; // high level type @@ -163,9 +164,10 @@ fn main() -> io::Result<()> { unrel: true, chan: 1, data: &mtpkt, - })?; + }) + .await?; - while let Ok(result) = rx.recv() { + while let Some(result) = rx.recv().await { match result { Ok(pkt) => { io::stdout().write(pkt.data.as_slice())?; diff --git a/src/recv_worker.rs b/src/recv_worker.rs index 2cd8197..60cadeb 100644 --- a/src/recv_worker.rs +++ b/src/recv_worker.rs @@ -3,17 +3,17 @@ use byteorder::{BigEndian, ReadBytesExt}; use std::{ cell::Cell, collections::HashMap, - io, result, - sync::{mpsc, Arc, Mutex, Weak}, - thread, time, + io, + sync::{Arc, Weak}, + time, }; +use tokio::sync::{mpsc, Mutex}; fn to_seqnum(seqnum: u16) -> usize { (seqnum as usize) & (REL_BUFFER - 1) } -type PktTx = mpsc::Sender; -type Result = result::Result<(), Error>; +type Result = std::result::Result<(), Error>; struct Split { timestamp: time::Instant, @@ -29,12 +29,12 @@ struct Chan { pub struct RecvWorker { share: Arc>, chans: Arc>>, - pkt_tx: PktTx, + pkt_tx: mpsc::UnboundedSender, udp_rx: R, } impl RecvWorker { - pub fn new(udp_rx: R, share: Arc>, pkt_tx: PktTx) -> Self { + pub async fn new(udp_rx: R, share: Arc>, pkt_tx: mpsc::UnboundedSender) { Self { udp_rx, share, @@ -52,28 +52,31 @@ impl RecvWorker { .collect(), ), } + .run() + .await } - pub fn run(&self) { + pub async fn run(&self) { let cleanup_chans = Arc::downgrade(&self.chans); - thread::spawn(move || { + tokio::spawn(async move { let timeout = time::Duration::from_secs(TIMEOUT); + let mut interval = tokio::time::interval(timeout); while let Some(chans) = Weak::upgrade(&cleanup_chans) { for chan in chans.iter() { - let mut ch = chan.lock().unwrap(); + let mut ch = chan.lock().await; ch.splits = ch .splits .drain_filter(|_k, v| v.timestamp.elapsed() < timeout) .collect(); } - thread::sleep(timeout); + interval.tick().await; } }); loop { - if let Err(e) = self.handle(self.recv_pkt()) { + if let Err(e) = self.handle(self.recv_pkt().await) { if let Error::LocalDisco = e { self.share .send( @@ -84,6 +87,7 @@ impl RecvWorker { data: &[CtlType::Disco as u8], }, ) + .await .ok(); } break; @@ -91,11 +95,11 @@ impl RecvWorker { } } - fn recv_pkt(&self) -> Result { + async fn recv_pkt(&self) -> Result { use Error::*; // todo: reset timeout - let mut cursor = io::Cursor::new(self.udp_rx.recv()?); + let mut cursor = io::Cursor::new(self.udp_rx.recv().await?); let proto_id = cursor.read_u32::()?; if proto_id != PROTO_ID { @@ -110,7 +114,7 @@ impl RecvWorker { .get(n_chan as usize) .ok_or(InvalidChannel(n_chan))? .lock() - .unwrap(); + .await; self.process_pkt(cursor, &mut chan) } -- 2.44.0