]> git.lizzy.rs Git - mt_rudp.git/blobdiff - src/main.rs
clean shutdown; send reliables
[mt_rudp.git] / src / main.rs
index 1f0aca01f1a53a87b617be8cfc1cff2a0be890d6..0510db765d994ce6e5dad09d3aafe5b93ed40282 100644 (file)
@@ -3,19 +3,27 @@
 #![feature(once_cell)]
 mod client;
 pub mod error;
-mod recv_worker;
+mod new;
+mod recv;
+mod send;
 
 use async_trait::async_trait;
 use byteorder::{BigEndian, WriteBytesExt};
 pub use client::{connect, Sender as Client};
+pub use new::new;
 use num_enum::TryFromPrimitive;
+use pretty_hex::PrettyHex;
 use std::{
     collections::HashMap,
     io::{self, Write},
     ops,
     sync::Arc,
+    time::Duration,
+};
+use tokio::{
+    sync::{mpsc, watch, Mutex, RwLock},
+    task::JoinSet,
 };
-use tokio::sync::{mpsc, watch, Mutex, RwLock};
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
@@ -24,9 +32,25 @@ pub const REL_BUFFER: usize = 0x8000;
 pub const INIT_SEQNUM: u16 = 65500;
 pub const TIMEOUT: u64 = 30;
 
+mod ticker_mod {
+    #[macro_export]
+    macro_rules! ticker {
+               ($duration:expr, $close:expr, $body:block) => {
+                       let mut interval = tokio::time::interval($duration);
+
+                       while tokio::select!{
+                               _ = interval.tick() => true,
+                               _ = $close.changed() => false,
+                       } $body
+               };
+       }
+
+    //pub(crate) use ticker;
+}
+
 #[async_trait]
 pub trait UdpSender: Send + Sync + 'static {
-    async fn send(&self, data: Vec<u8>) -> io::Result<()>;
+    async fn send(&self, data: &[u8]) -> io::Result<()>;
 }
 
 #[async_trait]
@@ -69,14 +93,28 @@ pub struct Pkt<T> {
 
 pub type Error = error::Error;
 pub type InPkt = Result<Pkt<Vec<u8>>, Error>;
-type AckChan = (watch::Sender<bool>, watch::Receiver<bool>);
+
+#[derive(Debug)]
+struct Ack {
+    tx: watch::Sender<bool>,
+    rx: watch::Receiver<bool>,
+    data: Vec<u8>,
+}
+
+#[derive(Debug)]
+struct Chan {
+    acks: HashMap<u16, Ack>,
+    seqnum: u16,
+}
 
 #[derive(Debug)]
 pub struct RudpShare<S: UdpSender> {
-    pub id: u16,
-    pub remote_id: RwLock<u16>,
-    pub ack_chans: Mutex<HashMap<u16, AckChan>>,
+    id: u16,
+    remote_id: RwLock<u16>,
+    chans: Vec<Mutex<Chan>>,
     udp_tx: S,
+    close_tx: watch::Sender<bool>,
+    tasks: Mutex<JoinSet<()>>,
 }
 
 #[derive(Debug)]
@@ -90,44 +128,31 @@ pub struct RudpSender<S: UdpSender> {
     share: Arc<RudpShare<S>>,
 }
 
-impl<S: UdpSender> RudpShare<S> {
-    pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
-        let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
-        buf.write_u32::<BigEndian>(PROTO_ID)?;
-        buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
-        buf.write_u8(pkt.chan as u8)?;
-        buf.write_u8(tp as u8)?;
-        buf.write(pkt.data)?;
-
-        self.udp_tx.send(buf).await?;
-
-        Ok(())
-    }
-}
+macro_rules! impl_share {
+    ($T:ident) => {
+        impl<S: UdpSender> $T<S> {
+            pub async fn peer_id(&self) -> u16 {
+                self.share.id
+            }
 
-impl<S: UdpSender> RudpSender<S> {
-    pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
-        self.share.send(PktType::Orig, pkt).await // TODO
-    }
+            pub async fn is_server(&self) -> bool {
+                self.share.id == PeerID::Srv as u16
+            }
 
-    pub async fn peer_id(&self) -> u16 {
-        self.share.id
-    }
+            pub async fn close(self) {
+                self.share.close_tx.send(true).ok();
 
-    pub async fn is_server(&self) -> bool {
-        self.share.id == PeerID::Srv as u16
-    }
+                let mut tasks = self.share.tasks.lock().await;
+                while let Some(res) = tasks.join_next().await {
+                    res.ok(); // TODO: handle error (?)
+                }
+            }
+        }
+    };
 }
 
