]> git.lizzy.rs Git - mt_ser.git/blobdiff - src/lib.rs
Implement UTF-8 decode and move packets to different crate
[mt_ser.git] / src / lib.rs
index c3e5cee2dd9b4c67642a663968539596bca918c0..0019d1c0b63e023abc5befa8b21e824a8e972d6b 100644 (file)
@@ -2,42 +2,24 @@
 #![feature(associated_type_bounds)]
 #![feature(iterator_try_collect)]
 
-pub use enumset;
 pub use flate2;
+pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
 pub use paste;
 
-#[cfg(feature = "random")]
-pub use generate_random;
-
-#[cfg(feature = "random")]
-pub use rand;
-
-#[cfg(feature = "serde")]
-pub use serde;
-
 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
-use enumset::{EnumSet, EnumSetType, EnumSetTypeWithRepr};
-use mt_data_derive::mt_derive;
-pub use mt_data_derive::{MtDeserialize, MtSerialize};
+use enumset::{EnumSet, EnumSetTypeWithRepr};
 use paste::paste as paste_macro;
 use std::{
     collections::{HashMap, HashSet},
     convert::Infallible,
-    fmt,
     io::{self, Read, Write},
     num::TryFromIntError,
     ops::Deref,
 };
 use thiserror::Error;
 
-#[cfg(feature = "serde")]
-use serde::{Deserialize, Serialize};
-
-#[cfg(feature = "random")]
-use generate_random::GenerateRandom;
-
-#[cfg(any(feature = "client", feature = "server"))]
-use crate as mt_data;
+#[cfg(test)]
+mod tests;
 
 #[derive(Error, Debug)]
 pub enum SerializeError {
@@ -87,6 +69,49 @@ pub trait OrDefault<T> {
     fn or_default(self) -> Self;
 }
 
+pub struct WrapRead<'a, R: Read>(pub &'a mut R);
+impl<'a, R: Read> Read for WrapRead<'a, R> {
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        self.0.read(buf)
+    }
+
+    fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
+        self.0.read_vectored(bufs)
+    }
+
+    /*
+
+    fn is_read_vectored(&self) -> bool {
+        self.0.is_read_vectored()
+    }
+
+    */
+
+    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
+        self.0.read_to_end(buf)
+    }
+
+    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
+        self.0.read_to_string(buf)
+    }
+
+    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+        self.0.read_exact(buf)
+    }
+
+    /*
+
+    fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
+        self.0.read_buf(buf)
+    }
+
+    fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
+        self.0.read_buf_exact(cursor)
+    }
+
+    */
+}
+
 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
     fn or_default(self) -> Self {
         match self {
@@ -96,24 +121,49 @@ impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
     }
 }
 
-pub trait MtCfg:
-    Sized + MtSerialize + MtDeserialize + TryFrom<usize, Error: Into<SerializeError>>
-{
+pub trait MtLen {
+    fn option(&self) -> Option<usize>;
+
     type Range: Iterator<Item = usize> + 'static;
+    fn range(&self) -> Self::Range;
+
+    type Take<R: Read>: Read;
+    fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
+}
+
+pub trait MtCfg {
+    type Len: MtLen;
 
     fn utf16() -> bool {
         false
     }
 
-    fn var_len() -> bool;
+    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
+    fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
+}
 
-    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
-        Self::try_from(len)
-            .map_err(Into::into)?
-            .mt_serialize::<DefCfg>(writer)
+pub trait MtSerialize {
+    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
+}
+
+pub trait MtDeserialize: Sized {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
+}
+
+impl MtLen for usize {
+    fn option(&self) -> Option<usize> {
+        Some(*self)
     }
 
-    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError>;
+    type Range = std::ops::Range<usize>;
+    fn range(&self) -> Self::Range {
+        0..*self
+    }
+
+    type Take<R: Read> = io::Take<R>;
+    fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
+        reader.take(*self as u64)
+    }
 }
 
 trait MtCfgLen:
@@ -126,105 +176,70 @@ trait MtCfgLen:
 }
 
 impl<T: MtCfgLen> MtCfg for T {
-    type Range = std::ops::Range<usize>;
+    type Len = usize;
 
-    fn var_len() -> bool {
-        false
+    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
+        Self::try_from(len)
+            .map_err(Into::into)?
+            .mt_serialize::<DefCfg>(writer)
     }
 
-    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
-        let len = Self::mt_deserialize::<DefCfg>(reader)?
+    fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
+        Ok(Self::mt_deserialize::<DefCfg>(reader)?
             .try_into()
-            .map_err(Into::into)?;
-
-        Ok(0..len)
+            .map_err(Into::into)?)
     }
 }
 
