]> git.lizzy.rs Git - mt_rudp.git/commitdiff
Don't spawn tasks
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Sat, 18 Feb 2023 02:03:40 +0000 (03:03 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Sat, 18 Feb 2023 02:05:06 +0000 (03:05 +0100)
Cargo.toml
examples/example.rs
src/client.rs
src/common.rs
src/error.rs
src/lib.rs
src/recv.rs
src/send.rs
src/share.rs

index 8e5f7ef186d2b67403b04f12347ec6e7bca9cff0..254a0db1b7328bff8825b9b54d48f829da1f6bb0 100644 (file)
@@ -7,12 +7,10 @@ edition = "2021"
 async-recursion = "1.0.0"
 async-trait = "0.1.60"
 byteorder = "1.4.3"
-delegate = "0.9.0"
-drop_bomb = "0.1.5"
 num_enum = "0.5.7"
 thiserror = "1.0.38"
-tokio = { version = "1.23.0", features = ["sync", "time", "net", "signal", "macros", "rt"] }
+tokio = { version = "1.23.0", features = ["sync", "time", "net", "macros"] }
 
 [dev-dependencies]
 pretty-hex = "0.3.0"
-tokio = { version = "1.23.0", features = ["rt-multi-thread"] }
+tokio = { version = "1.23.0", features = ["rt-multi-thread", "rt", "signal", "tracing"] }
index 8625243ebb7cd5290bdc812056fbf613b9e372a1..afb9d5baf8ca3f9847e56332bf66763eb35f15f2 100644 (file)
@@ -1,9 +1,9 @@
 use byteorder::{BigEndian, WriteBytesExt};
-use mt_rudp::{RudpReceiver, RudpSender, ToSrv};
+use mt_rudp::{RemoteSrv, RudpReceiver, RudpSender};
 use pretty_hex::PrettyHex;
 use std::io::{self, Write};
 
-async fn example(tx: &RudpSender<ToSrv>, rx: &mut RudpReceiver<ToSrv>) -> io::Result<()> {
+async fn example(tx: &RudpSender<RemoteSrv>, rx: &mut RudpReceiver<RemoteSrv>) -> io::Result<()> {
     // send hello packet
     let mut pkt = vec![];
     pkt.write_u16::<BigEndian>(2)?; // high level type
@@ -17,7 +17,7 @@ async fn example(tx: &RudpSender<ToSrv>, rx: &mut RudpReceiver<ToSrv>) -> io::Re
     tx.send(mt_rudp::Pkt {
         unrel: true,
         chan: 1,
-        data: &pkt,
+        data: pkt.into(),
     })
     .await?;
 
index c4922ec9c7d45c7bc45440f2f3bd733efb7906de..56db92a9ae060f8e6bfce7ce38017bda6c3199d0 100644 (file)
@@ -19,7 +19,7 @@ impl UdpSender for ToSrv {
 
 #[async_trait]
 impl UdpReceiver for FromSrv {
-    async fn recv(&self) -> io::Result<Vec<u8>> {
+    async fn recv(&mut self) -> io::Result<Vec<u8>> {
         let mut buffer = Vec::new();
         buffer.resize(UDP_PKT_SIZE, 0);
 
@@ -30,7 +30,13 @@ impl UdpReceiver for FromSrv {
     }
 }
 
-pub async fn connect(addr: &str) -> io::Result<(RudpSender<ToSrv>, RudpReceiver<ToSrv>)> {
+pub struct RemoteSrv;
+impl UdpPeer for RemoteSrv {
+    type Sender = ToSrv;
+    type Receiver = FromSrv;
+}
+
+pub async fn connect(addr: &str) -> io::Result<(RudpSender<RemoteSrv>, RudpReceiver<RemoteSrv>)> {
     let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0").await?);
     sock.connect(addr).await?;
 
index 4c0bc08600af0ee038ea9f65e15bf1291310a125..0ed08f3c3b09c8f3e28647710c3c688f764e65a9 100644 (file)
@@ -1,9 +1,6 @@
-use super::*;
 use async_trait::async_trait;
-use delegate::delegate;
 use num_enum::TryFromPrimitive;
-use std::{borrow::Cow, io, sync::Arc};
-use tokio::sync::mpsc;
+use std::{borrow::Cow, fmt::Debug, io};
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
@@ -14,13 +11,18 @@ pub const TIMEOUT: u64 = 30;
 pub const PING_TIMEOUT: u64 = 5;
 
 #[async_trait]
-pub trait UdpSender: Send + Sync + 'static {
+pub trait UdpSender: Send + Sync {
     async fn send(&self, data: &[u8]) -> io::Result<()>;
 }
 
 #[async_trait]
-pub trait UdpReceiver: Send + Sync + 'static {
-    async fn recv(&self) -> io::Result<Vec<u8>>;
+pub trait UdpReceiver: Send {
+    async fn recv(&mut self) -> io::Result<Vec<u8>>;
+}
+
+pub trait UdpPeer {
+    type Sender: UdpSender;
+    type Receiver: UdpReceiver;
 }
 
 #[derive(Debug, Copy, Clone, PartialEq)]
@@ -55,60 +57,3 @@ pub struct Pkt<'a> {
     pub chan: u8,
     pub data: Cow<'a, [u8]>,
 }
-
-pub type InPkt = Result<Pkt<'static>, Error>;
-
-#[derive(Debug)]
-pub struct RudpReceiver<S: UdpSender> {
-    pub(crate) share: Arc<RudpShare<S>>,
-    pub(crate) pkt_rx: mpsc::UnboundedReceiver<InPkt>,
-}
-
-#[derive(Debug)]
-pub struct RudpSender<S: UdpSender> {
-    pub(crate) share: Arc<RudpShare<S>>,
-}
-
-// derive(Clone) adds unwanted Clone trait bound to S parameter
-impl<S: UdpSender> Clone for RudpSender<S> {
-    fn clone(&self) -> Self {
-        Self {
-            share: Arc::clone(&self.share),
-        }
-    }
-}
-
-macro_rules! impl_share {
-    ($T:ident) => {
-        impl<S: UdpSender> $T<S> {
-            pub async fn peer_id(&self) -> u16 {
-                self.share.id
-            }
-
-            pub async fn is_server(&self) -> bool {
-                self.share.id == PeerID::Srv as u16
-            }
-
-            pub async fn close(self) {
-                self.share.bomb.lock().await.defuse();
-                self.share.close_tx.send(true).ok();
-
-                let mut tasks = self.share.tasks.lock().await;
-                while let Some(res) = tasks.join_next().await {
-                    res.ok(); // TODO: handle error (?)
-                }
-            }
-        }
-    };
-}
-
-impl_share!(RudpReceiver);
-impl_share!(RudpSender);
-
-impl<S: UdpSender> RudpReceiver<S> {
-    delegate! {
-        to self.pkt_rx {
-            pub async fn recv(&mut self) -> Option<InPkt>;
-        }
-    }
-}
index 7cfc0578572766d985e5ac6869cfa2c0a1c1d8bb..28233e869f99a5541c688d0a61280e63a15db7c1 100644 (file)
@@ -1,7 +1,6 @@
 use super::*;
 use num_enum::TryFromPrimitiveError;
 use thiserror::Error;
-use tokio::sync::mpsc::error::SendError;
 
 #[derive(Error, Debug)]
 pub enum Error {
@@ -38,9 +37,3 @@ impl From<TryFromPrimitiveError<CtlType>> for Error {
         Self::InvalidCtlType(err.number)
     }
 }
-
-impl From<SendError<InPkt>> for Error {
-    fn from(_err: SendError<InPkt>) -> Self {
-        Self::LocalDisco
-    }
-}
index a02eb206363e9f86673d6aa8b86be0f735c510d3..e7a8ebe5f978e3ba0912cb186127e9ec93583bcb 100644 (file)
@@ -1,6 +1,6 @@
 #![feature(cursor_remaining)]
 #![feature(hash_drain_filter)]
-#![feature(once_cell)]
+
 mod client;
 mod common;
 mod error;
@@ -11,21 +11,6 @@ mod share;
 pub use client::*;
 pub use common::*;
 pub use error::*;
-use recv::*;
+pub use recv::*;
 pub use send::*;
 pub use share::*;
-pub use ticker_mod::*;
-
-mod ticker_mod {
-    #[macro_export]
-    macro_rules! ticker {
-               ($duration:expr, $close:expr, $body:block) => {
-                       let mut interval = tokio::time::interval($duration);
-
-                       while tokio::select!{
-                               _ = interval.tick() => true,
-                               _ = $close.changed() => false,
-                       } $body
-               };
-       }
-}
index fd6f299681b062baf5980a5055219da9823e6582..34e273c08f3d3cc6e03eb0d355691ab847756c4f 100644 (file)
@@ -3,14 +3,16 @@ use async_recursion::async_recursion;
 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
 use std::{
     borrow::Cow,
-    cell::OnceCell,
-    collections::HashMap,
+    collections::{HashMap, VecDeque},
     io,
     pin::Pin,
     sync::Arc,
     time::{Duration, Instant},
 };
-use tokio::sync::{mpsc, watch, Mutex};
+use tokio::{
+    sync::watch,
+    time::{interval, sleep, Interval, Sleep},
+};
 
 fn to_seqnum(seqnum: u16) -> usize {
     (seqnum as usize) & (REL_BUFFER - 1)
@@ -21,7 +23,7 @@ type Result<T> = std::result::Result<T, Error>;
 #[derive(Debug)]
 struct Split {
     timestamp: Option<Instant>,
-    chunks: Vec<OnceCell<Vec<u8>>>,
+    chunks: Vec<Option<Vec<u8>>>,
     got: usize,
 }
 
@@ -29,61 +31,82 @@ struct RecvChan {
     packets: Vec<Option<Vec<u8>>>, // char ** ðŸ˜›
     splits: HashMap<u16, Split>,
     seqnum: u16,
-    num: u8,
 }
 
-pub(crate) struct RecvWorker<R: UdpReceiver, S: UdpSender> {
-    share: Arc<RudpShare<S>>,
+pub struct RudpReceiver<P: UdpPeer> {
+    pub(crate) share: Arc<RudpShare<P>>,
+    chans: [RecvChan; NUM_CHANS],
+    udp: P::Receiver,
     close: watch::Receiver<bool>,
-    chans: Arc<Vec<Mutex<RecvChan>>>,
-    pkt_tx: mpsc::UnboundedSender<InPkt>,
-    udp_rx: R,
+    closed: bool,
+    resend: Interval,
+    ping: Interval,
+    cleanup: Interval,
+    timeout: Pin<Box<Sleep>>,
+    queue: VecDeque<Result<Pkt<'static>>>,
 }
 
-impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
-    pub fn new(
-        udp_rx: R,
-        share: Arc<RudpShare<S>>,
+impl<P: UdpPeer> RudpReceiver<P> {
+    pub(crate) fn new(
+        udp: P::Receiver,
+        share: Arc<RudpShare<P>>,
         close: watch::Receiver<bool>,
-        pkt_tx: mpsc::UnboundedSender<InPkt>,
     ) -> Self {
         Self {
-            udp_rx,
+            udp,
             share,
             close,
-            pkt_tx,
-            chans: Arc::new(
-                (0..NUM_CHANS as u8)
-                    .map(|num| {
-                        Mutex::new(RecvChan {
-                            num,
-                            packets: (0..REL_BUFFER).map(|_| None).collect(),
-                            seqnum: INIT_SEQNUM,
-                            splits: HashMap::new(),
-                        })
-                    })
-                    .collect(),
-            ),
+            closed: false,
+            resend: interval(Duration::from_millis(500)),
+            ping: interval(Duration::from_secs(PING_TIMEOUT)),
+            cleanup: interval(Duration::from_secs(TIMEOUT)),
+            timeout: Box::pin(sleep(Duration::from_secs(TIMEOUT))),
+            chans: std::array::from_fn(|_| RecvChan {
+                packets: (0..REL_BUFFER).map(|_| None).collect(),
+                seqnum: INIT_SEQNUM,
+                splits: HashMap::new(),
+            }),
+            queue: VecDeque::new(),
+        }
+    }
+
+    fn handle_err(&mut self, res: Result<()>) -> Result<()> {
+        use Error::*;
+
+        match res {
+            Err(RemoteDisco(_)) | Err(LocalDisco) => {
+                self.closed = true;
+                res
+            }
+            Ok(_) => res,
+            Err(e) => {
+                self.queue.push_back(Err(e));
+                Ok(())
+            }
         }
     }
 
-    pub async fn run(&self) {
+    pub async fn recv(&mut self) -> Option<Result<Pkt<'static>>> {
         use Error::*;
 
-        let cleanup_chans = Arc::clone(&self.chans);
-        let mut cleanup_close = self.close.clone();
-        self.share
-            .tasks
-            .lock()
-            .await
-            /*.build_task()
-            .name("cleanup_splits")*/
-            .spawn(async move {
-                let timeout = Duration::from_secs(TIMEOUT);
-
-                ticker!(timeout, cleanup_close, {
-                    for chan_mtx in cleanup_chans.iter() {
-                        let mut chan = chan_mtx.lock().await;
+        if self.closed {
+            return None;
+        }
+
+        loop {
+            if let Some(x) = self.queue.pop_front() {
+                return Some(x);
+            }
+
+            tokio::select! {
+                _ = self.close.changed() => {
+                    self.closed = true;
+                    return Some(Err(LocalDisco));
+                },
+                _ = self.cleanup.tick() => {
+                    let timeout = Duration::from_secs(TIMEOUT);
+
+                    for chan in self.chans.iter_mut() {
                         chan.splits = chan
                             .splits
                             .drain_filter(
@@ -91,58 +114,48 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                             )
                             .collect();
                     }
-                });
-            });
-
-        let mut close = self.close.clone();
-        let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
-        tokio::pin!(timeout);
-
-        loop {
-            if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
-                // TODO: figure out whether this is a good idea
-                if let RemoteDisco(to) = e {
-                    self.pkt_tx.send(Err(RemoteDisco(to))).ok();
+                },
+                _ = self.resend.tick() => {
+                    for chan in self.share.chans.iter() {
+                        for (_, ack) in chan.lock().await.acks.iter() {
+                            self.share.send_raw(&ack.data).await.ok(); // TODO: handle error (?)
+                        }
+                    }
+                },
+                _ = self.ping.tick() => {
+                    self.share
+                        .send(
+                            PktType::Ctl,
+                            Pkt {
+                                chan: 0,
+                                unrel: false,
+                                data: Cow::Borrowed(&[CtlType::Ping as u8]),
+                            },
+                        )
+                        .await
+                        .ok();
+                }
+                _ = &mut self.timeout => {
+                    self.closed = true;
+                    return Some(Err(RemoteDisco(true)));
+                },
+                pkt = self.udp.recv() => {
+                    if let Err(e) = self.handle_pkt(pkt).await {
+                        return Some(Err(e));
+                    }
                 }
-
-                #[allow(clippy::single_match)]
-                match e {
-                                       // anon5's mt notifies the peer on timeout, C++ MT does not
-                                       LocalDisco /*| RemoteDisco(true)*/ => drop(
-                                               self.share
-                                                       .send(
-                                                               PktType::Ctl,
-                                                               Pkt {
-                                                                       unrel: true,
-                                                                       chan: 0,
-                                                                       data: Cow::Borrowed(&[CtlType::Disco as u8]),
-                                                               },
-                                                       )
-                                                       .await
-                                                       .ok(),
-                                       ),
-                                       _ => {}
-                               }
-
-                break;
             }
         }
     }
 
-    async fn recv_pkt(
-        &self,
-        close: &mut watch::Receiver<bool>,
-        timeout: Pin<&mut tokio::time::Sleep>,
-    ) -> Result<()> {
+    async fn handle_pkt(&mut self, pkt: io::Result<Vec<u8>>) -> Result<()> {
         use Error::*;
 
-        let mut cursor = io::Cursor::new(tokio::select! {
-            pkt = self.udp_rx.recv() => pkt?,
-            _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
-            _ = close.changed() => return Err(LocalDisco),
-        });
+        let mut cursor = io::Cursor::new(pkt?);
 
-        timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
+        self.timeout
+            .as_mut()
+            .reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
 
         let proto_id = cursor.read_u32::<BigEndian>()?;
         if proto_id != PROTO_ID {
@@ -151,38 +164,34 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
 
         let _peer_id = cursor.read_u16::<BigEndian>()?;
 
-        let n_chan = cursor.read_u8()?;
-        let mut chan = self
-            .chans
-            .get(n_chan as usize)
-            .ok_or(InvalidChannel(n_chan))?
-            .lock()
-            .await;
+        let chan = cursor.read_u8()?;
+        if chan >= NUM_CHANS as u8 {
+            return Err(InvalidChannel(chan));
+        }
+
+        let res = self.process_pkt(cursor, true, chan).await;
+        self.handle_err(res)?;
 
-        self.process_pkt(cursor, true, &mut chan).await
+        Ok(())
     }
 
     #[async_recursion]
     async fn process_pkt(
-        &self,
+        &mut self,
         mut cursor: io::Cursor<Vec<u8>>,
         unrel: bool,
-        chan: &mut RecvChan,
+        chan: u8,
     ) -> Result<()> {
         use Error::*;
 
+        let ch = chan as usize;
         match cursor.read_u8()?.try_into()? {
             PktType::Ctl => match cursor.read_u8()?.try_into()? {
                 CtlType::Ack => {
                     // println!("Ack");
 
                     let seqnum = cursor.read_u16::<BigEndian>()?;
-                    if let Some(ack) = self.share.chans[chan.num as usize]
-                        .lock()
-                        .await
-                        .acks
-                        .remove(&seqnum)
-                    {
+                    if let Some(ack) = self.share.chans[ch].lock().await.acks.remove(&seqnum) {
                         ack.tx.send(true).ok();
                     }
                 }
@@ -208,11 +217,11 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
             PktType::Orig => {
                 // println!("Orig");
 
-                self.pkt_tx.send(Ok(Pkt {
-                    chan: chan.num,
+                self.queue.push_back(Ok(Pkt {
+                    chan,
                     unrel,
                     data: Cow::Owned(cursor.remaining_slice().into()),
-                }))?;
+                }));
             }
             PktType::Split => {
                 // println!("Split");
@@ -221,11 +230,14 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                 let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
                 let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
 
-                let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
-                    got: 0,
-                    chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
-                    timestamp: None,
-                });
+                let mut split = self.chans[ch]
+                    .splits
+                    .entry(seqnum)
+                    .or_insert_with(|| Split {
+                        got: 0,
+                        chunks: (0..chunk_count).map(|_| None).collect(),
+                        timestamp: None,
+                    });
 
                 if split.chunks.len() != chunk_count {
                     return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
@@ -233,10 +245,10 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
 
                 if split
                     .chunks
-                    .get(chunk_index)
+                    .get_mut(chunk_index)
                     .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
-                    .set(cursor.remaining_slice().into())
-                    .is_ok()
+                    .replace(cursor.remaining_slice().into())
+                    .is_none()
                 {
                     split.got += 1;
                 }
@@ -244,25 +256,29 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                 split.timestamp = if unrel { Some(Instant::now()) } else { None };
 
                 if split.got == chunk_count {
-                    self.pkt_tx.send(Ok(Pkt {
-                        chan: chan.num,
+                    let split = self.chans[ch].splits.remove(&seqnum).unwrap();
+
+                    self.queue.push_back(Ok(Pkt {
+                        chan,
                         unrel,
                         data: split
                             .chunks
-                            .iter()
-                            .flat_map(|chunk| chunk.get().unwrap().iter())
-                            .copied()
-                            .collect(),
-                    }))?;
-
-                    chan.splits.remove(&seqnum);
+                            .into_iter()
+                            .map(|x| x.unwrap())
+                            .reduce(|mut a, mut b| {
+                                a.append(&mut b);
+                                a
+                            })
+                            .unwrap_or_default()
+                            .into(),
+                    }));
                 }
             }
             PktType::Rel => {
                 // println!("Rel");
 
                 let seqnum = cursor.read_u16::<BigEndian>()?;
-                chan.packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
+                self.chans[ch].packets[to_seqnum(seqnum)].replace(cursor.remaining_slice().into());
 
                 let mut ack_data = Vec::with_capacity(3);
                 ack_data.write_u8(CtlType::Ack as u8)?;
@@ -272,35 +288,22 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                     .send(
                         PktType::Ctl,
                         Pkt {
+                            chan,
                             unrel: true,
-                            chan: chan.num,
-                            data: Cow::Borrowed(&ack_data),
+                            data: ack_data.into(),
                         },
                     )
                     .await?;
 
-                fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
-                    chan.packets[to_seqnum(chan.seqnum)].take()
-                }
-
-                while let Some(pkt) = next_pkt(chan) {
-                    self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
-                    chan.seqnum = chan.seqnum.overflowing_add(1).0;
+                let next_pkt = |chan: &mut RecvChan| chan.packets[to_seqnum(chan.seqnum)].take();
+                while let Some(pkt) = next_pkt(&mut self.chans[ch]) {
+                    let res = self.process_pkt(io::Cursor::new(pkt), false, chan).await;
+                    self.handle_err(res)?;
+                    self.chans[ch].seqnum = self.chans[ch].seqnum.overflowing_add(1).0;
                 }
             }
         }
 
         Ok(())
     }
-
-    fn handle(&self, res: Result<()>) -> Result<()> {
-        use Error::*;
-
-        match res {
-            Ok(v) => Ok(v),
-            Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
-            Err(LocalDisco) => Err(LocalDisco),
-            Err(e) => Ok(self.pkt_tx.send(Err(e))?),
-        }
-    }
 }
index a3a7f036e5e2ea4639ed6267cbacf5e4b7441b25..2c449e15ce4e8fb94cad9fe99e47b798aecd1aaf 100644 (file)
@@ -1,17 +1,33 @@
 use super::*;
 use byteorder::{BigEndian, WriteBytesExt};
-use std::io::{self, Write};
+use std::{
+    io::{self, Write},
+    sync::Arc,
+};
 use tokio::sync::watch;
 
 pub type AckResult = io::Result<Option<watch::Receiver<bool>>>;
 
-impl<S: UdpSender> RudpSender<S> {
+pub struct RudpSender<P: UdpPeer> {
+    pub(crate) share: Arc<RudpShare<P>>,
+}
+
+// derive(Clone) adds unwanted Clone trait bound to P parameter
+impl<P: UdpPeer> Clone for RudpSender<P> {
+    fn clone(&self) -> Self {
+        Self {
+            share: Arc::clone(&self.share),
+        }
+    }
+}
+
+impl<P: UdpPeer> RudpSender<P> {
     pub async fn send(&self, pkt: Pkt<'_>) -> AckResult {
         self.share.send(PktType::Orig, pkt).await // TODO: splits
     }
 }
 
-impl<S: UdpSender> RudpShare<S> {
+impl<P: UdpPeer> RudpShare<P> {
     pub async fn send(&self, tp: PktType, pkt: Pkt<'_>) -> AckResult {
         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len());
         buf.write_u32::<BigEndian>(PROTO_ID)?;
index a2afc4c2f9c3bd40ff7c4e91aee5224e61fcbb21..02e37b2dc76de444ffd8770dfc3665b373ab142d 100644 (file)
@@ -1,10 +1,6 @@
 use super::*;
-use drop_bomb::DropBomb;
-use std::{borrow::Cow, collections::HashMap, io, sync::Arc, time::Duration};
-use tokio::{
-    sync::{mpsc, watch, Mutex, RwLock},
-    task::JoinSet,
-};
+use std::{borrow::Cow, collections::HashMap, io, sync::Arc};
+use tokio::sync::{watch, Mutex, RwLock};
 
 #[derive(Debug)]
 pub(crate) struct Ack {
@@ -20,96 +16,72 @@ pub(crate) struct Chan {
 }
 
 #[derive(Debug)]
-pub(crate) struct RudpShare<S: UdpSender> {
+pub(crate) struct RudpShare<P: UdpPeer> {
     pub(crate) id: u16,
     pub(crate) remote_id: RwLock<u16>,
-    pub(crate) chans: Vec<Mutex<Chan>>,
-    pub(crate) udp_tx: S,
-    pub(crate) close_tx: watch::Sender<bool>,
-    pub(crate) tasks: Mutex<JoinSet<()>>,
-    pub(crate) bomb: Mutex<DropBomb>,
+    pub(crate) chans: [Mutex<Chan>; NUM_CHANS],
+    pub(crate) udp_tx: P::Sender,
+    pub(crate) close: watch::Sender<bool>,
 }
 
-pub async fn new<S: UdpSender, R: UdpReceiver>(
+pub async fn new<P: UdpPeer>(
     id: u16,
     remote_id: u16,
-    udp_tx: S,
-    udp_rx: R,
-) -> io::Result<(RudpSender<S>, RudpReceiver<S>)> {
-    let (pkt_tx, pkt_rx) = mpsc::unbounded_channel();
+    udp_tx: P::Sender,
+    udp_rx: P::Receiver,
+) -> io::Result<(RudpSender<P>, RudpReceiver<P>)> {
     let (close_tx, close_rx) = watch::channel(false);
 
     let share = Arc::new(RudpShare {
         id,
         remote_id: RwLock::new(remote_id),
         udp_tx,
-        close_tx,
-        chans: (0..NUM_CHANS)
-            .map(|_| {
-                Mutex::new(Chan {
-                    acks: HashMap::new(),
-                    seqnum: INIT_SEQNUM,
-                })
+        close: close_tx,
+        chans: std::array::from_fn(|_| {
+            Mutex::new(Chan {
+                acks: HashMap::new(),
+                seqnum: INIT_SEQNUM,
             })
-            .collect(),
-        tasks: Mutex::new(JoinSet::new()),
-        bomb: Mutex::new(DropBomb::new("rudp connection must be explicitly closed")),
+        }),
     });
 
-    let mut tasks = share.tasks.lock().await;
+    Ok((
+        RudpSender {
+            share: Arc::clone(&share),
+        },
+        RudpReceiver::new(udp_rx, share, close_rx),
+    ))
+}
 
-    let recv_share = Arc::clone(&share);
-    let recv_close = close_rx.clone();
-    tasks
-        /*.build_task()
-        .name("recv")*/
-        .spawn(async move {
-            let worker = RecvWorker::new(udp_rx, recv_share, recv_close, pkt_tx);
-            worker.run().await;
-        });
+macro_rules! impl_share {
+    ($T:ident) => {
+        impl<P: UdpPeer> $T<P> {
+            pub async fn peer_id(&self) -> u16 {
+                self.share.id
+            }
 
-    let resend_share = Arc::clone(&share);
-    let mut resend_close = close_rx.clone();
-    tasks
-        /*.build_task()
-        .name("resend")*/
-        .spawn(async move {
-            ticker!(Duration::from_millis(500), resend_close, {
-                for chan in resend_share.chans.iter() {
-                    for (_, ack) in chan.lock().await.acks.iter() {
-                        resend_share.send_raw(&ack.data).await.ok(); // TODO: handle error (?)
-                    }
-                }
-            });
-        });
+            pub async fn is_server(&self) -> bool {
+                self.share.id == PeerID::Srv as u16
+            }
 
-    let ping_share = Arc::clone(&share);
-    let mut ping_close = close_rx.clone();
-    tasks
-        /*.build_task()
-        .name("ping")*/
-        .spawn(async move {
-            ticker!(Duration::from_secs(PING_TIMEOUT), ping_close, {
-                ping_share
+            pub async fn close(self) {
+                self.share.close.send(true).ok(); // FIXME: handle err?
+
+                self.share
                     .send(
                         PktType::Ctl,
                         Pkt {
+                            unrel: true,
                             chan: 0,
-                            unrel: false,
-                            data: Cow::Borrowed(&[CtlType::Ping as u8]),
+                            data: Cow::Borrowed(&[CtlType::Disco as u8]),
                         },
                     )
                     .await
                     .ok();
-            });
-        });
-
-    drop(tasks);
-
-    Ok((
-        RudpSender {
-            share: Arc::clone(&share),
-        },
-        RudpReceiver { share, pkt_rx },
-    ))
+            }
+        }
+    };
 }
+
+impl_share!(RudpReceiver);
+impl_share!(RudpSender);