]> git.lizzy.rs Git - mt_ser.git/commitdiff
Implement deserialize for basic types
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Thu, 9 Feb 2023 20:22:01 +0000 (21:22 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Fri, 10 Feb 2023 02:18:22 +0000 (03:18 +0100)
Cargo.toml
derive/src/lib.rs
src/lib.rs
src/to_clt/hud.rs

index c735152ce5c8abd4ab7df0739573a66020a75da6..645100867b82c9dd66205607eabb02678deb9934 100644 (file)
@@ -4,6 +4,7 @@ version = "0.1.0"
 edition = "2021"
 
 [features]
+all = ["client", "server", "random", "serde"]
 client = []
 random = ["dep:generate-random", "dep:rand"]
 serde = ["dep:serde", "dep:serde_arrays", "enumset/serde"]
index 9838ea17989d685cbcd316a3710e74d5becc15cb..75bc693b5515a3198219749ef7b2611dac046500 100644 (file)
@@ -305,9 +305,8 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
                        impl_size!(size32, u32);
                        impl_size!(size64, u64);
 
-
                        code
-               },
+               }
                Err(e) => return e.write_errors(),
        }
 }
@@ -315,10 +314,12 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
 fn serialize_fields(fields: &Fields) -> TokStr {
        fields
                .iter()
-               .map(|(ident, field)| serialize_args(MtArgs::from_field(field), |args| {
-                       let cfg = get_cfg(args);
-                       quote! { mt_data::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; }
-               }))
+               .map(|(ident, field)| {
+                       serialize_args(MtArgs::from_field(field), |args| {
+                               let cfg = get_cfg(args);
+                               quote! { mt_data::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; }
+                       })
+               })
                .collect()
 }
 
