]> git.lizzy.rs Git - rust.git/blob - library/std/src/sys/windows/net.rs
Rollup merge of #88201 - spastorino:tait-incomplete-inference-test, r=oli-obk
[rust.git] / library / std / src / sys / windows / net.rs
1 #![unstable(issue = "none", feature = "windows_net")]
2
3 use crate::cmp;
4 use crate::io::{self, IoSlice, IoSliceMut, Read};
5 use crate::mem;
6 use crate::net::{Shutdown, SocketAddr};
7 use crate::os::windows::io::{
8     AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
9 };
10 use crate::ptr;
11 use crate::sync::Once;
12 use crate::sys;
13 use crate::sys::c;
14 use crate::sys_common::net;
15 use crate::sys_common::{AsInner, FromInner, IntoInner};
16 use crate::time::Duration;
17
18 use libc::{c_int, c_long, c_ulong};
19
20 pub type wrlen_t = i32;
21
22 pub mod netc {
23     pub use crate::sys::c::ADDRESS_FAMILY as sa_family_t;
24     pub use crate::sys::c::ADDRINFOA as addrinfo;
25     pub use crate::sys::c::SOCKADDR as sockaddr;
26     pub use crate::sys::c::SOCKADDR_STORAGE_LH as sockaddr_storage;
27     pub use crate::sys::c::*;
28 }
29
30 pub struct Socket(OwnedSocket);
31
32 static INIT: Once = Once::new();
33
34 /// Checks whether the Windows socket interface has been started already, and
35 /// if not, starts it.
36 pub fn init() {
37     INIT.call_once(|| unsafe {
38         let mut data: c::WSADATA = mem::zeroed();
39         let ret = c::WSAStartup(
40             0x202, // version 2.2
41             &mut data,
42         );
43         assert_eq!(ret, 0);
44     });
45 }
46
47 pub fn cleanup() {
48     if INIT.is_completed() {
49         // only close the socket interface if it has actually been started
50         unsafe {
51             c::WSACleanup();
52         }
53     }
54 }
55
56 /// Returns the last error from the Windows socket interface.
57 fn last_error() -> io::Error {
58     io::Error::from_raw_os_error(unsafe { c::WSAGetLastError() })
59 }
60
61 #[doc(hidden)]
62 pub trait IsMinusOne {
63     fn is_minus_one(&self) -> bool;
64 }
65
66 macro_rules! impl_is_minus_one {
67     ($($t:ident)*) => ($(impl IsMinusOne for $t {
68         fn is_minus_one(&self) -> bool {
69             *self == -1
70         }
71     })*)
72 }
73
74 impl_is_minus_one! { i8 i16 i32 i64 isize }
75
76 /// Checks if the signed integer is the Windows constant `SOCKET_ERROR` (-1)
77 /// and if so, returns the last error from the Windows socket interface. This
78 /// function must be called before another call to the socket API is made.
79 pub fn cvt<T: IsMinusOne>(t: T) -> io::Result<T> {
80     if t.is_minus_one() { Err(last_error()) } else { Ok(t) }
81 }
82
83 /// A variant of `cvt` for `getaddrinfo` which return 0 for a success.
84 pub fn cvt_gai(err: c_int) -> io::Result<()> {
85     if err == 0 { Ok(()) } else { Err(last_error()) }
86 }
87
88 /// Just to provide the same interface as sys/unix/net.rs
89 pub fn cvt_r<T, F>(mut f: F) -> io::Result<T>
90 where
91     T: IsMinusOne,
92     F: FnMut() -> T,
93 {
94     cvt(f())
95 }
96
97 impl Socket {
98     pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
99         let family = match *addr {
100             SocketAddr::V4(..) => c::AF_INET,
101             SocketAddr::V6(..) => c::AF_INET6,
102         };
103         let socket = unsafe {
104             c::WSASocketW(
105                 family,
106                 ty,
107                 0,
108                 ptr::null_mut(),
109                 0,
110                 c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT,
111             )
112         };
113
114         if socket != c::INVALID_SOCKET {
115             unsafe { Ok(Self::from_raw_socket(socket)) }
116         } else {
117             let error = unsafe { c::WSAGetLastError() };
118
119             if error != c::WSAEPROTOTYPE && error != c::WSAEINVAL {
120                 return Err(io::Error::from_raw_os_error(error));
121             }
122
123             let socket =
124                 unsafe { c::WSASocketW(family, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED) };
125
126             if socket == c::INVALID_SOCKET {
127                 return Err(last_error());
128             }
129
130             unsafe {
131                 let socket = Self::from_raw_socket(socket);
132                 socket.set_no_inherit()?;
133                 Ok(socket)
134             }
135         }
136     }
137
138     pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
139         self.set_nonblocking(true)?;
140         let result = {
141             let (addrp, len) = addr.into_inner();
142             let result = unsafe { c::connect(self.as_raw_socket(), addrp, len) };
143             cvt(result).map(drop)
144         };
145         self.set_nonblocking(false)?;
146
147         match result {
148             Err(ref error) if error.kind() == io::ErrorKind::WouldBlock => {
149                 if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
150                     return Err(io::Error::new_const(
151                         io::ErrorKind::InvalidInput,
152                         &"cannot set a 0 duration timeout",
153                     ));
154                 }
155
156                 let mut timeout = c::timeval {
157                     tv_sec: timeout.as_secs() as c_long,
158                     tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
159                 };
160
161                 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
162                     timeout.tv_usec = 1;
163                 }
164
165                 let fds = {
166                     let mut fds = unsafe { mem::zeroed::<c::fd_set>() };
167                     fds.fd_count = 1;
168                     fds.fd_array[0] = self.as_raw_socket();
169                     fds
170                 };
171
172                 let mut writefds = fds;
173                 let mut errorfds = fds;
174
175                 let count = {
176                     let result = unsafe {
177                         c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout)
178                     };
179                     cvt(result)?
180                 };
181
182                 match count {
183                     0 => {
184                         Err(io::Error::new_const(io::ErrorKind::TimedOut, &"connection timed out"))
185                     }
186                     _ => {
187                         if writefds.fd_count != 1 {
188                             if let Some(e) = self.take_error()? {
189                                 return Err(e);
190                             }
191                         }
192
193                         Ok(())
194                     }
195                 }
196             }
197             _ => result,
198         }
199     }
200
201     pub fn accept(&self, storage: *mut c::SOCKADDR, len: *mut c_int) -> io::Result<Socket> {
202         let socket = unsafe { c::accept(self.as_raw_socket(), storage, len) };
203
204         match socket {
205             c::INVALID_SOCKET => Err(last_error()),
206             _ => unsafe { Ok(Self::from_raw_socket(socket)) },
207         }
208     }
209
210     pub fn duplicate(&self) -> io::Result<Socket> {
211         let mut info = unsafe { mem::zeroed::<c::WSAPROTOCOL_INFO>() };
212         let result = unsafe {
213             c::WSADuplicateSocketW(self.as_raw_socket(), c::GetCurrentProcessId(), &mut info)
214         };
215         cvt(result)?;
216         let socket = unsafe {
217             c::WSASocketW(
218                 info.iAddressFamily,
219                 info.iSocketType,
220                 info.iProtocol,
221                 &mut info,
222                 0,
223                 c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT,
224             )
225         };
226
227         if socket != c::INVALID_SOCKET {
228             unsafe { Ok(Self::from_inner(OwnedSocket::from_raw_socket(socket))) }
229         } else {
230             let error = unsafe { c::WSAGetLastError() };
231
232             if error != c::WSAEPROTOTYPE && error != c::WSAEINVAL {
233                 return Err(io::Error::from_raw_os_error(error));
234             }
235
236             let socket = unsafe {
237                 c::WSASocketW(
238                     info.iAddressFamily,
239                     info.iSocketType,
240                     info.iProtocol,
241                     &mut info,
242                     0,
243                     c::WSA_FLAG_OVERLAPPED,
244                 )
245             };
246
247             if socket == c::INVALID_SOCKET {
248                 return Err(last_error());
249             }
250
251             unsafe {
252                 let socket = Self::from_inner(OwnedSocket::from_raw_socket(socket));
253                 socket.set_no_inherit()?;
254                 Ok(socket)
255             }
256         }
257     }
258
259     fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
260         // On unix when a socket is shut down all further reads return 0, so we
261         // do the same on windows to map a shut down socket to returning EOF.
262         let length = cmp::min(buf.len(), i32::MAX as usize) as i32;
263         let result =
264             unsafe { c::recv(self.as_raw_socket(), buf.as_mut_ptr() as *mut _, length, flags) };
265
266         match result {
267             c::SOCKET_ERROR => {
268                 let error = unsafe { c::WSAGetLastError() };
269
270                 if error == c::WSAESHUTDOWN {
271                     Ok(0)
272                 } else {
273                     Err(io::Error::from_raw_os_error(error))
274                 }
275             }
276             _ => Ok(result as usize),
277         }
278     }
279
280     pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
281         self.recv_with_flags(buf, 0)
282     }
283
284     pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
285         // On unix when a socket is shut down all further reads return 0, so we
286         // do the same on windows to map a shut down socket to returning EOF.
287         let length = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD;
288         let mut nread = 0;
289         let mut flags = 0;
290         let result = unsafe {
291             c::WSARecv(
292                 self.as_raw_socket(),
293                 bufs.as_mut_ptr() as *mut c::WSABUF,
294                 length,
295                 &mut nread,
296                 &mut flags,
297                 ptr::null_mut(),
298                 ptr::null_mut(),
299             )
300         };
301
302         match result {
303             0 => Ok(nread as usize),
304             _ => {
305                 let error = unsafe { c::WSAGetLastError() };
306
307                 if error == c::WSAESHUTDOWN {
308                     Ok(0)
309                 } else {
310                     Err(io::Error::from_raw_os_error(error))
311                 }
312             }
313         }
314     }
315
316     #[inline]
317     pub fn is_read_vectored(&self) -> bool {
318         true
319     }
320
321     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
322         self.recv_with_flags(buf, c::MSG_PEEK)
323     }
324
325     fn recv_from_with_flags(
326         &self,
327         buf: &mut [u8],
328         flags: c_int,
329     ) -> io::Result<(usize, SocketAddr)> {
330         let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE_LH>() };
331         let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
332         let length = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;
333
334         // On unix when a socket is shut down all further reads return 0, so we
335         // do the same on windows to map a shut down socket to returning EOF.
336         let result = unsafe {
337             c::recvfrom(
338                 self.as_raw_socket(),
339                 buf.as_mut_ptr() as *mut _,
340                 length,
341                 flags,
342                 &mut storage as *mut _ as *mut _,
343                 &mut addrlen,
344             )
345         };
346
347         match result {
348             c::SOCKET_ERROR => {
349                 let error = unsafe { c::WSAGetLastError() };
350
351                 if error == c::WSAESHUTDOWN {
352                     Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
353                 } else {
354                     Err(io::Error::from_raw_os_error(error))
355                 }
356             }
357             _ => Ok((result as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
358         }
359     }
360
361     pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
362         self.recv_from_with_flags(buf, 0)
363     }
364
365     pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
366         self.recv_from_with_flags(buf, c::MSG_PEEK)
367     }
368
369     pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
370         let length = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD;
371         let mut nwritten = 0;
372         let result = unsafe {
373             c::WSASend(
374                 self.as_raw_socket(),
375                 bufs.as_ptr() as *const c::WSABUF as *mut _,
376                 length,
377                 &mut nwritten,
378                 0,
379                 ptr::null_mut(),
380                 ptr::null_mut(),
381             )
382         };
383         cvt(result).map(|_| nwritten as usize)
384     }
385
386     #[inline]
387     pub fn is_write_vectored(&self) -> bool {
388         true
389     }
390
391     pub fn set_timeout(&self, dur: Option<Duration>, kind: c_int) -> io::Result<()> {
392         let timeout = match dur {
393             Some(dur) => {
394                 let timeout = sys::dur2timeout(dur);
395                 if timeout == 0 {
396                     return Err(io::Error::new_const(
397                         io::ErrorKind::InvalidInput,
398                         &"cannot set a 0 duration timeout",
399                     ));
400                 }
401                 timeout
402             }
403             None => 0,
404         };
405         net::setsockopt(self, c::SOL_SOCKET, kind, timeout)
406     }
407
408     pub fn timeout(&self, kind: c_int) -> io::Result<Option<Duration>> {
409         let raw: c::DWORD = net::getsockopt(self, c::SOL_SOCKET, kind)?;
410         if raw == 0 {
411             Ok(None)
412         } else {
413             let secs = raw / 1000;
414             let nsec = (raw % 1000) * 1000000;
415             Ok(Some(Duration::new(secs as u64, nsec as u32)))
416         }
417     }
418
419     #[cfg(not(target_vendor = "uwp"))]
420     fn set_no_inherit(&self) -> io::Result<()> {
421         sys::cvt(unsafe {
422             c::SetHandleInformation(self.as_raw_socket() as c::HANDLE, c::HANDLE_FLAG_INHERIT, 0)
423         })
424         .map(drop)
425     }
426
427     #[cfg(target_vendor = "uwp")]
428     fn set_no_inherit(&self) -> io::Result<()> {
429         Err(io::Error::new_const(io::ErrorKind::Unsupported, &"Unavailable on UWP"))
430     }
431
432     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
433         let how = match how {
434             Shutdown::Write => c::SD_SEND,
435             Shutdown::Read => c::SD_RECEIVE,
436             Shutdown::Both => c::SD_BOTH,
437         };
438         let result = unsafe { c::shutdown(self.as_raw_socket(), how) };
439         cvt(result).map(drop)
440     }
441
442     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
443         let mut nonblocking = nonblocking as c_ulong;
444         let result =
445             unsafe { c::ioctlsocket(self.as_raw_socket(), c::FIONBIO as c_int, &mut nonblocking) };
446         cvt(result).map(drop)
447     }
448
449     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
450         net::setsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY, nodelay as c::BYTE)
451     }
452
453     pub fn nodelay(&self) -> io::Result<bool> {
454         let raw: c::BYTE = net::getsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY)?;
455         Ok(raw != 0)
456     }
457
458     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
459         let raw: c_int = net::getsockopt(self, c::SOL_SOCKET, c::SO_ERROR)?;
460         if raw == 0 { Ok(None) } else { Ok(Some(io::Error::from_raw_os_error(raw as i32))) }
461     }
462
463     // This is used by sys_common code to abstract over Windows and Unix.
464     pub fn as_raw(&self) -> RawSocket {
465         self.as_inner().as_raw_socket()
466     }
467 }
468
469 #[unstable(reason = "not public", issue = "none", feature = "fd_read")]
470 impl<'a> Read for &'a Socket {
471     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
472         (**self).read(buf)
473     }
474 }
475
476 impl AsInner<OwnedSocket> for Socket {
477     fn as_inner(&self) -> &OwnedSocket {
478         &self.0
479     }
480 }
481
482 impl FromInner<OwnedSocket> for Socket {
483     fn from_inner(sock: OwnedSocket) -> Socket {
484         Socket(sock)
485     }
486 }
487
488 impl IntoInner<OwnedSocket> for Socket {
489     fn into_inner(self) -> OwnedSocket {
490         self.0
491     }
492 }
493
494 impl AsSocket for Socket {
495     fn as_socket(&self) -> BorrowedSocket<'_> {
496         self.0.as_socket()
497     }
498 }
499
500 impl AsRawSocket for Socket {
501     fn as_raw_socket(&self) -> RawSocket {
502         self.0.as_raw_socket()
503     }
504 }
505
506 impl IntoRawSocket for Socket {
507     fn into_raw_socket(self) -> RawSocket {
508         self.0.into_raw_socket()
509     }
510 }
511
512 impl FromRawSocket for Socket {
513     unsafe fn from_raw_socket(raw_socket: RawSocket) -> Self {
514         Self(FromRawSocket::from_raw_socket(raw_socket))
515     }
516 }