]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
961e4169b549c5042c78eed2d3a115ce080bd0da
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8 pub use zstd;
9
10 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
11 use enumset::{EnumSet, EnumSetTypeWithRepr};
12 use paste::paste as paste_macro;
13 use std::{
14     collections::{HashMap, HashSet},
15     convert::Infallible,
16     fmt::Debug,
17     io::{self, Read, Write},
18     num::TryFromIntError,
19     ops::Deref,
20 };
21 use thiserror::Error;
22
23 #[cfg(test)]
24 mod tests;
25
26 #[derive(Error, Debug)]
27 pub enum SerializeError {
28     #[error("io error: {0}")]
29     IoError(#[from] io::Error),
30     #[error("collection too big: {0}")]
31     TooBig(#[from] TryFromIntError),
32     #[error("{0}")]
33     Other(String),
34 }
35
36 impl From<Infallible> for SerializeError {
37     fn from(_err: Infallible) -> Self {
38         unreachable!("infallible")
39     }
40 }
41
42 #[derive(Error, Debug)]
43 pub enum DeserializeError {
44     #[error("io error: {0}")]
45     IoError(io::Error),
46     #[error("unexpected end of file")]
47     UnexpectedEof,
48     #[error("collection too big: {0}")]
49     TooBig(#[from] TryFromIntError),
50     #[error("invalid UTF-16: {0}")]
51     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
52     #[error("invalid {0} enum variant {1:?}")]
53     InvalidEnum(&'static str, Box<dyn Debug>),
54     #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
55     InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
56     #[error("{0}")]
57     Other(String),
58 }
59
60 impl From<Infallible> for DeserializeError {
61     fn from(_err: Infallible) -> Self {
62         unreachable!("infallible")
63     }
64 }
65
66 impl From<io::Error> for DeserializeError {
67     fn from(err: io::Error) -> Self {
68         if err.kind() == io::ErrorKind::UnexpectedEof {
69             DeserializeError::UnexpectedEof
70         } else {
71             DeserializeError::IoError(err)
72         }
73     }
74 }
75
76 pub trait OrDefault<T> {
77     fn or_default(self) -> Self;
78 }
79
80 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
81 impl<'a, R: Read> Read for WrapRead<'a, R> {
82     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
83         self.0.read(buf)
84     }
85
86     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
87         self.0.read_vectored(bufs)
88     }
89
90     /*
91
92     fn is_read_vectored(&self) -> bool {
93         self.0.is_read_vectored()
94     }
95
96     */
97
98     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
99         self.0.read_to_end(buf)
100     }
101
102     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
103         self.0.read_to_string(buf)
104     }
105
106     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
107         self.0.read_exact(buf)
108     }
109
110     /*
111
112     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
113         self.0.read_buf(buf)
114     }
115
116     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
117         self.0.read_buf_exact(cursor)
118     }
119
120     */
121 }
122
123 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
124     fn or_default(self) -> Self {
125         match self {
126             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
127             x => x,
128         }
129     }
130 }
131
132 pub trait MtLen {
133     fn option(&self) -> Option<usize>;
134
135     type Range: Iterator<Item = usize> + 'static;
136     fn range(&self) -> Self::Range;
137
138     type Take<R: Read>: Read;
139     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
140 }
141
142 pub trait MtCfg {
143     type Len: MtLen;
144     type Inner: MtCfg;
145
146     fn utf16() -> bool {
147         false
148     }
149
150     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
151     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
152 }
153
154 pub trait MtSerialize {
155     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
156 }
157
158 pub trait MtDeserialize: Sized {
159     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
160 }
161
162 impl MtLen for usize {
163     fn option(&self) -> Option<usize> {
164         Some(*self)
165     }
166
167     type Range = std::ops::Range<usize>;
168     fn range(&self) -> Self::Range {
169         0..*self
170     }
171
172     type Take<R: Read> = io::Take<R>;
173     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
174         reader.take(*self as u64)
175     }
176 }
177
178 trait MtCfgLen:
179     Sized
180     + MtSerialize
181     + MtDeserialize
182     + TryFrom<usize, Error: Into<SerializeError>>
183     + TryInto<usize, Error: Into<DeserializeError>>
184 {
185 }
186
187 impl<T: MtCfgLen> MtCfg for T {
188     type Len = usize;
189     type Inner = DefCfg;
190
191     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
192         Self::try_from(len)
193             .map_err(Into::into)?
194             .mt_serialize::<DefCfg>(writer)
195     }
196
197     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
198         Self::mt_deserialize::<DefCfg>(reader)?
199             .try_into()
200             .map_err(Into::into)
201     }
202 }
203
204 impl MtCfgLen for u8 {}
205 impl MtCfgLen for u16 {}
206 impl MtCfgLen for u32 {}
207 impl MtCfgLen for u64 {}
208
209 pub type DefCfg = u16;
210
211 impl MtCfg for () {
212     type Len = ();
213     type Inner = DefCfg;
214
215     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
216         Ok(())
217     }
218
219     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
220         Ok(())
221     }
222 }
223
224 impl MtLen for () {
225     fn option(&self) -> Option<usize> {
226         None
227     }
228
229     type Range = std::ops::RangeFrom<usize>;
230     fn range(&self) -> Self::Range {
231         0..
232     }
233
234     type Take<R: Read> = R;
235     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
236         reader
237     }
238 }
239
240 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
241
242 impl<B: MtCfg> MtCfg for Utf16<B> {
243     type Len = B::Len;
244     type Inner = B::Inner;
245
246     fn utf16() -> bool {
247         true
248     }
249
250     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
251         B::write_len(len, writer)
252     }
253
254     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
255         B::read_len(reader)
256     }
257 }
258
259 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
260     type Len = A::Len;
261     type Inner = B;
262
263     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
264         A::write_len(len, writer)
265     }
266
267     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
268         A::read_len(reader)
269     }
270 }
271
272 impl MtSerialize for u8 {
273     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
274         writer.write_u8(*self)?;
275         Ok(())
276     }
277 }
278
279 impl MtDeserialize for u8 {
280     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
281         Ok(reader.read_u8()?)
282     }
283 }
284
285 impl MtSerialize for i8 {
286     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
287         writer.write_i8(*self)?;
288         Ok(())
289     }
290 }
291
292 impl MtDeserialize for i8 {
293     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
294         Ok(reader.read_i8()?)
295     }
296 }
297
298 macro_rules! impl_num {
299     ($T:ty) => {
300         impl MtSerialize for $T {
301             fn mt_serialize<C: MtCfg>(
302                 &self,
303                 writer: &mut impl Write,
304             ) -> Result<(), SerializeError> {
305                 paste_macro! {
306                     writer.[<write_ $T>]::<BigEndian>(*self)?;
307                 }
308                 Ok(())
309             }
310         }
311
312         impl MtDeserialize for $T {
313             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
314                 paste_macro! {
315                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
316                 }
317             }
318         }
319     };
320 }
321
322 impl_num!(u16);
323 impl_num!(i16);
324
325 impl_num!(u32);
326 impl_num!(i32);
327 impl_num!(f32);
328
329 impl_num!(u64);
330 impl_num!(i64);
331 impl_num!(f64);
332
333 impl MtSerialize for () {
334     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
335         Ok(())
336     }
337 }
338
339 impl MtDeserialize for () {
340     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
341         Ok(())
342     }
343 }
344
345 impl MtSerialize for bool {
346     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
347         (*self as u8).mt_serialize::<DefCfg>(writer)
348     }
349 }
350
351 impl MtDeserialize for bool {
352     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
353         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
354     }
355 }
356
357 impl<T: MtSerialize> MtSerialize for &T {
358     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
359         (*self).mt_serialize::<C>(writer)
360     }
361 }
362
363 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
364     writer: &mut impl Write,
365     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
366 ) -> Result<(), SerializeError> {
367     C::write_len(iter.len(), writer)?;
368
369     iter.into_iter()
370         .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
371 }
372
373 pub fn mt_deserialize_seq<C: MtCfg, T: MtDeserialize>(
374     reader: &mut impl Read,
375 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + '_, DeserializeError> {
376     let len = C::read_len(reader)?;
377     mt_deserialize_sized_seq::<C, _>(&len, reader)
378 }
379
380 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
381     len: &C::Len,
382     reader: &'a mut impl Read,
383 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
384     let variable = len.option().is_none();
385
386     Ok(len
387         .range()
388         .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
389             Err(DeserializeError::UnexpectedEof) if variable => None,
390             x => Some(x),
391         }))
392 }
393
394 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
395     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
396         mt_serialize_seq::<(), _>(writer, self.iter())
397     }
398 }
399
400 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
401     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
402         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
403     }
404 }
405
406 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
407     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
408         self.as_repr().mt_serialize::<DefCfg>(writer)
409     }
410 }
411
412 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
413     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
414         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
415             reader,
416         )?))
417     }
418 }
419
420 impl<T: MtSerialize> MtSerialize for Option<T> {
421     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
422         match self {
423             Some(item) => item.mt_serialize::<C>(writer),
424             None => Ok(()),
425         }
426     }
427 }
428
429 impl<T: MtDeserialize> MtDeserialize for Option<T> {
430     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
431         T::mt_deserialize::<C>(reader).map(Some).or_default()
432     }
433 }
434
435 impl<T: MtSerialize> MtSerialize for Vec<T> {
436     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
437         mt_serialize_seq::<C, _>(writer, self.iter())
438     }
439 }
440
441 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
442     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
443         mt_deserialize_seq::<C, _>(reader)?.try_collect()
444     }
445 }
446
447 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
448     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
449         mt_serialize_seq::<C, _>(writer, self.iter())
450     }
451 }
452
453 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
454     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
455         mt_deserialize_seq::<C, _>(reader)?.try_collect()
456     }
457 }
458
459 // TODO: support more tuples
460 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
461     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
462         self.0.mt_serialize::<C>(writer)?;
463         self.1.mt_serialize::<C::Inner>(writer)?;
464
465         Ok(())
466     }
467 }
468
469 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
470     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
471         let a = A::mt_deserialize::<C>(reader)?;
472         let b = B::mt_deserialize::<C::Inner>(reader)?;
473
474         Ok((a, b))
475     }
476 }
477
478 impl<K, V> MtSerialize for HashMap<K, V>
479 where
480     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
481     V: MtSerialize,
482 {
483     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
484         mt_serialize_seq::<C, _>(writer, self.iter())
485     }
486 }
487
488 impl<K, V> MtDeserialize for HashMap<K, V>
489 where
490     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
491     V: MtDeserialize,
492 {
493     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
494         mt_deserialize_seq::<C, _>(reader)?.try_collect()
495     }
496 }
497
498 impl MtSerialize for &str {
499     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
500         if C::utf16() {
501             self.encode_utf16()
502                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
503                 .mt_serialize::<C>(writer)
504         } else {
505             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
506         }
507     }
508 }
509
510 impl MtSerialize for String {
511     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
512         self.as_str().mt_serialize::<C>(writer)
513     }
514 }
515
516 impl MtDeserialize for String {
517     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
518         if C::utf16() {
519             let mut err = None;
520
521             let res =
522                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
523                     Ok(v) => Some(v),
524                     Err(e) => {
525                         err = Some(e);
526                         None
527                     }
528                 }))
529                 .try_collect();
530
531             match err {
532                 None => Ok(res?),
533                 Some(e) => Err(e),
534             }
535         } else {
536             let len = C::read_len(reader)?;
537
538             // use capacity if available
539             let mut st = match len.option() {
540                 Some(x) => String::with_capacity(x),
541                 None => String::new(),
542             };
543
544             len.take(WrapRead(reader)).read_to_string(&mut st)?;
545
546             Ok(st)
547         }
548     }
549 }
550
551 impl<T: MtSerialize> MtSerialize for Box<T> {
552     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
553         self.deref().mt_serialize::<C>(writer)
554     }
555 }
556
557 impl<T: MtDeserialize> MtDeserialize for Box<T> {
558     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
559         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
560     }
561 }