@@ -343,7 +344,9 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
                                        let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
 
                                        let ident_fn = match &v.fields {
-                                               syn::Fields::Unnamed(_) => |f| quote! { mt_data::paste! { [<field_ #f>] }},
+                                               syn::Fields::Unnamed(_) => |f| quote! {
+                                                       mt_data::paste::paste! { [<field_ #f>] }
+                                               },
                                                _ => |f| quote! { #f },
                                        };
 
index 99bd58a0c39190b13f3f34da49a4ba7a4b50d9c7..c3e5cee2dd9b4c67642a663968539596bca918c0 100644 (file)
@@ -1,6 +1,10 @@
+#![feature(array_try_from_fn)]
+#![feature(associated_type_bounds)]
+#![feature(iterator_try_collect)]
+
 pub use enumset;
 pub use flate2;
-pub use paste::paste;
+pub use paste;
 
 #[cfg(feature = "random")]
 pub use generate_random;
@@ -15,6 +19,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
 use enumset::{EnumSet, EnumSetType, EnumSetTypeWithRepr};
 use mt_data_derive::mt_derive;
 pub use mt_data_derive::{MtDeserialize, MtSerialize};
+use paste::paste as paste_macro;
 use std::{
     collections::{HashMap, HashSet},
     convert::Infallible,
@@ -31,23 +36,18 @@ use serde::{Deserialize, Serialize};
 #[cfg(feature = "random")]
 use generate_random::GenerateRandom;
 
+#[cfg(any(feature = "client", feature = "server"))]
 use crate as mt_data;
 
-#[derive(Error, Debug)]
-#[error("variable length")]
-pub struct VarLen;
-
 #[derive(Error, Debug)]
 pub enum SerializeError {
     #[error("io error: {0}")]
     IoError(#[from] io::Error),
     #[error("collection too big: {0}")]
     TooBig(#[from] TryFromIntError),
-    #[error("unimplemented")]
-    Unimplemented,
 }
 
-impl From<Infallible> for DeserializeError {
+impl From<Infallible> for SerializeError {
     fn from(_err: Infallible) -> Self {
         unreachable!("infallible")
     }
@@ -56,40 +56,88 @@ impl From<Infallible> for DeserializeError {
 #[derive(Error, Debug)]
 pub enum DeserializeError {
     #[error("io error: {0}")]
-    IoError(#[from] io::Error),
-    #[error("variable length not supported")]
-    NoVarlen(#[from] VarLen),
+    IoError(io::Error),
+    #[error("unexpected end of file")]
+    UnexpectedEof,
     #[error("collection too big: {0}")]
     TooBig(#[from] TryFromIntError),
+    #[error("invalid UTF-16: {0}")]
+    InvalidUtf16(#[from] std::char::DecodeUtf16Error),
     #[error("unimplemented")]
     Unimplemented,
 }
 
-impl From<Infallible> for SerializeError {
+impl From<Infallible> for DeserializeError {
     fn from(_err: Infallible) -> Self {
         unreachable!("infallible")
     }
 }
 
+impl From<io::Error> for DeserializeError {
+    fn from(err: io::Error) -> Self {
+        if err.kind() == io::ErrorKind::UnexpectedEof {
+            DeserializeError::UnexpectedEof
+        } else {
+            DeserializeError::IoError(err)
+        }
+    }
+}
+
+pub trait OrDefault<T> {
+    fn or_default(self) -> Self;
+}
+
+impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
+    fn or_default(self) -> Self {
+        match self {
+            Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
+            x => x,
+        }
+    }
+}
+
 pub trait MtCfg:
+    Sized + MtSerialize + MtDeserialize + TryFrom<usize, Error: Into<SerializeError>>
+{
+    type Range: Iterator<Item = usize> + 'static;
+
+    fn utf16() -> bool {
+        false
+    }
+
+    fn var_len() -> bool;
+
+    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
+        Self::try_from(len)
+            .map_err(Into::into)?
+            .mt_serialize::<DefCfg>(writer)
+    }
+
+    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError>;
+}
+
+trait MtCfgLen:
     Sized
     + MtSerialize
     + MtDeserialize
-    + TryFrom<usize, Error = Self::TryFromError>
-    + TryInto<usize, Error = Self::TryIntoError>
+    + TryFrom<usize, Error: Into<SerializeError>>
+    + TryInto<usize, Error: Into<DeserializeError>>
 {
-    type TryFromError: Into<SerializeError>;
-    type TryIntoError: Into<DeserializeError>;
+}
 
-    #[inline]
-    fn utf16() -> bool {
+impl<T: MtCfgLen> MtCfg for T {
+    type Range = std::ops::Range<usize>;
+
+    fn var_len() -> bool {
         false
     }
 
-    fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
-        Ok(Self::try_from(len)
-            .map_err(|e| e.into())?
-            .mt_serialize::<DefCfg>(writer)?)
+    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
+        let len = Self::mt_deserialize::<DefCfg>(reader)?
+            .try_into()
+            .map_err(Into::into)?;
+
+        Ok(0..len)
     }
 }
 
@@ -103,26 +151,12 @@ pub trait MtDeserialize: Sized {
     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
 }
 
-impl MtCfg for u8 {
-    type TryFromError = TryFromIntError;
-    type TryIntoError = Infallible;
-}
-
-impl MtCfg for u16 {
-    type TryFromError = TryFromIntError;
-    type TryIntoError = Infallible;
-}
-
-impl MtCfg for u32 {
-    type TryFromError = TryFromIntError;
-    type TryIntoError = TryFromIntError;
-}
-
-impl MtCfg for u64 {
-    type TryFromError = TryFromIntError;
-    type TryIntoError = TryFromIntError;
-}
+impl MtCfgLen for u8 {}
+impl MtCfgLen for u16 {}
+impl MtCfgLen for u32 {}
+impl MtCfgLen for u64 {}
 
+#[derive(Debug)]
 pub struct NoLen;
 
 impl MtSerialize for NoLen {
@@ -145,17 +179,16 @@ impl TryFrom<usize> for NoLen {
     }
 }
 
-impl TryInto<usize> for NoLen {
-    type Error = VarLen;
-
-    fn try_into(self) -> Result<usize, Self::Error> {
-        Err(VarLen)
+impl MtCfg for NoLen {
+    fn var_len() -> bool {
+        true
     }
-}
 
-impl MtCfg for NoLen {
-    type TryFromError = Infallible;
-    type TryIntoError = VarLen;
+    type Range = std::ops::RangeFrom<usize>;
+
+    fn read_len(_reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
+        Ok(0..)
+    }
 }
 
 pub struct Utf16<B: MtCfg>(pub B);
@@ -173,29 +206,27 @@ impl<B: MtCfg> MtDeserialize for Utf16<B> {
 }
 
 impl<B: MtCfg> TryFrom<usize> for Utf16<B> {
-    type Error = B::TryFromError;
+    type Error = <usize as TryInto<B>>::Error;
 
     fn try_from(x: usize) -> Result<Self, Self::Error> {
         Ok(Self(x.try_into()?))
     }
 }
 
-impl<B: MtCfg> TryInto<usize> for Utf16<B> {
-    type Error = B::TryIntoError;
-
-    fn try_into(self) -> Result<usize, Self::Error> {
-        self.0.try_into()
-    }
-}
-
 impl<B: MtCfg> MtCfg for Utf16<B> {
-    type TryFromError = B::TryFromError;
-    type TryIntoError = B::TryIntoError;
+    type Range = B::Range;
 
-    #[inline]
     fn utf16() -> bool {
         true
     }
+
+    fn var_len() -> bool {
+        B::var_len()
+    }
+
+    fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
+        B::read_len(reader)
+    }
 }
 
 impl MtSerialize for u8 {
@@ -231,7 +262,7 @@ macro_rules! impl_num {
                 &self,
                 writer: &mut impl Write,
             ) -> Result<(), SerializeError> {
-                paste! {
+                paste_macro! {
                     writer.[<write_ $T>]::<BigEndian>(*self)?;
                 }
                 Ok(())
@@ -240,7 +271,7 @@ macro_rules! impl_num {
 
         impl MtDeserialize for $T {
             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
-                paste! {
+                paste_macro! {
                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
                 }
             }
@@ -265,25 +296,60 @@ impl MtSerialize for () {
     }
 }
 
+impl MtDeserialize for () {
+    fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        Ok(())
+    }
+}
+
 impl MtSerialize for bool {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         (*self as u8).mt_serialize::<DefCfg>(writer)
     }
 }
 
-impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
+impl MtDeserialize for bool {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
+    }
+}
+
+impl<T: MtSerialize> MtSerialize for &T {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        self.as_slice().mt_serialize::<NoLen>(writer)
+        (*self).mt_serialize::<C>(writer)
     }
 }
 
-impl<T: MtSerialize> MtSerialize for &[T] {
+fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
+    writer: &mut impl Write,
+    iter: impl ExactSizeIterator + IntoIterator<Item = T>,
+) -> Result<(), SerializeError> {
+    C::write_len(iter.len(), writer)?;
+
+    iter.into_iter()
+        .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
+}
+
+fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
+    reader: &'a mut impl Read,
+) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
+    Ok(C::read_len(reader)?
+        .into_iter()
+        .map_while(|_| match T::mt_deserialize::<DefCfg>(reader) {
+            Err(DeserializeError::UnexpectedEof) if C::var_len() => None,
+            x => Some(x),
+        }))
+}
+
+impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        C::write_len(self.len(), writer)?;
-        for item in self.iter() {
-            item.mt_serialize::<DefCfg>(writer)?;
-        }
-        Ok(())
+        mt_serialize_seq::<NoLen, _>(writer, self.iter())
+    }
+}
+
+impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
     }
 }
 
@@ -293,6 +359,14 @@ impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E
     }
 }
 
+impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
+            reader,
+        )?))
+    }
+}
+
 impl<T: MtSerialize> MtSerialize for Option<T> {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         match self {
@@ -302,34 +376,71 @@ impl<T: MtSerialize> MtSerialize for Option<T> {
     }
 }
 
+impl<T: MtDeserialize> MtDeserialize for Option<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        T::mt_deserialize::<C>(reader).map(Some).or_default()
+    }
+}
+
 impl<T: MtSerialize> MtSerialize for Vec<T> {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        self.as_slice().mt_serialize::<C>(writer)
+        mt_serialize_seq::<C, _>(writer, self.iter())
+    }
+}
+
+impl<T: MtDeserialize> MtDeserialize for Vec<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        mt_deserialize_seq::<C, _>(reader)?.try_collect()
     }
 }
 
