]> git.lizzy.rs Git - mt_rudp.git/commitdiff
files
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Thu, 22 Dec 2022 22:02:33 +0000 (23:02 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Thu, 22 Dec 2022 22:02:33 +0000 (23:02 +0100)
src/client.rs [new file with mode: 0644]
src/error.rs [new file with mode: 0644]
src/main.rs
src/recv_worker.rs [new file with mode: 0644]

diff --git a/src/client.rs b/src/client.rs
new file mode 100644 (file)
index 0000000..e506a3e
--- /dev/null
@@ -0,0 +1,52 @@
+use crate::{PeerID, UdpReceiver, UdpSender};
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
+use std::{
+    cell::Cell,
+    fmt,
+    io::{self, Write},
+    net, ops,
+    sync::{mpsc, Arc},
+    thread,
+};
+
+pub struct Sender {
+    sock: Arc<net::UdpSocket>,
+}
+
+impl UdpSender for Sender {
+    fn send(&self, data: Vec<u8>) -> io::Result<()> {
+        self.sock.send(&data)?;
+        Ok(())
+    }
+}
+
+pub struct Receiver {
+    sock: Arc<net::UdpSocket>,
+}
+
+impl UdpReceiver for Receiver {
+    fn recv(&self) -> io::Result<Vec<u8>> {
+        let mut buffer = Vec::new();
+        buffer.resize(crate::UDP_PKT_SIZE, 0);
+
+        let len = self.sock.recv(&mut buffer)?;
+        buffer.truncate(len);
+
+        Ok(buffer)
+    }
+}
+
+pub fn connect(addr: &str) -> io::Result<(crate::RudpSender<Sender>, crate::RudpReceiver<Sender>)> {
+    let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0")?);
+    sock.connect(addr)?;
+
+    Ok(crate::new(
+        PeerID::Srv as u16,
+        PeerID::Nil as u16,
+        Sender {
+            sock: Arc::clone(&sock),
+        },
+        Receiver { sock },
+    ))
+}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644 (file)
index 0000000..02080c7
--- /dev/null
@@ -0,0 +1,57 @@
+use crate::{CtlType, InPkt, PktType};
+use num_enum::TryFromPrimitiveError;
+use std::{fmt, io, sync::mpsc};
+
+#[derive(Debug)]
+pub enum Error {
+    IoError(io::Error),
+    InvalidProtoId(u32),
+    InvalidPeerID,
+    InvalidChannel(u8),
+    InvalidType(u8),
+    InvalidCtlType(u8),
+    RemoteDisco,
+    LocalDisco,
+}
+
+impl From<io::Error> for Error {
+    fn from(err: io::Error) -> Self {
+        Self::IoError(err)
+    }
+}
+
+impl From<TryFromPrimitiveError<PktType>> for Error {
+    fn from(err: TryFromPrimitiveError<PktType>) -> Self {
+        Self::InvalidType(err.number)
+    }
+}
+
+impl From<TryFromPrimitiveError<CtlType>> for Error {
+    fn from(err: TryFromPrimitiveError<CtlType>) -> Self {
+        Self::InvalidType(err.number)
+    }
+}
+
+impl From<mpsc::SendError<InPkt>> for Error {
+    fn from(_err: mpsc::SendError<InPkt>) -> Self {
+        Self::LocalDisco // technically not a disconnect but a local drop
+    }
+}
+
+impl fmt::Display for Error {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        use Error::*;
+        write!(f, "RUDP Error: ")?;
+
+        match self {
+            IoError(err) => write!(f, "IO Error: {}", err),
+            InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
+            InvalidPeerID => write!(f, "Invalid Peer ID"),
+            InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
+            InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
+            InvalidCtlType(tp) => write!(f, "Invalid Control Type: {tp}"),
+            RemoteDisco => write!(f, "Remote Disconnected"),
+            LocalDisco => write!(f, "Local Disconnected"),
+        }
+    }
+}
index 61ca9830234e009dcc171c3e4cae41b37e560442..da535730877fb86cd37ddfac63a5e21f7950282d 100644 (file)
@@ -1,23 +1,36 @@
 #![feature(yeet_expr)]
 #![feature(cursor_remaining)]
