]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/mul_add.rs
MutImmutable -> Immutable, MutMutable -> Mutable, CaptureClause -> CaptureBy
[rust.git] / clippy_lints / src / mul_add.rs
1 use rustc::hir::*;
2 use rustc::lint::{LateContext, LateLintPass, LintArray, LintPass};
3 use rustc::{declare_lint_pass, declare_tool_lint};
4 use rustc_errors::Applicability;
5
6 use crate::utils::*;
7
8 declare_clippy_lint! {
9     /// **What it does:** Checks for expressions of the form `a * b + c`
10     /// or `c + a * b` where `a`, `b`, `c` are floats and suggests using
11     /// `a.mul_add(b, c)` instead.
12     ///
13     /// **Why is this bad?** Calculating `a * b + c` may lead to slight
14     /// numerical inaccuracies as `a * b` is rounded before being added to
15     /// `c`. Depending on the target architecture, `mul_add()` may be more
16     /// performant.
17     ///
18     /// **Known problems:** This lint can emit semantic incorrect suggestions.
19     /// For example, for `a * b * c + d` the suggestion `a * b.mul_add(c, d)`
20     /// is emitted, which is equivalent to `a * (b * c + d)`. (#4735)
21     ///
22     /// **Example:**
23     ///
24     /// ```rust
25     /// # let a = 0_f32;
26     /// # let b = 0_f32;
27     /// # let c = 0_f32;
28     /// let foo = (a * b) + c;
29     /// ```
30     ///
31     /// can be written as
32     ///
33     /// ```rust
34     /// # let a = 0_f32;
35     /// # let b = 0_f32;
36     /// # let c = 0_f32;
37     /// let foo = a.mul_add(b, c);
38     /// ```
39     pub MANUAL_MUL_ADD,
40     nursery,
41     "Using `a.mul_add(b, c)` for floating points has higher numerical precision than `a * b + c`"
42 }
43
44 declare_lint_pass!(MulAddCheck => [MANUAL_MUL_ADD]);
45
46 fn is_float<'a, 'tcx>(cx: &LateContext<'a, 'tcx>, expr: &Expr) -> bool {
47     cx.tables.expr_ty(expr).is_floating_point()
48 }
49
50 // Checks whether expression is multiplication of two floats
51 fn is_float_mult_expr<'a, 'tcx, 'b>(cx: &LateContext<'a, 'tcx>, expr: &'b Expr) -> Option<(&'b Expr, &'b Expr)> {
52     if let ExprKind::Binary(op, lhs, rhs) = &expr.kind {
53         if let BinOpKind::Mul = op.node {
54             if is_float(cx, &lhs) && is_float(cx, &rhs) {
55                 return Some((&lhs, &rhs));
56             }
57         }
58     }
59
60     None
61 }
62
63 impl<'a, 'tcx> LateLintPass<'a, 'tcx> for MulAddCheck {
64     fn check_expr(&mut self, cx: &LateContext<'a, 'tcx>, expr: &'tcx Expr) {
65         if let ExprKind::Binary(op, lhs, rhs) = &expr.kind {
66             if let BinOpKind::Add = op.node {
67                 //Converts mult_lhs * mult_rhs + rhs to mult_lhs.mult_add(mult_rhs, rhs)
68                 if let Some((mult_lhs, mult_rhs)) = is_float_mult_expr(cx, lhs) {
69                     if is_float(cx, rhs) {
70                         span_lint_and_sugg(
71                             cx,
72                             MANUAL_MUL_ADD,
73                             expr.span,
74                             "consider using mul_add() for better numerical precision",
75                             "try",
76                             format!(
77                                 "{}.mul_add({}, {})",
78                                 snippet(cx, mult_lhs.span, "_"),
79                                 snippet(cx, mult_rhs.span, "_"),
80                                 snippet(cx, rhs.span, "_"),
81                             ),
82                             Applicability::MaybeIncorrect,
83                         );
84                     }
85                 }
86                 //Converts lhs + mult_lhs * mult_rhs to mult_lhs.mult_add(mult_rhs, lhs)
87                 if let Some((mult_lhs, mult_rhs)) = is_float_mult_expr(cx, rhs) {
88                     if is_float(cx, lhs) {
89                         span_lint_and_sugg(
90                             cx,
91                             MANUAL_MUL_ADD,
92                             expr.span,
93                             "consider using mul_add() for better numerical precision",
94                             "try",
95                             format!(
96                                 "{}.mul_add({}, {})",
97                                 snippet(cx, mult_lhs.span, "_"),
98                                 snippet(cx, mult_rhs.span, "_"),
99                                 snippet(cx, lhs.span, "_"),
100                             ),
101                             Applicability::MaybeIncorrect,
102                         );
103                     }
104                 }
105             }
106         }
107     }
108 }