-impl<S: UdpSender> RudpReceiver<S> {
-    pub async fn peer_id(&self) -> u16 {
-        self.share.id
-    }
-
-    pub async fn is_server(&self) -> bool {
-        self.share.id == PeerID::Srv as u16
-    }
-}
+impl_share!(RudpReceiver);
+impl_share!(RudpSender);
 
 impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
     type Target = mpsc::UnboundedReceiver<InPkt>;
@@ -143,49 +168,16 @@ impl<S: UdpSender> ops::DerefMut for RudpReceiver<S> {
     }
 }
 
-pub fn new<S: UdpSender, R: UdpReceiver>(
-    id: u16,
-    remote_id: u16,
-    udp_tx: S,
-    udp_rx: R,
-) -> (RudpSender<S>, RudpReceiver<S>) {
-    let (pkt_tx, pkt_rx) = mpsc::unbounded_channel();
-
-    let share = Arc::new(RudpShare {
-        id,
-        remote_id: RwLock::new(remote_id),
-        udp_tx,
-        ack_chans: Mutex::new(HashMap::new()),
-    });
-    let recv_share = Arc::clone(&share);
-
-    tokio::spawn(async {
-        let worker = recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx);
-        worker.run().await;
-    });
-
-    (
-        RudpSender {
-            share: Arc::clone(&share),
-        },
-        RudpReceiver { share, pkt_rx },
-    )
-}
-
-// connect
-
-#[tokio::main]
-async fn main() -> io::Result<()> {
-    let (tx, mut rx) = connect("127.0.0.1:30000").await?;
-
+async fn example(tx: &RudpSender<Client>, rx: &mut RudpReceiver<Client>) -> io::Result<()> {
+    // send hello packet
     let mut mtpkt = vec![];
     mtpkt.write_u16::<BigEndian>(2)?; // high level type
     mtpkt.write_u8(29)?; // serialize ver
     mtpkt.write_u16::<BigEndian>(0)?; // compression modes
     mtpkt.write_u16::<BigEndian>(40)?; // MinProtoVer
     mtpkt.write_u16::<BigEndian>(40)?; // MaxProtoVer
-    mtpkt.write_u16::<BigEndian>(3)?; // player name length
-    mtpkt.write(b"foo")?; // player name
+    mtpkt.write_u16::<BigEndian>(6)?; // player name length
+    mtpkt.write(b"foobar")?; // player name
 
     tx.send(Pkt {
         unrel: true,
@@ -194,17 +186,34 @@ async fn main() -> io::Result<()> {
     })
     .await?;
 
+    // handle incoming packets
     while let Some(result) = rx.recv().await {
         match result {
             Ok(pkt) => {
-                io::stdout().write(pkt.data.as_slice())?;
+                println!("{}", pkt.data.hex_dump());
             }
             Err(err) => eprintln!("Error: {}", err),
         }
     }
-    println!("disco");
 
-    // close()ing rx is not needed because it has been consumed to the end
+    Ok(())
+}
+
+#[tokio::main]
+async fn main() -> io::Result<()> {
+    let (tx, mut rx) = connect("127.0.0.1:30000").await?;
+
+    tokio::select! {
+        _ = tokio::signal::ctrl_c() => println!("canceled"),
+        res = example(&tx, &mut rx) => {
+            res?;
+            println!("disconnected");
+        }
+    }
+
+    // close either the receiver or the sender
+    // this shuts down associated tasks
+    rx.close().await;
 
     Ok(())
 }