]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/protocol.rs
refactor read/write for correctness and ordering of messages
[connect-rs.git] / src / protocol.rs
index 73da4bc46b1bab9ebeddfecca81e8e663fb01631..a2fe0b0c851fcad67dd364e29341ac8962efd1ed 100644 (file)
@@ -1,29 +1,39 @@
+use std::array::TryFromSliceError;
+use std::convert::TryInto;
 use std::error::Error;
-use std::io::Read;
 
-const VERSION: u8 = 1;
+const VERSION: u16 = 1;
 
-/// Encountered when trying to construct a [`ConnectDatagram`] with an empty message body.
+/// Encountered when there is an issue constructing, serializing, or deserializing a [`ConnectDatagram`].
 ///
 #[derive(Debug, Clone)]
-pub struct DatagramEmptyError;
+pub enum DatagramError {
+    /// Tried to construct a [`ConnectDatagram`] with an empty message body.
+    EmptyBody,
 
-impl Error for DatagramEmptyError {}
+    /// Did not provide the complete byte-string necessary to deserialize the [`ConnectDatagram`].
+    IncompleteBytes,
 
-impl std::fmt::Display for DatagramEmptyError {
-    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
-        write!(
-            f,
-            "datagram cannot be constructed when provided payload is empty"
-        )
+    BytesParseFail(TryFromSliceError),
+}
+
+impl Error for DatagramError {}
+
+impl std::fmt::Display for DatagramError {
+    fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+        match self {
+            DatagramError::EmptyBody => formatter.write_str("tried to construct a `ConnectDatagram` with an empty message body"),
+            DatagramError::IncompleteBytes => formatter.write_str("did not provide the complete byte-string necessary to deserialize the `ConnectDatagram`"),
+            DatagramError::BytesParseFail(err) => std::fmt::Display::fmt(err, formatter),
+        }
     }
 }
 
-/// A simple packet format containing a version, recipient tag, and message body.
+/// A simple size-prefixed packet format containing a version tag, recipient tag, and message body.
 ///
 #[derive(Clone)]
 pub struct ConnectDatagram {
-    version: u8,
+    version: u16,
     recipient: u16,
     data: Option<Vec<u8>>,
 }
