]> git.lizzy.rs Git - mt_rudp.git/commitdiff
clean shutdown; send reliables
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Fri, 6 Jan 2023 16:45:16 +0000 (17:45 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Fri, 6 Jan 2023 16:45:16 +0000 (17:45 +0100)
Cargo.toml
src/client.rs
src/main.rs
src/new.rs [new file with mode: 0644]
src/recv.rs [new file with mode: 0644]
src/recv_worker.rs [deleted file]
src/send.rs [new file with mode: 0644]

index edcdf2ed01a0e4caa3ed78d97fa14eeb247a3443..7772e3b602b807ac700de8c216b397b3cdf5e3fa 100644 (file)
@@ -8,4 +8,5 @@ async-recursion = "1.0.0"
 async-trait = "0.1.60"
 byteorder = "1.4.3"
 num_enum = "0.5.7"
+pretty-hex = "0.3.0"
 tokio = { version = "1.23.0", features = ["full"] }
index d416e53a51e0801c9057b123cca983f3208d4133..172aa9649278771689a549c9ca44dc10d6a62dfd 100644 (file)
@@ -8,8 +8,8 @@ pub struct Sender {
 
 #[async_trait]
 impl UdpSender for Sender {
-    async fn send(&self, data: Vec<u8>) -> io::Result<()> {
-        self.sock.send(&data).await?;
+    async fn send(&self, data: &[u8]) -> io::Result<()> {
+        self.sock.send(data).await?;
         Ok(())
     }
 }
@@ -42,5 +42,6 @@ pub async fn connect(addr: &str) -> io::Result<(RudpSender<Sender>, RudpReceiver
             sock: Arc::clone(&sock),
         },
         Receiver { sock },
-    ))
+    )
+    .await?)
 }
index 1f0aca01f1a53a87b617be8cfc1cff2a0be890d6..0510db765d994ce6e5dad09d3aafe5b93ed40282 100644 (file)
@@ -3,19 +3,27 @@
 #![feature(once_cell)]
 mod client;
 pub mod error;
-mod recv_worker;
+mod new;
+mod recv;
+mod send;
 
 use async_trait::async_trait;
 use byteorder::{BigEndian, WriteBytesExt};
 pub use client::{connect, Sender as Client};
+pub use new::new;
 use num_enum::TryFromPrimitive;
+use pretty_hex::PrettyHex;
 use std::{
     collections::HashMap,
     io::{self, Write},
     ops,
     sync::Arc,
+    time::Duration,
+};
+use tokio::{
+    sync::{mpsc, watch, Mutex, RwLock},
+    task::JoinSet,
 };
-use tokio::sync::{mpsc, watch, Mutex, RwLock};
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
@@ -24,9 +32,25 @@ pub const REL_BUFFER: usize = 0x8000;
 pub const INIT_SEQNUM: u16 = 65500;
 pub const TIMEOUT: u64 = 30;
 
+mod ticker_mod {
+    #[macro_export]
+    macro_rules! ticker {
+               ($duration:expr, $close:expr, $body:block) => {
+                       let mut interval = tokio::time::interval($duration);
+
+                       while tokio::select!{
+                               _ = interval.tick() => true,
+                               _ = $close.changed() => false,
+                       } $body
+               };
+       }
+
+    //pub(crate) use ticker;
+}
+
 #[async_trait]
 pub trait UdpSender: Send + Sync + 'static {
-    async fn send(&self, data: Vec<u8>) -> io::Result<()>;
+    async fn send(&self, data: &[u8]) -> io::Result<()>;
 }
 
 #[async_trait]
