]> git.lizzy.rs Git - mt_rudp.git/blob - src/main.rs
61ca9830234e009dcc171c3e4cae41b37e560442
[mt_rudp.git] / src / main.rs
1 #![feature(yeet_expr)]
2 #![feature(cursor_remaining)]
3 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4 use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
5 use std::{
6     cell::Cell,
7     fmt,
8     io::{self, Write},
9     net,
10     sync::{mpsc, Arc},
11     thread,
12 };
13
14 pub const PROTO_ID: u32 = 0x4f457403;
15 pub const UDP_PKT_SIZE: usize = 512;
16 pub const NUM_CHANNELS: usize = 3;
17 pub const REL_BUFFER: usize = 0x8000;
18 pub const INIT_SEQNUM: u16 = 65500;
19
20 #[derive(Debug, Copy, Clone, PartialEq)]
21 pub enum PeerID {
22     Nil = 0,
23     Srv,
24     CltMin,
25 }
26
27 #[derive(Debug, Copy, Clone, PartialEq, TryFromPrimitive)]
28 #[repr(u8)]
29 pub enum PktType {
30     Ctl = 0,
31     Orig,
32     Split,
33     Rel,
34 }
35
36 #[derive(Debug)]
37 pub struct Pkt<T> {
38     unrel: bool,
39     chan: u8,
40     data: T,
41 }
42
43 #[derive(Debug)]
44 pub enum Error {
45     IoError(io::Error),
46     InvalidProtoId(u32),
47     InvalidPeerID,
48     InvalidChannel(u8),
49     InvalidType(u8),
50     LocalHangup,
51 }
52
53 impl From<io::Error> for Error {
54     fn from(err: io::Error) -> Self {
55         Self::IoError(err)
56     }
57 }
58
59 impl From<TryFromPrimitiveError<PktType>> for Error {
60     fn from(err: TryFromPrimitiveError<PktType>) -> Self {
61         Self::InvalidType(err.number)
62     }
63 }
64
65 impl From<mpsc::SendError<PktResult>> for Error {
66     fn from(_err: mpsc::SendError<PktResult>) -> Self {
67         Self::LocalHangup
68     }
69 }
70
71 impl fmt::Display for Error {
72     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73         use Error::*;
74         write!(f, "RUDP Error: ")?;
75
76         match self {
77             IoError(err) => write!(f, "IO Error: {}", err),
78             InvalidProtoId(id) => write!(f, "Invalid Protocol ID: {id}"),
79             InvalidPeerID => write!(f, "Invalid Peer ID"),
80             InvalidChannel(ch) => write!(f, "Invalid Channel: {ch}"),
81             InvalidType(tp) => write!(f, "Invalid Type: {tp}"),
82             LocalHangup => write!(f, "Local packet receiver hung up"),
83         }
84     }
85 }
86
87 #[derive(Debug)]
88 struct Channel {
89     num: u8,
90 }
91
92 type RelPkt = Cell<Option<Vec<u8>>>;
93
94 struct RecvChannel<'a> {
95     packets: Vec<RelPkt>, // used to be called char **
96     seqnum: u16,
97     main: &'a Channel,
98 }
99
100 pub type PktResult = Result<Pkt<Vec<u8>>, Error>;
101 type PktSender = mpsc::Sender<PktResult>;
102
103 trait HandleError {
104     fn handle(&self, res: Result<(), Error>) -> bool;
105 }
106
107 impl HandleError for PktSender {
108     fn handle(&self, res: Result<(), Error>) -> bool {
109         if let Err(err) = res {
110             if !self.send(Err(err)).is_ok() {
111                 return false;
112             }
113         }
114
115         true
116     }
117 }
118
119 fn to_seqnum(seqnum: u16) -> usize {
120     (seqnum as usize) & (REL_BUFFER - 1)
121 }
122
123 #[derive(Debug)]
124 struct ConnInner {
125     sock: net::UdpSocket,
126     id: u16,
127     remote_id: u16,
128     chans: Vec<Channel>,
129 }
130
131 impl ConnInner {
132     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
133         let mut buf = Vec::with_capacity(4 + 2 + 1 + 1 + pkt.data.len());
134         buf.write_u32::<BigEndian>(PROTO_ID)?;
135         buf.write_u16::<BigEndian>(self.remote_id)?;
136         buf.write_u8(pkt.chan as u8)?;
137         buf.write_u8(PktType::Orig as u8)?;
138         buf.write(pkt.data)?;
139
140         self.sock.send(&buf)?;
141
142         Ok(())
143     }
144
145     fn recv_loop(&self, tx: PktSender) {
146         let mut inbox = [0; UDP_PKT_SIZE];
147
148         let mut recv_chans = self
149             .chans
150             .iter()
151             .map(|main| RecvChannel {
152                 main,
153                 packets: (0..REL_BUFFER).map(|_| Cell::new(None)).collect(),
154                 seqnum: INIT_SEQNUM,
155             })
156             .collect();
157
158         while tx.handle(self.recv_pkt(&mut inbox, &mut recv_chans, &tx)) {}
159     }
160
161     fn recv_pkt(
162         &self,
163         buffer: &mut [u8],
164         chans: &mut Vec<RecvChannel>,
165         tx: &PktSender,
166     ) -> Result<(), Error> {
167         use Error::*;
168
169         // todo: reset timeout
170         let len = self.sock.recv(buffer)?;
171         let mut cursor = io::Cursor::new(&buffer[..len]);
172
173         let proto_id = cursor.read_u32::<BigEndian>()?;
174         if proto_id != PROTO_ID {
175             do yeet InvalidProtoId(proto_id);
176         }
177
178         let peer_id = cursor.read_u16::<BigEndian>()?;
179
180         let n_chan = cursor.read_u8()?;
181         let chan = chans
182             .get_mut(n_chan as usize)
183             .ok_or(InvalidChannel(n_chan))?;
184
185         self.process_pkt(cursor, chan, tx)
186     }
187
188     fn process_pkt(
189         &self,
190         mut cursor: io::Cursor<&[u8]>,
191         chan: &mut RecvChannel,
192         tx: &PktSender,
193     ) -> Result<(), Error> {
194         use PktType::*;
195
196         match cursor.read_u8()?.try_into()? {
197             Ctl => {
198                 dbg!("Ctl");
199                 dbg!(cursor.remaining_slice());
200             }
201             Orig => {
202                 tx.send(Ok(Pkt {
203                     chan: chan.main.num,
204                     unrel: true,
205                     data: cursor.remaining_slice().into(),
206                 }))?;
207             }
208             Split => {
209                 dbg!("Split");
210                 dbg!(cursor.remaining_slice());
211             }
212             Rel => {
213                 let seqnum = cursor.read_u16::<BigEndian>()?;
214                 chan.packets[to_seqnum(seqnum)].set(Some(cursor.remaining_slice().into()));
215
216                 while let Some(pkt) = chan.packets[to_seqnum(chan.seqnum)].take() {
217                     tx.handle(self.process_pkt(io::Cursor::new(&pkt), chan, tx));
218                     chan.seqnum = chan.seqnum.overflowing_add(1).0;
219                 }
220             }
221         }
222
223         Ok(())
224     }
225 }
226
227 #[derive(Debug)]
228 pub struct Conn {
229     inner: Arc<ConnInner>,
230     rx: mpsc::Receiver<PktResult>,
231 }
232
233 impl Conn {
234     pub fn connect(addr: &str) -> io::Result<Self> {
235         let (tx, rx) = mpsc::channel();
236
237         let inner = Arc::new(ConnInner {
238             sock: net::UdpSocket::bind("0.0.0.0:0")?,
239             id: PeerID::Srv as u16,
240             remote_id: PeerID::Nil as u16,
241             chans: (0..NUM_CHANNELS as u8).map(|num| Channel { num }).collect(),
242         });
243
244         inner.sock.connect(addr)?;
245
246         let recv_inner = Arc::clone(&inner);
247         thread::spawn(move || {
248             recv_inner.recv_loop(tx);
249         });
250
251         Ok(Conn { inner, rx })
252     }
253
254     pub fn send(&self, pkt: Pkt<&[u8]>) -> io::Result<()> {
255         self.inner.send(pkt)
256     }
257
258     pub fn recv(&self) -> Result<PktResult, mpsc::RecvError> {
259         self.rx.recv()
260     }
261 }
262
263 fn main() {
264     //println!("{}", x.deep_size_of());
265     let conn = Conn::connect("127.0.0.1:30000").expect("the spanish inquisition");
266
267     let mut mtpkt = vec![];
268     mtpkt.write_u16::<BigEndian>(2).unwrap(); // high level type
269     mtpkt.write_u8(29).unwrap(); // serialize ver
270     mtpkt.write_u16::<BigEndian>(0).unwrap(); // compression modes
271     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MinProtoVer
272     mtpkt.write_u16::<BigEndian>(40).unwrap(); // MaxProtoVer
273     mtpkt.write_u16::<BigEndian>(3).unwrap(); // player name length
274     mtpkt.write(b"foo").unwrap(); // player name
275
276     conn.send(Pkt {
277         unrel: true,
278         chan: 1,
279         data: &mtpkt,
280     })
281     .unwrap();
282
283     while let Ok(result) = conn.recv() {
284         match result {
285             Ok(pkt) => {
286                 io::stdout().write(pkt.data.as_slice()).unwrap();
287             }
288             Err(err) => eprintln!("Error: {}", err),
289         }
290     }
291 }