edition = "2021"
[dependencies]
+async-recursion = "1.0.0"
async-trait = "0.1.60"
byteorder = "1.4.3"
num_enum = "0.5.7"
pub enum Error {
IoError(io::Error),
InvalidProtoId(u32),
- InvalidPeerID,
InvalidChannel(u8),
InvalidType(u8),
InvalidCtlType(u8),
+ PeerIDAlreadySet,
+ InvalidChunkIndex(usize, usize),
+ InvalidChunkCount(usize, usize),
RemoteDisco,
LocalDisco,
}
impl From<SendError<InPkt>> for Error {
fn from(_err: SendError<InPkt>) -> Self {
- Self::LocalDisco // technically not a disconnect but a local drop
+ Self::LocalDisco
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use Error::*;
- write!(f, "RUDP Error: ")?;
+ write!(f, "RUDP error: ")?;
match self {
- IoError(err) => write!(f, "IO Error: {}", err),
- InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
- InvalidPeerID => write!(f, "Invalid Peer ID"),
- InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
- InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
- InvalidCtlType(tp) => write!(f, "Invalid Control Type: {tp}"),
- RemoteDisco => write!(f, "Remote Disconnected"),
- LocalDisco => write!(f, "Local Disconnected"),
+ IoError(err) => write!(f, "IO error: {}", err),
+ InvalidProtoId(id) => write!(f, "invalid protocol ID: {id}"),
+ InvalidChannel(ch) => write!(f, "invalid channel: {ch}"),
+ InvalidType(tp) => write!(f, "invalid type: {tp}"),
+ InvalidCtlType(tp) => write!(f, "invalid control type: {tp}"),
+ PeerIDAlreadySet => write!(f, "peer ID already set"),
+ InvalidChunkIndex(i, n) => write!(f, "chunk index {i} bigger than chunk count {n}"),
+ InvalidChunkCount(o, n) => write!(f, "chunk count changed from {o} to {n}"),
+ RemoteDisco => write!(f, "remote disconnected"),
+ LocalDisco => write!(f, "local disconnected"),
}
}
}
-#![feature(yeet_expr)]
#![feature(cursor_remaining)]
#![feature(hash_drain_filter)]
+#![feature(once_cell)]
mod client;
pub mod error;
mod recv_worker;
use byteorder::{BigEndian, WriteBytesExt};
pub use client::{connect, Sender as Client};
use num_enum::TryFromPrimitive;
-use std::future::Future;
use std::{
io::{self, Write},
ops,
sync::Arc,
};
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, RwLock};
pub const PROTO_ID: u32 = 0x4f457403;
pub const UDP_PKT_SIZE: usize = 512;
async fn recv(&self) -> io::Result<Vec<u8>>;
}
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[repr(u16)]
pub enum PeerID {
Nil = 0,
Srv,
#[derive(Debug)]
pub struct RudpShare<S: UdpSender> {
pub id: u16,
- pub remote_id: u16,
+ pub remote_id: RwLock<u16>,
pub chans: Vec<AckChan>,
udp_tx: S,
}
pub async fn send(&self, tp: PktType, pkt: Pkt<&[u8]>) -> io::Result<()> {
let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
buf.write_u32::<BigEndian>(PROTO_ID)?;
- buf.write_u16::<BigEndian>(self.remote_id)?;
+ buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
buf.write_u8(pkt.chan as u8)?;
buf.write_u8(tp as u8)?;
buf.write(pkt.data)?;
pub async fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
self.share.send(PktType::Orig, pkt).await // TODO
}
+
+ pub async fn peer_id(&self) -> u16 {
+ self.share.id
+ }
+
+ pub async fn is_server(&self) -> bool {
+ self.share.id == PeerID::Srv as u16
+ }
+}
+
+impl<S: UdpSender> RudpReceiver<S> {
+ pub async fn peer_id(&self) -> u16 {
+ self.share.id
+ }
+
+ pub async fn is_server(&self) -> bool {
+ self.share.id == PeerID::Srv as u16
+ }
}
impl<S: UdpSender> ops::Deref for RudpReceiver<S> {
let share = Arc::new(RudpShare {
id,
- remote_id,
+ remote_id: RwLock::new(remote_id),
udp_tx,
chans: (0..NUM_CHANS).map(|_| AckChan).collect(),
});
#[tokio::main]
async fn main() -> io::Result<()> {
- //println!("{}", x.deep_size_of());
let (tx, mut rx) = connect("127.0.0.1:30000").await?;
let mut mtpkt = vec![];
use crate::{error::Error, *};
+use async_recursion::async_recursion;
use byteorder::{BigEndian, ReadBytesExt};
use std::{
- cell::Cell,
+ cell::{Cell, OnceCell},
collections::HashMap,
io,
sync::{Arc, Weak},
(seqnum as usize) & (REL_BUFFER - 1)
}
-type Result = std::result::Result<(), Error>;
+type Result<T> = std::result::Result<T, Error>;
struct Split {
- timestamp: time::Instant,
+ timestamp: Option<time::Instant>,
+ chunks: Vec<OnceCell<Vec<u8>>>,
+ got: usize,
}
struct Chan {
- packets: Vec<Cell<Option<Vec<u8>>>>, // in the good old days this used to be called char **
+ packets: Vec<Cell<Option<Vec<u8>>>>, // char ** 😛
splits: HashMap<u16, Split>,
seqnum: u16,
num: u8,
let mut ch = chan.lock().await;
ch.splits = ch
.splits
- .drain_filter(|_k, v| v.timestamp.elapsed() < timeout)
+ .drain_filter(
+ |_k, v| !matches!(v.timestamp, Some(t) if t.elapsed() < timeout),
+ )
.collect();
}
}
}
- async fn recv_pkt(&self) -> Result {
+ async fn recv_pkt(&self) -> Result<()> {
use Error::*;
// todo: reset timeout
let proto_id = cursor.read_u32::<BigEndian>()?;
if proto_id != PROTO_ID {
- do yeet InvalidProtoId(proto_id);
+ return Err(InvalidProtoId(proto_id));
}
- let peer_id = cursor.read_u16::<BigEndian>()?;
+ let _peer_id = cursor.read_u16::<BigEndian>()?;
let n_chan = cursor.read_u8()?;
let mut chan = self
.lock()
.await;
- self.process_pkt(cursor, &mut chan)
+ self.process_pkt(cursor, true, &mut chan).await
}
- fn process_pkt(&self, mut cursor: io::Cursor<Vec<u8>>, chan: &mut Chan) -> Result {
- use CtlType::*;
+ #[async_recursion]
+ async fn process_pkt(
+ &self,
+ mut cursor: io::Cursor<Vec<u8>>,
+ unrel: bool,
+ chan: &mut Chan,
+ ) -> Result<()> {
use Error::*;
- use PktType::*;
match cursor.read_u8()?.try_into()? {
- Ctl => match cursor.read_u8()?.try_into()? {
- Disco => return Err(RemoteDisco),
- _ => {}
+ PktType::Ctl => match cursor.read_u8()?.try_into()? {
+ CtlType::Ack => { /* TODO */ }
+ CtlType::SetPeerID => {
+ let mut id = self.share.remote_id.write().await;
+
+ if *id != PeerID::Nil as u16 {
+ return Err(PeerIDAlreadySet);
+ }
+
+ *id = cursor.read_u16::<BigEndian>()?;
+ }
+ CtlType::Ping => {}
+ CtlType::Disco => return Err(RemoteDisco),
},
- Orig => {
+ PktType::Orig => {
println!("Orig");
self.pkt_tx.send(Ok(Pkt {
chan: chan.num,
- unrel: true,
+ unrel,
data: cursor.remaining_slice().into(),
}))?;
}
- Split => {
+ PktType::Split => {
println!("Split");
- dbg!(cursor.remaining_slice());
+
+ let seqnum = cursor.read_u16::<BigEndian>()?;
+ let chunk_index = cursor.read_u16::<BigEndian>()? as usize;
+ let chunk_count = cursor.read_u16::<BigEndian>()? as usize;
+
+ let mut split = chan.splits.entry(seqnum).or_insert_with(|| Split {
+ got: 0,
+ chunks: (0..chunk_count).map(|_| OnceCell::new()).collect(),
+ timestamp: None,
+ });
+
+ if split.chunks.len() != chunk_count {
+ return Err(InvalidChunkCount(split.chunks.len(), chunk_count));
+ }
+
+ if split
+ .chunks
+ .get(chunk_index)
+ .ok_or(InvalidChunkIndex(chunk_index, chunk_count))?
+ .set(cursor.remaining_slice().into())
+ .is_ok()
+ {
+ split.got += 1;
+ }
+
+ split.timestamp = if unrel {
+ Some(time::Instant::now())
+ } else {
+ None
+ };
+
+ if split.got == chunk_count {
+ self.pkt_tx.send(Ok(Pkt {
+ chan: chan.num,
+ unrel,
+ data: split
+ .chunks
+ .iter()
+ .flat_map(|chunk| chunk.get().unwrap().iter())
+ .copied()
+ .collect(),
+ }))?;
+
+ chan.splits.remove(&seqnum);
+ }
}
- Rel => {
+ PktType::Rel => {
println!("Rel");
let seqnum = cursor.read_u16::<BigEndian>()?;
chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
- while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
- self.handle(self.process_pkt(io::Cursor::new(pkt), chan))?;
+ fn next_pkt(chan: &mut Chan) -> Option<Vec<u8>> {
+ chan.packets[to_seqnum(chan.seqnum)].take()
+ }
+
+ while let Some(pkt) = next_pkt(chan) {
+ self.handle(self.process_pkt(io::Cursor::new(pkt), false, chan).await)?;
chan.seqnum = chan.seqnum.overflowing_add(1).0;
}
}
Ok(())
}
- fn handle(&self, res: Result) -> Result {
+ fn handle(&self, res: Result<()>) -> Result<()> {
use Error::*;
match res {