]> git.proxmox.com Git - proxmox-backup.git/blob - pbs-runtime/src/lib.rs
tfa: handle incompatible challenge data
[proxmox-backup.git] / pbs-runtime / src / lib.rs
1 //! Helpers for quirks of the current tokio runtime.
2
3 use std::cell::RefCell;
4 use std::future::Future;
5 use std::sync::{Arc, Weak, Mutex};
6 use std::task::{Context, Poll, RawWaker, Waker};
7 use std::thread::{self, Thread};
8
9 use lazy_static::lazy_static;
10 use pin_utils::pin_mut;
11 use tokio::runtime::{self, Runtime};
12
13 thread_local! {
14 static BLOCKING: RefCell<bool> = RefCell::new(false);
15 }
16
17 fn is_in_tokio() -> bool {
18 tokio::runtime::Handle::try_current()
19 .is_ok()
20 }
21
22 fn is_blocking() -> bool {
23 BLOCKING.with(|v| *v.borrow())
24 }
25
26 struct BlockingGuard(bool);
27
28 impl BlockingGuard {
29 fn set() -> Self {
30 Self(BLOCKING.with(|v| {
31 let old = *v.borrow();
32 *v.borrow_mut() = true;
33 old
34 }))
35 }
36 }
37
38 impl Drop for BlockingGuard {
39 fn drop(&mut self) {
40 BLOCKING.with(|v| {
41 *v.borrow_mut() = self.0;
42 });
43 }
44 }
45
46 lazy_static! {
47 // avoid openssl bug: https://github.com/openssl/openssl/issues/6214
48 // by dropping the runtime as early as possible
49 static ref RUNTIME: Mutex<Weak<Runtime>> = Mutex::new(Weak::new());
50 }
51
52 #[link(name = "crypto")]
53 extern "C" {
54 fn OPENSSL_thread_stop();
55 }
56
57 /// Get or create the current main tokio runtime.
58 ///
59 /// This makes sure that tokio's worker threads are marked for us so that we know whether we
60 /// can/need to use `block_in_place` in our `block_on` helper.
61 pub fn get_runtime_with_builder<F: Fn() -> runtime::Builder>(get_builder: F) -> Arc<Runtime> {
62
63 let mut guard = RUNTIME.lock().unwrap();
64
65 if let Some(rt) = guard.upgrade() { return rt; }
66
67 let mut builder = get_builder();
68 builder.on_thread_stop(|| {
69 // avoid openssl bug: https://github.com/openssl/openssl/issues/6214
70 // call OPENSSL_thread_stop to avoid race with openssl cleanup handlers
71 unsafe { OPENSSL_thread_stop(); }
72 });
73
74 let runtime = builder.build().expect("failed to spawn tokio runtime");
75 let rt = Arc::new(runtime);
76
77 *guard = Arc::downgrade(&rt);
78
79 rt
80 }
81
82 /// Get or create the current main tokio runtime.
83 ///
84 /// This calls get_runtime_with_builder() using the tokio default threaded scheduler
85 pub fn get_runtime() -> Arc<Runtime> {
86
87 get_runtime_with_builder(|| {
88 let mut builder = runtime::Builder::new_multi_thread();
89 builder.enable_all();
90 builder
91 })
92 }
93
94
95 /// Block on a synchronous piece of code.
96 pub fn block_in_place<R>(fut: impl FnOnce() -> R) -> R {
97 // don't double-exit the context (tokio doesn't like that)
98 // also, if we're not actually in a tokio-worker we must not use block_in_place() either
99 if is_blocking() || !is_in_tokio() {
100 fut()
101 } else {
102 // we are in an actual tokio worker thread, block it:
103 tokio::task::block_in_place(move || {
104 let _guard = BlockingGuard::set();
105 fut()
106 })
107 }
108 }
109
110 /// Block on a future in this thread.
111 pub fn block_on<F: Future>(fut: F) -> F::Output {
112 // don't double-exit the context (tokio doesn't like that)
113 if is_blocking() {
114 block_on_local_future(fut)
115 } else if is_in_tokio() {
116 // inside a tokio worker we need to tell tokio that we're about to really block:
117 tokio::task::block_in_place(move || {
118 let _guard = BlockingGuard::set();
119 block_on_local_future(fut)
120 })
121 } else {
122 // not a worker thread, not associated with a runtime, make sure we have a runtime (spawn
123 // it on demand if necessary), then enter it
124 let _guard = BlockingGuard::set();
125 let _enter_guard = get_runtime().enter();
126 get_runtime().block_on(fut)
127 }
128 }
129
130 /*
131 fn block_on_impl<F>(mut fut: F) -> F::Output
132 where
133 F: Future + Send,
134 F::Output: Send + 'static,
135 {
136 let (tx, rx) = tokio::sync::oneshot::channel();
137 let fut_ptr = &mut fut as *mut F as usize; // hack to not require F to be 'static
138 tokio::spawn(async move {
139 let fut: F = unsafe { std::ptr::read(fut_ptr as *mut F) };
140 tx
141 .send(fut.await)
142 .map_err(drop)
143 .expect("failed to send block_on result to channel")
144 });
145
146 futures::executor::block_on(async move {
147 rx.await.expect("failed to receive block_on result from channel")
148 })
149 std::mem::forget(fut);
150 }
151 */
152
153 /// This used to be our tokio main entry point. Now this just calls out to `block_on` for
154 /// compatibility, which will perform all the necessary tasks on-demand anyway.
155 pub fn main<F: Future>(fut: F) -> F::Output {
156 block_on(fut)
157 }
158
159 fn block_on_local_future<F: Future>(fut: F) -> F::Output {
160 pin_mut!(fut);
161
162 let waker = Arc::new(thread::current());
163 let waker = thread_waker_clone(Arc::into_raw(waker) as *const ());
164 let waker = unsafe { Waker::from_raw(waker) };
165 let mut context = Context::from_waker(&waker);
166 loop {
167 match fut.as_mut().poll(&mut context) {
168 Poll::Ready(out) => return out,
169 Poll::Pending => thread::park(),
170 }
171 }
172 }
173
174 const THREAD_WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new(
175 thread_waker_clone,
176 thread_waker_wake,
177 thread_waker_wake_by_ref,
178 thread_waker_drop,
179 );
180
181 fn thread_waker_clone(this: *const ()) -> RawWaker {
182 let this = unsafe { Arc::from_raw(this as *const Thread) };
183 let cloned = Arc::clone(&this);
184 let _ = Arc::into_raw(this);
185
186 RawWaker::new(Arc::into_raw(cloned) as *const (), &THREAD_WAKER_VTABLE)
187 }
188
189 fn thread_waker_wake(this: *const ()) {
190 let this = unsafe { Arc::from_raw(this as *const Thread) };
191 this.unpark();
192 }
193
194 fn thread_waker_wake_by_ref(this: *const ()) {
195 let this = unsafe { Arc::from_raw(this as *const Thread) };
196 this.unpark();
197 let _ = Arc::into_raw(this);
198 }
199
200 fn thread_waker_drop(this: *const ()) {
201 let this = unsafe { Arc::from_raw(this as *const Thread) };
202 drop(this);
203 }