1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
10 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11 use enumset::{EnumSet, EnumSetTypeWithRepr};
12 use paste::paste as paste_macro;
14 collections::{HashMap, HashSet},
17 io::{self, Read, Write},
19 ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
28 #[derive(Error, Debug)]
29 pub enum SerializeError {
30 #[error("io error: {0}")]
31 IoError(#[from] io::Error),
32 #[error("collection too big: {0}")]
33 TooBig(#[from] TryFromIntError),
38 impl From<Infallible> for SerializeError {
39 fn from(_err: Infallible) -> Self {
40 unreachable!("infallible")
44 #[derive(Error, Debug)]
45 pub enum DeserializeError {
46 #[error("io error: {0}")]
48 #[error("unexpected end of file")]
50 #[error("collection too big: {0}")]
51 TooBig(#[from] TryFromIntError),
52 #[error("invalid UTF-16: {0}")]
53 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
54 #[error("invalid {0} enum variant {1:?}")]
55 InvalidEnum(&'static str, Box<dyn Debug>),
56 #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
57 InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
62 impl From<Infallible> for DeserializeError {
63 fn from(_err: Infallible) -> Self {
64 unreachable!("infallible")
68 impl From<io::Error> for DeserializeError {
69 fn from(err: io::Error) -> Self {
70 if err.kind() == io::ErrorKind::UnexpectedEof {
71 DeserializeError::UnexpectedEof
73 DeserializeError::IoError(err)
78 pub trait OrDefault<T> {
79 fn or_default(self) -> Self;
82 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
83 impl<'a, R: Read> Read for WrapRead<'a, R> {
84 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
88 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
89 self.0.read_vectored(bufs)
94 fn is_read_vectored(&self) -> bool {
95 self.0.is_read_vectored()
100 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
101 self.0.read_to_end(buf)
104 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
105 self.0.read_to_string(buf)
108 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
109 self.0.read_exact(buf)
114 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
118 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
119 self.0.read_buf_exact(cursor)
125 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
126 fn or_default(self) -> Self {
128 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
135 fn option(&self) -> Option<usize>;
137 type Range: Iterator<Item = usize> + 'static;
138 fn range(&self) -> Self::Range;
140 type Take<R: Read>: Read;
141 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
152 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
153 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
156 pub trait MtSerialize {
157 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
160 pub trait MtDeserialize: Sized {
161 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
164 impl MtLen for usize {
165 fn option(&self) -> Option<usize> {
169 type Range = std::ops::Range<usize>;
170 fn range(&self) -> Self::Range {
174 type Take<R: Read> = io::Take<R>;
175 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
176 reader.take(*self as u64)
184 + TryFrom<usize, Error: Into<SerializeError>>
185 + TryInto<usize, Error: Into<DeserializeError>>
189 impl<T: MtCfgLen> MtCfg for T {
193 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
195 .map_err(Into::into)?
196 .mt_serialize::<DefCfg>(writer)
199 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
200 Self::mt_deserialize::<DefCfg>(reader)?
206 impl MtCfgLen for u8 {}
207 impl MtCfgLen for u16 {}
208 impl MtCfgLen for u32 {}
209 impl MtCfgLen for u64 {}
211 pub type DefCfg = u16;
217 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
221 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
227 fn option(&self) -> Option<usize> {
231 type Range = std::ops::RangeFrom<usize>;
232 fn range(&self) -> Self::Range {
236 type Take<R: Read> = R;
237 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
242 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
244 impl<B: MtCfg> MtCfg for Utf16<B> {
246 type Inner = B::Inner;
252 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
253 B::write_len(len, writer)
256 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
261 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
265 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
266 A::write_len(len, writer)
269 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
274 impl MtSerialize for u8 {
275 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
276 writer.write_u8(*self)?;
281 impl MtDeserialize for u8 {
282 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
283 Ok(reader.read_u8()?)
287 impl MtSerialize for i8 {
288 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
289 writer.write_i8(*self)?;
294 impl MtDeserialize for i8 {
295 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
296 Ok(reader.read_i8()?)
300 macro_rules! impl_num {
302 impl MtSerialize for $T {
303 fn mt_serialize<C: MtCfg>(
305 writer: &mut impl Write,
306 ) -> Result<(), SerializeError> {
308 writer.[<write_ $T>]::<BigEndian>(*self)?;
314 impl MtDeserialize for $T {
315 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
317 Ok(reader.[<read_ $T>]::<BigEndian>()?)
335 impl MtSerialize for () {
336 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
341 impl MtDeserialize for () {
342 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
347 impl MtSerialize for bool {
348 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
349 (*self as u8).mt_serialize::<DefCfg>(writer)
353 impl MtDeserialize for bool {
354 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
355 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
359 impl<T: MtSerialize> MtSerialize for &T {
360 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
361 (*self).mt_serialize::<C>(writer)
365 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
366 writer: &mut impl Write,
367 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
368 ) -> Result<(), SerializeError> {
369 C::write_len(iter.len(), writer)?;
372 .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
375 pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
376 reader: &mut impl Read,
377 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
378 let len = C::read_len(reader)?;
379 mt_deserialize_sized_seq::<C, _>(&len, reader)
382 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
384 reader: &'a mut impl Read,
385 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
386 let variable = len.option().is_none();
390 .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
391 Err(DeserializeError::UnexpectedEof) if variable => None,
396 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
397 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
398 mt_serialize_seq::<(), _>(writer, self.iter())
402 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
403 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
404 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
408 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
409 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
410 self.as_repr().mt_serialize::<DefCfg>(writer)
414 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
415 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
416 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
422 impl<T: MtSerialize> MtSerialize for Option<T> {
423 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
425 Some(item) => item.mt_serialize::<C>(writer),
431 impl<T: MtDeserialize> MtDeserialize for Option<T> {
432 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
433 T::mt_deserialize::<C>(reader).map(Some).or_default()
437 impl<T: MtSerialize> MtSerialize for Vec<T> {
438 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
439 mt_serialize_seq::<C, _>(writer, self.iter())
443 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
444 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
445 mt_deserialize_seq::<C, _>(reader)?.try_collect()
449 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
450 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
451 mt_serialize_seq::<C, _>(writer, self.iter())
455 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
456 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
457 mt_deserialize_seq::<C, _>(reader)?.try_collect()
461 // TODO: support more tuples
462 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
463 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
464 self.0.mt_serialize::<C>(writer)?;
465 self.1.mt_serialize::<C::Inner>(writer)?;
471 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
472 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
473 let a = A::mt_deserialize::<C>(reader)?;
474 let b = B::mt_deserialize::<C::Inner>(reader)?;
480 impl<K, V> MtSerialize for HashMap<K, V>
482 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
485 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
486 mt_serialize_seq::<C, _>(writer, self.iter())
490 impl<K, V> MtDeserialize for HashMap<K, V>
492 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
495 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
496 mt_deserialize_seq::<C, _>(reader)?.try_collect()
500 impl MtSerialize for &str {
501 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
504 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
505 .mt_serialize::<C>(writer)
507 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
512 impl MtSerialize for String {
513 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
514 self.as_str().mt_serialize::<C>(writer)
518 impl MtDeserialize for String {
519 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
524 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
538 let len = C::read_len(reader)?;
540 // use capacity if available
541 let mut st = match len.option() {
542 Some(x) => String::with_capacity(x),
543 None => String::new(),
546 len.take(WrapRead(reader)).read_to_string(&mut st)?;
553 impl<T: MtSerialize> MtSerialize for Box<T> {
554 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
555 self.deref().mt_serialize::<C>(writer)
559 impl<T: MtDeserialize> MtDeserialize for Box<T> {
560 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
561 Ok(Self::new(T::mt_deserialize::<C>(reader)?))
565 #[derive(MtSerialize, MtDeserialize)]
566 #[mt(typename = "Range")]
568 struct RemoteRange<T> {
573 #[derive(MtSerialize, MtDeserialize)]
574 #[mt(typename = "RangeFrom")]
576 struct RemoteRangeFrom<T> {
580 #[derive(MtSerialize, MtDeserialize)]
581 #[mt(typename = "RangeFull")]
583 struct RemoteRangeFull;
585 // RangeInclusive fields are private
586 impl<T: MtSerialize> MtSerialize for RangeInclusive<T> {
587 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
588 self.start().mt_serialize::<DefCfg>(writer)?;
589 self.end().mt_serialize::<DefCfg>(writer)?;
595 impl<T: MtDeserialize> MtDeserialize for RangeInclusive<T> {
596 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
597 let start = T::mt_deserialize::<DefCfg>(reader)?;
598 let end = T::mt_deserialize::<DefCfg>(reader)?;
604 #[derive(MtSerialize, MtDeserialize)]
605 #[mt(typename = "RangeTo")]
607 struct RemoteRangeTo<T> {
611 #[derive(MtSerialize, MtDeserialize)]
612 #[mt(typename = "RangeToInclusive")]
614 struct RemoteRangeToInclusive<T> {