]>
Commit | Line | Data |
---|---|---|
f20569fa XL |
1 | // Inspired by Clang's clang-format-diff: |
2 | // | |
3 | // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py | |
4 | ||
5 | #![deny(warnings)] | |
6 | ||
7 | use env_logger; | |
8 | #[macro_use] | |
9 | extern crate log; | |
10 | use regex; | |
11 | use serde::{Deserialize, Serialize}; | |
12 | use serde_json as json; | |
13 | use thiserror::Error; | |
14 | ||
15 | use std::collections::HashSet; | |
16 | use std::env; | |
17 | use std::ffi::OsStr; | |
18 | use std::io::{self, BufRead}; | |
19 | use std::process; | |
20 | ||
21 | use regex::Regex; | |
22 | ||
23 | use structopt::clap::AppSettings; | |
24 | use structopt::StructOpt; | |
25 | ||
26 | /// The default pattern of files to format. | |
27 | /// | |
28 | /// We only want to format rust files by default. | |
29 | const DEFAULT_PATTERN: &str = r".*\.rs"; | |
30 | ||
31 | #[derive(Error, Debug)] | |
32 | enum FormatDiffError { | |
33 | #[error("{0}")] | |
34 | IncorrectOptions(#[from] getopts::Fail), | |
35 | #[error("{0}")] | |
36 | IncorrectFilter(#[from] regex::Error), | |
37 | #[error("{0}")] | |
38 | IoError(#[from] io::Error), | |
39 | } | |
40 | ||
41 | #[derive(StructOpt, Debug)] | |
42 | #[structopt( | |
43 | name = "rustfmt-format-diff", | |
44 | setting = AppSettings::DisableVersion, | |
45 | setting = AppSettings::NextLineHelp | |
46 | )] | |
47 | pub struct Opts { | |
48 | /// Skip the smallest prefix containing NUMBER slashes | |
49 | #[structopt( | |
50 | short = "p", | |
51 | long = "skip-prefix", | |
52 | value_name = "NUMBER", | |
53 | default_value = "0" | |
54 | )] | |
55 | skip_prefix: u32, | |
56 | ||
57 | /// Custom pattern selecting file paths to reformat | |
58 | #[structopt( | |
59 | short = "f", | |
60 | long = "filter", | |
61 | value_name = "PATTERN", | |
62 | default_value = DEFAULT_PATTERN | |
63 | )] | |
64 | filter: String, | |
65 | } | |
66 | ||
67 | fn main() { | |
68 | env_logger::init(); | |
69 | let opts = Opts::from_args(); | |
70 | if let Err(e) = run(opts) { | |
71 | println!("{}", e); | |
72 | Opts::clap().print_help().expect("cannot write to stdout"); | |
73 | process::exit(1); | |
74 | } | |
75 | } | |
76 | ||
77 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] | |
78 | struct Range { | |
79 | file: String, | |
80 | range: [u32; 2], | |
81 | } | |
82 | ||
83 | fn run(opts: Opts) -> Result<(), FormatDiffError> { | |
84 | let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?; | |
85 | run_rustfmt(&files, &ranges) | |
86 | } | |
87 | ||
88 | fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> { | |
89 | if files.is_empty() || ranges.is_empty() { | |
90 | debug!("No files to format found"); | |
91 | return Ok(()); | |
92 | } | |
93 | ||
94 | let ranges_as_json = json::to_string(ranges).unwrap(); | |
95 | ||
96 | debug!("Files: {:?}", files); | |
97 | debug!("Ranges: {:?}", ranges); | |
98 | ||
99 | let rustfmt_var = env::var_os("RUSTFMT"); | |
100 | let rustfmt = match &rustfmt_var { | |
101 | Some(rustfmt) => rustfmt, | |
102 | None => OsStr::new("rustfmt"), | |
103 | }; | |
104 | let exit_status = process::Command::new(rustfmt) | |
105 | .args(files) | |
106 | .arg("--file-lines") | |
107 | .arg(ranges_as_json) | |
108 | .status()?; | |
109 | ||
110 | if !exit_status.success() { | |
111 | return Err(FormatDiffError::IoError(io::Error::new( | |
112 | io::ErrorKind::Other, | |
113 | format!("rustfmt failed with {}", exit_status), | |
114 | ))); | |
115 | } | |
116 | Ok(()) | |
117 | } | |
118 | ||
119 | /// Scans a diff from `from`, and returns the set of files found, and the ranges | |
120 | /// in those files. | |
121 | fn scan_diff<R>( | |
122 | from: R, | |
123 | skip_prefix: u32, | |
124 | file_filter: &str, | |
125 | ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError> | |
126 | where | |
127 | R: io::Read, | |
128 | { | |
129 | let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix); | |
130 | let diff_pattern = Regex::new(&diff_pattern).unwrap(); | |
131 | ||
132 | let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap(); | |
133 | ||
134 | let file_filter = Regex::new(&format!("^{}$", file_filter))?; | |
135 | ||
136 | let mut current_file = None; | |
137 | ||
138 | let mut files = HashSet::new(); | |
139 | let mut ranges = vec![]; | |
140 | for line in io::BufReader::new(from).lines() { | |
141 | let line = line.unwrap(); | |
142 | ||
143 | if let Some(captures) = diff_pattern.captures(&line) { | |
144 | current_file = Some(captures.get(1).unwrap().as_str().to_owned()); | |
145 | } | |
146 | ||
147 | let file = match current_file { | |
148 | Some(ref f) => &**f, | |
149 | None => continue, | |
150 | }; | |
151 | ||
152 | // FIXME(emilio): We could avoid this most of the time if needed, but | |
153 | // it's not clear it's worth it. | |
154 | if !file_filter.is_match(file) { | |
155 | continue; | |
156 | } | |
157 | ||
158 | let lines_captures = match lines_pattern.captures(&line) { | |
159 | Some(captures) => captures, | |
160 | None => continue, | |
161 | }; | |
162 | ||
163 | let start_line = lines_captures | |
164 | .get(1) | |
165 | .unwrap() | |
166 | .as_str() | |
167 | .parse::<u32>() | |
168 | .unwrap(); | |
169 | let line_count = match lines_captures.get(3) { | |
170 | Some(line_count) => line_count.as_str().parse::<u32>().unwrap(), | |
171 | None => 1, | |
172 | }; | |
173 | ||
174 | if line_count == 0 { | |
175 | continue; | |
176 | } | |
177 | ||
178 | let end_line = start_line + line_count - 1; | |
179 | files.insert(file.to_owned()); | |
180 | ranges.push(Range { | |
181 | file: file.to_owned(), | |
182 | range: [start_line, end_line], | |
183 | }); | |
184 | } | |
185 | ||
186 | Ok((files, ranges)) | |
187 | } | |
188 | ||
189 | #[test] | |
190 | fn scan_simple_git_diff() { | |
191 | const DIFF: &str = include_str!("test/bindgen.diff"); | |
192 | let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?"); | |
193 | ||
194 | assert!( | |
195 | files.contains("src/ir/traversal.rs"), | |
196 | "Should've matched the filter" | |
197 | ); | |
198 | ||
199 | assert!( | |
200 | !files.contains("tests/headers/anon_enum.hpp"), | |
201 | "Shouldn't have matched the filter" | |
202 | ); | |
203 | ||
204 | assert_eq!( | |
205 | &ranges, | |
206 | &[ | |
207 | Range { | |
208 | file: "src/ir/item.rs".to_owned(), | |
209 | range: [148, 158], | |
210 | }, | |
211 | Range { | |
212 | file: "src/ir/item.rs".to_owned(), | |
213 | range: [160, 170], | |
214 | }, | |
215 | Range { | |
216 | file: "src/ir/traversal.rs".to_owned(), | |
217 | range: [9, 16], | |
218 | }, | |
219 | Range { | |
220 | file: "src/ir/traversal.rs".to_owned(), | |
221 | range: [35, 43], | |
222 | }, | |
223 | ] | |
224 | ); | |
225 | } | |
226 | ||
227 | #[cfg(test)] | |
228 | mod cmd_line_tests { | |
229 | use super::*; | |
230 | ||
231 | #[test] | |
232 | fn default_options() { | |
233 | let empty: Vec<String> = vec![]; | |
234 | let o = Opts::from_iter(&empty); | |
235 | assert_eq!(DEFAULT_PATTERN, o.filter); | |
236 | assert_eq!(0, o.skip_prefix); | |
237 | } | |
238 | ||
239 | #[test] | |
240 | fn good_options() { | |
241 | let o = Opts::from_iter(&["test", "-p", "10", "-f", r".*\.hs"]); | |
242 | assert_eq!(r".*\.hs", o.filter); | |
243 | assert_eq!(10, o.skip_prefix); | |
244 | } | |
245 | ||
246 | #[test] | |
247 | fn unexpected_option() { | |
248 | assert!( | |
249 | Opts::clap() | |
250 | .get_matches_from_safe(&["test", "unexpected"]) | |
251 | .is_err() | |
252 | ); | |
253 | } | |
254 | ||
255 | #[test] | |
256 | fn unexpected_flag() { | |
257 | assert!( | |
258 | Opts::clap() | |
259 | .get_matches_from_safe(&["test", "--flag"]) | |
260 | .is_err() | |
261 | ); | |
262 | } | |
263 | ||
264 | #[test] | |
265 | fn overridden_option() { | |
266 | assert!( | |
267 | Opts::clap() | |
268 | .get_matches_from_safe(&["test", "-p", "10", "-p", "20"]) | |
269 | .is_err() | |
270 | ); | |
271 | } | |
272 | ||
273 | #[test] | |
274 | fn negative_filter() { | |
275 | assert!( | |
276 | Opts::clap() | |
277 | .get_matches_from_safe(&["test", "-p", "-1"]) | |
278 | .is_err() | |
279 | ); | |
280 | } | |
281 | } |