]> git.lizzy.rs Git - rust.git/commitdiff
native: clone/close_accept for win32 pipes
authorAlex Crichton <alex@alexcrichton.com>
Tue, 15 Jul 2014 19:42:40 +0000 (12:42 -0700)
committerAlex Crichton <alex@alexcrichton.com>
Mon, 25 Aug 2014 00:08:14 +0000 (17:08 -0700)
This commits takes a similar strategy to the previous commit to implement
close_accept and clone for the native win32 pipes implementation.

Closes #15595

src/libnative/io/c_windows.rs
src/libnative/io/net.rs
src/libnative/io/pipe_unix.rs
src/libnative/io/pipe_windows.rs
src/libnative/io/util.rs
src/librustuv/net.rs
src/librustuv/pipe.rs
src/libstd/io/net/tcp.rs
src/libstd/io/net/unix.rs
src/test/run-pass/tcp-accept-stress.rs

index 3bd850b5aac7fcc2daeb8d2c8cd01dfcf90ce3f0..909b37895b7b5495036f46c027fa8ca0d8a7ccc8 100644 (file)
@@ -115,6 +115,12 @@ pub fn getsockopt(sockfd: libc::SOCKET,
                       optval: *mut libc::c_char,
                       optlen: *mut libc::c_int) -> libc::c_int;
 
+    pub fn SetEvent(hEvent: libc::HANDLE) -> libc::BOOL;
+    pub fn WaitForMultipleObjects(nCount: libc::DWORD,
+                                  lpHandles: *const libc::HANDLE,
+                                  bWaitAll: libc::BOOL,
+                                  dwMilliseconds: libc::DWORD) -> libc::DWORD;
+
     pub fn CancelIo(hFile: libc::HANDLE) -> libc::BOOL;
     pub fn CancelIoEx(hFile: libc::HANDLE,
                       lpOverlapped: libc::LPOVERLAPPED) -> libc::BOOL;
index bbfc8aff1b74175721c729cb43e2ce1e859a93e9..368b5914444ac9509f12fca2e28f23258b674cca 100644 (file)
@@ -15,7 +15,7 @@
 use std::rt::mutex;
 use std::rt::rtio;
 use std::rt::rtio::{IoResult, IoError};
-use std::sync::atomics;
+use std::sync::atomic;
 
 use super::{retry, keep_going};
 use super::c;
@@ -456,7 +456,7 @@ pub fn native_listen(self, backlog: int) -> IoResult<TcpAcceptor> {
                         listener: self,
                         reader: reader,
                         writer: writer,
-                        closed: atomics::AtomicBool::new(false),
+                        closed: atomic::AtomicBool::new(false),
                     }),
                     deadline: 0,
                 })
@@ -476,7 +476,7 @@ pub fn native_listen(self, backlog: int) -> IoResult<TcpAcceptor> {
                         listener: self,
                         abort: try!(os::Event::new()),
                         accept: accept,
-                        closed: atomics::AtomicBool::new(false),
+                        closed: atomic::AtomicBool::new(false),
                     }),
                     deadline: 0,
                 })
@@ -510,7 +510,7 @@ struct AcceptorInner {
     listener: TcpListener,
     reader: FileDesc,
     writer: FileDesc,
-    closed: atomics::AtomicBool,
+    closed: atomic::AtomicBool,
 }
 
 #[cfg(windows)]
