1 use rustc_apfloat::Float;
2 use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
3 use rustc_middle::{mir, ty, ty::FloatTy};
4 use rustc_target::abi::{Endian, HasDataLayout, Size};
7 use helpers::check_arg_count;
9 impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
10 pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
11 /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
12 fn emulate_simd_intrinsic(
15 args: &[OpTy<'tcx, Provenance>],
16 dest: &PlaceTy<'tcx, Provenance>,
17 ) -> InterpResult<'tcx> {
18 let this = self.eval_context_mut();
19 match intrinsic_name {
28 let [op] = check_arg_count(args)?;
29 let (op, op_len) = this.operand_to_simd(op)?;
30 let (dest, dest_len) = this.place_to_simd(dest)?;
32 assert_eq!(dest_len, op_len);
34 #[derive(Copy, Clone)]
42 #[derive(Copy, Clone)]
48 let which = match intrinsic_name {
49 "neg" => Op::MirOp(mir::UnOp::Neg),
51 "ceil" => Op::HostOp(HostFloatOp::Ceil),
52 "floor" => Op::HostOp(HostFloatOp::Floor),
53 "round" => Op::HostOp(HostFloatOp::Round),
54 "trunc" => Op::HostOp(HostFloatOp::Trunc),
55 "fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
59 for i in 0..dest_len {
60 let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
61 let dest = this.mplace_index(&dest, i)?;
62 let val = match which {
63 Op::MirOp(mir_op) => this.unary_op(mir_op, &op)?.to_scalar(),
65 // Works for f32 and f64.
66 let ty::Float(float_ty) = op.layout.ty.kind() else {
67 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
69 let op = op.to_scalar();
71 FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
72 FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
75 Op::HostOp(host_op) => {
76 let ty::Float(float_ty) = op.layout.ty.kind() else {
77 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
79 // FIXME using host floats
82 let f = f32::from_bits(op.to_scalar().to_u32()?);
83 let res = match host_op {
84 HostFloatOp::Ceil => f.ceil(),
85 HostFloatOp::Floor => f.floor(),
86 HostFloatOp::Round => f.round(),
87 HostFloatOp::Trunc => f.trunc(),
88 HostFloatOp::Sqrt => f.sqrt(),
90 Scalar::from_u32(res.to_bits())
93 let f = f64::from_bits(op.to_scalar().to_u64()?);
94 let res = match host_op {
95 HostFloatOp::Ceil => f.ceil(),
96 HostFloatOp::Floor => f.floor(),
97 HostFloatOp::Round => f.round(),
98 HostFloatOp::Trunc => f.trunc(),
99 HostFloatOp::Sqrt => f.sqrt(),
101 Scalar::from_u64(res.to_bits())
107 this.write_scalar(val, &dest.into())?;
131 | "arith_offset" => {
134 let [left, right] = check_arg_count(args)?;
135 let (left, left_len) = this.operand_to_simd(left)?;
136 let (right, right_len) = this.operand_to_simd(right)?;
137 let (dest, dest_len) = this.place_to_simd(dest)?;
139 assert_eq!(dest_len, left_len);
140 assert_eq!(dest_len, right_len);
149 let which = match intrinsic_name {
150 "add" => Op::MirOp(BinOp::Add),
151 "sub" => Op::MirOp(BinOp::Sub),
152 "mul" => Op::MirOp(BinOp::Mul),
153 "div" => Op::MirOp(BinOp::Div),
154 "rem" => Op::MirOp(BinOp::Rem),
155 "shl" => Op::MirOp(BinOp::Shl),
156 "shr" => Op::MirOp(BinOp::Shr),
157 "and" => Op::MirOp(BinOp::BitAnd),
158 "or" => Op::MirOp(BinOp::BitOr),
159 "xor" => Op::MirOp(BinOp::BitXor),
160 "eq" => Op::MirOp(BinOp::Eq),
161 "ne" => Op::MirOp(BinOp::Ne),
162 "lt" => Op::MirOp(BinOp::Lt),
163 "le" => Op::MirOp(BinOp::Le),
164 "gt" => Op::MirOp(BinOp::Gt),
165 "ge" => Op::MirOp(BinOp::Ge),
168 "saturating_add" => Op::SaturatingOp(BinOp::Add),
169 "saturating_sub" => Op::SaturatingOp(BinOp::Sub),
170 "arith_offset" => Op::WrappingOffset,
174 for i in 0..dest_len {
175 let left = this.read_immediate(&this.mplace_index(&left, i)?.into())?;
176 let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?;
177 let dest = this.mplace_index(&dest, i)?;
178 let val = match which {
179 Op::MirOp(mir_op) => {
180 let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
181 if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
182 // Shifts have extra UB as SIMD operations that the MIR binop does not have.
183 // See <https://github.com/rust-lang/rust/issues/91237>.
185 let r_val = right.to_scalar().to_bits(right.layout.size)?;
186 throw_ub_format!("overflowing shift by {r_val} in `simd_{intrinsic_name}` in SIMD lane {i}");
189 if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
190 // Special handling for boolean-returning operations
191 assert_eq!(ty, this.tcx.types.bool);
192 let val = val.to_bool().unwrap();
193 bool_to_simd_element(val, dest.layout.size)
195 assert_ne!(ty, this.tcx.types.bool);
196 assert_eq!(ty, dest.layout.ty);
200 Op::SaturatingOp(mir_op) => {
201 this.saturating_arith(mir_op, &left, &right)?
203 Op::WrappingOffset => {
204 let ptr = left.to_scalar().to_pointer(this)?;
205 let offset_count = right.to_scalar().to_machine_isize(this)?;
206 let pointee_ty = left.layout.ty.builtin_deref(true).unwrap().ty;
208 let pointee_size = i64::try_from(this.layout_of(pointee_ty)?.size.bytes()).unwrap();
209 let offset_bytes = offset_count.wrapping_mul(pointee_size);
210 let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
211 Scalar::from_maybe_pointer(offset_ptr, this)
214 fmax_op(&left, &right)?
217 fmin_op(&left, &right)?
220 this.write_scalar(val, &dest.into())?;
224 let [a, b, c] = check_arg_count(args)?;
225 let (a, a_len) = this.operand_to_simd(a)?;
226 let (b, b_len) = this.operand_to_simd(b)?;
227 let (c, c_len) = this.operand_to_simd(c)?;
228 let (dest, dest_len) = this.place_to_simd(dest)?;
230 assert_eq!(dest_len, a_len);
231 assert_eq!(dest_len, b_len);
232 assert_eq!(dest_len, c_len);
234 for i in 0..dest_len {
235 let a = this.read_scalar(&this.mplace_index(&a, i)?.into())?;
236 let b = this.read_scalar(&this.mplace_index(&b, i)?.into())?;
237 let c = this.read_scalar(&this.mplace_index(&c, i)?.into())?;
238 let dest = this.mplace_index(&dest, i)?;
240 // Works for f32 and f64.
241 // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
242 let ty::Float(float_ty) = dest.layout.ty.kind() else {
243 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
245 let val = match float_ty {
247 let a = f32::from_bits(a.to_u32()?);
248 let b = f32::from_bits(b.to_u32()?);
249 let c = f32::from_bits(c.to_u32()?);
250 let res = a.mul_add(b, c);
251 Scalar::from_u32(res.to_bits())
254 let a = f64::from_bits(a.to_u64()?);
255 let b = f64::from_bits(b.to_u64()?);
256 let c = f64::from_bits(c.to_u64()?);
257 let res = a.mul_add(b, c);
258 Scalar::from_u64(res.to_bits())
261 this.write_scalar(val, &dest.into())?;
274 let [op] = check_arg_count(args)?;
275 let (op, op_len) = this.operand_to_simd(op)?;
278 |b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
286 let which = match intrinsic_name {
287 "reduce_and" => Op::MirOp(BinOp::BitAnd),
288 "reduce_or" => Op::MirOp(BinOp::BitOr),
289 "reduce_xor" => Op::MirOp(BinOp::BitXor),
290 "reduce_any" => Op::MirOpBool(BinOp::BitOr),
291 "reduce_all" => Op::MirOpBool(BinOp::BitAnd),
292 "reduce_max" => Op::Max,
293 "reduce_min" => Op::Min,
297 // Initialize with first lane, then proceed with the rest.
298 let mut res = this.read_immediate(&this.mplace_index(&op, 0)?.into())?;
299 if matches!(which, Op::MirOpBool(_)) {
300 // Convert to `bool` scalar.
301 res = imm_from_bool(simd_element_to_bool(res)?);
304 let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
306 Op::MirOp(mir_op) => {
307 this.binary_op(mir_op, &res, &op)?
309 Op::MirOpBool(mir_op) => {
310 let op = imm_from_bool(simd_element_to_bool(op)?);
311 this.binary_op(mir_op, &res, &op)?
314 if matches!(res.layout.ty.kind(), ty::Float(_)) {
315 ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
317 // Just boring integers, so NaNs to worry about
318 if this.binary_op(BinOp::Ge, &res, &op)?.to_scalar().to_bool()? {
326 if matches!(res.layout.ty.kind(), ty::Float(_)) {
327 ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
329 // Just boring integers, so NaNs to worry about
330 if this.binary_op(BinOp::Le, &res, &op)?.to_scalar().to_bool()? {
339 this.write_immediate(*res, dest)?;
342 | "reduce_add_ordered"
343 | "reduce_mul_ordered" => {
346 let [op, init] = check_arg_count(args)?;
347 let (op, op_len) = this.operand_to_simd(op)?;
348 let init = this.read_immediate(init)?;
350 let mir_op = match intrinsic_name {
351 "reduce_add_ordered" => BinOp::Add,
352 "reduce_mul_ordered" => BinOp::Mul,
358 let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
359 res = this.binary_op(mir_op, &res, &op)?;
361 this.write_immediate(*res, dest)?;
364 let [mask, yes, no] = check_arg_count(args)?;
365 let (mask, mask_len) = this.operand_to_simd(mask)?;
366 let (yes, yes_len) = this.operand_to_simd(yes)?;
367 let (no, no_len) = this.operand_to_simd(no)?;
368 let (dest, dest_len) = this.place_to_simd(dest)?;
370 assert_eq!(dest_len, mask_len);
371 assert_eq!(dest_len, yes_len);
372 assert_eq!(dest_len, no_len);
374 for i in 0..dest_len {
375 let mask = this.read_immediate(&this.mplace_index(&mask, i)?.into())?;
376 let yes = this.read_immediate(&this.mplace_index(&yes, i)?.into())?;
377 let no = this.read_immediate(&this.mplace_index(&no, i)?.into())?;
378 let dest = this.mplace_index(&dest, i)?;
380 let val = if simd_element_to_bool(mask)? { yes } else { no };
381 this.write_immediate(*val, &dest.into())?;
384 "select_bitmask" => {
385 let [mask, yes, no] = check_arg_count(args)?;
386 let (yes, yes_len) = this.operand_to_simd(yes)?;
387 let (no, no_len) = this.operand_to_simd(no)?;
388 let (dest, dest_len) = this.place_to_simd(dest)?;
389 let bitmask_len = dest_len.max(8);
391 assert!(mask.layout.ty.is_integral());
392 assert!(bitmask_len <= 64);
393 assert_eq!(bitmask_len, mask.layout.size.bits());
394 assert_eq!(dest_len, yes_len);
395 assert_eq!(dest_len, no_len);
396 let dest_len = u32::try_from(dest_len).unwrap();
397 let bitmask_len = u32::try_from(bitmask_len).unwrap();
400 this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap();
401 for i in 0..dest_len {
404 .checked_shl(simd_bitmask_index(i, dest_len, this.data_layout().endian))
406 let yes = this.read_immediate(&this.mplace_index(&yes, i.into())?.into())?;
407 let no = this.read_immediate(&this.mplace_index(&no, i.into())?.into())?;
408 let dest = this.mplace_index(&dest, i.into())?;
410 let val = if mask != 0 { yes } else { no };
411 this.write_immediate(*val, &dest.into())?;
413 for i in dest_len..bitmask_len {
414 // If the mask is "padded", ensure that padding is all-zero.
415 let mask = mask & 1u64.checked_shl(i).unwrap();
418 "a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
425 let [op] = check_arg_count(args)?;
426 let (op, op_len) = this.operand_to_simd(op)?;
427 let (dest, dest_len) = this.place_to_simd(dest)?;
429 assert_eq!(dest_len, op_len);
431 let safe_cast = intrinsic_name == "as";
433 for i in 0..dest_len {
434 let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
435 let dest = this.mplace_index(&dest, i)?;
437 let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
438 // Int-to-(int|float): always safe
439 (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_)) =>
440 this.int_to_int_or_float(&op, dest.layout.ty)?,
441 // Float-to-float: always safe
442 (ty::Float(_), ty::Float(_)) =>
443 this.float_to_float_or_int(&op, dest.layout.ty)?,
444 // Float-to-int in safe mode
445 (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
446 this.float_to_float_or_int(&op, dest.layout.ty)?,
447 // Float-to-int in unchecked mode
448 (ty::Float(FloatTy::F32), ty::Int(_) | ty::Uint(_)) if !safe_cast =>
449 this.float_to_int_unchecked(op.to_scalar().to_f32()?, dest.layout.ty)?.into(),
450 (ty::Float(FloatTy::F64), ty::Int(_) | ty::Uint(_)) if !safe_cast =>
451 this.float_to_int_unchecked(op.to_scalar().to_f64()?, dest.layout.ty)?.into(),
454 "Unsupported SIMD cast from element type {from_ty} to {to_ty}",
455 from_ty = op.layout.ty,
456 to_ty = dest.layout.ty,
459 this.write_immediate(val, &dest.into())?;
463 let [left, right, index] = check_arg_count(args)?;
464 let (left, left_len) = this.operand_to_simd(left)?;
465 let (right, right_len) = this.operand_to_simd(right)?;
466 let (dest, dest_len) = this.place_to_simd(dest)?;
468 // `index` is an array, not a SIMD type
469 let ty::Array(_, index_len) = index.layout.ty.kind() else {
470 span_bug!(this.cur_span(), "simd_shuffle index argument has non-array type {}", index.layout.ty)
472 let index_len = index_len.eval_usize(*this.tcx, this.param_env());
474 assert_eq!(left_len, right_len);
475 assert_eq!(index_len, dest_len);
477 for i in 0..dest_len {
478 let src_index: u64 = this
479 .read_immediate(&this.operand_index(index, i)?)?
483 let dest = this.mplace_index(&dest, i)?;
485 let val = if src_index < left_len {
486 this.read_immediate(&this.mplace_index(&left, src_index)?.into())?
487 } else if src_index < left_len.checked_add(right_len).unwrap() {
488 let right_idx = src_index.checked_sub(left_len).unwrap();
489 this.read_immediate(&this.mplace_index(&right, right_idx)?.into())?
493 "simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
496 this.write_immediate(*val, &dest.into())?;
500 let [passthru, ptrs, mask] = check_arg_count(args)?;
501 let (passthru, passthru_len) = this.operand_to_simd(passthru)?;
502 let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
503 let (mask, mask_len) = this.operand_to_simd(mask)?;
504 let (dest, dest_len) = this.place_to_simd(dest)?;
506 assert_eq!(dest_len, passthru_len);
507 assert_eq!(dest_len, ptrs_len);
508 assert_eq!(dest_len, mask_len);
510 for i in 0..dest_len {
511 let passthru = this.read_immediate(&this.mplace_index(&passthru, i)?.into())?;
512 let ptr = this.read_immediate(&this.mplace_index(&ptrs, i)?.into())?;
513 let mask = this.read_immediate(&this.mplace_index(&mask, i)?.into())?;
514 let dest = this.mplace_index(&dest, i)?;
516 let val = if simd_element_to_bool(mask)? {
517 let place = this.deref_operand(&ptr.into())?;
518 this.read_immediate(&place.into())?
522 this.write_immediate(*val, &dest.into())?;
526 let [value, ptrs, mask] = check_arg_count(args)?;
527 let (value, value_len) = this.operand_to_simd(value)?;
528 let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
529 let (mask, mask_len) = this.operand_to_simd(mask)?;
531 assert_eq!(ptrs_len, value_len);
532 assert_eq!(ptrs_len, mask_len);
534 for i in 0..ptrs_len {
535 let value = this.read_immediate(&this.mplace_index(&value, i)?.into())?;
536 let ptr = this.read_immediate(&this.mplace_index(&ptrs, i)?.into())?;
537 let mask = this.read_immediate(&this.mplace_index(&mask, i)?.into())?;
539 if simd_element_to_bool(mask)? {
540 let place = this.deref_operand(&ptr.into())?;
541 this.write_immediate(*value, &place.into())?;
546 let [op] = check_arg_count(args)?;
547 let (op, op_len) = this.operand_to_simd(op)?;
548 let bitmask_len = op_len.max(8);
550 assert!(dest.layout.ty.is_integral());
551 assert!(bitmask_len <= 64);
552 assert_eq!(bitmask_len, dest.layout.size.bits());
553 let op_len = u32::try_from(op_len).unwrap();
557 let op = this.read_immediate(&this.mplace_index(&op, i.into())?.into())?;
558 if simd_element_to_bool(op)? {
560 .checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
564 this.write_int(res, dest)?;
567 name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
573 fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
574 // SIMD uses all-1 as pattern for "true"
575 let val = if b { -1 } else { 0 };
576 Scalar::from_int(val, size)
579 fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
580 let val = elem.to_scalar().to_int(elem.layout.size)?;
584 _ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
588 fn simd_bitmask_index(idx: u32, vec_len: u32, endianess: Endian) -> u32 {
589 assert!(idx < vec_len);
591 Endian::Little => idx,
592 #[allow(clippy::integer_arithmetic)] // idx < vec_len
593 Endian::Big => vec_len - 1 - idx, // reverse order of bits
598 left: &ImmTy<'tcx, Provenance>,
599 right: &ImmTy<'tcx, Provenance>,
600 ) -> InterpResult<'tcx, Scalar<Provenance>> {
601 assert_eq!(left.layout.ty, right.layout.ty);
602 let ty::Float(float_ty) = left.layout.ty.kind() else {
603 bug!("fmax operand is not a float")
605 let left = left.to_scalar();
606 let right = right.to_scalar();
608 FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
609 FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
614 left: &ImmTy<'tcx, Provenance>,
615 right: &ImmTy<'tcx, Provenance>,
616 ) -> InterpResult<'tcx, Scalar<Provenance>> {
617 assert_eq!(left.layout.ty, right.layout.ty);
618 let ty::Float(float_ty) = left.layout.ty.kind() else {
619 bug!("fmin operand is not a float")
621 let left = left.to_scalar();
622 let right = right.to_scalar();
624 FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
625 FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),