@@ -31,13 +41,13 @@ pub struct ConnectDatagram {
 impl ConnectDatagram {
     /// Creates a new [`ConnectDatagram`] based on an intended recipient and message body.
     ///
-    /// This will return a [`DatagramEmptyError`] if the `data` parameter contains no bytes, or
-    /// in other words, when there is no message body.
+    /// This will return a [EmptyBody](`DatagramError::EmptyBody`) error if the `data` parameter
+    /// contains no bytes, or in other words, when there is no message body.
     ///
     /// The version field is decided by the library version and used to maintain backwards
     /// compatibility with previous datagram formats.
     ///
-    pub fn new(recipient: u16, data: Vec<u8>) -> Result<Self, DatagramEmptyError> {
+    pub fn new(recipient: u16, data: Vec<u8>) -> Result<Self, DatagramError> {
         if data.len() > 0 {
             Ok(Self {
                 version: VERSION,
@@ -45,13 +55,13 @@ impl ConnectDatagram {
                 data: Some(data),
             })
         } else {
-            Err(DatagramEmptyError)
+            Err(DatagramError::EmptyBody)
         }
     }
 
     /// Gets the version number of the datagram.
     ///
-    pub fn version(&self) -> u8 {
+    pub fn version(&self) -> u16 {
         self.version
     }
 
@@ -73,7 +83,9 @@ impl ConnectDatagram {
         self.data.take()
     }
 
-    /// Calculates the serialized byte-size of the datagram.
+    /// Calculates the size-prefixed serialized byte-size of the datagram.
+    ///
+    /// This will include the byte-size of the size-prefix.
     ///
     pub fn size(&self) -> usize {
         let data_len = if let Some(data) = self.data() {
@@ -82,7 +94,7 @@ impl ConnectDatagram {
             0
         };
 
-        3 + data_len
+        8 + data_len
     }
 
     /// Constructs a serialized representation of the datagram contents.
@@ -103,48 +115,47 @@ impl ConnectDatagram {
     /// Serializes the datagram.
     ///
     pub fn encode(self) -> Vec<u8> {
-        let size: u32 = (self.size()) as u32;
+        let content_encoded = self.bytes();
+        let size: u32 = (content_encoded.len()) as u32;
 
         let mut bytes = Vec::from(size.to_be_bytes());
-        bytes.extend(self.bytes());
+        bytes.extend(content_encoded);
 
         return bytes;
     }
 
-    /// Deserializes the datagram from a `source`.
+    /// Deserializes the datagram from a buffer.
     ///
-    pub fn decode(source: &mut (dyn Read + Send + Sync)) -> anyhow::Result<Self> {
-        // payload size
-        let mut payload_size_bytes: [u8; 4] = [0; 4];
-        source.read_exact(&mut payload_size_bytes)?;
-        let payload_size = u32::from_be_bytes(payload_size_bytes);
-
-        // read whole payload
-        let mut payload_bytes = vec![0; payload_size as usize];
-        source.read_exact(payload_bytes.as_mut_slice())?;
-
-        // version
-        let version_bytes = payload_bytes.remove(0);
-        let version = u8::from_be(version_bytes);
-
-        // recipient
-        let mut recipient_bytes: [u8; 2] = [0; 2];
-        for i in 0..recipient_bytes.len() {
-            recipient_bytes[i] = payload_bytes.remove(0);
-        }
-        let recipient = u16::from_be_bytes(recipient_bytes);
-
-        // data
-        let data = payload_bytes;
-
-        if data.len() > 0 {
-            Ok(Self {
-                version,
-                recipient,
-                data: Some(data),
-            })
+    /// The buffer **should not** contain the size-prefix, and only contain the byte contents of the
+    /// struct (version, recipient, and message body).
+    ///
+    pub fn decode(mut buffer: Vec<u8>) -> Result<Self, DatagramError> {
+        if buffer.len() > 4 {
+            let mem_size = std::mem::size_of::<u16>();
+            let data = buffer.split_off(mem_size * 2);
+
+            let (version_bytes, recipient_bytes) = buffer.split_at(mem_size);
+
+            match version_bytes.try_into() {
+                Ok(version_slice) => match recipient_bytes.try_into() {
+                    Ok(recipient_slice) => {
+                        let version = u16::from_be_bytes(version_slice);
+                        let recipient = u16::from_be_bytes(recipient_slice);
+
+                        Ok(Self {
+                            version,
+                            recipient,
+                            data: Some(data),
+                        })
+                    }
+
+                    Err(err) => Err(DatagramError::BytesParseFail(err)),
+                },
+
+                Err(err) => Err(DatagramError::BytesParseFail(err)),
+            }
         } else {
-            Err(anyhow::Error::from(DatagramEmptyError))
+            Err(DatagramError::IncompleteBytes)
         }
     }
 }
@@ -152,10 +163,9 @@ impl ConnectDatagram {
 #[cfg(test)]
 mod tests {
     use crate::protocol::ConnectDatagram;
-    use std::io::Cursor;
 
     #[test]
-    fn encoded_size() -> anyhow::Result<()> {
+    fn serialized_size() -> anyhow::Result<()> {
         let mut data = Vec::new();
         for _ in 0..5 {
             data.push(1);
@@ -163,7 +173,7 @@ mod tests {
         assert_eq!(5, data.len());
 
         let sample = ConnectDatagram::new(1, data)?;
-        assert_eq!(7 + 5, sample.encode().len());
+        assert_eq!(8 + 5, sample.encode().len());
 
         Ok(())
     }
@@ -193,12 +203,14 @@ mod tests {
         assert_eq!(5, data.len());
 
         let sample = ConnectDatagram::new(1, data)?;
+        let serialized_size = sample.size();
+        assert_eq!(8 + 5, serialized_size);
 
         let mut payload = sample.encode();
-        assert_eq!(7 + 5, payload.len());
+        assert_eq!(serialized_size, payload.len());
 
-        let mut cursor: Cursor<&mut [u8]> = Cursor::new(payload.as_mut());
-        let sample_back_res = ConnectDatagram::decode(&mut cursor);
+        let payload = payload.split_off(std::mem::size_of::<u32>());
+        let sample_back_res = ConnectDatagram::decode(payload);
         assert!(sample_back_res.is_ok());
 
         let sample_back = sample_back_res.unwrap();