From 3ab7ae39ec77a39df27ac6c3fbde03cd3b763542 Mon Sep 17 00:00:00 2001 From: Nathan West Date: Tue, 19 May 2020 23:26:49 -0400 Subject: [PATCH] Bring net/parser.rs up to modern up to date with modern rust patterns Made the following changes throughout the IP address parser: - Replaced all uses of `is_some()` / `is_none()` with `?`. - "Upgraded" loops wherever possible; ie, replace `while` with `for`, etc. - Removed all cases of manual index tracking / incrementing. - Renamed several single-character variables with more expressive names. - Replaced several manual control flow segments with equivalent adapters (such as `Option::filter`). - Removed `read_seq_3`; replaced with simple sequences of `?`. - Parser now reslices its state when consuming, rather than carrying a separate state and index variable. - `read_digit` now uses `char::to_digit`. - Removed unnecessary casts back and forth between u8 and u32 - Added comments throughout, especially in the complex IPv6 parsing logic. - Added comprehensive local unit tests for the parser to validate these changes. --- src/libstd/net/parser.rs | 482 +++++++++++++++++++++++---------------- 1 file changed, 289 insertions(+), 193 deletions(-) diff --git a/src/libstd/net/parser.rs b/src/libstd/net/parser.rs index 3f38ee54710..12d3baf6333 100644 --- a/src/libstd/net/parser.rs +++ b/src/libstd/net/parser.rs @@ -10,163 +10,132 @@ struct Parser<'a> { // parsing as ASCII, so can use byte array - s: &'a [u8], - pos: usize, + state: &'a [u8], } impl<'a> Parser<'a> { - fn new(s: &'a str) -> Parser<'a> { - Parser { s: s.as_bytes(), pos: 0 } + fn new(input: &'a str) -> Parser<'a> { + Parser { state: input.as_bytes() } } fn is_eof(&self) -> bool { - self.pos == self.s.len() + self.state.is_empty() } - // Commit only if parser returns Some - fn read_atomically(&mut self, cb: F) -> Option + /// Run a parser, and restore the pre-parse state if it fails + fn read_atomically(&mut self, inner: F) -> Option where F: FnOnce(&mut Parser<'_>) -> Option, { - let pos = self.pos; - let r = cb(self); - if r.is_none() { - self.pos = pos; + let state = self.state; + let result = inner(self); + if result.is_none() { + self.state = state; } - r + result } - // Commit only if parser read till EOF - fn read_till_eof(&mut self, cb: F) -> Option + /// Run a parser, but fail if the entire input wasn't consumed. + /// Doesn't run atomically. + fn read_till_eof(&mut self, inner: F) -> Option where F: FnOnce(&mut Parser<'_>) -> Option, { - self.read_atomically(move |p| cb(p).filter(|_| p.is_eof())) + inner(self).filter(|_| self.is_eof()) } - // Apply 3 parsers sequentially - fn read_seq_3(&mut self, pa: PA, pb: PB, pc: PC) -> Option<(A, B, C)> + /// Same as read_till_eof, but returns a Result on failure + fn parse_with(&mut self, inner: F) -> Result where - PA: FnOnce(&mut Parser<'_>) -> Option, - PB: FnOnce(&mut Parser<'_>) -> Option, - PC: FnOnce(&mut Parser<'_>) -> Option, + F: FnOnce(&mut Parser<'_>) -> Option, { - self.read_atomically(move |p| { - let a = pa(p); - let b = if a.is_some() { pb(p) } else { None }; - let c = if b.is_some() { pc(p) } else { None }; - match (a, b, c) { - (Some(a), Some(b), Some(c)) => Some((a, b, c)), - _ => None, - } - }) + self.read_till_eof(inner).ok_or(AddrParseError(())) } - // Read next char + /// Read the next character from the input fn read_char(&mut self) -> Option { - if self.is_eof() { - None - } else { - let r = self.s[self.pos] as char; - self.pos += 1; - Some(r) - } - } - - // Return char and advance iff next char is equal to requested - fn read_given_char(&mut self, c: char) -> Option { - self.read_atomically(|p| match p.read_char() { - Some(next) if next == c => Some(next), - _ => None, + self.state.split_first().map(|(&b, tail)| { + self.state = tail; + b as char }) } - // Read digit - fn read_digit(&mut self, radix: u8) -> Option { - fn parse_digit(c: char, radix: u8) -> Option { - let c = c as u8; - // assuming radix is either 10 or 16 - if c >= b'0' && c <= b'9' { - Some(c - b'0') - } else if radix > 10 && c >= b'a' && c < b'a' + (radix - 10) { - Some(c - b'a' + 10) - } else if radix > 10 && c >= b'A' && c < b'A' + (radix - 10) { - Some(c - b'A' + 10) - } else { - None - } - } - - self.read_atomically(|p| p.read_char().and_then(|c| parse_digit(c, radix))) + /// Read the next character from the input if it matches the target + fn read_given_char(&mut self, target: char) -> Option { + self.read_atomically(|p| p.read_char().filter(|&c| c == target)) } - fn read_number_impl(&mut self, radix: u8, max_digits: u32, upto: u32) -> Option { - let mut r = 0; - let mut digit_count = 0; - loop { - match self.read_digit(radix) { - Some(d) => { - r = r * (radix as u32) + (d as u32); - digit_count += 1; - if digit_count > max_digits || r >= upto { - return None; - } - } - None => { - if digit_count == 0 { - return None; - } else { - return Some(r); - } - } - }; - } + /// Helper for reading separators in an indexed loop. Reads the separator + /// character iff index > 0, then runs the parser. When used in a loop, + /// the separator character will only be read on index > 0 (see + /// read_ipv4_addr for an example) + fn read_separator(&mut self, sep: char, index: usize, inner: F) -> Option + where + F: FnOnce(&mut Parser<'_>) -> Option, + { + self.read_atomically(move |p| { + if index > 0 { + let _ = p.read_given_char(sep)?; + } + inner(p) + }) } - // Read number, failing if max_digits of number value exceeded - fn read_number(&mut self, radix: u8, max_digits: u32, upto: u32) -> Option { - self.read_atomically(|p| p.read_number_impl(radix, max_digits, upto)) + // Read a single digit in the given radix. For instance, 0-9 in radix 10; + // 0-9A-F in radix 16. + fn read_digit(&mut self, radix: u32) -> Option { + self.read_atomically(move |p| p.read_char()?.to_digit(radix)) } - fn read_ipv4_addr_impl(&mut self) -> Option { - let mut bs = [0; 4]; - let mut i = 0; - while i < 4 { - if i != 0 && self.read_given_char('.').is_none() { - return None; + // Read a number off the front of the input in the given radix, stopping + // at the first non-digit character or eof. Fails if the number has more + // digits than max_digits, or the value is >= upto, or if there is no number. + fn read_number(&mut self, radix: u32, max_digits: u32, upto: u32) -> Option { + self.read_atomically(move |p| { + let mut result = 0; + let mut digit_count = 0; + + while let Some(digit) = p.read_digit(radix) { + result = (result * radix) + digit; + digit_count += 1; + if digit_count > max_digits || result >= upto { + return None; + } } - bs[i] = self.read_number(10, 3, 0x100).map(|n| n as u8)?; - i += 1; - } - Some(Ipv4Addr::new(bs[0], bs[1], bs[2], bs[3])) + if digit_count == 0 { None } else { Some(result) } + }) } - // Read IPv4 address + /// Read an IPv4 address fn read_ipv4_addr(&mut self) -> Option { - self.read_atomically(|p| p.read_ipv4_addr_impl()) - } + self.read_atomically(|p| { + let mut groups = [0; 4]; - fn read_ipv6_addr_impl(&mut self) -> Option { - fn ipv6_addr_from_head_tail(head: &[u16], tail: &[u16]) -> Ipv6Addr { - assert!(head.len() + tail.len() <= 8); - let mut gs = [0; 8]; - gs[..head.len()].copy_from_slice(head); - gs[(8 - tail.len())..8].copy_from_slice(tail); - Ipv6Addr::new(gs[0], gs[1], gs[2], gs[3], gs[4], gs[5], gs[6], gs[7]) - } + for (i, slot) in groups.iter_mut().enumerate() { + *slot = p.read_separator('.', i, |p| p.read_number(10, 3, 0x100))? as u8; + } + + Some(groups.into()) + }) + } - fn read_groups(p: &mut Parser<'_>, groups: &mut [u16; 8], limit: usize) -> (usize, bool) { - let mut i = 0; - while i < limit { + /// Read an IPV6 Address + fn read_ipv6_addr(&mut self) -> Option { + /// Read a chunk of an ipv6 address into `groups`. Returns the number + /// of groups read, along with a bool indicating if an embedded + /// trailing ipv4 address was read. Specifically, read a series of + /// colon-separated ipv6 groups (0x0000 - 0xFFFF), with an optional + /// trailing embedded ipv4 address. + fn read_groups(p: &mut Parser<'_>, groups: &mut [u16]) -> (usize, bool) { + let limit = groups.len(); + + for (i, slot) in groups.iter_mut().enumerate() { + // Try to read a trailing embedded ipv4 address. There must be + // at least two groups left. if i < limit - 1 { - let ipv4 = p.read_atomically(|p| { - if i == 0 || p.read_given_char(':').is_some() { - p.read_ipv4_addr() - } else { - None - } - }); + let ipv4 = p.read_separator(':', i, |p| p.read_ipv4_addr()); + if let Some(v4_addr) = ipv4 { let octets = v4_addr.octets(); groups[i + 0] = ((octets[0] as u16) << 8) | (octets[1] as u16); @@ -175,83 +144,85 @@ fn ipv6_addr_from_head_tail(head: &[u16], tail: &[u16]) -> Ipv6Addr { } } - let group = p.read_atomically(|p| { - if i == 0 || p.read_given_char(':').is_some() { - p.read_number(16, 4, 0x10000).map(|n| n as u16) - } else { - None - } - }); + let group = p.read_separator(':', i, |p| p.read_number(16, 4, 0x10000)); + match group { - Some(g) => groups[i] = g, + Some(g) => *slot = g as u16, None => return (i, false), } - i += 1; } - (i, false) + (groups.len(), false) } - let mut head = [0; 8]; - let (head_size, head_ipv4) = read_groups(self, &mut head, 8); + self.read_atomically(|p| { + // Read the front part of the address; either the whole thing, or up + // to the first :: + let mut head = [0; 8]; + let (head_size, head_ipv4) = read_groups(p, &mut head); - if head_size == 8 { - return Some(Ipv6Addr::new( - head[0], head[1], head[2], head[3], head[4], head[5], head[6], head[7], - )); - } + if head_size == 8 { + return Some(head.into()); + } - // IPv4 part is not allowed before `::` - if head_ipv4 { - return None; - } + // IPv4 part is not allowed before `::` + if head_ipv4 { + return None; + } - // read `::` if previous code parsed less than 8 groups - if self.read_given_char(':').is_none() || self.read_given_char(':').is_none() { - return None; - } + // read `::` if previous code parsed less than 8 groups + // `::` indicates one or more groups of 16 bits of zeros + let _ = p.read_given_char(':')?; + let _ = p.read_given_char(':')?; - let mut tail = [0; 8]; - // `::` indicates one or more groups of 16 bits of zeros - let limit = 8 - (head_size + 1); - let (tail_size, _) = read_groups(self, &mut tail, limit); - Some(ipv6_addr_from_head_tail(&head[..head_size], &tail[..tail_size])) - } + // Read the back part of the address. The :: must contain at least one + // set of zeroes, so our max length is 7. + let mut tail = [0; 7]; + let limit = 8 - (head_size + 1); + let (tail_size, _) = read_groups(p, &mut tail[..limit]); - fn read_ipv6_addr(&mut self) -> Option { - self.read_atomically(|p| p.read_ipv6_addr_impl()) + // Concat the head and tail of the IP address + head[(8 - tail_size)..8].copy_from_slice(&tail[..tail_size]); + + Some(head.into()) + }) } + /// Read an IP Address, either IPV4 or IPV6. fn read_ip_addr(&mut self) -> Option { - self.read_ipv4_addr().map(IpAddr::V4).or_else(|| self.read_ipv6_addr().map(IpAddr::V6)) + self.read_ipv4_addr().map(IpAddr::V4).or_else(move || self.read_ipv6_addr().map(IpAddr::V6)) } - fn read_socket_addr_v4(&mut self) -> Option { - let ip_addr = |p: &mut Parser<'_>| p.read_ipv4_addr(); - let colon = |p: &mut Parser<'_>| p.read_given_char(':'); - let port = |p: &mut Parser<'_>| p.read_number(10, 5, 0x10000).map(|n| n as u16); + /// Read a : followed by a port in base 10 + fn read_port(&mut self) -> Option { + self.read_atomically(|p| { + let _ = p.read_given_char(':')?; + let port = p.read_number(10, 5, 0x10000)?; + Some(port as u16) + }) + } - self.read_seq_3(ip_addr, colon, port).map(|t| { - let (ip, _, port): (Ipv4Addr, char, u16) = t; - SocketAddrV4::new(ip, port) + /// Read an IPV4 address with a port + fn read_socket_addr_v4(&mut self) -> Option { + self.read_atomically(|p| { + let ip = p.read_ipv4_addr()?; + let port = p.read_port()?; + Some(SocketAddrV4::new(ip, port)) }) } + /// Read an IPV6 address with a port fn read_socket_addr_v6(&mut self) -> Option { - let ip_addr = |p: &mut Parser<'_>| { - let open_br = |p: &mut Parser<'_>| p.read_given_char('['); - let ip_addr = |p: &mut Parser<'_>| p.read_ipv6_addr(); - let clos_br = |p: &mut Parser<'_>| p.read_given_char(']'); - p.read_seq_3(open_br, ip_addr, clos_br).map(|t| t.1) - }; - let colon = |p: &mut Parser<'_>| p.read_given_char(':'); - let port = |p: &mut Parser<'_>| p.read_number(10, 5, 0x10000).map(|n| n as u16); - - self.read_seq_3(ip_addr, colon, port).map(|t| { - let (ip, _, port): (Ipv6Addr, char, u16) = t; - SocketAddrV6::new(ip, port, 0, 0) + self.read_atomically(|p| { + let _ = p.read_given_char('[')?; + let ip = p.read_ipv6_addr()?; + let _ = p.read_given_char(']')?; + + let port = p.read_port()?; + Some(SocketAddrV6::new(ip, port, 0, 0)) }) } + /// Read an IP address with a port fn read_socket_addr(&mut self) -> Option { self.read_socket_addr_v4() .map(SocketAddr::V4) @@ -263,10 +234,7 @@ fn read_socket_addr(&mut self) -> Option { impl FromStr for IpAddr { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_ip_addr()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_ip_addr()) } } @@ -274,10 +242,7 @@ fn from_str(s: &str) -> Result { impl FromStr for Ipv4Addr { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_ipv4_addr()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_ipv4_addr()) } } @@ -285,10 +250,7 @@ fn from_str(s: &str) -> Result { impl FromStr for Ipv6Addr { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_ipv6_addr()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_ipv6_addr()) } } @@ -296,10 +258,7 @@ fn from_str(s: &str) -> Result { impl FromStr for SocketAddrV4 { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_socket_addr_v4()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_socket_addr_v4()) } } @@ -307,10 +266,7 @@ fn from_str(s: &str) -> Result { impl FromStr for SocketAddrV6 { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_socket_addr_v6()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_socket_addr_v6()) } } @@ -318,10 +274,7 @@ fn from_str(s: &str) -> Result { impl FromStr for SocketAddr { type Err = AddrParseError; fn from_str(s: &str) -> Result { - match Parser::new(s).read_till_eof(|p| p.read_socket_addr()) { - Some(s) => Ok(s), - None => Err(AddrParseError(())), - } + Parser::new(s).parse_with(|p| p.read_socket_addr()) } } @@ -376,3 +329,146 @@ fn description(&self) -> &str { "invalid IP address syntax" } } + +#[cfg(test)] +mod tests { + // FIXME: These tests are all excellent candidates for AFL fuzz testing + use crate::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + use crate::str::FromStr; + + const PORT: u16 = 8080; + + const IPV4: Ipv4Addr = Ipv4Addr::new(192, 168, 0, 1); + const IPV4_STR: &str = "192.168.0.1"; + const IPV4_STR_PORT: &str = "192.168.0.1:8080"; + + const IPV6: Ipv6Addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xc0a8, 0x1); + const IPV6_STR_FULL: &str = "2001:db8:0:0:0:0:c0a8:1"; + const IPV6_STR_COMPRESS: &str = "2001:db8::c0a8:1"; + const IPV6_STR_V4: &str = "2001:db8::192.168.0.1"; + const IPV6_STR_PORT: &str = "[2001:db8::c0a8:1]:8080"; + + #[test] + fn parse_ipv4() { + let result: Ipv4Addr = IPV4_STR.parse().unwrap(); + assert_eq!(result, IPV4); + + assert!(Ipv4Addr::from_str(IPV4_STR_PORT).is_err()); + assert!(Ipv4Addr::from_str(IPV6_STR_FULL).is_err()); + assert!(Ipv4Addr::from_str(IPV6_STR_COMPRESS).is_err()); + assert!(Ipv4Addr::from_str(IPV6_STR_V4).is_err()); + assert!(Ipv4Addr::from_str(IPV6_STR_PORT).is_err()); + } + + #[test] + fn parse_ipv6() { + let result: Ipv6Addr = IPV6_STR_FULL.parse().unwrap(); + assert_eq!(result, IPV6); + + let result: Ipv6Addr = IPV6_STR_COMPRESS.parse().unwrap(); + assert_eq!(result, IPV6); + + let result: Ipv6Addr = IPV6_STR_V4.parse().unwrap(); + assert_eq!(result, IPV6); + + assert!(Ipv6Addr::from_str(IPV4_STR).is_err()); + assert!(Ipv6Addr::from_str(IPV4_STR_PORT).is_err()); + assert!(Ipv6Addr::from_str(IPV6_STR_PORT).is_err()); + } + + #[test] + fn parse_ip() { + let result: IpAddr = IPV4_STR.parse().unwrap(); + assert_eq!(result, IpAddr::from(IPV4)); + + let result: IpAddr = IPV6_STR_FULL.parse().unwrap(); + assert_eq!(result, IpAddr::from(IPV6)); + + let result: IpAddr = IPV6_STR_COMPRESS.parse().unwrap(); + assert_eq!(result, IpAddr::from(IPV6)); + + let result: IpAddr = IPV6_STR_V4.parse().unwrap(); + assert_eq!(result, IpAddr::from(IPV6)); + + assert!(IpAddr::from_str(IPV4_STR_PORT).is_err()); + assert!(IpAddr::from_str(IPV6_STR_PORT).is_err()); + } + + #[test] + fn parse_socket_v4() { + let result: SocketAddrV4 = IPV4_STR_PORT.parse().unwrap(); + assert_eq!(result, SocketAddrV4::new(IPV4, PORT)); + + assert!(SocketAddrV4::from_str(IPV4_STR).is_err()); + assert!(SocketAddrV4::from_str(IPV6_STR_FULL).is_err()); + assert!(SocketAddrV4::from_str(IPV6_STR_COMPRESS).is_err()); + assert!(SocketAddrV4::from_str(IPV6_STR_V4).is_err()); + assert!(SocketAddrV4::from_str(IPV6_STR_PORT).is_err()); + } + + #[test] + fn parse_socket_v6() { + let result: SocketAddrV6 = IPV6_STR_PORT.parse().unwrap(); + assert_eq!(result, SocketAddrV6::new(IPV6, PORT, 0, 0)); + + assert!(SocketAddrV6::from_str(IPV4_STR).is_err()); + assert!(SocketAddrV6::from_str(IPV4_STR_PORT).is_err()); + assert!(SocketAddrV6::from_str(IPV6_STR_FULL).is_err()); + assert!(SocketAddrV6::from_str(IPV6_STR_COMPRESS).is_err()); + assert!(SocketAddrV6::from_str(IPV6_STR_V4).is_err()); + } + + #[test] + fn parse_socket() { + let result: SocketAddr = IPV4_STR_PORT.parse().unwrap(); + assert_eq!(result, SocketAddr::from((IPV4, PORT))); + + let result: SocketAddr = IPV6_STR_PORT.parse().unwrap(); + assert_eq!(result, SocketAddr::from((IPV6, PORT))); + + assert!(SocketAddr::from_str(IPV4_STR).is_err()); + assert!(SocketAddr::from_str(IPV6_STR_FULL).is_err()); + assert!(SocketAddr::from_str(IPV6_STR_COMPRESS).is_err()); + assert!(SocketAddr::from_str(IPV6_STR_V4).is_err()); + } + + #[test] + fn ipv6_corner_cases() { + let result: Ipv6Addr = "1::".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(1, 0, 0, 0, 0, 0, 0, 0)); + + let result: Ipv6Addr = "1:1::".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(1, 1, 0, 0, 0, 0, 0, 0)); + + let result: Ipv6Addr = "::1".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + + let result: Ipv6Addr = "::1:1".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(0, 0, 0, 0, 0, 0, 1, 1)); + + let result: Ipv6Addr = "::".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let result: Ipv6Addr = "::192.168.0.1".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0xc0a8, 0x1)); + + let result: Ipv6Addr = "::1:192.168.0.1".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(0, 0, 0, 0, 0, 1, 0xc0a8, 0x1)); + + let result: Ipv6Addr = "1:1:1:1:1:1:192.168.0.1".parse().unwrap(); + assert_eq!(result, Ipv6Addr::new(1, 1, 1, 1, 1, 1, 0xc0a8, 0x1)); + } + + // Things that might not seem like failures but are + #[test] + fn ipv6_corner_failures() { + // No IP address before the :: + assert!(Ipv6Addr::from_str("1:192.168.0.1::").is_err()); + + // :: must have at least 1 set of zeroes + assert!(Ipv6Addr::from_str("1:1:1:1::1:1:1:1").is_err()); + + // Need brackets for a port + assert!(SocketAddrV6::from_str("1:1:1:1:1:1:1:1:8080").is_err()); + } +} -- 2.44.0