]> git.lizzy.rs Git - mt_ser.git/blobdiff - src/lib.rs
Support Ranges
[mt_ser.git] / src / lib.rs
index f57e487eaa8c635e4eddee7196fe4fe85613fd20..b2aebc87e8b1e3bb634b6cdcdf5b51d460679bc0 100644 (file)
@@ -5,6 +5,7 @@
 pub use flate2;
 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
 pub use paste;
+pub use zstd;
 
 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
 use enumset::{EnumSet, EnumSetTypeWithRepr};
@@ -12,21 +13,26 @@ use paste::paste as paste_macro;
 use std::{
     collections::{HashMap, HashSet},
     convert::Infallible,
+    fmt::Debug,
     io::{self, Read, Write},
     num::TryFromIntError,
-    ops::Deref,
+    ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
 };
 use thiserror::Error;
 
 #[cfg(test)]
 mod tests;
 
+use crate as mt_ser;
+
 #[derive(Error, Debug)]
 pub enum SerializeError {
     #[error("io error: {0}")]
     IoError(#[from] io::Error),
     #[error("collection too big: {0}")]
     TooBig(#[from] TryFromIntError),
+    #[error("{0}")]
+    Other(String),
 }
 
 impl From<Infallible> for SerializeError {
@@ -45,10 +51,12 @@ pub enum DeserializeError {
     TooBig(#[from] TryFromIntError),
     #[error("invalid UTF-16: {0}")]
     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
-    #[error("invalid {0} enum variant {1}")]
-    InvalidEnumVariant(&'static str, u64),
-    #[error("invalid constant - wanted: {0} - got: {1}")]
-    InvalidConst(u64, u64),
+    #[error("invalid {0} enum variant {1:?}")]
+    InvalidEnum(&'static str, Box<dyn Debug>),
+    #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
+    InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
+    #[error("{0}")]
+    Other(String),
 }
 
 impl From<Infallible> for DeserializeError {
@@ -135,6 +143,7 @@ pub trait MtLen {
 
 pub trait MtCfg {
     type Len: MtLen;
+    type Inner: MtCfg;
 
     fn utf16() -> bool {
         false
@@ -179,6 +188,7 @@ trait MtCfgLen:
 
 impl<T: MtCfgLen> MtCfg for T {
     type Len = usize;
+    type Inner = DefCfg;
 
     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
         Self::try_from(len)
@@ -187,9 +197,9 @@ impl<T: MtCfgLen> MtCfg for T {
     }
 
     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
-        Ok(Self::mt_deserialize::<DefCfg>(reader)?
+        Self::mt_deserialize::<DefCfg>(reader)?
             .try_into()
-            .map_err(Into::into)?)
+            .map_err(Into::into)
     }
 }
 
@@ -202,6 +212,7 @@ pub type DefCfg = u16;
 
 impl MtCfg for () {
     type Len = ();
+    type Inner = DefCfg;
 
     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
         Ok(())
@@ -228,10 +239,11 @@ impl MtLen for () {
     }
 }
 
-pub struct Utf16<B: MtCfg>(pub B);
+pub struct Utf16<B: MtCfg = DefCfg>(pub B);
 
 impl<B: MtCfg> MtCfg for Utf16<B> {
     type Len = B::Len;
+    type Inner = B::Inner;
 
     fn utf16() -> bool {
         true
@@ -246,6 +258,19 @@ impl<B: MtCfg> MtCfg for Utf16<B> {
     }
 }
 
+impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
+    type Len = A::Len;
+    type Inner = B;
+
+    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
+        A::write_len(len, writer)
+    }
+
+    fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
+        A::read_len(reader)
+    }
+}
+
 impl MtSerialize for u8 {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         writer.write_u8(*self)?;
@@ -344,25 +369,25 @@ pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
     C::write_len(iter.len(), writer)?;
 
     iter.into_iter()
-        .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
+        .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
 }
 
-pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
-    reader: &'a mut impl Read,
-) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
+pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
+    reader: &mut impl Read,
+) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
     let len = C::read_len(reader)?;
-    mt_deserialize_sized_seq(&len, reader)
+    mt_deserialize_sized_seq::<C, _>(&len, reader)
 }
 
-pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
-    len: &L,
+pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
+    len: &C::Len,
     reader: &'a mut impl Read,
 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
     let variable = len.option().is_none();
 
     Ok(len
         .range()
-        .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
+        .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
             Err(DeserializeError::UnexpectedEof) if variable => None,
             x => Some(x),
         }))
@@ -436,8 +461,8 @@ impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSe
 // TODO: support more tuples
 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        self.0.mt_serialize::<DefCfg>(writer)?;
-        self.1.mt_serialize::<DefCfg>(writer)?;
+        self.0.mt_serialize::<C>(writer)?;
+        self.1.mt_serialize::<C::Inner>(writer)?;
 
         Ok(())
     }
@@ -445,8 +470,8 @@ impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
 
 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
-        let a = A::mt_deserialize::<DefCfg>(reader)?;
-        let b = B::mt_deserialize::<DefCfg>(reader)?;
+        let a = A::mt_deserialize::<C>(reader)?;
+        let b = B::mt_deserialize::<C::Inner>(reader)?;
 
         Ok((a, b))
     }
@@ -472,7 +497,7 @@ where
     }
 }
 
-impl MtSerialize for String {
+impl MtSerialize for &str {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         if C::utf16() {
             self.encode_utf16()
@@ -484,6 +509,12 @@ impl MtSerialize for String {
     }
 }
 
+impl MtSerialize for String {
+    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
+        self.as_str().mt_serialize::<C>(writer)
+    }
+}
+
 impl MtDeserialize for String {
     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
         if C::utf16() {
@@ -530,3 +561,56 @@ impl<T: MtDeserialize> MtDeserialize for Box<T> {
         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
     }
 }
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "Range")]
+#[allow(unused)]
+struct RemoteRange<T> {
+    start: T,
+    end: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeFrom")]
+#[allow(unused)]
+struct RemoteRangeFrom<T> {
+    start: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeFull")]
+#[allow(unused)]
+struct RemoteRangeFull;
+
+// RangeInclusive fields are private
+impl<T: MtSerialize> MtSerialize for RangeInclusive<T> {
+    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
+        self.start().mt_serialize::<DefCfg>(writer)?;
+        self.end().mt_serialize::<DefCfg>(writer)?;
+
+        Ok(())
+    }
+}
+
+impl<T: MtDeserialize> MtDeserialize for RangeInclusive<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        let start = T::mt_deserialize::<DefCfg>(reader)?;
+        let end = T::mt_deserialize::<DefCfg>(reader)?;
+
+        Ok(start..=end)
+    }
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeTo")]
+#[allow(unused)]
+struct RemoteRangeTo<T> {
+    end: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeToInclusive")]
+#[allow(unused)]
+struct RemoteRangeToInclusive<T> {
+    end: T,
+}