]> git.lizzy.rs Git - rust.git/commitdiff
Add TCP functionality from net2
authorSteven Fackler <sfackler@gmail.com>
Sat, 27 Feb 2016 22:15:19 +0000 (14:15 -0800)
committerSteven Fackler <sfackler@gmail.com>
Sun, 28 Feb 2016 17:41:33 +0000 (09:41 -0800)
src/libstd/net/tcp.rs
src/libstd/sys/common/net.rs
src/libstd/sys/unix/net.rs
src/libstd/sys/windows/c.rs
src/libstd/sys/windows/net.rs

index f9c38c38458475661a87e298266849114991d7e7..b8530c98398c85985659a8b6654d3726ad80a091 100644 (file)
@@ -180,6 +180,117 @@ pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
     pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
         self.0.write_timeout()
     }
+
+    /// Sets the value of the `TCP_NODELAY` option on this socket.
+    ///
+    /// If set, this option disables the Nagle algorithm. This means that
+    /// segments are always sent as soon as possible, even if there is only a
+    /// small amount of data. When not set, data is buffered until there is a
+    /// sufficient amount to send out, thereby avoiding the frequent sending of
+    /// small packets.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
+        self.0.set_nodelay(nodelay)
+    }
+
+    /// Gets the value of the `TCP_NODELAY` option on this socket.
+    ///
+    /// For more information about this option, see [`set_nodelay`][link].
+    ///
+    /// [link]: #tymethod.set_nodelay
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn nodelay(&self) -> io::Result<bool> {
+        self.0.nodelay()
+    }
+
+    /// Sets whether keepalive messages are enabled to be sent on this socket.
+    ///
+    /// On Unix, this option will set the `SO_KEEPALIVE` as well as the
+    /// `TCP_KEEPALIVE` or `TCP_KEEPIDLE` option (depending on your platform).
+    /// On Windows, this will set the `SIO_KEEPALIVE_VALS` option.
+    ///
+    /// If `None` is specified then keepalive messages are disabled, otherwise
+    /// the duration specified will be the time to remain idle before sending a
+    /// TCP keepalive probe.
+    ///
+    /// Some platforms specify this value in seconds, so sub-second
+    /// specifications may be omitted.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
+        self.0.set_keepalive(keepalive)
+    }
+
+    /// Returns whether keepalive messages are enabled on this socket, and if so
+    /// the duration of time between them.
+    ///
+    /// For more information about this option, see [`set_keepalive`][link].
+    ///
+    /// [link]: #tymethod.set_keepalive
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn keepalive(&self) -> io::Result<Option<Duration>> {
+        self.0.keepalive()
+    }
+
+    /// Sets the value for the `IP_TTL` option on this socket.
+    ///
+    /// This value sets the time-to-live field that is used in every packet sent
+    /// from this socket.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
+        self.0.set_ttl(ttl)
+    }
+
+    /// Gets the value of the `IP_TTL` option for this socket.
+    ///
+    /// For more information about this option, see [`set_ttl`][link].
+    ///
+    /// [link]: #tymethod.set_ttl
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn ttl(&self) -> io::Result<u32> {
+        self.0.ttl()
+    }
+
+    /// Sets the value for the `IPV6_V6ONLY` option on this socket.
+    ///
+    /// If this is set to `true` then the socket is restricted to sending and
+    /// receiving IPv6 packets only. If this is the case, an IPv4 and an IPv6
+    /// application can each bind the same port at the same time.
+    ///
+    /// If this is set to `false` then the socket can be used to send and
+    /// receive packets from an IPv4-mapped IPv6 address.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
+        self.0.set_only_v6(only_v6)
+    }
+
+    /// Gets the value of the `IPV6_V6ONLY` option for this socket.
+    ///
+    /// For more information about this option, see [`set_only_v6`][link].
+    ///
+    /// [link]: #tymethod.set_only_v6
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn only_v6(&self) -> io::Result<bool> {
+        self.0.only_v6()
+    }
+
+    /// Get the value of the `SO_ERROR` option on this socket.
+    ///
+    /// This will retrieve the stored error in the underlying socket, clearing
+    /// the field in the process. This can be useful for checking errors between
+    /// calls.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
+        self.0.take_error()
+    }
+
+    /// Moves this TCP stream into or out of nonblocking mode.
+    ///
+    /// On Unix this corresponds to calling fcntl, and on Windows this
+    /// corresponds to calling ioctlsocket.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        self.0.set_nonblocking(nonblocking)
+    }
 }
 
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -278,6 +389,67 @@ pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
     pub fn incoming(&self) -> Incoming {
         Incoming { listener: self }
     }
