]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
0019d1c0b63e023abc5befa8b21e824a8e972d6b
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use flate2;
6 pub use mt_ser_derive::{mt_derive, MtDeserialize, MtSerialize};
7 pub use paste;
8
9 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10 use enumset::{EnumSet, EnumSetTypeWithRepr};
11 use paste::paste as paste_macro;
12 use std::{
13     collections::{HashMap, HashSet},
14     convert::Infallible,
15     io::{self, Read, Write},
16     num::TryFromIntError,
17     ops::Deref,
18 };
19 use thiserror::Error;
20
21 #[cfg(test)]
22 mod tests;
23
24 #[derive(Error, Debug)]
25 pub enum SerializeError {
26     #[error("io error: {0}")]
27     IoError(#[from] io::Error),
28     #[error("collection too big: {0}")]
29     TooBig(#[from] TryFromIntError),
30 }
31
32 impl From<Infallible> for SerializeError {
33     fn from(_err: Infallible) -> Self {
34         unreachable!("infallible")
35     }
36 }
37
38 #[derive(Error, Debug)]
39 pub enum DeserializeError {
40     #[error("io error: {0}")]
41     IoError(io::Error),
42     #[error("unexpected end of file")]
43     UnexpectedEof,
44     #[error("collection too big: {0}")]
45     TooBig(#[from] TryFromIntError),
46     #[error("invalid UTF-16: {0}")]
47     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
48     #[error("unimplemented")]
49     Unimplemented,
50 }
51
52 impl From<Infallible> for DeserializeError {
53     fn from(_err: Infallible) -> Self {
54         unreachable!("infallible")
55     }
56 }
57
58 impl From<io::Error> for DeserializeError {
59     fn from(err: io::Error) -> Self {
60         if err.kind() == io::ErrorKind::UnexpectedEof {
61             DeserializeError::UnexpectedEof
62         } else {
63             DeserializeError::IoError(err)
64         }
65     }
66 }
67
68 pub trait OrDefault<T> {
69     fn or_default(self) -> Self;
70 }
71
72 pub struct WrapRead<'a, R: Read>(pub &'a mut R);
73 impl<'a, R: Read> Read for WrapRead<'a, R> {
74     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
75         self.0.read(buf)
76     }
77
78     fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
79         self.0.read_vectored(bufs)
80     }
81
82     /*
83
84     fn is_read_vectored(&self) -> bool {
85         self.0.is_read_vectored()
86     }
87
88     */
89
90     fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
91         self.0.read_to_end(buf)
92     }
93
94     fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
95         self.0.read_to_string(buf)
96     }
97
98     fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
99         self.0.read_exact(buf)
100     }
101
102     /*
103
104     fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
105         self.0.read_buf(buf)
106     }
107
108     fn read_buf_exact(&mut self, cursor: io::BorrowedCursor<'_>) -> io::Result<()> {
109         self.0.read_buf_exact(cursor)
110     }
111
112     */
113 }
114
115 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
116     fn or_default(self) -> Self {
117         match self {
118             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
119             x => x,
120         }
121     }
122 }
123
124 pub trait MtLen {
125     fn option(&self) -> Option<usize>;
126
127     type Range: Iterator<Item = usize> + 'static;
128     fn range(&self) -> Self::Range;
129
130     type Take<R: Read>: Read;
131     fn take<R: Read>(&self, reader: R) -> Self::Take<R>;
132 }
133
134 pub trait MtCfg {
135     type Len: MtLen;
136
137     fn utf16() -> bool {
138         false
139     }
140
141     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError>;
142     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError>;
143 }
144
145 pub trait MtSerialize {
146     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
147 }
148
149 pub trait MtDeserialize: Sized {
150     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
151 }
152
153 impl MtLen for usize {
154     fn option(&self) -> Option<usize> {
155         Some(*self)
156     }
157
158     type Range = std::ops::Range<usize>;
159     fn range(&self) -> Self::Range {
160         0..*self
161     }
162
163     type Take<R: Read> = io::Take<R>;
164     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
165         reader.take(*self as u64)
166     }
167 }
168
169 trait MtCfgLen:
170     Sized
171     + MtSerialize
172     + MtDeserialize
173     + TryFrom<usize, Error: Into<SerializeError>>
174     + TryInto<usize, Error: Into<DeserializeError>>
175 {
176 }
177
178 impl<T: MtCfgLen> MtCfg for T {
179     type Len = usize;
180
181     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
182         Self::try_from(len)
183             .map_err(Into::into)?
184             .mt_serialize::<DefCfg>(writer)
185     }
186
187     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
188         Ok(Self::mt_deserialize::<DefCfg>(reader)?
189             .try_into()
190             .map_err(Into::into)?)
191     }
192 }
193
194 impl MtCfgLen for u8 {}
195 impl MtCfgLen for u16 {}
196 impl MtCfgLen for u32 {}
197 impl MtCfgLen for u64 {}
198
199 pub type DefCfg = u16;
200
201 impl MtCfg for () {
202     type Len = ();
203
204     fn write_len(_len: usize, _writer: &mut impl Write) -> Result<(), SerializeError> {
205         Ok(())
206     }
207
208     fn read_len(_writer: &mut impl Read) -> Result<Self::Len, DeserializeError> {
209         Ok(())
210     }
211 }
212
213 impl MtLen for () {
214     fn option(&self) -> Option<usize> {
215         None
216     }
217
218     type Range = std::ops::RangeFrom<usize>;
219     fn range(&self) -> Self::Range {
220         0..
221     }
222
223     type Take<R: Read> = R;
224     fn take<R: Read>(&self, reader: R) -> Self::Take<R> {
225         reader
226     }
227 }
228
229 pub struct Utf16<B: MtCfg>(pub B);
230
231 impl<B: MtCfg> MtCfg for Utf16<B> {
232     type Len = B::Len;
233
234     fn utf16() -> bool {
235         true
236     }
237
238     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
239         B::write_len(len, writer)
240     }
241
242     fn read_len(reader: &mut impl Read) -> Result<Self::Len, DeserializeError> {
243         B::read_len(reader)
244     }
245 }
246
247 impl MtSerialize for u8 {
248     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
249         writer.write_u8(*self)?;
250         Ok(())
251     }
252 }
253
254 impl MtDeserialize for u8 {
255     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
256         Ok(reader.read_u8()?)
257     }
258 }
259
260 impl MtSerialize for i8 {
261     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
262         writer.write_i8(*self)?;
263         Ok(())
264     }
265 }
266
267 impl MtDeserialize for i8 {
268     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
269         Ok(reader.read_i8()?)
270     }
271 }
272
273 macro_rules! impl_num {
274     ($T:ty) => {
275         impl MtSerialize for $T {
276             fn mt_serialize<C: MtCfg>(
277                 &self,
278                 writer: &mut impl Write,
279             ) -> Result<(), SerializeError> {
280                 paste_macro! {
281                     writer.[<write_ $T>]::<BigEndian>(*self)?;
282                 }
283                 Ok(())
284             }
285         }
286
287         impl MtDeserialize for $T {
288             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
289                 paste_macro! {
290                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
291                 }
292             }
293         }
294     };
295 }
296
297 impl_num!(u16);
298 impl_num!(i16);
299
300 impl_num!(u32);
301 impl_num!(i32);
302 impl_num!(f32);
303
304 impl_num!(u64);
305 impl_num!(i64);
306 impl_num!(f64);
307
308 impl MtSerialize for () {
309     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
310         Ok(())
311     }
312 }
313
314 impl MtDeserialize for () {
315     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
316         Ok(())
317     }
318 }
319
320 impl MtSerialize for bool {
321     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
322         (*self as u8).mt_serialize::<DefCfg>(writer)
323     }
324 }
325
326 impl MtDeserialize for bool {
327     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
328         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
329     }
330 }
331
332 impl<T: MtSerialize> MtSerialize for &T {
333     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
334         (*self).mt_serialize::<C>(writer)
335     }
336 }
337
338 pub fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
339     writer: &mut impl Write,
340     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
341 ) -> Result<(), SerializeError> {
342     C::write_len(iter.len(), writer)?;
343
344     iter.into_iter()
345         .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
346 }
347
348 pub fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
349     reader: &'a mut impl Read,
350 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
351     let len = C::read_len(reader)?;
352     mt_deserialize_sized_seq(&len, reader)
353 }
354
355 pub fn mt_deserialize_sized_seq<'a, L: MtLen, T: MtDeserialize>(
356     len: &L,
357     reader: &'a mut impl Read,
358 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
359     let variable = len.option().is_none();
360
361     Ok(len
362         .range()
363         .map_while(move |_| match T::mt_deserialize::<DefCfg>(reader) {
364             Err(DeserializeError::UnexpectedEof) if variable => None,
365             x => Some(x),
366         }))
367 }
368
369 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
370     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
371         mt_serialize_seq::<(), _>(writer, self.iter())
372     }
373 }
374
375 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
376     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
377         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
378     }
379 }
380
381 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
382     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
383         self.as_repr().mt_serialize::<DefCfg>(writer)
384     }
385 }
386
387 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
388     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
389         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
390             reader,
391         )?))
392     }
393 }
394
395 impl<T: MtSerialize> MtSerialize for Option<T> {
396     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
397         match self {
398             Some(item) => item.mt_serialize::<C>(writer),
399             None => Ok(()),
400         }
401     }
402 }
403
404 impl<T: MtDeserialize> MtDeserialize for Option<T> {
405     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
406         T::mt_deserialize::<C>(reader).map(Some).or_default()
407     }
408 }
409
410 impl<T: MtSerialize> MtSerialize for Vec<T> {
411     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
412         mt_serialize_seq::<C, _>(writer, self.iter())
413     }
414 }
415
416 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
417     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
418         mt_deserialize_seq::<C, _>(reader)?.try_collect()
419     }
420 }
421
422 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
423     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
424         mt_serialize_seq::<C, _>(writer, self.iter())
425     }
426 }
427
428 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
429     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
430         mt_deserialize_seq::<C, _>(reader)?.try_collect()
431     }
432 }
433
434 // TODO: support more tuples
435 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
436     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
437         self.0.mt_serialize::<DefCfg>(writer)?;
438         self.1.mt_serialize::<DefCfg>(writer)?;
439
440         Ok(())
441     }
442 }
443
444 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
445     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
446         let a = A::mt_deserialize::<DefCfg>(reader)?;
447         let b = B::mt_deserialize::<DefCfg>(reader)?;
448
449         Ok((a, b))
450     }
451 }
452
453 impl<K, V> MtSerialize for HashMap<K, V>
454 where
455     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
456     V: MtSerialize,
457 {
458     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
459         mt_serialize_seq::<C, _>(writer, self.iter())
460     }
461 }
462
463 impl<K, V> MtDeserialize for HashMap<K, V>
464 where
465     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
466     V: MtDeserialize,
467 {
468     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
469         mt_deserialize_seq::<C, _>(reader)?.try_collect()
470     }
471 }
472
473 impl MtSerialize for String {
474     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
475         if C::utf16() {
476             self.encode_utf16()
477                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
478                 .mt_serialize::<C>(writer)
479         } else {
480             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
481         }
482     }
483 }
484
485 impl MtDeserialize for String {
486     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
487         if C::utf16() {
488             let mut err = None;
489
490             let res =
491                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
492                     Ok(v) => Some(v),
493                     Err(e) => {
494                         err = Some(e);
495                         None
496                     }
497                 }))
498                 .try_collect();
499
500             match err {
501                 None => Ok(res?),
502                 Some(e) => Err(e),
503             }
504         } else {
505             let len = C::read_len(reader)?;
506
507             // use capacity if available
508             let mut st = match len.option() {
509                 Some(x) => String::with_capacity(x),
510                 None => String::new(),
511             };
512
513             len.take(WrapRead(reader)).read_to_string(&mut st)?;
514
515             Ok(st)
516         }
517     }
518 }
519
520 impl<T: MtSerialize> MtSerialize for Box<T> {
521     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
522         self.deref().mt_serialize::<C>(writer)
523     }
524 }
525
526 impl<T: MtDeserialize> MtDeserialize for Box<T> {
527     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
528         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
529     }
530 }