1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
13 collections::{HashMap, HashSet},
15 io::{self, Read, Write},
24 #[derive(Error, Debug)]
25 pub enum SerializeError {
26 #[error("io error: {0}")]
27 IoError(#[from] io::Error),
28 #[error("collection too big: {0}")]
29 TooBig(#[from] TryFromIntError),
32 impl From<Infallible> for SerializeError {
33 fn from(_err: Infallible) -> Self {
34 unreachable!("infallible")
38 #[derive(Error, Debug)]
39 pub enum DeserializeError {
40 #[error("io error: {0}")]
42 #[error("unexpected end of file")]
44 #[error("collection too big: {0}")]
45 TooBig(#[from] TryFromIntError),
46 #[error("invalid UTF-16: {0}")]
47 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
48 #[error("invalid {0} enum variant {1}")]
49 InvalidEnumVariant(&'static str, u64),
50 #[error("invalid constant - wanted: {0} - got: {1}")]
51 InvalidConst(u64, u64),
54 impl From<Infallible> for DeserializeError {
55 fn from(_err: Infallible) -> Self {
56 unreachable!("infallible")
60 impl From<io::Error> for DeserializeError {
61 fn from(err: io::Error) -> Self {
62 if err.kind() == io::ErrorKind::UnexpectedEof {
63 DeserializeError::UnexpectedEof
65 DeserializeError::IoError(err)
70 pub trait OrDefault<T> {
71 fn or_default(self) -> Self;
74 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
75 impl<'a, R: Read> Read for WrapRead<'a, R> {
76 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
80 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
81 self.0.read_vectored(bufs)
86 fn is_read_vectored(&self) -> bool {
87 self.0.is_read_vectored()
92 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
93 self.0.read_to_end(buf)
96 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
97 self.0.read_to_string(buf)
100 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
101 self.0.read_exact(buf)
106 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
110 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
111 self.0.read_buf_exact(cursor)
117 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
118 fn or_default(self) -> Self {
120 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
127 fn option(&self) -> Option<usize>;
129 type Range: Iterator<Item = usize> + 'static;
130 fn range(&self) -> Self::Range;
132 type Take<R: Read>: Read;
133 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
143 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
144 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
147 pub trait MtSerialize {
148 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
151 pub trait MtDeserialize: Sized {
152 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
155 impl MtLen for usize {
156 fn option(&self) -> Option<usize> {
160 type Range = std::ops::Range<usize>;
161 fn range(&self) -> Self::Range {
165 type Take<R: Read> = io::Take<R>;
166 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
167 reader.take(*self as u64)
175 + TryFrom<usize, Error: Into<SerializeError>>
176 + TryInto<usize, Error: Into<DeserializeError>>
180 impl<T: MtCfgLen> MtCfg for T {
183 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
185 .map_err(Into::into)?
186 .mt_serialize::<DefCfg>(writer)
189 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
190 Ok(Self::mt_deserialize::<DefCfg>(reader)?
192 .map_err(Into::into)?)
196 impl MtCfgLen for u8 {}
197 impl MtCfgLen for u16 {}
198 impl MtCfgLen for u32 {}
199 impl MtCfgLen for u64 {}
201 pub type DefCfg = u16;
206 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
210 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
216 fn option(&self) -> Option<usize> {
220 type Range = std::ops::RangeFrom<usize>;
221 fn range(&self) -> Self::Range {
225 type Take<R: Read> = R;
226 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
231 pub struct Utf16<B: MtCfg>(pub B);
233 impl<B: MtCfg> MtCfg for Utf16<B> {
240 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
241 B::write_len(len, writer)
244 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
249 impl MtSerialize for u8 {
250 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
251 writer.write_u8(*self)?;
256 impl MtDeserialize for u8 {
257 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
258 Ok(reader.read_u8()?)
262 impl MtSerialize for i8 {
263 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
264 writer.write_i8(*self)?;
269 impl MtDeserialize for i8 {
270 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
271 Ok(reader.read_i8()?)
275 macro_rules! impl_num {
277 impl MtSerialize for $T {
278 fn mt_serialize<C: MtCfg>(
280 writer: &mut impl Write,
281 ) -> Result<(), SerializeError> {
283 writer.[<write_ $T>]::<BigEndian>(*self)?;
289 impl MtDeserialize for $T {
290 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
292 Ok(reader.[<read_ $T>]::<BigEndian>()?)
310 impl MtSerialize for () {
311 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
316 impl MtDeserialize for () {
317 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
322 impl MtSerialize for bool {
323 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
324 (*self as u8).mt_serialize::<DefCfg>(writer)
328 impl MtDeserialize for bool {
329 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
330 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
334 impl<T: MtSerialize> MtSerialize for &T {
335 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
336 (*self).mt_serialize::<C>(writer)
340 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
341 writer: &mut impl Write,
342 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
343 ) -> Result<(), SerializeError> {
344 C::write_len(iter.len(), writer)?;
347 .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
350 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
351 reader: &'a mut impl Read,
352 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
353 let len = C::read_len(reader)?;
354 mt_deserialize_sized_seq(&len, reader)
357 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
359 reader: &'a mut impl Read,
360 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
361 let variable = len.option().is_none();
365 .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
366 Err(DeserializeError::UnexpectedEof) if variable => None,
371 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
372 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
373 mt_serialize_seq::<(), _>(writer, self.iter())
377 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
378 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
379 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
383 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
384 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
385 self.as_repr().mt_serialize::<DefCfg>(writer)
389 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
390 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
391 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
397 impl<T: MtSerialize> MtSerialize for Option<T> {
398 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
400 Some(item) => item.mt_serialize::<C>(writer),
406 impl<T: MtDeserialize> MtDeserialize for Option<T> {
407 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
408 T::mt_deserialize::<C>(reader).map(Some).or_default()
412 impl<T: MtSerialize> MtSerialize for Vec<T> {
413 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
414 mt_serialize_seq::<C, _>(writer, self.iter())
418 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
419 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
420 mt_deserialize_seq::<C, _>(reader)?.try_collect()
424 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
425 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
426 mt_serialize_seq::<C, _>(writer, self.iter())
430 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
431 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
432 mt_deserialize_seq::<C, _>(reader)?.try_collect()
436 // TODO: support more tuples
437 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
438 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
439 self.0.mt_serialize::<DefCfg>(writer)?;
440 self.1.mt_serialize::<DefCfg>(writer)?;
446 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
447 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
448 let a = A::mt_deserialize::<DefCfg>(reader)?;
449 let b = B::mt_deserialize::<DefCfg>(reader)?;
455 impl<K, V> MtSerialize for HashMap<K, V>
457 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
460 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
461 mt_serialize_seq::<C, _>(writer, self.iter())
465 impl<K, V> MtDeserialize for HashMap<K, V>
467 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
470 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
471 mt_deserialize_seq::<C, _>(reader)?.try_collect()
475 impl MtSerialize for String {
476 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
479 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
480 .mt_serialize::<C>(writer)
482 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
487 impl MtDeserialize for String {
488 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
493 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
507 let len = C::read_len(reader)?;
509 // use capacity if available
510 let mut st = match len.option() {
511 Some(x) => String::with_capacity(x),
512 None => String::new(),
515 len.take(WrapRead(reader)).read_to_string(&mut st)?;
522 impl<T: MtSerialize> MtSerialize for Box<T> {
523 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
524 self.deref().mt_serialize::<C>(writer)
528 impl<T: MtDeserialize> MtDeserialize for Box<T> {
529 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
530 Ok(Self::new(T::mt_deserialize::<C>(reader)?))