1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
13 collections::{HashMap, HashSet},
16 io::{self, Read, Write},
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27 #[error("io error: {0}")]
28 IoError(#[from] io::Error),
29 #[error("collection too big: {0}")]
30 TooBig(#[from] TryFromIntError),
35 impl From<Infallible> for SerializeError {
36 fn from(_err: Infallible) -> Self {
37 unreachable!("infallible")
41 #[derive(Error, Debug)]
42 pub enum DeserializeError {
43 #[error("io error: {0}")]
45 #[error("unexpected end of file")]
47 #[error("collection too big: {0}")]
48 TooBig(#[from] TryFromIntError),
49 #[error("invalid UTF-16: {0}")]
50 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
51 #[error("invalid {0} enum variant {1:?}")]
52 InvalidEnum(&'static str, Box<dyn Debug>),
53 #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
54 InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
59 impl From<Infallible> for DeserializeError {
60 fn from(_err: Infallible) -> Self {
61 unreachable!("infallible")
65 impl From<io::Error> for DeserializeError {
66 fn from(err: io::Error) -> Self {
67 if err.kind() == io::ErrorKind::UnexpectedEof {
68 DeserializeError::UnexpectedEof
70 DeserializeError::IoError(err)
75 pub trait OrDefault<T> {
76 fn or_default(self) -> Self;
79 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
80 impl<'a, R: Read> Read for WrapRead<'a, R> {
81 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
85 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
86 self.0.read_vectored(bufs)
91 fn is_read_vectored(&self) -> bool {
92 self.0.is_read_vectored()
97 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
98 self.0.read_to_end(buf)
101 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
102 self.0.read_to_string(buf)
105 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
106 self.0.read_exact(buf)
111 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
115 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
116 self.0.read_buf_exact(cursor)
122 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
123 fn or_default(self) -> Self {
125 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
132 fn option(&self) -> Option<usize>;
134 type Range: Iterator<Item = usize> + 'static;
135 fn range(&self) -> Self::Range;
137 type Take<R: Read>: Read;
138 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
149 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
150 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
153 pub trait MtSerialize {
154 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
157 pub trait MtDeserialize: Sized {
158 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
161 impl MtLen for usize {
162 fn option(&self) -> Option<usize> {
166 type Range = std::ops::Range<usize>;
167 fn range(&self) -> Self::Range {
171 type Take<R: Read> = io::Take<R>;
172 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
173 reader.take(*self as u64)
181 + TryFrom<usize, Error: Into<SerializeError>>
182 + TryInto<usize, Error: Into<DeserializeError>>
186 impl<T: MtCfgLen> MtCfg for T {
190 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
192 .map_err(Into::into)?
193 .mt_serialize::<DefCfg>(writer)
196 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
197 Self::mt_deserialize::<DefCfg>(reader)?
203 impl MtCfgLen for u8 {}
204 impl MtCfgLen for u16 {}
205 impl MtCfgLen for u32 {}
206 impl MtCfgLen for u64 {}
208 pub type DefCfg = u16;
214 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
218 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
224 fn option(&self) -> Option<usize> {
228 type Range = std::ops::RangeFrom<usize>;
229 fn range(&self) -> Self::Range {
233 type Take<R: Read> = R;
234 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
239 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
241 impl<B: MtCfg> MtCfg for Utf16<B> {
243 type Inner = B::Inner;
249 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
250 B::write_len(len, writer)
253 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
258 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
262 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
263 A::write_len(len, writer)
266 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
271 impl MtSerialize for u8 {
272 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
273 writer.write_u8(*self)?;
278 impl MtDeserialize for u8 {
279 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
280 Ok(reader.read_u8()?)
284 impl MtSerialize for i8 {
285 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
286 writer.write_i8(*self)?;
291 impl MtDeserialize for i8 {
292 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
293 Ok(reader.read_i8()?)
297 macro_rules! impl_num {
299 impl MtSerialize for $T {
300 fn mt_serialize<C: MtCfg>(
302 writer: &mut impl Write,
303 ) -> Result<(), SerializeError> {
305 writer.[<write_ $T>]::<BigEndian>(*self)?;
311 impl MtDeserialize for $T {
312 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
314 Ok(reader.[<read_ $T>]::<BigEndian>()?)
332 impl MtSerialize for () {
333 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
338 impl MtDeserialize for () {
339 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
344 impl MtSerialize for bool {
345 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
346 (*self as u8).mt_serialize::<DefCfg>(writer)
350 impl MtDeserialize for bool {
351 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
352 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
356 impl<T: MtSerialize> MtSerialize for &T {
357 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
358 (*self).mt_serialize::<C>(writer)
362 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
363 writer: &mut impl Write,
364 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
365 ) -> Result<(), SerializeError> {
366 C::write_len(iter.len(), writer)?;
369 .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
372 pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
373 reader: &mut impl Read,
374 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
375 let len = C::read_len(reader)?;
376 mt_deserialize_sized_seq::<C, _>(&len, reader)
379 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
381 reader: &'a mut impl Read,
382 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
383 let variable = len.option().is_none();
387 .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
388 Err(DeserializeError::UnexpectedEof) if variable => None,
393 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
394 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
395 mt_serialize_seq::<(), _>(writer, self.iter())
399 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
400 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
401 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
405 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
406 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
407 self.as_repr().mt_serialize::<DefCfg>(writer)
411 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
412 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
413 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
419 impl<T: MtSerialize> MtSerialize for Option<T> {
420 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
422 Some(item) => item.mt_serialize::<C>(writer),
428 impl<T: MtDeserialize> MtDeserialize for Option<T> {
429 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
430 T::mt_deserialize::<C>(reader).map(Some).or_default()
434 impl<T: MtSerialize> MtSerialize for Vec<T> {
435 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
436 mt_serialize_seq::<C, _>(writer, self.iter())
440 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
441 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
442 mt_deserialize_seq::<C, _>(reader)?.try_collect()
446 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
447 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
448 mt_serialize_seq::<C, _>(writer, self.iter())
452 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
453 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
454 mt_deserialize_seq::<C, _>(reader)?.try_collect()
458 // TODO: support more tuples
459 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
460 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
461 self.0.mt_serialize::<C>(writer)?;
462 self.1.mt_serialize::<C::Inner>(writer)?;
468 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
469 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
470 let a = A::mt_deserialize::<C>(reader)?;
471 let b = B::mt_deserialize::<C::Inner>(reader)?;
477 impl<K, V> MtSerialize for HashMap<K, V>
479 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
482 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
483 mt_serialize_seq::<C, _>(writer, self.iter())
487 impl<K, V> MtDeserialize for HashMap<K, V>
489 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
492 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
493 mt_deserialize_seq::<C, _>(reader)?.try_collect()
497 impl MtSerialize for &str {
498 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
501 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
502 .mt_serialize::<C>(writer)
504 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
509 impl MtSerialize for String {
510 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
511 self.as_str().mt_serialize::<C>(writer)
515 impl MtDeserialize for String {
516 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
521 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
535 let len = C::read_len(reader)?;
537 // use capacity if available
538 let mut st = match len.option() {
539 Some(x) => String::with_capacity(x),
540 None => String::new(),
543 len.take(WrapRead(reader)).read_to_string(&mut st)?;
550 impl<T: MtSerialize> MtSerialize for Box<T> {
551 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
552 self.deref().mt_serialize::<C>(writer)
556 impl<T: MtDeserialize> MtDeserialize for Box<T> {
557 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
558 Ok(Self::new(T::mt_deserialize::<C>(reader)?))