]> git.lizzy.rs Git - rust.git/blob - crates/hir-ty/src/consteval.rs
feat: support negative const generic parameters
[rust.git] / crates / hir-ty / src / consteval.rs
1 //! Constant evaluation details
2
3 use std::{
4     collections::HashMap,
5     convert::TryInto,
6     fmt::{Display, Write},
7 };
8
9 use chalk_ir::{BoundVar, DebruijnIndex, GenericArgData, IntTy, Scalar};
10 use hir_def::{
11     expr::{ArithOp, BinaryOp, Expr, ExprId, Literal, Pat, PatId},
12     path::ModPath,
13     resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs},
14     type_ref::ConstScalar,
15     ConstId, DefWithBodyId,
16 };
17 use la_arena::{Arena, Idx};
18 use stdx::never;
19
20 use crate::{
21     db::HirDatabase, infer::InferenceContext, lower::ParamLoweringMode, to_placeholder_idx,
22     utils::Generics, Const, ConstData, ConstValue, GenericArg, InferenceResult, Interner, Ty,
23     TyBuilder, TyKind,
24 };
25
26 /// Extension trait for [`Const`]
27 pub trait ConstExt {
28     /// Is a [`Const`] unknown?
29     fn is_unknown(&self) -> bool;
30 }
31
32 impl ConstExt for Const {
33     fn is_unknown(&self) -> bool {
34         match self.data(Interner).value {
35             // interned Unknown
36             chalk_ir::ConstValue::Concrete(chalk_ir::ConcreteConst {
37                 interned: ConstScalar::Unknown,
38             }) => true,
39
40             // interned concrete anything else
41             chalk_ir::ConstValue::Concrete(..) => false,
42
43             _ => {
44                 tracing::error!(
45                     "is_unknown was called on a non-concrete constant value! {:?}",
46                     self
47                 );
48                 true
49             }
50         }
51     }
52 }
53
54 pub struct ConstEvalCtx<'a> {
55     pub db: &'a dyn HirDatabase,
56     pub owner: DefWithBodyId,
57     pub exprs: &'a Arena<Expr>,
58     pub pats: &'a Arena<Pat>,
59     pub local_data: HashMap<PatId, ComputedExpr>,
60     infer: &'a InferenceResult,
61 }
62
63 impl ConstEvalCtx<'_> {
64     fn expr_ty(&mut self, expr: ExprId) -> Ty {
65         self.infer[expr].clone()
66     }
67 }
68
69 #[derive(Debug, Clone, PartialEq, Eq)]
70 pub enum ConstEvalError {
71     NotSupported(&'static str),
72     SemanticError(&'static str),
73     Loop,
74     IncompleteExpr,
75     Panic(String),
76 }
77
78 #[derive(Debug, Clone, PartialEq, Eq)]
79 pub enum ComputedExpr {
80     Literal(Literal),
81     Tuple(Box<[ComputedExpr]>),
82 }
83
84 impl Display for ComputedExpr {
85     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86         match self {
87             ComputedExpr::Literal(l) => match l {
88                 Literal::Int(x, _) => {
89                     if *x >= 10 {
90                         write!(f, "{} ({:#X})", x, x)
91                     } else {
92                         x.fmt(f)
93                     }
94                 }
95                 Literal::Uint(x, _) => {
96                     if *x >= 10 {
97                         write!(f, "{} ({:#X})", x, x)
98                     } else {
99                         x.fmt(f)
100                     }
101                 }
102                 Literal::Float(x, _) => x.fmt(f),
103                 Literal::Bool(x) => x.fmt(f),
104                 Literal::Char(x) => std::fmt::Debug::fmt(x, f),
105                 Literal::String(x) => std::fmt::Debug::fmt(x, f),
106                 Literal::ByteString(x) => std::fmt::Debug::fmt(x, f),
107             },
108             ComputedExpr::Tuple(t) => {
109                 f.write_char('(')?;
110                 for x in &**t {
111                     x.fmt(f)?;
112                     f.write_str(", ")?;
113                 }
114                 f.write_char(')')
115             }
116         }
117     }
118 }
119
120 fn scalar_max(scalar: &Scalar) -> i128 {
121     match scalar {
122         Scalar::Bool => 1,
123         Scalar::Char => u32::MAX as i128,
124         Scalar::Int(x) => match x {
125             IntTy::Isize => isize::MAX as i128,
126             IntTy::I8 => i8::MAX as i128,
127             IntTy::I16 => i16::MAX as i128,
128             IntTy::I32 => i32::MAX as i128,
129             IntTy::I64 => i64::MAX as i128,
130             IntTy::I128 => i128::MAX as i128,
131         },
132         Scalar::Uint(x) => match x {
133             chalk_ir::UintTy::Usize => usize::MAX as i128,
134             chalk_ir::UintTy::U8 => u8::MAX as i128,
135             chalk_ir::UintTy::U16 => u16::MAX as i128,
136             chalk_ir::UintTy::U32 => u32::MAX as i128,
137             chalk_ir::UintTy::U64 => u64::MAX as i128,
138             chalk_ir::UintTy::U128 => i128::MAX as i128, // ignore too big u128 for now
139         },
140         Scalar::Float(_) => 0,
141     }
142 }
143
144 fn is_valid(scalar: &Scalar, value: i128) -> bool {
145     if value < 0 {
146         !matches!(scalar, Scalar::Uint(_)) && -scalar_max(scalar) - 1 <= value
147     } else {
148         value <= scalar_max(scalar)
149     }
150 }
151
152 pub fn eval_const(
153     expr_id: ExprId,
154     ctx: &mut ConstEvalCtx<'_>,
155 ) -> Result<ComputedExpr, ConstEvalError> {
156     let expr = &ctx.exprs[expr_id];
157     match expr {
158         Expr::Missing => Err(ConstEvalError::IncompleteExpr),
159         Expr::Literal(l) => Ok(ComputedExpr::Literal(l.clone())),
160         &Expr::UnaryOp { expr, op } => {
161             let ty = &ctx.expr_ty(expr);
162             let ev = eval_const(expr, ctx)?;
163             match op {
164                 hir_def::expr::UnaryOp::Deref => Err(ConstEvalError::NotSupported("deref")),
165                 hir_def::expr::UnaryOp::Not => {
166                     let v = match ev {
167                         ComputedExpr::Literal(Literal::Bool(b)) => {
168                             return Ok(ComputedExpr::Literal(Literal::Bool(!b)))
169                         }
170                         ComputedExpr::Literal(Literal::Int(v, _)) => v,
171                         ComputedExpr::Literal(Literal::Uint(v, _)) => v
172                             .try_into()
173                             .map_err(|_| ConstEvalError::NotSupported("too big u128"))?,
174                         _ => return Err(ConstEvalError::NotSupported("this kind of operator")),
175                     };
176                     let r = match ty.kind(Interner) {
177                         TyKind::Scalar(Scalar::Uint(x)) => match x {
178                             chalk_ir::UintTy::U8 => !(v as u8) as i128,
179                             chalk_ir::UintTy::U16 => !(v as u16) as i128,
180                             chalk_ir::UintTy::U32 => !(v as u32) as i128,
181                             chalk_ir::UintTy::U64 => !(v as u64) as i128,
182                             chalk_ir::UintTy::U128 => {
183                                 return Err(ConstEvalError::NotSupported("negation of u128"))
184                             }
185                             chalk_ir::UintTy::Usize => !(v as usize) as i128,
186                         },
187                         TyKind::Scalar(Scalar::Int(x)) => match x {
188                             chalk_ir::IntTy::I8 => !(v as i8) as i128,
189                             chalk_ir::IntTy::I16 => !(v as i16) as i128,
190                             chalk_ir::IntTy::I32 => !(v as i32) as i128,
191                             chalk_ir::IntTy::I64 => !(v as i64) as i128,
192                             chalk_ir::IntTy::I128 => !v,
193                             chalk_ir::IntTy::Isize => !(v as isize) as i128,
194                         },
195                         _ => return Err(ConstEvalError::NotSupported("unreachable?")),
196                     };
197                     Ok(ComputedExpr::Literal(Literal::Int(r, None)))
198                 }
199                 hir_def::expr::UnaryOp::Neg => {
200                     let v = match ev {
201                         ComputedExpr::Literal(Literal::Int(v, _)) => v,
202                         ComputedExpr::Literal(Literal::Uint(v, _)) => v
203                             .try_into()
204                             .map_err(|_| ConstEvalError::NotSupported("too big u128"))?,
205                         _ => return Err(ConstEvalError::NotSupported("this kind of operator")),
206                     };
207                     Ok(ComputedExpr::Literal(Literal::Int(
208                         v.checked_neg().ok_or_else(|| {
209                             ConstEvalError::Panic("overflow in negation".to_string())
210                         })?,
211                         None,
212                     )))
213                 }
214             }
215         }
216         &Expr::BinaryOp { lhs, rhs, op } => {
217             let ty = &ctx.expr_ty(lhs);
218             let lhs = eval_const(lhs, ctx)?;
219             let rhs = eval_const(rhs, ctx)?;
220             let op = op.ok_or(ConstEvalError::IncompleteExpr)?;
221             let v1 = match lhs {
222                 ComputedExpr::Literal(Literal::Int(v, _)) => v,
223                 ComputedExpr::Literal(Literal::Uint(v, _)) => {
224                     v.try_into().map_err(|_| ConstEvalError::NotSupported("too big u128"))?
225                 }
226                 _ => return Err(ConstEvalError::NotSupported("this kind of operator")),
227             };
228             let v2 = match rhs {
229                 ComputedExpr::Literal(Literal::Int(v, _)) => v,
230                 ComputedExpr::Literal(Literal::Uint(v, _)) => {
231                     v.try_into().map_err(|_| ConstEvalError::NotSupported("too big u128"))?
232                 }
233                 _ => return Err(ConstEvalError::NotSupported("this kind of operator")),
234             };
235             match op {
236                 BinaryOp::ArithOp(b) => {
237                     let panic_arith = ConstEvalError::Panic(
238                         "attempt to run invalid arithmetic operation".to_string(),
239                     );
240                     let r = match b {
241                         ArithOp::Add => v1.checked_add(v2).ok_or_else(|| panic_arith.clone())?,
242                         ArithOp::Mul => v1.checked_mul(v2).ok_or_else(|| panic_arith.clone())?,
243                         ArithOp::Sub => v1.checked_sub(v2).ok_or_else(|| panic_arith.clone())?,
244                         ArithOp::Div => v1.checked_div(v2).ok_or_else(|| panic_arith.clone())?,
245                         ArithOp::Rem => v1.checked_rem(v2).ok_or_else(|| panic_arith.clone())?,
246                         ArithOp::Shl => v1
247                             .checked_shl(v2.try_into().map_err(|_| panic_arith.clone())?)
248                             .ok_or_else(|| panic_arith.clone())?,
249                         ArithOp::Shr => v1
250                             .checked_shr(v2.try_into().map_err(|_| panic_arith.clone())?)
251                             .ok_or_else(|| panic_arith.clone())?,
252                         ArithOp::BitXor => v1 ^ v2,
253                         ArithOp::BitOr => v1 | v2,
254                         ArithOp::BitAnd => v1 & v2,
255                     };
256                     if let TyKind::Scalar(s) = ty.kind(Interner) {
257                         if !is_valid(s, r) {
258                             return Err(panic_arith);
259                         }
260                     }
261                     Ok(ComputedExpr::Literal(Literal::Int(r, None)))
262                 }
263                 BinaryOp::LogicOp(_) => Err(ConstEvalError::SemanticError("logic op on numbers")),
264                 _ => Err(ConstEvalError::NotSupported("bin op on this operators")),
265             }
266         }
267         Expr::Block { statements, tail, .. } => {
268             let mut prev_values = HashMap::<PatId, Option<ComputedExpr>>::default();
269             for statement in &**statements {
270                 match *statement {
271                     hir_def::expr::Statement::Let { pat: pat_id, initializer, .. } => {
272                         let pat = &ctx.pats[pat_id];
273                         match pat {
274                             Pat::Bind { subpat, .. } if subpat.is_none() => (),
275                             _ => {
276                                 return Err(ConstEvalError::NotSupported("complex patterns in let"))
277                             }
278                         };
279                         let value = match initializer {
280                             Some(x) => eval_const(x, ctx)?,
281                             None => continue,
282                         };
283                         if !prev_values.contains_key(&pat_id) {
284                             let prev = ctx.local_data.insert(pat_id, value);
285                             prev_values.insert(pat_id, prev);
286                         } else {
287                             ctx.local_data.insert(pat_id, value);
288                         }
289                     }
290                     hir_def::expr::Statement::Expr { .. } => {
291                         return Err(ConstEvalError::NotSupported("this kind of statement"))
292                     }
293                 }
294             }
295             let r = match tail {
296                 &Some(x) => eval_const(x, ctx),
297                 None => Ok(ComputedExpr::Tuple(Box::new([]))),
298             };
299             // clean up local data, so caller will receive the exact map that passed to us
300             for (name, val) in prev_values {
301                 match val {
302                     Some(x) => ctx.local_data.insert(name, x),
303                     None => ctx.local_data.remove(&name),
304                 };
305             }
306             r
307         }
308         Expr::Path(p) => {
309             let resolver = resolver_for_expr(ctx.db.upcast(), ctx.owner, expr_id);
310             let pr = resolver
311                 .resolve_path_in_value_ns(ctx.db.upcast(), p.mod_path())
312                 .ok_or(ConstEvalError::SemanticError("unresolved path"))?;
313             let pr = match pr {
314                 ResolveValueResult::ValueNs(v) => v,
315                 ResolveValueResult::Partial(..) => {
316                     return match ctx
317                         .infer
318                         .assoc_resolutions_for_expr(expr_id)
319                         .ok_or(ConstEvalError::SemanticError("unresolved assoc item"))?
320                     {
321                         hir_def::AssocItemId::FunctionId(_) => {
322                             Err(ConstEvalError::NotSupported("assoc function"))
323                         }
324                         hir_def::AssocItemId::ConstId(c) => ctx.db.const_eval(c),
325                         hir_def::AssocItemId::TypeAliasId(_) => {
326                             Err(ConstEvalError::NotSupported("assoc type alias"))
327                         }
328                     }
329                 }
330             };
331             match pr {
332                 ValueNs::LocalBinding(pat_id) => {
333                     let r = ctx
334                         .local_data
335                         .get(&pat_id)
336                         .ok_or(ConstEvalError::NotSupported("Unexpected missing local"))?;
337                     Ok(r.clone())
338                 }
339                 ValueNs::ConstId(id) => ctx.db.const_eval(id),
340                 ValueNs::GenericParam(_) => {
341                     Err(ConstEvalError::NotSupported("const generic without substitution"))
342                 }
343                 _ => Err(ConstEvalError::NotSupported("path that are not const or local")),
344             }
345         }
346         _ => Err(ConstEvalError::NotSupported("This kind of expression")),
347     }
348 }
349
350 pub(crate) fn path_to_const(
351     db: &dyn HirDatabase,
352     resolver: &Resolver,
353     path: &ModPath,
354     mode: ParamLoweringMode,
355     args_lazy: impl FnOnce() -> Generics,
356     debruijn: DebruijnIndex,
357 ) -> Option<Const> {
358     match resolver.resolve_path_in_value_ns_fully(db.upcast(), &path) {
359         Some(ValueNs::GenericParam(p)) => {
360             let ty = db.const_param_ty(p);
361             let args = args_lazy();
362             let value = match mode {
363                 ParamLoweringMode::Placeholder => {
364                     ConstValue::Placeholder(to_placeholder_idx(db, p.into()))
365                 }
366                 ParamLoweringMode::Variable => match args.param_idx(p.into()) {
367                     Some(x) => ConstValue::BoundVar(BoundVar::new(debruijn, x)),
368                     None => {
369                         never!(
370                             "Generic list doesn't contain this param: {:?}, {}, {:?}",
371                             args,
372                             path,
373                             p
374                         );
375                         return None;
376                     }
377                 },
378             };
379             Some(ConstData { ty, value }.intern(Interner))
380         }
381         _ => None,
382     }
383 }
384
385 pub fn unknown_const(ty: Ty) -> Const {
386     ConstData {
387         ty,
388         value: ConstValue::Concrete(chalk_ir::ConcreteConst { interned: ConstScalar::Unknown }),
389     }
390     .intern(Interner)
391 }
392
393 pub fn unknown_const_as_generic(ty: Ty) -> GenericArg {
394     GenericArgData::Const(unknown_const(ty)).intern(Interner)
395 }
396
397 /// Interns a constant scalar with the given type
398 pub fn intern_const_scalar_with_type(value: ConstScalar, ty: Ty) -> Const {
399     ConstData { ty, value: ConstValue::Concrete(chalk_ir::ConcreteConst { interned: value }) }
400         .intern(Interner)
401 }
402
403 /// Interns a possibly-unknown target usize
404 pub fn usize_const(value: Option<u128>) -> Const {
405     intern_const_scalar_with_type(
406         value.map(ConstScalar::UInt).unwrap_or(ConstScalar::Unknown),
407         TyBuilder::usize(),
408     )
409 }
410
411 /// Interns a constant scalar with the default type
412 pub fn intern_const_scalar(value: ConstScalar) -> Const {
413     intern_const_scalar_with_type(value, TyBuilder::builtin(value.builtin_type()))
414 }
415
416 pub(crate) fn const_eval_recover(
417     _: &dyn HirDatabase,
418     _: &[String],
419     _: &ConstId,
420 ) -> Result<ComputedExpr, ConstEvalError> {
421     Err(ConstEvalError::Loop)
422 }
423
424 pub(crate) fn const_eval_query(
425     db: &dyn HirDatabase,
426     const_id: ConstId,
427 ) -> Result<ComputedExpr, ConstEvalError> {
428     let def = const_id.into();
429     let body = db.body(def);
430     let infer = &db.infer(def);
431     let result = eval_const(
432         body.body_expr,
433         &mut ConstEvalCtx {
434             db,
435             owner: const_id.into(),
436             exprs: &body.exprs,
437             pats: &body.pats,
438             local_data: HashMap::default(),
439             infer,
440         },
441     );
442     result
443 }
444
445 pub(crate) fn eval_to_const<'a>(
446     expr: Idx<Expr>,
447     mode: ParamLoweringMode,
448     ctx: &mut InferenceContext<'a>,
449     args: impl FnOnce() -> Generics,
450     debruijn: DebruijnIndex,
451 ) -> Const {
452     if let Expr::Path(p) = &ctx.body.exprs[expr] {
453         let db = ctx.db;
454         let resolver = &ctx.resolver;
455         if let Some(c) = path_to_const(db, resolver, p.mod_path(), mode, args, debruijn) {
456             return c;
457         }
458     }
459     let body = ctx.body.clone();
460     let mut ctx = ConstEvalCtx {
461         db: ctx.db,
462         owner: ctx.owner,
463         exprs: &body.exprs,
464         pats: &body.pats,
465         local_data: HashMap::default(),
466         infer: &ctx.result,
467     };
468     let computed_expr = eval_const(expr, &mut ctx);
469     let const_scalar = match computed_expr {
470         Ok(ComputedExpr::Literal(literal)) => literal.into(),
471         _ => ConstScalar::Unknown,
472     };
473     intern_const_scalar_with_type(const_scalar, TyBuilder::usize())
474 }
475
476 #[cfg(test)]
477 mod tests;