From 827be2de0d753afb3e5a00e66afe6e3c3ac79494 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 27 Feb 2016 14:15:19 -0800 Subject: [PATCH] Add TCP functionality from net2 --- src/libstd/net/tcp.rs | 231 ++++++++++++++++++++++++++++++++++ src/libstd/sys/common/net.rs | 79 ++++++++++++ src/libstd/sys/unix/net.rs | 48 +++++++ src/libstd/sys/windows/c.rs | 28 +++++ src/libstd/sys/windows/net.rs | 54 ++++++++ 5 files changed, 440 insertions(+) diff --git a/src/libstd/net/tcp.rs b/src/libstd/net/tcp.rs index f9c38c38458..b8530c98398 100644 --- a/src/libstd/net/tcp.rs +++ b/src/libstd/net/tcp.rs @@ -180,6 +180,117 @@ pub fn read_timeout(&self) -> io::Result> { pub fn write_timeout(&self) -> io::Result> { self.0.write_timeout() } + + /// Sets the value of the `TCP_NODELAY` option on this socket. + /// + /// If set, this option disables the Nagle algorithm. This means that + /// segments are always sent as soon as possible, even if there is only a + /// small amount of data. When not set, data is buffered until there is a + /// sufficient amount to send out, thereby avoiding the frequent sending of + /// small packets. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { + self.0.set_nodelay(nodelay) + } + + /// Gets the value of the `TCP_NODELAY` option on this socket. + /// + /// For more information about this option, see [`set_nodelay`][link]. + /// + /// [link]: #tymethod.set_nodelay + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn nodelay(&self) -> io::Result { + self.0.nodelay() + } + + /// Sets whether keepalive messages are enabled to be sent on this socket. + /// + /// On Unix, this option will set the `SO_KEEPALIVE` as well as the + /// `TCP_KEEPALIVE` or `TCP_KEEPIDLE` option (depending on your platform). + /// On Windows, this will set the `SIO_KEEPALIVE_VALS` option. + /// + /// If `None` is specified then keepalive messages are disabled, otherwise + /// the duration specified will be the time to remain idle before sending a + /// TCP keepalive probe. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_keepalive(&self, keepalive: Option) -> io::Result<()> { + self.0.set_keepalive(keepalive) + } + + /// Returns whether keepalive messages are enabled on this socket, and if so + /// the duration of time between them. + /// + /// For more information about this option, see [`set_keepalive`][link]. + /// + /// [link]: #tymethod.set_keepalive + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn keepalive(&self) -> io::Result> { + self.0.keepalive() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.0.set_ttl(ttl) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`][link]. + /// + /// [link]: #tymethod.set_ttl + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn ttl(&self) -> io::Result { + self.0.ttl() + } + + /// Sets the value for the `IPV6_V6ONLY` option on this socket. + /// + /// If this is set to `true` then the socket is restricted to sending and + /// receiving IPv6 packets only. If this is the case, an IPv4 and an IPv6 + /// application can each bind the same port at the same time. + /// + /// If this is set to `false` then the socket can be used to send and + /// receive packets from an IPv4-mapped IPv6 address. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + self.0.set_only_v6(only_v6) + } + + /// Gets the value of the `IPV6_V6ONLY` option for this socket. + /// + /// For more information about this option, see [`set_only_v6`][link]. + /// + /// [link]: #tymethod.set_only_v6 + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn only_v6(&self) -> io::Result { + self.0.only_v6() + } + + /// Get the value of the `SO_ERROR` option on this socket. + /// + /// This will retrieve the stored error in the underlying socket, clearing + /// the field in the process. This can be useful for checking errors between + /// calls. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Moves this TCP stream into or out of nonblocking mode. + /// + /// On Unix this corresponds to calling fcntl, and on Windows this + /// corresponds to calling ioctlsocket. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -278,6 +389,67 @@ pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { pub fn incoming(&self) -> Incoming { Incoming { listener: self } } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.0.set_ttl(ttl) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`][link]. + /// + /// [link]: #tymethod.set_ttl + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn ttl(&self) -> io::Result { + self.0.ttl() + } + + /// Sets the value for the `IPV6_V6ONLY` option on this socket. + /// + /// If this is set to `true` then the socket is restricted to sending and + /// receiving IPv6 packets only. In this case two IPv4 and IPv6 applications + /// can bind the same port at the same time. + /// + /// If this is set to `false` then the socket can be used to send and + /// receive packets from an IPv4-mapped IPv6 address. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + self.0.set_only_v6(only_v6) + } + + /// Gets the value of the `IPV6_V6ONLY` option for this socket. + /// + /// For more information about this option, see [`set_only_v6`][link]. + /// + /// [link]: #tymethod.set_only_v6 + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn only_v6(&self) -> io::Result { + self.0.only_v6() + } + + /// Get the value of the `SO_ERROR` option on this socket. + /// + /// This will retrieve the stored error in the underlying socket, clearing + /// the field in the process. This can be useful for checking errors between + /// calls. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Moves this TCP stream into or out of nonblocking mode. + /// + /// On Unix this corresponds to calling fcntl, and on Windows this + /// corresponds to calling ioctlsocket. + #[stable(feature = "net2_mutators", since = "1.9.0")] + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -969,4 +1141,63 @@ fn test_read_with_timeout() { assert!(start.elapsed() > Duration::from_millis(400)); drop(listener); } + + #[test] + fn nodelay() { + let addr = next_test_ip4(); + let _listener = t!(TcpListener::bind(&addr)); + + let stream = t!(TcpStream::connect(&("localhost", addr.port()))); + + assert_eq!(false, t!(stream.nodelay())); + t!(stream.set_nodelay(true)); + assert_eq!(true, t!(stream.nodelay())); + t!(stream.set_nodelay(false)); + assert_eq!(false, t!(stream.nodelay())); + } + + #[test] + fn keepalive() { + let addr = next_test_ip4(); + let _listener = t!(TcpListener::bind(&addr)); + + let stream = t!(TcpStream::connect(&("localhost", addr.port()))); + let dur = Duration::new(15410, 0); + + assert_eq!(None, t!(stream.keepalive())); + t!(stream.set_keepalive(Some(dur))); + assert_eq!(Some(dur), t!(stream.keepalive())); + t!(stream.set_keepalive(None)); + assert_eq!(None, t!(stream.keepalive())); + } + + #[test] + fn ttl() { + let ttl = 100; + + let addr = next_test_ip4(); + let listener = t!(TcpListener::bind(&addr)); + + t!(listener.set_ttl(ttl)); + assert_eq!(ttl, t!(listener.ttl())); + + let stream = t!(TcpStream::connect(&("localhost", addr.port()))); + + t!(stream.set_ttl(ttl)); + assert_eq!(ttl, t!(stream.ttl())); + } + + #[test] + fn set_nonblocking() { + let addr = next_test_ip4(); + let listener = t!(TcpListener::bind(&addr)); + + t!(listener.set_nonblocking(true)); + t!(listener.set_nonblocking(false)); + + let stream = t!(TcpStream::connect(&("localhost", addr.port()))); + + t!(stream.set_nonblocking(true)); + t!(stream.set_nonblocking(false)); + } } diff --git a/src/libstd/sys/common/net.rs b/src/libstd/sys/common/net.rs index 1cb9303a9fc..0ac4056de8e 100644 --- a/src/libstd/sys/common/net.rs +++ b/src/libstd/sys/common/net.rs @@ -228,6 +228,54 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { pub fn duplicate(&self) -> io::Result { self.inner.duplicate().map(|s| TcpStream { inner: s }) } + + pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_TCP, c::TCP_NODELAY, nodelay as c_int) + } + + pub fn nodelay(&self) -> io::Result { + let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_TCP, c::TCP_NODELAY)); + Ok(raw != 0) + } + + pub fn set_keepalive(&self, keepalive: Option) -> io::Result<()> { + self.inner.set_keepalive(keepalive) + } + + pub fn keepalive(&self) -> io::Result> { + self.inner.keepalive() + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) + } + + pub fn ttl(&self) -> io::Result { + let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)); + Ok(raw as u32) + } + + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int) + } + + pub fn only_v6(&self) -> io::Result { + let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY)); + Ok(raw != 0) + } + + pub fn take_error(&self) -> io::Result> { + let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR)); + if raw == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(raw as i32))) + } + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } } impl FromInner for TcpStream { @@ -307,6 +355,37 @@ pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { pub fn duplicate(&self) -> io::Result { self.inner.duplicate().map(|s| TcpListener { inner: s }) } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) + } + + pub fn ttl(&self) -> io::Result { + let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)); + Ok(raw as u32) + } + + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int) + } + + pub fn only_v6(&self) -> io::Result { + let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY)); + Ok(raw != 0) + } + + pub fn take_error(&self) -> io::Result> { + let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR)); + if raw == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(raw as i32))) + } + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } } impl FromInner for TcpListener { diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs index 16c369674f0..7a2ac7257af 100644 --- a/src/libstd/sys/unix/net.rs +++ b/src/libstd/sys/unix/net.rs @@ -35,6 +35,16 @@ #[cfg(not(target_os = "linux"))] const SOCK_CLOEXEC: c_int = 0; +#[cfg(any(target_os = "openbsd", taret_os = "freebsd"))] +use libc::SO_KEEPALIVE as TCP_KEEPALIVE; +#[cfg(any(target_os = "macos", taret_os = "ios"))] +use libc::TCP_KEEPALIVE; +#[cfg(not(any(target_os = "openbsd", + target_os = "freebsd", + target_os = "macos", + target_os = "ios")))] +use libc::TCP_KEEPIDLE as TCP_KEEPALIVE; + pub struct Socket(FileDesc); pub fn init() {} @@ -168,6 +178,44 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { try!(cvt(unsafe { libc::shutdown(self.0.raw(), how) })); Ok(()) } + + pub fn set_keepalive(&self, keepalive: Option) -> io::Result<()> { + try!(setsockopt(self, + libc::SOL_SOCKET, + libc::SO_KEEPALIVE, + keepalive.is_some() as libc::c_int)); + if let Some(dur) = keepalive { + let mut raw = dur.as_secs(); + if dur.subsec_nanos() > 0 { + raw = raw.saturating_add(1); + } + + let raw = if raw > libc::c_int::max_value() as u64 { + libc::c_int::max_value() + } else { + raw as libc::c_int + }; + + try!(setsockopt(self, libc::IPPROTO_TCP, TCP_KEEPALIVE, raw)); + } + + Ok(()) + } + + pub fn keepalive(&self) -> io::Result> { + let raw: c_int = try!(getsockopt(self, libc::SOL_SOCKET, libc::SO_KEEPALIVE)); + if raw == 0 { + return Ok(None); + } + + let raw: c_int = try!(getsockopt(self, libc::IPPROTO_TCP, TCP_KEEPALIVE)); + Ok(Some(Duration::from_secs(raw as u64))) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = nonblocking as libc::c_ulong; + cvt(unsafe { libc::ioctl(*self.as_inner(), libc::FIONBIO, &mut nonblocking) }).map(|_| ()) + } } impl AsInner for Socket { diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs index 5cbfec01bed..cc420763fd7 100644 --- a/src/libstd/sys/windows/c.rs +++ b/src/libstd/sys/windows/c.rs @@ -78,6 +78,13 @@ pub type socklen_t = c_int; pub type ADDRESS_FAMILY = USHORT; +pub type LPWSAOVERLAPPED_COMPLETION_ROUTINE = + Option; +pub type LPWSAOVERLAPPED = *mut OVERLAPPED; + pub const TRUE: BOOL = 1; pub const FALSE: BOOL = 0; @@ -114,6 +121,9 @@ pub const FILE_FLAG_BACKUP_SEMANTICS: DWORD = 0x02000000; pub const SECURITY_SQOS_PRESENT: DWORD = 0x00100000; +pub const SIO_KEEPALIVE_VALS: DWORD = 0x98000004; +pub const FIONBIO: c_ulong = 0x8004667e; + #[repr(C)] #[derive(Copy)] pub struct WIN32_FIND_DATAW { @@ -775,6 +785,13 @@ pub struct in6_addr { pub s6_addr: [u8; 16], } +#[repr(C)] +pub struct tcp_keepalive { + pub onoff: c_ulong, + pub keepalivetime: c_ulong, + pub keepaliveinterval: c_ulong, +} + #[cfg(all(target_arch = "x86_64", target_env = "gnu"))] pub enum UNWIND_HISTORY_TABLE {} @@ -833,6 +850,17 @@ pub fn WSASocketW(af: c_int, lpProtocolInfo: LPWSAPROTOCOL_INFO, g: GROUP, dwFlags: DWORD) -> SOCKET; + pub fn WSAIoctl(s: SOCKET, + dwIoControlCode: DWORD, + lpvInBuffer: LPVOID, + cbInBuffer: DWORD, + lpvOutBuffer: LPVOID, + cbOutBuffer: DWORD, + lpcbBytesReturned: LPDWORD, + lpOverlapped: LPWSAOVERLAPPED, + lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE) + -> c_int; + pub fn ioctlsocket(s: SOCKET, cmd: c_long, argp: *mut u_long) -> c_int; pub fn InitializeCriticalSection(CriticalSection: *mut CRITICAL_SECTION); pub fn EnterCriticalSection(CriticalSection: *mut CRITICAL_SECTION); pub fn TryEnterCriticalSection(CriticalSection: *mut CRITICAL_SECTION) -> BOOLEAN; diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs index 49ba8e9c659..be13657aaf4 100644 --- a/src/libstd/sys/windows/net.rs +++ b/src/libstd/sys/windows/net.rs @@ -185,6 +185,60 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { try!(cvt(unsafe { c::shutdown(self.0, how) })); Ok(()) } + + pub fn set_keepalive(&self, keepalive: Option) -> io::Result<()> { + let ms = keepalive.map(sys::dur2timeout).unwrap_or(c::INFINITE); + let ka = c::tcp_keepalive { + onoff: keepalive.is_some() as c::c_ulong, + keepalivetime: ms as c::c_ulong, + keepaliveinterval: ms as c::c_ulong, + }; + sys::cvt(unsafe { + c::WSAIoctl(self.0, + c::SIO_KEEPALIVE_VALS, + &ka as *const _ as *mut _, + mem::size_of_val(&ka) as c::DWORD, + 0 as *mut _, + 0, + 0 as *mut _, + 0 as *mut _, + None) + }).map(|_| ()) + } + + pub fn keepalive(&self) -> io::Result> { + let mut ka = c::tcp_keepalive { + onoff: 0, + keepalivetime: 0, + keepaliveinterval: 0, + }; + try!(sys::cvt(unsafe { + WSAIoctl(self.0, + c::SIO_KEEPALIVE_VALS, + 0 as *mut _, + 0, + &mut ka as *mut _ as *mut _, + mem::size_of_val(&ka) as c::DWORD, + 0 as *mut _, + 0 as *mut _, + None) + })); + + if ka.onoff == 0 { + Ok(None) + } else { + let secs = ka.keepaliveinterval / 1000; + let nsec = (ka.keepaliveinterval % 1000) * 1000000; + Ok(Some(Duration::new(secs as u64, nsec as u32))) + } + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = nonblocking as c::c_ulong; + sys::cvt(unsafe { + c::ioctlsocket(self.0, c::FIONBIO as c::c_int, &mut nonblocking) + }).map(|_| ()) + } } impl Drop for Socket { -- 2.44.0