]>
Commit | Line | Data |
---|---|---|
3c9b3702 | 1 | use std::sync::{Arc, Mutex}; |
ae3cfa8f WB |
2 | use std::thread::JoinHandle; |
3 | ||
55bee048 | 4 | use anyhow::{bail, format_err, Error}; |
ae3cfa8f | 5 | use crossbeam_channel::{bounded, Sender}; |
3c9b3702 | 6 | |
ae3cfa8f | 7 | /// A handle to send data to the worker thread (implements clone) |
3c9b3702 DM |
8 | pub struct SendHandle<I> { |
9 | input: Sender<I>, | |
10 | abort: Arc<Mutex<Option<String>>>, | |
11 | } | |
12 | ||
dec00364 | 13 | /// Returns the first error happened, if any |
0a8f3ae0 | 14 | pub fn check_abort(abort: &Mutex<Option<String>>) -> Result<(), Error> { |
dec00364 SR |
15 | let guard = abort.lock().unwrap(); |
16 | if let Some(err_msg) = &*guard { | |
17 | return Err(format_err!("{}", err_msg)); | |
3c9b3702 | 18 | } |
dec00364 SR |
19 | Ok(()) |
20 | } | |
3c9b3702 | 21 | |
dec00364 | 22 | impl<I: Send> SendHandle<I> { |
3c9b3702 DM |
23 | /// Send data to the worker threads |
24 | pub fn send(&self, input: I) -> Result<(), Error> { | |
0a8f3ae0 | 25 | check_abort(&self.abort)?; |
55bee048 DM |
26 | match self.input.send(input) { |
27 | Ok(()) => Ok(()), | |
28 | Err(_) => bail!("send failed - channel closed"), | |
29 | } | |
3c9b3702 DM |
30 | } |
31 | } | |
32 | ||
1c13afa8 DM |
33 | /// A thread pool which run the supplied closure |
34 | /// | |
35 | /// The send command sends data to the worker threads. If one handler | |
36 | /// returns an error, we mark the channel as failed and it is no | |
37 | /// longer possible to send data. | |
38 | /// | |
39 | /// When done, the 'complete()' method needs to be called to check for | |
40 | /// outstanding errors. | |
b02b374b | 41 | pub struct ParallelHandler<'a, I> { |
1c13afa8 DM |
42 | handles: Vec<JoinHandle<()>>, |
43 | name: String, | |
44 | input: Option<SendHandle<I>>, | |
b02b374b | 45 | _marker: std::marker::PhantomData<&'a ()>, |
1c13afa8 DM |
46 | } |
47 | ||
ae3cfa8f | 48 | impl<I> Clone for SendHandle<I> { |
3c9b3702 | 49 | fn clone(&self) -> Self { |
ae3cfa8f WB |
50 | Self { |
51 | input: self.input.clone(), | |
db24c011 | 52 | abort: Arc::clone(&self.abort), |
ae3cfa8f | 53 | } |
3c9b3702 DM |
54 | } |
55 | } | |
56 | ||
ae3cfa8f | 57 | impl<'a, I: Send + 'static> ParallelHandler<'a, I> { |
3c9b3702 DM |
58 | /// Create a new thread pool, each thread processing incoming data |
59 | /// with 'handler_fn'. | |
ae3cfa8f | 60 | pub fn new<F>(name: &str, threads: usize, handler_fn: F) -> Self |
b02b374b | 61 | where F: Fn(I) -> Result<(), Error> + Send + Clone + 'a, |
3c9b3702 DM |
62 | { |
63 | let mut handles = Vec::new(); | |
64 | let (input_tx, input_rx) = bounded::<I>(threads); | |
65 | ||
66 | let abort = Arc::new(Mutex::new(None)); | |
67 | ||
68 | for i in 0..threads { | |
69 | let input_rx = input_rx.clone(); | |
db24c011 | 70 | let abort = Arc::clone(&abort); |
b02b374b DM |
71 | |
72 | // Erase the 'a lifetime bound. This is safe because we | |
73 | // join all thread in the drop handler. | |
74 | let handler_fn: Box<dyn Fn(I) -> Result<(), Error> + Send + 'a> = | |
75 | Box::new(handler_fn.clone()); | |
76 | let handler_fn: Box<dyn Fn(I) -> Result<(), Error> + Send + 'static> = | |
77 | unsafe { std::mem::transmute(handler_fn) }; | |
78 | ||
3c9b3702 DM |
79 | handles.push( |
80 | std::thread::Builder::new() | |
81 | .name(format!("{} ({})", name, i)) | |
ae3cfa8f WB |
82 | .spawn(move || loop { |
83 | let data = match input_rx.recv() { | |
84 | Ok(data) => data, | |
85 | Err(_) => return, | |
86 | }; | |
87 | match (handler_fn)(data) { | |
88 | Ok(()) => (), | |
89 | Err(err) => { | |
90 | let mut guard = abort.lock().unwrap(); | |
91 | if guard.is_none() { | |
92 | *guard = Some(err.to_string()); | |
3c9b3702 DM |
93 | } |
94 | } | |
95 | } | |
96 | }) | |
97 | .unwrap() | |
98 | ); | |
99 | } | |
100 | Self { | |
101 | handles, | |
102 | name: name.to_string(), | |
1c13afa8 | 103 | input: Some(SendHandle { |
3c9b3702 DM |
104 | input: input_tx, |
105 | abort, | |
1c13afa8 | 106 | }), |
b02b374b | 107 | _marker: std::marker::PhantomData, |
3c9b3702 DM |
108 | } |
109 | } | |
110 | ||
111 | /// Returns a cloneable channel to send data to the worker threads | |
112 | pub fn channel(&self) -> SendHandle<I> { | |
1c13afa8 | 113 | self.input.as_ref().unwrap().clone() |
3c9b3702 DM |
114 | } |
115 | ||
116 | /// Send data to the worker threads | |
117 | pub fn send(&self, input: I) -> Result<(), Error> { | |
1c13afa8 | 118 | self.input.as_ref().unwrap().send(input)?; |
3c9b3702 DM |
119 | Ok(()) |
120 | } | |
121 | ||
122 | /// Wait for worker threads to complete and check for errors | |
1c13afa8 | 123 | pub fn complete(mut self) -> Result<(), Error> { |
dec00364 SR |
124 | let input = self.input.take().unwrap(); |
125 | let abort = Arc::clone(&input.abort); | |
0a8f3ae0 | 126 | check_abort(&abort)?; |
dec00364 | 127 | drop(input); |
1c13afa8 DM |
128 | |
129 | let msg_list = self.join_threads(); | |
130 | ||
0a8f3ae0 DM |
131 | // an error might be encountered while waiting for the join |
132 | check_abort(&abort)?; | |
133 | ||
1c13afa8 DM |
134 | if msg_list.is_empty() { |
135 | return Ok(()); | |
136 | } | |
137 | Err(format_err!("{}", msg_list.join("\n"))) | |
138 | } | |
139 | ||
140 | fn join_threads(&mut self) -> Vec<String> { | |
141 | ||
142 | let mut msg_list = Vec::new(); | |
143 | ||
144 | let mut i = 0; | |
145 | loop { | |
146 | let handle = match self.handles.pop() { | |
147 | Some(handle) => handle, | |
148 | None => break, | |
149 | }; | |
3c9b3702 DM |
150 | if let Err(panic) = handle.join() { |
151 | match panic.downcast::<&str>() { | |
1c13afa8 DM |
152 | Ok(panic_msg) => msg_list.push( |
153 | format!("thread {} ({}) paniced: {}", self.name, i, panic_msg) | |
154 | ), | |
155 | Err(_) => msg_list.push( | |
156 | format!("thread {} ({}) paniced", self.name, i) | |
157 | ), | |
3c9b3702 DM |
158 | } |
159 | } | |
1c13afa8 | 160 | i += 1; |
3c9b3702 | 161 | } |
1c13afa8 DM |
162 | msg_list |
163 | } | |
164 | } | |
165 | ||
166 | // Note: We make sure that all threads will be joined | |
ae3cfa8f | 167 | impl<'a, I> Drop for ParallelHandler<'a, I> { |
1c13afa8 DM |
168 | fn drop(&mut self) { |
169 | drop(self.input.take()); | |
ee1a9c32 WB |
170 | while let Some(handle) = self.handles.pop() { |
171 | let _ = handle.join(); | |
3c9b3702 | 172 | } |
3c9b3702 DM |
173 | } |
174 | } |