1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
10 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11 use enumset::{EnumSet, EnumSetTypeWithRepr};
12 use paste::paste as paste_macro;
14 collections::{HashMap, HashSet},
17 io::{self, Read, Write},
26 #[derive(Error, Debug)]
27 pub enum SerializeError {
28 #[error("io error: {0}")]
29 IoError(#[from] io::Error),
30 #[error("collection too big: {0}")]
31 TooBig(#[from] TryFromIntError),
36 impl From<Infallible> for SerializeError {
37 fn from(_err: Infallible) -> Self {
38 unreachable!("infallible")
42 #[derive(Error, Debug)]
43 pub enum DeserializeError {
44 #[error("io error: {0}")]
46 #[error("unexpected end of file")]
48 #[error("collection too big: {0}")]
49 TooBig(#[from] TryFromIntError),
50 #[error("invalid UTF-16: {0}")]
51 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
52 #[error("invalid {0} enum variant {1:?}")]
53 InvalidEnum(&'static str, Box<dyn Debug>),
54 #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
55 InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
60 impl From<Infallible> for DeserializeError {
61 fn from(_err: Infallible) -> Self {
62 unreachable!("infallible")
66 impl From<io::Error> for DeserializeError {
67 fn from(err: io::Error) -> Self {
68 if err.kind() == io::ErrorKind::UnexpectedEof {
69 DeserializeError::UnexpectedEof
71 DeserializeError::IoError(err)
76 pub trait OrDefault<T> {
77 fn or_default(self) -> Self;
80 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
81 impl<'a, R: Read> Read for WrapRead<'a, R> {
82 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
86 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
87 self.0.read_vectored(bufs)
92 fn is_read_vectored(&self) -> bool {
93 self.0.is_read_vectored()
98 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
99 self.0.read_to_end(buf)
102 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
103 self.0.read_to_string(buf)
106 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
107 self.0.read_exact(buf)
112 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
116 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
117 self.0.read_buf_exact(cursor)
123 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
124 fn or_default(self) -> Self {
126 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
133 fn option(&self) -> Option<usize>;
135 type Range: Iterator<Item = usize> + 'static;
136 fn range(&self) -> Self::Range;
138 type Take<R: Read>: Read;
139 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
150 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
151 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
154 pub trait MtSerialize {
155 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
158 pub trait MtDeserialize: Sized {
159 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
162 impl MtLen for usize {
163 fn option(&self) -> Option<usize> {
167 type Range = std::ops::Range<usize>;
168 fn range(&self) -> Self::Range {
172 type Take<R: Read> = io::Take<R>;
173 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
174 reader.take(*self as u64)
182 + TryFrom<usize, Error: Into<SerializeError>>
183 + TryInto<usize, Error: Into<DeserializeError>>
187 impl<T: MtCfgLen> MtCfg for T {
191 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
193 .map_err(Into::into)?
194 .mt_serialize::<DefCfg>(writer)
197 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
198 Self::mt_deserialize::<DefCfg>(reader)?
204 impl MtCfgLen for u8 {}
205 impl MtCfgLen for u16 {}
206 impl MtCfgLen for u32 {}
207 impl MtCfgLen for u64 {}
209 pub type DefCfg = u16;
215 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
219 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
225 fn option(&self) -> Option<usize> {
229 type Range = std::ops::RangeFrom<usize>;
230 fn range(&self) -> Self::Range {
234 type Take<R: Read> = R;
235 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
240 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
242 impl<B: MtCfg> MtCfg for Utf16<B> {
244 type Inner = B::Inner;
250 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
251 B::write_len(len, writer)
254 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
259 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
263 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
264 A::write_len(len, writer)
267 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
272 impl MtSerialize for u8 {
273 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
274 writer.write_u8(*self)?;
279 impl MtDeserialize for u8 {
280 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
281 Ok(reader.read_u8()?)
285 impl MtSerialize for i8 {
286 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
287 writer.write_i8(*self)?;
292 impl MtDeserialize for i8 {
293 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
294 Ok(reader.read_i8()?)
298 macro_rules! impl_num {
300 impl MtSerialize for $T {
301 fn mt_serialize<C: MtCfg>(
303 writer: &mut impl Write,
304 ) -> Result<(), SerializeError> {
306 writer.[<write_ $T>]::<BigEndian>(*self)?;
312 impl MtDeserialize for $T {
313 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
315 Ok(reader.[<read_ $T>]::<BigEndian>()?)
333 impl MtSerialize for () {
334 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
339 impl MtDeserialize for () {
340 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
345 impl MtSerialize for bool {
346 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
347 (*self as u8).mt_serialize::<DefCfg>(writer)
351 impl MtDeserialize for bool {
352 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
353 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
357 impl<T: MtSerialize> MtSerialize for &T {
358 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
359 (*self).mt_serialize::<C>(writer)
363 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
364 writer: &mut impl Write,
365 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
366 ) -> Result<(), SerializeError> {
367 C::write_len(iter.len(), writer)?;
370 .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
373 pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
374 reader: &mut impl Read,
375 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
376 let len = C::read_len(reader)?;
377 mt_deserialize_sized_seq::<C, _>(&len, reader)
380 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
382 reader: &'a mut impl Read,
383 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
384 let variable = len.option().is_none();
388 .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
389 Err(DeserializeError::UnexpectedEof) if variable => None,
394 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
395 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
396 mt_serialize_seq::<(), _>(writer, self.iter())
400 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
401 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
402 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
406 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
407 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
408 self.as_repr().mt_serialize::<DefCfg>(writer)
412 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
413 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
414 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
420 impl<T: MtSerialize> MtSerialize for Option<T> {
421 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
423 Some(item) => item.mt_serialize::<C>(writer),
429 impl<T: MtDeserialize> MtDeserialize for Option<T> {
430 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
431 T::mt_deserialize::<C>(reader).map(Some).or_default()
435 impl<T: MtSerialize> MtSerialize for Vec<T> {
436 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
437 mt_serialize_seq::<C, _>(writer, self.iter())
441 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
442 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
443 mt_deserialize_seq::<C, _>(reader)?.try_collect()
447 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
448 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
449 mt_serialize_seq::<C, _>(writer, self.iter())
453 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
454 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
455 mt_deserialize_seq::<C, _>(reader)?.try_collect()
459 // TODO: support more tuples
460 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
461 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
462 self.0.mt_serialize::<C>(writer)?;
463 self.1.mt_serialize::<C::Inner>(writer)?;
469 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
470 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
471 let a = A::mt_deserialize::<C>(reader)?;
472 let b = B::mt_deserialize::<C::Inner>(reader)?;
478 impl<K, V> MtSerialize for HashMap<K, V>
480 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
483 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
484 mt_serialize_seq::<C, _>(writer, self.iter())
488 impl<K, V> MtDeserialize for HashMap<K, V>
490 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
493 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
494 mt_deserialize_seq::<C, _>(reader)?.try_collect()
498 impl MtSerialize for &str {
499 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
502 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
503 .mt_serialize::<C>(writer)
505 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
510 impl MtSerialize for String {
511 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
512 self.as_str().mt_serialize::<C>(writer)
516 impl MtDeserialize for String {
517 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
522 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
536 let len = C::read_len(reader)?;
538 // use capacity if available
539 let mut st = match len.option() {
540 Some(x) => String::with_capacity(x),
541 None => String::new(),
544 len.take(WrapRead(reader)).read_to_string(&mut st)?;
551 impl<T: MtSerialize> MtSerialize for Box<T> {
552 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
553 self.deref().mt_serialize::<C>(writer)
557 impl<T: MtDeserialize> MtDeserialize for Box<T> {
558 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
559 Ok(Self::new(T::mt_deserialize::<C>(reader)?))