From: Jethro Beekman Date: Tue, 18 Sep 2018 22:25:08 +0000 (-0700) Subject: Refactor net::each_addr/lookup_host to forward error from resolve X-Git-Url: https://git.lizzy.rs/?a=commitdiff_plain;h=22c43689937a81cf5ad6ecfe22d9e63e3cebed04;p=rust.git Refactor net::each_addr/lookup_host to forward error from resolve --- diff --git a/src/libstd/net/addr.rs b/src/libstd/net/addr.rs index ff35325ab4f..1ac0bdf922f 100644 --- a/src/libstd/net/addr.rs +++ b/src/libstd/net/addr.rs @@ -16,10 +16,11 @@ use option; use sys::net::netc as c; use sys_common::{FromInner, AsInner, IntoInner}; -use sys_common::net::lookup_host; +use sys_common::net::LookupHost; use vec; use iter; use slice; +use convert::TryInto; /// An internet socket address, either IPv4 or IPv6. /// @@ -863,9 +864,9 @@ fn to_socket_addrs(&self) -> io::Result> { } } -fn resolve_socket_addr(s: &str, p: u16) -> io::Result> { - let ips = lookup_host(s)?; - let v: Vec<_> = ips.map(|mut a| { a.set_port(p); a }).collect(); +fn resolve_socket_addr(lh: LookupHost) -> io::Result> { + let p = lh.port(); + let v: Vec<_> = lh.map(|mut a| { a.set_port(p); a }).collect(); Ok(v.into_iter()) } @@ -885,7 +886,7 @@ fn to_socket_addrs(&self) -> io::Result> { return Ok(vec![SocketAddr::V6(addr)].into_iter()) } - resolve_socket_addr(host, port) + resolve_socket_addr((host, port).try_into()?) } } @@ -899,22 +900,7 @@ fn to_socket_addrs(&self) -> io::Result> { return Ok(vec![addr].into_iter()); } - macro_rules! try_opt { - ($e:expr, $msg:expr) => ( - match $e { - Some(r) => r, - None => return Err(io::Error::new(io::ErrorKind::InvalidInput, - $msg)), - } - ) - } - - // split the string by ':' and convert the second part to u16 - let mut parts_iter = self.rsplitn(2, ':'); - let port_str = try_opt!(parts_iter.next(), "invalid socket address"); - let host = try_opt!(parts_iter.next(), "invalid socket address"); - let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value"); - resolve_socket_addr(host, port) + resolve_socket_addr(self.try_into()?) } } diff --git a/src/libstd/net/mod.rs b/src/libstd/net/mod.rs index be4bcee8a68..ff579a5feb1 100644 --- a/src/libstd/net/mod.rs +++ b/src/libstd/net/mod.rs @@ -112,11 +112,15 @@ fn hton(i: I) -> I { i.to_be() } fn ntoh(i: I) -> I { I::from_be(i) } fn each_addr(addr: A, mut f: F) -> io::Result - where F: FnMut(&SocketAddr) -> io::Result + where F: FnMut(io::Result<&SocketAddr>) -> io::Result { + let addrs = match addr.to_socket_addrs() { + Ok(addrs) => addrs, + Err(e) => return f(Err(e)) + }; let mut last_err = None; - for addr in addr.to_socket_addrs()? { - match f(&addr) { + for addr in addrs { + match f(Ok(&addr)) { Ok(l) => return Ok(l), Err(e) => last_err = Some(e), } diff --git a/src/libstd/sys/cloudabi/shims/net.rs b/src/libstd/sys/cloudabi/shims/net.rs index 93eaf6a9e7d..7229e71d175 100644 --- a/src/libstd/sys/cloudabi/shims/net.rs +++ b/src/libstd/sys/cloudabi/shims/net.rs @@ -13,13 +13,14 @@ use net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; use time::Duration; use sys::{unsupported, Void}; +use convert::TryFrom; pub extern crate libc as netc; pub struct TcpStream(Void); impl TcpStream { - pub fn connect(_: &SocketAddr) -> io::Result { + pub fn connect(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -105,7 +106,7 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct TcpListener(Void); impl TcpListener { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -155,7 +156,7 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct UdpSocket(Void); impl UdpSocket { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -271,7 +272,7 @@ pub fn send(&self, _: &[u8]) -> io::Result { match self.0 {} } - pub fn connect(&self, _: &SocketAddr) -> io::Result<()> { + pub fn connect(&self, _: io::Result<&SocketAddr>) -> io::Result<()> { match self.0 {} } } @@ -284,6 +285,12 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct LookupHost(Void); +impl LookupHost { + pub fn port(&self) -> u16 { + match self.0 {} + } +} + impl Iterator for LookupHost { type Item = SocketAddr; fn next(&mut self) -> Option { @@ -291,6 +298,18 @@ fn next(&mut self) -> Option { } } -pub fn lookup_host(_: &str) -> io::Result { - unsupported() +impl<'a> TryFrom<&'a str> for LookupHost { + type Error = io::Error; + + fn try_from(_v: &'a str) -> io::Result { + unsupported() + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from(_v: (&'a str, u16)) -> io::Result { + unsupported() + } } diff --git a/src/libstd/sys/redox/net/mod.rs b/src/libstd/sys/redox/net/mod.rs index 67f22231d5f..04a183f2417 100644 --- a/src/libstd/sys/redox/net/mod.rs +++ b/src/libstd/sys/redox/net/mod.rs @@ -9,7 +9,7 @@ // except according to those terms. use fs::File; -use io::{Error, Result, Read}; +use io::{Error, Read, self}; use iter::Iterator; use net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use str::FromStr; @@ -17,6 +17,7 @@ use sys::syscall::EINVAL; use time::{self, Duration}; use vec::{IntoIter, Vec}; +use convert::{TryFrom, TryInto}; use self::dns::{Dns, DnsQuery}; @@ -29,7 +30,13 @@ mod tcp; mod udp; -pub struct LookupHost(IntoIter); +pub struct LookupHost(IntoIter, u16); + +impl LookupHost { + pub fn port(&self) -> u16 { + self.1 + } +} impl Iterator for LookupHost { type Item = SocketAddr; @@ -38,65 +45,93 @@ fn next(&mut self) -> Option { } } -pub fn lookup_host(host: &str) -> Result { - let mut ip_string = String::new(); - File::open("/etc/net/ip")?.read_to_string(&mut ip_string)?; - let ip: Vec = ip_string.trim().split('.').map(|part| part.parse::() - .unwrap_or(0)).collect(); - - let mut dns_string = String::new(); - File::open("/etc/net/dns")?.read_to_string(&mut dns_string)?; - let dns: Vec = dns_string.trim().split('.').map(|part| part.parse::() - .unwrap_or(0)).collect(); - - if ip.len() == 4 && dns.len() == 4 { - let time = time::SystemTime::now().duration_since(time::UNIX_EPOCH).unwrap(); - let tid = (time.subsec_nanos() >> 16) as u16; - - let packet = Dns { - transaction_id: tid, - flags: 0x0100, - queries: vec![DnsQuery { - name: host.to_string(), - q_type: 0x0001, - q_class: 0x0001, - }], - answers: vec![] - }; - - let packet_data = packet.compile(); - - let my_ip = Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]); - let dns_ip = Ipv4Addr::new(dns[0], dns[1], dns[2], dns[3]); - let socket = UdpSocket::bind(&SocketAddr::V4(SocketAddrV4::new(my_ip, 0)))?; - socket.set_read_timeout(Some(Duration::new(5, 0)))?; - socket.set_write_timeout(Some(Duration::new(5, 0)))?; - socket.connect(&SocketAddr::V4(SocketAddrV4::new(dns_ip, 53)))?; - socket.send(&packet_data)?; - - let mut buf = [0; 65536]; - let count = socket.recv(&mut buf)?; - - match Dns::parse(&buf[.. count]) { - Ok(response) => { - let mut addrs = vec![]; - for answer in response.answers.iter() { - if answer.a_type == 0x0001 && answer.a_class == 0x0001 - && answer.data.len() == 4 - { - let answer_ip = Ipv4Addr::new(answer.data[0], - answer.data[1], - answer.data[2], - answer.data[3]); - addrs.push(SocketAddr::V4(SocketAddrV4::new(answer_ip, 0))); - } +impl<'a> TryFrom<&'a str> for LookupHost { + type Error = io::Error; + + fn try_from(s: &str) -> io::Result { + macro_rules! try_opt { + ($e:expr, $msg:expr) => ( + match $e { + Some(r) => r, + None => return Err(io::Error::new(io::ErrorKind::InvalidInput, + $msg)), } - Ok(LookupHost(addrs.into_iter())) - }, - Err(_err) => Err(Error::from_raw_os_error(EINVAL)) + ) + } + + // split the string by ':' and convert the second part to u16 + let mut parts_iter = s.rsplitn(2, ':'); + let port_str = try_opt!(parts_iter.next(), "invalid socket address"); + let host = try_opt!(parts_iter.next(), "invalid socket address"); + let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value"); + + (host, port).try_into() + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from((host, port): (&'a str, u16)) -> io::Result { + let mut ip_string = String::new(); + File::open("/etc/net/ip")?.read_to_string(&mut ip_string)?; + let ip: Vec = ip_string.trim().split('.').map(|part| part.parse::() + .unwrap_or(0)).collect(); + + let mut dns_string = String::new(); + File::open("/etc/net/dns")?.read_to_string(&mut dns_string)?; + let dns: Vec = dns_string.trim().split('.').map(|part| part.parse::() + .unwrap_or(0)).collect(); + + if ip.len() == 4 && dns.len() == 4 { + let time = time::SystemTime::now().duration_since(time::UNIX_EPOCH).unwrap(); + let tid = (time.subsec_nanos() >> 16) as u16; + + let packet = Dns { + transaction_id: tid, + flags: 0x0100, + queries: vec![DnsQuery { + name: host.to_string(), + q_type: 0x0001, + q_class: 0x0001, + }], + answers: vec![] + }; + + let packet_data = packet.compile(); + + let my_ip = Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]); + let dns_ip = Ipv4Addr::new(dns[0], dns[1], dns[2], dns[3]); + let socket = UdpSocket::bind(Ok(&SocketAddr::V4(SocketAddrV4::new(my_ip, 0))))?; + socket.set_read_timeout(Some(Duration::new(5, 0)))?; + socket.set_write_timeout(Some(Duration::new(5, 0)))?; + socket.connect(Ok(&SocketAddr::V4(SocketAddrV4::new(dns_ip, 53))))?; + socket.send(&packet_data)?; + + let mut buf = [0; 65536]; + let count = socket.recv(&mut buf)?; + + match Dns::parse(&buf[.. count]) { + Ok(response) => { + let mut addrs = vec![]; + for answer in response.answers.iter() { + if answer.a_type == 0x0001 && answer.a_class == 0x0001 + && answer.data.len() == 4 + { + let answer_ip = Ipv4Addr::new(answer.data[0], + answer.data[1], + answer.data[2], + answer.data[3]); + addrs.push(SocketAddr::V4(SocketAddrV4::new(answer_ip, 0))); + } + } + Ok(LookupHost(addrs.into_iter(), port)) + }, + Err(_err) => Err(Error::from_raw_os_error(EINVAL)) + } + } else { + Err(Error::from_raw_os_error(EINVAL)) } - } else { - Err(Error::from_raw_os_error(EINVAL)) } } diff --git a/src/libstd/sys/redox/net/tcp.rs b/src/libstd/sys/redox/net/tcp.rs index b5664908479..37457d87f33 100644 --- a/src/libstd/sys/redox/net/tcp.rs +++ b/src/libstd/sys/redox/net/tcp.rs @@ -24,8 +24,8 @@ pub struct TcpStream(File); impl TcpStream { - pub fn connect(addr: &SocketAddr) -> Result { - let path = format!("tcp:{}", addr); + pub fn connect(addr: Result<&SocketAddr>) -> Result { + let path = format!("tcp:{}", addr?); let mut options = OpenOptions::new(); options.read(true); options.write(true); @@ -180,8 +180,8 @@ fn into_inner(self) -> File { self.0 } pub struct TcpListener(File); impl TcpListener { - pub fn bind(addr: &SocketAddr) -> Result { - let path = format!("tcp:/{}", addr); + pub fn bind(addr: Result<&SocketAddr>) -> Result { + let path = format!("tcp:/{}", addr?); let mut options = OpenOptions::new(); options.read(true); options.write(true); diff --git a/src/libstd/sys/redox/net/udp.rs b/src/libstd/sys/redox/net/udp.rs index 22af02079e7..85bfd425924 100644 --- a/src/libstd/sys/redox/net/udp.rs +++ b/src/libstd/sys/redox/net/udp.rs @@ -25,8 +25,8 @@ pub struct UdpSocket(File, UnsafeCell>); impl UdpSocket { - pub fn bind(addr: &SocketAddr) -> Result { - let path = format!("udp:/{}", addr); + pub fn bind(addr: Result<&SocketAddr>) -> Result { + let path = format!("udp:/{}", addr?); let mut options = OpenOptions::new(); options.read(true); options.write(true); @@ -37,8 +37,8 @@ fn get_conn(&self) -> &mut Option { unsafe { &mut *(self.1.get()) } } - pub fn connect(&self, addr: &SocketAddr) -> Result<()> { - unsafe { *self.1.get() = Some(*addr) }; + pub fn connect(&self, addr: Result<&SocketAddr>) -> Result<()> { + unsafe { *self.1.get() = Some(*addr?) }; Ok(()) } diff --git a/src/libstd/sys/unix/l4re.rs b/src/libstd/sys/unix/l4re.rs index 21218489679..bbb0fd45ba3 100644 --- a/src/libstd/sys/unix/l4re.rs +++ b/src/libstd/sys/unix/l4re.rs @@ -21,7 +21,7 @@ pub mod net { use sys_common::{AsInner, FromInner, IntoInner}; use sys::fd::FileDesc; use time::Duration; - + use convert::TryFrom; pub extern crate libc as netc; @@ -118,7 +118,7 @@ pub struct TcpStream { } impl TcpStream { - pub fn connect(_: &SocketAddr) -> io::Result { + pub fn connect(_: io::Result<&SocketAddr>) -> io::Result { unimpl!(); } @@ -216,7 +216,7 @@ pub struct TcpListener { } impl TcpListener { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unimpl!(); } @@ -278,7 +278,7 @@ pub struct UdpSocket { } impl UdpSocket { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unimpl!(); } @@ -402,7 +402,7 @@ pub fn send(&self, _: &[u8]) -> io::Result { unimpl!(); } - pub fn connect(&self, _: &SocketAddr) -> io::Result<()> { + pub fn connect(&self, _: io::Result<&SocketAddr>) -> io::Result<()> { unimpl!(); } } @@ -431,11 +431,30 @@ fn next(&mut self) -> Option { } } + impl LookupHost { + pub fn port(&self) -> u16 { + unimpl!(); + } + } + unsafe impl Sync for LookupHost {} unsafe impl Send for LookupHost {} - pub fn lookup_host(_: &str) -> io::Result { - unimpl!(); + + impl<'a> TryFrom<&'a str> for LookupHost { + type Error = io::Error; + + fn try_from(_v: &'a str) -> io::Result { + unimpl!(); + } + } + + impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from(_v: (&'a str, u16)) -> io::Result { + unimpl!(); + } } } diff --git a/src/libstd/sys/wasm/net.rs b/src/libstd/sys/wasm/net.rs index 03a5b2d779e..e1c33b09cb4 100644 --- a/src/libstd/sys/wasm/net.rs +++ b/src/libstd/sys/wasm/net.rs @@ -13,11 +13,12 @@ use net::{SocketAddr, Shutdown, Ipv4Addr, Ipv6Addr}; use time::Duration; use sys::{unsupported, Void}; +use convert::TryFrom; pub struct TcpStream(Void); impl TcpStream { - pub fn connect(_: &SocketAddr) -> io::Result { + pub fn connect(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -103,7 +104,7 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct TcpListener(Void); impl TcpListener { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -153,7 +154,7 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct UdpSocket(Void); impl UdpSocket { - pub fn bind(_: &SocketAddr) -> io::Result { + pub fn bind(_: io::Result<&SocketAddr>) -> io::Result { unsupported() } @@ -273,7 +274,7 @@ pub fn send(&self, _: &[u8]) -> io::Result { match self.0 {} } - pub fn connect(&self, _: &SocketAddr) -> io::Result<()> { + pub fn connect(&self, _: io::Result<&SocketAddr>) -> io::Result<()> { match self.0 {} } } @@ -286,6 +287,12 @@ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { pub struct LookupHost(Void); +impl LookupHost { + pub fn port(&self) -> u16 { + match self.0 {} + } +} + impl Iterator for LookupHost { type Item = SocketAddr; fn next(&mut self) -> Option { @@ -293,8 +300,20 @@ fn next(&mut self) -> Option { } } -pub fn lookup_host(_: &str) -> io::Result { - unsupported() +impl<'a> TryFrom<&'a str> for LookupHost { + type Error = io::Error; + + fn try_from(_v: &'a str) -> io::Result { + unsupported() + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from(_v: (&'a str, u16)) -> io::Result { + unsupported() + } } #[allow(nonstandard_style)] diff --git a/src/libstd/sys_common/net.rs b/src/libstd/sys_common/net.rs index d09a233ed89..dce2bf71cec 100644 --- a/src/libstd/sys_common/net.rs +++ b/src/libstd/sys_common/net.rs @@ -20,6 +20,7 @@ use sys::net::netc as c; use sys_common::{AsInner, FromInner, IntoInner}; use time::Duration; +use convert::{TryFrom, TryInto}; #[cfg(any(target_os = "dragonfly", target_os = "freebsd", target_os = "ios", target_os = "macos", @@ -129,6 +130,13 @@ fn to_ipv6mr_interface(value: u32) -> ::libc::c_uint { pub struct LookupHost { original: *mut c::addrinfo, cur: *mut c::addrinfo, + port: u16 +} + +impl LookupHost { + pub fn port(&self) -> u16 { + self.port + } } impl Iterator for LookupHost { @@ -158,17 +166,45 @@ fn drop(&mut self) { } } -pub fn lookup_host(host: &str) -> io::Result { - init(); +impl<'a> TryFrom<&'a str> for LookupHost { + type Error = io::Error; - let c_host = CString::new(host)?; - let mut hints: c::addrinfo = unsafe { mem::zeroed() }; - hints.ai_socktype = c::SOCK_STREAM; - let mut res = ptr::null_mut(); - unsafe { - cvt_gai(c::getaddrinfo(c_host.as_ptr(), ptr::null(), &hints, &mut res)).map(|_| { - LookupHost { original: res, cur: res } - }) + fn try_from(s: &str) -> io::Result { + macro_rules! try_opt { + ($e:expr, $msg:expr) => ( + match $e { + Some(r) => r, + None => return Err(io::Error::new(io::ErrorKind::InvalidInput, + $msg)), + } + ) + } + + // split the string by ':' and convert the second part to u16 + let mut parts_iter = s.rsplitn(2, ':'); + let port_str = try_opt!(parts_iter.next(), "invalid socket address"); + let host = try_opt!(parts_iter.next(), "invalid socket address"); + let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value"); + + (host, port).try_into() + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from((host, port): (&'a str, u16)) -> io::Result { + init(); + + let c_host = CString::new(host)?; + let mut hints: c::addrinfo = unsafe { mem::zeroed() }; + hints.ai_socktype = c::SOCK_STREAM; + let mut res = ptr::null_mut(); + unsafe { + cvt_gai(c::getaddrinfo(c_host.as_ptr(), ptr::null(), &hints, &mut res)).map(|_| { + LookupHost { original: res, cur: res, port } + }) + } } } @@ -181,7 +217,9 @@ pub struct TcpStream { } impl TcpStream { - pub fn connect(addr: &SocketAddr) -> io::Result { + pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result { + let addr = addr?; + init(); let sock = Socket::new(addr, c::SOCK_STREAM)?; @@ -317,7 +355,9 @@ pub struct TcpListener { } impl TcpListener { - pub fn bind(addr: &SocketAddr) -> io::Result { + pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result { + let addr = addr?; + init(); let sock = Socket::new(addr, c::SOCK_STREAM)?; @@ -418,7 +458,9 @@ pub struct UdpSocket { } impl UdpSocket { - pub fn bind(addr: &SocketAddr) -> io::Result { + pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result { + let addr = addr?; + init(); let sock = Socket::new(addr, c::SOCK_DGRAM)?; @@ -584,8 +626,8 @@ pub fn send(&self, buf: &[u8]) -> io::Result { Ok(ret as usize) } - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - let (addrp, len) = addr.into_inner(); + pub fn connect(&self, addr: io::Result<&SocketAddr>) -> io::Result<()> { + let (addrp, len) = addr?.into_inner(); cvt_r(|| unsafe { c::connect(*self.inner.as_inner(), addrp, len) }).map(|_| ()) } } @@ -618,7 +660,7 @@ mod tests { #[test] fn no_lookup_host_duplicates() { let mut addrs = HashMap::new(); - let lh = match lookup_host("localhost") { + let lh = match LookupHost::try_from(("localhost", 0)) { Ok(lh) => lh, Err(e) => panic!("couldn't resolve `localhost': {}", e) };