]> git.lizzy.rs Git - rust.git/blobdiff - src/helpers.rs
rustup: more flexible write_bytes avoids allocations and removes itertools dependency
[rust.git] / src / helpers.rs
index 9e1fa34370589e40e3e6166a0c9d95ae6d55499a..f7be3de8e481227ba2d0077d757ec4d146bf9822 100644 (file)
@@ -1,4 +1,5 @@
-use std::mem;
+use std::{mem, iter};
+use std::ffi::{OsStr, OsString};
 
 use rustc::hir::def_id::{DefId, CRATE_DEF_INDEX};
 use rustc::mir;
@@ -402,4 +403,60 @@ fn try_unwrap_io_result<T: From<i32>>(
             }
         }
     }
+
+    /// Helper function to read an OsString from a null-terminated sequence of bytes, which is what
+    /// the Unix APIs usually handle.
+    fn read_os_string_from_c_string(&mut self, scalar: Scalar<Tag>) -> InterpResult<'tcx, OsString> {
+        let bytes = self.eval_context_mut().memory.read_c_str(scalar)?;
+        Ok(bytes_to_os_str(bytes)?.into())
+    }
+
+    /// Helper function to write an OsStr as a null-terminated sequence of bytes, which is what
+    /// the Unix APIs usually handle. This function returns `Ok(false)` without trying to write if
+    /// `size` is not large enough to fit the contents of `os_string` plus a null terminator. It
+    /// returns `Ok(true)` if the writing process was successful.
+    fn write_os_str_to_c_string(
+        &mut self,
+        os_str: &OsStr,
+        scalar: Scalar<Tag>,
+        size: u64
+    ) -> InterpResult<'tcx, bool> {
+        let bytes = os_str_to_bytes(os_str)?;
+        // If `size` is smaller or equal than `bytes.len()`, writing `bytes` plus the required null
+        // terminator to memory using the `ptr` pointer would cause an overflow.
+        if size <= bytes.len() as u64 {
+            return Ok(false);
+        }
+        // FIXME: We should use `Iterator::chain` instead when rust-lang/rust#65704 lands.
+        self.eval_context_mut().memory.write_bytes(scalar, bytes.iter().copied().chain(iter::once(0u8)))?;
+        Ok(true)
+    }
+}
+
+#[cfg(target_os = "unix")]
+fn os_str_to_bytes<'tcx, 'a>(os_str: &'a OsStr) -> InterpResult<'tcx, &'a [u8]> {
+    std::os::unix::ffi::OsStringExt::into_bytes(os_str)
+}
+
+#[cfg(target_os = "unix")]
+fn bytes_to_os_str<'tcx, 'a>(bytes: &'a[u8]) -> InterpResult<'tcx, &'a OsStr> {
+    Ok(std::os::unix::ffi::OsStringExt::from_bytes(bytes))
+}
+
+// On non-unix platforms the best we can do to transform bytes from/to OS strings is to do the
+// intermediate transformation into strings. Which invalidates non-utf8 paths that are actually
+// valid.
+#[cfg(not(target_os = "unix"))]
+fn os_str_to_bytes<'tcx, 'a>(os_str: &'a OsStr) -> InterpResult<'tcx, &'a [u8]> {
+    os_str
+        .to_str()
+        .map(|s| s.as_bytes())
+        .ok_or_else(|| err_unsup_format!("{:?} is not a valid utf-8 string", os_str).into())
+}
+
+#[cfg(not(target_os = "unix"))]
+fn bytes_to_os_str<'tcx, 'a>(bytes: &'a[u8]) -> InterpResult<'tcx, &'a OsStr> {
+    let s = std::str::from_utf8(bytes)
+        .map_err(|_| err_unsup_format!("{:?} is not a valid utf-8 string", bytes))?;
+    Ok(&OsStr::new(s))
 }