]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
derive deserialize
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
12 use std::{
13     collections::{HashMap, HashSet},
14     convert::Infallible,
15     io::{self, Read, Write},
16     num::TryFromIntError,
17     ops::Deref,
18 };
19 use thiserror::Error;
20
21 #[cfg(test)]
22 mod tests;
23
24 #[derive(Error, Debug)]
25 pub enum SerializeError {
26     #[error("io error: {0}")]
27     IoError(#[from] io::Error),
28     #[error("collection too big: {0}")]
29     TooBig(#[from] TryFromIntError),
30 }
31
32 impl From<Infallible> for SerializeError {
33     fn from(_err: Infallible) -> Self {
34         unreachable!("infallible")
35     }
36 }
37
38 #[derive(Error, Debug)]
39 pub enum DeserializeError {
40     #[error("io error: {0}")]
41     IoError(io::Error),
42     #[error("unexpected end of file")]
43     UnexpectedEof,
44     #[error("collection too big: {0}")]
45     TooBig(#[from] TryFromIntError),
46     #[error("invalid UTF-16: {0}")]
47     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
48     #[error("invalid {0} enum variant {1}")]
49     InvalidEnumVariant(&'static str, u64),
50     #[error("invalid constant - wanted: {0} - got: {1}")]
51     InvalidConst(u64, u64),
52 }
53
54 impl From<Infallible> for DeserializeError {
55     fn from(_err: Infallible) -> Self {
56         unreachable!("infallible")
57     }
58 }
59
60 impl From<io::Error> for DeserializeError {
61     fn from(err: io::Error) -> Self {
62         if err.kind() == io::ErrorKind::UnexpectedEof {
63             DeserializeError::UnexpectedEof
64         } else {
65             DeserializeError::IoError(err)
66         }
67     }
68 }
69
70 pub trait OrDefault<T> {
71     fn or_default(self) -> Self;
72 }
73
74 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
75 impl<'a, R: Read> Read for WrapRead<'a, R> {
76     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
77         self.0.read(buf)
78     }
79
80     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
81         self.0.read_vectored(bufs)
82     }
83
84     /*
85
86     fn is_read_vectored(&self) -> bool {
87         self.0.is_read_vectored()
88     }
89
90     */
91
92     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
93         self.0.read_to_end(buf)
94     }
95
96     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
97         self.0.read_to_string(buf)
98     }
99
100     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
101         self.0.read_exact(buf)
102     }
103
104     /*
105
106     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
107         self.0.read_buf(buf)
108     }
109
110     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
111         self.0.read_buf_exact(cursor)
112     }
113
114     */
115 }
116
117 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
118     fn or_default(self) -> Self {
119         match self {
120             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
121             x => x,
122         }
123     }
124 }
125
126 pub trait MtLen {
127     fn option(&self) -> Option<usize>;
128
129     type Range: Iterator<Item = usize> + 'static;
130     fn range(&self) -> Self::Range;
131
132     type Take<R: Read>: Read;
133     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
134 }
135
136 pub trait MtCfg {
137     type Len: MtLen;
138
139     fn utf16() -> bool {
140         false
141     }
142
143     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
144     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
145 }
146
147 pub trait MtSerialize {
148     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
149 }
150
151 pub trait MtDeserialize: Sized {
152     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
153 }
154
155 impl MtLen for usize {
156     fn option(&self) -> Option<usize> {
157         Some(*self)
158     }
159
160     type Range = std::ops::Range<usize>;
161     fn range(&self) -> Self::Range {
162         0..*self
163     }
164
165     type Take<R: Read> = io::Take<R>;
166     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
167         reader.take(*self as u64)
168     }
169 }
170
171 trait MtCfgLen:
172     Sized
173     + MtSerialize
174     + MtDeserialize
175     + TryFrom<usize, Error: Into<SerializeError>>
176     + TryInto<usize, Error: Into<DeserializeError>>
177 {
178 }
179
180 impl<T: MtCfgLen> MtCfg for T {
181     type Len = usize;
182
183     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
184         Self::try_from(len)
185             .map_err(Into::into)?
186             .mt_serialize::<DefCfg>(writer)
187     }
188
189     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
190         Ok(Self::mt_deserialize::<DefCfg>(reader)?
191             .try_into()
192             .map_err(Into::into)?)
193     }
194 }
195
196 impl MtCfgLen for u8 {}
197 impl MtCfgLen for u16 {}
198 impl MtCfgLen for u32 {}
199 impl MtCfgLen for u64 {}
200
201 pub type DefCfg = u16;
202
203 impl MtCfg for () {
204     type Len = ();
205
206     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
207         Ok(())
208     }
209
210     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
211         Ok(())
212     }
213 }
214
215 impl MtLen for () {
216     fn option(&self) -> Option<usize> {
217         None
218     }
219
220     type Range = std::ops::RangeFrom<usize>;
221     fn range(&self) -> Self::Range {
222         0..
223     }
224
225     type Take<R: Read> = R;
226     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
227         reader
228     }
229 }
230
231 pub struct Utf16<B: MtCfg>(pub B);
232
233 impl<B: MtCfg> MtCfg for Utf16<B> {
234     type Len = B::Len;
235
236     fn utf16() -> bool {
237         true
238     }
239
240     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
241         B::write_len(len, writer)
242     }
243
244     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
245         B::read_len(reader)
246     }
247 }
248
249 impl MtSerialize for u8 {
250     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
251         writer.write_u8(*self)?;
252         Ok(())
253     }
254 }
255
256 impl MtDeserialize for u8 {
257     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
258         Ok(reader.read_u8()?)
259     }
260 }
261
262 impl MtSerialize for i8 {
263     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
264         writer.write_i8(*self)?;
265         Ok(())
266     }
267 }
268
269 impl MtDeserialize for i8 {
270     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
271         Ok(reader.read_i8()?)
272     }
273 }
274
275 macro_rules! impl_num {
276     ($T:ty) => {
277         impl MtSerialize for $T {
278             fn mt_serialize<C: MtCfg>(
279                 &self,
280                 writer: &mut impl Write,
281             ) -> Result<(), SerializeError> {
282                 paste_macro! {
283                     writer.[<write_ $T>]::<BigEndian>(*self)?;
284                 }
285                 Ok(())
286             }
287         }
288
289         impl MtDeserialize for $T {
290             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
291                 paste_macro! {
292                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
293                 }
294             }
295         }
296     };
297 }
298
299 impl_num!(u16);
300 impl_num!(i16);
301
302 impl_num!(u32);
303 impl_num!(i32);
304 impl_num!(f32);
305
306 impl_num!(u64);
307 impl_num!(i64);
308 impl_num!(f64);
309
310 impl MtSerialize for () {
311     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
312         Ok(())
313     }
314 }
315
316 impl MtDeserialize for () {
317     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
318         Ok(())
319     }
320 }
321
322 impl MtSerialize for bool {
323     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
324         (*self as u8).mt_serialize::<DefCfg>(writer)
325     }
326 }
327
328 impl MtDeserialize for bool {
329     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
330         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
331     }
332 }
333
334 impl<T: MtSerialize> MtSerialize for &T {
335     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
336         (*self).mt_serialize::<C>(writer)
337     }
338 }
339
340 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
341     writer: &mut impl Write,
342     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
343 ) -> Result<(), SerializeError> {
344     C::write_len(iter.len(), writer)?;
345
346     iter.into_iter()
347         .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
348 }
349
350 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
351     reader: &'a mut impl Read,
352 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
353     let len = C::read_len(reader)?;
354     mt_deserialize_sized_seq(&len, reader)
355 }
356
357 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
358     len: &L,
359     reader: &'a mut impl Read,
360 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
361     let variable = len.option().is_none();
362
363     Ok(len
364         .range()
365         .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
366             Err(DeserializeError::UnexpectedEof) if variable => None,
367             x => Some(x),
368         }))
369 }
370
371 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
372     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
373         mt_serialize_seq::<(), _>(writer, self.iter())
374     }
375 }
376
377 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
378     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
379         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
380     }
381 }
382
383 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
384     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
385         self.as_repr().mt_serialize::<DefCfg>(writer)
386     }
387 }
388
389 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
390     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
391         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
392             reader,
393         )?))
394     }
395 }
396
397 impl<T: MtSerialize> MtSerialize for Option<T> {
398     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
399         match self {
400             Some(item) => item.mt_serialize::<C>(writer),
401             None => Ok(()),
402         }
403     }
404 }
405
406 impl<T: MtDeserialize> MtDeserialize for Option<T> {
407     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
408         T::mt_deserialize::<C>(reader).map(Some).or_default()
409     }
410 }
411
412 impl<T: MtSerialize> MtSerialize for Vec<T> {
413     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
414         mt_serialize_seq::<C, _>(writer, self.iter())
415     }
416 }
417
418 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
419     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
420         mt_deserialize_seq::<C, _>(reader)?.try_collect()
421     }
422 }
423
424 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
425     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
426         mt_serialize_seq::<C, _>(writer, self.iter())
427     }
428 }
429
430 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
431     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
432         mt_deserialize_seq::<C, _>(reader)?.try_collect()
433     }
434 }
435
436 // TODO: support more tuples
437 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
438     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
439         self.0.mt_serialize::<DefCfg>(writer)?;
440         self.1.mt_serialize::<DefCfg>(writer)?;
441
442         Ok(())
443     }
444 }
445
446 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
447     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
448         let a = A::mt_deserialize::<DefCfg>(reader)?;
449         let b = B::mt_deserialize::<DefCfg>(reader)?;
450
451         Ok((a, b))
452     }
453 }
454
455 impl<K, V> MtSerialize for HashMap<K, V>
456 where
457     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
458     V: MtSerialize,
459 {
460     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
461         mt_serialize_seq::<C, _>(writer, self.iter())
462     }
463 }
464
465 impl<K, V> MtDeserialize for HashMap<K, V>
466 where
467     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
468     V: MtDeserialize,
469 {
470     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
471         mt_deserialize_seq::<C, _>(reader)?.try_collect()
472     }
473 }
474
475 impl MtSerialize for String {
476     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
477         if C::utf16() {
478             self.encode_utf16()
479                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
480                 .mt_serialize::<C>(writer)
481         } else {
482             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
483         }
484     }
485 }
486
487 impl MtDeserialize for String {
488     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
489         if C::utf16() {
490             let mut err = None;
491
492             let res =
493                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
494                     Ok(v) => Some(v),
495                     Err(e) => {
496                         err = Some(e);
497                         None
498                     }
499                 }))
500                 .try_collect();
501
502             match err {
503                 None => Ok(res?),
504                 Some(e) => Err(e),
505             }
506         } else {
507             let len = C::read_len(reader)?;
508
509             // use capacity if available
510             let mut st = match len.option() {
511                 Some(x) => String::with_capacity(x),
512                 None => String::new(),
513             };
514
515             len.take(WrapRead(reader)).read_to_string(&mut st)?;
516
517             Ok(st)
518         }
519     }
520 }
521
522 impl<T: MtSerialize> MtSerialize for Box<T> {
523     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
524         self.deref().mt_serialize::<C>(writer)
525     }
526 }
527
528 impl<T: MtDeserialize> MtDeserialize for Box<T> {
529     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
530         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
531     }
532 }