]>
Commit | Line | Data |
---|---|---|
4710c53d | 1 | """Unittests for heapq."""\r |
2 | \r | |
3 | import sys\r | |
4 | import random\r | |
5 | \r | |
6 | from test import test_support\r | |
7 | from unittest import TestCase, skipUnless\r | |
8 | \r | |
9 | py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])\r | |
10 | c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])\r | |
11 | \r | |
12 | # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when\r | |
13 | # _heapq is imported, so check them there\r | |
14 | func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',\r | |
15 | 'heapreplace', '_nlargest', '_nsmallest']\r | |
16 | \r | |
17 | class TestModules(TestCase):\r | |
18 | def test_py_functions(self):\r | |
19 | for fname in func_names:\r | |
20 | self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')\r | |
21 | \r | |
22 | @skipUnless(c_heapq, 'requires _heapq')\r | |
23 | def test_c_functions(self):\r | |
24 | for fname in func_names:\r | |
25 | self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')\r | |
26 | \r | |
27 | \r | |
28 | class TestHeap(TestCase):\r | |
29 | module = None\r | |
30 | \r | |
31 | def test_push_pop(self):\r | |
32 | # 1) Push 256 random numbers and pop them off, verifying all's OK.\r | |
33 | heap = []\r | |
34 | data = []\r | |
35 | self.check_invariant(heap)\r | |
36 | for i in range(256):\r | |
37 | item = random.random()\r | |
38 | data.append(item)\r | |
39 | self.module.heappush(heap, item)\r | |
40 | self.check_invariant(heap)\r | |
41 | results = []\r | |
42 | while heap:\r | |
43 | item = self.module.heappop(heap)\r | |
44 | self.check_invariant(heap)\r | |
45 | results.append(item)\r | |
46 | data_sorted = data[:]\r | |
47 | data_sorted.sort()\r | |
48 | self.assertEqual(data_sorted, results)\r | |
49 | # 2) Check that the invariant holds for a sorted array\r | |
50 | self.check_invariant(results)\r | |
51 | \r | |
52 | self.assertRaises(TypeError, self.module.heappush, [])\r | |
53 | try:\r | |
54 | self.assertRaises(TypeError, self.module.heappush, None, None)\r | |
55 | self.assertRaises(TypeError, self.module.heappop, None)\r | |
56 | except AttributeError:\r | |
57 | pass\r | |
58 | \r | |
59 | def check_invariant(self, heap):\r | |
60 | # Check the heap invariant.\r | |
61 | for pos, item in enumerate(heap):\r | |
62 | if pos: # pos 0 has no parent\r | |
63 | parentpos = (pos-1) >> 1\r | |
64 | self.assertTrue(heap[parentpos] <= item)\r | |
65 | \r | |
66 | def test_heapify(self):\r | |
67 | for size in range(30):\r | |
68 | heap = [random.random() for dummy in range(size)]\r | |
69 | self.module.heapify(heap)\r | |
70 | self.check_invariant(heap)\r | |
71 | \r | |
72 | self.assertRaises(TypeError, self.module.heapify, None)\r | |
73 | \r | |
74 | def test_naive_nbest(self):\r | |
75 | data = [random.randrange(2000) for i in range(1000)]\r | |
76 | heap = []\r | |
77 | for item in data:\r | |
78 | self.module.heappush(heap, item)\r | |
79 | if len(heap) > 10:\r | |
80 | self.module.heappop(heap)\r | |
81 | heap.sort()\r | |
82 | self.assertEqual(heap, sorted(data)[-10:])\r | |
83 | \r | |
84 | def heapiter(self, heap):\r | |
85 | # An iterator returning a heap's elements, smallest-first.\r | |
86 | try:\r | |
87 | while 1:\r | |
88 | yield self.module.heappop(heap)\r | |
89 | except IndexError:\r | |
90 | pass\r | |
91 | \r | |
92 | def test_nbest(self):\r | |
93 | # Less-naive "N-best" algorithm, much faster (if len(data) is big\r | |
94 | # enough <wink>) than sorting all of data. However, if we had a max\r | |
95 | # heap instead of a min heap, it could go faster still via\r | |
96 | # heapify'ing all of data (linear time), then doing 10 heappops\r | |
97 | # (10 log-time steps).\r | |
98 | data = [random.randrange(2000) for i in range(1000)]\r | |
99 | heap = data[:10]\r | |
100 | self.module.heapify(heap)\r | |
101 | for item in data[10:]:\r | |
102 | if item > heap[0]: # this gets rarer the longer we run\r | |
103 | self.module.heapreplace(heap, item)\r | |
104 | self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])\r | |
105 | \r | |
106 | self.assertRaises(TypeError, self.module.heapreplace, None)\r | |
107 | self.assertRaises(TypeError, self.module.heapreplace, None, None)\r | |
108 | self.assertRaises(IndexError, self.module.heapreplace, [], None)\r | |
109 | \r | |
110 | def test_nbest_with_pushpop(self):\r | |
111 | data = [random.randrange(2000) for i in range(1000)]\r | |
112 | heap = data[:10]\r | |
113 | self.module.heapify(heap)\r | |
114 | for item in data[10:]:\r | |
115 | self.module.heappushpop(heap, item)\r | |
116 | self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])\r | |
117 | self.assertEqual(self.module.heappushpop([], 'x'), 'x')\r | |
118 | \r | |
119 | def test_heappushpop(self):\r | |
120 | h = []\r | |
121 | x = self.module.heappushpop(h, 10)\r | |
122 | self.assertEqual((h, x), ([], 10))\r | |
123 | \r | |
124 | h = [10]\r | |
125 | x = self.module.heappushpop(h, 10.0)\r | |
126 | self.assertEqual((h, x), ([10], 10.0))\r | |
127 | self.assertEqual(type(h[0]), int)\r | |
128 | self.assertEqual(type(x), float)\r | |
129 | \r | |
130 | h = [10];\r | |
131 | x = self.module.heappushpop(h, 9)\r | |
132 | self.assertEqual((h, x), ([10], 9))\r | |
133 | \r | |
134 | h = [10];\r | |
135 | x = self.module.heappushpop(h, 11)\r | |
136 | self.assertEqual((h, x), ([11], 10))\r | |
137 | \r | |
138 | def test_heapsort(self):\r | |
139 | # Exercise everything with repeated heapsort checks\r | |
140 | for trial in xrange(100):\r | |
141 | size = random.randrange(50)\r | |
142 | data = [random.randrange(25) for i in range(size)]\r | |
143 | if trial & 1: # Half of the time, use heapify\r | |
144 | heap = data[:]\r | |
145 | self.module.heapify(heap)\r | |
146 | else: # The rest of the time, use heappush\r | |
147 | heap = []\r | |
148 | for item in data:\r | |
149 | self.module.heappush(heap, item)\r | |
150 | heap_sorted = [self.module.heappop(heap) for i in range(size)]\r | |
151 | self.assertEqual(heap_sorted, sorted(data))\r | |
152 | \r | |
153 | def test_merge(self):\r | |
154 | inputs = []\r | |
155 | for i in xrange(random.randrange(5)):\r | |
156 | row = sorted(random.randrange(1000) for j in range(random.randrange(10)))\r | |
157 | inputs.append(row)\r | |
158 | self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))\r | |
159 | self.assertEqual(list(self.module.merge()), [])\r | |
160 | \r | |
161 | def test_merge_stability(self):\r | |
162 | class Int(int):\r | |
163 | pass\r | |
164 | inputs = [[], [], [], []]\r | |
165 | for i in range(20000):\r | |
166 | stream = random.randrange(4)\r | |
167 | x = random.randrange(500)\r | |
168 | obj = Int(x)\r | |
169 | obj.pair = (x, stream)\r | |
170 | inputs[stream].append(obj)\r | |
171 | for stream in inputs:\r | |
172 | stream.sort()\r | |
173 | result = [i.pair for i in self.module.merge(*inputs)]\r | |
174 | self.assertEqual(result, sorted(result))\r | |
175 | \r | |
176 | def test_nsmallest(self):\r | |
177 | data = [(random.randrange(2000), i) for i in range(1000)]\r | |
178 | for f in (None, lambda x: x[0] * 547 % 2000):\r | |
179 | for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):\r | |
180 | self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])\r | |
181 | self.assertEqual(self.module.nsmallest(n, data, key=f),\r | |
182 | sorted(data, key=f)[:n])\r | |
183 | \r | |
184 | def test_nlargest(self):\r | |
185 | data = [(random.randrange(2000), i) for i in range(1000)]\r | |
186 | for f in (None, lambda x: x[0] * 547 % 2000):\r | |
187 | for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):\r | |
188 | self.assertEqual(self.module.nlargest(n, data),\r | |
189 | sorted(data, reverse=True)[:n])\r | |
190 | self.assertEqual(self.module.nlargest(n, data, key=f),\r | |
191 | sorted(data, key=f, reverse=True)[:n])\r | |
192 | \r | |
193 | def test_comparison_operator(self):\r | |
194 | # Issue 3051: Make sure heapq works with both __lt__ and __le__\r | |
195 | def hsort(data, comp):\r | |
196 | data = map(comp, data)\r | |
197 | self.module.heapify(data)\r | |
198 | return [self.module.heappop(data).x for i in range(len(data))]\r | |
199 | class LT:\r | |
200 | def __init__(self, x):\r | |
201 | self.x = x\r | |
202 | def __lt__(self, other):\r | |
203 | return self.x > other.x\r | |
204 | class LE:\r | |
205 | def __init__(self, x):\r | |
206 | self.x = x\r | |
207 | def __le__(self, other):\r | |
208 | return self.x >= other.x\r | |
209 | data = [random.random() for i in range(100)]\r | |
210 | target = sorted(data, reverse=True)\r | |
211 | self.assertEqual(hsort(data, LT), target)\r | |
212 | self.assertEqual(hsort(data, LE), target)\r | |
213 | \r | |
214 | \r | |
215 | class TestHeapPython(TestHeap):\r | |
216 | module = py_heapq\r | |
217 | \r | |
218 | \r | |
219 | @skipUnless(c_heapq, 'requires _heapq')\r | |
220 | class TestHeapC(TestHeap):\r | |
221 | module = c_heapq\r | |
222 | \r | |
223 | \r | |
224 | #==============================================================================\r | |
225 | \r | |
226 | class LenOnly:\r | |
227 | "Dummy sequence class defining __len__ but not __getitem__."\r | |
228 | def __len__(self):\r | |
229 | return 10\r | |
230 | \r | |
231 | class GetOnly:\r | |
232 | "Dummy sequence class defining __getitem__ but not __len__."\r | |
233 | def __getitem__(self, ndx):\r | |
234 | return 10\r | |
235 | \r | |
236 | class CmpErr:\r | |
237 | "Dummy element that always raises an error during comparison"\r | |
238 | def __cmp__(self, other):\r | |
239 | raise ZeroDivisionError\r | |
240 | \r | |
241 | def R(seqn):\r | |
242 | 'Regular generator'\r | |
243 | for i in seqn:\r | |
244 | yield i\r | |
245 | \r | |
246 | class G:\r | |
247 | 'Sequence using __getitem__'\r | |
248 | def __init__(self, seqn):\r | |
249 | self.seqn = seqn\r | |
250 | def __getitem__(self, i):\r | |
251 | return self.seqn[i]\r | |
252 | \r | |
253 | class I:\r | |
254 | 'Sequence using iterator protocol'\r | |
255 | def __init__(self, seqn):\r | |
256 | self.seqn = seqn\r | |
257 | self.i = 0\r | |
258 | def __iter__(self):\r | |
259 | return self\r | |
260 | def next(self):\r | |
261 | if self.i >= len(self.seqn): raise StopIteration\r | |
262 | v = self.seqn[self.i]\r | |
263 | self.i += 1\r | |
264 | return v\r | |
265 | \r | |
266 | class Ig:\r | |
267 | 'Sequence using iterator protocol defined with a generator'\r | |
268 | def __init__(self, seqn):\r | |
269 | self.seqn = seqn\r | |
270 | self.i = 0\r | |
271 | def __iter__(self):\r | |
272 | for val in self.seqn:\r | |
273 | yield val\r | |
274 | \r | |
275 | class X:\r | |
276 | 'Missing __getitem__ and __iter__'\r | |
277 | def __init__(self, seqn):\r | |
278 | self.seqn = seqn\r | |
279 | self.i = 0\r | |
280 | def next(self):\r | |
281 | if self.i >= len(self.seqn): raise StopIteration\r | |
282 | v = self.seqn[self.i]\r | |
283 | self.i += 1\r | |
284 | return v\r | |
285 | \r | |
286 | class N:\r | |
287 | 'Iterator missing next()'\r | |
288 | def __init__(self, seqn):\r | |
289 | self.seqn = seqn\r | |
290 | self.i = 0\r | |
291 | def __iter__(self):\r | |
292 | return self\r | |
293 | \r | |
294 | class E:\r | |
295 | 'Test propagation of exceptions'\r | |
296 | def __init__(self, seqn):\r | |
297 | self.seqn = seqn\r | |
298 | self.i = 0\r | |
299 | def __iter__(self):\r | |
300 | return self\r | |
301 | def next(self):\r | |
302 | 3 // 0\r | |
303 | \r | |
304 | class S:\r | |
305 | 'Test immediate stop'\r | |
306 | def __init__(self, seqn):\r | |
307 | pass\r | |
308 | def __iter__(self):\r | |
309 | return self\r | |
310 | def next(self):\r | |
311 | raise StopIteration\r | |
312 | \r | |
313 | from itertools import chain, imap\r | |
314 | def L(seqn):\r | |
315 | 'Test multiple tiers of iterators'\r | |
316 | return chain(imap(lambda x:x, R(Ig(G(seqn)))))\r | |
317 | \r | |
318 | class TestErrorHandling(TestCase):\r | |
319 | module = None\r | |
320 | \r | |
321 | def test_non_sequence(self):\r | |
322 | for f in (self.module.heapify, self.module.heappop):\r | |
323 | self.assertRaises((TypeError, AttributeError), f, 10)\r | |
324 | for f in (self.module.heappush, self.module.heapreplace,\r | |
325 | self.module.nlargest, self.module.nsmallest):\r | |
326 | self.assertRaises((TypeError, AttributeError), f, 10, 10)\r | |
327 | \r | |
328 | def test_len_only(self):\r | |
329 | for f in (self.module.heapify, self.module.heappop):\r | |
330 | self.assertRaises((TypeError, AttributeError), f, LenOnly())\r | |
331 | for f in (self.module.heappush, self.module.heapreplace):\r | |
332 | self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)\r | |
333 | for f in (self.module.nlargest, self.module.nsmallest):\r | |
334 | self.assertRaises(TypeError, f, 2, LenOnly())\r | |
335 | \r | |
336 | def test_get_only(self):\r | |
337 | seq = [CmpErr(), CmpErr(), CmpErr()]\r | |
338 | for f in (self.module.heapify, self.module.heappop):\r | |
339 | self.assertRaises(ZeroDivisionError, f, seq)\r | |
340 | for f in (self.module.heappush, self.module.heapreplace):\r | |
341 | self.assertRaises(ZeroDivisionError, f, seq, 10)\r | |
342 | for f in (self.module.nlargest, self.module.nsmallest):\r | |
343 | self.assertRaises(ZeroDivisionError, f, 2, seq)\r | |
344 | \r | |
345 | def test_arg_parsing(self):\r | |
346 | for f in (self.module.heapify, self.module.heappop,\r | |
347 | self.module.heappush, self.module.heapreplace,\r | |
348 | self.module.nlargest, self.module.nsmallest):\r | |
349 | self.assertRaises((TypeError, AttributeError), f, 10)\r | |
350 | \r | |
351 | def test_iterable_args(self):\r | |
352 | for f in (self.module.nlargest, self.module.nsmallest):\r | |
353 | for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):\r | |
354 | for g in (G, I, Ig, L, R):\r | |
355 | with test_support.check_py3k_warnings(\r | |
356 | ("comparing unequal types not supported",\r | |
357 | DeprecationWarning), quiet=True):\r | |
358 | self.assertEqual(f(2, g(s)), f(2,s))\r | |
359 | self.assertEqual(f(2, S(s)), [])\r | |
360 | self.assertRaises(TypeError, f, 2, X(s))\r | |
361 | self.assertRaises(TypeError, f, 2, N(s))\r | |
362 | self.assertRaises(ZeroDivisionError, f, 2, E(s))\r | |
363 | \r | |
364 | \r | |
365 | class TestErrorHandlingPython(TestErrorHandling):\r | |
366 | module = py_heapq\r | |
367 | \r | |
368 | \r | |
369 | @skipUnless(c_heapq, 'requires _heapq')\r | |
370 | class TestErrorHandlingC(TestErrorHandling):\r | |
371 | module = c_heapq\r | |
372 | \r | |
373 | \r | |
374 | #==============================================================================\r | |
375 | \r | |
376 | \r | |
377 | def test_main(verbose=None):\r | |
378 | test_classes = [TestModules, TestHeapPython, TestHeapC,\r | |
379 | TestErrorHandlingPython, TestErrorHandlingC]\r | |
380 | test_support.run_unittest(*test_classes)\r | |
381 | \r | |
382 | # verify reference counting\r | |
383 | if verbose and hasattr(sys, "gettotalrefcount"):\r | |
384 | import gc\r | |
385 | counts = [None] * 5\r | |
386 | for i in xrange(len(counts)):\r | |
387 | test_support.run_unittest(*test_classes)\r | |
388 | gc.collect()\r | |
389 | counts[i] = sys.gettotalrefcount()\r | |
390 | print counts\r | |
391 | \r | |
392 | if __name__ == "__main__":\r | |
393 | test_main(verbose=True)\r |