+++ /dev/null
-"""Unit tests for contextlib.py, and other context managers."""\r
-\r
-import sys\r
-import tempfile\r
-import unittest\r
-from contextlib import * # Tests __all__\r
-from test import test_support\r
-try:\r
- import threading\r
-except ImportError:\r
- threading = None\r
-\r
-\r
-class ContextManagerTestCase(unittest.TestCase):\r
-\r
- def test_contextmanager_plain(self):\r
- state = []\r
- @contextmanager\r
- def woohoo():\r
- state.append(1)\r
- yield 42\r
- state.append(999)\r
- with woohoo() as x:\r
- self.assertEqual(state, [1])\r
- self.assertEqual(x, 42)\r
- state.append(x)\r
- self.assertEqual(state, [1, 42, 999])\r
-\r
- def test_contextmanager_finally(self):\r
- state = []\r
- @contextmanager\r
- def woohoo():\r
- state.append(1)\r
- try:\r
- yield 42\r
- finally:\r
- state.append(999)\r
- with self.assertRaises(ZeroDivisionError):\r
- with woohoo() as x:\r
- self.assertEqual(state, [1])\r
- self.assertEqual(x, 42)\r
- state.append(x)\r
- raise ZeroDivisionError()\r
- self.assertEqual(state, [1, 42, 999])\r
-\r
- def test_contextmanager_no_reraise(self):\r
- @contextmanager\r
- def whee():\r
- yield\r
- ctx = whee()\r
- ctx.__enter__()\r
- # Calling __exit__ should not result in an exception\r
- self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))\r
-\r
- def test_contextmanager_trap_yield_after_throw(self):\r
- @contextmanager\r
- def whoo():\r
- try:\r
- yield\r
- except:\r
- yield\r
- ctx = whoo()\r
- ctx.__enter__()\r
- self.assertRaises(\r
- RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None\r
- )\r
-\r
- def test_contextmanager_except(self):\r
- state = []\r
- @contextmanager\r
- def woohoo():\r
- state.append(1)\r
- try:\r
- yield 42\r
- except ZeroDivisionError, e:\r
- state.append(e.args[0])\r
- self.assertEqual(state, [1, 42, 999])\r
- with woohoo() as x:\r
- self.assertEqual(state, [1])\r
- self.assertEqual(x, 42)\r
- state.append(x)\r
- raise ZeroDivisionError(999)\r
- self.assertEqual(state, [1, 42, 999])\r
-\r
- def _create_contextmanager_attribs(self):\r
- def attribs(**kw):\r
- def decorate(func):\r
- for k,v in kw.items():\r
- setattr(func,k,v)\r
- return func\r
- return decorate\r
- @contextmanager\r
- @attribs(foo='bar')\r
- def baz(spam):\r
- """Whee!"""\r
- return baz\r
-\r
- def test_contextmanager_attribs(self):\r
- baz = self._create_contextmanager_attribs()\r
- self.assertEqual(baz.__name__,'baz')\r
- self.assertEqual(baz.foo, 'bar')\r
-\r
- @unittest.skipIf(sys.flags.optimize >= 2,\r
- "Docstrings are omitted with -O2 and above")\r
- def test_contextmanager_doc_attrib(self):\r
- baz = self._create_contextmanager_attribs()\r
- self.assertEqual(baz.__doc__, "Whee!")\r
-\r
-class NestedTestCase(unittest.TestCase):\r
-\r
- # XXX This needs more work\r
-\r
- def test_nested(self):\r
- @contextmanager\r
- def a():\r
- yield 1\r
- @contextmanager\r
- def b():\r
- yield 2\r
- @contextmanager\r
- def c():\r
- yield 3\r
- with nested(a(), b(), c()) as (x, y, z):\r
- self.assertEqual(x, 1)\r
- self.assertEqual(y, 2)\r
- self.assertEqual(z, 3)\r
-\r
- def test_nested_cleanup(self):\r
- state = []\r
- @contextmanager\r
- def a():\r
- state.append(1)\r
- try:\r
- yield 2\r
- finally:\r
- state.append(3)\r
- @contextmanager\r
- def b():\r
- state.append(4)\r
- try:\r
- yield 5\r
- finally:\r
- state.append(6)\r
- with self.assertRaises(ZeroDivisionError):\r
- with nested(a(), b()) as (x, y):\r
- state.append(x)\r
- state.append(y)\r
- 1 // 0\r
- self.assertEqual(state, [1, 4, 2, 5, 6, 3])\r
-\r
- def test_nested_right_exception(self):\r
- @contextmanager\r
- def a():\r
- yield 1\r
- class b(object):\r
- def __enter__(self):\r
- return 2\r
- def __exit__(self, *exc_info):\r
- try:\r
- raise Exception()\r
- except:\r
- pass\r
- with self.assertRaises(ZeroDivisionError):\r
- with nested(a(), b()) as (x, y):\r
- 1 // 0\r
- self.assertEqual((x, y), (1, 2))\r
-\r
- def test_nested_b_swallows(self):\r
- @contextmanager\r
- def a():\r
- yield\r
- @contextmanager\r
- def b():\r
- try:\r
- yield\r
- except:\r
- # Swallow the exception\r
- pass\r
- try:\r
- with nested(a(), b()):\r
- 1 // 0\r
- except ZeroDivisionError:\r
- self.fail("Didn't swallow ZeroDivisionError")\r
-\r
- def test_nested_break(self):\r
- @contextmanager\r
- def a():\r
- yield\r
- state = 0\r
- while True:\r
- state += 1\r
- with nested(a(), a()):\r
- break\r
- state += 10\r
- self.assertEqual(state, 1)\r
-\r
- def test_nested_continue(self):\r
- @contextmanager\r
- def a():\r
- yield\r
- state = 0\r
- while state < 3:\r
- state += 1\r
- with nested(a(), a()):\r
- continue\r
- state += 10\r
- self.assertEqual(state, 3)\r
-\r
- def test_nested_return(self):\r
- @contextmanager\r
- def a():\r
- try:\r
- yield\r
- except:\r
- pass\r
- def foo():\r
- with nested(a(), a()):\r
- return 1\r
- return 10\r
- self.assertEqual(foo(), 1)\r
-\r
-class ClosingTestCase(unittest.TestCase):\r
-\r
- # XXX This needs more work\r
-\r
- def test_closing(self):\r
- state = []\r
- class C:\r
- def close(self):\r
- state.append(1)\r
- x = C()\r
- self.assertEqual(state, [])\r
- with closing(x) as y:\r
- self.assertEqual(x, y)\r
- self.assertEqual(state, [1])\r
-\r
- def test_closing_error(self):\r
- state = []\r
- class C:\r
- def close(self):\r
- state.append(1)\r
- x = C()\r
- self.assertEqual(state, [])\r
- with self.assertRaises(ZeroDivisionError):\r
- with closing(x) as y:\r
- self.assertEqual(x, y)\r
- 1 // 0\r
- self.assertEqual(state, [1])\r
-\r
-class FileContextTestCase(unittest.TestCase):\r
-\r
- def testWithOpen(self):\r
- tfn = tempfile.mktemp()\r
- try:\r
- f = None\r
- with open(tfn, "w") as f:\r
- self.assertFalse(f.closed)\r
- f.write("Booh\n")\r
- self.assertTrue(f.closed)\r
- f = None\r
- with self.assertRaises(ZeroDivisionError):\r
- with open(tfn, "r") as f:\r
- self.assertFalse(f.closed)\r
- self.assertEqual(f.read(), "Booh\n")\r
- 1 // 0\r
- self.assertTrue(f.closed)\r
- finally:\r
- test_support.unlink(tfn)\r
-\r
-@unittest.skipUnless(threading, 'Threading required for this test.')\r
-class LockContextTestCase(unittest.TestCase):\r
-\r
- def boilerPlate(self, lock, locked):\r
- self.assertFalse(locked())\r
- with lock:\r
- self.assertTrue(locked())\r
- self.assertFalse(locked())\r
- with self.assertRaises(ZeroDivisionError):\r
- with lock:\r
- self.assertTrue(locked())\r
- 1 // 0\r
- self.assertFalse(locked())\r
-\r
- def testWithLock(self):\r
- lock = threading.Lock()\r
- self.boilerPlate(lock, lock.locked)\r
-\r
- def testWithRLock(self):\r
- lock = threading.RLock()\r
- self.boilerPlate(lock, lock._is_owned)\r
-\r
- def testWithCondition(self):\r
- lock = threading.Condition()\r
- def locked():\r
- return lock._is_owned()\r
- self.boilerPlate(lock, locked)\r
-\r
- def testWithSemaphore(self):\r
- lock = threading.Semaphore()\r
- def locked():\r
- if lock.acquire(False):\r
- lock.release()\r
- return False\r
- else:\r
- return True\r
- self.boilerPlate(lock, locked)\r
-\r
- def testWithBoundedSemaphore(self):\r
- lock = threading.BoundedSemaphore()\r
- def locked():\r
- if lock.acquire(False):\r
- lock.release()\r
- return False\r
- else:\r
- return True\r
- self.boilerPlate(lock, locked)\r
-\r
-# This is needed to make the test actually run under regrtest.py!\r
-def test_main():\r
- with test_support.check_warnings(("With-statements now directly support "\r
- "multiple context managers",\r
- DeprecationWarning)):\r
- test_support.run_unittest(__name__)\r
-\r
-if __name__ == "__main__":\r
- test_main()\r