]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
2aa5c5ad5d07131b107a0b56cd6fdac96e0572ee
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
12 use std::{
13     collections::{HashMap, HashSet},
14     convert::Infallible,
15     fmt::Debug,
16     io::{self, Read, Write},
17     num::TryFromIntError,
18     ops::Deref,
19 };
20 use thiserror::Error;
21
22 #[cfg(test)]
23 mod tests;
24
25 #[derive(Error, Debug)]
26 pub enum SerializeError {
27     #[error("io error: {0}")]
28     IoError(#[from] io::Error),
29     #[error("collection too big: {0}")]
30     TooBig(#[from] TryFromIntError),
31     #[error("{0}")]
32     Other(String),
33 }
34
35 impl From<Infallible> for SerializeError {
36     fn from(_err: Infallible) -> Self {
37         unreachable!("infallible")
38     }
39 }
40
41 #[derive(Error, Debug)]
42 pub enum DeserializeError {
43     #[error("io error: {0}")]
44     IoError(io::Error),
45     #[error("unexpected end of file")]
46     UnexpectedEof,
47     #[error("collection too big: {0}")]
48     TooBig(#[from] TryFromIntError),
49     #[error("invalid UTF-16: {0}")]
50     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
51     #[error("invalid {0} enum variant {1}")]
52     InvalidEnumVariant(&'static str, u64),
53     #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
54     InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
55     #[error("{0}")]
56     Other(String),
57 }
58
59 impl From<Infallible> for DeserializeError {
60     fn from(_err: Infallible) -> Self {
61         unreachable!("infallible")
62     }
63 }
64
65 impl From<io::Error> for DeserializeError {
66     fn from(err: io::Error) -> Self {
67         if err.kind() == io::ErrorKind::UnexpectedEof {
68             DeserializeError::UnexpectedEof
69         } else {
70             DeserializeError::IoError(err)
71         }
72     }
73 }
74
75 pub trait OrDefault<T> {
76     fn or_default(self) -> Self;
77 }
78
79 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
80 impl<'a, R: Read> Read for WrapRead<'a, R> {
81     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
82         self.0.read(buf)
83     }
84
85     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
86         self.0.read_vectored(bufs)
87     }
88
89     /*
90
91     fn is_read_vectored(&self) -> bool {
92         self.0.is_read_vectored()
93     }
94
95     */
96
97     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
98         self.0.read_to_end(buf)
99     }
100
101     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
102         self.0.read_to_string(buf)
103     }
104
105     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
106         self.0.read_exact(buf)
107     }
108
109     /*
110
111     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
112         self.0.read_buf(buf)
113     }
114
115     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
116         self.0.read_buf_exact(cursor)
117     }
118
119     */
120 }
121
122 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
123     fn or_default(self) -> Self {
124         match self {
125             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
126             x => x,
127         }
128     }
129 }
130
131 pub trait MtLen {
132     fn option(&self) -> Option<usize>;
133
134     type Range: Iterator<Item = usize> + 'static;
135     fn range(&self) -> Self::Range;
136
137     type Take<R: Read>: Read;
138     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
139 }
140
141 pub trait MtCfg {
142     type Len: MtLen;
143     type Inner: MtCfg;
144
145     fn utf16() -> bool {
146         false
147     }
148
149     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
150     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
151 }
152
153 pub trait MtSerialize {
154     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
155 }
156
157 pub trait MtDeserialize: Sized {
158     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
159 }
160
161 impl MtLen for usize {
162     fn option(&self) -> Option<usize> {
163         Some(*self)
164     }
165
166     type Range = std::ops::Range<usize>;
167     fn range(&self) -> Self::Range {
168         0..*self
169     }
170
171     type Take<R: Read> = io::Take<R>;
172     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
173         reader.take(*self as u64)
174     }
175 }
176
177 trait MtCfgLen:
178     Sized
179     + MtSerialize
180     + MtDeserialize
181     + TryFrom<usize, Error: Into<SerializeError>>
182     + TryInto<usize, Error: Into<DeserializeError>>
183 {
184 }
185
186 impl<T: MtCfgLen> MtCfg for T {
187     type Len = usize;
188     type Inner = DefCfg;
189
190     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
191         Self::try_from(len)
192             .map_err(Into::into)?
193             .mt_serialize::<DefCfg>(writer)
194     }
195
196     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
197         Ok(Self::mt_deserialize::<DefCfg>(reader)?
198             .try_into()
199             .map_err(Into::into)?)
200     }
201 }
202
203 impl MtCfgLen for u8 {}
204 impl MtCfgLen for u16 {}
205 impl MtCfgLen for u32 {}
206 impl MtCfgLen for u64 {}
207
208 pub type DefCfg = u16;
209
210 impl MtCfg for () {
211     type Len = ();
212     type Inner = DefCfg;
213
214     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
215         Ok(())
216     }
217
218     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
219         Ok(())
220     }
221 }
222
223 impl MtLen for () {
224     fn option(&self) -> Option<usize> {
225         None
226     }
227
228     type Range = std::ops::RangeFrom<usize>;
229     fn range(&self) -> Self::Range {
230         0..
231     }
232
233     type Take<R: Read> = R;
234     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
235         reader
236     }
237 }
238
239 pub struct Utf16<B: MtCfg = DefCfg>(pub B);
240
241 impl<B: MtCfg> MtCfg for Utf16<B> {
242     type Len = B::Len;
243     type Inner = B::Inner;
244
245     fn utf16() -> bool {
246         true
247     }
248
249     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
250         B::write_len(len, writer)
251     }
252
253     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
254         B::read_len(reader)
255     }
256 }
257
258 impl<A: MtCfg, B: MtCfg> MtCfg for (A, B) {
259     type Len = A::Len;
260     type Inner = B;
261
262     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
263         A::write_len(len, writer)
264     }
265
266     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
267         A::read_len(reader)
268     }
269 }
270
271 impl MtSerialize for u8 {
272     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
273         writer.write_u8(*self)?;
274         Ok(())
275     }
276 }
277
278 impl MtDeserialize for u8 {
279     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
280         Ok(reader.read_u8()?)
281     }
282 }
283
284 impl MtSerialize for i8 {
285     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
286         writer.write_i8(*self)?;
287         Ok(())
288     }
289 }
290
291 impl MtDeserialize for i8 {
292     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
293         Ok(reader.read_i8()?)
294     }
295 }
296
297 macro_rules! impl_num {
298     ($T:ty) => {
299         impl MtSerialize for $T {
300             fn mt_serialize<C: MtCfg>(
301                 &self,
302                 writer: &mut impl Write,
303             ) -> Result<(), SerializeError> {
304                 paste_macro! {
305                     writer.[<write_ $T>]::<BigEndian>(*self)?;
306                 }
307                 Ok(())
308             }
309         }
310
311         impl MtDeserialize for $T {
312             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
313                 paste_macro! {
314                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
315                 }
316             }
317         }
318     };
319 }
320
321 impl_num!(u16);
322 impl_num!(i16);
323
324 impl_num!(u32);
325 impl_num!(i32);
326 impl_num!(f32);
327
328 impl_num!(u64);
329 impl_num!(i64);
330 impl_num!(f64);
331
332 impl MtSerialize for () {
333     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
334         Ok(())
335     }
336 }
337
338 impl MtDeserialize for () {
339     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
340         Ok(())
341     }
342 }
343
344 impl MtSerialize for bool {
345     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
346         (*self as u8).mt_serialize::<DefCfg>(writer)
347     }
348 }
349
350 impl MtDeserialize for bool {
351     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
352         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
353     }
354 }
355
356 impl<T: MtSerialize> MtSerialize for &T {
357     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
358         (*self).mt_serialize::<C>(writer)
359     }
360 }
361
362 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
363     writer: &mut impl Write,
364     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
365 ) -> Result<(), SerializeError> {
366     C::write_len(iter.len(), writer)?;
367
368     iter.into_iter()
369         .try_for_each(|item| item.mt_serialize::<C::Inner>(writer))
370 }
371
372 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
373     reader: &'a mut impl Read,
374 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
375     let len = C::read_len(reader)?;
376     mt_deserialize_sized_seq::<C, _>(&len, reader)
377 }
378
379 pub fn mt_deserialize_sized_seq<'a, C: MtCfg, T: MtDeserialize>(
380     len: &C::Len,
381     reader: &'a mut impl Read,
382 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
383     let variable = len.option().is_none();
384
385     Ok(len
386         .range()
387         .map_while(move |_| match T::mt_deserialize::<C::Inner>(reader) {
388             Err(DeserializeError::UnexpectedEof) if variable => None,
389             x => Some(x),
390         }))
391 }
392
393 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
394     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
395         mt_serialize_seq::<(), _>(writer, self.iter())
396     }
397 }
398
399 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
400     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
401         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
402     }
403 }
404
405 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
406     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
407         self.as_repr().mt_serialize::<DefCfg>(writer)
408     }
409 }
410
411 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
412     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
413         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
414             reader,
415         )?))
416     }
417 }
418
419 impl<T: MtSerialize> MtSerialize for Option<T> {
420     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
421         match self {
422             Some(item) => item.mt_serialize::<C>(writer),
423             None => Ok(()),
424         }
425     }
426 }
427
428 impl<T: MtDeserialize> MtDeserialize for Option<T> {
429     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
430         T::mt_deserialize::<C>(reader).map(Some).or_default()
431     }
432 }
433
434 impl<T: MtSerialize> MtSerialize for Vec<T> {
435     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
436         mt_serialize_seq::<C, _>(writer, self.iter())
437     }
438 }
439
440 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
441     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
442         mt_deserialize_seq::<C, _>(reader)?.try_collect()
443     }
444 }
445
446 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
447     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
448         mt_serialize_seq::<C, _>(writer, self.iter())
449     }
450 }
451
452 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
453     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
454         mt_deserialize_seq::<C, _>(reader)?.try_collect()
455     }
456 }
457
458 // TODO: support more tuples
459 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
460     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
461         self.0.mt_serialize::<C>(writer)?;
462         self.1.mt_serialize::<C::Inner>(writer)?;
463
464         Ok(())
465     }
466 }
467
468 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
469     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
470         let a = A::mt_deserialize::<C>(reader)?;
471         let b = B::mt_deserialize::<C::Inner>(reader)?;
472
473         Ok((a, b))
474     }
475 }
476
477 impl<K, V> MtSerialize for HashMap<K, V>
478 where
479     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
480     V: MtSerialize,
481 {
482     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
483         mt_serialize_seq::<C, _>(writer, self.iter())
484     }
485 }
486
487 impl<K, V> MtDeserialize for HashMap<K, V>
488 where
489     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
490     V: MtDeserialize,
491 {
492     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
493         mt_deserialize_seq::<C, _>(reader)?.try_collect()
494     }
495 }
496
497 impl MtSerialize for String {
498     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
499         if C::utf16() {
500             self.encode_utf16()
501                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
502                 .mt_serialize::<C>(writer)
503         } else {
504             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
505         }
506     }
507 }
508
509 impl MtDeserialize for String {
510     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
511         if C::utf16() {
512             let mut err = None;
513
514             let res =
515                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
516                     Ok(v) => Some(v),
517                     Err(e) => {
518                         err = Some(e);
519                         None
520                     }
521                 }))
522                 .try_collect();
523
524             match err {
525                 None => Ok(res?),
526                 Some(e) => Err(e),
527             }
528         } else {
529             let len = C::read_len(reader)?;
530
531             // use capacity if available
532             let mut st = match len.option() {
533                 Some(x) => String::with_capacity(x),
534                 None => String::new(),
535             };
536
537             len.take(WrapRead(reader)).read_to_string(&mut st)?;
538
539             Ok(st)
540         }
541     }
542 }
543
544 impl<T: MtSerialize> MtSerialize for Box<T> {
545     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
546         self.deref().mt_serialize::<C>(writer)
547     }
548 }
549
550 impl<T: MtDeserialize> MtDeserialize for Box<T> {
551     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
552         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
553     }
554 }