]> git.lizzy.rs Git - rust.git/blob - library/std/src/sys/windows/net.rs
Move most init to `sys::init`
[rust.git] / library / std / src / sys / windows / net.rs
1 #![unstable(issue = "none", feature = "windows_net")]
2
3 use crate::cmp;
4 use crate::io::{self, IoSlice, IoSliceMut, Read};
5 use crate::mem;
6 use crate::net::{Shutdown, SocketAddr};
7 use crate::ptr;
8 use crate::sync::Once;
9 use crate::sys;
10 use crate::sys::c;
11 use crate::sys_common::net;
12 use crate::sys_common::{AsInner, FromInner, IntoInner};
13 use crate::time::Duration;
14
15 use libc::{c_int, c_long, c_ulong, c_void};
16
17 pub type wrlen_t = i32;
18
19 pub mod netc {
20     pub use crate::sys::c::ADDRESS_FAMILY as sa_family_t;
21     pub use crate::sys::c::ADDRINFOA as addrinfo;
22     pub use crate::sys::c::SOCKADDR as sockaddr;
23     pub use crate::sys::c::SOCKADDR_STORAGE_LH as sockaddr_storage;
24     pub use crate::sys::c::*;
25 }
26
27 pub struct Socket(c::SOCKET);
28
29 /// Checks whether the Windows socket interface has been started already, and
30 /// if not, starts it.
31 pub fn init() {
32     static START: Once = Once::new();
33
34     START.call_once(|| unsafe {
35         let mut data: c::WSADATA = mem::zeroed();
36         let ret = c::WSAStartup(
37             0x202, // version 2.2
38             &mut data,
39         );
40         assert_eq!(ret, 0);
41     });
42 }
43
44 pub fn cleanup() {
45     unsafe {
46         c::WSACleanup();
47     }
48 }
49
50 /// Returns the last error from the Windows socket interface.
51 fn last_error() -> io::Error {
52     io::Error::from_raw_os_error(unsafe { c::WSAGetLastError() })
53 }
54
55 #[doc(hidden)]
56 pub trait IsMinusOne {
57     fn is_minus_one(&self) -> bool;
58 }
59
60 macro_rules! impl_is_minus_one {
61     ($($t:ident)*) => ($(impl IsMinusOne for $t {
62         fn is_minus_one(&self) -> bool {
63             *self == -1
64         }
65     })*)
66 }
67
68 impl_is_minus_one! { i8 i16 i32 i64 isize }
69
70 /// Checks if the signed integer is the Windows constant `SOCKET_ERROR` (-1)
71 /// and if so, returns the last error from the Windows socket interface. This
72 /// function must be called before another call to the socket API is made.
73 pub fn cvt<T: IsMinusOne>(t: T) -> io::Result<T> {
74     if t.is_minus_one() { Err(last_error()) } else { Ok(t) }
75 }
76
77 /// A variant of `cvt` for `getaddrinfo` which return 0 for a success.
78 pub fn cvt_gai(err: c_int) -> io::Result<()> {
79     if err == 0 { Ok(()) } else { Err(last_error()) }
80 }
81
82 /// Just to provide the same interface as sys/unix/net.rs
83 pub fn cvt_r<T, F>(mut f: F) -> io::Result<T>
84 where
85     T: IsMinusOne,
86     F: FnMut() -> T,
87 {
88     cvt(f())
89 }
90
91 impl Socket {
92     pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
93         let fam = match *addr {
94             SocketAddr::V4(..) => c::AF_INET,
95             SocketAddr::V6(..) => c::AF_INET6,
96         };
97         let socket = unsafe {
98             match c::WSASocketW(
99                 fam,
100                 ty,
101                 0,
102                 ptr::null_mut(),
103                 0,
104                 c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT,
105             ) {
106                 c::INVALID_SOCKET => match c::WSAGetLastError() {
107                     c::WSAEPROTOTYPE | c::WSAEINVAL => {
108                         match c::WSASocketW(fam, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED)
109                         {
110                             c::INVALID_SOCKET => Err(last_error()),
111                             n => {
112                                 let s = Socket(n);
113                                 s.set_no_inherit()?;
114                                 Ok(s)
115                             }
116                         }
117                     }
118                     n => Err(io::Error::from_raw_os_error(n)),
119                 },
120                 n => Ok(Socket(n)),
121             }
122         }?;
123         Ok(socket)
124     }
125
126     pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
127         self.set_nonblocking(true)?;
128         let r = unsafe {
129             let (addrp, len) = addr.into_inner();
130             cvt(c::connect(self.0, addrp, len))
131         };
132         self.set_nonblocking(false)?;
133
134         match r {
135             Ok(_) => return Ok(()),
136             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
137             Err(e) => return Err(e),
138         }
139
140         if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
141             return Err(io::Error::new_const(
142                 io::ErrorKind::InvalidInput,
143                 &"cannot set a 0 duration timeout",
144             ));
145         }
146
147         let mut timeout = c::timeval {
148             tv_sec: timeout.as_secs() as c_long,
149             tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
150         };
151         if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
152             timeout.tv_usec = 1;
153         }
154
155         let fds = unsafe {
156             let mut fds = mem::zeroed::<c::fd_set>();
157             fds.fd_count = 1;
158             fds.fd_array[0] = self.0;
159             fds
160         };
161
162         let mut writefds = fds;
163         let mut errorfds = fds;
164
165         let n =
166             unsafe { cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))? };
167
168         match n {
169             0 => Err(io::Error::new_const(io::ErrorKind::TimedOut, &"connection timed out")),
170             _ => {
171                 if writefds.fd_count != 1 {
172                     if let Some(e) = self.take_error()? {
173                         return Err(e);
174                     }
175                 }
176                 Ok(())
177             }
178         }
179     }
180
181     pub fn accept(&self, storage: *mut c::SOCKADDR, len: *mut c_int) -> io::Result<Socket> {
182         let socket = unsafe {
183             match c::accept(self.0, storage, len) {
184                 c::INVALID_SOCKET => Err(last_error()),
185                 n => Ok(Socket(n)),
186             }
187         }?;
188         Ok(socket)
189     }
190
191     pub fn duplicate(&self) -> io::Result<Socket> {
192         let socket = unsafe {
193             let mut info: c::WSAPROTOCOL_INFO = mem::zeroed();
194             cvt(c::WSADuplicateSocketW(self.0, c::GetCurrentProcessId(), &mut info))?;
195
196             match c::WSASocketW(
197                 info.iAddressFamily,
198                 info.iSocketType,
199                 info.iProtocol,
200                 &mut info,
201                 0,
202                 c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT,
203             ) {
204                 c::INVALID_SOCKET => match c::WSAGetLastError() {
205                     c::WSAEPROTOTYPE | c::WSAEINVAL => {
206                         match c::WSASocketW(
207                             info.iAddressFamily,
208                             info.iSocketType,
209                             info.iProtocol,
210                             &mut info,
211                             0,
212                             c::WSA_FLAG_OVERLAPPED,
213                         ) {
214                             c::INVALID_SOCKET => Err(last_error()),
215                             n => {
216                                 let s = Socket(n);
217                                 s.set_no_inherit()?;
218                                 Ok(s)
219                             }
220                         }
221                     }
222                     n => Err(io::Error::from_raw_os_error(n)),
223                 },
224                 n => Ok(Socket(n)),
225             }
226         }?;
227         Ok(socket)
228     }
229
230     fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
231         // On unix when a socket is shut down all further reads return 0, so we
232         // do the same on windows to map a shut down socket to returning EOF.
233         let len = cmp::min(buf.len(), i32::MAX as usize) as i32;
234         unsafe {
235             match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
236                 -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
237                 -1 => Err(last_error()),
238                 n => Ok(n as usize),
239             }
240         }
241     }
242
243     pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
244         self.recv_with_flags(buf, 0)
245     }
246
247     pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
248         // On unix when a socket is shut down all further reads return 0, so we
249         // do the same on windows to map a shut down socket to returning EOF.
250         let len = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD;
251         let mut nread = 0;
252         let mut flags = 0;
253         unsafe {
254             let ret = c::WSARecv(
255                 self.0,
256                 bufs.as_mut_ptr() as *mut c::WSABUF,
257                 len,
258                 &mut nread,
259                 &mut flags,
260                 ptr::null_mut(),
261                 ptr::null_mut(),
262             );
263             match ret {
264                 0 => Ok(nread as usize),
265                 _ if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
266                 _ => Err(last_error()),
267             }
268         }
269     }
270
271     #[inline]
272     pub fn is_read_vectored(&self) -> bool {
273         true
274     }
275
276     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
277         self.recv_with_flags(buf, c::MSG_PEEK)
278     }
279
280     fn recv_from_with_flags(
281         &self,
282         buf: &mut [u8],
283         flags: c_int,
284     ) -> io::Result<(usize, SocketAddr)> {
285         let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
286         let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
287         let len = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;
288
289         // On unix when a socket is shut down all further reads return 0, so we
290         // do the same on windows to map a shut down socket to returning EOF.
291         unsafe {
292             match c::recvfrom(
293                 self.0,
294                 buf.as_mut_ptr() as *mut c_void,
295                 len,
296                 flags,
297                 &mut storage as *mut _ as *mut _,
298                 &mut addrlen,
299             ) {
300                 -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
301                     Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
302                 }
303                 -1 => Err(last_error()),
304                 n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
305             }
306         }
307     }
308
309     pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
310         self.recv_from_with_flags(buf, 0)
311     }
312
313     pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
314         self.recv_from_with_flags(buf, c::MSG_PEEK)
315     }
316
317     pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
318         let len = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD;
319         let mut nwritten = 0;
320         unsafe {
321             cvt(c::WSASend(
322                 self.0,
323                 bufs.as_ptr() as *const c::WSABUF as *mut c::WSABUF,
324                 len,
325                 &mut nwritten,
326                 0,
327                 ptr::null_mut(),
328                 ptr::null_mut(),
329             ))?;
330         }
331         Ok(nwritten as usize)
332     }
333
334     #[inline]
335     pub fn is_write_vectored(&self) -> bool {
336         true
337     }
338
339     pub fn set_timeout(&self, dur: Option<Duration>, kind: c_int) -> io::Result<()> {
340         let timeout = match dur {
341             Some(dur) => {
342                 let timeout = sys::dur2timeout(dur);
343                 if timeout == 0 {
344                     return Err(io::Error::new_const(
345                         io::ErrorKind::InvalidInput,
346                         &"cannot set a 0 duration timeout",
347                     ));
348                 }
349                 timeout
350             }
351             None => 0,
352         };
353         net::setsockopt(self, c::SOL_SOCKET, kind, timeout)
354     }
355
356     pub fn timeout(&self, kind: c_int) -> io::Result<Option<Duration>> {
357         let raw: c::DWORD = net::getsockopt(self, c::SOL_SOCKET, kind)?;
358         if raw == 0 {
359             Ok(None)
360         } else {
361             let secs = raw / 1000;
362             let nsec = (raw % 1000) * 1000000;
363             Ok(Some(Duration::new(secs as u64, nsec as u32)))
364         }
365     }
366
367     #[cfg(not(target_vendor = "uwp"))]
368     fn set_no_inherit(&self) -> io::Result<()> {
369         sys::cvt(unsafe { c::SetHandleInformation(self.0 as c::HANDLE, c::HANDLE_FLAG_INHERIT, 0) })
370             .map(drop)
371     }
372
373     #[cfg(target_vendor = "uwp")]
374     fn set_no_inherit(&self) -> io::Result<()> {
375         Err(io::Error::new_const(io::ErrorKind::Unsupported, &"Unavailable on UWP"))
376     }
377
378     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
379         let how = match how {
380             Shutdown::Write => c::SD_SEND,
381             Shutdown::Read => c::SD_RECEIVE,
382             Shutdown::Both => c::SD_BOTH,
383         };
384         cvt(unsafe { c::shutdown(self.0, how) })?;
385         Ok(())
386     }
387
388     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
389         let mut nonblocking = nonblocking as c_ulong;
390         let r = unsafe { c::ioctlsocket(self.0, c::FIONBIO as c_int, &mut nonblocking) };
391         if r == 0 { Ok(()) } else { Err(io::Error::last_os_error()) }
392     }
393
394     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
395         net::setsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY, nodelay as c::BYTE)
396     }
397
398     pub fn nodelay(&self) -> io::Result<bool> {
399         let raw: c::BYTE = net::getsockopt(self, c::IPPROTO_TCP, c::TCP_NODELAY)?;
400         Ok(raw != 0)
401     }
402
403     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
404         let raw: c_int = net::getsockopt(self, c::SOL_SOCKET, c::SO_ERROR)?;
405         if raw == 0 { Ok(None) } else { Ok(Some(io::Error::from_raw_os_error(raw as i32))) }
406     }
407 }
408
409 #[unstable(reason = "not public", issue = "none", feature = "fd_read")]
410 impl<'a> Read for &'a Socket {
411     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
412         (**self).read(buf)
413     }
414 }
415
416 impl Drop for Socket {
417     fn drop(&mut self) {
418         let _ = unsafe { c::closesocket(self.0) };
419     }
420 }
421
422 impl AsInner<c::SOCKET> for Socket {
423     fn as_inner(&self) -> &c::SOCKET {
424         &self.0
425     }
426 }
427
428 impl FromInner<c::SOCKET> for Socket {
429     fn from_inner(sock: c::SOCKET) -> Socket {
430         Socket(sock)
431     }
432 }
433
434 impl IntoInner<c::SOCKET> for Socket {
435     fn into_inner(self) -> c::SOCKET {
436         let ret = self.0;
437         mem::forget(self);
438         ret
439     }
440 }