]> git.proxmox.com Git - rustc.git/blob - src/tools/rustfmt/src/format-diff/main.rs
bump version to 1.80.1+dfsg1-1~bpo12+pve1
[rustc.git] / src / tools / rustfmt / src / format-diff / main.rs
1 // Inspired by Clang's clang-format-diff:
2 //
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
4
5 #![deny(warnings)]
6
7 #[macro_use]
8 extern crate tracing;
9
10 use serde::{Deserialize, Serialize};
11 use serde_json as json;
12 use thiserror::Error;
13 use tracing_subscriber::EnvFilter;
14
15 use std::collections::HashSet;
16 use std::env;
17 use std::ffi::OsStr;
18 use std::io::{self, BufRead};
19 use std::process;
20
21 use regex::Regex;
22
23 use clap::{CommandFactory, Parser};
24
25 /// The default pattern of files to format.
26 ///
27 /// We only want to format rust files by default.
28 const DEFAULT_PATTERN: &str = r".*\.rs";
29
30 #[derive(Error, Debug)]
31 enum FormatDiffError {
32 #[error("{0}")]
33 IncorrectOptions(#[from] getopts::Fail),
34 #[error("{0}")]
35 IncorrectFilter(#[from] regex::Error),
36 #[error("{0}")]
37 IoError(#[from] io::Error),
38 }
39
40 #[derive(Parser, Debug)]
41 #[clap(
42 name = "rustfmt-format-diff",
43 disable_version_flag = true,
44 next_line_help = true
45 )]
46 pub struct Opts {
47 /// Skip the smallest prefix containing NUMBER slashes
48 #[clap(
49 short = 'p',
50 long = "skip-prefix",
51 value_name = "NUMBER",
52 default_value = "0"
53 )]
54 skip_prefix: u32,
55
56 /// Custom pattern selecting file paths to reformat
57 #[clap(
58 short = 'f',
59 long = "filter",
60 value_name = "PATTERN",
61 default_value = DEFAULT_PATTERN
62 )]
63 filter: String,
64 }
65
66 fn main() {
67 tracing_subscriber::fmt()
68 .with_env_filter(EnvFilter::from_env("RUSTFMT_LOG"))
69 .init();
70 let opts = Opts::parse();
71 if let Err(e) = run(opts) {
72 println!("{e}");
73 Opts::command()
74 .print_help()
75 .expect("cannot write to stdout");
76 process::exit(1);
77 }
78 }
79
80 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
81 struct Range {
82 file: String,
83 range: [u32; 2],
84 }
85
86 fn run(opts: Opts) -> Result<(), FormatDiffError> {
87 let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
88 run_rustfmt(&files, &ranges)
89 }
90
91 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
92 if files.is_empty() || ranges.is_empty() {
93 debug!("No files to format found");
94 return Ok(());
95 }
96
97 let ranges_as_json = json::to_string(ranges).unwrap();
98
99 debug!("Files: {:?}", files);
100 debug!("Ranges: {:?}", ranges);
101
102 let rustfmt_var = env::var_os("RUSTFMT");
103 let rustfmt = match &rustfmt_var {
104 Some(rustfmt) => rustfmt,
105 None => OsStr::new("rustfmt"),
106 };
107 let exit_status = process::Command::new(rustfmt)
108 .args(files)
109 .arg("--file-lines")
110 .arg(ranges_as_json)
111 .status()?;
112
113 if !exit_status.success() {
114 return Err(FormatDiffError::IoError(io::Error::new(
115 io::ErrorKind::Other,
116 format!("rustfmt failed with {exit_status}"),
117 )));
118 }
119 Ok(())
120 }
121
122 /// Scans a diff from `from`, and returns the set of files found, and the ranges
123 /// in those files.
124 fn scan_diff<R>(
125 from: R,
126 skip_prefix: u32,
127 file_filter: &str,
128 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
129 where
130 R: io::Read,
131 {
132 let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{skip_prefix}}}(\S*)");
133 let diff_pattern = Regex::new(&diff_pattern).unwrap();
134
135 let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
136
137 let file_filter = Regex::new(&format!("^{file_filter}$"))?;
138
139 let mut current_file = None;
140
141 let mut files = HashSet::new();
142 let mut ranges = vec![];
143 for line in io::BufReader::new(from).lines() {
144 let line = line.unwrap();
145
146 if let Some(captures) = diff_pattern.captures(&line) {
147 current_file = Some(captures.get(1).unwrap().as_str().to_owned());
148 }
149
150 let file = match current_file {
151 Some(ref f) => &**f,
152 None => continue,
153 };
154
155 // FIXME(emilio): We could avoid this most of the time if needed, but
156 // it's not clear it's worth it.
157 if !file_filter.is_match(file) {
158 continue;
159 }
160
161 let lines_captures = match lines_pattern.captures(&line) {
162 Some(captures) => captures,
163 None => continue,
164 };
165
166 let start_line = lines_captures
167 .get(1)
168 .unwrap()
169 .as_str()
170 .parse::<u32>()
171 .unwrap();
172 let line_count = match lines_captures.get(3) {
173 Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
174 None => 1,
175 };
176
177 if line_count == 0 {
178 continue;
179 }
180
181 let end_line = start_line + line_count - 1;
182 files.insert(file.to_owned());
183 ranges.push(Range {
184 file: file.to_owned(),
185 range: [start_line, end_line],
186 });
187 }
188
189 Ok((files, ranges))
190 }
191
192 #[test]
193 fn scan_simple_git_diff() {
194 const DIFF: &str = include_str!("test/bindgen.diff");
195 let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
196
197 assert!(
198 files.contains("src/ir/traversal.rs"),
199 "Should've matched the filter"
200 );
201
202 assert!(
203 !files.contains("tests/headers/anon_enum.hpp"),
204 "Shouldn't have matched the filter"
205 );
206
207 assert_eq!(
208 &ranges,
209 &[
210 Range {
211 file: "src/ir/item.rs".to_owned(),
212 range: [148, 158],
213 },
214 Range {
215 file: "src/ir/item.rs".to_owned(),
216 range: [160, 170],
217 },
218 Range {
219 file: "src/ir/traversal.rs".to_owned(),
220 range: [9, 16],
221 },
222 Range {
223 file: "src/ir/traversal.rs".to_owned(),
224 range: [35, 43],
225 },
226 ]
227 );
228 }
229
230 #[cfg(test)]
231 mod cmd_line_tests {
232 use super::*;
233
234 #[test]
235 fn default_options() {
236 let empty: Vec<String> = vec![];
237 let o = Opts::parse_from(&empty);
238 assert_eq!(DEFAULT_PATTERN, o.filter);
239 assert_eq!(0, o.skip_prefix);
240 }
241
242 #[test]
243 fn good_options() {
244 let o = Opts::parse_from(&["test", "-p", "10", "-f", r".*\.hs"]);
245 assert_eq!(r".*\.hs", o.filter);
246 assert_eq!(10, o.skip_prefix);
247 }
248
249 #[test]
250 fn unexpected_option() {
251 assert!(
252 Opts::command()
253 .try_get_matches_from(&["test", "unexpected"])
254 .is_err()
255 );
256 }
257
258 #[test]
259 fn unexpected_flag() {
260 assert!(
261 Opts::command()
262 .try_get_matches_from(&["test", "--flag"])
263 .is_err()
264 );
265 }
266
267 #[test]
268 fn overridden_option() {
269 assert!(
270 Opts::command()
271 .try_get_matches_from(&["test", "-p", "10", "-p", "20"])
272 .is_err()
273 );
274 }
275
276 #[test]
277 fn negative_filter() {
278 assert!(
279 Opts::command()
280 .try_get_matches_from(&["test", "-p", "-1"])
281 .is_err()
282 );
283 }
284 }