]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
4eaf119fdebb59f8386f316af30f554e560655e5
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
12 use std::{
13     collections::{HashMap, HashSet},
14     convert::Infallible,
15     fmt::Debug,
16     io::{self, Read, Write},
17     num::TryFromIntError,
18     ops::Deref,
19 };
20 use thiserror::Error;
21
22 #[cfg(test)]
23 mod tests;
24
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27     #[error("io error: {0}")]
28     IoError(#[from] io::Error),
29     #[error("collection too big: {0}")]
30     TooBig(#[from] TryFromIntError),
31 }
32
33 impl From<Infallible> for SerializeError {
34     fn from(_err: Infallible) -> Self {
35         unreachable!("infallible")
36     }
37 }
38
39 #[derive(Error, Debug)]
40 pub enum DeserializeError {
41     #[error("io error: {0}")]
42     IoError(io::Error),
43     #[error("unexpected end of file")]
44     UnexpectedEof,
45     #[error("collection too big: {0}")]
46     TooBig(#[from] TryFromIntError),
47     #[error("invalid UTF-16: {0}")]
48     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
49     #[error("invalid {0} enum variant {1}")]
50     InvalidEnumVariant(&'static str, u64),
51     #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
52     InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
53 }
54
55 impl From<Infallible> for DeserializeError {
56     fn from(_err: Infallible) -> Self {
57         unreachable!("infallible")
58     }
59 }
60
61 impl From<io::Error> for DeserializeError {
62     fn from(err: io::Error) -> Self {
63         if err.kind() == io::ErrorKind::UnexpectedEof {
64             DeserializeError::UnexpectedEof
65         } else {
66             DeserializeError::IoError(err)
67         }
68     }
69 }
70
71 pub trait OrDefault<T> {
72     fn or_default(self) -> Self;
73 }
74
75 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
76 impl<'a, R: Read> Read for WrapRead<'a, R> {
77     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
78         self.0.read(buf)
79     }
80
81     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
82         self.0.read_vectored(bufs)
83     }
84
85     /*
86
87     fn is_read_vectored(&self) -> bool {
88         self.0.is_read_vectored()
89     }
90
91     */
92
93     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
94         self.0.read_to_end(buf)
95     }
96
97     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
98         self.0.read_to_string(buf)
99     }
100
101     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
102         self.0.read_exact(buf)
103     }
104
105     /*
106
107     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
108         self.0.read_buf(buf)
109     }
110
111     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
112         self.0.read_buf_exact(cursor)
113     }
114
115     */
116 }
117
118 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
119     fn or_default(self) -> Self {
120         match self {
121             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
122             x => x,
123         }
124     }
125 }
126
127 pub trait MtLen {
128     fn option(&self) -> Option<usize>;
129
130     type Range: Iterator<Item = usize> + 'static;
131     fn range(&self) -> Self::Range;
132
133     type Take<R: Read>: Read;
134     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
135 }
136
137 pub trait MtCfg {
138     type Len: MtLen;
139
140     fn utf16() -> bool {
141         false
142     }
143
144     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
145     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
146 }
147
148 pub trait MtSerialize {
149     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
150 }
151
152 pub trait MtDeserialize: Sized {
153     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
154 }
155
156 impl MtLen for usize {
157     fn option(&self) -> Option<usize> {
158         Some(*self)
159     }
160
161     type Range = std::ops::Range<usize>;
162     fn range(&self) -> Self::Range {
163         0..*self
164     }
165
166     type Take<R: Read> = io::Take<R>;
167     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
168         reader.take(*self as u64)
169     }
170 }
171
172 trait MtCfgLen:
173     Sized
174     + MtSerialize
175     + MtDeserialize
176     + TryFrom<usize, Error: Into<SerializeError>>
177     + TryInto<usize, Error: Into<DeserializeError>>
178 {
179 }
180
181 impl<T: MtCfgLen> MtCfg for T {
182     type Len = usize;
183
184     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
185         Self::try_from(len)
186             .map_err(Into::into)?
187             .mt_serialize::<DefCfg>(writer)
188     }
189
190     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
191         Ok(Self::mt_deserialize::<DefCfg>(reader)?
192             .try_into()
193             .map_err(Into::into)?)
194     }
195 }
196
197 impl MtCfgLen for u8 {}
198 impl MtCfgLen for u16 {}
199 impl MtCfgLen for u32 {}
200 impl MtCfgLen for u64 {}
201
202 pub type DefCfg = u16;
203
204 impl MtCfg for () {
205     type Len = ();
206
207     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
208         Ok(())
209     }
210
211     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
212         Ok(())
213     }
214 }
215
216 impl MtLen for () {
217     fn option(&self) -> Option<usize> {
218         None
219     }
220
221     type Range = std::ops::RangeFrom<usize>;
222     fn range(&self) -> Self::Range {
223         0..
224     }
225
226     type Take<R: Read> = R;
227     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
228         reader
229     }
230 }
231
232 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
233
234 impl<B: MtCfg> MtCfg for Utf16<B> {
235     type Len = B::Len;
236
237     fn utf16() -> bool {
238         true
239     }
240
241     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
242         B::write_len(len, writer)
243     }
244
245     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
246         B::read_len(reader)
247     }
248 }
249
250 impl MtSerialize for u8 {
251     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
252         writer.write_u8(*self)?;
253         Ok(())
254     }
255 }
256
257 impl MtDeserialize for u8 {
258     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
259         Ok(reader.read_u8()?)
260     }
261 }
262
263 impl MtSerialize for i8 {
264     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
265         writer.write_i8(*self)?;
266         Ok(())
267     }
268 }
269
270 impl MtDeserialize for i8 {
271     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
272         Ok(reader.read_i8()?)
273     }
274 }
275
276 macro_rules! impl_num {
277     ($T:ty) => {
278         impl MtSerialize for $T {
279             fn mt_serialize<C: MtCfg>(
280                 &self,
281                 writer: &mut impl Write,
282             ) -> Result<(), SerializeError> {
283                 paste_macro! {
284                     writer.[<write_ $T>]::<BigEndian>(*self)?;
285                 }
286                 Ok(())
287             }
288         }
289
290         impl MtDeserialize for $T {
291             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
292                 paste_macro! {
293                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
294                 }
295             }
296         }
297     };
298 }
299
300 impl_num!(u16);
301 impl_num!(i16);
302
303 impl_num!(u32);
304 impl_num!(i32);
305 impl_num!(f32);
306
307 impl_num!(u64);
308 impl_num!(i64);
309 impl_num!(f64);
310
311 impl MtSerialize for () {
312     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
313         Ok(())
314     }
315 }
316
317 impl MtDeserialize for () {
318     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
319         Ok(())
320     }
321 }
322
323 impl MtSerialize for bool {
324     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
325         (*self as u8).mt_serialize::<DefCfg>(writer)
326     }
327 }
328
329 impl MtDeserialize for bool {
330     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
331         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
332     }
333 }
334
335 impl<T: MtSerialize> MtSerialize for &T {
336     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
337         (*self).mt_serialize::<C>(writer)
338     }
339 }
340
341 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
342     writer: &mut impl Write,
343     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
344 ) -> Result<(), SerializeError> {
345     C::write_len(iter.len(), writer)?;
346
347     iter.into_iter()
348         .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
349 }
350
351 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
352     reader: &'a mut impl Read,
353 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
354     let len = C::read_len(reader)?;
355     mt_deserialize_sized_seq(&len, reader)
356 }
357
358 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
359     len: &L,
360     reader: &'a mut impl Read,
361 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
362     let variable = len.option().is_none();
363
364     Ok(len
365         .range()
366         .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
367             Err(DeserializeError::UnexpectedEof) if variable => None,
368             x => Some(x),
369         }))
370 }
371
372 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
373     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
374         mt_serialize_seq::<(), _>(writer, self.iter())
375     }
376 }
377
378 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
379     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
380         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
381     }
382 }
383
384 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
385     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
386         self.as_repr().mt_serialize::<DefCfg>(writer)
387     }
388 }
389
390 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
391     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
392         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
393             reader,
394         )?))
395     }
396 }
397
398 impl<T: MtSerialize> MtSerialize for Option<T> {
399     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
400         match self {
401             Some(item) => item.mt_serialize::<C>(writer),
402             None => Ok(()),
403         }
404     }
405 }
406
407 impl<T: MtDeserialize> MtDeserialize for Option<T> {
408     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
409         T::mt_deserialize::<C>(reader).map(Some).or_default()
410     }
411 }
412
413 impl<T: MtSerialize> MtSerialize for Vec<T> {
414     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
415         mt_serialize_seq::<C, _>(writer, self.iter())
416     }
417 }
418
419 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
420     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
421         mt_deserialize_seq::<C, _>(reader)?.try_collect()
422     }
423 }
424
425 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
426     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
427         mt_serialize_seq::<C, _>(writer, self.iter())
428     }
429 }
430
431 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
432     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
433         mt_deserialize_seq::<C, _>(reader)?.try_collect()
434     }
435 }
436
437 // TODO: support more tuples
438 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
439     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
440         self.0.mt_serialize::<DefCfg>(writer)?;
441         self.1.mt_serialize::<DefCfg>(writer)?;
442
443         Ok(())
444     }
445 }
446
447 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
448     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
449         let a = A::mt_deserialize::<DefCfg>(reader)?;
450         let b = B::mt_deserialize::<DefCfg>(reader)?;
451
452         Ok((a, b))
453     }
454 }
455
456 impl<K, V> MtSerialize for HashMap<K, V>
457 where
458     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
459     V: MtSerialize,
460 {
461     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
462         mt_serialize_seq::<C, _>(writer, self.iter())
463     }
464 }
465
466 impl<K, V> MtDeserialize for HashMap<K, V>
467 where
468     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
469     V: MtDeserialize,
470 {
471     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
472         mt_deserialize_seq::<C, _>(reader)?.try_collect()
473     }
474 }
475
476 impl MtSerialize for String {
477     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
478         if C::utf16() {
479             self.encode_utf16()
480                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
481                 .mt_serialize::<C>(writer)
482         } else {
483             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
484         }
485     }
486 }
487
488 impl MtDeserialize for String {
489     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
490         if C::utf16() {
491             let mut err = None;
492
493             let res =
494                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
495                     Ok(v) => Some(v),
496                     Err(e) => {
497                         err = Some(e);
498                         None
499                     }
500                 }))
501                 .try_collect();
502
503             match err {
504                 None => Ok(res?),
505                 Some(e) => Err(e),
506             }
507         } else {
508             let len = C::read_len(reader)?;
509
510             // use capacity if available
511             let mut st = match len.option() {
512                 Some(x) => String::with_capacity(x),
513                 None => String::new(),
514             };
515
516             len.take(WrapRead(reader)).read_to_string(&mut st)?;
517
518             Ok(st)
519         }
520     }
521 }
522
523 impl<T: MtSerialize> MtSerialize for Box<T> {
524     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
525         self.deref().mt_serialize::<C>(writer)
526     }
527 }
528
529 impl<T: MtDeserialize> MtDeserialize for Box<T> {
530     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
531         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
532     }
533 }