@@ -69,14 +93,28 @@ pub struct Pkt<T> {
 
 pub type Error = error::Error;
 pub type InPkt = Result<Pkt<Vec<u8>>, Error>;
-type AckChan = (watch::Sender<bool>, watch::Receiver<bool>);
+
+#[derive(Debug)]
+struct Ack {
+    tx: watch::Sender<bool>,
+    rx: watch::Receiver<bool>,
+    data: Vec<u8>,
+}
+
+#[derive(Debug)]
+struct Chan {
+    acks: HashMap<u16, Ack>,
+    seqnum: u16,
+}
 
 #[derive(Debug)]
 pub struct RudpShare<S: UdpSender> {
-    pub id: u16,
-    pub remote_id: RwLock<u16>,
-    pub ack_chans: Mutex<HashMap<u16, AckChan>>,
+    id: u16,
+    remote_id: RwLock<u16>,
+    chans: Vec<Mutex<Chan>>,
     udp_tx: S,
+    close_tx: watch::Sender<bool>,
+    tasks: Mutex<JoinSet<()>>,
 }
 
 #[derive(Debug)]
@@ -90,44 +128,31 @@ pub struct RudpSender<S: UdpSender> {
     share: Arc<RudpShare<S>>,
 }
 
-impl<S: UdpSender> RudpShare<S> {
-    pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
-        let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
-        buf.write_u32::<BigEndian>(PROTO_ID)?;
-        buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
-        buf.write_u8(pkt.chan as u8)?;
-        buf.write_u8(tp as u8)?;
-        buf.write(pkt.data)?;
-
-        self.udp_tx.send(buf).await?;
-
-        Ok(())
-    }
-}
+macro_rules! impl_share {
+    ($T:ident) => {
+        impl<S: UdpSender> $T<S> {
+            pub async fn peer_id(&self) -> u16 {
+                self.share.id
+            }
 
-impl<S: UdpSender> RudpSender<S> {
-    pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
-        self.share.send(PktType::Orig, pkt).await // TODO
-    }
+            pub async fn is_server(&self) -> bool {
+                self.share.id == PeerID::Srv as u16
+            }
 
-    pub async fn peer_id(&self) -> u16 {
-        self.share.id
-    }
+            pub async fn close(self) {
+                self.share.close_tx.send(true).ok();
 
-    pub async fn is_server(&self) -> bool {
-        self.share.id == PeerID::Srv as u16
-    }
+                let mut tasks = self.share.tasks.lock().await;
+                while let Some(res) = tasks.join_next().await {
+                    res.ok(); // TODO: handle error (?)
+                }
+            }
+        }
+    };
 }
 
-impl<S: UdpSender> RudpReceiver<S> {
-    pub async fn peer_id(&self) -> u16 {
-        self.share.id
-    }
-
-    pub async fn is_server(&self) -> bool {
-        self.share.id == PeerID::Srv as u16
-    }
-}
+impl_share!(RudpReceiver);
+impl_share!(RudpSender);
 
 impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
     type Target = mpsc::UnboundedReceiver<InPkt>;
@@ -143,49 +168,16 @@ impl<S: UdpSender> ops::DerefMut for RudpReceiver<S> {
     }
 }
 