+
+    /// Sets the value for the `IP_TTL` option on this socket.
+    ///
+    /// This value sets the time-to-live field that is used in every packet sent
+    /// from this socket.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
+        self.0.set_ttl(ttl)
+    }
+
+    /// Gets the value of the `IP_TTL` option for this socket.
+    ///
+    /// For more information about this option, see [`set_ttl`][link].
+    ///
+    /// [link]: #tymethod.set_ttl
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn ttl(&self) -> io::Result<u32> {
+        self.0.ttl()
+    }
+
+    /// Sets the value for the `IPV6_V6ONLY` option on this socket.
+    ///
+    /// If this is set to `true` then the socket is restricted to sending and
+    /// receiving IPv6 packets only. In this case two IPv4 and IPv6 applications
+    /// can bind the same port at the same time.
+    ///
+    /// If this is set to `false` then the socket can be used to send and
+    /// receive packets from an IPv4-mapped IPv6 address.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
+        self.0.set_only_v6(only_v6)
+    }
+
+    /// Gets the value of the `IPV6_V6ONLY` option for this socket.
+    ///
+    /// For more information about this option, see [`set_only_v6`][link].
+    ///
+    /// [link]: #tymethod.set_only_v6
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn only_v6(&self) -> io::Result<bool> {
+        self.0.only_v6()
+    }
+
+    /// Get the value of the `SO_ERROR` option on this socket.
+    ///
+    /// This will retrieve the stored error in the underlying socket, clearing
+    /// the field in the process. This can be useful for checking errors between
+    /// calls.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
+        self.0.take_error()
+    }
+
+    /// Moves this TCP stream into or out of nonblocking mode.
+    ///
+    /// On Unix this corresponds to calling fcntl, and on Windows this
+    /// corresponds to calling ioctlsocket.
+    #[stable(feature = "net2_mutators", since = "1.9.0")]
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        self.0.set_nonblocking(nonblocking)
+    }
 }
 
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -969,4 +1141,63 @@ fn test_read_with_timeout() {
         assert!(start.elapsed() > Duration::from_millis(400));
         drop(listener);
     }
+
+    #[test]
+    fn nodelay() {
+        let addr = next_test_ip4();
+        let _listener = t!(TcpListener::bind(&addr));
+
+        let stream = t!(TcpStream::connect(&("localhost", addr.port())));
+
+        assert_eq!(false, t!(stream.nodelay()));
+        t!(stream.set_nodelay(true));
+        assert_eq!(true, t!(stream.nodelay()));
+        t!(stream.set_nodelay(false));
+        assert_eq!(false, t!(stream.nodelay()));
+    }
+
+    #[test]
+    fn keepalive() {
+        let addr = next_test_ip4();
+        let _listener = t!(TcpListener::bind(&addr));
+
+        let stream = t!(TcpStream::connect(&("localhost", addr.port())));
+        let dur = Duration::new(15410, 0);
+
+        assert_eq!(None, t!(stream.keepalive()));
+        t!(stream.set_keepalive(Some(dur)));
+        assert_eq!(Some(dur), t!(stream.keepalive()));
+        t!(stream.set_keepalive(None));
+        assert_eq!(None, t!(stream.keepalive()));
+    }
+
+    #[test]
+    fn ttl() {
+        let ttl = 100;
+
+        let addr = next_test_ip4();
+        let listener = t!(TcpListener::bind(&addr));
+
+        t!(listener.set_ttl(ttl));
+        assert_eq!(ttl, t!(listener.ttl()));
+
+        let stream = t!(TcpStream::connect(&("localhost", addr.port())));
+
+        t!(stream.set_ttl(ttl));
+        assert_eq!(ttl, t!(stream.ttl()));
+    }
+
+    #[test]
+    fn set_nonblocking() {
+        let addr = next_test_ip4();
+        let listener = t!(TcpListener::bind(&addr));
+
+        t!(listener.set_nonblocking(true));
+        t!(listener.set_nonblocking(false));
+
+        let stream = t!(TcpStream::connect(&("localhost", addr.port())));
+
+        t!(stream.set_nonblocking(true));
+        t!(stream.set_nonblocking(false));
+    }
 }
index 1cb9303a9fc4b5a984a1795942c4fff6d32ac63f..0ac4056de8e05b67c026466269ce733b77c67c32 100644 (file)
@@ -228,6 +228,54 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
     pub fn duplicate(&self) -> io::Result<TcpStream> {
         self.inner.duplicate().map(|s| TcpStream { inner: s })
     }
