1 // Inspired by Clang's clang-format-diff:
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
10 use serde::{Deserialize, Serialize};
11 use serde_json as json;
14 use std::collections::HashSet;
17 use std::io::{self, BufRead};
22 use structopt::clap::AppSettings;
23 use structopt::StructOpt;
25 /// The default pattern of files to format.
27 /// We only want to format rust files by default.
28 const DEFAULT_PATTERN: &str = r".*\.rs";
30 #[derive(Error, Debug)]
31 enum FormatDiffError {
33 IncorrectOptions(#[from] getopts::Fail),
35 IncorrectFilter(#[from] regex::Error),
37 IoError(#[from] io::Error),
40 #[derive(StructOpt, Debug)]
42 name = "rustfmt-format-diff",
43 setting = AppSettings::DisableVersion,
44 setting = AppSettings::NextLineHelp
47 /// Skip the smallest prefix containing NUMBER slashes
51 value_name = "NUMBER",
56 /// Custom pattern selecting file paths to reformat
60 value_name = "PATTERN",
61 default_value = DEFAULT_PATTERN
67 env_logger::Builder::from_env("RUSTFMT_LOG").init();
68 let opts = Opts::from_args();
69 if let Err(e) = run(opts) {
71 Opts::clap().print_help().expect("cannot write to stdout");
76 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
82 fn run(opts: Opts) -> Result<(), FormatDiffError> {
83 let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
84 run_rustfmt(&files, &ranges)
87 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
88 if files.is_empty() || ranges.is_empty() {
89 debug!("No files to format found");
93 let ranges_as_json = json::to_string(ranges).unwrap();
95 debug!("Files: {:?}", files);
96 debug!("Ranges: {:?}", ranges);
98 let rustfmt_var = env::var_os("RUSTFMT");
99 let rustfmt = match &rustfmt_var {
100 Some(rustfmt) => rustfmt,
101 None => OsStr::new("rustfmt"),
103 let exit_status = process::Command::new(rustfmt)
109 if !exit_status.success() {
110 return Err(FormatDiffError::IoError(io::Error::new(
111 io::ErrorKind::Other,
112 format!("rustfmt failed with {}", exit_status),
118 /// Scans a diff from `from`, and returns the set of files found, and the ranges
124 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
128 let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix);
129 let diff_pattern = Regex::new(&diff_pattern).unwrap();
131 let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
133 let file_filter = Regex::new(&format!("^{}$", file_filter))?;
135 let mut current_file = None;
137 let mut files = HashSet::new();
138 let mut ranges = vec![];
139 for line in io::BufReader::new(from).lines() {
140 let line = line.unwrap();
142 if let Some(captures) = diff_pattern.captures(&line) {
143 current_file = Some(captures.get(1).unwrap().as_str().to_owned());
146 let file = match current_file {
151 // FIXME(emilio): We could avoid this most of the time if needed, but
152 // it's not clear it's worth it.
153 if !file_filter.is_match(file) {
157 let lines_captures = match lines_pattern.captures(&line) {
158 Some(captures) => captures,
162 let start_line = lines_captures
168 let line_count = match lines_captures.get(3) {
169 Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
177 let end_line = start_line + line_count - 1;
178 files.insert(file.to_owned());
180 file: file.to_owned(),
181 range: [start_line, end_line],
189 fn scan_simple_git_diff() {
190 const DIFF: &str = include_str!("test/bindgen.diff");
191 let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
194 files.contains("src/ir/traversal.rs"),
195 "Should've matched the filter"
199 !files.contains("tests/headers/anon_enum.hpp"),
200 "Shouldn't have matched the filter"
207 file: "src/ir/item.rs".to_owned(),
211 file: "src/ir/item.rs".to_owned(),
215 file: "src/ir/traversal.rs".to_owned(),
219 file: "src/ir/traversal.rs".to_owned(),
231 fn default_options() {
232 let empty: Vec<String> = vec![];
233 let o = Opts::from_iter(&empty);
234 assert_eq!(DEFAULT_PATTERN, o.filter);
235 assert_eq!(0, o.skip_prefix);
240 let o = Opts::from_iter(&["test", "-p", "10", "-f", r".*\.hs"]);
241 assert_eq!(r".*\.hs", o.filter);
242 assert_eq!(10, o.skip_prefix);
246 fn unexpected_option() {
249 .get_matches_from_safe(&["test", "unexpected"])
255 fn unexpected_flag() {
258 .get_matches_from_safe(&["test", "--flag"])
264 fn overridden_option() {
267 .get_matches_from_safe(&["test", "-p", "10", "-p", "20"])
273 fn negative_filter() {
276 .get_matches_from_safe(&["test", "-p", "-1"])