From 58100eb80dc20283e3b4de178082ef17f6213551 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Thu, 22 Dec 2022 23:02:33 +0100 Subject: [PATCH] files --- src/client.rs | 52 ++++++++ src/error.rs | 57 +++++++++ src/main.rs | 305 +++++++++++++++------------------------------ src/recv_worker.rs | 136 ++++++++++++++++++++ 4 files changed, 343 insertions(+), 207 deletions(-) create mode 100644 src/client.rs create mode 100644 src/error.rs create mode 100644 src/recv_worker.rs diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..e506a3e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,52 @@ +use crate::{PeerID, UdpReceiver, UdpSender}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; +use std::{ + cell::Cell, + fmt, + io::{self, Write}, + net, ops, + sync::{mpsc, Arc}, + thread, +}; + +pub struct Sender { + sock: Arc, +} + +impl UdpSender for Sender { + fn send(&self, data: Vec) -> io::Result<()> { + self.sock.send(&data)?; + Ok(()) + } +} + +pub struct Receiver { + sock: Arc, +} + +impl UdpReceiver for Receiver { + fn recv(&self) -> io::Result> { + let mut buffer = Vec::new(); + buffer.resize(crate::UDP_PKT_SIZE, 0); + + let len = self.sock.recv(&mut buffer)?; + buffer.truncate(len); + + Ok(buffer) + } +} + +pub fn connect(addr: &str) -> io::Result<(crate::RudpSender, crate::RudpReceiver)> { + let sock = Arc::new(net::UdpSocket::bind("0.0.0.0:0")?); + sock.connect(addr)?; + + Ok(crate::new( + PeerID::Srv as u16, + PeerID::Nil as u16, + Sender { + sock: Arc::clone(&sock), + }, + Receiver { sock }, + )) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..02080c7 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,57 @@ +use crate::{CtlType, InPkt, PktType}; +use num_enum::TryFromPrimitiveError; +use std::{fmt, io, sync::mpsc}; + +#[derive(Debug)] +pub enum Error { + IoError(io::Error), + InvalidProtoId(u32), + InvalidPeerID, + InvalidChannel(u8), + InvalidType(u8), + InvalidCtlType(u8), + RemoteDisco, + LocalDisco, +} + +impl From for Error { + fn from(err: io::Error) -> Self { + Self::IoError(err) + } +} + +impl From> for Error { + fn from(err: TryFromPrimitiveError) -> Self { + Self::InvalidType(err.number) + } +} + +impl From> for Error { + fn from(err: TryFromPrimitiveError) -> Self { + Self::InvalidType(err.number) + } +} + +impl From> for Error { + fn from(_err: mpsc::SendError) -> Self { + Self::LocalDisco // technically not a disconnect but a local drop + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Error::*; + write!(f, "RUDP Error: ")?; + + match self { + IoError(err) => write!(f, "IO Error: {}", err), + InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"), + InvalidPeerID => write!(f, "Invalid Peer ID"), + InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"), + InvalidType(tp) => write!(f, "Invalid Type: {tp}"), + InvalidCtlType(tp) => write!(f, "Invalid Control Type: {tp}"), + RemoteDisco => write!(f, "Remote Disconnected"), + LocalDisco => write!(f, "Local Disconnected"), + } + } +} diff --git a/src/main.rs b/src/main.rs index 61ca983..da53573 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,23 +1,36 @@ #![feature(yeet_expr)] #![feature(cursor_remaining)] +mod client; +pub mod error; +mod recv_worker; + use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +pub use client::{connect, Sender as Client}; use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; use std::{ - cell::Cell, - fmt, io::{self, Write}, - net, + ops, sync::{mpsc, Arc}, thread, }; pub const PROTO_ID: u32 = 0x4f457403; pub const UDP_PKT_SIZE: usize = 512; -pub const NUM_CHANNELS: usize = 3; +pub const NUM_CHANS: usize = 3; pub const REL_BUFFER: usize = 0x8000; pub const INIT_SEQNUM: u16 = 65500; -#[derive(Debug, Copy, Clone, PartialEq)] +pub type Error = error::Error; + +pub trait UdpSender: Send + Sync + 'static { + fn send(&self, data: Vec) -> io::Result<()>; +} + +pub trait UdpReceiver: Send + Sync + 'static { + fn recv(&self) -> io::Result>; +} + +#[derive(Debug, Copy, Clone)] pub enum PeerID { Nil = 0, Srv, @@ -33,6 +46,15 @@ pub enum PktType { Rel, } +#[derive(Debug, Copy, Clone, PartialEq, TryFromPrimitive)] +#[repr(u8)] +pub enum CtlType { + Ack = 0, + SetPeerID, + Ping, + Disco, +} + #[derive(Debug)] pub struct Pkt { unrel: bool, @@ -40,252 +62,121 @@ pub struct Pkt { data: T, } -#[derive(Debug)] -pub enum Error { - IoError(io::Error), - InvalidProtoId(u32), - InvalidPeerID, - InvalidChannel(u8), - InvalidType(u8), - LocalHangup, -} - -impl From for Error { - fn from(err: io::Error) -> Self { - Self::IoError(err) - } -} +pub type InPkt = Result>, Error>; -impl From> for Error { - fn from(err: TryFromPrimitiveError) -> Self { - Self::InvalidType(err.number) - } -} - -impl From> for Error { - fn from(_err: mpsc::SendError) -> Self { - Self::LocalHangup - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use Error::*; - write!(f, "RUDP Error: ")?; - - match self { - IoError(err) => write!(f, "IO Error: {}", err), - InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"), - InvalidPeerID => write!(f, "Invalid Peer ID"), - InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"), - InvalidType(tp) => write!(f, "Invalid Type: {tp}"), - LocalHangup => write!(f, "Local packet receiver hung up"), - } - } -} +#[derive(Debug)] +pub struct AckChan; #[derive(Debug)] -struct Channel { - num: u8, +pub struct RudpShare { + pub id: u16, + pub remote_id: u16, + pub chans: Vec, + udp_tx: S, } -type RelPkt = Cell>>; - -struct RecvChannel<'a> { - packets: Vec, // used to be called char ** - seqnum: u16, - main: &'a Channel, +#[derive(Debug)] +pub struct RudpReceiver { + share: Arc>, + pkt_rx: mpsc::Receiver, } -pub type PktResult = Result>, Error>; -type PktSender = mpsc::Sender; - -trait HandleError { - fn handle(&self, res: Result<(), Error>) -> bool; +#[derive(Debug)] +pub struct RudpSender { + share: Arc>, } -impl HandleError for PktSender { - fn handle(&self, res: Result<(), Error>) -> bool { - if let Err(err) = res { - if !self.send(Err(err)).is_ok() { - return false; - } +impl RudpShare { + pub fn new(id: u16, remote_id: u16, udp_tx: S) -> Self { + Self { + id, + remote_id, + udp_tx, + chans: (0..NUM_CHANS).map(|_| AckChan).collect(), } - - true } -} - -fn to_seqnum(seqnum: u16) -> usize { - (seqnum as usize) & (REL_BUFFER - 1) -} -#[derive(Debug)] -struct ConnInner { - sock: net::UdpSocket, - id: u16, - remote_id: u16, - chans: Vec, -} - -impl ConnInner { - pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { + pub fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> { let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len()); buf.write_u32::(PROTO_ID)?; buf.write_u16::(self.remote_id)?; buf.write_u8(pkt.chan as u8)?; - buf.write_u8(PktType::Orig as u8)?; + buf.write_u8(tp as u8)?; buf.write(pkt.data)?; - self.sock.send(&buf)?; + self.udp_tx.send(buf)?; Ok(()) } +} - fn recv_loop(&self, tx: PktSender) { - let mut inbox = [0; UDP_PKT_SIZE]; - - let mut recv_chans = self - .chans - .iter() - .map(|main| RecvChannel { - main, - packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(), - seqnum: INIT_SEQNUM, - }) - .collect(); - - while tx.handle(self.recv_pkt(&mut inbox, &mut recv_chans, &tx)) {} - } - - fn recv_pkt( - &self, - buffer: &mut [u8], - chans: &mut Vec, - tx: &PktSender, - ) -> Result<(), Error> { - use Error::*; - - // todo: reset timeout - let len = self.sock.recv(buffer)?; - let mut cursor = io::Cursor::new(&buffer[..len]); - - let proto_id = cursor.read_u32::()?; - if proto_id != PROTO_ID { - do yeet InvalidProtoId(proto_id); - } - - let peer_id = cursor.read_u16::()?; - - let n_chan = cursor.read_u8()?; - let chan = chans - .get_mut(n_chan as usize) - .ok_or(InvalidChannel(n_chan))?; - - self.process_pkt(cursor, chan, tx) +impl RudpSender { + pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { + self.share.send(PktType::Orig, pkt) // TODO } +} - fn process_pkt( - &self, - mut cursor: io::Cursor<&[u8]>, - chan: &mut RecvChannel, - tx: &PktSender, - ) -> Result<(), Error> { - use PktType::*; - - match cursor.read_u8()?.try_into()? { - Ctl => { - dbg!("Ctl"); - dbg!(cursor.remaining_slice()); - } - Orig => { - tx.send(Ok(Pkt { - chan: chan.main.num, - unrel: true, - data: cursor.remaining_slice().into(), - }))?; - } - Split => { - dbg!("Split"); - dbg!(cursor.remaining_slice()); - } - Rel => { - let seqnum = cursor.read_u16::()?; - chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); - - while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() { - tx.handle(self.process_pkt(io::Cursor::new(&pkt), chan, tx)); - chan.seqnum = chan.seqnum.overflowing_add(1).0; - } - } - } +impl ops::Deref for RudpReceiver { + type Target = mpsc::Receiver; - Ok(()) + fn deref(&self) -> &Self::Target { + &self.pkt_rx } } -#[derive(Debug)] -pub struct Conn { - inner: Arc, - rx: mpsc::Receiver, -} - -impl Conn { - pub fn connect(addr: &str) -> io::Result { - let (tx, rx) = mpsc::channel(); - - let inner = Arc::new(ConnInner { - sock: net::UdpSocket::bind("0.0.0.0:0")?, - id: PeerID::Srv as u16, - remote_id: PeerID::Nil as u16, - chans: (0..NUM_CHANNELS as u8).map(|num| Channel { num }).collect(), - }); - - inner.sock.connect(addr)?; - - let recv_inner = Arc::clone(&inner); - thread::spawn(move || { - recv_inner.recv_loop(tx); - }); +pub fn new( + id: u16, + remote_id: u16, + udp_tx: S, + udp_rx: R, +) -> (RudpSender, RudpReceiver) { + let (pkt_tx, pkt_rx) = mpsc::channel(); - Ok(Conn { inner, rx }) - } + let share = Arc::new(RudpShare::new(id, remote_id, udp_tx)); + let recv_share = Arc::clone(&share); - pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> { - self.inner.send(pkt) - } + thread::spawn(move || { + recv_worker::RecvWorker::new(udp_rx, recv_share, pkt_tx).run(); + }); - pub fn recv(&self) -> Result { - self.rx.recv() - } + ( + RudpSender { + share: Arc::clone(&share), + }, + RudpReceiver { share, pkt_rx }, + ) } -fn main() { +// connect + +fn main() -> io::Result<()> { //println!("{}", x.deep_size_of()); - let conn = Conn::connect("127.0.0.1:30000").expect("the spanish inquisition"); + let (tx, rx) = connect("127.0.0.1:30000")?; let mut mtpkt = vec![]; - mtpkt.write_u16::(2).unwrap(); // high level type - mtpkt.write_u8(29).unwrap(); // serialize ver - mtpkt.write_u16::(0).unwrap(); // compression modes - mtpkt.write_u16::(40).unwrap(); // MinProtoVer - mtpkt.write_u16::(40).unwrap(); // MaxProtoVer - mtpkt.write_u16::(3).unwrap(); // player name length - mtpkt.write(b"foo").unwrap(); // player name - - conn.send(Pkt { + mtpkt.write_u16::(2)?; // high level type + mtpkt.write_u8(29)?; // serialize ver + mtpkt.write_u16::(0)?; // compression modes + mtpkt.write_u16::(40)?; // MinProtoVer + mtpkt.write_u16::(40)?; // MaxProtoVer + mtpkt.write_u16::(3)?; // player name length + mtpkt.write(b"foo")?; // player name + + tx.send(Pkt { unrel: true, chan: 1, data: &mtpkt, - }) - .unwrap(); + })?; - while let Ok(result) = conn.recv() { + while let Ok(result) = rx.recv() { match result { Ok(pkt) => { - io::stdout().write(pkt.data.as_slice()).unwrap(); + io::stdout().write(pkt.data.as_slice())?; } Err(err) => eprintln!("Error: {}", err), } } + println!("disco"); + + Ok(()) } diff --git a/src/recv_worker.rs b/src/recv_worker.rs new file mode 100644 index 0000000..d1ae5b1 --- /dev/null +++ b/src/recv_worker.rs @@ -0,0 +1,136 @@ +use crate::{error::Error, CtlType, InPkt, Pkt, PktType, RudpShare, UdpReceiver, UdpSender}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use num_enum::{TryFromPrimitive, TryFromPrimitiveError}; +use std::{ + cell::Cell, + io, result, + sync::{mpsc, Arc}, +}; + +fn to_seqnum(seqnum: u16) -> usize { + (seqnum as usize) & (crate::REL_BUFFER - 1) +} + +struct RelChan { + packets: Vec>>>, // in the good old days this used to be called char ** + seqnum: u16, + num: u8, +} + +type PktTx = mpsc::Sender; +type Result = result::Result<(), Error>; + +pub struct RecvWorker { + share: Arc>, + pkt_tx: PktTx, + udp_rx: R, +} + +impl RecvWorker { + pub fn new(udp_rx: R, share: Arc>, pkt_tx: PktTx) -> Self { + Self { + udp_rx, + share, + pkt_tx, + } + } + + pub fn run(&self) { + let mut recv_chans = (0..crate::NUM_CHANS as u8) + .map(|num| RelChan { + num, + packets: (0..crate::REL_BUFFER).map(|_| Cell::new(None)).collect(), + seqnum: crate::INIT_SEQNUM, + }) + .collect(); + + loop { + if let Err(e) = self.handle(self.recv_pkt(&mut recv_chans)) { + if let Error::LocalDisco = e { + self.share + .send( + PktType::Ctl, + Pkt { + unrel: true, + chan: 0, + data: &[CtlType::Disco as u8], + }, + ) + .ok(); + } + break; + } + } + } + + fn recv_pkt(&self, chans: &mut Vec) -> Result { + use Error::*; + + // todo: reset timeout + let mut cursor = io::Cursor::new(self.udp_rx.recv()?); + + let proto_id = cursor.read_u32::()?; + if proto_id != crate::PROTO_ID { + do yeet InvalidProtoId(proto_id); + } + + let peer_id = cursor.read_u16::()?; + + let n_chan = cursor.read_u8()?; + let chan = chans + .get_mut(n_chan as usize) + .ok_or(InvalidChannel(n_chan))?; + + self.process_pkt(cursor, chan) + } + + fn process_pkt(&self, mut cursor: io::Cursor>, chan: &mut RelChan) -> Result { + use CtlType::*; + use Error::*; + use PktType::*; + + match cursor.read_u8()?.try_into()? { + Ctl => match cursor.read_u8()?.try_into()? { + Disco => return Err(RemoteDisco), + _ => {} + }, + Orig => { + println!("Orig"); + + self.pkt_tx.send(Ok(Pkt { + chan: chan.num, + unrel: true, + data: cursor.remaining_slice().into(), + }))?; + } + Split => { + println!("Split"); + dbg!(cursor.remaining_slice()); + } + Rel => { + println!("Rel"); + + let seqnum = cursor.read_u16::()?; + chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into())); + + while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() { + self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?; + chan.seqnum = chan.seqnum.overflowing_add(1).0; + } + } + } + + Ok(()) + } + + fn handle(&self, res: Result) -> Result { + use Error::*; + + match res { + Ok(v) => Ok(v), + Err(RemoteDisco) => Err(RemoteDisco), + Err(LocalDisco) => Err(LocalDisco), + Err(e) => Ok(self.pkt_tx.send(Err(e))?), + } + } +} -- 2.44.0