+
+    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
+        setsockopt(&self.inner, c::IPPROTO_TCP, c::TCP_NODELAY, nodelay as c_int)
+    }
+
+    pub fn nodelay(&self) -> io::Result<bool> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_TCP, c::TCP_NODELAY));
+        Ok(raw != 0)
+    }
+
+    pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
+        self.inner.set_keepalive(keepalive)
+    }
+
+    pub fn keepalive(&self) -> io::Result<Option<Duration>> {
+        self.inner.keepalive()
+    }
+
+    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
+        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int)
+    }
+
+    pub fn ttl(&self) -> io::Result<u32> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL));
+        Ok(raw as u32)
+    }
+
+    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
+        setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int)
+    }
+
+    pub fn only_v6(&self) -> io::Result<bool> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY));
+        Ok(raw != 0)
+    }
+
+    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR));
+        if raw == 0 {
+            Ok(None)
+        } else {
+            Ok(Some(io::Error::from_raw_os_error(raw as i32)))
+        }
+    }
+
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        self.inner.set_nonblocking(nonblocking)
+    }
 }
 
 impl FromInner<Socket> for TcpStream {
@@ -307,6 +355,37 @@ pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
     pub fn duplicate(&self) -> io::Result<TcpListener> {
         self.inner.duplicate().map(|s| TcpListener { inner: s })
     }
+
+    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
+        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int)
+    }
+
+    pub fn ttl(&self) -> io::Result<u32> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL));
+        Ok(raw as u32)
+    }
+
+    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
+        setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int)
+    }
+
+    pub fn only_v6(&self) -> io::Result<bool> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY));
+        Ok(raw != 0)
+    }
+
+    pub fn take_error(&self) -> io::Result<Option<io::Error>> {
+        let raw: c_int = try!(getsockopt(&self.inner, c::SOL_SOCKET, c::SO_ERROR));
+        if raw == 0 {
+            Ok(None)
+        } else {
+            Ok(Some(io::Error::from_raw_os_error(raw as i32)))
+        }
+    }
+
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        self.inner.set_nonblocking(nonblocking)
+    }
 }
 
 impl FromInner<Socket> for TcpListener {
index 16c369674f0a2e2bf90f8f80952135480fff1328..7a2ac7257afbcf084dc7b0e7691852dc7820619d 100644 (file)
 #[cfg(not(target_os = "linux"))]
 const SOCK_CLOEXEC: c_int = 0;
 
+#[cfg(any(target_os = "openbsd", taret_os = "freebsd"))]
+use libc::SO_KEEPALIVE as TCP_KEEPALIVE;
+#[cfg(any(target_os = "macos", taret_os = "ios"))]
+use libc::TCP_KEEPALIVE;
+#[cfg(not(any(target_os = "openbsd",
+              target_os = "freebsd",
+              target_os = "macos",
+              target_os = "ios")))]
+use libc::TCP_KEEPIDLE as TCP_KEEPALIVE;
+
 pub struct Socket(FileDesc);
 
 pub fn init() {}
@@ -168,6 +178,44 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
         try!(cvt(unsafe { libc::shutdown(self.0.raw(), how) }));
         Ok(())
     }
