]> git.lizzy.rs Git - mt_rudp.git/blob - src/main.rs
b
[mt_rudp.git] / src / main.rs
1 #![feature(yeet_expr)]
2 #![feature(cursor_remaining)]
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
5 use std::{
6     cell::Cell,
7     fmt,
8     io::{self, Write},
9     net,
10     sync::{mpsc, Arc},
11     thread,
12 };
13
14 pub const PROTO_ID: u32 = 0x4f457403;
15 pub const UDP_PKT_SIZE: usize = 512;
16 pub const NUM_CHANNELS: usize = 3;
17 pub const REL_BUFFER: usize = 0x8000;
18 pub const INIT_SEQNUM: u16 = 65500;
19
20 #[derive(Debug, Copy, Clone, PartialEq)]
21 pub enum PeerID {
22     Nil = 0,
23     Srv,
24     CltMin,
25 }
26
27 #[derive(Debug, Copy, Clone, PartialEq, TryFromPrimitive)]
28 #[repr(u8)]
29 pub enum PktType {
30     Ctl = 0,
31     Orig,
32     Split,
33     Rel,
34 }
35
36 #[derive(Debug)]
37 pub struct Pkt<T> {
38     unrel: bool,
39     chan: u8,
40     data: T,
41 }
42
43 #[derive(Debug)]
44 pub enum Error {
45     IoError(io::Error),
46     InvalidProtoId(u32),
47     InvalidPeerID,
48     InvalidChannel(u8),
49     InvalidType(u8),
50     LocalHangup,
51 }
52
53 impl From<io::Error> for Error {
54     fn from(err: io::Error) -> Self {
55         Self::IoError(err)
56     }
57 }
58
59 impl From<TryFromPrimitiveError<PktType>> for Error {
60     fn from(err: TryFromPrimitiveError<PktType>) -> Self {
61         Self::InvalidType(err.number)
62     }
63 }
64
65 impl From<mpsc::SendError<PktResult>> for Error {
66     fn from(err: mpsc::SendError<PktResult>) -> Self {
67         Self::LocalHangup
68     }
69 }
70
71 impl fmt::Display for Error {
72     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73         use Error::*;
74         write!(f, "RUDP Error: ")?;
75
76         match self {
77             IoError(err) => write!(f, "IO Error: {}", err),
78             InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
79             InvalidPeerID => write!(f, "Invalid Peer ID"),
80             InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
81             InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
82             LocalHangup => write!(f, "Local packet receiver hung up"),
83         }
84     }
85 }
86
87 #[derive(Debug)]
88 struct Channel {}
89
90 pub type PktResult = Result<Pkt<Vec<u8>>, Error>;
91 type PktSender = mpsc::Sender<PktResult>;
92
93 #[derive(Debug)]
94 struct ConnInner {
95     sock: net::UdpSocket,
96     id: u16,
97     remote_id: u16,
98     chans: Vec<Channel>,
99 }
100
101 struct RecvChannel {
102     packets: Vec<Cell<Option<Vec<u8>>>>, // used to be called char **
103     seqnum: u16,
104     num: u8,
105 }
106
107 struct ConnReceiver {
108     tx: PktSender,
109     inner: Arc<ConnInner>,
110     chans: Vec<RecvChannel>,
111     inbox: [u8; UDP_PKT_SIZE],
112 }
113
114 impl ConnReceiver {
115     pub fn run(inner: Arc<ConnInner>, tx: PktSender) {
116         Self {
117             tx,
118             inner,
119             chans: (0..NUM_CHANNELS as u8)
120                 .map(|num| RecvChannel {
121                     num,
122                     packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
123                     seqnum: INIT_SEQNUM,
124                 })
125                 .collect(),
126             inbox: [0; UDP_PKT_SIZE],
127         }
128         .do_loop();
129     }
130
131     fn handle_err(&self, res: Result<(), Error>) -> bool {
132         if let Err(err) = res {
133             if !self.tx.send(Err(err)).is_ok() {
134                 return false;
135             }
136         }
137
138         true
139     }
140
141     fn do_loop(&mut self) {
142         while self.handle_err(self.recv_pkt()) {}
143     }
144
145     fn recv_pkt(&mut self) -> Result<(), Error> {
146         use Error::*;
147
148         // todo: reset timeout
149         let len = self.inner.sock.recv(&mut self.inbox)?;
150         let mut cursor = io::Cursor::new(&self.inbox[..len]);
151
152         let proto_id = cursor.read_u32::<BigEndian>()?;
153         if proto_id != PROTO_ID {
154             do yeet InvalidProtoId(proto_id);
155         }
156
157         let peer_id = cursor.read_u16::<BigEndian>()?;
158
159         let n_channel = cursor.read_u8()?;
160         let mut channel = self
161             .chans
162             .get(n_channel as usize)
163             .ok_or(InvalidChannel(n_channel))?;
164
165         self.process_pkt(cursor, channel)
166     }
167
168     fn process_pkt(&self, mut cursor: io::Cursor<&[u8]>, chan: &RecvChannel) -> Result<(), Error> {
169         use PktType::*;
170
171         match cursor.read_u8()?.try_into()? {
172             Ctl => {
173                 dbg!(cursor.remaining_slice());
174             }
175             Orig => {
176                 self.tx.send(Ok(Pkt {
177                     chan: chan.num,
178                     unrel: true,
179                     data: cursor.remaining_slice().into(),
180                 }))?;
181             }
182             Split => {
183                 dbg!(cursor.remaining_slice());
184             }
185             Rel => {
186                 let seqnum = cursor.read_u16::<BigEndian>()?;
187                 chan.packets[seqnum as usize].set(Some(cursor.remaining_slice().into()));
188
189                 while let Some(pkt) = chan.packets[chan.seqnum as usize].take() {
190                     self.handle_err(self.process_pkt(io::Cursor::new(&pkt), chan));
191                     chan.seqnum.overflowing_add(1);
192                 }
193             }
194         }
195
196         Ok(())
197     }
198 }
199
200 impl ConnInner {
201     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
202         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
203         buf.write_u32::<BigEndian>(PROTO_ID)?;
204         buf.write_u16::<BigEndian>(self.remote_id)?;
205         buf.write_u8(pkt.chan as u8)?;
206         buf.write_u8(PktType::Orig as u8)?;
207         buf.write(pkt.data)?;
208
209         self.sock.send(&buf)?;
210
211         Ok(())
212     }
213 }
214
215 #[derive(Debug)]
216 pub struct Conn {
217     inner: Arc<ConnInner>,
218     rx: mpsc::Receiver<PktResult>,
219 }
220
221 impl Conn {
222     pub fn connect(addr: &str) -> io::Result<Self> {
223         let (tx, rx) = mpsc::channel();
224
225         let inner = Arc::new(ConnInner {
226             sock: net::UdpSocket::bind("0.0.0.0:0")?,
227             id: PeerID::Srv as u16,
228             remote_id: PeerID::Nil as u16,
229             chans: (0..NUM_CHANNELS).map(|_| Channel {}).collect(),
230         });
231
232         inner.sock.connect(addr)?;
233
234         let recv_inner = Arc::clone(&inner);
235         thread::spawn(move || {
236             ConnReceiver::run(recv_inner, tx);
237         });
238
239         Ok(Conn { inner, rx })
240     }
241
242     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
243         self.inner.send(pkt)
244     }
245
246     pub fn recv(&self) -> Result<PktResult, mpsc::RecvError> {
247         self.rx.recv()
248     }
249 }
250
251 fn main() {
252     //println!("{}", x.deep_size_of());
253     let conn = Conn::connect("127.0.0.1:30000").expect("the spanish inquisition");
254
255     let mut mtpkt = vec![];
256     mtpkt.write_u16::<BigEndian>(2).unwrap(); // high level type
257     mtpkt.write_u8(29).unwrap(); // serialize ver
258     mtpkt.write_u16::<BigEndian>(0).unwrap(); // compression modes
259     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MinProtoVer
260     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MaxProtoVer
261     mtpkt.write_u16::<BigEndian>(3).unwrap(); // player name length
262     mtpkt.write(b"foo").unwrap(); // player name
263
264     conn.send(Pkt {
265         unrel: true,
266         chan: 1,
267         data: &mtpkt,
268     })
269     .unwrap();
270
271     while let Ok(result) = conn.recv() {
272         match result {
273             Ok(pkt) => {
274                 io::stdout().write(pkt.data.as_slice()).unwrap();
275             }
276             Err(err) => eprintln!("Error: {}", err),
277         }
278     }
279 }