]> git.lizzy.rs Git - rust.git/blob - src/bin/rustfmt-format-diff.rs
Do not print usage when rustfmt failed
[rust.git] / src / bin / rustfmt-format-diff.rs
1 // Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 // Inspired by Clang's clang-format-diff:
12 //
13 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
14
15 #![deny(warnings)]
16
17 extern crate env_logger;
18 extern crate getopts;
19 #[macro_use]
20 extern crate log;
21 extern crate regex;
22 #[macro_use]
23 extern crate serde_derive;
24 extern crate serde_json as json;
25
26 use std::{env, fmt, process};
27 use std::collections::HashSet;
28 use std::error::Error;
29 use std::io::{self, BufRead};
30
31 use regex::Regex;
32
33 /// The default pattern of files to format.
34 ///
35 /// We only want to format rust files by default.
36 const DEFAULT_PATTERN: &str = r".*\.rs";
37
38 #[derive(Debug)]
39 enum FormatDiffError {
40     IncorrectOptions(getopts::Fail),
41     IncorrectFilter(regex::Error),
42     IoError(io::Error),
43 }
44
45 impl fmt::Display for FormatDiffError {
46     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
47         fmt::Display::fmt(self.cause().unwrap(), f)
48     }
49 }
50
51 impl Error for FormatDiffError {
52     fn description(&self) -> &str {
53         self.cause().unwrap().description()
54     }
55
56     fn cause(&self) -> Option<&Error> {
57         Some(match *self {
58             FormatDiffError::IoError(ref e) => e,
59             FormatDiffError::IncorrectFilter(ref e) => e,
60             FormatDiffError::IncorrectOptions(ref e) => e,
61         })
62     }
63 }
64
65 impl From<getopts::Fail> for FormatDiffError {
66     fn from(fail: getopts::Fail) -> Self {
67         FormatDiffError::IncorrectOptions(fail)
68     }
69 }
70
71 impl From<regex::Error> for FormatDiffError {
72     fn from(err: regex::Error) -> Self {
73         FormatDiffError::IncorrectFilter(err)
74     }
75 }
76
77 impl From<io::Error> for FormatDiffError {
78     fn from(fail: io::Error) -> Self {
79         FormatDiffError::IoError(fail)
80     }
81 }
82
83 fn main() {
84     let _ = env_logger::init();
85
86     let mut opts = getopts::Options::new();
87     opts.optflag("h", "help", "show this message");
88     opts.optopt(
89         "p",
90         "skip-prefix",
91         "skip the smallest prefix containing NUMBER slashes",
92         "NUMBER",
93     );
94     opts.optopt(
95         "f",
96         "filter",
97         "custom pattern selecting file paths to reformat",
98         "PATTERN",
99     );
100
101     if let Err(e) = run(&opts) {
102         println!("{}", opts.usage(e.description()));
103         process::exit(1);
104     }
105 }
106
107 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
108 struct Range {
109     file: String,
110     range: [u32; 2],
111 }
112
113 fn run(opts: &getopts::Options) -> Result<(), FormatDiffError> {
114     let matches = opts.parse(env::args().skip(1))?;
115
116     if matches.opt_present("h") {
117         println!("{}", opts.usage("usage: "));
118         return Ok(());
119     }
120
121     let filter = matches
122         .opt_str("f")
123         .unwrap_or_else(|| DEFAULT_PATTERN.to_owned());
124
125     let skip_prefix = matches
126         .opt_str("p")
127         .and_then(|p| p.parse::<u32>().ok())
128         .unwrap_or(0);
129
130     let (files, ranges) = scan_diff(io::stdin(), skip_prefix, &filter)?;
131
132     run_rustfmt(&files, &ranges)
133 }
134
135 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
136     if files.is_empty() || ranges.is_empty() {
137         debug!("No files to format found");
138         return Ok(());
139     }
140
141     let ranges_as_json = json::to_string(ranges).unwrap();
142
143     debug!("Files: {:?}", files);
144     debug!("Ranges: {:?}", ranges);
145
146     let exit_status = process::Command::new("rustfmt")
147         .args(files)
148         .arg("--file-lines")
149         .arg(ranges_as_json)
150         .status()?;
151
152     if !exit_status.success() {
153         return Err(FormatDiffError::IoError(io::Error::new(
154             io::ErrorKind::Other,
155             format!("rustfmt failed with {}", exit_status),
156         )));
157     }
158     Ok(())
159 }
160
161 /// Scans a diff from `from`, and returns the set of files found, and the ranges
162 /// in those files.
163 fn scan_diff<R>(
164     from: R,
165     skip_prefix: u32,
166     file_filter: &str,
167 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
168 where
169     R: io::Read,
170 {
171     let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix);
172     let diff_pattern = Regex::new(&diff_pattern).unwrap();
173
174     let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
175
176     let file_filter = Regex::new(&format!("^{}$", file_filter))?;
177
178     let mut current_file = None;
179
180     let mut files = HashSet::new();
181     let mut ranges = vec![];
182     for line in io::BufReader::new(from).lines() {
183         let line = line.unwrap();
184
185         if let Some(captures) = diff_pattern.captures(&line) {
186             current_file = Some(captures.get(1).unwrap().as_str().to_owned());
187         }
188
189         let file = match current_file {
190             Some(ref f) => &**f,
191             None => continue,
192         };
193
194         // TODO(emilio): We could avoid this most of the time if needed, but
195         // it's not clear it's worth it.
196         if !file_filter.is_match(file) {
197             continue;
198         }
199
200         let lines_captures = match lines_pattern.captures(&line) {
201             Some(captures) => captures,
202             None => continue,
203         };
204
205         let start_line = lines_captures
206             .get(1)
207             .unwrap()
208             .as_str()
209             .parse::<u32>()
210             .unwrap();
211         let line_count = match lines_captures.get(3) {
212             Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
213             None => 1,
214         };
215
216         if line_count == 0 {
217             continue;
218         }
219
220         let end_line = start_line + line_count - 1;
221         files.insert(file.to_owned());
222         ranges.push(Range {
223             file: file.to_owned(),
224             range: [start_line, end_line],
225         });
226     }
227
228     Ok((files, ranges))
229 }
230
231 #[test]
232 fn scan_simple_git_diff() {
233     const DIFF: &'static str = include_str!("test/bindgen.diff");
234     let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
235
236     assert!(
237         files.contains("src/ir/traversal.rs"),
238         "Should've matched the filter"
239     );
240
241     assert!(
242         !files.contains("tests/headers/anon_enum.hpp"),
243         "Shouldn't have matched the filter"
244     );
245
246     assert_eq!(
247         &ranges,
248         &[
249             Range {
250                 file: "src/ir/item.rs".to_owned(),
251                 range: [148, 158],
252             },
253             Range {
254                 file: "src/ir/item.rs".to_owned(),
255                 range: [160, 170],
256             },
257             Range {
258                 file: "src/ir/traversal.rs".to_owned(),
259                 range: [9, 16],
260             },
261             Range {
262                 file: "src/ir/traversal.rs".to_owned(),
263                 range: [35, 43],
264             }
265         ]
266     );
267 }