]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_const_eval/src/const_eval/valtrees.rs
Rollup merge of #106897 - estebank:issue-99430, r=davidtwco
[rust.git] / compiler / rustc_const_eval / src / const_eval / valtrees.rs
1 use super::eval_queries::{mk_eval_cx, op_to_const};
2 use super::machine::CompileTimeEvalContext;
3 use super::{ValTreeCreationError, ValTreeCreationResult, VALTREE_MAX_NODES};
4 use crate::interpret::{
5     intern_const_alloc_recursive, ConstValue, ImmTy, Immediate, InternKind, MemPlaceMeta,
6     MemoryKind, PlaceTy, Scalar,
7 };
8 use crate::interpret::{MPlaceTy, Value};
9 use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
10 use rustc_span::source_map::DUMMY_SP;
11 use rustc_target::abi::{Align, VariantIdx};
12
13 #[instrument(skip(ecx), level = "debug")]
14 fn branches<'tcx>(
15     ecx: &CompileTimeEvalContext<'tcx, 'tcx>,
16     place: &MPlaceTy<'tcx>,
17     n: usize,
18     variant: Option<VariantIdx>,
19     num_nodes: &mut usize,
20 ) -> ValTreeCreationResult<'tcx> {
21     let place = match variant {
22         Some(variant) => ecx.mplace_downcast(&place, variant).unwrap(),
23         None => *place,
24     };
25     let variant = variant.map(|variant| Some(ty::ValTree::Leaf(ScalarInt::from(variant.as_u32()))));
26     debug!(?place, ?variant);
27
28     let mut fields = Vec::with_capacity(n);
29     for i in 0..n {
30         let field = ecx.mplace_field(&place, i).unwrap();
31         let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
32         fields.push(Some(valtree));
33     }
34
35     // For enums, we prepend their variant index before the variant's fields so we can figure out
36     // the variant again when just seeing a valtree.
37     let branches = variant
38         .into_iter()
39         .chain(fields.into_iter())
40         .collect::<Option<Vec<_>>>()
41         .expect("should have already checked for errors in ValTree creation");
42
43     // Have to account for ZSTs here
44     if branches.len() == 0 {
45         *num_nodes += 1;
46     }
47
48     Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(branches)))
49 }
50
51 #[instrument(skip(ecx), level = "debug")]
52 fn slice_branches<'tcx>(
53     ecx: &CompileTimeEvalContext<'tcx, 'tcx>,
54     place: &MPlaceTy<'tcx>,
55     num_nodes: &mut usize,
56 ) -> ValTreeCreationResult<'tcx> {
57     let n = place
58         .len(&ecx.tcx.tcx)
59         .unwrap_or_else(|_| panic!("expected to use len of place {:?}", place));
60
61     let mut elems = Vec::with_capacity(n as usize);
62     for i in 0..n {
63         let place_elem = ecx.mplace_index(place, i).unwrap();
64         let valtree = const_to_valtree_inner(ecx, &place_elem, num_nodes)?;
65         elems.push(valtree);
66     }
67
68     Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(elems)))
69 }
70
71 #[instrument(skip(ecx), level = "debug")]
72 pub(crate) fn const_to_valtree_inner<'tcx>(
73     ecx: &CompileTimeEvalContext<'tcx, 'tcx>,
74     place: &MPlaceTy<'tcx>,
75     num_nodes: &mut usize,
76 ) -> ValTreeCreationResult<'tcx> {
77     let ty = place.layout.ty;
78     debug!("ty kind: {:?}", ty.kind());
79
80     if *num_nodes >= VALTREE_MAX_NODES {
81         return Err(ValTreeCreationError::NodesOverflow);
82     }
83
84     match ty.kind() {
85         ty::FnDef(..) => {
86             *num_nodes += 1;
87             Ok(ty::ValTree::zst())
88         }
89         ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
90             let Ok(val) = ecx.read_immediate(&place.into()) else {
91                 return Err(ValTreeCreationError::Other);
92             };
93             let val = val.to_scalar();
94             *num_nodes += 1;
95
96             Ok(ty::ValTree::Leaf(val.assert_int()))
97         }
98
99         // Raw pointers are not allowed in type level constants, as we cannot properly test them for
100         // equality at compile-time (see `ptr_guaranteed_cmp`).
101         // Technically we could allow function pointers (represented as `ty::Instance`), but this is not guaranteed to
102         // agree with runtime equality tests.
103         ty::FnPtr(_) | ty::RawPtr(_) => Err(ValTreeCreationError::NonSupportedType),
104
105         ty::Ref(_, _, _)  => {
106             let Ok(derefd_place)= ecx.deref_operand(&place.into()) else {
107                 return Err(ValTreeCreationError::Other);
108             };
109             debug!(?derefd_place);
110
111             const_to_valtree_inner(ecx, &derefd_place, num_nodes)
112         }
113
114         ty::Str | ty::Slice(_) | ty::Array(_, _) => {
115             slice_branches(ecx, place, num_nodes)
116         }
117         // Trait objects are not allowed in type level constants, as we have no concept for
118         // resolving their backing type, even if we can do that at const eval time. We may
119         // hypothetically be able to allow `dyn StructuralEq` trait objects in the future,
120         // but it is unclear if this is useful.
121         ty::Dynamic(..) => Err(ValTreeCreationError::NonSupportedType),
122
123         ty::Tuple(elem_tys) => {
124             branches(ecx, place, elem_tys.len(), None, num_nodes)
125         }
126
127         ty::Adt(def, _) => {
128             if def.is_union() {
129                 return Err(ValTreeCreationError::NonSupportedType);
130             } else if def.variants().is_empty() {
131                 bug!("uninhabited types should have errored and never gotten converted to valtree")
132             }
133
134             let Ok((_, variant)) = ecx.read_discriminant(&place.into()) else {
135                 return Err(ValTreeCreationError::Other);
136             };
137             branches(ecx, place, def.variant(variant).fields.len(), def.is_enum().then_some(variant), num_nodes)
138         }
139
140         ty::Never
141         | ty::Error(_)
142         | ty::Foreign(..)
143         | ty::Infer(ty::FreshIntTy(_))
144         | ty::Infer(ty::FreshFloatTy(_))
145         // FIXME(oli-obk): we could look behind opaque types
146         | ty::Alias(..)
147         | ty::Param(_)
148         | ty::Bound(..)
149         | ty::Placeholder(..)
150         | ty::Infer(_)
151         // FIXME(oli-obk): we can probably encode closures just like structs
152         | ty::Closure(..)
153         | ty::Generator(..)
154         | ty::GeneratorWitness(..) => Err(ValTreeCreationError::NonSupportedType),
155     }
156 }
157
158 #[instrument(skip(ecx), level = "debug")]
159 fn create_mplace_from_layout<'tcx>(
160     ecx: &mut CompileTimeEvalContext<'tcx, 'tcx>,
161     ty: Ty<'tcx>,
162 ) -> MPlaceTy<'tcx> {
163     let tcx = ecx.tcx;
164     let param_env = ecx.param_env;
165     let layout = tcx.layout_of(param_env.and(ty)).unwrap();
166     debug!(?layout);
167
168     ecx.allocate(layout, MemoryKind::Stack).unwrap()
169 }
170
171 // Walks custom DSTs and gets the type of the unsized field and the number of elements
172 // in the unsized field.
173 fn get_info_on_unsized_field<'tcx>(
174     ty: Ty<'tcx>,
175     valtree: ty::ValTree<'tcx>,
176     tcx: TyCtxt<'tcx>,
177 ) -> (Ty<'tcx>, usize) {
178     let mut last_valtree = valtree;
179     let tail = tcx.struct_tail_with_normalize(
180         ty,
181         |ty| ty,
182         || {
183             let branches = last_valtree.unwrap_branch();
184             last_valtree = branches[branches.len() - 1];
185             debug!(?branches, ?last_valtree);
186         },
187     );
188     let unsized_inner_ty = match tail.kind() {
189         ty::Slice(t) => *t,
190         ty::Str => tail,
191         _ => bug!("expected Slice or Str"),
192     };
193
194     // Have to adjust type for ty::Str
195     let unsized_inner_ty = match unsized_inner_ty.kind() {
196         ty::Str => tcx.mk_ty(ty::Uint(ty::UintTy::U8)),
197         _ => unsized_inner_ty,
198     };
199
200     // Get the number of elements in the unsized field
201     let num_elems = last_valtree.unwrap_branch().len();
202
203     (unsized_inner_ty, num_elems)
204 }
205
206 #[instrument(skip(ecx), level = "debug", ret)]
207 fn create_pointee_place<'tcx>(
208     ecx: &mut CompileTimeEvalContext<'tcx, 'tcx>,
209     ty: Ty<'tcx>,
210     valtree: ty::ValTree<'tcx>,
211 ) -> MPlaceTy<'tcx> {
212     let tcx = ecx.tcx.tcx;
213
214     if !ty.is_sized(*ecx.tcx, ty::ParamEnv::empty()) {
215         // We need to create `Allocation`s for custom DSTs
216
217         let (unsized_inner_ty, num_elems) = get_info_on_unsized_field(ty, valtree, tcx);
218         let unsized_inner_ty = match unsized_inner_ty.kind() {
219             ty::Str => tcx.mk_ty(ty::Uint(ty::UintTy::U8)),
220             _ => unsized_inner_ty,
221         };
222         let unsized_inner_ty_size =
223             tcx.layout_of(ty::ParamEnv::empty().and(unsized_inner_ty)).unwrap().layout.size();
224         debug!(?unsized_inner_ty, ?unsized_inner_ty_size, ?num_elems);
225
226         // for custom DSTs only the last field/element is unsized, but we need to also allocate
227         // space for the other fields/elements
228         let layout = tcx.layout_of(ty::ParamEnv::empty().and(ty)).unwrap();
229         let size_of_sized_part = layout.layout.size();
230
231         // Get the size of the memory behind the DST
232         let dst_size = unsized_inner_ty_size.checked_mul(num_elems as u64, &tcx).unwrap();
233
234         let size = size_of_sized_part.checked_add(dst_size, &tcx).unwrap();
235         let align = Align::from_bytes(size.bytes().next_power_of_two()).unwrap();
236         let ptr = ecx.allocate_ptr(size, align, MemoryKind::Stack).unwrap();
237         debug!(?ptr);
238
239         MPlaceTy::from_aligned_ptr_with_meta(
240             ptr.into(),
241             layout,
242             MemPlaceMeta::Meta(Scalar::from_machine_usize(num_elems as u64, &tcx)),
243         )
244     } else {
245         create_mplace_from_layout(ecx, ty)
246     }
247 }
248
249 /// Converts a `ValTree` to a `ConstValue`, which is needed after mir
250 /// construction has finished.
251 // FIXME Merge `valtree_to_const_value` and `valtree_into_mplace` into one function
252 #[instrument(skip(tcx), level = "debug", ret)]
253 pub fn valtree_to_const_value<'tcx>(
254     tcx: TyCtxt<'tcx>,
255     param_env_ty: ty::ParamEnvAnd<'tcx, Ty<'tcx>>,
256     valtree: ty::ValTree<'tcx>,
257 ) -> ConstValue<'tcx> {
258     // Basic idea: We directly construct `Scalar` values from trivial `ValTree`s
259     // (those for constants with type bool, int, uint, float or char).
260     // For all other types we create an `MPlace` and fill that by walking
261     // the `ValTree` and using `place_projection` and `place_field` to
262     // create inner `MPlace`s which are filled recursively.
263     // FIXME Does this need an example?
264
265     let (param_env, ty) = param_env_ty.into_parts();
266     let mut ecx = mk_eval_cx(tcx, DUMMY_SP, param_env, false);
267
268     match ty.kind() {
269         ty::FnDef(..) => {
270             assert!(valtree.unwrap_branch().is_empty());
271             ConstValue::ZeroSized
272         }
273         ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => match valtree {
274             ty::ValTree::Leaf(scalar_int) => ConstValue::Scalar(Scalar::Int(scalar_int)),
275             ty::ValTree::Branch(_) => bug!(
276                 "ValTrees for Bool, Int, Uint, Float or Char should have the form ValTree::Leaf"
277             ),
278         },
279         ty::Ref(_, _, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Adt(..) => {
280             let mut place = match ty.kind() {
281                 ty::Ref(_, inner_ty, _) => {
282                     // Need to create a place for the pointee to fill for Refs
283                     create_pointee_place(&mut ecx, *inner_ty, valtree)
284                 }
285                 _ => create_mplace_from_layout(&mut ecx, ty),
286             };
287             debug!(?place);
288
289             valtree_into_mplace(&mut ecx, &mut place, valtree);
290             dump_place(&ecx, place.into());
291             intern_const_alloc_recursive(&mut ecx, InternKind::Constant, &place).unwrap();
292
293             match ty.kind() {
294                 ty::Ref(_, _, _) => {
295                     let ref_place = place.to_ref(&tcx);
296                     let imm =
297                         ImmTy::from_immediate(ref_place, tcx.layout_of(param_env_ty).unwrap());
298
299                     op_to_const(&ecx, &imm.into())
300                 }
301                 _ => op_to_const(&ecx, &place.into()),
302             }
303         }
304         ty::Never
305         | ty::Error(_)
306         | ty::Foreign(..)
307         | ty::Infer(ty::FreshIntTy(_))
308         | ty::Infer(ty::FreshFloatTy(_))
309         | ty::Alias(..)
310         | ty::Param(_)
311         | ty::Bound(..)
312         | ty::Placeholder(..)
313         | ty::Infer(_)
314         | ty::Closure(..)
315         | ty::Generator(..)
316         | ty::GeneratorWitness(..)
317         | ty::FnPtr(_)
318         | ty::RawPtr(_)
319         | ty::Str
320         | ty::Slice(_)
321         | ty::Dynamic(..) => bug!("no ValTree should have been created for type {:?}", ty.kind()),
322     }
323 }
324
325 #[instrument(skip(ecx), level = "debug")]
326 fn valtree_into_mplace<'tcx>(
327     ecx: &mut CompileTimeEvalContext<'tcx, 'tcx>,
328     place: &mut MPlaceTy<'tcx>,
329     valtree: ty::ValTree<'tcx>,
330 ) {
331     // This will match on valtree and write the value(s) corresponding to the ValTree
332     // inside the place recursively.
333
334     let tcx = ecx.tcx.tcx;
335     let ty = place.layout.ty;
336
337     match ty.kind() {
338         ty::FnDef(_, _) => {
339             ecx.write_immediate(Immediate::Uninit, &place.into()).unwrap();
340         }
341         ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
342             let scalar_int = valtree.unwrap_leaf();
343             debug!("writing trivial valtree {:?} to place {:?}", scalar_int, place);
344             ecx.write_immediate(Immediate::Scalar(scalar_int.into()), &place.into()).unwrap();
345         }
346         ty::Ref(_, inner_ty, _) => {
347             let mut pointee_place = create_pointee_place(ecx, *inner_ty, valtree);
348             debug!(?pointee_place);
349
350             valtree_into_mplace(ecx, &mut pointee_place, valtree);
351             dump_place(ecx, pointee_place.into());
352             intern_const_alloc_recursive(ecx, InternKind::Constant, &pointee_place).unwrap();
353
354             let imm = match inner_ty.kind() {
355                 ty::Slice(_) | ty::Str => {
356                     let len = valtree.unwrap_branch().len();
357                     let len_scalar = Scalar::from_machine_usize(len as u64, &tcx);
358
359                     Immediate::ScalarPair(
360                         Scalar::from_maybe_pointer((*pointee_place).ptr, &tcx),
361                         len_scalar,
362                     )
363                 }
364                 _ => pointee_place.to_ref(&tcx),
365             };
366             debug!(?imm);
367
368             ecx.write_immediate(imm, &place.into()).unwrap();
369         }
370         ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Str | ty::Slice(_) => {
371             let branches = valtree.unwrap_branch();
372
373             // Need to downcast place for enums
374             let (place_adjusted, branches, variant_idx) = match ty.kind() {
375                 ty::Adt(def, _) if def.is_enum() => {
376                     // First element of valtree corresponds to variant
377                     let scalar_int = branches[0].unwrap_leaf();
378                     let variant_idx = VariantIdx::from_u32(scalar_int.try_to_u32().unwrap());
379                     let variant = def.variant(variant_idx);
380                     debug!(?variant);
381
382                     (
383                         place.project_downcast(ecx, variant_idx).unwrap(),
384                         &branches[1..],
385                         Some(variant_idx),
386                     )
387                 }
388                 _ => (*place, branches, None),
389             };
390             debug!(?place_adjusted, ?branches);
391
392             // Create the places (by indexing into `place`) for the fields and fill
393             // them recursively
394             for (i, inner_valtree) in branches.iter().enumerate() {
395                 debug!(?i, ?inner_valtree);
396
397                 let mut place_inner = match ty.kind() {
398                     ty::Str | ty::Slice(_) => ecx.mplace_index(&place, i as u64).unwrap(),
399                     _ if !ty.is_sized(*ecx.tcx, ty::ParamEnv::empty())
400                         && i == branches.len() - 1 =>
401                     {
402                         // Note: For custom DSTs we need to manually process the last unsized field.
403                         // We created a `Pointer` for the `Allocation` of the complete sized version of
404                         // the Adt in `create_pointee_place` and now we fill that `Allocation` with the
405                         // values in the ValTree. For the unsized field we have to additionally add the meta
406                         // data.
407
408                         let (unsized_inner_ty, num_elems) =
409                             get_info_on_unsized_field(ty, valtree, tcx);
410                         debug!(?unsized_inner_ty);
411
412                         let inner_ty = match ty.kind() {
413                             ty::Adt(def, substs) => {
414                                 def.variant(VariantIdx::from_u32(0)).fields[i].ty(tcx, substs)
415                             }
416                             ty::Tuple(inner_tys) => inner_tys[i],
417                             _ => bug!("unexpected unsized type {:?}", ty),
418                         };
419
420                         let inner_layout =
421                             tcx.layout_of(ty::ParamEnv::empty().and(inner_ty)).unwrap();
422                         debug!(?inner_layout);
423
424                         let offset = place_adjusted.layout.fields.offset(i);
425                         place
426                             .offset_with_meta(
427                                 offset,
428                                 MemPlaceMeta::Meta(Scalar::from_machine_usize(
429                                     num_elems as u64,
430                                     &tcx,
431                                 )),
432                                 inner_layout,
433                                 &tcx,
434                             )
435                             .unwrap()
436                     }
437                     _ => ecx.mplace_field(&place_adjusted, i).unwrap(),
438                 };
439
440                 debug!(?place_inner);
441                 valtree_into_mplace(ecx, &mut place_inner, *inner_valtree);
442                 dump_place(&ecx, place_inner.into());
443             }
444
445             debug!("dump of place_adjusted:");
446             dump_place(ecx, place_adjusted.into());
447
448             if let Some(variant_idx) = variant_idx {
449                 // don't forget filling the place with the discriminant of the enum
450                 ecx.write_discriminant(variant_idx, &place.into()).unwrap();
451             }
452
453             debug!("dump of place after writing discriminant:");
454             dump_place(ecx, place.into());
455         }
456         _ => bug!("shouldn't have created a ValTree for {:?}", ty),
457     }
458 }
459
460 fn dump_place<'tcx>(ecx: &CompileTimeEvalContext<'tcx, 'tcx>, place: PlaceTy<'tcx>) {
461     trace!("{:?}", ecx.dump_place(*place));
462 }