]> git.lizzy.rs Git - rust.git/blob - library/std/src/os/unix/net/ancillary.rs
Rollup merge of #94839 - TaKO8Ki:suggest-using-double-colon-for-struct-field-type...
[rust.git] / library / std / src / os / unix / net / ancillary.rs
1 use super::{sockaddr_un, SocketAddr};
2 use crate::convert::TryFrom;
3 use crate::io::{self, IoSlice, IoSliceMut};
4 use crate::marker::PhantomData;
5 use crate::mem::{size_of, zeroed};
6 use crate::os::unix::io::RawFd;
7 use crate::path::Path;
8 use crate::ptr::{eq, read_unaligned};
9 use crate::slice::from_raw_parts;
10 use crate::sys::net::Socket;
11
12 // FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
13 #[cfg(all(doc, not(target_os = "linux"), not(target_os = "android")))]
14 #[allow(non_camel_case_types)]
15 mod libc {
16     pub use libc::c_int;
17     pub struct ucred;
18     pub struct cmsghdr;
19     pub type pid_t = i32;
20     pub type gid_t = u32;
21     pub type uid_t = u32;
22 }
23
24 pub(super) fn recv_vectored_with_ancillary_from(
25     socket: &Socket,
26     bufs: &mut [IoSliceMut<'_>],
27     ancillary: &mut SocketAncillary<'_>,
28 ) -> io::Result<(usize, bool, io::Result<SocketAddr>)> {
29     unsafe {
30         let mut msg_name: libc::sockaddr_un = zeroed();
31         let mut msg: libc::msghdr = zeroed();
32         msg.msg_name = &mut msg_name as *mut _ as *mut _;
33         msg.msg_namelen = size_of::<libc::sockaddr_un>() as libc::socklen_t;
34         msg.msg_iov = bufs.as_mut_ptr().cast();
35         msg.msg_iovlen = bufs.len() as _;
36         msg.msg_controllen = ancillary.buffer.len() as _;
37         // macos requires that the control pointer is null when the len is 0.
38         if msg.msg_controllen > 0 {
39             msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
40         }
41
42         let count = socket.recv_msg(&mut msg)?;
43
44         ancillary.length = msg.msg_controllen as usize;
45         ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
46
47         let truncated = msg.msg_flags & libc::MSG_TRUNC == libc::MSG_TRUNC;
48         let addr = SocketAddr::from_parts(msg_name, msg.msg_namelen);
49
50         Ok((count, truncated, addr))
51     }
52 }
53
54 pub(super) fn send_vectored_with_ancillary_to(
55     socket: &Socket,
56     path: Option<&Path>,
57     bufs: &[IoSlice<'_>],
58     ancillary: &mut SocketAncillary<'_>,
59 ) -> io::Result<usize> {
60     unsafe {
61         let (mut msg_name, msg_namelen) =
62             if let Some(path) = path { sockaddr_un(path)? } else { (zeroed(), 0) };
63
64         let mut msg: libc::msghdr = zeroed();
65         msg.msg_name = &mut msg_name as *mut _ as *mut _;
66         msg.msg_namelen = msg_namelen;
67         msg.msg_iov = bufs.as_ptr() as *mut _;
68         msg.msg_iovlen = bufs.len() as _;
69         msg.msg_controllen = ancillary.length as _;
70         // macos requires that the control pointer is null when the len is 0.
71         if msg.msg_controllen > 0 {
72             msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
73         }
74
75         ancillary.truncated = false;
76
77         socket.send_msg(&mut msg)
78     }
79 }
80
81 fn add_to_ancillary_data<T>(
82     buffer: &mut [u8],
83     length: &mut usize,
84     source: &[T],
85     cmsg_level: libc::c_int,
86     cmsg_type: libc::c_int,
87 ) -> bool {
88     let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) {
89         if let Ok(source_len) = u32::try_from(source_len) {
90             source_len
91         } else {
92             return false;
93         }
94     } else {
95         return false;
96     };
97
98     unsafe {
99         let additional_space = libc::CMSG_SPACE(source_len) as usize;
100
101         let new_length = if let Some(new_length) = additional_space.checked_add(*length) {
102             new_length
103         } else {
104             return false;
105         };
106
107         if new_length > buffer.len() {
108             return false;
109         }
110
111         buffer[*length..new_length].fill(0);
112
113         *length = new_length;
114
115         let mut msg: libc::msghdr = zeroed();
116         msg.msg_control = buffer.as_mut_ptr().cast();
117         msg.msg_controllen = *length as _;
118
119         let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
120         let mut previous_cmsg = cmsg;
121         while !cmsg.is_null() {
122             previous_cmsg = cmsg;
123             cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
124
125             // Most operating systems, but not Linux or emscripten, return the previous pointer
126             // when its length is zero. Therefore, check if the previous pointer is the same as
127             // the current one.
128             if eq(cmsg, previous_cmsg) {
129                 break;
130             }
131         }
132
133         if previous_cmsg.is_null() {
134             return false;
135         }
136
137         (*previous_cmsg).cmsg_level = cmsg_level;
138         (*previous_cmsg).cmsg_type = cmsg_type;
139         (*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as _;
140
141         let data = libc::CMSG_DATA(previous_cmsg).cast();
142
143         libc::memcpy(data, source.as_ptr().cast(), source_len as usize);
144     }
145     true
146 }
147
148 struct AncillaryDataIter<'a, T> {
149     data: &'a [u8],
150     phantom: PhantomData<T>,
151 }
152
153 impl<'a, T> AncillaryDataIter<'a, T> {
154     /// Create `AncillaryDataIter` struct to iterate through the data unit in the control message.
155     ///
156     /// # Safety
157     ///
158     /// `data` must contain a valid control message.
159     unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> {
160         AncillaryDataIter { data, phantom: PhantomData }
161     }
162 }
163
164 impl<'a, T> Iterator for AncillaryDataIter<'a, T> {
165     type Item = T;
166
167     fn next(&mut self) -> Option<T> {
168         if size_of::<T>() <= self.data.len() {
169             unsafe {
170                 let unit = read_unaligned(self.data.as_ptr().cast());
171                 self.data = &self.data[size_of::<T>()..];
172                 Some(unit)
173             }
174         } else {
175             None
176         }
177     }
178 }
179
180 /// Unix credential.
181 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
182 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
183 #[derive(Clone)]
184 pub struct SocketCred(libc::ucred);
185
186 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
187 impl SocketCred {
188     /// Create a Unix credential struct.
189     ///
190     /// PID, UID and GID is set to 0.
191     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
192     #[must_use]
193     pub fn new() -> SocketCred {
194         SocketCred(libc::ucred { pid: 0, uid: 0, gid: 0 })
195     }
196
197     /// Set the PID.
198     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
199     pub fn set_pid(&mut self, pid: libc::pid_t) {
200         self.0.pid = pid;
201     }
202
203     /// Get the current PID.
204     #[must_use]
205     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
206     pub fn get_pid(&self) -> libc::pid_t {
207         self.0.pid
208     }
209
210     /// Set the UID.
211     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
212     pub fn set_uid(&mut self, uid: libc::uid_t) {
213         self.0.uid = uid;
214     }
215
216     /// Get the current UID.
217     #[must_use]
218     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
219     pub fn get_uid(&self) -> libc::uid_t {
220         self.0.uid
221     }
222
223     /// Set the GID.
224     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
225     pub fn set_gid(&mut self, gid: libc::gid_t) {
226         self.0.gid = gid;
227     }
228
229     /// Get the current GID.
230     #[must_use]
231     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
232     pub fn get_gid(&self) -> libc::gid_t {
233         self.0.gid
234     }
235 }
236
237 /// This control message contains file descriptors.
238 ///
239 /// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`.
240 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
241 pub struct ScmRights<'a>(AncillaryDataIter<'a, RawFd>);
242
243 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
244 impl<'a> Iterator for ScmRights<'a> {
245     type Item = RawFd;
246
247     fn next(&mut self) -> Option<RawFd> {
248         self.0.next()
249     }
250 }
251
252 /// This control message contains unix credentials.
253 ///
254 /// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_CREDENTIALS` or `SCM_CREDS`.
255 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
256 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
257 pub struct ScmCredentials<'a>(AncillaryDataIter<'a, libc::ucred>);
258
259 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
260 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
261 impl<'a> Iterator for ScmCredentials<'a> {
262     type Item = SocketCred;
263
264     fn next(&mut self) -> Option<SocketCred> {
265         Some(SocketCred(self.0.next()?))
266     }
267 }
268
269 /// The error type which is returned from parsing the type a control message.
270 #[non_exhaustive]
271 #[derive(Debug)]
272 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
273 pub enum AncillaryError {
274     Unknown { cmsg_level: i32, cmsg_type: i32 },
275 }
276
277 /// This enum represent one control message of variable type.
278 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
279 pub enum AncillaryData<'a> {
280     ScmRights(ScmRights<'a>),
281     #[cfg(any(doc, target_os = "android", target_os = "linux",))]
282     ScmCredentials(ScmCredentials<'a>),
283 }
284
285 impl<'a> AncillaryData<'a> {
286     /// Create an `AncillaryData::ScmRights` variant.
287     ///
288     /// # Safety
289     ///
290     /// `data` must contain a valid control message and the control message must be type of
291     /// `SOL_SOCKET` and level of `SCM_RIGHTS`.
292     unsafe fn as_rights(data: &'a [u8]) -> Self {
293         let ancillary_data_iter = AncillaryDataIter::new(data);
294         let scm_rights = ScmRights(ancillary_data_iter);
295         AncillaryData::ScmRights(scm_rights)
296     }
297
298     /// Create an `AncillaryData::ScmCredentials` variant.
299     ///
300     /// # Safety
301     ///
302     /// `data` must contain a valid control message and the control message must be type of
303     /// `SOL_SOCKET` and level of `SCM_CREDENTIALS` or `SCM_CREDENTIALS`.
304     #[cfg(any(doc, target_os = "android", target_os = "linux",))]
305     unsafe fn as_credentials(data: &'a [u8]) -> Self {
306         let ancillary_data_iter = AncillaryDataIter::new(data);
307         let scm_credentials = ScmCredentials(ancillary_data_iter);
308         AncillaryData::ScmCredentials(scm_credentials)
309     }
310
311     fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Result<Self, AncillaryError> {
312         unsafe {
313             let cmsg_len_zero = libc::CMSG_LEN(0) as usize;
314             let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero;
315             let data = libc::CMSG_DATA(cmsg).cast();
316             let data = from_raw_parts(data, data_len);
317
318             match (*cmsg).cmsg_level {
319                 libc::SOL_SOCKET => match (*cmsg).cmsg_type {
320                     libc::SCM_RIGHTS => Ok(AncillaryData::as_rights(data)),
321                     #[cfg(any(target_os = "android", target_os = "linux",))]
322                     libc::SCM_CREDENTIALS => Ok(AncillaryData::as_credentials(data)),
323                     cmsg_type => {
324                         Err(AncillaryError::Unknown { cmsg_level: libc::SOL_SOCKET, cmsg_type })
325                     }
326                 },
327                 cmsg_level => {
328                     Err(AncillaryError::Unknown { cmsg_level, cmsg_type: (*cmsg).cmsg_type })
329                 }
330             }
331         }
332     }
333 }
334
335 /// This struct is used to iterate through the control messages.
336 #[must_use = "iterators are lazy and do nothing unless consumed"]
337 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
338 pub struct Messages<'a> {
339     buffer: &'a [u8],
340     current: Option<&'a libc::cmsghdr>,
341 }
342
343 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
344 impl<'a> Iterator for Messages<'a> {
345     type Item = Result<AncillaryData<'a>, AncillaryError>;
346
347     fn next(&mut self) -> Option<Self::Item> {
348         unsafe {
349             let mut msg: libc::msghdr = zeroed();
350             msg.msg_control = self.buffer.as_ptr() as *mut _;
351             msg.msg_controllen = self.buffer.len() as _;
352
353             let cmsg = if let Some(current) = self.current {
354                 libc::CMSG_NXTHDR(&msg, current)
355             } else {
356                 libc::CMSG_FIRSTHDR(&msg)
357             };
358
359             let cmsg = cmsg.as_ref()?;
360
361             // Most operating systems, but not Linux or emscripten, return the previous pointer
362             // when its length is zero. Therefore, check if the previous pointer is the same as
363             // the current one.
364             if let Some(current) = self.current {
365                 if eq(current, cmsg) {
366                     return None;
367                 }
368             }
369
370             self.current = Some(cmsg);
371             let ancillary_result = AncillaryData::try_from_cmsghdr(cmsg);
372             Some(ancillary_result)
373         }
374     }
375 }
376
377 /// A Unix socket Ancillary data struct.
378 ///
379 /// # Example
380 /// ```no_run
381 /// #![feature(unix_socket_ancillary_data)]
382 /// use std::os::unix::net::{UnixStream, SocketAncillary, AncillaryData};
383 /// use std::io::IoSliceMut;
384 ///
385 /// fn main() -> std::io::Result<()> {
386 ///     let sock = UnixStream::connect("/tmp/sock")?;
387 ///
388 ///     let mut fds = [0; 8];
389 ///     let mut ancillary_buffer = [0; 128];
390 ///     let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
391 ///
392 ///     let mut buf = [1; 8];
393 ///     let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
394 ///     sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
395 ///
396 ///     for ancillary_result in ancillary.messages() {
397 ///         if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
398 ///             for fd in scm_rights {
399 ///                 println!("receive file descriptor: {fd}");
400 ///             }
401 ///         }
402 ///     }
403 ///     Ok(())
404 /// }
405 /// ```
406 #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
407 #[derive(Debug)]
408 pub struct SocketAncillary<'a> {
409     buffer: &'a mut [u8],
410     length: usize,
411     truncated: bool,
412 }
413
414 impl<'a> SocketAncillary<'a> {
415     /// Create an ancillary data with the given buffer.
416     ///
417     /// # Example
418     ///
419     /// ```no_run
420     /// # #![allow(unused_mut)]
421     /// #![feature(unix_socket_ancillary_data)]
422     /// use std::os::unix::net::SocketAncillary;
423     /// let mut ancillary_buffer = [0; 128];
424     /// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
425     /// ```
426     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
427     pub fn new(buffer: &'a mut [u8]) -> Self {
428         SocketAncillary { buffer, length: 0, truncated: false }
429     }
430
431     /// Returns the capacity of the buffer.
432     #[must_use]
433     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
434     pub fn capacity(&self) -> usize {
435         self.buffer.len()
436     }
437
438     /// Returns `true` if the ancillary data is empty.
439     #[must_use]
440     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
441     pub fn is_empty(&self) -> bool {
442         self.length == 0
443     }
444
445     /// Returns the number of used bytes.
446     #[must_use]
447     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
448     pub fn len(&self) -> usize {
449         self.length
450     }
451
452     /// Returns the iterator of the control messages.
453     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
454     pub fn messages(&self) -> Messages<'_> {
455         Messages { buffer: &self.buffer[..self.length], current: None }
456     }
457
458     /// Is `true` if during a recv operation the ancillary was truncated.
459     ///
460     /// # Example
461     ///
462     /// ```no_run
463     /// #![feature(unix_socket_ancillary_data)]
464     /// use std::os::unix::net::{UnixStream, SocketAncillary};
465     /// use std::io::IoSliceMut;
466     ///
467     /// fn main() -> std::io::Result<()> {
468     ///     let sock = UnixStream::connect("/tmp/sock")?;
469     ///
470     ///     let mut ancillary_buffer = [0; 128];
471     ///     let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
472     ///
473     ///     let mut buf = [1; 8];
474     ///     let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
475     ///     sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
476     ///
477     ///     println!("Is truncated: {}", ancillary.truncated());
478     ///     Ok(())
479     /// }
480     /// ```
481     #[must_use]
482     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
483     pub fn truncated(&self) -> bool {
484         self.truncated
485     }
486
487     /// Add file descriptors to the ancillary data.
488     ///
489     /// The function returns `true` if there was enough space in the buffer.
490     /// If there was not enough space then no file descriptors was appended.
491     /// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
492     /// and type `SCM_RIGHTS`.
493     ///
494     /// # Example
495     ///
496     /// ```no_run
497     /// #![feature(unix_socket_ancillary_data)]
498     /// use std::os::unix::net::{UnixStream, SocketAncillary};
499     /// use std::os::unix::io::AsRawFd;
500     /// use std::io::IoSlice;
501     ///
502     /// fn main() -> std::io::Result<()> {
503     ///     let sock = UnixStream::connect("/tmp/sock")?;
504     ///
505     ///     let mut ancillary_buffer = [0; 128];
506     ///     let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
507     ///     ancillary.add_fds(&[sock.as_raw_fd()][..]);
508     ///
509     ///     let mut buf = [1; 8];
510     ///     let mut bufs = &mut [IoSlice::new(&mut buf[..])][..];
511     ///     sock.send_vectored_with_ancillary(bufs, &mut ancillary)?;
512     ///     Ok(())
513     /// }
514     /// ```
515     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
516     pub fn add_fds(&mut self, fds: &[RawFd]) -> bool {
517         self.truncated = false;
518         add_to_ancillary_data(
519             &mut self.buffer,
520             &mut self.length,
521             fds,
522             libc::SOL_SOCKET,
523             libc::SCM_RIGHTS,
524         )
525     }
526
527     /// Add credentials to the ancillary data.
528     ///
529     /// The function returns `true` if there was enough space in the buffer.
530     /// If there was not enough space then no credentials was appended.
531     /// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
532     /// and type `SCM_CREDENTIALS` or `SCM_CREDS`.
533     ///
534     #[cfg(any(doc, target_os = "android", target_os = "linux",))]
535     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
536     pub fn add_creds(&mut self, creds: &[SocketCred]) -> bool {
537         self.truncated = false;
538         add_to_ancillary_data(
539             &mut self.buffer,
540             &mut self.length,
541             creds,
542             libc::SOL_SOCKET,
543             libc::SCM_CREDENTIALS,
544         )
545     }
546
547     /// Clears the ancillary data, removing all values.
548     ///
549     /// # Example
550     ///
551     /// ```no_run
552     /// #![feature(unix_socket_ancillary_data)]
553     /// use std::os::unix::net::{UnixStream, SocketAncillary, AncillaryData};
554     /// use std::io::IoSliceMut;
555     ///
556     /// fn main() -> std::io::Result<()> {
557     ///     let sock = UnixStream::connect("/tmp/sock")?;
558     ///
559     ///     let mut fds1 = [0; 8];
560     ///     let mut fds2 = [0; 8];
561     ///     let mut ancillary_buffer = [0; 128];
562     ///     let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
563     ///
564     ///     let mut buf = [1; 8];
565     ///     let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
566     ///
567     ///     sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
568     ///     for ancillary_result in ancillary.messages() {
569     ///         if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
570     ///             for fd in scm_rights {
571     ///                 println!("receive file descriptor: {fd}");
572     ///             }
573     ///         }
574     ///     }
575     ///
576     ///     ancillary.clear();
577     ///
578     ///     sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
579     ///     for ancillary_result in ancillary.messages() {
580     ///         if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
581     ///             for fd in scm_rights {
582     ///                 println!("receive file descriptor: {fd}");
583     ///             }
584     ///         }
585     ///     }
586     ///     Ok(())
587     /// }
588     /// ```
589     #[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
590     pub fn clear(&mut self) {
591         self.length = 0;
592         self.truncated = false;
593     }
594 }