]> git.lizzy.rs Git - mt_ser.git/blob - src/lib.rs
Implement deserialize for basic types
[mt_ser.git] / src / lib.rs
1 #![feature(array_try_from_fn)]
2 #![feature(associated_type_bounds)]
3 #![feature(iterator_try_collect)]
4
5 pub use enumset;
6 pub use flate2;
7 pub use paste;
8
9 #[cfg(feature = "random")]
10 pub use generate_random;
11
12 #[cfg(feature = "random")]
13 pub use rand;
14
15 #[cfg(feature = "serde")]
16 pub use serde;
17
18 use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
19 use enumset::{EnumSet, EnumSetType, EnumSetTypeWithRepr};
20 use mt_data_derive::mt_derive;
21 pub use mt_data_derive::{MtDeserialize, MtSerialize};
22 use paste::paste as paste_macro;
23 use std::{
24     collections::{HashMap, HashSet},
25     convert::Infallible,
26     fmt,
27     io::{self, Read, Write},
28     num::TryFromIntError,
29     ops::Deref,
30 };
31 use thiserror::Error;
32
33 #[cfg(feature = "serde")]
34 use serde::{Deserialize, Serialize};
35
36 #[cfg(feature = "random")]
37 use generate_random::GenerateRandom;
38
39 #[cfg(any(feature = "client", feature = "server"))]
40 use crate as mt_data;
41
42 #[derive(Error, Debug)]
43 pub enum SerializeError {
44     #[error("io error: {0}")]
45     IoError(#[from] io::Error),
46     #[error("collection too big: {0}")]
47     TooBig(#[from] TryFromIntError),
48 }
49
50 impl From<Infallible> for SerializeError {
51     fn from(_err: Infallible) -> Self {
52         unreachable!("infallible")
53     }
54 }
55
56 #[derive(Error, Debug)]
57 pub enum DeserializeError {
58     #[error("io error: {0}")]
59     IoError(io::Error),
60     #[error("unexpected end of file")]
61     UnexpectedEof,
62     #[error("collection too big: {0}")]
63     TooBig(#[from] TryFromIntError),
64     #[error("invalid UTF-16: {0}")]
65     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
66     #[error("unimplemented")]
67     Unimplemented,
68 }
69
70 impl From<Infallible> for DeserializeError {
71     fn from(_err: Infallible) -> Self {
72         unreachable!("infallible")
73     }
74 }
75
76 impl From<io::Error> for DeserializeError {
77     fn from(err: io::Error) -> Self {
78         if err.kind() == io::ErrorKind::UnexpectedEof {
79             DeserializeError::UnexpectedEof
80         } else {
81             DeserializeError::IoError(err)
82         }
83     }
84 }
85
86 pub trait OrDefault<T> {
87     fn or_default(self) -> Self;
88 }
89
90 impl<T: MtDeserialize + Default> OrDefault<T> for Result<T, DeserializeError> {
91     fn or_default(self) -> Self {
92         match self {
93             Err(DeserializeError::UnexpectedEof) => Ok(T::default()),
94             x => x,
95         }
96     }
97 }
98
99 pub trait MtCfg:
100     Sized + MtSerialize + MtDeserialize + TryFrom<usize, Error: Into<SerializeError>>
101 {
102     type Range: Iterator<Item = usize> + 'static;
103
104     fn utf16() -> bool {
105         false
106     }
107
108     fn var_len() -> bool;
109
110     fn write_len(len: usize, writer: &mut impl Write) -> Result<(), SerializeError> {
111         Self::try_from(len)
112             .map_err(Into::into)?
113             .mt_serialize::<DefCfg>(writer)
114     }
115
116     fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError>;
117 }
118
119 trait MtCfgLen:
120     Sized
121     + MtSerialize
122     + MtDeserialize
123     + TryFrom<usize, Error: Into<SerializeError>>
124     + TryInto<usize, Error: Into<DeserializeError>>
125 {
126 }
127
128 impl<T: MtCfgLen> MtCfg for T {
129     type Range = std::ops::Range<usize>;
130
131     fn var_len() -> bool {
132         false
133     }
134
135     fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
136         let len = Self::mt_deserialize::<DefCfg>(reader)?
137             .try_into()
138             .map_err(Into::into)?;
139
140         Ok(0..len)
141     }
142 }
143
144 pub type DefCfg = u16;
145
146 pub trait MtSerialize: Sized {
147     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError>;
148 }
149
150 pub trait MtDeserialize: Sized {
151     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError>;
152 }
153
154 impl MtCfgLen for u8 {}
155 impl MtCfgLen for u16 {}
156 impl MtCfgLen for u32 {}
157 impl MtCfgLen for u64 {}
158
159 #[derive(Debug)]
160 pub struct NoLen;
161
162 impl MtSerialize for NoLen {
163     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
164         Ok(())
165     }
166 }
167
168 impl MtDeserialize for NoLen {
169     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
170         Ok(Self)
171     }
172 }
173
174 impl TryFrom<usize> for NoLen {
175     type Error = Infallible;
176
177     fn try_from(_x: usize) -> Result<Self, Self::Error> {
178         Ok(Self)
179     }
180 }
181
182 impl MtCfg for NoLen {
183     fn var_len() -> bool {
184         true
185     }
186
187     type Range = std::ops::RangeFrom<usize>;
188
189     fn read_len(_reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
190         Ok(0..)
191     }
192 }
193
194 pub struct Utf16<B: MtCfg>(pub B);
195
196 impl<B: MtCfg> MtSerialize for Utf16<B> {
197     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
198         self.0.mt_serialize::<DefCfg>(writer)
199     }
200 }
201
202 impl<B: MtCfg> MtDeserialize for Utf16<B> {
203     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
204         Ok(Self(B::mt_deserialize::<DefCfg>(reader)?))
205     }
206 }
207
208 impl<B: MtCfg> TryFrom<usize> for Utf16<B> {
209     type Error = <usize as TryInto<B>>::Error;
210
211     fn try_from(x: usize) -> Result<Self, Self::Error> {
212         Ok(Self(x.try_into()?))
213     }
214 }
215
216 impl<B: MtCfg> MtCfg for Utf16<B> {
217     type Range = B::Range;
218
219     fn utf16() -> bool {
220         true
221     }
222
223     fn var_len() -> bool {
224         B::var_len()
225     }
226
227     fn read_len(reader: &mut impl Read) -> Result<Self::Range, DeserializeError> {
228         B::read_len(reader)
229     }
230 }
231
232 impl MtSerialize for u8 {
233     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
234         writer.write_u8(*self)?;
235         Ok(())
236     }
237 }
238
239 impl MtDeserialize for u8 {
240     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
241         Ok(reader.read_u8()?)
242     }
243 }
244
245 impl MtSerialize for i8 {
246     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
247         writer.write_i8(*self)?;
248         Ok(())
249     }
250 }
251
252 impl MtDeserialize for i8 {
253     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
254         Ok(reader.read_i8()?)
255     }
256 }
257
258 macro_rules! impl_num {
259     ($T:ty) => {
260         impl MtSerialize for $T {
261             fn mt_serialize<C: MtCfg>(
262                 &self,
263                 writer: &mut impl Write,
264             ) -> Result<(), SerializeError> {
265                 paste_macro! {
266                     writer.[<write_ $T>]::<BigEndian>(*self)?;
267                 }
268                 Ok(())
269             }
270         }
271
272         impl MtDeserialize for $T {
273             fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
274                 paste_macro! {
275                     Ok(reader.[<read_ $T>]::<BigEndian>()?)
276                 }
277             }
278         }
279     };
280 }
281
282 impl_num!(u16);
283 impl_num!(i16);
284
285 impl_num!(u32);
286 impl_num!(i32);
287 impl_num!(f32);
288
289 impl_num!(u64);
290 impl_num!(i64);
291 impl_num!(f64);
292
293 impl MtSerialize for () {
294     fn mt_serialize<C: MtCfg>(&self, _writer: &mut impl Write) -> Result<(), SerializeError> {
295         Ok(())
296     }
297 }
298
299 impl MtDeserialize for () {
300     fn mt_deserialize<C: MtCfg>(_reader: &mut impl Read) -> Result<Self, DeserializeError> {
301         Ok(())
302     }
303 }
304
305 impl MtSerialize for bool {
306     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
307         (*self as u8).mt_serialize::<DefCfg>(writer)
308     }
309 }
310
311 impl MtDeserialize for bool {
312     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
313         Ok(u8::mt_deserialize::<DefCfg>(reader)? != 0)
314     }
315 }
316
317 impl<T: MtSerialize> MtSerialize for &T {
318     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
319         (*self).mt_serialize::<C>(writer)
320     }
321 }
322
323 fn mt_serialize_seq<C: MtCfg, T: MtSerialize>(
324     writer: &mut impl Write,
325     iter: impl ExactSizeIterator + IntoIterator<Item = T>,
326 ) -> Result<(), SerializeError> {
327     C::write_len(iter.len(), writer)?;
328
329     iter.into_iter()
330         .try_for_each(|item| item.mt_serialize::<DefCfg>(writer))
331 }
332
333 fn mt_deserialize_seq<'a, C: MtCfg, T: MtDeserialize>(
334     reader: &'a mut impl Read,
335 ) -> Result<impl Iterator<Item = Result<T, DeserializeError>> + 'a, DeserializeError> {
336     Ok(C::read_len(reader)?
337         .into_iter()
338         .map_while(|_| match T::mt_deserialize::<DefCfg>(reader) {
339             Err(DeserializeError::UnexpectedEof) if C::var_len() => None,
340             x => Some(x),
341         }))
342 }
343
344 impl<T: MtSerialize, const N: usize> MtSerialize for [T; N] {
345     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
346         mt_serialize_seq::<NoLen, _>(writer, self.iter())
347     }
348 }
349
350 impl<T: MtDeserialize, const N: usize> MtDeserialize for [T; N] {
351     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
352         std::array::try_from_fn(|_| T::mt_deserialize::<DefCfg>(reader))
353     }
354 }
355
356 impl<T: MtSerialize, E: EnumSetTypeWithRepr<Repr = T>> MtSerialize for EnumSet<E> {
357     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
358         self.as_repr().mt_serialize::<DefCfg>(writer)
359     }
360 }
361
362 impl<T: MtDeserialize, E: EnumSetTypeWithRepr<Repr = T>> MtDeserialize for EnumSet<E> {
363     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
364         Ok(Self::from_repr_truncated(T::mt_deserialize::<DefCfg>(
365             reader,
366         )?))
367     }
368 }
369
370 impl<T: MtSerialize> MtSerialize for Option<T> {
371     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
372         match self {
373             Some(item) => item.mt_serialize::<C>(writer),
374             None => Ok(()),
375         }
376     }
377 }
378
379 impl<T: MtDeserialize> MtDeserialize for Option<T> {
380     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
381         T::mt_deserialize::<C>(reader).map(Some).or_default()
382     }
383 }
384
385 impl<T: MtSerialize> MtSerialize for Vec<T> {
386     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
387         mt_serialize_seq::<C, _>(writer, self.iter())
388     }
389 }
390
391 impl<T: MtDeserialize> MtDeserialize for Vec<T> {
392     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
393         mt_deserialize_seq::<C, _>(reader)?.try_collect()
394     }
395 }
396
397 impl<T: MtSerialize + std::cmp::Eq + std::hash::Hash> MtSerialize for HashSet<T> {
398     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
399         mt_serialize_seq::<C, _>(writer, self.iter())
400     }
401 }
402
403 impl<T: MtDeserialize + std::cmp::Eq + std::hash::Hash> MtDeserialize for HashSet<T> {
404     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
405         mt_deserialize_seq::<C, _>(reader)?.try_collect()
406     }
407 }
408
409 impl<A: MtSerialize, B: MtSerialize> MtSerialize for (A, B) {
410     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
411         self.0.mt_serialize::<DefCfg>(writer)?;
412         self.1.mt_serialize::<DefCfg>(writer)?;
413
414         Ok(())
415     }
416 }
417
418 impl<A: MtDeserialize, B: MtDeserialize> MtDeserialize for (A, B) {
419     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
420         let a = A::mt_deserialize::<DefCfg>(reader)?;
421         let b = B::mt_deserialize::<DefCfg>(reader)?;
422
423         Ok((a, b))
424     }
425 }
426
427 impl<K, V> MtSerialize for HashMap<K, V>
428 where
429     K: MtSerialize + std::cmp::Eq + std::hash::Hash,
430     V: MtSerialize,
431 {
432     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
433         mt_serialize_seq::<C, _>(writer, self.iter())
434     }
435 }
436
437 impl<K, V> MtDeserialize for HashMap<K, V>
438 where
439     K: MtDeserialize + std::cmp::Eq + std::hash::Hash,
440     V: MtDeserialize,
441 {
442     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
443         mt_deserialize_seq::<C, _>(reader)?.try_collect()
444     }
445 }
446
447 impl MtSerialize for String {
448     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
449         if C::utf16() {
450             self.encode_utf16()
451                 .collect::<Vec<_>>() // FIXME: is this allocation necessary?
452                 .mt_serialize::<C>(writer)
453         } else {
454             mt_serialize_seq::<C, _>(writer, self.as_bytes().iter())
455         }
456     }
457 }
458
459 impl MtDeserialize for String {
460     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
461         if C::utf16() {
462             let mut err = None;
463
464             let res =
465                 char::decode_utf16(mt_deserialize_seq::<C, _>(reader)?.map_while(|x| match x {
466                     Ok(v) => Some(v),
467                     Err(e) => {
468                         err = Some(e);
469                         None
470                     }
471                 }))
472                 .try_collect();
473
474             match err {
475                 None => Ok(res?),
476                 Some(e) => Err(e),
477             }
478         } else {
479             // TODO: UTF-8 decode
480             Ok("".into())
481         }
482     }
483 }
484
485 impl<T: MtSerialize> MtSerialize for Box<T> {
486     fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
487         self.deref().mt_serialize::<C>(writer)
488     }
489 }
490
491 impl<T: MtDeserialize> MtDeserialize for Box<T> {
492     fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
493         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
494     }
495 }
496
497 mod to_clt;
498 mod to_srv;
499
500 pub use to_clt::*;
501 pub use to_srv::*;