1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
13 collections::{HashMap, HashSet},
15 io::{self, Read, Write},
24 #[derive(Error, Debug)]
25 pub enum SerializeError {
26 #[error("io error: {0}")]
27 IoError(#[from] io::Error),
28 #[error("collection too big: {0}")]
29 TooBig(#[from] TryFromIntError),
32 impl From<Infallible> for SerializeError {
33 fn from(_err: Infallible) -> Self {
34 unreachable!("infallible")
38 #[derive(Error, Debug)]
39 pub enum DeserializeError {
40 #[error("io error: {0}")]
42 #[error("unexpected end of file")]
44 #[error("collection too big: {0}")]
45 TooBig(#[from] TryFromIntError),
46 #[error("invalid UTF-16: {0}")]
47 InvalidUtf16(#[from] std::char::DecodeUtf16Error),
48 #[error("unimplemented")]
52 impl From<Infallible> for DeserializeError {
53 fn from(_err: Infallible) -> Self {
54 unreachable!("infallible")
58 impl From<io::Error> for DeserializeError {
59 fn from(err: io::Error) -> Self {
60 if err.kind() == io::ErrorKind::UnexpectedEof {
61 DeserializeError::UnexpectedEof
63 DeserializeError::IoError(err)
68 pub trait OrDefault<T> {
69 fn or_default(self) -> Self;
72 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
73 impl<'a, R: Read> Read for WrapRead<'a, R> {
74 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
78 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
79 self.0.read_vectored(bufs)
84 fn is_read_vectored(&self) -> bool {
85 self.0.is_read_vectored()
90 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
91 self.0.read_to_end(buf)
94 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
95 self.0.read_to_string(buf)
98 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
99 self.0.read_exact(buf)
104 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
108 fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
109 self.0.read_buf_exact(cursor)
115 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
116 fn or_default(self) -> Self {
118 Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
125 fn option(&self) -> Option<usize>;
127 type Range: Iterator<Item = usize> + 'static;
128 fn range(&self) -> Self::Range;
130 type Take<R: Read>: Read;
131 fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
141 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
142 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
145 pub trait MtSerialize {
146 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
149 pub trait MtDeserialize: Sized {
150 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
153 impl MtLen for usize {
154 fn option(&self) -> Option<usize> {
158 type Range = std::ops::Range<usize>;
159 fn range(&self) -> Self::Range {
163 type Take<R: Read> = io::Take<R>;
164 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
165 reader.take(*self as u64)
173 + TryFrom<usize, Error: Into<SerializeError>>
174 + TryInto<usize, Error: Into<DeserializeError>>
178 impl<T: MtCfgLen> MtCfg for T {
181 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
183 .map_err(Into::into)?
184 .mt_serialize::<DefCfg>(writer)
187 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
188 Ok(Self::mt_deserialize::<DefCfg>(reader)?
190 .map_err(Into::into)?)
194 impl MtCfgLen for u8 {}
195 impl MtCfgLen for u16 {}
196 impl MtCfgLen for u32 {}
197 impl MtCfgLen for u64 {}
199 pub type DefCfg = u16;
204 fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
208 fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
214 fn option(&self) -> Option<usize> {
218 type Range = std::ops::RangeFrom<usize>;
219 fn range(&self) -> Self::Range {
223 type Take<R: Read> = R;
224 fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
229 pub struct Utf16<B: MtCfg>(pub B);
231 impl<B: MtCfg> MtCfg for Utf16<B> {
238 fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
239 B::write_len(len, writer)
242 fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
247 impl MtSerialize for u8 {
248 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
249 writer.write_u8(*self)?;
254 impl MtDeserialize for u8 {
255 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
256 Ok(reader.read_u8()?)
260 impl MtSerialize for i8 {
261 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
262 writer.write_i8(*self)?;
267 impl MtDeserialize for i8 {
268 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
269 Ok(reader.read_i8()?)
273 macro_rules! impl_num {
275 impl MtSerialize for $T {
276 fn mt_serialize<C: MtCfg>(
278 writer: &mut impl Write,
279 ) -> Result<(), SerializeError> {
281 writer.[<write_ $T>]::<BigEndian>(*self)?;
287 impl MtDeserialize for $T {
288 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
290 Ok(reader.[<read_ $T>]::<BigEndian>()?)
308 impl MtSerialize for () {
309 fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
314 impl MtDeserialize for () {
315 fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
320 impl MtSerialize for bool {
321 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
322 (*self as u8).mt_serialize::<DefCfg>(writer)
326 impl MtDeserialize for bool {
327 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
328 Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
332 impl<T: MtSerialize> MtSerialize for &T {
333 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
334 (*self).mt_serialize::<C>(writer)
338 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
339 writer: &mut impl Write,
340 iter: impl ExactSizeIterator + IntoIterator<Item = T>,
341 ) -> Result<(), SerializeError> {
342 C::write_len(iter.len(), writer)?;
345 .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
348 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
349 reader: &'a mut impl Read,
350 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
351 let len = C::read_len(reader)?;
352 mt_deserialize_sized_seq(&len, reader)
355 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
357 reader: &'a mut impl Read,
358 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
359 let variable = len.option().is_none();
363 .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
364 Err(DeserializeError::UnexpectedEof) if variable => None,
369 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
370 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
371 mt_serialize_seq::<(), _>(writer, self.iter())
375 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
376 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
377 std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
381 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
382 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
383 self.as_repr().mt_serialize::<DefCfg>(writer)
387 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
388 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
389 Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
395 impl<T: MtSerialize> MtSerialize for Option<T> {
396 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
398 Some(item) => item.mt_serialize::<C>(writer),
404 impl<T: MtDeserialize> MtDeserialize for Option<T> {
405 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
406 T::mt_deserialize::<C>(reader).map(Some).or_default()
410 impl<T: MtSerialize> MtSerialize for Vec<T> {
411 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
412 mt_serialize_seq::<C, _>(writer, self.iter())
416 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
417 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
418 mt_deserialize_seq::<C, _>(reader)?.try_collect()
422 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
423 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
424 mt_serialize_seq::<C, _>(writer, self.iter())
428 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
429 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
430 mt_deserialize_seq::<C, _>(reader)?.try_collect()
434 // TODO: support more tuples
435 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
436 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
437 self.0.mt_serialize::<DefCfg>(writer)?;
438 self.1.mt_serialize::<DefCfg>(writer)?;
444 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
445 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
446 let a = A::mt_deserialize::<DefCfg>(reader)?;
447 let b = B::mt_deserialize::<DefCfg>(reader)?;
453 impl<K, V> MtSerialize for HashMap<K, V>
455 K: MtSerialize + std::cmp::Eq + std::hash::Hash,
458 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
459 mt_serialize_seq::<C, _>(writer, self.iter())
463 impl<K, V> MtDeserialize for HashMap<K, V>
465 K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
468 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
469 mt_deserialize_seq::<C, _>(reader)?.try_collect()
473 impl MtSerialize for String {
474 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
477 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
478 .mt_serialize::<C>(writer)
480 mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
485 impl MtDeserialize for String {
486 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
491 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
505 let len = C::read_len(reader)?;
507 // use capacity if available
508 let mut st = match len.option() {
509 Some(x) => String::with_capacity(x),
510 None => String::new(),
513 len.take(WrapRead(reader)).read_to_string(&mut st)?;
520 impl<T: MtSerialize> MtSerialize for Box<T> {
521 fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
522 self.deref().mt_serialize::<C>(writer)
526 impl<T: MtDeserialize> MtDeserialize for Box<T> {
527 fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
528 Ok(Self::new(T::mt_deserialize::<C>(reader)?))