+mod client;
+pub mod error;
+mod recv_worker;
+
 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+pub use client::{connect, Sender as Client};
 use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
 use std::{
-    cell::Cell,
-    fmt,
     io::{self, Write},
-    net,
+    ops,
     sync::{mpsc, Arc},
     thread,
 };
 
 pub const PROTO_ID: u32 = 0x4f457403;
 pub const UDP_PKT_SIZE: usize = 512;
-pub const NUM_CHANNELS: usize = 3;
+pub const NUM_CHANS: usize = 3;
 pub const REL_BUFFER: usize = 0x8000;
 pub const INIT_SEQNUM: u16 = 65500;
 
-#[derive(Debug, Copy, Clone, PartialEq)]
+pub type Error = error::Error;
+
+pub trait UdpSender: Send + Sync + 'static {
+    fn send(&self, data: Vec<u8>) -> io::Result<()>;
+}
+
+pub trait UdpReceiver: Send + Sync + 'static {
+    fn recv(&self) -> io::Result<Vec<u8>>;
+}
+
+#[derive(Debug, Copy, Clone)]
 pub enum PeerID {
     Nil = 0,
     Srv,
@@ -33,6 +46,15 @@ pub enum PktType {
     Rel,
 }
 
+#[derive(Debug, Copy, Clone, PartialEq, TryFromPrimitive)]
+#[repr(u8)]
+pub enum CtlType {
+    Ack = 0,
+    SetPeerID,
+    Ping,
+    Disco,
+}
+
 #[derive(Debug)]
 pub struct Pkt<T> {
     unrel: bool,
@@ -40,252 +62,121 @@ pub struct Pkt<T> {
     data: T,
 }
 
-#[derive(Debug)]
-pub enum Error {
-    IoError(io::Error),
-    InvalidProtoId(u32),
-    InvalidPeerID,
-    InvalidChannel(u8),
-    InvalidType(u8),
-    LocalHangup,
-}
-
-impl From<io::Error> for Error {
-    fn from(err: io::Error) -> Self {
-        Self::IoError(err)
-    }
-}
+pub type InPkt = Result<Pkt<Vec<u8>>, Error>;
 
-impl From<TryFromPrimitiveError<PktType>> for Error {
-    fn from(err: TryFromPrimitiveError<PktType>) -> Self {
-        Self::InvalidType(err.number)
-    }
-}
-
-impl From<mpsc::SendError<PktResult>> for Error {
-    fn from(_err: mpsc::SendError<PktResult>) -> Self {
-        Self::LocalHangup
-    }
-}
-
-impl fmt::Display for Error {
-    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        use Error::*;
-        write!(f, "RUDP Error: ")?;
-
-        match self {
-            IoError(err) => write!(f, "IO Error: {}", err),
-            InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
-            InvalidPeerID => write!(f, "Invalid Peer ID"),
-            InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
-            InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
-            LocalHangup => write!(f, "Local packet receiver hung up"),
-        }
-    }
-}
+#[derive(Debug)]
+pub struct AckChan;
 
 #[derive(Debug)]
-struct Channel {
-    num: u8,
+pub struct RudpShare<S: UdpSender> {
+    pub id: u16,
+    pub remote_id: u16,
+    pub chans: Vec<AckChan>,
+    udp_tx: S,
 }
 
-type RelPkt = Cell<Option<Vec<u8>>>;
-
-struct RecvChannel<'a> {
-    packets: Vec<RelPkt>, // used to be called char **
-    seqnum: u16,
-    main: &'a Channel,
+#[derive(Debug)]
+pub struct RudpReceiver<S: UdpSender> {
+    share: Arc<RudpShare<S>>,
+    pkt_rx: mpsc::Receiver<InPkt>,
 }
 
-pub type PktResult = Result<Pkt<Vec<u8>>, Error>;
-type PktSender = mpsc::Sender<PktResult>;
-
-trait HandleError {
-    fn handle(&self, res: Result<(), Error>) -> bool;
+#[derive(Debug)]
+pub struct RudpSender<S: UdpSender> {
+    share: Arc<RudpShare<S>>,
 }
 
-impl HandleError for PktSender {
-    fn handle(&self, res: Result<(), Error>) -> bool {
-        if let Err(err) = res {
-            if !self.send(Err(err)).is_ok() {
-                return false;
-            }
+impl<S: UdpSender> RudpShare<S> {
+    pub fn new(id: u16, remote_id: u16, udp_tx: S) -> Self {
+        Self {
+            id,
+            remote_id,
+            udp_tx,
+            chans: (0..NUM_CHANS).map(|_| AckChan).collect(),
         }
-
-        true
     }
-}
-
-fn to_seqnum(seqnum: u16) -> usize {
-    (seqnum as usize) & (REL_BUFFER - 1)
-}
 
-#[derive(Debug)]
-struct ConnInner {
-    sock: net::UdpSocket,
-    id: u16,
-    remote_id: u16,
-    chans: Vec<Channel>,
-}
-
-impl ConnInner {
-    pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
+    pub fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
         buf.write_u32::<BigEndian>(PROTO_ID)?;
         buf.write_u16::<BigEndian>(self.remote_id)?;
         buf.write_u8(pkt.chan as u8)?;
-        buf.write_u8(PktType::Orig as u8)?;
+        buf.write_u8(tp as u8)?;
         buf.write(pkt.data)?;
 
-        self.sock.send(&buf)?;
+        self.udp_tx.send(buf)?;
 
         Ok(())
     }
+}
 
-    fn recv_loop(&self, tx: PktSender) {
-        let mut inbox = [0; UDP_PKT_SIZE];
-
-        let mut recv_chans = self
-            .chans
-            .iter()
-            .map(|main| RecvChannel {
-                main,
-                packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
-                seqnum: INIT_SEQNUM,
-            })
-            .collect();
-
-        while tx.handle(self.recv_pkt(&mut inbox, &mut recv_chans, &tx)) {}
-    }
-
-    fn recv_pkt(
-        &self,
-        buffer: &mut [u8],
-        chans: &mut Vec<RecvChannel>,
-        tx: &PktSender,
-    ) -> Result<(), Error> {
-        use Error::*;
-
-        // todo: reset timeout
-        let len = self.sock.recv(buffer)?;
-        let mut cursor = io::Cursor::new(&buffer[..len]);
-
-        let proto_id = cursor.read_u32::<BigEndian>()?;
-        if proto_id != PROTO_ID {
-            do yeet InvalidProtoId(proto_id);
-        }
-
-        let peer_id = cursor.read_u16::<BigEndian>()?;
-
-        let n_chan = cursor.read_u8()?;
-        let chan = chans
-            .get_mut(n_chan as usize)
-            .ok_or(InvalidChannel(n_chan))?;
-
-        self.process_pkt(cursor, chan, tx)
+impl<S: UdpSender> RudpSender<S> {
+    pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
+        self.share.send(PktType::Orig, pkt) // TODO
     }
+}
 
-    fn process_pkt(
-        &self,
-        mut cursor: io::Cursor<&[u8]>,
-        chan: &mut RecvChannel,
-        tx: &PktSender,
-    ) -> Result<(), Error> {
-        use PktType::*;
-
-        match cursor.read_u8()?.try_into()? {
-            Ctl => {
-                dbg!("Ctl");
-                dbg!(cursor.remaining_slice());
-            }
-            Orig => {
-                tx.send(Ok(Pkt {
-                    chan: chan.main.num,
-                    unrel: true,
-                    data: cursor.remaining_slice().into(),
-                }))?;
-            }
-            Split => {
-                dbg!("Split");
-                dbg!(cursor.remaining_slice());
-            }
-            Rel => {
-                let seqnum = cursor.read_u16::<BigEndian>()?;
-                chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
-
-                while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
-                    tx.handle(self.process_pkt(io::Cursor::new(&pkt), chan, tx));
-                    chan.seqnum = chan.seqnum.overflowing_add(1).0;
-                }
-            }
-        }
+impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
+    type Target = mpsc::Receiver<InPkt>;
 
-        Ok(())
+    fn deref(&self) -> &Self::Target {
+        &self.pkt_rx
     }
 }
 
-#[derive(Debug)]
-pub struct Conn {
-    inner: Arc<ConnInner>,
-    rx: mpsc::Receiver<PktResult>,
-}
-
-impl Conn {
-    pub fn connect(addr: &str) -> io::Result<Self> {
-        let (tx, rx) = mpsc::channel();
-
-        let inner = Arc::new(ConnInner {
-            sock: net::UdpSocket::bind("0.0.0.0:0")?,
-            id: PeerID::Srv as u16,
-            remote_id: PeerID::Nil as u16,
-            chans: (0..NUM_CHANNELS as u8).map(|num| Channel { num }).collect(),
-        });
-
-        inner.sock.connect(addr)?;
-
-        let recv_inner = Arc::clone(&inner);
-        thread::spawn(move || {
-            recv_inner.recv_loop(tx);
-        });
+pub fn new<S: UdpSender, R: UdpReceiver>(
+    id: u16,
+    remote_id: u16,
+    udp_tx: S,
+    udp_rx: R,
+) -> (RudpSender<S>, RudpReceiver<S>) {
+    let (pkt_tx, pkt_rx) = mpsc::channel();
 
-        Ok(Conn { inner, rx })
-    }
+    let share = Arc::new(RudpShare::new(id, remote_id, udp_tx));
+    let recv_share = Arc::clone(&share);
 
-    pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
-        self.inner.send(pkt)
-    }
+    thread::spawn(move || {
+        recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx).run();
+    });
 
-    pub fn recv(&self) -> Result<PktResult, mpsc::RecvError> {
-        self.rx.recv()
-    }
+    (
+        RudpSender {
+            share: Arc::clone(&share),
+        },
+        RudpReceiver { share, pkt_rx },
+    )
 }
 
-fn main() {
+// connect
+
+fn main() -> io::Result<()> {
     //println!("{}", x.deep_size_of());
-    let conn = Conn::connect("127.0.0.1:30000").expect("the spanish inquisition");
+    let (tx, rx) = connect("127.0.0.1:30000")?;
 
     let mut mtpkt = vec![];
-    mtpkt.write_u16::<BigEndian>(2).unwrap(); // high level type
-    mtpkt.write_u8(29).unwrap(); // serialize ver
-    mtpkt.write_u16::<BigEndian>(0).unwrap(); // compression modes
-    mtpkt.write_u16::<BigEndian>(40).unwrap(); // MinProtoVer
-    mtpkt.write_u16::<BigEndian>(40).unwrap(); // MaxProtoVer
-    mtpkt.write_u16::<BigEndian>(3).unwrap(); // player name length
-    mtpkt.write(b"foo").unwrap(); // player name
-
-    conn.send(Pkt {
+    mtpkt.write_u16::<BigEndian>(2)?; // high level type
+    mtpkt.write_u8(29)?; // serialize ver
+    mtpkt.write_u16::<BigEndian>(0)?; // compression modes
+    mtpkt.write_u16::<BigEndian>(40)?; // MinProtoVer
+    mtpkt.write_u16::<BigEndian>(40)?; // MaxProtoVer
+    mtpkt.write_u16::<BigEndian>(3)?; // player name length
+    mtpkt.write(b"foo")?; // player name
+
+    tx.send(Pkt {
         unrel: true,
         chan: 1,
         data: &mtpkt,
-    })
-    .unwrap();
+    })?;
 
-    while let Ok(result) = conn.recv() {
+    while let Ok(result) = rx.recv() {
         match result {
             Ok(pkt) => {
-                io::stdout().write(pkt.data.as_slice()).unwrap();
+                io::stdout().write(pkt.data.as_slice())?;
             }
             Err(err) => eprintln!("Error: {}", err),
         }
     }
+    println!("disco");
+
+    Ok(())
 }
diff --git a/src/recv_worker.rs b/src/recv_worker.rs
new file mode 100644 (file)
index 0000000..d1ae5b1
--- /dev/null
@@ -0,0 +1,136 @@
+use crate::{error::Error, CtlType, InPkt, Pkt, PktType, RudpShare, UdpReceiver, UdpSender};
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
+use std::{
+    cell::Cell,
+    io, result,
+    sync::{mpsc, Arc},
+};
+
+fn to_seqnum(seqnum: u16) -> usize {
+    (seqnum as usize) & (crate::REL_BUFFER - 1)
+}
+
+struct RelChan {
+    packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
+    seqnum: u16,
+    num: u8,
+}
+
+type PktTx = mpsc::Sender<InPkt>;
+type Result = result::Result<(), Error>;
+
+pub struct RecvWorker<R: UdpReceiver, S: UdpSender> {
+    share: Arc<RudpShare<S>>,
+    pkt_tx: PktTx,
+    udp_rx: R,
+}
+
+impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
+    pub fn new(udp_rx: R, share: Arc<RudpShare<S>>, pkt_tx: PktTx) -> Self {
+        Self {
+            udp_rx,
+            share,
+            pkt_tx,
+        }
+    }
+
+    pub fn run(&self) {
+        let mut recv_chans = (0..crate::NUM_CHANS as u8)
+            .map(|num| RelChan {
+                num,
+                packets: (0..crate::REL_BUFFER).map(|_| Cell::new(None)).collect(),
+                seqnum: crate::INIT_SEQNUM,
+            })
+            .collect();
+
+        loop {
+            if let Err(e) = self.handle(self.recv_pkt(&mut recv_chans)) {
+                if let Error::LocalDisco = e {
+                    self.share
+                        .send(
+                            PktType::Ctl,
+                            Pkt {
+                                unrel: true,
+                                chan: 0,
+                                data: &[CtlType::Disco as u8],
+                            },
+                        )
+                        .ok();
+                }
+                break;
+            }
+        }
+    }
+
+    fn recv_pkt(&self, chans: &mut Vec<RelChan>) -> Result {
+        use Error::*;
+
+        // todo: reset timeout
+        let mut cursor = io::Cursor::new(self.udp_rx.recv()?);
+
+        let proto_id = cursor.read_u32::<BigEndian>()?;
+        if proto_id != crate::PROTO_ID {
+            do yeet InvalidProtoId(proto_id);
+        }
+
+        let peer_id = cursor.read_u16::<BigEndian>()?;
+
+        let n_chan = cursor.read_u8()?;
+        let chan = chans
+            .get_mut(n_chan as usize)
+            .ok_or(InvalidChannel(n_chan))?;
+
+        self.process_pkt(cursor, chan)
+    }
+
+    fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut RelChan) -> Result {
+        use CtlType::*;
+        use Error::*;
+        use PktType::*;
+
+        match cursor.read_u8()?.try_into()? {
+            Ctl => match cursor.read_u8()?.try_into()? {
+                Disco => return Err(RemoteDisco),
+                _ => {}
+            },
+            Orig => {
+                println!("Orig");
+
+                self.pkt_tx.send(Ok(Pkt {
+                    chan: chan.num,
+                    unrel: true,
+                    data: cursor.remaining_slice().into(),
+                }))?;
+            }
+            Split => {
+                println!("Split");
+                dbg!(cursor.remaining_slice());
+            }
+            Rel => {
+                println!("Rel");
+
+                let seqnum = cursor.read_u16::<BigEndian>()?;
+                chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
+
+                while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
+                    self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
+                    chan.seqnum = chan.seqnum.overflowing_add(1).0;
+                }
+            }
+        }
+
+        Ok(())
+    }
+
+    fn handle(&self, res: Result) -> Result {
+        use Error::*;
+
+        match res {
+            Ok(v) => Ok(v),
+            Err(RemoteDisco) => Err(RemoteDisco),
+            Err(LocalDisco) => Err(LocalDisco),
+            Err(e) => Ok(self.pkt_tx.send(Err(e))?),
+        }
+    }
+}