]>
Commit | Line | Data |
---|---|---|
260147bd DM |
1 | //! A thread pool which run a closure in parallel. |
2 | ||
3c9b3702 | 3 | use std::sync::{Arc, Mutex}; |
ae3cfa8f WB |
4 | use std::thread::JoinHandle; |
5 | ||
55bee048 | 6 | use anyhow::{bail, format_err, Error}; |
ae3cfa8f | 7 | use crossbeam_channel::{bounded, Sender}; |
3c9b3702 | 8 | |
ae3cfa8f | 9 | /// A handle to send data to the worker thread (implements clone) |
3c9b3702 DM |
10 | pub struct SendHandle<I> { |
11 | input: Sender<I>, | |
12 | abort: Arc<Mutex<Option<String>>>, | |
13 | } | |
14 | ||
dec00364 | 15 | /// Returns the first error happened, if any |
0a8f3ae0 | 16 | pub fn check_abort(abort: &Mutex<Option<String>>) -> Result<(), Error> { |
dec00364 SR |
17 | let guard = abort.lock().unwrap(); |
18 | if let Some(err_msg) = &*guard { | |
19 | return Err(format_err!("{}", err_msg)); | |
3c9b3702 | 20 | } |
dec00364 SR |
21 | Ok(()) |
22 | } | |
3c9b3702 | 23 | |
dec00364 | 24 | impl<I: Send> SendHandle<I> { |
3c9b3702 DM |
25 | /// Send data to the worker threads |
26 | pub fn send(&self, input: I) -> Result<(), Error> { | |
0a8f3ae0 | 27 | check_abort(&self.abort)?; |
55bee048 DM |
28 | match self.input.send(input) { |
29 | Ok(()) => Ok(()), | |
30 | Err(_) => bail!("send failed - channel closed"), | |
31 | } | |
3c9b3702 DM |
32 | } |
33 | } | |
34 | ||
1c13afa8 DM |
35 | /// A thread pool which run the supplied closure |
36 | /// | |
37 | /// The send command sends data to the worker threads. If one handler | |
38 | /// returns an error, we mark the channel as failed and it is no | |
39 | /// longer possible to send data. | |
40 | /// | |
41 | /// When done, the 'complete()' method needs to be called to check for | |
42 | /// outstanding errors. | |
a71bc08f | 43 | pub struct ParallelHandler<I> { |
1c13afa8 DM |
44 | handles: Vec<JoinHandle<()>>, |
45 | name: String, | |
46 | input: Option<SendHandle<I>>, | |
47 | } | |
48 | ||
ae3cfa8f | 49 | impl<I> Clone for SendHandle<I> { |
3c9b3702 | 50 | fn clone(&self) -> Self { |
ae3cfa8f WB |
51 | Self { |
52 | input: self.input.clone(), | |
db24c011 | 53 | abort: Arc::clone(&self.abort), |
ae3cfa8f | 54 | } |
3c9b3702 DM |
55 | } |
56 | } | |
57 | ||
a71bc08f | 58 | impl<I: Send + 'static> ParallelHandler<I> { |
3c9b3702 DM |
59 | /// Create a new thread pool, each thread processing incoming data |
60 | /// with 'handler_fn'. | |
ae3cfa8f | 61 | pub fn new<F>(name: &str, threads: usize, handler_fn: F) -> Self |
9531d2c5 TL |
62 | where |
63 | F: Fn(I) -> Result<(), Error> + Send + Clone + 'static, | |
3c9b3702 DM |
64 | { |
65 | let mut handles = Vec::new(); | |
66 | let (input_tx, input_rx) = bounded::<I>(threads); | |
67 | ||
68 | let abort = Arc::new(Mutex::new(None)); | |
69 | ||
70 | for i in 0..threads { | |
71 | let input_rx = input_rx.clone(); | |
db24c011 | 72 | let abort = Arc::clone(&abort); |
a71bc08f | 73 | let handler_fn = handler_fn.clone(); |
b02b374b | 74 | |
3c9b3702 DM |
75 | handles.push( |
76 | std::thread::Builder::new() | |
77 | .name(format!("{} ({})", name, i)) | |
ae3cfa8f WB |
78 | .spawn(move || loop { |
79 | let data = match input_rx.recv() { | |
80 | Ok(data) => data, | |
81 | Err(_) => return, | |
82 | }; | |
6aff2de5 MS |
83 | if let Err(err) = (handler_fn)(data) { |
84 | let mut guard = abort.lock().unwrap(); | |
85 | if guard.is_none() { | |
86 | *guard = Some(err.to_string()); | |
3c9b3702 DM |
87 | } |
88 | } | |
89 | }) | |
9531d2c5 | 90 | .unwrap(), |
3c9b3702 DM |
91 | ); |
92 | } | |
93 | Self { | |
94 | handles, | |
95 | name: name.to_string(), | |
1c13afa8 | 96 | input: Some(SendHandle { |
3c9b3702 DM |
97 | input: input_tx, |
98 | abort, | |
1c13afa8 | 99 | }), |
3c9b3702 DM |
100 | } |
101 | } | |
102 | ||
103 | /// Returns a cloneable channel to send data to the worker threads | |
104 | pub fn channel(&self) -> SendHandle<I> { | |
1c13afa8 | 105 | self.input.as_ref().unwrap().clone() |
3c9b3702 DM |
106 | } |
107 | ||
108 | /// Send data to the worker threads | |
109 | pub fn send(&self, input: I) -> Result<(), Error> { | |
1c13afa8 | 110 | self.input.as_ref().unwrap().send(input)?; |
3c9b3702 DM |
111 | Ok(()) |
112 | } | |
113 | ||
114 | /// Wait for worker threads to complete and check for errors | |
1c13afa8 | 115 | pub fn complete(mut self) -> Result<(), Error> { |
dec00364 SR |
116 | let input = self.input.take().unwrap(); |
117 | let abort = Arc::clone(&input.abort); | |
0a8f3ae0 | 118 | check_abort(&abort)?; |
dec00364 | 119 | drop(input); |
1c13afa8 DM |
120 | |
121 | let msg_list = self.join_threads(); | |
122 | ||
0a8f3ae0 DM |
123 | // an error might be encountered while waiting for the join |
124 | check_abort(&abort)?; | |
125 | ||
1c13afa8 DM |
126 | if msg_list.is_empty() { |
127 | return Ok(()); | |
128 | } | |
129 | Err(format_err!("{}", msg_list.join("\n"))) | |
130 | } | |
131 | ||
132 | fn join_threads(&mut self) -> Vec<String> { | |
1c13afa8 DM |
133 | let mut msg_list = Vec::new(); |
134 | ||
135 | let mut i = 0; | |
0d2133db | 136 | while let Some(handle) = self.handles.pop() { |
3c9b3702 DM |
137 | if let Err(panic) = handle.join() { |
138 | match panic.downcast::<&str>() { | |
9531d2c5 TL |
139 | Ok(panic_msg) => msg_list.push(format!( |
140 | "thread {} ({}) panicked: {}", | |
141 | self.name, i, panic_msg | |
142 | )), | |
143 | Err(_) => msg_list.push(format!("thread {} ({}) panicked", self.name, i)), | |
3c9b3702 DM |
144 | } |
145 | } | |
1c13afa8 | 146 | i += 1; |
3c9b3702 | 147 | } |
1c13afa8 DM |
148 | msg_list |
149 | } | |
150 | } | |
151 | ||
152 | // Note: We make sure that all threads will be joined | |
a71bc08f | 153 | impl<I> Drop for ParallelHandler<I> { |
1c13afa8 DM |
154 | fn drop(&mut self) { |
155 | drop(self.input.take()); | |
ee1a9c32 WB |
156 | while let Some(handle) = self.handles.pop() { |
157 | let _ = handle.join(); | |
3c9b3702 | 158 | } |
3c9b3702 DM |
159 | } |
160 | } |