1 # Copyright 2006 Google, Inc. All Rights Reserved.
2 # Licensed to PSF under a Contributor Agreement.
4 """Refactoring framework.
6 Used as a main program, this can refactor any number of files and/or
7 recursively descend down directories. Imported as a module, this
8 provides infrastructure to write your own refactoring tool.
11 from __future__
import with_statement
13 __author__
= "Guido van Rossum <guido@python.org>"
23 from itertools
import chain
26 from .pgen2
import driver
, tokenize
, token
27 from .fixer_util
import find_root
28 from . import pytree
, pygram
29 from . import btm_utils
as bu
30 from . import btm_matcher
as bm
33 def get_all_fix_names(fixer_pkg
, remove_prefix
=True):
34 """Return a sorted list of all available fix names in the given package."""
35 pkg
= __import__(fixer_pkg
, [], [], ["*"])
36 fixer_dir
= os
.path
.dirname(pkg
.__file
__)
38 for name
in sorted(os
.listdir(fixer_dir
)):
39 if name
.startswith("fix_") and name
.endswith(".py"):
42 fix_names
.append(name
[:-3])
46 class _EveryNode(Exception):
50 def _get_head_types(pat
):
51 """ Accepts a pytree Pattern Node and returns a set
52 of the pattern types which will match first. """
54 if isinstance(pat
, (pytree
.NodePattern
, pytree
.LeafPattern
)):
55 # NodePatters must either have no type and no content
56 # or a type and content -- so they don't get any farther
60 return set([pat
.type])
62 if isinstance(pat
, pytree
.NegatedPattern
):
64 return _get_head_types(pat
.content
)
65 raise _EveryNode
# Negated Patterns don't have a type
67 if isinstance(pat
, pytree
.WildcardPattern
):
68 # Recurse on each node in content
72 r
.update(_get_head_types(x
))
75 raise Exception("Oh no! I don't understand pattern %s" %(pat))
78 def _get_headnode_dict(fixer_list
):
79 """ Accepts a list of fixers and returns a dictionary
80 of head node type --> fixer list. """
81 head_nodes
= collections
.defaultdict(list)
83 for fixer
in fixer_list
:
86 heads
= _get_head_types(fixer
.pattern
)
90 for node_type
in heads
:
91 head_nodes
[node_type
].append(fixer
)
93 if fixer
._accept
_type
is not None:
94 head_nodes
[fixer
._accept
_type
].append(fixer
)
97 for node_type
in chain(pygram
.python_grammar
.symbol2number
.itervalues(),
98 pygram
.python_grammar
.tokens
):
99 head_nodes
[node_type
].extend(every
)
100 return dict(head_nodes
)
103 def get_fixers_from_package(pkg_name
):
105 Return the fully qualified names for fixers in the package pkg_name.
107 return [pkg_name
+ "." + fix_name
108 for fix_name
in get_all_fix_names(pkg_name
, False)]
113 if sys
.version_info
< (3, 0):
115 _open_with_encoding
= codecs
.open
116 # codecs.open doesn't translate newlines sadly.
117 def _from_system_newlines(input):
118 return input.replace(u
"\r\n", u
"\n")
119 def _to_system_newlines(input):
120 if os
.linesep
!= "\n":
121 return input.replace(u
"\n", os
.linesep
)
125 _open_with_encoding
= open
126 _from_system_newlines
= _identity
127 _to_system_newlines
= _identity
130 def _detect_future_features(source
):
131 have_docstring
= False
132 gen
= tokenize
.generate_tokens(StringIO
.StringIO(source
).readline
)
135 return tok
[0], tok
[1]
136 ignore
= frozenset((token
.NEWLINE
, tokenize
.NL
, token
.COMMENT
))
140 tp
, value
= advance()
143 elif tp
== token
.STRING
:
146 have_docstring
= True
147 elif tp
== token
.NAME
and value
== u
"from":
148 tp
, value
= advance()
149 if tp
!= token
.NAME
or value
!= u
"__future__":
151 tp
, value
= advance()
152 if tp
!= token
.NAME
or value
!= u
"import":
154 tp
, value
= advance()
155 if tp
== token
.OP
and value
== u
"(":
156 tp
, value
= advance()
157 while tp
== token
.NAME
:
159 tp
, value
= advance()
160 if tp
!= token
.OP
or value
!= u
",":
162 tp
, value
= advance()
165 except StopIteration:
167 return frozenset(features
)
170 class FixerError(Exception):
171 """A fixer could not be loaded."""
174 class RefactoringTool(object):
176 _default_options
= {"print_function" : False}
178 CLASS_PREFIX
= "Fix" # The prefix for fixer classes
179 FILE_PREFIX
= "fix_" # The prefix for modules with a fixer within
181 def __init__(self
, fixer_names
, options
=None, explicit
=None):
185 fixer_names: a list of fixers to import
186 options: an dict with configuration.
187 explicit: a list of fixers to run even if they are explicit.
189 self
.fixers
= fixer_names
190 self
.explicit
= explicit
or []
191 self
.options
= self
._default
_options
.copy()
192 if options
is not None:
193 self
.options
.update(options
)
194 if self
.options
["print_function"]:
195 self
.grammar
= pygram
.python_grammar_no_print_statement
197 self
.grammar
= pygram
.python_grammar
199 self
.logger
= logging
.getLogger("RefactoringTool")
202 self
.driver
= driver
.Driver(self
.grammar
,
203 convert
=pytree
.convert
,
205 self
.pre_order
, self
.post_order
= self
.get_fixers()
208 self
.files
= [] # List of files that were or should be modified
210 self
.BM
= bm
.BottomMatcher()
211 self
.bmi_pre_order
= [] # Bottom Matcher incompatible fixers
212 self
.bmi_post_order
= []
214 for fixer
in chain(self
.post_order
, self
.pre_order
):
215 if fixer
.BM_compatible
:
216 self
.BM
.add_fixer(fixer
)
217 # remove fixers that will be handled by the bottom-up
219 elif fixer
in self
.pre_order
:
220 self
.bmi_pre_order
.append(fixer
)
221 elif fixer
in self
.post_order
:
222 self
.bmi_post_order
.append(fixer
)
224 self
.bmi_pre_order_heads
= _get_headnode_dict(self
.bmi_pre_order
)
225 self
.bmi_post_order_heads
= _get_headnode_dict(self
.bmi_post_order
)
229 def get_fixers(self
):
230 """Inspects the options to load the requested patterns and handlers.
233 (pre_order, post_order), where pre_order is the list of fixers that
234 want a pre-order AST traversal, and post_order is the list that want
235 post-order traversal.
237 pre_order_fixers
= []
238 post_order_fixers
= []
239 for fix_mod_path
in self
.fixers
:
240 mod
= __import__(fix_mod_path
, {}, {}, ["*"])
241 fix_name
= fix_mod_path
.rsplit(".", 1)[-1]
242 if fix_name
.startswith(self
.FILE_PREFIX
):
243 fix_name
= fix_name
[len(self
.FILE_PREFIX
):]
244 parts
= fix_name
.split("_")
245 class_name
= self
.CLASS_PREFIX
+ "".join([p
.title() for p
in parts
])
247 fix_class
= getattr(mod
, class_name
)
248 except AttributeError:
249 raise FixerError("Can't find %s.%s" % (fix_name
, class_name
))
250 fixer
= fix_class(self
.options
, self
.fixer_log
)
251 if fixer
.explicit
and self
.explicit
is not True and \
252 fix_mod_path
not in self
.explicit
:
253 self
.log_message("Skipping implicit fixer: %s", fix_name
)
256 self
.log_debug("Adding transformation: %s", fix_name
)
257 if fixer
.order
== "pre":
258 pre_order_fixers
.append(fixer
)
259 elif fixer
.order
== "post":
260 post_order_fixers
.append(fixer
)
262 raise FixerError("Illegal fixer order: %r" % fixer
.order
)
264 key_func
= operator
.attrgetter("run_order")
265 pre_order_fixers
.sort(key
=key_func
)
266 post_order_fixers
.sort(key
=key_func
)
267 return (pre_order_fixers
, post_order_fixers
)
269 def log_error(self
, msg
, *args
, **kwds
):
270 """Called when an error occurs."""
273 def log_message(self
, msg
, *args
):
274 """Hook to log a message."""
277 self
.logger
.info(msg
)
279 def log_debug(self
, msg
, *args
):
282 self
.logger
.debug(msg
)
284 def print_output(self
, old_text
, new_text
, filename
, equal
):
285 """Called with the old version, new version, and filename of a
289 def refactor(self
, items
, write
=False, doctests_only
=False):
290 """Refactor a list of files and directories."""
292 for dir_or_file
in items
:
293 if os
.path
.isdir(dir_or_file
):
294 self
.refactor_dir(dir_or_file
, write
, doctests_only
)
296 self
.refactor_file(dir_or_file
, write
, doctests_only
)
298 def refactor_dir(self
, dir_name
, write
=False, doctests_only
=False):
299 """Descends down a directory and refactor every Python file found.
301 Python files are assumed to have a .py extension.
303 Files and subdirectories starting with '.' are skipped.
305 py_ext
= os
.extsep
+ "py"
306 for dirpath
, dirnames
, filenames
in os
.walk(dir_name
):
307 self
.log_debug("Descending into %s", dirpath
)
310 for name
in filenames
:
311 if (not name
.startswith(".") and
312 os
.path
.splitext(name
)[1] == py_ext
):
313 fullname
= os
.path
.join(dirpath
, name
)
314 self
.refactor_file(fullname
, write
, doctests_only
)
315 # Modify dirnames in-place to remove subdirs with leading dots
316 dirnames
[:] = [dn
for dn
in dirnames
if not dn
.startswith(".")]
318 def _read_python_source(self
, filename
):
320 Do our best to decode a Python source file correctly.
323 f
= open(filename
, "rb")
324 except IOError as err
:
325 self
.log_error("Can't open %s: %s", filename
, err
)
328 encoding
= tokenize
.detect_encoding(f
.readline
)[0]
331 with
_open_with_encoding(filename
, "r", encoding
=encoding
) as f
:
332 return _from_system_newlines(f
.read()), encoding
334 def refactor_file(self
, filename
, write
=False, doctests_only
=False):
335 """Refactors a file."""
336 input, encoding
= self
._read
_python
_source
(filename
)
338 # Reading the file failed.
340 input += u
"\n" # Silence certain parse errors
342 self
.log_debug("Refactoring doctests in %s", filename
)
343 output
= self
.refactor_docstring(input, filename
)
345 self
.processed_file(output
, filename
, input, write
, encoding
)
347 self
.log_debug("No doctest changes in %s", filename
)
349 tree
= self
.refactor_string(input, filename
)
350 if tree
and tree
.was_changed
:
351 # The [:-1] is to take off the \n we added earlier
352 self
.processed_file(unicode(tree
)[:-1], filename
,
353 write
=write
, encoding
=encoding
)
355 self
.log_debug("No changes in %s", filename
)
357 def refactor_string(self
, data
, name
):
358 """Refactor a given input string.
361 data: a string holding the code to be refactored.
362 name: a human-readable name for use in error/log messages.
365 An AST corresponding to the refactored input stream; None if
366 there were errors during the parse.
368 features
= _detect_future_features(data
)
369 if "print_function" in features
:
370 self
.driver
.grammar
= pygram
.python_grammar_no_print_statement
372 tree
= self
.driver
.parse_string(data
)
373 except Exception as err
:
374 self
.log_error("Can't parse %s: %s: %s",
375 name
, err
.__class
__.__name
__, err
)
378 self
.driver
.grammar
= self
.grammar
379 tree
.future_features
= features
380 self
.log_debug("Refactoring %s", name
)
381 self
.refactor_tree(tree
, name
)
384 def refactor_stdin(self
, doctests_only
=False):
385 input = sys
.stdin
.read()
387 self
.log_debug("Refactoring doctests in stdin")
388 output
= self
.refactor_docstring(input, "<stdin>")
390 self
.processed_file(output
, "<stdin>", input)
392 self
.log_debug("No doctest changes in stdin")
394 tree
= self
.refactor_string(input, "<stdin>")
395 if tree
and tree
.was_changed
:
396 self
.processed_file(unicode(tree
), "<stdin>", input)
398 self
.log_debug("No changes in stdin")
400 def refactor_tree(self
, tree
, name
):
401 """Refactors a parse tree (modifying the tree in place).
403 For compatible patterns the bottom matcher module is
404 used. Otherwise the tree is traversed node-to-node for
408 tree: a pytree.Node instance representing the root of the tree
410 name: a human-readable name for this tree.
413 True if the tree was modified, False otherwise.
416 for fixer
in chain(self
.pre_order
, self
.post_order
):
417 fixer
.start_tree(tree
, name
)
419 #use traditional matching for the incompatible fixers
420 self
.traverse_by(self
.bmi_pre_order_heads
, tree
.pre_order())
421 self
.traverse_by(self
.bmi_post_order_heads
, tree
.post_order())
423 # obtain a set of candidate nodes
424 match_set
= self
.BM
.run(tree
.leaves())
426 while any(match_set
.values()):
427 for fixer
in self
.BM
.fixers
:
428 if fixer
in match_set
and match_set
[fixer
]:
429 #sort by depth; apply fixers from bottom(of the AST) to top
430 match_set
[fixer
].sort(key
=pytree
.Base
.depth
, reverse
=True)
432 if fixer
.keep_line_order
:
433 #some fixers(eg fix_imports) must be applied
434 #with the original file's line order
435 match_set
[fixer
].sort(key
=pytree
.Base
.get_lineno
)
437 for node
in list(match_set
[fixer
]):
438 if node
in match_set
[fixer
]:
439 match_set
[fixer
].remove(node
)
443 except AssertionError:
444 # this node has been cut off from a
445 # previous transformation ; skip
448 if node
.fixers_applied
and fixer
in node
.fixers_applied
:
449 # do not apply the same fixer again
452 results
= fixer
.match(node
)
455 new
= fixer
.transform(node
, results
)
458 #new.fixers_applied.append(fixer)
459 for node
in new
.post_order():
460 # do not apply the fixer again to
461 # this or any subnode
462 if not node
.fixers_applied
:
463 node
.fixers_applied
= []
464 node
.fixers_applied
.append(fixer
)
466 # update the original match set for
468 new_matches
= self
.BM
.run(new
.leaves())
469 for fxr
in new_matches
:
470 if not fxr
in match_set
:
473 match_set
[fxr
].extend(new_matches
[fxr
])
475 for fixer
in chain(self
.pre_order
, self
.post_order
):
476 fixer
.finish_tree(tree
, name
)
477 return tree
.was_changed
479 def traverse_by(self
, fixers
, traversal
):
480 """Traverse an AST, applying a set of fixers to each node.
482 This is a helper method for refactor_tree().
485 fixers: a list of fixer instances.
486 traversal: a generator that yields AST nodes.
493 for node
in traversal
:
494 for fixer
in fixers
[node
.type]:
495 results
= fixer
.match(node
)
497 new
= fixer
.transform(node
, results
)
502 def processed_file(self
, new_text
, filename
, old_text
=None, write
=False,
505 Called when a file has been refactored, and there are changes.
507 self
.files
.append(filename
)
509 old_text
= self
._read
_python
_source
(filename
)[0]
512 equal
= old_text
== new_text
513 self
.print_output(old_text
, new_text
, filename
, equal
)
515 self
.log_debug("No changes to %s", filename
)
518 self
.write_file(new_text
, filename
, old_text
, encoding
)
520 self
.log_debug("Not writing changes to %s", filename
)
522 def write_file(self
, new_text
, filename
, old_text
, encoding
=None):
523 """Writes a string to a file.
525 It first shows a unified diff between the old text and the new text, and
526 then rewrites the file; the latter is only done if the write option is
530 f
= _open_with_encoding(filename
, "w", encoding
=encoding
)
531 except os
.error
as err
:
532 self
.log_error("Can't create %s: %s", filename
, err
)
535 f
.write(_to_system_newlines(new_text
))
536 except os
.error
as err
:
537 self
.log_error("Can't write %s: %s", filename
, err
)
540 self
.log_debug("Wrote changes to %s", filename
)
546 def refactor_docstring(self
, input, filename
):
547 """Refactors a docstring, looking for doctests.
549 This returns a modified version of the input string. It looks
550 for doctests, which start with a ">>>" prompt, and may be
551 continued with "..." prompts, as long as the "..." is indented
552 the same as the ">>>".
554 (Unfortunately we can't use the doctest module's parser,
555 since, like most parsers, it is not geared towards preserving
556 the original source.)
563 for line
in input.splitlines(True):
565 if line
.lstrip().startswith(self
.PS1
):
566 if block
is not None:
567 result
.extend(self
.refactor_doctest(block
, block_lineno
,
569 block_lineno
= lineno
571 i
= line
.find(self
.PS1
)
573 elif (indent
is not None and
574 (line
.startswith(indent
+ self
.PS2
) or
575 line
== indent
+ self
.PS2
.rstrip() + u
"\n")):
578 if block
is not None:
579 result
.extend(self
.refactor_doctest(block
, block_lineno
,
584 if block
is not None:
585 result
.extend(self
.refactor_doctest(block
, block_lineno
,
587 return u
"".join(result
)
589 def refactor_doctest(self
, block
, lineno
, indent
, filename
):
590 """Refactors one doctest.
592 A doctest is given as a block of lines, the first of which starts
593 with ">>>" (possibly indented), while the remaining lines start
594 with "..." (identically indented).
598 tree
= self
.parse_block(block
, lineno
, indent
)
599 except Exception as err
:
600 if self
.logger
.isEnabledFor(logging
.DEBUG
):
602 self
.log_debug("Source: %s", line
.rstrip(u
"\n"))
603 self
.log_error("Can't parse docstring in %s line %s: %s: %s",
604 filename
, lineno
, err
.__class
__.__name
__, err
)
606 if self
.refactor_tree(tree
, filename
):
607 new
= unicode(tree
).splitlines(True)
608 # Undo the adjustment of the line numbers in wrap_toks() below.
609 clipped
, new
= new
[:lineno
-1], new
[lineno
-1:]
610 assert clipped
== [u
"\n"] * (lineno
-1), clipped
611 if not new
[-1].endswith(u
"\n"):
613 block
= [indent
+ self
.PS1
+ new
.pop(0)]
615 block
+= [indent
+ self
.PS2
+ line
for line
in new
]
624 self
.log_message("No files %s modified.", were
)
626 self
.log_message("Files that %s modified:", were
)
627 for file in self
.files
:
628 self
.log_message(file)
630 self
.log_message("Warnings/messages while refactoring:")
631 for message
in self
.fixer_log
:
632 self
.log_message(message
)
634 if len(self
.errors
) == 1:
635 self
.log_message("There was 1 error:")
637 self
.log_message("There were %d errors:", len(self
.errors
))
638 for msg
, args
, kwds
in self
.errors
:
639 self
.log_message(msg
, *args
, **kwds
)
641 def parse_block(self
, block
, lineno
, indent
):
642 """Parses a block into a tree.
644 This is necessary to get correct line number / offset information
645 in the parser diagnostics and embedded into the parse tree.
647 tree
= self
.driver
.parse_tokens(self
.wrap_toks(block
, lineno
, indent
))
648 tree
.future_features
= frozenset()
651 def wrap_toks(self
, block
, lineno
, indent
):
652 """Wraps a tokenize stream to systematically modify start/end."""
653 tokens
= tokenize
.generate_tokens(self
.gen_lines(block
, indent
).next
)
654 for type, value
, (line0
, col0
), (line1
, col1
), line_text
in tokens
:
657 # Don't bother updating the columns; this is too complicated
658 # since line_text would also have to be updated and it would
659 # still break for tokens spanning lines. Let the user guess
660 # that the column numbers for doctests are relative to the
661 # end of the prompt string (PS1 or PS2).
662 yield type, value
, (line0
, col0
), (line1
, col1
), line_text
665 def gen_lines(self
, block
, indent
):
666 """Generates lines as expected by tokenize from a list of lines.
668 This strips the first len(indent + self.PS1) characters off each line.
670 prefix1
= indent
+ self
.PS1
671 prefix2
= indent
+ self
.PS2
674 if line
.startswith(prefix
):
675 yield line
[len(prefix
):]
676 elif line
== prefix
.rstrip() + u
"\n":
679 raise AssertionError("line=%r, prefix=%r" % (line
, prefix
))
685 class MultiprocessingUnsupported(Exception):
689 class MultiprocessRefactoringTool(RefactoringTool
):
691 def __init__(self
, *args
, **kwargs
):
692 super(MultiprocessRefactoringTool
, self
).__init
__(*args
, **kwargs
)
694 self
.output_lock
= None
696 def refactor(self
, items
, write
=False, doctests_only
=False,
698 if num_processes
== 1:
699 return super(MultiprocessRefactoringTool
, self
).refactor(
700 items
, write
, doctests_only
)
702 import multiprocessing
704 raise MultiprocessingUnsupported
705 if self
.queue
is not None:
706 raise RuntimeError("already doing multiple processes")
707 self
.queue
= multiprocessing
.JoinableQueue()
708 self
.output_lock
= multiprocessing
.Lock()
709 processes
= [multiprocessing
.Process(target
=self
._child
)
710 for i
in xrange(num_processes
)]
714 super(MultiprocessRefactoringTool
, self
).refactor(items
, write
,
718 for i
in xrange(num_processes
):
726 task
= self
.queue
.get()
727 while task
is not None:
730 super(MultiprocessRefactoringTool
, self
).refactor_file(
733 self
.queue
.task_done()
734 task
= self
.queue
.get()
736 def refactor_file(self
, *args
, **kwargs
):
737 if self
.queue
is not None:
738 self
.queue
.put((args
, kwargs
))
740 return super(MultiprocessRefactoringTool
, self
).refactor_file(