@@ -518,7 +518,7 @@ struct AcceptorInner {
     listener: TcpListener,
     abort: os::Event,
     accept: os::Event,
-    closed: atomics::AtomicBool,
+    closed: atomic::AtomicBool,
 }
 
 impl TcpAcceptor {
@@ -542,7 +542,7 @@ pub fn native_accept(&mut self) -> IoResult<TcpStream> {
         // self-pipe is never written to unless close_accept() is called.
         let deadline = if self.deadline == 0 {None} else {Some(self.deadline)};
 
-        while !self.inner.closed.load(atomics::SeqCst) {
+        while !self.inner.closed.load(atomic::SeqCst) {
             match retry(|| unsafe {
                 libc::accept(self.fd(), ptr::mut_null(), ptr::mut_null())
             }) {
@@ -581,12 +581,12 @@ pub fn native_accept(&mut self) -> IoResult<TcpStream> {
         // stolen, so we do all of this in a loop as well.
         let events = [self.inner.abort.handle(), self.inner.accept.handle()];
 
-        while !self.inner.closed.load(atomics::SeqCst) {
+        while !self.inner.closed.load(atomic::SeqCst) {
             let ms = if self.deadline == 0 {
                 c::WSA_INFINITE as u64
             } else {
                 let now = ::io::timer::now();
-                if self.deadline < now {0} else {now - self.deadline}
+                if self.deadline < now {0} else {self.deadline - now}
             };
             let ret = unsafe {
                 c::WSAWaitForMultipleEvents(2, events.as_ptr(), libc::FALSE,
@@ -600,7 +600,6 @@ pub fn native_accept(&mut self) -> IoResult<TcpStream> {
                 c::WSA_WAIT_EVENT_0 => break,
                 n => assert_eq!(n, c::WSA_WAIT_EVENT_0 + 1),
             }
-            println!("woke up");
 
             let mut wsaevents: c::WSANETWORKEVENTS = unsafe { mem::zeroed() };
             let ret = unsafe {
@@ -614,7 +613,19 @@ pub fn native_accept(&mut self) -> IoResult<TcpStream> {
             } {
                 -1 if util::wouldblock() => {}
                 -1 => return Err(os::last_error()),
-                fd => return Ok(TcpStream::new(Inner::new(fd))),
+
+                // Accepted sockets inherit the same properties as the caller,
+                // so we need to deregister our event and switch the socket back
+                // to blocking mode
+                fd => {
+                    let stream = TcpStream::new(Inner::new(fd));
+                    let ret = unsafe {
+                        c::WSAEventSelect(fd, events[1], 0)
+                    };
+                    if ret != 0 { return Err(os::last_error()) }
+                    try!(util::set_nonblocking(fd, false));
+                    return Ok(stream)
+                }
             }
         }
 
@@ -648,7 +659,7 @@ fn clone(&self) -> Box<rtio::RtioTcpAcceptor + Send> {
 
     #[cfg(unix)]
     fn close_accept(&mut self) -> IoResult<()> {
-        self.inner.closed.store(true, atomics::SeqCst);
+        self.inner.closed.store(true, atomic::SeqCst);
         let mut fd = FileDesc::new(self.inner.writer.fd(), false);
         match fd.inner_write([0]) {
             Ok(..) => Ok(()),
@@ -659,7 +670,7 @@ fn close_accept(&mut self) -> IoResult<()> {
 
     #[cfg(windows)]
     fn close_accept(&mut self) -> IoResult<()> {
-        self.inner.closed.store(true, atomics::SeqCst);
+        self.inner.closed.store(true, atomic::SeqCst);
         let ret = unsafe { c::WSASetEvent(self.inner.abort.handle()) };
         if ret == libc::TRUE {
             Ok(())
index 4ad8383e6f8052ae10ba9a36a328928f960e5622..a3564dfe2cc9a5905cdc9b6464787a0f8102e5da 100644 (file)
@@ -15,7 +15,7 @@
 use std::rt::mutex;
 use std::rt::rtio;
 use std::rt::rtio::{IoResult, IoError};
-use std::sync::atomics;
+use std::sync::atomic;
 
 use super::retry;
 use super::net;
@@ -239,7 +239,7 @@ pub fn native_listen(self, backlog: int) -> IoResult<UnixAcceptor> {
                         listener: self,
                         reader: reader,
                         writer: writer,
-                        closed: atomics::AtomicBool::new(false),
+                        closed: atomic::AtomicBool::new(false),
                     }),
                     deadline: 0,
                 })
@@ -267,7 +267,7 @@ struct AcceptorInner {
     listener: UnixListener,
     reader: FileDesc,
     writer: FileDesc,
-    closed: atomics::AtomicBool,
+    closed: atomic::AtomicBool,
 }
 
 impl UnixAcceptor {
@@ -276,7 +276,7 @@ fn fd(&self) -> fd_t { self.inner.listener.fd() }
     pub fn native_accept(&mut self) -> IoResult<UnixStream> {
         let deadline = if self.deadline == 0 {None} else {Some(self.deadline)};
 
-        while !self.inner.closed.load(atomics::SeqCst) {
+        while !self.inner.closed.load(atomic::SeqCst) {
             unsafe {
                 let mut storage: libc::sockaddr_storage = mem::zeroed();
                 let storagep = &mut storage as *mut libc::sockaddr_storage;
@@ -317,7 +317,7 @@ fn clone(&self) -> Box<rtio::RtioUnixAcceptor + Send> {
 
     #[cfg(unix)]
     fn close_accept(&mut self) -> IoResult<()> {
-        self.inner.closed.store(true, atomics::SeqCst);
+        self.inner.closed.store(true, atomic::SeqCst);
         let mut fd = FileDesc::new(self.inner.writer.fd(), false);
         match fd.inner_write([0]) {
             Ok(..) => Ok(()),
index 4d01230cbd9771295c5042d37a554a1cece1c395..95afa11f4a9a0c8f72712d1524dc2678a2a41089 100644 (file)
@@ -169,23 +169,30 @@ unsafe fn pipe(name: *const u16, init: bool) -> libc::HANDLE {
 }
 
 pub fn await(handle: libc::HANDLE, deadline: u64,
-             overlapped: &mut libc::OVERLAPPED) -> bool {
-    if deadline == 0 { return true }
+             events: &[libc::HANDLE]) -> IoResult<uint> {
+    use libc::consts::os::extra::{WAIT_FAILED, WAIT_TIMEOUT, WAIT_OBJECT_0};
 
     // If we've got a timeout, use WaitForSingleObject in tandem with CancelIo
     // to figure out if we should indeed get the result.
-    let now = ::io::timer::now();
-    let timeout = deadline < now || unsafe {
-        let ms = (deadline - now) as libc::DWORD;
-        let r = libc::WaitForSingleObject(overlapped.hEvent,
-                                          ms);
-        r != libc::WAIT_OBJECT_0
-    };
-    if timeout {
-        unsafe { let _ = c::CancelIo(handle); }
-        false
+    let ms = if deadline == 0 {
+        libc::INFINITE as u64
     } else {
-        true
+        let now = ::io::timer::now();
+        if deadline < now {0} else {deadline - now}
+    };
+    let ret = unsafe {
+        c::WaitForMultipleObjects(events.len() as libc::DWORD,
+                                  events.as_ptr(),
+                                  libc::FALSE,
+                                  ms as libc::DWORD)
+    };
+    match ret {
+        WAIT_FAILED => Err(super::last_error()),
+        WAIT_TIMEOUT => unsafe {
+            let _ = c::CancelIo(handle);
+            Err(util::timeout("operation timed out"))
+        },
+        n => Ok((n - WAIT_OBJECT_0) as uint)
     }
 }
 
@@ -390,8 +397,8 @@ fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
         drop(guard);
         loop {
             // Process a timeout if one is pending
-            let succeeded = await(self.handle(), self.read_deadline,
-                                  &mut overlapped);
+            let wait_succeeded = await(self.handle(), self.read_deadline,
+                                       [overlapped.hEvent]);
 
             let ret = unsafe {
                 libc::GetOverlappedResult(self.handle(),
@@ -408,7 +415,7 @@ fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
 
             // If the reading half is now closed, then we're done. If we woke up
             // because the writing half was closed, keep trying.
-            if !succeeded {
+            if wait_succeeded.is_err() {
                 return Err(util::timeout("read timed out"))
             }
             if self.read_closed() {
@@ -458,8 +465,8 @@ fn write(&mut self, buf: &[u8]) -> IoResult<()> {
                     })
                 }
                 // Process a timeout if one is pending
-                let succeeded = await(self.handle(), self.write_deadline,
-                                      &mut overlapped);
+                let wait_succeeded = await(self.handle(), self.write_deadline,
+                                           [overlapped.hEvent]);
                 let ret = unsafe {
                     libc::GetOverlappedResult(self.handle(),
                                               &mut overlapped,
@@ -473,7 +480,7 @@ fn write(&mut self, buf: &[u8]) -> IoResult<()> {
                     if os::errno() != libc::ERROR_OPERATION_ABORTED as uint {
                         return Err(super::last_error())
                     }
-                    if !succeeded {
+                    if !wait_succeeded.is_ok() {
                         let amt = offset + bytes_written as uint;
                         return if amt > 0 {
                             Err(IoError {
@@ -577,6 +584,10 @@ pub fn native_listen(self) -> IoResult<UnixAcceptor> {
             listener: self,
             event: try!(Event::new(true, false)),
             deadline: 0,
+            inner: Arc::new(AcceptorState {
+                abort: try!(Event::new(true, false)),
+                closed: atomic::AtomicBool::new(false),
+            }),
         })
     }
 }
@@ -597,11 +608,17 @@ fn listen(self: Box<UnixListener>)
 }
 
 pub struct UnixAcceptor {
+    inner: Arc<AcceptorState>,
     listener: UnixListener,
     event: Event,
     deadline: u64,
 }
 
+struct AcceptorState {
+    abort: Event,
+    closed: atomic::AtomicBool,
+}
+
 impl UnixAcceptor {
     pub fn native_accept(&mut self) -> IoResult<UnixStream> {
         // This function has some funky implementation details when working with
@@ -638,6 +655,10 @@ pub fn native_accept(&mut self) -> IoResult<UnixStream> {
         // using the original server pipe.
         let handle = self.listener.handle;
 
+        // If we've had an artifical call to close_accept, be sure to never
+        // proceed in accepting new clients in the future
+        if self.inner.closed.load(atomic::SeqCst) { return Err(util::eof()) }
+
         let name = try!(to_utf16(&self.listener.name));
 
         // Once we've got a "server handle", we need to wait for a client to
@@ -652,7 +673,9 @@ pub fn native_accept(&mut self) -> IoResult<UnixStream> {
 
             if err == libc::ERROR_IO_PENDING as libc::DWORD {
                 // Process a timeout if one is pending
-                let _ = await(handle, self.deadline, &mut overlapped);
+                let wait_succeeded = await(handle, self.deadline,
+                                           [self.inner.abort.handle(),
+                                            overlapped.hEvent]);
 
                 // This will block until the overlapped I/O is completed. The
                 // timeout was previously handled, so this will either block in
@@ -665,7 +688,11 @@ pub fn native_accept(&mut self) -> IoResult<UnixStream> {
                                               libc::TRUE)
                 };
                 if ret == 0 {
-                    err = unsafe { libc::GetLastError() };
+                    if wait_succeeded.is_ok() {
+                        err = unsafe { libc::GetLastError() };
+                    } else {
+                        return Err(util::timeout("accept timed out"))
+                    }
                 } else {
                     // we succeeded, bypass the check below
                     err = libc::ERROR_PIPE_CONNECTED as libc::DWORD;
@@ -711,11 +738,32 @@ fn set_timeout(&mut self, timeout: Option<u64>) {
     }
 
     fn clone(&self) -> Box<rtio::RtioUnixAcceptor + Send> {
-        fail!()
+        let name = to_utf16(&self.listener.name).ok().unwrap();
+        box UnixAcceptor {
+            inner: self.inner.clone(),
+            event: Event::new(true, false).ok().unwrap(),
+            deadline: 0,
+            listener: UnixListener {
+                name: self.listener.name.clone(),
+                handle: unsafe {
+                    let p = pipe(name.as_ptr(), false) ;
+                    assert!(p != libc::INVALID_HANDLE_VALUE as libc::HANDLE);
+                    p
+                },
+            },
+        } as Box<rtio::RtioUnixAcceptor + Send>
     }
 
     fn close_accept(&mut self) -> IoResult<()> {
-        fail!()
+        self.inner.closed.store(true, atomic::SeqCst);
+        let ret = unsafe {
+            c::SetEvent(self.inner.abort.handle())
+        };
+        if ret == 0 {
+            Err(super::last_error())
+        } else {
+            Ok(())
+        }
     }
 }
 
index c5b1bbec4f1631988026fa98d25f770906440905..078989b058180328acc784722f031cabc2fe7b66 100644 (file)
@@ -175,6 +175,9 @@ pub fn await(fds: &[net::sock_t], deadline: Option<u64>,
         c::fd_set(&mut set, fd);
         max = cmp::max(max, fd + 1);
     }
+    if cfg!(windows) {
+        max = fds.len() as net::sock_t;
+    }
 
     let (read, write) = match status {
         Readable => (&mut set as *mut _, ptr::mut_null()),
index b13598402470575f84cde7e8d90a1fd8d1eda283..84ef9deaf922fd3555dfed54e55a9a76bf049e2d 100644 (file)
@@ -387,7 +387,7 @@ fn socket_name(&mut self) -> Result<rtio::SocketAddr, IoError> {
 }
 
 impl rtio::RtioTcpListener for TcpListener {
-    fn listen(self: Box<TcpListener>)
+    fn listen(mut self: Box<TcpListener>)
               -> Result<Box<rtio::RtioTcpAcceptor + Send>, IoError> {
         let _m = self.fire_homing_missile();
 
index aa89e5e5f034e24cdafaa8758760a03241e7796b..9ece6525e1e82b07239e342a0e54a6811d4c76fa 100644 (file)
@@ -245,7 +245,7 @@ pub fn bind(io: &mut UvIoFactory, name: &CString)
 }
 
 impl rtio::RtioUnixListener for PipeListener {
-    fn listen(self: Box<PipeListener>)
+    fn listen(mut self: Box<PipeListener>)
               -> IoResult<Box<rtio::RtioUnixAcceptor + Send>> {
         let _m = self.fire_homing_missile();
 
index ebc3940c16f6997a1bcfc89115cb3b35b715cfae..a6fdceaa3739fc75d23c435bc4ef82e2e185b980 100644 (file)
@@ -461,8 +461,7 @@ impl TcpAcceptor {
     ///
     /// ```
     /// # #![allow(experimental)]
-    /// use std::io::TcpListener;
-    /// use std::io::{Listener, Acceptor, TimedOut};
+    /// use std::io::{TcpListener, Listener, Acceptor, EndOfFile};
     ///
     /// let mut a = TcpListener::bind("127.0.0.1", 8482).listen().unwrap();
     /// let a2 = a.clone();
index 74f024a844e2c5f8f94c42b03d1aa01d484b424d..3bd31c6a839edf2c7e89528b9703ba4885672886 100644 (file)
@@ -731,6 +731,7 @@ pub fn smalltest(server: proc(UnixStream):Send, client: proc(UnixStream):Send) {
         rx2.recv();
     })
 
+    #[cfg(not(windows))]
     iotest!(fn clone_accept_smoke() {
         let addr = next_test_unix();
         let l = UnixListener::bind(&addr);
@@ -746,6 +747,7 @@ pub fn smalltest(server: proc(UnixStream):Send, client: proc(UnixStream):Send) {
         });
 
         assert!(a.accept().is_ok());
+        drop(a);
         assert!(a2.accept().is_ok());
     })
 
index 3e420e45cfce6b33b7a0694e2bdbf52f312f6538..b8470ef7b8fac1f5484b011dffa52c11fd049b03 100644 (file)
@@ -15,7 +15,7 @@
 extern crate native;
 
 use std::io::{TcpListener, Listener, Acceptor, EndOfFile, TcpStream};
-use std::sync::{atomics, Arc};
+use std::sync::{atomic, Arc};
 use std::task::TaskBuilder;
 use native::NativeTaskBuilder;
 
@@ -38,7 +38,7 @@ fn test() {
     let mut l = TcpListener::bind("127.0.0.1", 0).unwrap();
     let addr = l.socket_name().unwrap();
     let mut a = l.listen().unwrap();
-    let cnt = Arc::new(atomics::AtomicUint::new(0));
+    let cnt = Arc::new(atomic::AtomicUint::new(0));
 
     let (tx, rx) = channel();
     for _ in range(0, N) {
@@ -52,7 +52,7 @@ fn test() {
                 match a.accept() {
                     Ok(..) => {
                         mycnt += 1;
-                        if cnt.fetch_add(1, atomics::SeqCst) == N * M - 1 {
+                        if cnt.fetch_add(1, atomic::SeqCst) == N * M - 1 {
                             break
                         }
                     }
@@ -89,6 +89,6 @@ fn test() {
     assert_eq!(rx.iter().take(N - 1).count(), N - 1);
 
     // Everything should have been accepted.
-    assert_eq!(cnt.load(atomics::SeqCst), N * M);
+    assert_eq!(cnt.load(atomic::SeqCst), N * M);
 }