]> git.lizzy.rs Git - connect-rs.git/blob - src/protocol.rs
don't let users construct datagrams with a message larger than 100MB
[connect-rs.git] / src / protocol.rs
1 use std::array::TryFromSliceError;
2 use std::convert::TryInto;
3 use std::error::Error;
4
5 const VERSION: u16 = 1;
6
7 /// Encountered when there is an issue constructing, serializing, or deserializing a [`ConnectDatagram`].
8 ///
9 #[derive(Debug, Clone)]
10 pub enum DatagramError {
11     /// Tried to construct a [`ConnectDatagram`] with an empty message body.
12     EmptyMessage,
13
14     /// Tried to construct a [`ConnectDatagram`] with a message body larger than 100MB.
15     TooLargeMessage,
16
17     /// Did not provide the complete byte-string necessary to deserialize the [`ConnectDatagram`].
18     IncompleteBytes,
19
20     /// Wraps a [`TryFromSliceError`] encountered when the version or recipient tags cannot be
21     /// parsed from the provided bytes.
22     BytesParseFail(TryFromSliceError),
23 }
24
25 impl Error for DatagramError {}
26
27 impl std::fmt::Display for DatagramError {
28     fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
29         match self {
30             DatagramError::EmptyMessage => formatter.write_str("tried to construct a `ConnectDatagram` with an empty message body"),
31             DatagramError::TooLargeMessage => formatter.write_str("tried to construct a `ConnectDatagram` with a message body larger than 100MB"),
32             DatagramError::IncompleteBytes => formatter.write_str("did not provide the complete byte-string necessary to deserialize the `ConnectDatagram`"),
33             DatagramError::BytesParseFail(err) => std::fmt::Display::fmt(err, formatter),
34         }
35     }
36 }
37
38 /// A simple size-prefixed packet format containing a version tag, recipient tag, and message body.
39 ///
40 /// The version tag is decided by the library version and used to maintain backwards
41 /// compatibility with previous datagram formats.
42 ///
43 #[derive(Clone)]
44 pub struct ConnectDatagram {
45     version: u16,
46     recipient: u16,
47     data: Option<Vec<u8>>,
48 }
49
50 impl ConnectDatagram {
51     /// Creates a new [`ConnectDatagram`] based on an intended recipient and message body.
52     ///
53     /// The version tag is decided by the library version and used to maintain backwards
54     /// compatibility with previous datagram formats.
55     ///
56     /// This will return a [EmptyMessage](`DatagramError::EmptyMessage`) error if the `data`
57     /// parameter contains no bytes, or in other words, when there is no message body.
58     ///
59     /// This will return a [TooLargeMessage](`DatagramError::TooLargeMessage`) error if the `data`
60     /// parameter contains a buffer size greater than 100,000,000 (bytes), or 100MB.
61     ///
62     pub fn new(recipient: u16, data: Vec<u8>) -> Result<Self, DatagramError> {
63         if data.len() > 100_000_000 {
64             Err(DatagramError::TooLargeMessage)
65         } else if data.len() > 0 {
66             Ok(Self {
67                 version: VERSION,
68                 recipient,
69                 data: Some(data),
70             })
71         } else {
72             Err(DatagramError::EmptyMessage)
73         }
74     }
75
76     /// Gets the version number of the datagram.
77     ///
78     pub fn version(&self) -> u16 {
79         self.version
80     }
81
82     /// Gets the recipient of the datagram.
83     ///
84     pub fn recipient(&self) -> u16 {
85         self.recipient
86     }
87
88     /// Gets the message body of the datagram.
89     ///
90     pub fn data(&self) -> Option<&Vec<u8>> {
91         self.data.as_ref()
92     }
93
94     /// Takes ownership of the message body of the datagram.
95     ///
96     pub fn take_data(&mut self) -> Option<Vec<u8>> {
97         self.data.take()
98     }
99
100     /// Calculates the size-prefixed serialized byte-size of the datagram.
101     ///
102     /// This will include the byte-size of the size-prefix.
103     ///
104     pub fn size(&self) -> usize {
105         let data_len = if let Some(data) = self.data() {
106             data.len()
107         } else {
108             0
109         };
110
111         8 + data_len
112     }
113
114     /// Constructs a serialized representation of the datagram contents.
115     ///
116     pub(crate) fn bytes(&self) -> Vec<u8> {
117         let mut bytes = Vec::with_capacity(self.size());
118
119         bytes.extend(&self.version.to_be_bytes());
120         bytes.extend(&self.recipient.to_be_bytes());
121
122         if let Some(data) = self.data() {
123             bytes.extend(data.as_slice());
124         }
125
126         return bytes;
127     }
128
129     /// Serializes the datagram.
130     ///
131     pub fn encode(self) -> Vec<u8> {
132         let content_encoded = self.bytes();
133         let size: u32 = (content_encoded.len()) as u32;
134
135         let mut bytes = Vec::from(size.to_be_bytes());
136         bytes.extend(content_encoded);
137
138         return bytes;
139     }
140
141     /// Deserializes the datagram from a buffer.
142     ///
143     /// The buffer **should not** contain the size-prefix, and only contain the byte contents of the
144     /// struct (version, recipient, and message body).
145     ///
146     pub fn decode(mut buffer: Vec<u8>) -> Result<Self, DatagramError> {
147         if buffer.len() > 4 {
148             let mem_size = std::mem::size_of::<u16>();
149             let data = buffer.split_off(mem_size * 2);
150
151             let (version_bytes, recipient_bytes) = buffer.split_at(mem_size);
152
153             match version_bytes.try_into() {
154                 Ok(version_slice) => match recipient_bytes.try_into() {
155                     Ok(recipient_slice) => {
156                         let version = u16::from_be_bytes(version_slice);
157                         let recipient = u16::from_be_bytes(recipient_slice);
158
159                         Ok(Self {
160                             version,
161                             recipient,
162                             data: Some(data),
163                         })
164                     }
165
166                     Err(err) => Err(DatagramError::BytesParseFail(err)),
167                 },
168
169                 Err(err) => Err(DatagramError::BytesParseFail(err)),
170             }
171         } else {
172             Err(DatagramError::IncompleteBytes)
173         }
174     }
175 }
176
177 #[cfg(test)]
178 mod tests {
179     use crate::protocol::ConnectDatagram;
180
181     #[test]
182     fn serialized_size() -> anyhow::Result<()> {
183         let mut data = Vec::new();
184         for _ in 0..5 {
185             data.push(1);
186         }
187         assert_eq!(5, data.len());
188
189         let sample = ConnectDatagram::new(1, data)?;
190         assert_eq!(8 + 5, sample.encode().len());
191
192         Ok(())
193     }
194
195     #[test]
196     fn take_data() -> anyhow::Result<()> {
197         let mut data = Vec::new();
198         for _ in 0..5 {
199             data.push(1);
200         }
201
202         let mut sample = ConnectDatagram::new(1, data)?;
203
204         let taken_data = sample.take_data().unwrap();
205         assert!(sample.data().is_none());
206         assert_eq!(5, taken_data.len());
207
208         Ok(())
209     }
210
211     #[async_std::test]
212     async fn encode_and_decode() -> anyhow::Result<()> {
213         let mut data = Vec::new();
214         for _ in 0..5 {
215             data.push(1);
216         }
217         assert_eq!(5, data.len());
218
219         let sample = ConnectDatagram::new(1, data)?;
220         let serialized_size = sample.size();
221         assert_eq!(8 + 5, serialized_size);
222
223         let mut payload = sample.encode();
224         assert_eq!(serialized_size, payload.len());
225
226         let payload = payload.split_off(std::mem::size_of::<u32>());
227         let sample_back_res = ConnectDatagram::decode(payload);
228         assert!(sample_back_res.is_ok());
229
230         let sample_back = sample_back_res.unwrap();
231         assert_eq!(sample_back.version(), 1);
232         assert_eq!(sample_back.recipient(), 1);
233         assert_eq!(sample_back.data().unwrap().len(), 5);
234
235         Ok(())
236     }
237 }