]> git.proxmox.com Git - ceph.git/blob - ceph/src/boost/libs/mpi/test/python/nonblocking_test.py
import quincy beta 17.1.0
[ceph.git] / ceph / src / boost / libs / mpi / test / python / nonblocking_test.py
1 # (C) Copyright 2007
2 # Andreas Kloeckner <inform -at- tiker.net>
3 #
4 # Use, modification and distribution is subject to the Boost Software
5 # License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
6 # http://www.boost.org/LICENSE_1_0.txt)
7 #
8 # Authors: Andreas Kloeckner
9
10 from __future__ import print_function
11 import mpi
12 import random
13 import sys
14
15 MAX_GENERATIONS = 20
16 TAG_DEBUG = 0
17 TAG_DATA = 1
18 TAG_TERMINATE = 2
19 TAG_PROGRESS_REPORT = 3
20
21
22
23
24 class TagGroupListener:
25 """Class to help listen for only a given set of tags.
26
27 This is contrived: Typicallly you could just listen for
28 mpi.any_tag and filter."""
29 def __init__(self, comm, tags):
30 self.tags = tags
31 self.comm = comm
32 self.active_requests = {}
33
34 def wait(self):
35 for tag in self.tags:
36 if tag not in self.active_requests:
37 self.active_requests[tag] = self.comm.irecv(tag=tag)
38 requests = mpi.RequestList(self.active_requests.values())
39 data, status, index = mpi.wait_any(requests)
40 del self.active_requests[status.tag]
41 return status, data
42
43 def cancel(self):
44 for r in self.active_requests.itervalues():
45 r.cancel()
46 #r.wait()
47 self.active_requests = {}
48
49
50
51 def rank0():
52 sent_histories = (mpi.size-1)*15
53 print ("sending %d packets on their way" % sent_histories)
54 send_reqs = mpi.RequestList()
55 for i in range(sent_histories):
56 dest = random.randrange(1, mpi.size)
57 send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
58
59 mpi.wait_all(send_reqs)
60
61 completed_histories = []
62 progress_reports = {}
63 dead_kids = []
64
65 tgl = TagGroupListener(mpi.world,
66 [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
67
68 def is_complete():
69 for i in progress_reports.values():
70 if i != sent_histories:
71 return False
72 return len(dead_kids) == mpi.size-1
73
74 while True:
75 status, data = tgl.wait()
76
77 if status.tag == TAG_DATA:
78 #print ("received completed history %s from %d" % (data, status.source))
79 completed_histories.append(data)
80 if len(completed_histories) == sent_histories:
81 print ("all histories received, exiting")
82 for rank in range(1, mpi.size):
83 mpi.world.send(rank, TAG_TERMINATE, None)
84 elif status.tag == TAG_PROGRESS_REPORT:
85 progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
86 elif status.tag == TAG_DEBUG:
87 print ("[DBG %d] %s" % (status.source, data))
88 elif status.tag == TAG_TERMINATE:
89 dead_kids.append(status.source)
90 else:
91 print ("unexpected tag %d from %d" % (status.tag, status.source))
92
93 if is_complete():
94 break
95
96 print ("OK")
97
98 def comm_rank():
99 while True:
100 data, status = mpi.world.recv(return_status=True)
101 if status.tag == TAG_DATA:
102 mpi.world.send(0, TAG_PROGRESS_REPORT, data)
103 data.append(mpi.rank)
104 if len(data) >= MAX_GENERATIONS:
105 dest = 0
106 else:
107 dest = random.randrange(1, mpi.size)
108 mpi.world.send(dest, TAG_DATA, data)
109 elif status.tag == TAG_TERMINATE:
110 from time import sleep
111 mpi.world.send(0, TAG_TERMINATE, 0)
112 break
113 else:
114 print ("[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source))
115
116
117 def main():
118 # this program sends around messages consisting of lists of visited nodes
119 # randomly. After MAX_GENERATIONS, they are returned to rank 0.
120
121 if mpi.rank == 0:
122 rank0()
123 else:
124 comm_rank()
125
126
127
128 if __name__ == "__main__":
129 main()