1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
13 collections::{HashMap, HashSet},
16 io::{self, Read, Write},
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27 #[error("io error: {0}")]
28 IoError(#[from] io::Error),
29 #[error("collection too big: {0}")]
30 TooBig(#[from] TryFromIntError),
33 impl From<Infallible> for SerializeError {
34 fn from(_err: Infallible) -> Self {
35 unreachable!("infallible")
39 #[derive(Error, Debug)]
40 pub enum DeserializeError {
41 #[error("io error: {0}")]
43 #[error("unexpected end of file")]
45 #[error("collection too big: {0}")]
46 TooBig(#[from] TryFromIntError),
47 #[error("invalid UTF-16: {0}")]
48 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
49 #[error("invalid {0} enum variant {1}")]
50 InvalidEnumVariant(&'static str, u64),
51 #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
52 InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
55 impl From<Infallible> for DeserializeError {
56 fn from(_err: Infallible) -> Self {
57 unreachable!("infallible")
61 impl From<io::Error> for DeserializeError {
62 fn from(err: io::Error) -> Self {
63 if err.kind() == io::ErrorKind::UnexpectedEof {
64 DeserializeError::UnexpectedEof
66 DeserializeError::IoError(err)
71 pub trait OrDefault<T> {
72 fn or_default(self) -> Self;
75 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
76 impl<'a, R: Read> Read for WrapRead<'a, R> {
77 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
81 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
82 self.0.read_vectored(bufs)
87 fn is_read_vectored(&self) -> bool {
88 self.0.is_read_vectored()
93 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
94 self.0.read_to_end(buf)
97 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
98 self.0.read_to_string(buf)
101 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
102 self.0.read_exact(buf)
107 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
111 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
112 self.0.read_buf_exact(cursor)
118 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
119 fn or_default(self) -> Self {
121 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
128 fn option(&self) -> Option<usize>;
130 type Range: Iterator<Item = usize> + 'static;
131 fn range(&self) -> Self::Range;
133 type Take<R: Read>: Read;
134 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
145 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
146 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
149 pub trait MtSerialize {
150 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
153 pub trait MtDeserialize: Sized {
154 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
157 impl MtLen for usize {
158 fn option(&self) -> Option<usize> {
162 type Range = std::ops::Range<usize>;
163 fn range(&self) -> Self::Range {
167 type Take<R: Read> = io::Take<R>;
168 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
169 reader.take(*self as u64)
177 + TryFrom<usize, Error: Into<SerializeError>>
178 + TryInto<usize, Error: Into<DeserializeError>>
182 impl<T: MtCfgLen> MtCfg for T {
186 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
188 .map_err(Into::into)?
189 .mt_serialize::<DefCfg>(writer)
192 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
193 Ok(Self::mt_deserialize::<DefCfg>(reader)?
195 .map_err(Into::into)?)
199 impl MtCfgLen for u8 {}
200 impl MtCfgLen for u16 {}
201 impl MtCfgLen for u32 {}
202 impl MtCfgLen for u64 {}
204 pub type DefCfg = u16;
210 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
214 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
220 fn option(&self) -> Option<usize> {
224 type Range = std::ops::RangeFrom<usize>;
225 fn range(&self) -> Self::Range {
229 type Take<R: Read> = R;
230 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
235 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
237 impl<B: MtCfg> MtCfg for Utf16<B> {
239 type Inner = B::Inner;
245 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
246 B::write_len(len, writer)
249 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
254 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
258 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
259 A::write_len(len, writer)
262 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
267 impl MtSerialize for u8 {
268 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
269 writer.write_u8(*self)?;
274 impl MtDeserialize for u8 {
275 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
276 Ok(reader.read_u8()?)
280 impl MtSerialize for i8 {
281 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
282 writer.write_i8(*self)?;
287 impl MtDeserialize for i8 {
288 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
289 Ok(reader.read_i8()?)
293 macro_rules! impl_num {
295 impl MtSerialize for $T {
296 fn mt_serialize<C: MtCfg>(
298 writer: &mut impl Write,
299 ) -> Result<(), SerializeError> {
301 writer.[<write_ $T>]::<BigEndian>(*self)?;
307 impl MtDeserialize for $T {
308 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
310 Ok(reader.[<read_ $T>]::<BigEndian>()?)
328 impl MtSerialize for () {
329 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
334 impl MtDeserialize for () {
335 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
340 impl MtSerialize for bool {
341 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
342 (*self as u8).mt_serialize::<DefCfg>(writer)
346 impl MtDeserialize for bool {
347 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
348 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
352 impl<T: MtSerialize> MtSerialize for &T {
353 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
354 (*self).mt_serialize::<C>(writer)
358 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
359 writer: &mut impl Write,
360 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
361 ) -> Result<(), SerializeError> {
362 C::write_len(iter.len(), writer)?;
365 .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
368 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
369 reader: &'a mut impl Read,
370 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
371 let len = C::read_len(reader)?;
372 mt_deserialize_sized_seq::<C, _>(&len, reader)
375 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
377 reader: &'a mut impl Read,
378 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
379 let variable = len.option().is_none();
383 .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
384 Err(DeserializeError::UnexpectedEof) if variable => None,
389 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
390 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
391 mt_serialize_seq::<(), _>(writer, self.iter())
395 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
396 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
397 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
401 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
402 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
403 self.as_repr().mt_serialize::<DefCfg>(writer)
407 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
408 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
409 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
415 impl<T: MtSerialize> MtSerialize for Option<T> {
416 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
418 Some(item) => item.mt_serialize::<C>(writer),
424 impl<T: MtDeserialize> MtDeserialize for Option<T> {
425 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
426 T::mt_deserialize::<C>(reader).map(Some).or_default()
430 impl<T: MtSerialize> MtSerialize for Vec<T> {
431 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
432 mt_serialize_seq::<C, _>(writer, self.iter())
436 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
437 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
438 mt_deserialize_seq::<C, _>(reader)?.try_collect()
442 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
443 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
444 mt_serialize_seq::<C, _>(writer, self.iter())
448 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
449 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
450 mt_deserialize_seq::<C, _>(reader)?.try_collect()
454 // TODO: support more tuples
455 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
456 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
457 self.0.mt_serialize::<C>(writer)?;
458 self.1.mt_serialize::<C::Inner>(writer)?;
464 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
465 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
466 let a = A::mt_deserialize::<C>(reader)?;
467 let b = B::mt_deserialize::<C::Inner>(reader)?;
473 impl<K, V> MtSerialize for HashMap<K, V>
475 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
478 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
479 mt_serialize_seq::<C, _>(writer, self.iter())
483 impl<K, V> MtDeserialize for HashMap<K, V>
485 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
488 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
489 mt_deserialize_seq::<C, _>(reader)?.try_collect()
493 impl MtSerialize for String {
494 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
497 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
498 .mt_serialize::<C>(writer)
500 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
505 impl MtDeserialize for String {
506 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
511 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
525 let len = C::read_len(reader)?;
527 // use capacity if available
528 let mut st = match len.option() {
529 Some(x) => String::with_capacity(x),
530 None => String::new(),
533 len.take(WrapRead(reader)).read_to_string(&mut st)?;
540 impl<T: MtSerialize> MtSerialize for Box<T> {
541 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
542 self.deref().mt_serialize::<C>(writer)
546 impl<T: MtDeserialize> MtDeserialize for Box<T> {
547 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
548 Ok(Self::new(T::mt_deserialize::<C>(reader)?))