2 Unit tests for refactor.py.
5 from __future__
import with_statement
17 from lib2to3
import refactor
, pygram
, fixer_base
18 from lib2to3
.pgen2
import token
23 TEST_DATA_DIR
= os
.path
.join(os
.path
.dirname(__file__
), "data")
24 FIXER_DIR
= os
.path
.join(TEST_DATA_DIR
, "fixers")
26 sys
.path
.append(FIXER_DIR
)
28 _DEFAULT_FIXERS
= refactor
.get_fixers_from_package("myfixes")
32 _2TO3_FIXERS
= refactor
.get_fixers_from_package("lib2to3.fixes")
34 class TestRefactoringTool(unittest
.TestCase
):
37 sys
.path
.append(FIXER_DIR
)
42 def check_instances(self
, instances
, classes
):
43 for inst
, cls
in zip(instances
, classes
):
44 if not isinstance(inst
, cls
):
45 self
.fail("%s are not instances of %s" % instances
, classes
)
47 def rt(self
, options
=None, fixers
=_DEFAULT_FIXERS
, explicit
=None):
48 return refactor
.RefactoringTool(fixers
, options
, explicit
)
50 def test_print_function_option(self
):
51 rt
= self
.rt({"print_function" : True})
52 self
.assertTrue(rt
.grammar
is pygram
.python_grammar_no_print_statement
)
53 self
.assertTrue(rt
.driver
.grammar
is
54 pygram
.python_grammar_no_print_statement
)
56 def test_fixer_loading_helpers(self
):
57 contents
= ["explicit", "first", "last", "parrot", "preorder"]
58 non_prefixed
= refactor
.get_all_fix_names("myfixes")
59 prefixed
= refactor
.get_all_fix_names("myfixes", False)
60 full_names
= refactor
.get_fixers_from_package("myfixes")
61 self
.assertEqual(prefixed
, ["fix_" + name
for name
in contents
])
62 self
.assertEqual(non_prefixed
, contents
)
63 self
.assertEqual(full_names
,
64 ["myfixes.fix_" + name
for name
in contents
])
66 def test_detect_future_features(self
):
67 run
= refactor
._detect
_future
_features
70 self
.assertEqual(run(""), empty
)
71 self
.assertEqual(run("from __future__ import print_function"),
72 fs(("print_function",)))
73 self
.assertEqual(run("from __future__ import generators"),
75 self
.assertEqual(run("from __future__ import generators, feature"),
76 fs(("generators", "feature")))
77 inp
= "from __future__ import generators, print_function"
78 self
.assertEqual(run(inp
), fs(("generators", "print_function")))
79 inp
="from __future__ import print_function, generators"
80 self
.assertEqual(run(inp
), fs(("print_function", "generators")))
81 inp
= "from __future__ import (print_function,)"
82 self
.assertEqual(run(inp
), fs(("print_function",)))
83 inp
= "from __future__ import (generators, print_function)"
84 self
.assertEqual(run(inp
), fs(("generators", "print_function")))
85 inp
= "from __future__ import (generators, nested_scopes)"
86 self
.assertEqual(run(inp
), fs(("generators", "nested_scopes")))
87 inp
= """from __future__ import generators
88 from __future__ import print_function"""
89 self
.assertEqual(run(inp
), fs(("generators", "print_function")))
99 self
.assertEqual(run(inp
), empty
)
100 inp
= "'docstring'\nfrom __future__ import print_function"
101 self
.assertEqual(run(inp
), fs(("print_function",)))
102 inp
= "'docstring'\n'somng'\nfrom __future__ import print_function"
103 self
.assertEqual(run(inp
), empty
)
104 inp
= "# comment\nfrom __future__ import print_function"
105 self
.assertEqual(run(inp
), fs(("print_function",)))
106 inp
= "# comment\n'doc'\nfrom __future__ import print_function"
107 self
.assertEqual(run(inp
), fs(("print_function",)))
108 inp
= "class x: pass\nfrom __future__ import print_function"
109 self
.assertEqual(run(inp
), empty
)
111 def test_get_headnode_dict(self
):
112 class NoneFix(fixer_base
.BaseFix
):
115 class FileInputFix(fixer_base
.BaseFix
):
116 PATTERN
= "file_input< any * >"
118 class SimpleFix(fixer_base
.BaseFix
):
121 no_head
= NoneFix({}, [])
122 with_head
= FileInputFix({}, [])
123 simple
= SimpleFix({}, [])
124 d
= refactor
._get
_headnode
_dict
([no_head
, with_head
, simple
])
125 top_fixes
= d
.pop(pygram
.python_symbols
.file_input
)
126 self
.assertEqual(top_fixes
, [with_head
, no_head
])
127 name_fixes
= d
.pop(token
.NAME
)
128 self
.assertEqual(name_fixes
, [simple
, no_head
])
129 for fixes
in d
.itervalues():
130 self
.assertEqual(fixes
, [no_head
])
132 def test_fixer_loading(self
):
133 from myfixes
.fix_first
import FixFirst
134 from myfixes
.fix_last
import FixLast
135 from myfixes
.fix_parrot
import FixParrot
136 from myfixes
.fix_preorder
import FixPreorder
139 pre
, post
= rt
.get_fixers()
141 self
.check_instances(pre
, [FixPreorder
])
142 self
.check_instances(post
, [FixFirst
, FixParrot
, FixLast
])
144 def test_naughty_fixers(self
):
145 self
.assertRaises(ImportError, self
.rt
, fixers
=["not_here"])
146 self
.assertRaises(refactor
.FixerError
, self
.rt
, fixers
=["no_fixer_cls"])
147 self
.assertRaises(refactor
.FixerError
, self
.rt
, fixers
=["bad_order"])
149 def test_refactor_string(self
):
151 input = "def parrot(): pass\n\n"
152 tree
= rt
.refactor_string(input, "<test>")
153 self
.assertNotEqual(str(tree
), input)
155 input = "def f(): pass\n\n"
156 tree
= rt
.refactor_string(input, "<test>")
157 self
.assertEqual(str(tree
), input)
159 def test_refactor_stdin(self
):
161 class MyRT(refactor
.RefactoringTool
):
163 def print_output(self
, old_text
, new_text
, filename
, equal
):
164 results
.extend([old_text
, new_text
, filename
, equal
])
167 rt
= MyRT(_DEFAULT_FIXERS
)
169 sys
.stdin
= StringIO
.StringIO("def parrot(): pass\n\n")
174 expected
= ["def parrot(): pass\n\n",
175 "def cheese(): pass\n\n",
177 self
.assertEqual(results
, expected
)
179 def check_file_refactoring(self
, test_file
, fixers
=_2TO3_FIXERS
):
181 with
open(test_file
, "rb") as fp
:
183 old_contents
= read_file()
184 rt
= self
.rt(fixers
=fixers
)
186 rt
.refactor_file(test_file
)
187 self
.assertEqual(old_contents
, read_file())
190 rt
.refactor_file(test_file
, True)
191 new_contents
= read_file()
192 self
.assertNotEqual(old_contents
, new_contents
)
194 with
open(test_file
, "wb") as fp
:
195 fp
.write(old_contents
)
198 def test_refactor_file(self
):
199 test_file
= os
.path
.join(FIXER_DIR
, "parrot_example.py")
200 self
.check_file_refactoring(test_file
, _DEFAULT_FIXERS
)
202 def test_refactor_dir(self
):
203 def check(structure
, expected
):
204 def mock_refactor_file(self
, f
, *args
):
206 save_func
= refactor
.RefactoringTool
.refactor_file
207 refactor
.RefactoringTool
.refactor_file
= mock_refactor_file
210 dir = tempfile
.mkdtemp(prefix
="2to3-test_refactor")
212 os
.mkdir(os
.path
.join(dir, "a_dir"))
214 open(os
.path
.join(dir, fn
), "wb").close()
217 refactor
.RefactoringTool
.refactor_file
= save_func
219 self
.assertEqual(got
,
220 [os
.path
.join(dir, path
) for path
in expected
])
229 check(tree
, expected
)
231 os
.path
.join("a_dir", "stuff.py")]
234 def test_file_encoding(self
):
235 fn
= os
.path
.join(TEST_DATA_DIR
, "different_encoding.py")
236 self
.check_file_refactoring(fn
)
239 fn
= os
.path
.join(TEST_DATA_DIR
, "bom.py")
240 data
= self
.check_file_refactoring(fn
)
241 self
.assertTrue(data
.startswith(codecs
.BOM_UTF8
))
243 def test_crlf_newlines(self
):
247 fn
= os
.path
.join(TEST_DATA_DIR
, "crlf.py")
248 fixes
= refactor
.get_fixers_from_package("lib2to3.fixes")
249 self
.check_file_refactoring(fn
, fixes
)
253 def test_refactor_docstring(self
):
260 out
= rt
.refactor_docstring(doc
, "<test>")
261 self
.assertEqual(out
, doc
)
267 out
= rt
.refactor_docstring(doc
, "<test>")
268 self
.assertNotEqual(out
, doc
)
270 def test_explicit(self
):
271 from myfixes
.fix_explicit
import FixExplicit
273 rt
= self
.rt(fixers
=["myfixes.fix_explicit"])
274 self
.assertEqual(len(rt
.post_order
), 0)
276 rt
= self
.rt(explicit
=["myfixes.fix_explicit"])
277 for fix
in rt
.post_order
:
278 if isinstance(fix
, FixExplicit
):
281 self
.fail("explicit fixer not loaded")