From 800bb04e808aa2881719857e5027d251afc047ac Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Thu, 9 Feb 2023 21:22:01 +0100 Subject: [PATCH] Implement deserialize for basic types --- Cargo.toml | 1 + derive/src/lib.rs | 17 +-- src/lib.rs | 315 +++++++++++++++++++++++++++++++++------------- src/to_clt/hud.rs | 24 +++- 4 files changed, 259 insertions(+), 98 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c735152..6451008 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [features] +all = ["client", "server", "random", "serde"] client = [] random = ["dep:generate-random", "dep:rand"] serde = ["dep:serde", "dep:serde_arrays", "enumset/serde"] diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 9838ea1..75bc693 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -305,9 +305,8 @@ fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> To impl_size!(size32, u32); impl_size!(size64, u64); - code - }, + } Err(e) => return e.write_errors(), } } @@ -315,10 +314,12 @@ fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> To fn serialize_fields(fields: &Fields) -> TokStr { fields .iter() - .map(|(ident, field)| serialize_args(MtArgs::from_field(field), |args| { - let cfg = get_cfg(args); - quote! { mt_data::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; } - })) + .map(|(ident, field)| { + serialize_args(MtArgs::from_field(field), |args| { + let cfg = get_cfg(args); + quote! { mt_data::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; } + }) + }) .collect() } @@ -343,7 +344,9 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr); let ident_fn = match &v.fields { - syn::Fields::Unnamed(_) => |f| quote! { mt_data::paste! { [] }}, + syn::Fields::Unnamed(_) => |f| quote! { + mt_data::paste::paste! { [] } + }, _ => |f| quote! { #f }, }; diff --git a/src/lib.rs b/src/lib.rs index 99bd58a..c3e5cee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,10 @@ +#![feature(array_try_from_fn)] +#![feature(associated_type_bounds)] +#![feature(iterator_try_collect)] + pub use enumset; pub use flate2; -pub use paste::paste; +pub use paste; #[cfg(feature = "random")] pub use generate_random; @@ -15,6 +19,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use enumset::{EnumSet, EnumSetType, EnumSetTypeWithRepr}; use mt_data_derive::mt_derive; pub use mt_data_derive::{MtDeserialize, MtSerialize}; +use paste::paste as paste_macro; use std::{ collections::{HashMap, HashSet}, convert::Infallible, @@ -31,23 +36,18 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "random")] use generate_random::GenerateRandom; +#[cfg(any(feature = "client", feature = "server"))] use crate as mt_data; -#[derive(Error, Debug)] -#[error("variable length")] -pub struct VarLen; - #[derive(Error, Debug)] pub enum SerializeError { #[error("io error: {0}")] IoError(#[from] io::Error), #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), - #[error("unimplemented")] - Unimplemented, } -impl From for DeserializeError { +impl From for SerializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } @@ -56,40 +56,88 @@ impl From for DeserializeError { #[derive(Error, Debug)] pub enum DeserializeError { #[error("io error: {0}")] - IoError(#[from] io::Error), - #[error("variable length not supported")] - NoVarlen(#[from] VarLen), + IoError(io::Error), + #[error("unexpected end of file")] + UnexpectedEof, #[error("collection too big: {0}")] TooBig(#[from] TryFromIntError), + #[error("invalid UTF-16: {0}")] + InvalidUtf16(#[from] std::char::DecodeUtf16Error), #[error("unimplemented")] Unimplemented, } -impl From for SerializeError { +impl From for DeserializeError { fn from(_err: Infallible) -> Self { unreachable!("infallible") } } +impl From for DeserializeError { + fn from(err: io::Error) -> Self { + if err.kind() == io::ErrorKind::UnexpectedEof { + DeserializeError::UnexpectedEof + } else { + DeserializeError::IoError(err) + } + } +} + +pub trait OrDefault { + fn or_default(self) -> Self; +} + +impl OrDefault for Result { + fn or_default(self) -> Self { + match self { + Err(DeserializeError::UnexpectedEof) => Ok(T::default()), + x => x, + } + } +} + pub trait MtCfg: + Sized + MtSerialize + MtDeserialize + TryFrom> +{ + type Range: Iterator + 'static; + + fn utf16() -> bool { + false + } + + fn var_len() -> bool; + + fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { + Self::try_from(len) + .map_err(Into::into)? + .mt_serialize::(writer) + } + + fn read_len(reader: &mut impl Read) -> Result; +} + +trait MtCfgLen: Sized + MtSerialize + MtDeserialize - + TryFrom - + TryInto + + TryFrom> + + TryInto> { - type TryFromError: Into; - type TryIntoError: Into; +} - #[inline] - fn utf16() -> bool { +impl MtCfg for T { + type Range = std::ops::Range; + + fn var_len() -> bool { false } - fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> { - Ok(Self::try_from(len) - .map_err(|e| e.into())? - .mt_serialize::(writer)?) + fn read_len(reader: &mut impl Read) -> Result { + let len = Self::mt_deserialize::(reader)? + .try_into() + .map_err(Into::into)?; + + Ok(0..len) } } @@ -103,26 +151,12 @@ pub trait MtDeserialize: Sized { fn mt_deserialize(reader: &mut impl Read) -> Result; } -impl MtCfg for u8 { - type TryFromError = TryFromIntError; - type TryIntoError = Infallible; -} - -impl MtCfg for u16 { - type TryFromError = TryFromIntError; - type TryIntoError = Infallible; -} - -impl MtCfg for u32 { - type TryFromError = TryFromIntError; - type TryIntoError = TryFromIntError; -} - -impl MtCfg for u64 { - type TryFromError = TryFromIntError; - type TryIntoError = TryFromIntError; -} +impl MtCfgLen for u8 {} +impl MtCfgLen for u16 {} +impl MtCfgLen for u32 {} +impl MtCfgLen for u64 {} +#[derive(Debug)] pub struct NoLen; impl MtSerialize for NoLen { @@ -145,17 +179,16 @@ impl TryFrom for NoLen { } } -impl TryInto for NoLen { - type Error = VarLen; - - fn try_into(self) -> Result { - Err(VarLen) +impl MtCfg for NoLen { + fn var_len() -> bool { + true } -} -impl MtCfg for NoLen { - type TryFromError = Infallible; - type TryIntoError = VarLen; + type Range = std::ops::RangeFrom; + + fn read_len(_reader: &mut impl Read) -> Result { + Ok(0..) + } } pub struct Utf16(pub B); @@ -173,29 +206,27 @@ impl MtDeserialize for Utf16 { } impl TryFrom for Utf16 { - type Error = B::TryFromError; + type Error = >::Error; fn try_from(x: usize) -> Result { Ok(Self(x.try_into()?)) } } -impl TryInto for Utf16 { - type Error = B::TryIntoError; - - fn try_into(self) -> Result { - self.0.try_into() - } -} - impl MtCfg for Utf16 { - type TryFromError = B::TryFromError; - type TryIntoError = B::TryIntoError; + type Range = B::Range; - #[inline] fn utf16() -> bool { true } + + fn var_len() -> bool { + B::var_len() + } + + fn read_len(reader: &mut impl Read) -> Result { + B::read_len(reader) + } } impl MtSerialize for u8 { @@ -231,7 +262,7 @@ macro_rules! impl_num { &self, writer: &mut impl Write, ) -> Result<(), SerializeError> { - paste! { + paste_macro! { writer.[]::(*self)?; } Ok(()) @@ -240,7 +271,7 @@ macro_rules! impl_num { impl MtDeserialize for $T { fn mt_deserialize(reader: &mut impl Read) -> Result { - paste! { + paste_macro! { Ok(reader.[]::()?) } } @@ -265,25 +296,60 @@ impl MtSerialize for () { } } +impl MtDeserialize for () { + fn mt_deserialize(_reader: &mut impl Read) -> Result { + Ok(()) + } +} + impl MtSerialize for bool { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { (*self as u8).mt_serialize::(writer) } } -impl MtSerialize for [T; N] { +impl MtDeserialize for bool { + fn mt_deserialize(reader: &mut impl Read) -> Result { + Ok(u8::mt_deserialize::(reader)? != 0) + } +} + +impl MtSerialize for &T { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - self.as_slice().mt_serialize::(writer) + (*self).mt_serialize::(writer) } } -impl MtSerialize for &[T] { +fn mt_serialize_seq( + writer: &mut impl Write, + iter: impl ExactSizeIterator + IntoIterator, +) -> Result<(), SerializeError> { + C::write_len(iter.len(), writer)?; + + iter.into_iter() + .try_for_each(|item| item.mt_serialize::(writer)) +} + +fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>( + reader: &'a mut impl Read, +) -> Result> + 'a, DeserializeError> { + Ok(C::read_len(reader)? + .into_iter() + .map_while(|_| match T::mt_deserialize::(reader) { + Err(DeserializeError::UnexpectedEof) if C::var_len() => None, + x => Some(x), + })) +} + +impl MtSerialize for [T; N] { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - C::write_len(self.len(), writer)?; - for item in self.iter() { - item.mt_serialize::(writer)?; - } - Ok(()) + mt_serialize_seq::(writer, self.iter()) + } +} + +impl MtDeserialize for [T; N] { + fn mt_deserialize(reader: &mut impl Read) -> Result { + std::array::try_from_fn(|_| T::mt_deserialize::(reader)) } } @@ -293,6 +359,14 @@ impl> MtSerialize for EnumSet> MtDeserialize for EnumSet { + fn mt_deserialize(reader: &mut impl Read) -> Result { + Ok(Self::from_repr_truncated(T::mt_deserialize::( + reader, + )?)) + } +} + impl MtSerialize for Option { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { match self { @@ -302,34 +376,71 @@ impl MtSerialize for Option { } } +impl MtDeserialize for Option { + fn mt_deserialize(reader: &mut impl Read) -> Result { + T::mt_deserialize::(reader).map(Some).or_default() + } +} + impl MtSerialize for Vec { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - self.as_slice().mt_serialize::(writer) + mt_serialize_seq::(writer, self.iter()) + } +} + +impl MtDeserialize for Vec { + fn mt_deserialize(reader: &mut impl Read) -> Result { + mt_deserialize_seq::(reader)?.try_collect() } } -impl MtSerialize for HashSet { +impl MtSerialize for HashSet { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - C::write_len(self.len(), writer)?; - for item in self.iter() { - item.mt_serialize::(writer)?; - } + mt_serialize_seq::(writer, self.iter()) + } +} + +impl MtDeserialize for HashSet { + fn mt_deserialize(reader: &mut impl Read) -> Result { + mt_deserialize_seq::(reader)?.try_collect() + } +} + +impl MtSerialize for (A, B) { + fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { + self.0.mt_serialize::(writer)?; + self.1.mt_serialize::(writer)?; + Ok(()) } } +impl MtDeserialize for (A, B) { + fn mt_deserialize(reader: &mut impl Read) -> Result { + let a = A::mt_deserialize::(reader)?; + let b = B::mt_deserialize::(reader)?; + + Ok((a, b)) + } +} + impl MtSerialize for HashMap where K: MtSerialize + std::cmp::Eq + std::hash::Hash, V: MtSerialize, { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - C::write_len(self.len(), writer)?; - for (key, value) in self.iter() { - key.mt_serialize::(writer)?; - value.mt_serialize::(writer)?; - } - Ok(()) + mt_serialize_seq::(writer, self.iter()) + } +} + +impl MtDeserialize for HashMap +where + K: MtDeserialize + std::cmp::Eq + std::hash::Hash, + V: MtDeserialize, +{ + fn mt_deserialize(reader: &mut impl Read) -> Result { + mt_deserialize_seq::(reader)?.try_collect() } } @@ -337,10 +448,36 @@ impl MtSerialize for String { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { if C::utf16() { self.encode_utf16() - .collect::>() + .collect::>() // FIXME: is this allocation necessary? .mt_serialize::(writer) } else { - self.as_bytes().mt_serialize::(writer) + mt_serialize_seq::(writer, self.as_bytes().iter()) + } + } +} + +impl MtDeserialize for String { + fn mt_deserialize(reader: &mut impl Read) -> Result { + if C::utf16() { + let mut err = None; + + let res = + char::decode_utf16(mt_deserialize_seq::(reader)?.map_while(|x| match x { + Ok(v) => Some(v), + Err(e) => { + err = Some(e); + None + } + })) + .try_collect(); + + match err { + None => Ok(res?), + Some(e) => Err(e), + } + } else { + // TODO: UTF-8 decode + Ok("".into()) } } } @@ -351,6 +488,12 @@ impl MtSerialize for Box { } } +impl MtDeserialize for Box { + fn mt_deserialize(reader: &mut impl Read) -> Result { + Ok(Self::new(T::mt_deserialize::(reader)?)) + } +} + mod to_clt; mod to_srv; diff --git a/src/to_clt/hud.rs b/src/to_clt/hud.rs index bdd7de7..3a29d7b 100644 --- a/src/to_clt/hud.rs +++ b/src/to_clt/hud.rs @@ -115,18 +115,32 @@ pub struct MinimapModePkt { modes: Vec, } +#[cfg(feature = "server")] impl MtSerialize for MinimapModePkt { fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { - C::write_len(self.modes.len(), writer)?; + DefCfg::write_len(self.modes.len(), writer)?; self.current.mt_serialize::(writer)?; - for item in self.modes.iter() { - item.mt_serialize::(writer)?; - } + self.modes.mt_serialize::(writer)?; + Ok(()) } } + +#[cfg(feature = "client")] +impl MtDeserialize for MinimapModePkt { + fn mt_deserialize(reader: &mut impl Read) -> Result { + let range = DefCfg::read_len(reader)?; + let current = MtDeserialize::mt_deserialize::(reader)?; + let modes = range + .map(|_| MtDeserialize::mt_deserialize::(reader)) + .try_collect()?; + + Ok(Self { current, modes }) + } +} + /* -TODO: rustify +TODO: rustify this var DefaultMinimap = []MinimapMode{ {Type: NoMinimap}, -- 2.44.0