]>
git.proxmox.com Git - rustc.git/blob - src/tools/rustfmt/src/format-diff/main.rs
1 // Inspired by Clang's clang-format-diff:
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
10 use serde
::{Deserialize, Serialize}
;
11 use serde_json
as json
;
13 use tracing_subscriber
::EnvFilter
;
15 use std
::collections
::HashSet
;
18 use std
::io
::{self, BufRead}
;
23 use clap
::{CommandFactory, Parser}
;
25 /// The default pattern of files to format.
27 /// We only want to format rust files by default.
28 const DEFAULT_PATTERN
: &str = r
".*\.rs";
30 #[derive(Error, Debug)]
31 enum FormatDiffError
{
33 IncorrectOptions(#[from] getopts::Fail),
35 IncorrectFilter(#[from] regex::Error),
37 IoError(#[from] io::Error),
40 #[derive(Parser, Debug)]
42 name
= "rustfmt-format-diff",
43 disable_version_flag
= true,
47 /// Skip the smallest prefix containing NUMBER slashes
51 value_name
= "NUMBER",
56 /// Custom pattern selecting file paths to reformat
60 value_name
= "PATTERN",
61 default_value
= DEFAULT_PATTERN
67 tracing_subscriber
::fmt()
68 .with_env_filter(EnvFilter
::from_env("RUSTFMT_LOG"))
70 let opts
= Opts
::parse();
71 if let Err(e
) = run(opts
) {
75 .expect("cannot write to stdout");
80 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
86 fn run(opts
: Opts
) -> Result
<(), FormatDiffError
> {
87 let (files
, ranges
) = scan_diff(io
::stdin(), opts
.skip_prefix
, &opts
.filter
)?
;
88 run_rustfmt(&files
, &ranges
)
91 fn run_rustfmt(files
: &HashSet
<String
>, ranges
: &[Range
]) -> Result
<(), FormatDiffError
> {
92 if files
.is_empty() || ranges
.is_empty() {
93 debug
!("No files to format found");
97 let ranges_as_json
= json
::to_string(ranges
).unwrap();
99 debug
!("Files: {:?}", files
);
100 debug
!("Ranges: {:?}", ranges
);
102 let rustfmt_var
= env
::var_os("RUSTFMT");
103 let rustfmt
= match &rustfmt_var
{
104 Some(rustfmt
) => rustfmt
,
105 None
=> OsStr
::new("rustfmt"),
107 let exit_status
= process
::Command
::new(rustfmt
)
113 if !exit_status
.success() {
114 return Err(FormatDiffError
::IoError(io
::Error
::new(
115 io
::ErrorKind
::Other
,
116 format
!("rustfmt failed with {exit_status}"),
122 /// Scans a diff from `from`, and returns the set of files found, and the ranges
128 ) -> Result
<(HashSet
<String
>, Vec
<Range
>), FormatDiffError
>
132 let diff_pattern
= format
!(r
"^\+\+\+\s(?:.*?/){{{skip_prefix}}}(\S*)");
133 let diff_pattern
= Regex
::new(&diff_pattern
).unwrap();
135 let lines_pattern
= Regex
::new(r
"^@@.*\+(\d+)(,(\d+))?").unwrap();
137 let file_filter
= Regex
::new(&format
!("^{file_filter}$"))?
;
139 let mut current_file
= None
;
141 let mut files
= HashSet
::new();
142 let mut ranges
= vec
![];
143 for line
in io
::BufReader
::new(from
).lines() {
144 let line
= line
.unwrap();
146 if let Some(captures
) = diff_pattern
.captures(&line
) {
147 current_file
= Some(captures
.get(1).unwrap().as_str().to_owned());
150 let file
= match current_file
{
155 // FIXME(emilio): We could avoid this most of the time if needed, but
156 // it's not clear it's worth it.
157 if !file_filter
.is_match(file
) {
161 let lines_captures
= match lines_pattern
.captures(&line
) {
162 Some(captures
) => captures
,
166 let start_line
= lines_captures
172 let line_count
= match lines_captures
.get(3) {
173 Some(line_count
) => line_count
.as_str().parse
::<u32>().unwrap(),
181 let end_line
= start_line
+ line_count
- 1;
182 files
.insert(file
.to_owned());
184 file
: file
.to_owned(),
185 range
: [start_line
, end_line
],
193 fn scan_simple_git_diff() {
194 const DIFF
: &str = include_str
!("test/bindgen.diff");
195 let (files
, ranges
) = scan_diff(DIFF
.as_bytes(), 1, r
".*\.rs").expect("scan_diff failed?");
198 files
.contains("src/ir/traversal.rs"),
199 "Should've matched the filter"
203 !files
.contains("tests/headers/anon_enum.hpp"),
204 "Shouldn't have matched the filter"
211 file
: "src/ir/item.rs".to_owned(),
215 file
: "src/ir/item.rs".to_owned(),
219 file
: "src/ir/traversal.rs".to_owned(),
223 file
: "src/ir/traversal.rs".to_owned(),
235 fn default_options() {
236 let empty
: Vec
<String
> = vec
![];
237 let o
= Opts
::parse_from(&empty
);
238 assert_eq
!(DEFAULT_PATTERN
, o
.filter
);
239 assert_eq
!(0, o
.skip_prefix
);
244 let o
= Opts
::parse_from(&["test", "-p", "10", "-f", r
".*\.hs"]);
245 assert_eq
!(r
".*\.hs", o
.filter
);
246 assert_eq
!(10, o
.skip_prefix
);
250 fn unexpected_option() {
253 .try_get_matches_from(&["test", "unexpected"])
259 fn unexpected_flag() {
262 .try_get_matches_from(&["test", "--flag"])
268 fn overridden_option() {
271 .try_get_matches_from(&["test", "-p", "10", "-p", "20"])
277 fn negative_filter() {
280 .try_get_matches_from(&["test", "-p", "-1"])