]>
Commit | Line | Data |
---|---|---|
3c9b3702 | 1 | use std::sync::{Arc, Mutex}; |
ae3cfa8f WB |
2 | use std::thread::JoinHandle; |
3 | ||
55bee048 | 4 | use anyhow::{bail, format_err, Error}; |
ae3cfa8f | 5 | use crossbeam_channel::{bounded, Sender}; |
3c9b3702 | 6 | |
ae3cfa8f | 7 | /// A handle to send data to the worker thread (implements clone) |
3c9b3702 DM |
8 | pub struct SendHandle<I> { |
9 | input: Sender<I>, | |
10 | abort: Arc<Mutex<Option<String>>>, | |
11 | } | |
12 | ||
ae3cfa8f | 13 | impl<I: Send> SendHandle<I> { |
3c9b3702 DM |
14 | /// Returns the first error happened, if any |
15 | pub fn check_abort(&self) -> Result<(), Error> { | |
16 | let guard = self.abort.lock().unwrap(); | |
17 | if let Some(err_msg) = &*guard { | |
18 | return Err(format_err!("{}", err_msg)); | |
19 | } | |
20 | Ok(()) | |
21 | } | |
22 | ||
23 | /// Send data to the worker threads | |
24 | pub fn send(&self, input: I) -> Result<(), Error> { | |
25 | self.check_abort()?; | |
55bee048 DM |
26 | match self.input.send(input) { |
27 | Ok(()) => Ok(()), | |
28 | Err(_) => bail!("send failed - channel closed"), | |
29 | } | |
3c9b3702 DM |
30 | } |
31 | } | |
32 | ||
1c13afa8 DM |
33 | /// A thread pool which run the supplied closure |
34 | /// | |
35 | /// The send command sends data to the worker threads. If one handler | |
36 | /// returns an error, we mark the channel as failed and it is no | |
37 | /// longer possible to send data. | |
38 | /// | |
39 | /// When done, the 'complete()' method needs to be called to check for | |
40 | /// outstanding errors. | |
b02b374b | 41 | pub struct ParallelHandler<'a, I> { |
1c13afa8 DM |
42 | handles: Vec<JoinHandle<()>>, |
43 | name: String, | |
44 | input: Option<SendHandle<I>>, | |
b02b374b | 45 | _marker: std::marker::PhantomData<&'a ()>, |
1c13afa8 DM |
46 | } |
47 | ||
ae3cfa8f | 48 | impl<I> Clone for SendHandle<I> { |
3c9b3702 | 49 | fn clone(&self) -> Self { |
ae3cfa8f WB |
50 | Self { |
51 | input: self.input.clone(), | |
52 | abort: self.abort.clone(), | |
53 | } | |
3c9b3702 DM |
54 | } |
55 | } | |
56 | ||
ae3cfa8f | 57 | impl<'a, I: Send + 'static> ParallelHandler<'a, I> { |
3c9b3702 DM |
58 | /// Create a new thread pool, each thread processing incoming data |
59 | /// with 'handler_fn'. | |
ae3cfa8f | 60 | pub fn new<F>(name: &str, threads: usize, handler_fn: F) -> Self |
b02b374b | 61 | where F: Fn(I) -> Result<(), Error> + Send + Clone + 'a, |
3c9b3702 DM |
62 | { |
63 | let mut handles = Vec::new(); | |
64 | let (input_tx, input_rx) = bounded::<I>(threads); | |
65 | ||
66 | let abort = Arc::new(Mutex::new(None)); | |
67 | ||
68 | for i in 0..threads { | |
69 | let input_rx = input_rx.clone(); | |
70 | let abort = abort.clone(); | |
b02b374b DM |
71 | |
72 | // Erase the 'a lifetime bound. This is safe because we | |
73 | // join all thread in the drop handler. | |
74 | let handler_fn: Box<dyn Fn(I) -> Result<(), Error> + Send + 'a> = | |
75 | Box::new(handler_fn.clone()); | |
76 | let handler_fn: Box<dyn Fn(I) -> Result<(), Error> + Send + 'static> = | |
77 | unsafe { std::mem::transmute(handler_fn) }; | |
78 | ||
3c9b3702 DM |
79 | handles.push( |
80 | std::thread::Builder::new() | |
81 | .name(format!("{} ({})", name, i)) | |
ae3cfa8f WB |
82 | .spawn(move || loop { |
83 | let data = match input_rx.recv() { | |
84 | Ok(data) => data, | |
85 | Err(_) => return, | |
86 | }; | |
87 | match (handler_fn)(data) { | |
88 | Ok(()) => (), | |
89 | Err(err) => { | |
90 | let mut guard = abort.lock().unwrap(); | |
91 | if guard.is_none() { | |
92 | *guard = Some(err.to_string()); | |
3c9b3702 DM |
93 | } |
94 | } | |
95 | } | |
96 | }) | |
97 | .unwrap() | |
98 | ); | |
99 | } | |
100 | Self { | |
101 | handles, | |
102 | name: name.to_string(), | |
1c13afa8 | 103 | input: Some(SendHandle { |
3c9b3702 DM |
104 | input: input_tx, |
105 | abort, | |
1c13afa8 | 106 | }), |
b02b374b | 107 | _marker: std::marker::PhantomData, |
3c9b3702 DM |
108 | } |
109 | } | |
110 | ||
111 | /// Returns a cloneable channel to send data to the worker threads | |
112 | pub fn channel(&self) -> SendHandle<I> { | |
1c13afa8 | 113 | self.input.as_ref().unwrap().clone() |
3c9b3702 DM |
114 | } |
115 | ||
116 | /// Send data to the worker threads | |
117 | pub fn send(&self, input: I) -> Result<(), Error> { | |
1c13afa8 | 118 | self.input.as_ref().unwrap().send(input)?; |
3c9b3702 DM |
119 | Ok(()) |
120 | } | |
121 | ||
122 | /// Wait for worker threads to complete and check for errors | |
1c13afa8 DM |
123 | pub fn complete(mut self) -> Result<(), Error> { |
124 | self.input.as_ref().unwrap().check_abort()?; | |
125 | drop(self.input.take()); | |
126 | ||
127 | let msg_list = self.join_threads(); | |
128 | ||
129 | if msg_list.is_empty() { | |
130 | return Ok(()); | |
131 | } | |
132 | Err(format_err!("{}", msg_list.join("\n"))) | |
133 | } | |
134 | ||
135 | fn join_threads(&mut self) -> Vec<String> { | |
136 | ||
137 | let mut msg_list = Vec::new(); | |
138 | ||
139 | let mut i = 0; | |
140 | loop { | |
141 | let handle = match self.handles.pop() { | |
142 | Some(handle) => handle, | |
143 | None => break, | |
144 | }; | |
3c9b3702 DM |
145 | if let Err(panic) = handle.join() { |
146 | match panic.downcast::<&str>() { | |
1c13afa8 DM |
147 | Ok(panic_msg) => msg_list.push( |
148 | format!("thread {} ({}) paniced: {}", self.name, i, panic_msg) | |
149 | ), | |
150 | Err(_) => msg_list.push( | |
151 | format!("thread {} ({}) paniced", self.name, i) | |
152 | ), | |
3c9b3702 DM |
153 | } |
154 | } | |
1c13afa8 | 155 | i += 1; |
3c9b3702 | 156 | } |
1c13afa8 DM |
157 | msg_list |
158 | } | |
159 | } | |
160 | ||
161 | // Note: We make sure that all threads will be joined | |
ae3cfa8f | 162 | impl<'a, I> Drop for ParallelHandler<'a, I> { |
1c13afa8 DM |
163 | fn drop(&mut self) { |
164 | drop(self.input.take()); | |
165 | loop { | |
166 | match self.handles.pop() { | |
167 | Some(handle) => { | |
168 | let _ = handle.join(); | |
169 | } | |
170 | None => break, | |
171 | } | |
3c9b3702 | 172 | } |
3c9b3702 DM |
173 | } |
174 | } |