-pub fn new<S: UdpSender, R: UdpReceiver>(
-    id: u16,
-    remote_id: u16,
-    udp_tx: S,
-    udp_rx: R,
-) -> (RudpSender<S>, RudpReceiver<S>) {
-    let (pkt_tx, pkt_rx) = mpsc::unbounded_channel();
-
-    let share = Arc::new(RudpShare {
-        id,
-        remote_id: RwLock::new(remote_id),
-        udp_tx,
-        ack_chans: Mutex::new(HashMap::new()),
-    });
-    let recv_share = Arc::clone(&share);
-
-    tokio::spawn(async {
-        let worker = recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx);
-        worker.run().await;
-    });
-
-    (
-        RudpSender {
-            share: Arc::clone(&share),
-        },
-        RudpReceiver { share, pkt_rx },
-    )
-}
-
-// connect
-
-#[tokio::main]
-async fn main() -> io::Result<()> {
-    let (tx, mut rx) = connect("127.0.0.1:30000").await?;
-
+async fn example(tx: &RudpSender<Client>, rx: &mut RudpReceiver<Client>) -> io::Result<()> {
+    // send hello packet
     let mut mtpkt = vec![];
     mtpkt.write_u16::<BigEndian>(2)?; // high level type
     mtpkt.write_u8(29)?; // serialize ver
     mtpkt.write_u16::<BigEndian>(0)?; // compression modes
     mtpkt.write_u16::<BigEndian>(40)?; // MinProtoVer
     mtpkt.write_u16::<BigEndian>(40)?; // MaxProtoVer
-    mtpkt.write_u16::<BigEndian>(3)?; // player name length
-    mtpkt.write(b"foo")?; // player name
+    mtpkt.write_u16::<BigEndian>(6)?; // player name length
+    mtpkt.write(b"foobar")?; // player name
 
     tx.send(Pkt {
         unrel: true,
@@ -194,17 +186,34 @@ async fn main() -> io::Result<()> {
     })
     .await?;
 
+    // handle incoming packets
     while let Some(result) = rx.recv().await {
         match result {
             Ok(pkt) => {
-                io::stdout().write(pkt.data.as_slice())?;
+                println!("{}", pkt.data.hex_dump());
             }
             Err(err) => eprintln!("Error: {}", err),
         }
     }
-    println!("disco");
 
-    // close()ing rx is not needed because it has been consumed to the end
+    Ok(())
+}
+
+#[tokio::main]
+async fn main() -> io::Result<()> {
+    let (tx, mut rx) = connect("127.0.0.1:30000").await?;
+
+    tokio::select! {
+        _ = tokio::signal::ctrl_c() => println!("canceled"),
+        res = example(&tx, &mut rx) => {
+            res?;
+            println!("disconnected");
+        }
+    }
+
+    // close either the receiver or the sender
+    // this shuts down associated tasks
+    rx.close().await;
 
     Ok(())
 }
diff --git a/src/new.rs b/src/new.rs
new file mode 100644 (file)
index 0000000..a70b117
--- /dev/null
@@ -0,0 +1,63 @@
+use crate::*;
+
+pub async fn new<S: UdpSender, R: UdpReceiver>(
+    id: u16,
+    remote_id: u16,
+    udp_tx: S,
+    udp_rx: R,
+) -> io::Result<(RudpSender<S>, RudpReceiver<S>)> {
+    let (pkt_tx, pkt_rx) = mpsc::unbounded_channel();
+    let (close_tx, close_rx) = watch::channel(false);
+
+    let share = Arc::new(RudpShare {
+        id,
+        remote_id: RwLock::new(remote_id),
+        udp_tx,
+        close_tx,
+        chans: (0..NUM_CHANS)
+            .map(|_| {
+                Mutex::new(Chan {
+                    acks: HashMap::new(),
+                    seqnum: INIT_SEQNUM,
+                })
+            })
+            .collect(),
+        tasks: Mutex::new(JoinSet::new()),
+    });
+
+    let mut tasks = share.tasks.lock().await;
+
+    let recv_share = Arc::clone(&share);
+    let recv_close = close_rx.clone();
+    tasks
+        /*.build_task()
+        .name("recv")*/
+        .spawn(async move {
+            let worker = recv::RecvWorker::new(udp_rx, recv_share, recv_close, pkt_tx);
+            worker.run().await;
+        });
+
+    let resend_share = Arc::clone(&share);
+    let mut resend_close = close_rx.clone();
+    tasks
+        /*.build_task()
+        .name("resend")*/
+        .spawn(async move {
+            ticker!(Duration::from_millis(500), resend_close, {
+                for chan in resend_share.chans.iter() {
+                    for (_, ack) in chan.lock().await.acks.iter() {
+                        resend_share.send_raw(&ack.data).await.ok(); // TODO: handle error (?)
+                    }
+                }
+            });
+        });
+
+    drop(tasks);
+
+    Ok((
+        RudpSender {
+            share: Arc::clone(&share),
+        },
+        RudpReceiver { share, pkt_rx },
+    ))
+}
diff --git a/src/recv.rs b/src/recv.rs
new file mode 100644 (file)
index 0000000..15811f2
--- /dev/null
@@ -0,0 +1,283 @@
+use crate::{error::Error, *};
+use async_recursion::async_recursion;
+use byteorder::{BigEndian, ReadBytesExt};
+use std::{
+    cell::{Cell, OnceCell},
+    collections::HashMap,
+    io,
+    sync::Arc,
+    time::{Duration, Instant},
+};
+use tokio::sync::{mpsc, Mutex};
+
+fn to_seqnum(seqnum: u16) -> usize {
+    (seqnum as usize) & (REL_BUFFER - 1)
+}
+
+type Result<T> = std::result::Result<T, Error>;
+
+struct Split {
+    timestamp: Option<Instant>,
+    chunks: Vec<OnceCell<Vec<u8>>>,
+    got: usize,
+}
+
+struct RecvChan {
+    packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
+    splits: HashMap<u16, Split>,
+    seqnum: u16,
+    num: u8,
+}
+
+pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
+    share: Arc<RudpShare<S>>,
+    close: watch::Receiver<bool>,
+    chans: Arc<Vec<Mutex<RecvChan>>>,
+    pkt_tx: mpsc::UnboundedSender<InPkt>,
+    udp_rx: R,
+}
+
+impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
+    pub fn new(
+        udp_rx: R,
+        share: Arc<RudpShare<S>>,
+        close: watch::Receiver<bool>,
+        pkt_tx: mpsc::UnboundedSender<InPkt>,
+    ) -> Self {
+        Self {
+            udp_rx,
+            share,
+            close,
+            pkt_tx,
+            chans: Arc::new(
+                (0..NUM_CHANS as u8)
+                    .map(|num| {
+                        Mutex::new(RecvChan {
+                            num,
+                            packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
+                            seqnum: INIT_SEQNUM,
+                            splits: HashMap::new(),
+                        })
+                    })
+                    .collect(),
+            ),
+        }
+    }
+
+    pub async fn run(&self) {
+        let cleanup_chans = Arc::clone(&self.chans);
+        let mut cleanup_close = self.close.clone();
+        self.share
+            .tasks
+            .lock()
+            .await
+            /*.build_task()
+            .name("cleanup_splits")*/
+            .spawn(async move {
+                let timeout = Duration::from_secs(TIMEOUT);
+
+                ticker!(timeout, cleanup_close, {
+                    for chan_mtx in cleanup_chans.iter() {
+                        let mut chan = chan_mtx.lock().await;
+                        chan.splits = chan
+                            .splits
+                            .drain_filter(
+                                |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
+                            )
+                            .collect();
+                    }
+                });
+            });
+
+        let mut close = self.close.clone();
+        loop {
+            if let Err(e) = self.handle(self.recv_pkt(&mut close).await) {
+                if let Error::LocalDisco = e {
+                    self.share
+                        .send(
+                            PktType::Ctl,
+                            Pkt {
+                                unrel: true,
+                                chan: 0,
+                                data: &[CtlType::Disco as u8],
+                            },
+                        )
+                        .await
+                        .ok();
+                }
+                break;
+            }
+        }
+    }
+
+    async fn recv_pkt(&self, close: &mut watch::Receiver<bool>) -> Result<()> {
+        use Error::*;
+
+        // TODO: reset timeout
+        let mut cursor = io::Cursor::new(tokio::select! {
+            pkt = self.udp_rx.recv() => pkt?,
+            _ = close.changed() => return Err(LocalDisco),
+        });
+
+        println!("recv");
+
+        let proto_id = cursor.read_u32::<BigEndian>()?;
+        if proto_id != PROTO_ID {
+            return Err(InvalidProtoId(proto_id));
+        }
+
+        let _peer_id = cursor.read_u16::<BigEndian>()?;
+
+        let n_chan = cursor.read_u8()?;
+        let mut chan = self
+            .chans
+            .get(n_chan as usize)
+            .ok_or(InvalidChannel(n_chan))?
+            .lock()
+            .await;
+
+        self.process_pkt(cursor, true, &mut chan).await
+    }
+
+    #[async_recursion]
+    async fn process_pkt(
+        &self,
+        mut cursor: io::Cursor<Vec<u8>>,
+        unrel: bool,
+        chan: &mut RecvChan,
+    ) -> Result<()> {
+        use Error::*;
+
+        match cursor.read_u8()?.try_into()? {
+            PktType::Ctl => match cursor.read_u8()?.try_into()? {
+                CtlType::Ack => {
+                    println!("Ack");
+
+                    let seqnum = cursor.read_u16::<BigEndian>()?;
+                    if let Some(ack) = self.share.chans[chan.num as usize]
+                        .lock()
+                        .await
+                        .acks
+                        .remove(&seqnum)
+                    {
+                        ack.tx.send(true).ok();
+                    }
+                }
+                CtlType::SetPeerID => {
+                    println!("SetPeerID");
+
+                    let mut id = self.share.remote_id.write().await;
+
+                    if *id != PeerID::Nil as u16 {
+                        return Err(PeerIDAlreadySet);
+                    }
+
+                    *id = cursor.read_u16::<BigEndian>()?;
+                }
+                CtlType::Ping => {
+                    println!("Ping");
+                }
+                CtlType::Disco => {
+                    println!("Disco");
+                    return Err(RemoteDisco);
+                }
+            },
+            PktType::Orig => {
+                println!("Orig");
+
+                self.pkt_tx.send(Ok(Pkt {
+                    chan: chan.num,
+                    unrel,
+                    data: cursor.remaining_slice().into(),
+                }))?;
+            }
+            PktType::Split => {
+                println!("Split");
+
+                let seqnum = cursor.read_u16::<BigEndian>()?;
+                let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
+                let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
+
+                let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
+                    got: 0,
+                    chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
+                    timestamp: None,
+                });
+
+                if split.chunks.len() != chunk_count {
+                    return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
+                }
+
+                if split
+                    .chunks
+                    .get(chunk_index)
+                    .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
+                    .set(cursor.remaining_slice().into())
+                    .is_ok()
+                {
+                    split.got += 1;
+                }
+
+                split.timestamp = if unrel { Some(Instant::now()) } else { None };
+
+                if split.got == chunk_count {
+                    self.pkt_tx.send(Ok(Pkt {
+                        chan: chan.num,
+                        unrel,
+                        data: split
+                            .chunks
+                            .iter()
+                            .flat_map(|chunk| chunk.get().unwrap().iter())
+                            .copied()
+                            .collect(),
+                    }))?;
+
+                    chan.splits.remove(&seqnum);
+                }
+            }
+            PktType::Rel => {
+                println!("Rel");
+
+                let seqnum = cursor.read_u16::<BigEndian>()?;
+                chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
+
+                let mut ack_data = Vec::with_capacity(3);
+                ack_data.write_u8(CtlType::Ack as u8)?;
+                ack_data.write_u16::<BigEndian>(seqnum)?;
+
+                self.share
+                    .send(
+                        PktType::Ctl,
+                        Pkt {
+                            unrel: true,
+                            chan: chan.num,
+                            data: &ack_data,
+                        },
+                    )
+                    .await?;
+
+                fn next_pkt(chan: &mut RecvChan) -> Option<Vec<u8>> {
+                    chan.packets[to_seqnum(chan.seqnum)].take()
+                }
+
+                while let Some(pkt) = next_pkt(chan) {
+                    self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
+                    chan.seqnum = chan.seqnum.overflowing_add(1).0;
+                }
+            }
+        }
+
+        Ok(())
+    }
+
+    fn handle(&self, res: Result<()>) -> Result<()> {
+        use Error::*;
+
+        match res {
+            Ok(v) => Ok(v),
+            Err(RemoteDisco) => Err(RemoteDisco),
+            Err(LocalDisco) => Err(LocalDisco),
+            Err(e) => Ok(self.pkt_tx.send(Err(e))?),
+        }
+    }
+}
diff --git a/src/recv_worker.rs b/src/recv_worker.rs
deleted file mode 100644 (file)
index 83b3273..0000000
+++ /dev/null
@@ -1,256 +0,0 @@
-use crate::{error::Error, *};
-use async_recursion::async_recursion;
-use byteorder::{BigEndian, ReadBytesExt};
-use std::{
-    cell::{Cell, OnceCell},
-    collections::HashMap,
-    io,
-    sync::{Arc, Weak},
-    time,
-};
-use tokio::sync::{mpsc, Mutex};
-
-fn to_seqnum(seqnum: u16) -> usize {
-    (seqnum as usize) & (REL_BUFFER - 1)
-}
-
-type Result<T> = std::result::Result<T, Error>;
-
-struct Split {
-    timestamp: Option<time::Instant>,
-    chunks: Vec<OnceCell<Vec<u8>>>,
-    got: usize,
-}
-
-struct Chan {
-    packets: Vec<Cell<Option<Vec<u8>>>>, // char ** ðŸ˜›
-    splits: HashMap<u16, Split>,
-    seqnum: u16,
-    num: u8,
-}
-
-pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
-    share: Arc<RudpShare<S>>,
-    chans: Arc<Vec<Mutex<Chan>>>,
-    pkt_tx: mpsc::UnboundedSender<InPkt>,
-    udp_rx: R,
-}
-
-impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
-    pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: mpsc::UnboundedSender<InPkt>) -> Self {
-        Self {
-            udp_rx,
-            share,
-            pkt_tx,
-            chans: Arc::new(
-                (0..NUM_CHANS as u8)
-                    .map(|num| {
-                        Mutex::new(Chan {
-                            num,
-                            packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
-                            seqnum: INIT_SEQNUM,
-                            splits: HashMap::new(),
-                        })
-                    })
-                    .collect(),
-            ),
-        }
-    }
-
-    pub async fn run(&self) {
-        let cleanup_chans = Arc::downgrade(&self.chans);
-        tokio::spawn(async move {
-            let timeout = time::Duration::from_secs(TIMEOUT);
-            let mut interval = tokio::time::interval(timeout);
-
-            while let Some(chans) = Weak::upgrade(&cleanup_chans) {
-                for chan in chans.iter() {
-                    let mut ch = chan.lock().await;
-                    ch.splits = ch
-                        .splits
-                        .drain_filter(
-                            |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
-                        )
-                        .collect();
-                }
-
-                interval.tick().await;
-            }
-        });
-
-        loop {
-            if let Err(e) = self.handle(self.recv_pkt().await) {
-                if let Error::LocalDisco = e {
-                    self.share
-                        .send(
-                            PktType::Ctl,
-                            Pkt {
-                                unrel: true,
-                                chan: 0,
-                                data: &[CtlType::Disco as u8],
-                            },
-                        )
-                        .await
-                        .ok();
-                }
-                break;
-            }
-        }
-    }
-
-    async fn recv_pkt(&self) -> Result<()> {
-        use Error::*;
-
-        // todo: reset timeout
-        let mut cursor = io::Cursor::new(self.udp_rx.recv().await?);
-
-        let proto_id = cursor.read_u32::<BigEndian>()?;
-        if proto_id != PROTO_ID {
-            return Err(InvalidProtoId(proto_id));
-        }
-
-        let _peer_id = cursor.read_u16::<BigEndian>()?;
-
-        let n_chan = cursor.read_u8()?;
-        let mut chan = self
-            .chans
-            .get(n_chan as usize)
-            .ok_or(InvalidChannel(n_chan))?
-            .lock()
-            .await;
-
-        self.process_pkt(cursor, true, &mut chan).await
-    }
-
-    #[async_recursion]
-    async fn process_pkt(
-        &self,
-        mut cursor: io::Cursor<Vec<u8>>,
-        unrel: bool,
-        chan: &mut Chan,
-    ) -> Result<()> {
-        use Error::*;
-
-        match cursor.read_u8()?.try_into()? {
-            PktType::Ctl => match cursor.read_u8()?.try_into()? {
-                CtlType::Ack => {
-                    let seqnum = cursor.read_u16::<BigEndian>()?;
-                    if let Some((tx, _)) = self.share.ack_chans.lock().await.remove(&seqnum) {
-                        tx.send(true).ok();
-                    }
-                }
-                CtlType::SetPeerID => {
-                    let mut id = self.share.remote_id.write().await;
-
-                    if *id != PeerID::Nil as u16 {
-                        return Err(PeerIDAlreadySet);
-                    }
-
-                    *id = cursor.read_u16::<BigEndian>()?;
-                }
-                CtlType::Ping => {}
-                CtlType::Disco => return Err(RemoteDisco),
-            },
-            PktType::Orig => {
-                println!("Orig");
-
-                self.pkt_tx.send(Ok(Pkt {
-                    chan: chan.num,
-                    unrel,
-                    data: cursor.remaining_slice().into(),
-                }))?;
-            }
-            PktType::Split => {
-                println!("Split");
-
-                let seqnum = cursor.read_u16::<BigEndian>()?;
-                let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
-                let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
-
-                let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
-                    got: 0,
-                    chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
-                    timestamp: None,
-                });
-
-                if split.chunks.len() != chunk_count {
-                    return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
-                }
-
-                if split
-                    .chunks
-                    .get(chunk_index)
-                    .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
-                    .set(cursor.remaining_slice().into())
-                    .is_ok()
-                {
-                    split.got += 1;
-                }
-
-                split.timestamp = if unrel {
-                    Some(time::Instant::now())
-                } else {
-                    None
-                };
-
-                if split.got == chunk_count {
-                    self.pkt_tx.send(Ok(Pkt {
-                        chan: chan.num,
-                        unrel,
-                        data: split
-                            .chunks
-                            .iter()
-                            .flat_map(|chunk| chunk.get().unwrap().iter())
-                            .copied()
-                            .collect(),
-                    }))?;
-
-                    chan.splits.remove(&seqnum);
-                }
-            }
-            PktType::Rel => {
-                println!("Rel");
-
-                let seqnum = cursor.read_u16::<BigEndian>()?;
-                chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
-
-                let mut ack_data = Vec::with_capacity(3);
-                ack_data.write_u8(CtlType::Ack as u8)?;
-                ack_data.write_u16::<BigEndian>(seqnum)?;
-
-                self.share
-                    .send(
-                        PktType::Ctl,
-                        Pkt {
-                            unrel: true,
-                            chan: chan.num,
-                            data: &ack_data,
-                        },
-                    )
-                    .await?;
-
-                fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
-                    chan.packets[to_seqnum(chan.seqnum)].take()
-                }
-
-                while let Some(pkt) = next_pkt(chan) {
-                    self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
-                    chan.seqnum = chan.seqnum.overflowing_add(1).0;
-                }
-            }
-        }
-
-        Ok(())
-    }
-
-    fn handle(&self, res: Result<()>) -> Result<()> {
-        use Error::*;
-
-        match res {
-            Ok(v) => Ok(v),
-            Err(RemoteDisco) => Err(RemoteDisco),
-            Err(LocalDisco) => Err(LocalDisco),
-            Err(e) => Ok(self.pkt_tx.send(Err(e))?),
-        }
-    }
-}
diff --git a/src/send.rs b/src/send.rs
new file mode 100644 (file)
index 0000000..89c15c7
--- /dev/null
@@ -0,0 +1,55 @@
+use crate::*;
+use tokio::sync::watch;
+
+type AckResult = io::Result<Option<watch::Receiver<bool>>>;
+
+impl<S: UdpSender> RudpSender<S> {
+    pub async fn send(&self, pkt: Pkt<&[u8]>) -> AckResult {
+        self.share.send(PktType::Orig, pkt).await // TODO: splits
+    }
+}
+
+impl<S: UdpSender> RudpShare<S> {
+    pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> AckResult {
+        let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + 2 + 1 + pkt.data.len());
+        buf.write_u32::<BigEndian>(PROTO_ID)?;
+        buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
+        buf.write_u8(pkt.chan as u8)?;
+
+        let mut chan = self.chans[pkt.chan as usize].lock().await;
+        let seqnum = chan.seqnum;
+
+        if !pkt.unrel {
+            buf.write_u8(PktType::Rel as u8)?;
+            buf.write_u16::<BigEndian>(seqnum)?;
+        }
+
+        buf.write_u8(tp as u8)?;
+        buf.write(pkt.data)?;
+
+        self.send_raw(&buf).await?;
+
+        if pkt.unrel {
+            Ok(None)
+        } else {
+            // TODO: reliable window
+            let (tx, rx) = watch::channel(false);
+            chan.acks.insert(
+                seqnum,
+                Ack {
+                    tx,
+                    rx: rx.clone(),
+                    data: buf,
+                },
+            );
+            chan.seqnum += 1;
+
+            Ok(Some(rx))
+        }
+    }
+
+    pub async fn send_raw(&self, data: &[u8]) -> io::Result<()> {
+        self.udp_tx.send(data).await
+        // TODO: reset ping timeout
+    }
+}