]> git.lizzy.rs Git - rust.git/commitdiff
Avoid unchecked casts in net parser
authorTamir Duberstein <tamird@google.com>
Sun, 4 Oct 2020 12:39:39 +0000 (12:39 +0000)
committerTamir Duberstein <tamird@google.com>
Sun, 4 Oct 2020 16:57:54 +0000 (16:57 +0000)
library/std/src/net/parser.rs

index 0570a7c41bfe60b62390d4fac0b3124e629779c2..da94a5735034fbc344485faead46eab6e38907c9 100644 (file)
@@ -6,11 +6,34 @@
 #[cfg(test)]
 mod tests;
 
+use crate::convert::TryInto as _;
 use crate::error::Error;
 use crate::fmt;
 use crate::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
 use crate::str::FromStr;
 
+trait ReadNumberHelper: crate::marker::Sized {
+    const ZERO: Self;
+    fn checked_mul(&self, other: u32) -> Option<Self>;
+    fn checked_add(&self, other: u32) -> Option<Self>;
+}
+
+macro_rules! impl_helper {
+    ($($t:ty)*) => ($(impl ReadNumberHelper for $t {
+        const ZERO: Self = 0;
+        #[inline]
+        fn checked_mul(&self, other: u32) -> Option<Self> {
+            Self::checked_mul(*self, other.try_into().ok()?)
+        }
+        #[inline]
+        fn checked_add(&self, other: u32) -> Option<Self> {
+            Self::checked_add(*self, other.try_into().ok()?)
+        }
+    })*)
+}
+
+impl_helper! { u8 u16 }
+
 struct Parser<'a> {
     // parsing as ASCII, so can use byte array
     state: &'a [u8],
@@ -59,7 +82,7 @@ fn parse_with<T, F>(&mut self, inner: F) -> Result<T, AddrParseError>
     fn read_char(&mut self) -> Option<char> {
         self.state.split_first().map(|(&b, tail)| {
             self.state = tail;
-            b as char
+            char::from(b)
         })
     }
 
@@ -84,25 +107,26 @@ fn read_separator<T, F>(&mut self, sep: char, index: usize, inner: F) -> Option<
         })
     }
 
-    // Read a single digit in the given radix. For instance, 0-9 in radix 10;
-    // 0-9A-F in radix 16.
-    fn read_digit(&mut self, radix: u32) -> Option<u32> {
-        self.read_atomically(move |p| p.read_char()?.to_digit(radix))
-    }
-
     // Read a number off the front of the input in the given radix, stopping
     // at the first non-digit character or eof. Fails if the number has more
-    // digits than max_digits, or the value is >= upto, or if there is no number.
-    fn read_number(&mut self, radix: u32, max_digits: u32, upto: u32) -> Option<u32> {
+    // digits than max_digits or if there is no number.
+    fn read_number<T: ReadNumberHelper>(
+        &mut self,
+        radix: u32,
+        max_digits: Option<usize>,
+    ) -> Option<T> {
         self.read_atomically(move |p| {
-            let mut result = 0;
+            let mut result = T::ZERO;
             let mut digit_count = 0;
 
-            while let Some(digit) = p.read_digit(radix) {
-                result = (result * radix) + digit;
+            while let Some(digit) = p.read_atomically(|p| p.read_char()?.to_digit(radix)) {
+                result = result.checked_mul(radix)?;
+                result = result.checked_add(digit)?;
                 digit_count += 1;
-                if digit_count > max_digits || result >= upto {
-                    return None;
+                if let Some(max_digits) = max_digits {
+                    if digit_count > max_digits {
+                        return None;
+                    }
                 }
             }
 
@@ -116,7 +140,7 @@ fn read_ipv4_addr(&mut self) -> Option<Ipv4Addr> {
             let mut groups = [0; 4];
 
             for (i, slot) in groups.iter_mut().enumerate() {
-                *slot = p.read_separator('.', i, |p| p.read_number(10, 3, 0x100))? as u8;
+                *slot = p.read_separator('.', i, |p| p.read_number(10, None))?;
             }
 
             Some(groups.into())
@@ -140,17 +164,17 @@ fn read_groups(p: &mut Parser<'_>, groups: &mut [u16]) -> (usize, bool) {
                     let ipv4 = p.read_separator(':', i, |p| p.read_ipv4_addr());
 
                     if let Some(v4_addr) = ipv4 {
-                        let octets = v4_addr.octets();
-                        groups[i + 0] = ((octets[0] as u16) << 8) | (octets[1] as u16);
-                        groups[i + 1] = ((octets[2] as u16) << 8) | (octets[3] as u16);
+                        let [one, two, three, four] = v4_addr.octets();
+                        groups[i + 0] = u16::from_be_bytes([one, two]);
+                        groups[i + 1] = u16::from_be_bytes([three, four]);
                         return (i + 2, true);
                     }
                 }
 
-                let group = p.read_separator(':', i, |p| p.read_number(16, 4, 0x10000));
+                let group = p.read_separator(':', i, |p| p.read_number(16, Some(4)));
 
                 match group {
-                    Some(g) => *slot = g as u16,
+                    Some(g) => *slot = g,
                     None => return (i, false),
                 }
             }
@@ -195,12 +219,11 @@ fn read_ip_addr(&mut self) -> Option<IpAddr> {
         self.read_ipv4_addr().map(IpAddr::V4).or_else(move || self.read_ipv6_addr().map(IpAddr::V6))
     }
 
-    /// Read a : followed by a port in base 10
+    /// Read a : followed by a port in base 10.
     fn read_port(&mut self) -> Option<u16> {
         self.read_atomically(|p| {
             let _ = p.read_given_char(':')?;
-            let port = p.read_number(10, 5, 0x10000)?;
-            Some(port as u16)
+            p.read_number(10, None)
         })
     }