]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
3a84f78145a21e28875637462346fcf0f0c86508
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
12 use std::{
13     collections::{HashMap, HashSet},
14     convert::Infallible,
15     fmt::Debug,
16     io::{self, Read, Write},
17     num::TryFromIntError,
18     ops::Deref,
19 };
20 use thiserror::Error;
21
22 #[cfg(test)]
23 mod tests;
24
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27     #[error("io error: {0}")]
28     IoError(#[from] io::Error),
29     #[error("collection too big: {0}")]
30     TooBig(#[from] TryFromIntError),
31 }
32
33 impl From<Infallible> for SerializeError {
34     fn from(_err: Infallible) -> Self {
35         unreachable!("infallible")
36     }
37 }
38
39 #[derive(Error, Debug)]
40 pub enum DeserializeError {
41     #[error("io error: {0}")]
42     IoError(io::Error),
43     #[error("unexpected end of file")]
44     UnexpectedEof,
45     #[error("collection too big: {0}")]
46     TooBig(#[from] TryFromIntError),
47     #[error("invalid UTF-16: {0}")]
48     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
49     #[error("invalid {0} enum variant {1}")]
50     InvalidEnumVariant(&'static str, u64),
51     #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
52     InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
53 }
54
55 impl From<Infallible> for DeserializeError {
56     fn from(_err: Infallible) -> Self {
57         unreachable!("infallible")
58     }
59 }
60
61 impl From<io::Error> for DeserializeError {
62     fn from(err: io::Error) -> Self {
63         if err.kind() == io::ErrorKind::UnexpectedEof {
64             DeserializeError::UnexpectedEof
65         } else {
66             DeserializeError::IoError(err)
67         }
68     }
69 }
70
71 pub trait OrDefault<T> {
72     fn or_default(self) -> Self;
73 }
74
75 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
76 impl<'a, R: Read> Read for WrapRead<'a, R> {
77     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
78         self.0.read(buf)
79     }
80
81     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
82         self.0.read_vectored(bufs)
83     }
84
85     /*
86
87     fn is_read_vectored(&self) -> bool {
88         self.0.is_read_vectored()
89     }
90
91     */
92
93     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
94         self.0.read_to_end(buf)
95     }
96
97     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
98         self.0.read_to_string(buf)
99     }
100
101     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
102         self.0.read_exact(buf)
103     }
104
105     /*
106
107     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
108         self.0.read_buf(buf)
109     }
110
111     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
112         self.0.read_buf_exact(cursor)
113     }
114
115     */
116 }
117
118 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
119     fn or_default(self) -> Self {
120         match self {
121             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
122             x => x,
123         }
124     }
125 }
126
127 pub trait MtLen {
128     fn option(&self) -> Option<usize>;
129
130     type Range: Iterator<Item = usize> + 'static;
131     fn range(&self) -> Self::Range;
132
133     type Take<R: Read>: Read;
134     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
135 }
136
137 pub trait MtCfg {
138     type Len: MtLen;
139     type Inner: MtCfg;
140
141     fn utf16() -> bool {
142         false
143     }
144
145     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
146     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
147 }
148
149 pub trait MtSerialize {
150     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
151 }
152
153 pub trait MtDeserialize: Sized {
154     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
155 }
156
157 impl MtLen for usize {
158     fn option(&self) -> Option<usize> {
159         Some(*self)
160     }
161
162     type Range = std::ops::Range<usize>;
163     fn range(&self) -> Self::Range {
164         0..*self
165     }
166
167     type Take<R: Read> = io::Take<R>;
168     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
169         reader.take(*self as u64)
170     }
171 }
172
173 trait MtCfgLen:
174     Sized
175     + MtSerialize
176     + MtDeserialize
177     + TryFrom<usize, Error: Into<SerializeError>>
178     + TryInto<usize, Error: Into<DeserializeError>>
179 {
180 }
181
182 impl<T: MtCfgLen> MtCfg for T {
183     type Len = usize;
184     type Inner = DefCfg;
185
186     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
187         Self::try_from(len)
188             .map_err(Into::into)?
189             .mt_serialize::<DefCfg>(writer)
190     }
191
192     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
193         Ok(Self::mt_deserialize::<DefCfg>(reader)?
194             .try_into()
195             .map_err(Into::into)?)
196     }
197 }
198
199 impl MtCfgLen for u8 {}
200 impl MtCfgLen for u16 {}
201 impl MtCfgLen for u32 {}
202 impl MtCfgLen for u64 {}
203
204 pub type DefCfg = u16;
205
206 impl MtCfg for () {
207     type Len = ();
208     type Inner = DefCfg;
209
210     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
211         Ok(())
212     }
213
214     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
215         Ok(())
216     }
217 }
218
219 impl MtLen for () {
220     fn option(&self) -> Option<usize> {
221         None
222     }
223
224     type Range = std::ops::RangeFrom<usize>;
225     fn range(&self) -> Self::Range {
226         0..
227     }
228
229     type Take<R: Read> = R;
230     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
231         reader
232     }
233 }
234
235 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
236
237 impl<B: MtCfg> MtCfg for Utf16<B> {
238     type Len = B::Len;
239     type Inner = B::Inner;
240
241     fn utf16() -> bool {
242         true
243     }
244
245     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
246         B::write_len(len, writer)
247     }
248
249     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
250         B::read_len(reader)
251     }
252 }
253
254 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
255     type Len = A::Len;
256     type Inner = B;
257
258     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
259         A::write_len(len, writer)
260     }
261
262     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
263         A::read_len(reader)
264     }
265 }
266
267 impl MtSerialize for u8 {
268     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
269         writer.write_u8(*self)?;
270         Ok(())
271     }
272 }
273
274 impl MtDeserialize for u8 {
275     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
276         Ok(reader.read_u8()?)
277     }
278 }
279
280 impl MtSerialize for i8 {
281     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
282         writer.write_i8(*self)?;
283         Ok(())
284     }
285 }
286
287 impl MtDeserialize for i8 {
288     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
289         Ok(reader.read_i8()?)
290     }
291 }
292
293 macro_rules! impl_num {
294     ($T:ty) => {
295         impl MtSerialize for $T {
296             fn mt_serialize<C: MtCfg>(
297                 &self,
298                 writer: &mut impl Write,
299             ) -> Result<(), SerializeError> {
300                 paste_macro! {
301                     writer.[<write_ $T>]::<BigEndian>(*self)?;
302                 }
303                 Ok(())
304             }
305         }
306
307         impl MtDeserialize for $T {
308             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
309                 paste_macro! {
310                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
311                 }
312             }
313         }
314     };
315 }
316
317 impl_num!(u16);
318 impl_num!(i16);
319
320 impl_num!(u32);
321 impl_num!(i32);
322 impl_num!(f32);
323
324 impl_num!(u64);
325 impl_num!(i64);
326 impl_num!(f64);
327
328 impl MtSerialize for () {
329     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
330         Ok(())
331     }
332 }
333
334 impl MtDeserialize for () {
335     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
336         Ok(())
337     }
338 }
339
340 impl MtSerialize for bool {
341     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
342         (*self as u8).mt_serialize::<DefCfg>(writer)
343     }
344 }
345
346 impl MtDeserialize for bool {
347     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
348         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
349     }
350 }
351
352 impl<T: MtSerialize> MtSerialize for &T {
353     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
354         (*self).mt_serialize::<C>(writer)
355     }
356 }
357
358 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
359     writer: &mut impl Write,
360     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
361 ) -> Result<(), SerializeError> {
362     C::write_len(iter.len(), writer)?;
363
364     iter.into_iter()
365         .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
366 }
367
368 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
369     reader: &'a mut impl Read,
370 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
371     let len = C::read_len(reader)?;
372     mt_deserialize_sized_seq::<C, _>(&len, reader)
373 }
374
375 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
376     len: &C::Len,
377     reader: &'a mut impl Read,
378 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
379     let variable = len.option().is_none();
380
381     Ok(len
382         .range()
383         .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
384             Err(DeserializeError::UnexpectedEof) if variable => None,
385             x => Some(x),
386         }))
387 }
388
389 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
390     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
391         mt_serialize_seq::<(), _>(writer, self.iter())
392     }
393 }
394
395 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
396     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
397         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
398     }
399 }
400
401 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
402     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
403         self.as_repr().mt_serialize::<DefCfg>(writer)
404     }
405 }
406
407 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
408     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
409         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
410             reader,
411         )?))
412     }
413 }
414
415 impl<T: MtSerialize> MtSerialize for Option<T> {
416     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
417         match self {
418             Some(item) => item.mt_serialize::<C>(writer),
419             None => Ok(()),
420         }
421     }
422 }
423
424 impl<T: MtDeserialize> MtDeserialize for Option<T> {
425     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
426         T::mt_deserialize::<C>(reader).map(Some).or_default()
427     }
428 }
429
430 impl<T: MtSerialize> MtSerialize for Vec<T> {
431     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
432         mt_serialize_seq::<C, _>(writer, self.iter())
433     }
434 }
435
436 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
437     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
438         mt_deserialize_seq::<C, _>(reader)?.try_collect()
439     }
440 }
441
442 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
443     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
444         mt_serialize_seq::<C, _>(writer, self.iter())
445     }
446 }
447
448 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
449     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
450         mt_deserialize_seq::<C, _>(reader)?.try_collect()
451     }
452 }
453
454 // TODO: support more tuples
455 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
456     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
457         self.0.mt_serialize::<C>(writer)?;
458         self.1.mt_serialize::<C::Inner>(writer)?;
459
460         Ok(())
461     }
462 }
463
464 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
465     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
466         let a = A::mt_deserialize::<C>(reader)?;
467         let b = B::mt_deserialize::<C::Inner>(reader)?;
468
469         Ok((a, b))
470     }
471 }
472
473 impl<K, V> MtSerialize for HashMap<K, V>
474 where
475     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
476     V: MtSerialize,
477 {
478     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
479         mt_serialize_seq::<C, _>(writer, self.iter())
480     }
481 }
482
483 impl<K, V> MtDeserialize for HashMap<K, V>
484 where
485     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
486     V: MtDeserialize,
487 {
488     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
489         mt_deserialize_seq::<C, _>(reader)?.try_collect()
490     }
491 }
492
493 impl MtSerialize for String {
494     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
495         if C::utf16() {
496             self.encode_utf16()
497                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
498                 .mt_serialize::<C>(writer)
499         } else {
500             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
501         }
502     }
503 }
504
505 impl MtDeserialize for String {
506     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
507         if C::utf16() {
508             let mut err = None;
509
510             let res =
511                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
512                     Ok(v) => Some(v),
513                     Err(e) => {
514                         err = Some(e);
515                         None
516                     }
517                 }))
518                 .try_collect();
519
520             match err {
521                 None => Ok(res?),
522                 Some(e) => Err(e),
523             }
524         } else {
525             let len = C::read_len(reader)?;
526
527             // use capacity if available
528             let mut st = match len.option() {
529                 Some(x) => String::with_capacity(x),
530                 None => String::new(),
531             };
532
533             len.take(WrapRead(reader)).read_to_string(&mut st)?;
534
535             Ok(st)
536         }
537     }
538 }
539
540 impl<T: MtSerialize> MtSerialize for Box<T> {
541     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
542         self.deref().mt_serialize::<C>(writer)
543     }
544 }
545
546 impl<T: MtDeserialize> MtDeserialize for Box<T> {
547     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
548         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
549     }
550 }