]> git.lizzy.rs Git - mt_rudp.git/blob - src/main.rs
a
[mt_rudp.git] / src / main.rs
1 #![feature(yeet_expr)]
2 #![feature(cursor_remaining)]
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
5 use std::{
6     cell::Cell,
7     fmt,
8     io::{self, Write},
9     net,
10     sync::{mpsc, Arc},
11     thread,
12 };
13
14 pub const PROTO_ID: u32 = 0x4f457403;
15 pub const UDP_PKT_SIZE: usize = 512;
16 pub const NUM_CHANNELS: usize = 3;
17 pub const REL_BUFFER: usize = 0x8000;
18 pub const INIT_SEQNUM: u16 = 65500;
19
20 #[derive(Debug, Copy, Clone, PartialEq)]
21 pub enum PeerID {
22     Nil = 0,
23     Srv,
24     CltMin,
25 }
26
27 #[derive(Debug, Copy, Clone, PartialEq, TryFromPrimitive)]
28 #[repr(u8)]
29 pub enum PktType {
30     Ctl = 0,
31     Orig,
32     Split,
33     Rel,
34 }
35
36 #[derive(Debug)]
37 pub struct Pkt<T> {
38     unrel: bool,
39     chan: u8,
40     data: T,
41 }
42
43 #[derive(Debug)]
44 pub enum Error {
45     IoError(io::Error),
46     InvalidProtoId(u32),
47     InvalidPeerID,
48     InvalidChannel(u8),
49     InvalidType(u8),
50     LocalHangup,
51 }
52
53 impl From<io::Error> for Error {
54     fn from(err: io::Error) -> Self {
55         Self::IoError(err)
56     }
57 }
58
59 impl From<TryFromPrimitiveError<PktType>> for Error {
60     fn from(err: TryFromPrimitiveError<PktType>) -> Self {
61         Self::InvalidType(err.number)
62     }
63 }
64
65 impl From<mpsc::SendError<PktResult>> for Error {
66     fn from(err: mpsc::SendError<PktResult>) -> Self {
67         Self::LocalHangup
68     }
69 }
70
71 impl fmt::Display for Error {
72     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73         use Error::*;
74         write!(f, "RUDP Error: ")?;
75
76         match self {
77             IoError(err) => write!(f, "IO Error: {}", err),
78             InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
79             InvalidPeerID => write!(f, "Invalid Peer ID"),
80             InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
81             InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
82             LocalHangup => write!(f, "Local packet receiver hung up"),
83         }
84     }
85 }
86
87 #[derive(Debug)]
88 struct Channel {}
89
90 #[derive(Debug)]
91 struct RecvChannel<'a> {
92     packets: Vec<Option<Vec<u8>>>, // used to be called char **
93     seqnum: u16,
94     chan: &'a Channel,
95 }
96
97 pub type PktResult = Result<Pkt<Vec<u8>>, Error>;
98 type PktSender = mpsc::Sender<PktResult>;
99
100 #[derive(Debug)]
101 struct ConnInner {
102     sock: net::UdpSocket,
103     id: u16,
104     remote_id: u16,
105     chans: Vec<Channel>,
106 }
107
108 impl ConnInner {
109     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
110         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
111         buf.write_u32::<BigEndian>(PROTO_ID)?;
112         buf.write_u16::<BigEndian>(self.remote_id)?;
113         buf.write_u8(pkt.chan as u8)?;
114         buf.write_u8(PktType::Orig as u8)?;
115         buf.write(pkt.data)?;
116
117         self.sock.send(&buf)?;
118
119         Ok(())
120     }
121
122     fn recv_loop(&self, tx: PktSender) {
123         let mut inbox = [0; UDP_PKT_SIZE];
124
125         let mut recv_chans = self.channels.map(|chan| RecvChannel {
126             chan,
127             packets: (0..REL_BUFFER).map(|_| Cell::new(None)),
128             seqnum: INIT_SEQNUM,
129         });
130
131         loop {
132             if let Err(err) = self.recv_pkt(&mut inbox, &mut recv_chans, &tx) {
133                 if !tx.send(Err(err)).is_ok() {
134                     break;
135                 }
136             }
137         }
138     }
139
140     fn recv_pkt(
141         &self,
142         buffer: &mut [u8],
143         chans: &mut Vec<RecvChannel>,
144         tx: &PktSender,
145     ) -> Result<(), Error> {
146         use Error::*;
147         use PktType::*;
148
149         // todo: reset timeout
150         let len = self.sock.recv(buffer)?;
151         let mut cursor = io::Cursor::new(&buffer[..len]);
152
153         let proto_id = cursor.read_u32::<BigEndian>()?;
154         if proto_id != PROTO_ID {
155             do yeet InvalidProtoId(proto_id);
156         }
157
158         let peer_id = cursor.read_u16::<BigEndian>()?;
159
160         let n_channel = cursor.read_u8()?;
161         let mut channel = self
162             .chans
163             .get_mut(n_channel as usize)
164             .ok_or(InvalidChannel(n_channel))?;
165
166         self.process_pkt(cursor, channel);
167     }
168
169     fn process_pkt(
170         &self,
171         mut cursor: io::Cursor<&[u8]>,
172         chan: &mut RecvChannel,
173     ) -> Result<(), Error> {
174         match cursor.read_u8()?.try_into()? {
175             Ctl => {
176                 dbg!(cursor.remaining_slice());
177             }
178             Orig => {
179                 tx.send(Ok(Pkt {
180                     chan: n_channel,
181                     unrel: true,
182                     data: cursor.remaining_slice().into(),
183                 }))?;
184             }
185             Split => {
186                 dbg!(cursor.remaining_slice());
187             }
188             Rel => {
189                 let seqnum = cursor.read_u16::<BigEndian>()?;
190                 chan.packets[seqnum].set(cursor.remaining_slice().into());
191
192                 while Some(pkt) = chan.packets[chan.seqnum].take() {
193                     self.process_pkt(io::Cursor::new(&pkt), chan)?;
194                     chan.seqnum.overflowing_add(1);
195                 }
196             }
197         }
198
199         Ok(())
200     }
201 }
202
203 #[derive(Debug)]
204 pub struct Conn {
205     inner: Arc<ConnInner>,
206     rx: mpsc::Receiver<PktResult>,
207 }
208
209 impl Conn {
210     pub fn connect(addr: &str) -> io::Result<Self> {
211         let (tx, rx) = mpsc::channel();
212
213         let inner = Arc::new(ConnInner {
214             sock: net::UdpSocket::bind("0.0.0.0:0")?,
215             id: PeerID::Srv as u16,
216             remote_id: PeerID::Nil as u16,
217             chans: (0..NUM_CHANNELS).map(|_| Channel {}).collect(),
218         });
219
220         inner.sock.connect(addr)?;
221
222         let recv_inner = Arc::clone(&inner);
223         thread::spawn(move || {
224             recv_inner.recv_loop(tx);
225         });
226
227         Ok(Conn { inner, rx })
228     }
229
230     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
231         self.inner.send(pkt)
232     }
233
234     pub fn recv(&self) -> Result<PktResult, mpsc::RecvError> {
235         self.rx.recv()
236     }
237 }
238
239 fn main() {
240     //println!("{}", x.deep_size_of());
241     let conn = Conn::connect("127.0.0.1:30000").expect("the spanish inquisition");
242
243     let mut mtpkt = vec![];
244     mtpkt.write_u16::<BigEndian>(2).unwrap(); // high level type
245     mtpkt.write_u8(29).unwrap(); // serialize ver
246     mtpkt.write_u16::<BigEndian>(0).unwrap(); // compression modes
247     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MinProtoVer
248     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MaxProtoVer
249     mtpkt.write_u16::<BigEndian>(3).unwrap(); // player name length
250     mtpkt.write(b"foo").unwrap(); // player name
251
252     conn.send(Pkt {
253         unrel: true,
254         chan: 1,
255         data: &mtpkt,
256     })
257     .unwrap();
258
259     while let Ok(result) = conn.recv() {
260         match result {
261             Ok(pkt) => {
262                 io::stdout().write(pkt.data.as_slice()).unwrap();
263             }
264             Err(err) => eprintln!("Error: {}", err),
265         }
266     }
267 }