1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
13 collections::{HashMap, HashSet},
16 io::{self, Read, Write},
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27 #[error("io error: {0}")]
28 IoError(#[from] io::Error),
29 #[error("collection too big: {0}")]
30 TooBig(#[from] TryFromIntError),
33 impl From<Infallible> for SerializeError {
34 fn from(_err: Infallible) -> Self {
35 unreachable!("infallible")
39 #[derive(Error, Debug)]
40 pub enum DeserializeError {
41 #[error("io error: {0}")]
43 #[error("unexpected end of file")]
45 #[error("collection too big: {0}")]
46 TooBig(#[from] TryFromIntError),
47 #[error("invalid UTF-16: {0}")]
48 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
49 #[error("invalid {0} enum variant {1}")]
50 InvalidEnumVariant(&'static str, u64),
51 #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
52 InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
55 impl From<Infallible> for DeserializeError {
56 fn from(_err: Infallible) -> Self {
57 unreachable!("infallible")
61 impl From<io::Error> for DeserializeError {
62 fn from(err: io::Error) -> Self {
63 if err.kind() == io::ErrorKind::UnexpectedEof {
64 DeserializeError::UnexpectedEof
66 DeserializeError::IoError(err)
71 pub trait OrDefault<T> {
72 fn or_default(self) -> Self;
75 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
76 impl<'a, R: Read> Read for WrapRead<'a, R> {
77 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
81 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
82 self.0.read_vectored(bufs)
87 fn is_read_vectored(&self) -> bool {
88 self.0.is_read_vectored()
93 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
94 self.0.read_to_end(buf)
97 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
98 self.0.read_to_string(buf)
101 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
102 self.0.read_exact(buf)
107 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
111 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
112 self.0.read_buf_exact(cursor)
118 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
119 fn or_default(self) -> Self {
121 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
128 fn option(&self) -> Option<usize>;
130 type Range: Iterator<Item = usize> + 'static;
131 fn range(&self) -> Self::Range;
133 type Take<R: Read>: Read;
134 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
144 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
145 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
148 pub trait MtSerialize {
149 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
152 pub trait MtDeserialize: Sized {
153 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
156 impl MtLen for usize {
157 fn option(&self) -> Option<usize> {
161 type Range = std::ops::Range<usize>;
162 fn range(&self) -> Self::Range {
166 type Take<R: Read> = io::Take<R>;
167 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
168 reader.take(*self as u64)
176 + TryFrom<usize, Error: Into<SerializeError>>
177 + TryInto<usize, Error: Into<DeserializeError>>
181 impl<T: MtCfgLen> MtCfg for T {
184 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
186 .map_err(Into::into)?
187 .mt_serialize::<DefCfg>(writer)
190 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
191 Ok(Self::mt_deserialize::<DefCfg>(reader)?
193 .map_err(Into::into)?)
197 impl MtCfgLen for u8 {}
198 impl MtCfgLen for u16 {}
199 impl MtCfgLen for u32 {}
200 impl MtCfgLen for u64 {}
202 pub type DefCfg = u16;
207 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
211 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
217 fn option(&self) -> Option<usize> {
221 type Range = std::ops::RangeFrom<usize>;
222 fn range(&self) -> Self::Range {
226 type Take<R: Read> = R;
227 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
232 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
234 impl<B: MtCfg> MtCfg for Utf16<B> {
241 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
242 B::write_len(len, writer)
245 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
250 impl MtSerialize for u8 {
251 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
252 writer.write_u8(*self)?;
257 impl MtDeserialize for u8 {
258 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
259 Ok(reader.read_u8()?)
263 impl MtSerialize for i8 {
264 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
265 writer.write_i8(*self)?;
270 impl MtDeserialize for i8 {
271 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
272 Ok(reader.read_i8()?)
276 macro_rules! impl_num {
278 impl MtSerialize for $T {
279 fn mt_serialize<C: MtCfg>(
281 writer: &mut impl Write,
282 ) -> Result<(), SerializeError> {
284 writer.[<write_ $T>]::<BigEndian>(*self)?;
290 impl MtDeserialize for $T {
291 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
293 Ok(reader.[<read_ $T>]::<BigEndian>()?)
311 impl MtSerialize for () {
312 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
317 impl MtDeserialize for () {
318 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
323 impl MtSerialize for bool {
324 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
325 (*self as u8).mt_serialize::<DefCfg>(writer)
329 impl MtDeserialize for bool {
330 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
331 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
335 impl<T: MtSerialize> MtSerialize for &T {
336 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
337 (*self).mt_serialize::<C>(writer)
341 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
342 writer: &mut impl Write,
343 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
344 ) -> Result<(), SerializeError> {
345 C::write_len(iter.len(), writer)?;
348 .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
351 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
352 reader: &'a mut impl Read,
353 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
354 let len = C::read_len(reader)?;
355 mt_deserialize_sized_seq(&len, reader)
358 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
360 reader: &'a mut impl Read,
361 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
362 let variable = len.option().is_none();
366 .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
367 Err(DeserializeError::UnexpectedEof) if variable => None,
372 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
373 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
374 mt_serialize_seq::<(), _>(writer, self.iter())
378 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
379 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
380 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
384 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
385 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
386 self.as_repr().mt_serialize::<DefCfg>(writer)
390 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
391 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
392 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
398 impl<T: MtSerialize> MtSerialize for Option<T> {
399 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
401 Some(item) => item.mt_serialize::<C>(writer),
407 impl<T: MtDeserialize> MtDeserialize for Option<T> {
408 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
409 T::mt_deserialize::<C>(reader).map(Some).or_default()
413 impl<T: MtSerialize> MtSerialize for Vec<T> {
414 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
415 mt_serialize_seq::<C, _>(writer, self.iter())
419 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
420 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
421 mt_deserialize_seq::<C, _>(reader)?.try_collect()
425 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
426 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
427 mt_serialize_seq::<C, _>(writer, self.iter())
431 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
432 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
433 mt_deserialize_seq::<C, _>(reader)?.try_collect()
437 // TODO: support more tuples
438 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
439 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
440 self.0.mt_serialize::<DefCfg>(writer)?;
441 self.1.mt_serialize::<DefCfg>(writer)?;
447 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
448 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
449 let a = A::mt_deserialize::<DefCfg>(reader)?;
450 let b = B::mt_deserialize::<DefCfg>(reader)?;
456 impl<K, V> MtSerialize for HashMap<K, V>
458 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
461 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
462 mt_serialize_seq::<C, _>(writer, self.iter())
466 impl<K, V> MtDeserialize for HashMap<K, V>
468 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
471 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
472 mt_deserialize_seq::<C, _>(reader)?.try_collect()
476 impl MtSerialize for String {
477 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
480 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
481 .mt_serialize::<C>(writer)
483 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
488 impl MtDeserialize for String {
489 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
494 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
508 let len = C::read_len(reader)?;
510 // use capacity if available
511 let mut st = match len.option() {
512 Some(x) => String::with_capacity(x),
513 None => String::new(),
516 len.take(WrapRead(reader)).read_to_string(&mut st)?;
523 impl<T: MtSerialize> MtSerialize for Box<T> {
524 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
525 self.deref().mt_serialize::<C>(writer)
529 impl<T: MtDeserialize> MtDeserialize for Box<T> {
530 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
531 Ok(Self::new(T::mt_deserialize::<C>(reader)?))