]>
Commit | Line | Data |
---|---|---|
5446bfbb SR |
1 | //! An 'async'-safe layer on the existing sync LruCache implementation. Supports multiple |
2 | //! concurrent requests to the same key. | |
3 | ||
4 | use anyhow::Error; | |
5 | ||
6 | use std::collections::HashMap; | |
7 | use std::future::Future; | |
8 | use std::sync::{Arc, Mutex}; | |
9 | ||
10 | use super::lru_cache::LruCache; | |
11 | use super::BroadcastFuture; | |
12 | ||
13 | /// Interface for asynchronously getting values on cache misses. | |
14 | pub trait AsyncCacher<K, V: Clone>: Sync + Send { | |
15 | /// Fetch a value for key on cache miss. | |
16 | /// | |
17 | /// Works similar to non-async lru_cache::Cacher, but if the key has already been requested | |
18 | /// and the result is not cached yet, the 'fetch' function will not be called and instead the | |
19 | /// result of the original request cloned and returned upon completion. | |
20 | /// | |
21 | /// The underlying LRU structure is not modified until the returned future resolves to an | |
22 | /// Ok(Some(_)) value. | |
23 | fn fetch(&self, key: K) -> Box<dyn Future<Output = Result<Option<V>, Error>> + Send>; | |
24 | } | |
25 | ||
26 | /// See tools::lru_cache::LruCache, this implements an async-safe variant of that with the help of | |
27 | /// AsyncCacher. | |
28 | #[derive(Clone)] | |
29 | pub struct AsyncLruCache<K, V> { | |
30 | maps: Arc<Mutex<(LruCache<K, V>, HashMap<K, BroadcastFuture<Option<V>>>)>>, | |
31 | } | |
32 | ||
33 | impl<K: std::cmp::Eq + std::hash::Hash + Copy, V: Clone + Send + 'static> AsyncLruCache<K, V> { | |
34 | /// Create a new AsyncLruCache with the given maximum capacity. | |
35 | pub fn new(capacity: usize) -> Self { | |
36 | Self { | |
37 | maps: Arc::new(Mutex::new((LruCache::new(capacity), HashMap::new()))), | |
38 | } | |
39 | } | |
40 | ||
41 | /// Access an item either via the cache or by calling cacher.fetch. A return value of Ok(None) | |
42 | /// means the item requested has no representation, Err(_) means a call to fetch() failed, | |
43 | /// regardless of whether it was initiated by this call or a previous one. | |
44 | pub async fn access(&self, key: K, cacher: &dyn AsyncCacher<K, V>) -> Result<Option<V>, Error> { | |
45 | let (owner, result_fut) = { | |
46 | // check if already requested | |
47 | let mut maps = self.maps.lock().unwrap(); | |
48 | if let Some(fut) = maps.1.get(&key) { | |
49 | // wait for the already scheduled future to resolve | |
50 | (false, fut.listen()) | |
51 | } else { | |
52 | // check if value is cached in LRU | |
53 | if let Some(val) = maps.0.get_mut(key) { | |
54 | return Ok(Some(val.clone())); | |
55 | } | |
56 | ||
57 | // if neither, start broadcast future and put into map while we still have lock | |
58 | let fut = cacher.fetch(key); | |
59 | let broadcast = BroadcastFuture::new(fut); | |
60 | let result_fut = broadcast.listen(); | |
61 | maps.1.insert(key, broadcast); | |
62 | (true, result_fut) | |
63 | } | |
64 | // drop Mutex before awaiting any future | |
65 | }; | |
66 | ||
67 | let result = result_fut.await; | |
68 | match result { | |
69 | Ok(Some(ref value)) if owner => { | |
70 | // this call was the one initiating the request, put into LRU and remove from map | |
71 | let mut maps = self.maps.lock().unwrap(); | |
72 | maps.0.insert(key, value.clone()); | |
73 | maps.1.remove(&key); | |
74 | } | |
75 | _ => {} | |
76 | } | |
77 | result | |
78 | } | |
79 | } | |
80 | ||
81 | mod test { | |
82 | use super::*; | |
83 | ||
84 | struct TestAsyncCacher { | |
85 | prefix: &'static str, | |
86 | } | |
87 | ||
88 | impl AsyncCacher<i32, String> for TestAsyncCacher { | |
89 | fn fetch( | |
90 | &self, | |
91 | key: i32, | |
92 | ) -> Box<dyn Future<Output = Result<Option<String>, Error>> + Send> { | |
93 | let x = self.prefix; | |
94 | Box::new(async move { Ok(Some(format!("{}{}", x, key))) }) | |
95 | } | |
96 | } | |
97 | ||
98 | #[test] | |
99 | fn test_async_lru_cache() { | |
100 | let rt = tokio::runtime::Runtime::new().unwrap(); | |
101 | rt.block_on(async move { | |
102 | let cacher = TestAsyncCacher { prefix: "x" }; | |
103 | let cache: AsyncLruCache<i32, String> = AsyncLruCache::new(2); | |
104 | ||
105 | assert_eq!( | |
106 | cache.access(10, &cacher).await.unwrap(), | |
107 | Some("x10".to_string()) | |
108 | ); | |
109 | assert_eq!( | |
110 | cache.access(20, &cacher).await.unwrap(), | |
111 | Some("x20".to_string()) | |
112 | ); | |
113 | assert_eq!( | |
114 | cache.access(30, &cacher).await.unwrap(), | |
115 | Some("x30".to_string()) | |
116 | ); | |
117 | ||
118 | for _ in 0..10 { | |
119 | let c = cache.clone(); | |
120 | tokio::spawn(async move { | |
121 | let cacher = TestAsyncCacher { prefix: "y" }; | |
122 | assert_eq!( | |
123 | c.access(40, &cacher).await.unwrap(), | |
124 | Some("y40".to_string()) | |
125 | ); | |
126 | }); | |
127 | } | |
128 | ||
129 | assert_eq!( | |
130 | cache.access(20, &cacher).await.unwrap(), | |
131 | Some("x20".to_string()) | |
132 | ); | |
133 | }); | |
134 | } | |
135 | } |