1 # -*- coding: utf-8 -*-
2 from __future__
import absolute_import
8 from collections
import defaultdict
9 from functools
import partial
11 from ..services
.exception
import serialize_dashboard_exception
12 from ..tools
import NotificationQueue
, TaskManager
, TaskExecutor
16 class CallbackExecutor(TaskExecutor
):
17 def __init__(self
, fail
, progress
):
18 super(MyTask
.CallbackExecutor
, self
).__init
__()
20 self
.progress
= progress
23 super(MyTask
.CallbackExecutor
, self
).init(task
)
24 args
= [self
.callback
]
25 args
.extend(self
.task
.fn_args
)
26 self
.task
.fn_args
= args
28 def callback(self
, result
):
29 self
.task
.set_progress(self
.progress
)
31 self
.finish(None, Exception("Task Unexpected Exception"))
33 self
.finish(result
, None)
35 # pylint: disable=too-many-arguments
36 def __init__(self
, op_seconds
, wait
=False, fail
=False, progress
=50,
37 is_async
=False, handle_ex
=False):
38 self
.op_seconds
= op_seconds
41 self
.progress
= progress
42 self
.is_async
= is_async
43 self
.handle_ex
= handle_ex
44 self
._event
= threading
.Event()
46 def run(self
, ns
, timeout
=None):
48 kwargs
= {'dummy': 'arg'}
49 h_ex
= partial(serialize_dashboard_exception
,
50 include_http_status
=True) if self
.handle_ex
else None
52 task
= TaskManager
.run(
53 ns
, self
.metadata(), self
.task_op
, args
, kwargs
,
54 exception_handler
=h_ex
)
56 task
= TaskManager
.run(
57 ns
, self
.metadata(), self
.task_async_op
, args
, kwargs
,
58 executor
=MyTask
.CallbackExecutor(self
.fail
, self
.progress
),
59 exception_handler
=h_ex
)
60 return task
.wait(timeout
)
62 def task_op(self
, *args
, **kwargs
):
63 time
.sleep(self
.op_seconds
)
64 TaskManager
.current_task().set_progress(self
.progress
)
66 raise Exception("Task Unexpected Exception")
69 return {'args': list(args
), 'kwargs': kwargs
}
71 def task_async_op(self
, callback
, *args
, **kwargs
):
72 if self
.fail
== "premature":
73 raise Exception("Task Unexpected Exception")
76 time
.sleep(self
.op_seconds
)
79 callback({'args': list(args
), 'kwargs': kwargs
})
81 worker
= threading
.Thread(target
=_run_bg
)
89 'op_seconds': self
.op_seconds
,
92 'progress': self
.progress
,
93 'is_async': self
.is_async
,
94 'handle_ex': self
.handle_ex
98 class TaskTest(unittest
.TestCase
):
100 TASK_FINISHED_MAP
= defaultdict(threading
.Event
)
103 def _handle_task(cls
, task
):
104 cls
.TASK_FINISHED_MAP
[task
.name
].set()
107 def wait_for_task(cls
, name
):
108 cls
.TASK_FINISHED_MAP
[name
].wait()
112 NotificationQueue
.start_queue()
114 NotificationQueue
.register(cls
._handle
_task
, 'cd_task_finished',
118 def tearDownClass(cls
):
119 NotificationQueue
.deregister(cls
._handle
_task
, 'cd_task_finished')
120 NotificationQueue
.stop()
123 TaskManager
.FINISHED_TASK_SIZE
= 10
124 TaskManager
.FINISHED_TASK_TTL
= 60.0
126 def assertTaskResult(self
, result
): # noqa: N802
127 self
.assertEqual(result
,
128 {'args': ['dummy arg'], 'kwargs': {'dummy': 'arg'}})
130 def test_fast_task(self
):
132 state
, result
= task1
.run('test1/task1')
133 self
.assertEqual(state
, TaskManager
.VALUE_DONE
)
134 self
.assertTaskResult(result
)
135 self
.wait_for_task('test1/task1')
136 _
, fn_t
= TaskManager
.list('test1/*')
137 self
.assertEqual(len(fn_t
), 1)
138 self
.assertIsNone(fn_t
[0].exception
)
139 self
.assertTaskResult(fn_t
[0].ret_value
)
140 self
.assertEqual(fn_t
[0].progress
, 100)
142 def test_slow_task(self
):
144 state
, result
= task1
.run('test2/task1', 0.5)
145 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
146 self
.assertIsNone(result
)
147 self
.wait_for_task('test2/task1')
148 _
, fn_t
= TaskManager
.list('test2/*')
149 self
.assertEqual(len(fn_t
), 1)
150 self
.assertIsNone(fn_t
[0].exception
)
151 self
.assertTaskResult(fn_t
[0].ret_value
)
152 self
.assertEqual(fn_t
[0].progress
, 100)
154 def test_fast_task_with_failure(self
):
155 task1
= MyTask(1, fail
=True, progress
=40)
157 with self
.assertRaises(Exception) as ctx
:
158 task1
.run('test3/task1')
160 self
.assertEqual(str(ctx
.exception
), "Task Unexpected Exception")
161 self
.wait_for_task('test3/task1')
162 _
, fn_t
= TaskManager
.list('test3/*')
163 self
.assertEqual(len(fn_t
), 1)
164 self
.assertIsNone(fn_t
[0].ret_value
)
165 self
.assertEqual(str(fn_t
[0].exception
), "Task Unexpected Exception")
166 self
.assertEqual(fn_t
[0].progress
, 40)
168 def test_slow_task_with_failure(self
):
169 task1
= MyTask(1, fail
=True, progress
=70)
170 state
, result
= task1
.run('test4/task1', 0.5)
171 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
172 self
.assertIsNone(result
)
173 self
.wait_for_task('test4/task1')
174 _
, fn_t
= TaskManager
.list('test4/*')
175 self
.assertEqual(len(fn_t
), 1)
176 self
.assertIsNone(fn_t
[0].ret_value
)
177 self
.assertEqual(str(fn_t
[0].exception
), "Task Unexpected Exception")
178 self
.assertEqual(fn_t
[0].progress
, 70)
180 def test_executing_tasks_list(self
):
181 task1
= MyTask(0, wait
=True, progress
=30)
182 task2
= MyTask(0, wait
=True, progress
=60)
183 state
, result
= task1
.run('test5/task1', 0.5)
184 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
185 self
.assertIsNone(result
)
186 ex_t
, _
= TaskManager
.list('test5/*')
187 self
.assertEqual(len(ex_t
), 1)
188 self
.assertEqual(ex_t
[0].name
, 'test5/task1')
189 self
.assertEqual(ex_t
[0].progress
, 30)
190 state
, result
= task2
.run('test5/task2', 0.5)
191 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
192 self
.assertIsNone(result
)
193 ex_t
, _
= TaskManager
.list('test5/*')
194 self
.assertEqual(len(ex_t
), 2)
196 if task
.name
== 'test5/task1':
197 self
.assertEqual(task
.progress
, 30)
198 elif task
.name
== 'test5/task2':
199 self
.assertEqual(task
.progress
, 60)
201 self
.wait_for_task('test5/task2')
202 ex_t
, _
= TaskManager
.list('test5/*')
203 self
.assertEqual(len(ex_t
), 1)
204 self
.assertEqual(ex_t
[0].name
, 'test5/task1')
206 self
.wait_for_task('test5/task1')
207 ex_t
, _
= TaskManager
.list('test5/*')
208 self
.assertEqual(len(ex_t
), 0)
210 def test_task_idempotent(self
):
211 task1
= MyTask(0, wait
=True)
212 task1_clone
= MyTask(0, wait
=True)
213 state
, result
= task1
.run('test6/task1', 0.5)
214 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
215 self
.assertIsNone(result
)
216 ex_t
, _
= TaskManager
.list('test6/*')
217 self
.assertEqual(len(ex_t
), 1)
218 self
.assertEqual(ex_t
[0].name
, 'test6/task1')
219 state
, result
= task1_clone
.run('test6/task1', 0.5)
220 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
221 self
.assertIsNone(result
)
222 ex_t
, _
= TaskManager
.list('test6/*')
223 self
.assertEqual(len(ex_t
), 1)
224 self
.assertEqual(ex_t
[0].name
, 'test6/task1')
226 self
.wait_for_task('test6/task1')
227 ex_t
, fn_t
= TaskManager
.list('test6/*')
228 self
.assertEqual(len(ex_t
), 0)
229 self
.assertEqual(len(fn_t
), 1)
231 def test_finished_cleanup(self
):
232 TaskManager
.FINISHED_TASK_SIZE
= 2
233 TaskManager
.FINISHED_TASK_TTL
= 0.5
236 state
, result
= task1
.run('test7/task1')
237 self
.assertEqual(state
, TaskManager
.VALUE_DONE
)
238 self
.assertTaskResult(result
)
239 self
.wait_for_task('test7/task1')
240 state
, result
= task2
.run('test7/task2')
241 self
.assertEqual(state
, TaskManager
.VALUE_DONE
)
242 self
.assertTaskResult(result
)
243 self
.wait_for_task('test7/task2')
245 _
, fn_t
= TaskManager
.list('test7/*')
246 self
.assertEqual(len(fn_t
), 2)
247 for idx
, task
in enumerate(fn_t
):
248 self
.assertEqual(task
.name
,
249 "test7/task{}".format(len(fn_t
)-idx
))
251 state
, result
= task3
.run('test7/task3')
252 self
.assertEqual(state
, TaskManager
.VALUE_DONE
)
253 self
.assertTaskResult(result
)
254 self
.wait_for_task('test7/task3')
256 _
, fn_t
= TaskManager
.list('test7/*')
257 self
.assertEqual(len(fn_t
), 3)
258 for idx
, task
in enumerate(fn_t
):
259 self
.assertEqual(task
.name
,
260 "test7/task{}".format(len(fn_t
)-idx
))
261 _
, fn_t
= TaskManager
.list('test7/*')
262 self
.assertEqual(len(fn_t
), 2)
263 for idx
, task
in enumerate(fn_t
):
264 self
.assertEqual(task
.name
,
265 "test7/task{}".format(len(fn_t
)-idx
+1))
267 def test_task_serialization_format(self
):
268 task1
= MyTask(0, wait
=True, progress
=20)
270 task1
.run('test8/task1', 0.5)
271 task2
.run('test8/task2', 0.5)
272 self
.wait_for_task('test8/task2')
273 ex_t
, fn_t
= TaskManager
.list_serializable('test8/*')
274 self
.assertEqual(len(ex_t
), 1)
275 self
.assertEqual(len(fn_t
), 1)
279 except ValueError as ex
:
280 self
.fail("Failed to serialize executing tasks: {}".format(str(ex
)))
284 except ValueError as ex
:
285 self
.fail("Failed to serialize finished tasks: {}".format(str(ex
)))
287 # validate executing tasks attributes
288 self
.assertEqual(len(ex_t
[0].keys()), 4)
289 self
.assertEqual(ex_t
[0]['name'], 'test8/task1')
290 self
.assertEqual(ex_t
[0]['metadata'], task1
.metadata())
291 self
.assertIsNotNone(ex_t
[0]['begin_time'])
292 self
.assertEqual(ex_t
[0]['progress'], 20)
293 # validate finished tasks attributes
294 self
.assertEqual(len(fn_t
[0].keys()), 9)
295 self
.assertEqual(fn_t
[0]['name'], 'test8/task2')
296 self
.assertEqual(fn_t
[0]['metadata'], task2
.metadata())
297 self
.assertIsNotNone(fn_t
[0]['begin_time'])
298 self
.assertIsNotNone(fn_t
[0]['end_time'])
299 self
.assertGreaterEqual(fn_t
[0]['duration'], 1.0)
300 self
.assertEqual(fn_t
[0]['progress'], 100)
301 self
.assertTrue(fn_t
[0]['success'])
302 self
.assertTaskResult(fn_t
[0]['ret_value'])
303 self
.assertIsNone(fn_t
[0]['exception'])
305 self
.wait_for_task('test8/task1')
307 def test_fast_async_task(self
):
308 task1
= MyTask(1, is_async
=True)
309 state
, result
= task1
.run('test9/task1')
310 self
.assertEqual(state
, TaskManager
.VALUE_DONE
)
311 self
.assertTaskResult(result
)
312 self
.wait_for_task('test9/task1')
313 _
, fn_t
= TaskManager
.list('test9/*')
314 self
.assertEqual(len(fn_t
), 1)
315 self
.assertIsNone(fn_t
[0].exception
)
316 self
.assertTaskResult(fn_t
[0].ret_value
)
317 self
.assertEqual(fn_t
[0].progress
, 100)
319 def test_slow_async_task(self
):
320 task1
= MyTask(1, is_async
=True)
321 state
, result
= task1
.run('test10/task1', 0.5)
322 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
323 self
.assertIsNone(result
)
324 self
.wait_for_task('test10/task1')
325 _
, fn_t
= TaskManager
.list('test10/*')
326 self
.assertEqual(len(fn_t
), 1)
327 self
.assertIsNone(fn_t
[0].exception
)
328 self
.assertTaskResult(fn_t
[0].ret_value
)
329 self
.assertEqual(fn_t
[0].progress
, 100)
331 def test_fast_async_task_with_failure(self
):
332 task1
= MyTask(1, fail
=True, progress
=40, is_async
=True)
334 with self
.assertRaises(Exception) as ctx
:
335 task1
.run('test11/task1')
337 self
.assertEqual(str(ctx
.exception
), "Task Unexpected Exception")
338 self
.wait_for_task('test11/task1')
339 _
, fn_t
= TaskManager
.list('test11/*')
340 self
.assertEqual(len(fn_t
), 1)
341 self
.assertIsNone(fn_t
[0].ret_value
)
342 self
.assertEqual(str(fn_t
[0].exception
), "Task Unexpected Exception")
343 self
.assertEqual(fn_t
[0].progress
, 40)
345 def test_slow_async_task_with_failure(self
):
346 task1
= MyTask(1, fail
=True, progress
=70, is_async
=True)
347 state
, result
= task1
.run('test12/task1', 0.5)
348 self
.assertEqual(state
, TaskManager
.VALUE_EXECUTING
)
349 self
.assertIsNone(result
)
350 self
.wait_for_task('test12/task1')
351 _
, fn_t
= TaskManager
.list('test12/*')
352 self
.assertEqual(len(fn_t
), 1)
353 self
.assertIsNone(fn_t
[0].ret_value
)
354 self
.assertEqual(str(fn_t
[0].exception
), "Task Unexpected Exception")
355 self
.assertEqual(fn_t
[0].progress
, 70)
357 def test_fast_async_task_with_premature_failure(self
):
358 task1
= MyTask(1, fail
="premature", progress
=40, is_async
=True)
360 with self
.assertRaises(Exception) as ctx
:
361 task1
.run('test13/task1')
363 self
.assertEqual(str(ctx
.exception
), "Task Unexpected Exception")
364 self
.wait_for_task('test13/task1')
365 _
, fn_t
= TaskManager
.list('test13/*')
366 self
.assertEqual(len(fn_t
), 1)
367 self
.assertIsNone(fn_t
[0].ret_value
)
368 self
.assertEqual(str(fn_t
[0].exception
), "Task Unexpected Exception")
370 def test_task_serialization_format_on_failure(self
):
371 task1
= MyTask(1, fail
=True)
372 task1
.run('test14/task1', 0.5)
373 self
.wait_for_task('test14/task1')
374 ex_t
, fn_t
= TaskManager
.list_serializable('test14/*')
375 self
.assertEqual(len(ex_t
), 0)
376 self
.assertEqual(len(fn_t
), 1)
377 # validate finished tasks attributes
381 except TypeError as ex
:
382 self
.fail("Failed to serialize finished tasks: {}".format(str(ex
)))
384 self
.assertEqual(len(fn_t
[0].keys()), 9)
385 self
.assertEqual(fn_t
[0]['name'], 'test14/task1')
386 self
.assertEqual(fn_t
[0]['metadata'], task1
.metadata())
387 self
.assertIsNotNone(fn_t
[0]['begin_time'])
388 self
.assertIsNotNone(fn_t
[0]['end_time'])
389 self
.assertGreaterEqual(fn_t
[0]['duration'], 1.0)
390 self
.assertEqual(fn_t
[0]['progress'], 50)
391 self
.assertFalse(fn_t
[0]['success'])
392 self
.assertIsNotNone(fn_t
[0]['exception'])
393 self
.assertEqual(fn_t
[0]['exception'],
394 {"detail": "Task Unexpected Exception"})
396 def test_task_serialization_format_on_failure_with_handler(self
):
397 task1
= MyTask(1, fail
=True, handle_ex
=True)
398 task1
.run('test15/task1', 0.5)
399 self
.wait_for_task('test15/task1')
400 ex_t
, fn_t
= TaskManager
.list_serializable('test15/*')
401 self
.assertEqual(len(ex_t
), 0)
402 self
.assertEqual(len(fn_t
), 1)
403 # validate finished tasks attributes
407 except TypeError as ex
:
408 self
.fail("Failed to serialize finished tasks: {}".format(str(ex
)))
410 self
.assertEqual(len(fn_t
[0].keys()), 9)
411 self
.assertEqual(fn_t
[0]['name'], 'test15/task1')
412 self
.assertEqual(fn_t
[0]['metadata'], task1
.metadata())
413 self
.assertIsNotNone(fn_t
[0]['begin_time'])
414 self
.assertIsNotNone(fn_t
[0]['end_time'])
415 self
.assertGreaterEqual(fn_t
[0]['duration'], 1.0)
416 self
.assertEqual(fn_t
[0]['progress'], 50)
417 self
.assertFalse(fn_t
[0]['success'])
418 self
.assertIsNotNone(fn_t
[0]['exception'])
419 self
.assertEqual(fn_t
[0]['exception'], {
421 'detail': 'Task Unexpected Exception',
431 'name': 'test15/task1'