]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
b2aebc87e8b1e3bb634b6cdcdf5b51d460679bc0
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8 pub use zstd;
9
10 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11 use enumset::{EnumSet, EnumSetTypeWithRepr};
12 use paste::paste as paste_macro;
13 use std::{
14     collections::{HashMap, HashSet},
15     convert::Infallible,
16     fmt::Debug,
17     io::{self, Read, Write},
18     num::TryFromIntError,
19     ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
20 };
21 use thiserror::Error;
22
23 #[cfg(test)]
24 mod tests;
25
26 use crate as mt_ser;
27
28 #[derive(Error, Debug)]
29 pub enum SerializeError {
30     #[error("io error: {0}")]
31     IoError(#[from] io::Error),
32     #[error("collection too big: {0}")]
33     TooBig(#[from] TryFromIntError),
34     #[error("{0}")]
35     Other(String),
36 }
37
38 impl From<Infallible> for SerializeError {
39     fn from(_err: Infallible) -> Self {
40         unreachable!("infallible")
41     }
42 }
43
44 #[derive(Error, Debug)]
45 pub enum DeserializeError {
46     #[error("io error: {0}")]
47     IoError(io::Error),
48     #[error("unexpected end of file")]
49     UnexpectedEof,
50     #[error("collection too big: {0}")]
51     TooBig(#[from] TryFromIntError),
52     #[error("invalid UTF-16: {0}")]
53     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
54     #[error("invalid {0} enum variant {1:?}")]
55     InvalidEnum(&'static str, Box<dyn Debug>),
56     #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
57     InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
58     #[error("{0}")]
59     Other(String),
60 }
61
62 impl From<Infallible> for DeserializeError {
63     fn from(_err: Infallible) -> Self {
64         unreachable!("infallible")
65     }
66 }
67
68 impl From<io::Error> for DeserializeError {
69     fn from(err: io::Error) -> Self {
70         if err.kind() == io::ErrorKind::UnexpectedEof {
71             DeserializeError::UnexpectedEof
72         } else {
73             DeserializeError::IoError(err)
74         }
75     }
76 }
77
78 pub trait OrDefault<T> {
79     fn or_default(self) -> Self;
80 }
81
82 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
83 impl<'a, R: Read> Read for WrapRead<'a, R> {
84     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
85         self.0.read(buf)
86     }
87
88     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
89         self.0.read_vectored(bufs)
90     }
91
92     /*
93
94     fn is_read_vectored(&self) -> bool {
95         self.0.is_read_vectored()
96     }
97
98     */
99
100     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
101         self.0.read_to_end(buf)
102     }
103
104     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
105         self.0.read_to_string(buf)
106     }
107
108     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
109         self.0.read_exact(buf)
110     }
111
112     /*
113
114     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
115         self.0.read_buf(buf)
116     }
117
118     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
119         self.0.read_buf_exact(cursor)
120     }
121
122     */
123 }
124
125 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
126     fn or_default(self) -> Self {
127         match self {
128             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
129             x => x,
130         }
131     }
132 }
133
134 pub trait MtLen {
135     fn option(&self) -> Option<usize>;
136
137     type Range: Iterator<Item = usize> + 'static;
138     fn range(&self) -> Self::Range;
139
140     type Take<R: Read>: Read;
141     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
142 }
143
144 pub trait MtCfg {
145     type Len: MtLen;
146     type Inner: MtCfg;
147
148     fn utf16() -> bool {
149         false
150     }
151
152     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
153     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
154 }
155
156 pub trait MtSerialize {
157     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
158 }
159
160 pub trait MtDeserialize: Sized {
161     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
162 }
163
164 impl MtLen for usize {
165     fn option(&self) -> Option<usize> {
166         Some(*self)
167     }
168
169     type Range = std::ops::Range<usize>;
170     fn range(&self) -> Self::Range {
171         0..*self
172     }
173
174     type Take<R: Read> = io::Take<R>;
175     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
176         reader.take(*self as u64)
177     }
178 }
179
180 trait MtCfgLen:
181     Sized
182     + MtSerialize
183     + MtDeserialize
184     + TryFrom<usize, Error: Into<SerializeError>>
185     + TryInto<usize, Error: Into<DeserializeError>>
186 {
187 }
188
189 impl<T: MtCfgLen> MtCfg for T {
190     type Len = usize;
191     type Inner = DefCfg;
192
193     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
194         Self::try_from(len)
195             .map_err(Into::into)?
196             .mt_serialize::<DefCfg>(writer)
197     }
198
199     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
200         Self::mt_deserialize::<DefCfg>(reader)?
201             .try_into()
202             .map_err(Into::into)
203     }
204 }
205
206 impl MtCfgLen for u8 {}
207 impl MtCfgLen for u16 {}
208 impl MtCfgLen for u32 {}
209 impl MtCfgLen for u64 {}
210
211 pub type DefCfg = u16;
212
213 impl MtCfg for () {
214     type Len = ();
215     type Inner = DefCfg;
216
217     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
218         Ok(())
219     }
220
221     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
222         Ok(())
223     }
224 }
225
226 impl MtLen for () {
227     fn option(&self) -> Option<usize> {
228         None
229     }
230
231     type Range = std::ops::RangeFrom<usize>;
232     fn range(&self) -> Self::Range {
233         0..
234     }
235
236     type Take<R: Read> = R;
237     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
238         reader
239     }
240 }
241
242 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
243
244 impl<B: MtCfg> MtCfg for Utf16<B> {
245     type Len = B::Len;
246     type Inner = B::Inner;
247
248     fn utf16() -> bool {
249         true
250     }
251
252     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
253         B::write_len(len, writer)
254     }
255
256     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
257         B::read_len(reader)
258     }
259 }
260
261 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
262     type Len = A::Len;
263     type Inner = B;
264
265     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
266         A::write_len(len, writer)
267     }
268
269     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
270         A::read_len(reader)
271     }
272 }
273
274 impl MtSerialize for u8 {
275     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
276         writer.write_u8(*self)?;
277         Ok(())
278     }
279 }
280
281 impl MtDeserialize for u8 {
282     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
283         Ok(reader.read_u8()?)
284     }
285 }
286
287 impl MtSerialize for i8 {
288     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
289         writer.write_i8(*self)?;
290         Ok(())
291     }
292 }
293
294 impl MtDeserialize for i8 {
295     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
296         Ok(reader.read_i8()?)
297     }
298 }
299
300 macro_rules! impl_num {
301     ($T:ty) => {
302         impl MtSerialize for $T {
303             fn mt_serialize<C: MtCfg>(
304                 &self,
305                 writer: &mut impl Write,
306             ) -> Result<(), SerializeError> {
307                 paste_macro! {
308                     writer.[<write_ $T>]::<BigEndian>(*self)?;
309                 }
310                 Ok(())
311             }
312         }
313
314         impl MtDeserialize for $T {
315             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
316                 paste_macro! {
317                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
318                 }
319             }
320         }
321     };
322 }
323
324 impl_num!(u16);
325 impl_num!(i16);
326
327 impl_num!(u32);
328 impl_num!(i32);
329 impl_num!(f32);
330
331 impl_num!(u64);
332 impl_num!(i64);
333 impl_num!(f64);
334
335 impl MtSerialize for () {
336     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
337         Ok(())
338     }
339 }
340
341 impl MtDeserialize for () {
342     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
343         Ok(())
344     }
345 }
346
347 impl MtSerialize for bool {
348     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
349         (*self as u8).mt_serialize::<DefCfg>(writer)
350     }
351 }
352
353 impl MtDeserialize for bool {
354     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
355         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
356     }
357 }
358
359 impl<T: MtSerialize> MtSerialize for &T {
360     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
361         (*self).mt_serialize::<C>(writer)
362     }
363 }
364
365 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
366     writer: &mut impl Write,
367     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
368 ) -> Result<(), SerializeError> {
369     C::write_len(iter.len(), writer)?;
370
371     iter.into_iter()
372         .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
373 }
374
375 pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
376     reader: &mut impl Read,
377 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
378     let len = C::read_len(reader)?;
379     mt_deserialize_sized_seq::<C, _>(&len, reader)
380 }
381
382 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
383     len: &C::Len,
384     reader: &'a mut impl Read,
385 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
386     let variable = len.option().is_none();
387
388     Ok(len
389         .range()
390         .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
391             Err(DeserializeError::UnexpectedEof) if variable => None,
392             x => Some(x),
393         }))
394 }
395
396 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
397     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
398         mt_serialize_seq::<(), _>(writer, self.iter())
399     }
400 }
401
402 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
403     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
404         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
405     }
406 }
407
408 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
409     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
410         self.as_repr().mt_serialize::<DefCfg>(writer)
411     }
412 }
413
414 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
415     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
416         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
417             reader,
418         )?))
419     }
420 }
421
422 impl<T: MtSerialize> MtSerialize for Option<T> {
423     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
424         match self {
425             Some(item) => item.mt_serialize::<C>(writer),
426             None => Ok(()),
427         }
428     }
429 }
430
431 impl<T: MtDeserialize> MtDeserialize for Option<T> {
432     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
433         T::mt_deserialize::<C>(reader).map(Some).or_default()
434     }
435 }
436
437 impl<T: MtSerialize> MtSerialize for Vec<T> {
438     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
439         mt_serialize_seq::<C, _>(writer, self.iter())
440     }
441 }
442
443 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
444     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
445         mt_deserialize_seq::<C, _>(reader)?.try_collect()
446     }
447 }
448
449 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
450     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
451         mt_serialize_seq::<C, _>(writer, self.iter())
452     }
453 }
454
455 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
456     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
457         mt_deserialize_seq::<C, _>(reader)?.try_collect()
458     }
459 }
460
461 // TODO: support more tuples
462 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
463     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
464         self.0.mt_serialize::<C>(writer)?;
465         self.1.mt_serialize::<C::Inner>(writer)?;
466
467         Ok(())
468     }
469 }
470
471 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
472     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
473         let a = A::mt_deserialize::<C>(reader)?;
474         let b = B::mt_deserialize::<C::Inner>(reader)?;
475
476         Ok((a, b))
477     }
478 }
479
480 impl<K, V> MtSerialize for HashMap<K, V>
481 where
482     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
483     V: MtSerialize,
484 {
485     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
486         mt_serialize_seq::<C, _>(writer, self.iter())
487     }
488 }
489
490 impl<K, V> MtDeserialize for HashMap<K, V>
491 where
492     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
493     V: MtDeserialize,
494 {
495     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
496         mt_deserialize_seq::<C, _>(reader)?.try_collect()
497     }
498 }
499
500 impl MtSerialize for &str {
501     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
502         if C::utf16() {
503             self.encode_utf16()
504                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
505                 .mt_serialize::<C>(writer)
506         } else {
507             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
508         }
509     }
510 }
511
512 impl MtSerialize for String {
513     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
514         self.as_str().mt_serialize::<C>(writer)
515     }
516 }
517
518 impl MtDeserialize for String {
519     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
520         if C::utf16() {
521             let mut err = None;
522
523             let res =
524                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
525                     Ok(v) => Some(v),
526                     Err(e) => {
527                         err = Some(e);
528                         None
529                     }
530                 }))
531                 .try_collect();
532
533             match err {
534                 None => Ok(res?),
535                 Some(e) => Err(e),
536             }
537         } else {
538             let len = C::read_len(reader)?;
539
540             // use capacity if available
541             let mut st = match len.option() {
542                 Some(x) => String::with_capacity(x),
543                 None => String::new(),
544             };
545
546             len.take(WrapRead(reader)).read_to_string(&mut st)?;
547
548             Ok(st)
549         }
550     }
551 }
552
553 impl<T: MtSerialize> MtSerialize for Box<T> {
554     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
555         self.deref().mt_serialize::<C>(writer)
556     }
557 }
558
559 impl<T: MtDeserialize> MtDeserialize for Box<T> {
560     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
561         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
562     }
563 }
564
565 #[derive(MtSerialize, MtDeserialize)]
566 #[mt(typename = "Range")]
567 #[allow(unused)]
568 struct RemoteRange<T> {
569     start: T,
570     end: T,
571 }
572
573 #[derive(MtSerialize, MtDeserialize)]
574 #[mt(typename = "RangeFrom")]
575 #[allow(unused)]
576 struct RemoteRangeFrom<T> {
577     start: T,
578 }
579
580 #[derive(MtSerialize, MtDeserialize)]
581 #[mt(typename = "RangeFull")]
582 #[allow(unused)]
583 struct RemoteRangeFull;
584
585 // RangeInclusive fields are private
586 impl<T: MtSerialize> MtSerialize for RangeInclusive<T> {
587     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
588         self.start().mt_serialize::<DefCfg>(writer)?;
589         self.end().mt_serialize::<DefCfg>(writer)?;
590
591         Ok(())
592     }
593 }
594
595 impl<T: MtDeserialize> MtDeserialize for RangeInclusive<T> {
596     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
597         let start = T::mt_deserialize::<DefCfg>(reader)?;
598         let end = T::mt_deserialize::<DefCfg>(reader)?;
599
600         Ok(start..=end)
601     }
602 }
603
604 #[derive(MtSerialize, MtDeserialize)]
605 #[mt(typename = "RangeTo")]
606 #[allow(unused)]
607 struct RemoteRangeTo<T> {
608     end: T,
609 }
610
611 #[derive(MtSerialize, MtDeserialize)]
612 #[mt(typename = "RangeToInclusive")]
613 #[allow(unused)]
614 struct RemoteRangeToInclusive<T> {
615     end: T,
616 }