-pub type DefCfg = u16;
-
-pub trait MtSerialize: Sized {
-    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
-}
-
-pub trait MtDeserialize: Sized {
-    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
-}
-
 impl MtCfgLen for u8 {}
 impl MtCfgLen for u16 {}
 impl MtCfgLen for u32 {}
 impl MtCfgLen for u64 {}
 
-#[derive(Debug)]
-pub struct NoLen;
+pub type DefCfg = u16;
 
-impl MtSerialize for NoLen {
-    fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
-        Ok(())
-    }
-}
+impl MtCfg for () {
+    type Len = ();
 
-impl MtDeserialize for NoLen {
-    fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
-        Ok(Self)
+    fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
+        Ok(())
     }
-}
 
-impl TryFrom<usize> for NoLen {
-    type Error = Infallible;
-
-    fn try_from(_x: usize) -> Result<Self, Self::Error> {
-        Ok(Self)
+    fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
+        Ok(())
     }
 }
 
-impl MtCfg for NoLen {
-    fn var_len() -> bool {
-        true
+impl MtLen for () {
+    fn option(&self) -> Option<usize> {
+        None
     }
 
     type Range = std::ops::RangeFrom<usize>;
-
-    fn read_len(_reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
-        Ok(0..)
-    }
-}
-
-pub struct Utf16<B: MtCfg>(pub B);
-
-impl<B: MtCfg> MtSerialize for Utf16<B> {
-    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        self.0.mt_serialize::<DefCfg>(writer)
+    fn range(&self) -> Self::Range {
+        0..
     }
-}
 
-impl<B: MtCfg> MtDeserialize for Utf16<B> {
-    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
-        Ok(Self(B::mt_deserialize::<DefCfg>(reader)?))
+    type Take<R: Read> = R;
+    fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
+        reader
     }
 }
 
-impl<B: MtCfg> TryFrom<usize> for Utf16<B> {
-    type Error = <usize as TryInto<B>>::Error;
-
-    fn try_from(x: usize) -> Result<Self, Self::Error> {
-        Ok(Self(x.try_into()?))
-    }
-}
+pub struct Utf16<B: MtCfg>(pub B);
 
 impl<B: MtCfg> MtCfg for Utf16<B> {
-    type Range = B::Range;
+    type Len = B::Len;
 
     fn utf16() -> bool {
         true
     }
 
-    fn var_len() -> bool {
-        B::var_len()
+    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
+        B::write_len(len, writer)
     }
 
-    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
+    fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
         B::read_len(reader)
     }
 }
@@ -320,7 +335,7 @@ impl<T: MtSerialize> MtSerialize for &T {
     }
 }
 
-fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
+pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
     writer: &mut impl Write,
     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
 ) -> Result<(), SerializeError> {
@@ -330,20 +345,30 @@ fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
         .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
 }
 
-fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
+pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
+    reader: &'a mut impl Read,
+) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
+    let len = C::read_len(reader)?;
+    mt_deserialize_sized_seq(&len, reader)
+}
+
+pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
+    len: &L,
     reader: &'a mut impl Read,
 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
-    Ok(C::read_len(reader)?
-        .into_iter()
-        .map_while(|_| match T::mt_deserialize::<DefCfg>(reader) {
-            Err(DeserializeError::UnexpectedEof) if C::var_len() => None,
+    let variable = len.option().is_none();
+
+    Ok(len
+        .range()
+        .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
+            Err(DeserializeError::UnexpectedEof) if variable => None,
             x => Some(x),
         }))
 }
 
 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        mt_serialize_seq::<NoLen, _>(writer, self.iter())
+        mt_serialize_seq::<(), _>(writer, self.iter())
     }
 }
 
@@ -406,6 +431,7 @@ impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSe
     }
 }
 
+// TODO: support more tuples
 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         self.0.mt_serialize::<DefCfg>(writer)?;
@@ -476,8 +502,17 @@ impl MtDeserialize for String {
                 Some(e) => Err(e),
             }
         } else {
-            // TODO: UTF-8 decode
-            Ok("".into())
+            let len = C::read_len(reader)?;
+
+            // use capacity if available
+            let mut st = match len.option() {
+                Some(x) => String::with_capacity(x),
+                None => String::new(),
+            };
+
+            len.take(WrapRead(reader)).read_to_string(&mut st)?;
+
+            Ok(st)
         }
     }
 }
@@ -493,9 +528,3 @@ impl<T: MtDeserialize> MtDeserialize for Box<T> {
         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
     }
 }
-
-mod to_clt;
-mod to_srv;
-
-pub use to_clt::*;
-pub use to_srv::*;