]>
Commit | Line | Data |
---|---|---|
5446bfbb SR |
1 | //! An 'async'-safe layer on the existing sync LruCache implementation. Supports multiple |
2 | //! concurrent requests to the same key. | |
3 | ||
4 | use anyhow::Error; | |
5 | ||
6 | use std::collections::HashMap; | |
7 | use std::future::Future; | |
8 | use std::sync::{Arc, Mutex}; | |
9 | ||
9a1b24b6 | 10 | use proxmox_async::broadcast_future::BroadcastFuture; |
6c221244 | 11 | use crate::lru_cache::LruCache; |
5446bfbb SR |
12 | |
13 | /// Interface for asynchronously getting values on cache misses. | |
14 | pub trait AsyncCacher<K, V: Clone>: Sync + Send { | |
15 | /// Fetch a value for key on cache miss. | |
16 | /// | |
17 | /// Works similar to non-async lru_cache::Cacher, but if the key has already been requested | |
18 | /// and the result is not cached yet, the 'fetch' function will not be called and instead the | |
19 | /// result of the original request cloned and returned upon completion. | |
20 | /// | |
21 | /// The underlying LRU structure is not modified until the returned future resolves to an | |
22 | /// Ok(Some(_)) value. | |
23 | fn fetch(&self, key: K) -> Box<dyn Future<Output = Result<Option<V>, Error>> + Send>; | |
24 | } | |
25 | ||
26 | /// See tools::lru_cache::LruCache, this implements an async-safe variant of that with the help of | |
27 | /// AsyncCacher. | |
28 | #[derive(Clone)] | |
29 | pub struct AsyncLruCache<K, V> { | |
30 | maps: Arc<Mutex<(LruCache<K, V>, HashMap<K, BroadcastFuture<Option<V>>>)>>, | |
31 | } | |
32 | ||
33 | impl<K: std::cmp::Eq + std::hash::Hash + Copy, V: Clone + Send + 'static> AsyncLruCache<K, V> { | |
34 | /// Create a new AsyncLruCache with the given maximum capacity. | |
35 | pub fn new(capacity: usize) -> Self { | |
36 | Self { | |
37 | maps: Arc::new(Mutex::new((LruCache::new(capacity), HashMap::new()))), | |
38 | } | |
39 | } | |
40 | ||
41 | /// Access an item either via the cache or by calling cacher.fetch. A return value of Ok(None) | |
42 | /// means the item requested has no representation, Err(_) means a call to fetch() failed, | |
43 | /// regardless of whether it was initiated by this call or a previous one. | |
44 | pub async fn access(&self, key: K, cacher: &dyn AsyncCacher<K, V>) -> Result<Option<V>, Error> { | |
45 | let (owner, result_fut) = { | |
46 | // check if already requested | |
47 | let mut maps = self.maps.lock().unwrap(); | |
48 | if let Some(fut) = maps.1.get(&key) { | |
49 | // wait for the already scheduled future to resolve | |
50 | (false, fut.listen()) | |
51 | } else { | |
52 | // check if value is cached in LRU | |
53 | if let Some(val) = maps.0.get_mut(key) { | |
54 | return Ok(Some(val.clone())); | |
55 | } | |
56 | ||
57 | // if neither, start broadcast future and put into map while we still have lock | |
58 | let fut = cacher.fetch(key); | |
59 | let broadcast = BroadcastFuture::new(fut); | |
60 | let result_fut = broadcast.listen(); | |
61 | maps.1.insert(key, broadcast); | |
62 | (true, result_fut) | |
63 | } | |
64 | // drop Mutex before awaiting any future | |
65 | }; | |
66 | ||
67 | let result = result_fut.await; | |
c48c38ab SR |
68 | |
69 | if owner { | |
70 | // this call was the one initiating the request, put into LRU and remove from map | |
71 | let mut maps = self.maps.lock().unwrap(); | |
72 | if let Ok(Some(ref value)) = result { | |
5446bfbb | 73 | maps.0.insert(key, value.clone()); |
5446bfbb | 74 | } |
c48c38ab | 75 | maps.1.remove(&key); |
5446bfbb | 76 | } |
c48c38ab | 77 | |
5446bfbb SR |
78 | result |
79 | } | |
80 | } | |
81 | ||
82 | mod test { | |
83 | use super::*; | |
84 | ||
85 | struct TestAsyncCacher { | |
86 | prefix: &'static str, | |
87 | } | |
88 | ||
89 | impl AsyncCacher<i32, String> for TestAsyncCacher { | |
90 | fn fetch( | |
91 | &self, | |
92 | key: i32, | |
93 | ) -> Box<dyn Future<Output = Result<Option<String>, Error>> + Send> { | |
94 | let x = self.prefix; | |
95 | Box::new(async move { Ok(Some(format!("{}{}", x, key))) }) | |
96 | } | |
97 | } | |
98 | ||
99 | #[test] | |
100 | fn test_async_lru_cache() { | |
101 | let rt = tokio::runtime::Runtime::new().unwrap(); | |
102 | rt.block_on(async move { | |
103 | let cacher = TestAsyncCacher { prefix: "x" }; | |
104 | let cache: AsyncLruCache<i32, String> = AsyncLruCache::new(2); | |
105 | ||
106 | assert_eq!( | |
107 | cache.access(10, &cacher).await.unwrap(), | |
108 | Some("x10".to_string()) | |
109 | ); | |
110 | assert_eq!( | |
111 | cache.access(20, &cacher).await.unwrap(), | |
112 | Some("x20".to_string()) | |
113 | ); | |
114 | assert_eq!( | |
115 | cache.access(30, &cacher).await.unwrap(), | |
116 | Some("x30".to_string()) | |
117 | ); | |
118 | ||
119 | for _ in 0..10 { | |
120 | let c = cache.clone(); | |
121 | tokio::spawn(async move { | |
122 | let cacher = TestAsyncCacher { prefix: "y" }; | |
123 | assert_eq!( | |
124 | c.access(40, &cacher).await.unwrap(), | |
125 | Some("y40".to_string()) | |
126 | ); | |
127 | }); | |
128 | } | |
129 | ||
130 | assert_eq!( | |
131 | cache.access(20, &cacher).await.unwrap(), | |
132 | Some("x20".to_string()) | |
133 | ); | |
134 | }); | |
135 | } | |
136 | } |