-impl<T: MtSerialize> MtSerialize for HashSet<T> {
+impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        C::write_len(self.len(), writer)?;
-        for item in self.iter() {
-            item.mt_serialize::<DefCfg>(writer)?;
-        }
+        mt_serialize_seq::<C, _>(writer, self.iter())
+    }
+}
+
+impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        mt_deserialize_seq::<C, _>(reader)?.try_collect()
+    }
+}
+
+impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
+    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
+        self.0.mt_serialize::<DefCfg>(writer)?;
+        self.1.mt_serialize::<DefCfg>(writer)?;
+
         Ok(())
     }
 }
 
+impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        let a = A::mt_deserialize::<DefCfg>(reader)?;
+        let b = B::mt_deserialize::<DefCfg>(reader)?;
+
+        Ok((a, b))
+    }
+}
+
 impl<K, V> MtSerialize for HashMap<K, V>
 where
     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
     V: MtSerialize,
 {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        C::write_len(self.len(), writer)?;
-        for (key, value) in self.iter() {
-            key.mt_serialize::<DefCfg>(writer)?;
-            value.mt_serialize::<DefCfg>(writer)?;
-        }
-        Ok(())
+        mt_serialize_seq::<C, _>(writer, self.iter())
+    }
+}
+
+impl<K, V> MtDeserialize for HashMap<K, V>
+where
+    K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
+    V: MtDeserialize,
+{
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        mt_deserialize_seq::<C, _>(reader)?.try_collect()
     }
 }
 
