]> git.lizzy.rs Git - mt_rudp.git/blobdiff - src/main.rs
send acks
[mt_rudp.git] / src / main.rs
index a190bcd094447afb6bdd928b1e0c2d3023f12e43..1f0aca01f1a53a87b617be8cfc1cff2a0be890d6 100644 (file)
@@ -1,6 +1,6 @@
-#![feature(yeet_expr)]
 #![feature(cursor_remaining)]
 #![feature(hash_drain_filter)]
+#![feature(once_cell)]
 mod client;
 pub mod error;
 mod recv_worker;
@@ -9,13 +9,13 @@ use async_trait::async_trait;
 use byteorder::{BigEndian, WriteBytesExt};
 pub use client::{connect, Sender as Client};
 use num_enum::TryFromPrimitive;
-use std::future::Future;
 use std::{
+    collections::HashMap,
     io::{self, Write},
     ops,
     sync::Arc,
 };
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, watch, Mutex, RwLock};
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
@@ -34,7 +34,8 @@ pub trait UdpReceiver: Send + Sync + 'static {
     async fn recv(&self) -> io::Result<Vec<u8>>;
 }
 
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[repr(u16)]
 pub enum PeerID {
     Nil = 0,
     Srv,
@@ -68,15 +69,13 @@ pub struct Pkt<T> {
 
 pub type Error = error::Error;
 pub type InPkt = Result<Pkt<Vec<u8>>, Error>;
-
-#[derive(Debug)]
-pub struct AckChan;
+type AckChan = (watch::Sender<bool>, watch::Receiver<bool>);
 
 #[derive(Debug)]
 pub struct RudpShare<S: UdpSender> {
     pub id: u16,
-    pub remote_id: u16,
-    pub chans: Vec<AckChan>,
+    pub remote_id: RwLock<u16>,
+    pub ack_chans: Mutex<HashMap<u16, AckChan>>,
     udp_tx: S,
 }
 
@@ -95,7 +94,7 @@ impl<S: UdpSender> RudpShare<S> {
     pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
         buf.write_u32::<BigEndian>(PROTO_ID)?;
-        buf.write_u16::<BigEndian>(self.remote_id)?;
+        buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
         buf.write_u8(pkt.chan as u8)?;
         buf.write_u8(tp as u8)?;
         buf.write(pkt.data)?;
@@ -110,6 +109,24 @@ impl<S: UdpSender> RudpSender<S> {
     pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
         self.share.send(PktType::Orig, pkt).await // TODO
     }
+
+    pub async fn peer_id(&self) -> u16 {
+        self.share.id
+    }
+
+    pub async fn is_server(&self) -> bool {
+        self.share.id == PeerID::Srv as u16
+    }
+}
+
+impl<S: UdpSender> RudpReceiver<S> {
+    pub async fn peer_id(&self) -> u16 {
+        self.share.id
+    }
+
+    pub async fn is_server(&self) -> bool {
+        self.share.id == PeerID::Srv as u16
+    }
 }
 
 impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
@@ -136,9 +153,9 @@ pub fn new<S: UdpSender, R: UdpReceiver>(
 
     let share = Arc::new(RudpShare {
         id,
-        remote_id,
+        remote_id: RwLock::new(remote_id),
         udp_tx,
-        chans: (0..NUM_CHANS).map(|_| AckChan).collect(),
+        ack_chans: Mutex::new(HashMap::new()),
     });
     let recv_share = Arc::clone(&share);
 
@@ -159,7 +176,6 @@ pub fn new<S: UdpSender, R: UdpReceiver>(
 
 #[tokio::main]
 async fn main() -> io::Result<()> {
-    //println!("{}", x.deep_size_of());
     let (tx, mut rx) = connect("127.0.0.1:30000").await?;
 
     let mut mtpkt = vec![];
@@ -188,5 +204,7 @@ async fn main() -> io::Result<()> {
     }
     println!("disco");
 
+    // close()ing rx is not needed because it has been consumed to the end
+
     Ok(())
 }