+
+    pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
+        try!(setsockopt(self,
+                        libc::SOL_SOCKET,
+                        libc::SO_KEEPALIVE,
+                        keepalive.is_some() as libc::c_int));
+        if let Some(dur) = keepalive {
+            let mut raw = dur.as_secs();
+            if dur.subsec_nanos() > 0 {
+                raw = raw.saturating_add(1);
+            }
+
+            let raw = if raw > libc::c_int::max_value() as u64 {
+                libc::c_int::max_value()
+            } else {
+                raw as libc::c_int
+            };
+
+            try!(setsockopt(self, libc::IPPROTO_TCP, TCP_KEEPALIVE, raw));
+        }
+
+        Ok(())
+    }
+
+    pub fn keepalive(&self) -> io::Result<Option<Duration>> {
+        let raw: c_int = try!(getsockopt(self, libc::SOL_SOCKET, libc::SO_KEEPALIVE));
+        if raw == 0 {
+            return Ok(None);
+        }
+
+        let raw: c_int = try!(getsockopt(self, libc::IPPROTO_TCP, TCP_KEEPALIVE));
+        Ok(Some(Duration::from_secs(raw as u64)))
+    }
+
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        let mut nonblocking = nonblocking as libc::c_ulong;
+        cvt(unsafe { libc::ioctl(*self.as_inner(), libc::FIONBIO, &mut nonblocking) }).map(|_| ())
+    }
 }
 
 impl AsInner<c_int> for Socket {
index 5cbfec01bedaa897f1885bf7b596bda55b663329..cc420763fd787306ad27383865ad1ee6102cf580 100644 (file)
 pub type socklen_t = c_int;
 pub type ADDRESS_FAMILY = USHORT;
 
+pub type LPWSAOVERLAPPED_COMPLETION_ROUTINE =
+    Option<unsafe extern "system" fn(dwError: DWORD,
+                                     cbTransferred: DWORD,
+                                     lpOverlapped: LPWSAOVERLAPPED,
+                                     dwFlags: DWORD)>;
+pub type LPWSAOVERLAPPED = *mut OVERLAPPED;
+
 pub const TRUE: BOOL = 1;
 pub const FALSE: BOOL = 0;
 
 pub const FILE_FLAG_BACKUP_SEMANTICS: DWORD = 0x02000000;
 pub const SECURITY_SQOS_PRESENT: DWORD = 0x00100000;
 
+pub const SIO_KEEPALIVE_VALS: DWORD = 0x98000004;
+pub const FIONBIO: c_ulong = 0x8004667e;
+
 #[repr(C)]
 #[derive(Copy)]
 pub struct WIN32_FIND_DATAW {
@@ -775,6 +785,13 @@ pub struct in6_addr {
     pub s6_addr: [u8; 16],
 }
 
+#[repr(C)]
+pub struct tcp_keepalive {
+    pub onoff: c_ulong,
+    pub keepalivetime: c_ulong,
+    pub keepaliveinterval: c_ulong,
+}
+
 #[cfg(all(target_arch = "x86_64", target_env = "gnu"))]
 pub enum UNWIND_HISTORY_TABLE {}
 
@@ -833,6 +850,17 @@ pub fn WSASocketW(af: c_int,
                       lpProtocolInfo: LPWSAPROTOCOL_INFO,
                       g: GROUP,
                       dwFlags: DWORD) -> SOCKET;
+    pub fn WSAIoctl(s: SOCKET,
+                    dwIoControlCode: DWORD,
+                    lpvInBuffer: LPVOID,
+                    cbInBuffer: DWORD,
+                    lpvOutBuffer: LPVOID,
+                    cbOutBuffer: DWORD,
+                    lpcbBytesReturned: LPDWORD,
+                    lpOverlapped: LPWSAOVERLAPPED,
+                    lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE)
+                    -> c_int;
+    pub fn ioctlsocket(s: SOCKET, cmd: c_long, argp: *mut u_long) -> c_int;
     pub fn InitializeCriticalSection(CriticalSection: *mut CRITICAL_SECTION);
     pub fn EnterCriticalSection(CriticalSection: *mut CRITICAL_SECTION);
     pub fn TryEnterCriticalSection(CriticalSection: *mut CRITICAL_SECTION) -> BOOLEAN;
index 49ba8e9c65990570c4192be868b7500cb8257b3c..be13657aaf487be1eb4e4af3d80d1795ee46b836 100644 (file)
@@ -185,6 +185,60 @@ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
         try!(cvt(unsafe { c::shutdown(self.0, how) }));
         Ok(())
     }
+
+    pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
+        let ms = keepalive.map(sys::dur2timeout).unwrap_or(c::INFINITE);
+        let ka = c::tcp_keepalive {
+            onoff: keepalive.is_some() as c::c_ulong,
+            keepalivetime: ms as c::c_ulong,
+            keepaliveinterval: ms as c::c_ulong,
+        };
+        sys::cvt(unsafe {
+            c::WSAIoctl(self.0,
+                        c::SIO_KEEPALIVE_VALS,
+                        &ka as *const _ as *mut _,
+                        mem::size_of_val(&ka) as c::DWORD,
+                        0 as *mut _,
+                        0,
+                        0 as *mut _,
+                        0 as *mut _,
+                        None)
+        }).map(|_| ())
+    }
+
+    pub fn keepalive(&self) -> io::Result<Option<Duration>> {
+        let mut ka = c::tcp_keepalive {
+            onoff: 0,
+            keepalivetime: 0,
+            keepaliveinterval: 0,
+        };
+        try!(sys::cvt(unsafe {
+            WSAIoctl(self.0,
+                     c::SIO_KEEPALIVE_VALS,
+                     0 as *mut _,
+                     0,
+                     &mut ka as *mut _ as *mut _,
+                     mem::size_of_val(&ka) as c::DWORD,
+                     0 as *mut _,
+                     0 as *mut _,
+                     None)
+        }));
+
+        if ka.onoff == 0 {
+            Ok(None)
+        } else {
+            let secs = ka.keepaliveinterval / 1000;
+            let nsec = (ka.keepaliveinterval % 1000) * 1000000;
+            Ok(Some(Duration::new(secs as u64, nsec as u32)))
+        }
+    }
+
+    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
+        let mut nonblocking = nonblocking as c::c_ulong;
+        sys::cvt(unsafe {
+            c::ioctlsocket(self.0, c::FIONBIO as c::c_int, &mut nonblocking)
+        }).map(|_| ())
+    }
 }
 
 impl Drop for Socket {