@@ -337,10 +448,36 @@ impl MtSerialize for String {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
         if C::utf16() {
             self.encode_utf16()
-                .collect::<Vec<_>>()
+                .collect::<Vec<_>>() // FIXME: is this allocation necessary?
                 .mt_serialize::<C>(writer)
         } else {
-            self.as_bytes().mt_serialize::<C>(writer)
+            mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
+        }
+    }
+}
+
+impl MtDeserialize for String {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        if C::utf16() {
+            let mut err = None;
+
+            let res =
+                char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
+                    Ok(v) => Some(v),
+                    Err(e) => {
+                        err = Some(e);
+                        None
+                    }
+                }))
+                .try_collect();
+
+            match err {
+                None => Ok(res?),
+                Some(e) => Err(e),
+            }
+        } else {
+            // TODO: UTF-8 decode
+            Ok("".into())
         }
     }
 }
@@ -351,6 +488,12 @@ impl<T: MtSerialize> MtSerialize for Box<T> {
     }
 }
 
+impl<T: MtDeserialize> MtDeserialize for Box<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        Ok(Self::new(T::mt_deserialize::<C>(reader)?))
+    }
+}
+
 mod to_clt;
 mod to_srv;
 
index bdd7de7653df643a23a8c0fc71ac348042640388..3a29d7b50a2a4c2597124691d67f3f57ed31f069 100644 (file)
@@ -115,18 +115,32 @@ pub struct MinimapModePkt {
     modes: Vec<MinimapMode>,
 }
 
+#[cfg(feature = "server")]
 impl MtSerialize for MinimapModePkt {
     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
-        C::write_len(self.modes.len(), writer)?;
+        DefCfg::write_len(self.modes.len(), writer)?;
         self.current.mt_serialize::<DefCfg>(writer)?;
-        for item in self.modes.iter() {
-            item.mt_serialize::<DefCfg>(writer)?;
-        }
+        self.modes.mt_serialize::<NoLen>(writer)?;
+
         Ok(())
     }
 }
+
+#[cfg(feature = "client")]
+impl MtDeserialize for MinimapModePkt {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        let range = DefCfg::read_len(reader)?;
+        let current = MtDeserialize::mt_deserialize::<DefCfg>(reader)?;
+        let modes = range
+            .map(|_| MtDeserialize::mt_deserialize::<DefCfg>(reader))
+            .try_collect()?;
+
+        Ok(Self { current, modes })
+    }
+}
+
 /*
-TODO: rustify
+TODO: rustify this
 
 var DefaultMinimap = []MinimapMode{
     {Type: NoMinimap},