lychee_lib/ratelimit/
pool.rs1use dashmap::DashMap;
2use http::Method;
3use reqwest::{Client, Request};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::ratelimit::{
8 CacheableResponse, Host, HostConfigs, HostKey, HostStats, HostStatsMap, RateLimitConfig,
9};
10use crate::types::Result;
11use crate::{ErrorKind, Uri};
12
13pub type ClientMap = HashMap<HostKey, reqwest::Client>;
15
16#[derive(Debug)]
28pub struct HostPool {
29 hosts: DashMap<HostKey, Arc<Host>>,
31
32 global_config: RateLimitConfig,
34
35 host_configs: HostConfigs,
37
38 default_client: Client,
40
41 client_map: ClientMap,
43}
44
45impl HostPool {
46 #[must_use]
48 pub fn new(
49 global_config: RateLimitConfig,
50 host_configs: HostConfigs,
51 default_client: Client,
52 client_map: ClientMap,
53 ) -> Self {
54 Self {
55 hosts: DashMap::new(),
56 global_config,
57 host_configs,
58 default_client,
59 client_map,
60 }
61 }
62
63 pub(crate) async fn execute_request(&self, request: Request) -> Result<CacheableResponse> {
71 let url = request.url();
72 let host_key = HostKey::try_from(url)?;
73 let host = self.get_or_create_host(host_key);
74 host.execute_request(request).await
75 }
76
77 pub fn build_request(&self, method: Method, uri: &Uri) -> Result<Request> {
85 let host_key = HostKey::try_from(uri)?;
86 let host = self.get_or_create_host(host_key);
87 host.get_client()
88 .request(method, uri.url.clone())
89 .build()
90 .map_err(ErrorKind::BuildRequestClient)
91 }
92
93 fn get_or_create_host(&self, host_key: HostKey) -> Arc<Host> {
95 self.hosts
96 .entry(host_key.clone())
97 .or_insert_with(|| {
98 let host_config = self
99 .host_configs
100 .get(&host_key)
101 .cloned()
102 .unwrap_or_default();
103
104 let client = self
105 .client_map
106 .get(&host_key)
107 .unwrap_or(&self.default_client)
108 .clone();
109
110 Arc::new(Host::new(
111 host_key,
112 &host_config,
113 &self.global_config,
114 client,
115 ))
116 })
117 .value()
118 .clone()
119 }
120
121 #[must_use]
124 pub fn host_stats(&self, hostname: &str) -> HostStats {
125 let host_key = HostKey::from(hostname);
126 self.hosts
127 .get(&host_key)
128 .map(|host| host.stats())
129 .unwrap_or_default()
130 }
131
132 #[must_use]
135 pub fn all_host_stats(&self) -> HostStatsMap {
136 HostStatsMap::from(
137 self.hosts
138 .iter()
139 .map(|entry| {
140 let hostname = entry.key().to_string();
141 let stats = entry.value().stats();
142 (hostname, stats)
143 })
144 .collect::<HashMap<_, _>>(),
145 )
146 }
147
148 #[must_use]
152 pub fn active_host_count(&self) -> usize {
153 self.hosts.len()
154 }
155
156 #[must_use]
159 pub fn host_configurations(&self) -> HostConfigs {
160 self.host_configs.clone()
161 }
162
163 #[must_use]
173 pub fn remove_host(&self, hostname: &str) -> bool {
174 let host_key = HostKey::from(hostname);
175 self.hosts.remove(&host_key).is_some()
176 }
177
178 #[must_use]
180 pub fn cache_stats(&self) -> HashMap<String, (usize, f64)> {
181 self.hosts
182 .iter()
183 .map(|entry| {
184 let hostname = entry.key().to_string();
185 let cache_size = entry.value().cache_size();
186 let hit_rate = entry.value().stats().cache_hit_rate();
187 (hostname, (cache_size, hit_rate))
188 })
189 .collect()
190 }
191
192 pub fn record_persistent_cache_hit(&self, uri: &crate::Uri) {
197 if !uri.is_file() && !uri.is_mail() {
198 match crate::ratelimit::HostKey::try_from(uri) {
199 Ok(key) => {
200 let host = self.get_or_create_host(key);
201 host.record_persistent_cache_hit();
202 }
203 Err(e) => {
204 log::debug!("Failed to record cache hit for {uri}: {e}");
205 }
206 }
207 }
208 }
209}
210
211impl Default for HostPool {
212 fn default() -> Self {
213 Self::new(
214 RateLimitConfig::default(),
215 HostConfigs::default(),
216 Client::default(),
217 HashMap::new(),
218 )
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::ratelimit::RateLimitConfig;
226
227 use url::Url;
228
229 #[test]
230 fn test_host_pool_creation() {
231 let pool = HostPool::new(
232 RateLimitConfig::default(),
233 HostConfigs::default(),
234 Client::default(),
235 HashMap::new(),
236 );
237
238 assert_eq!(pool.active_host_count(), 0);
239 }
240
241 #[test]
242 fn test_host_pool_default() {
243 let pool = HostPool::default();
244 assert_eq!(pool.active_host_count(), 0);
245 }
246
247 #[tokio::test]
248 async fn test_host_creation_on_demand() {
249 let pool = HostPool::default();
250 let url: Url = "https://example.com/path".parse().unwrap();
251 let host_key = HostKey::try_from(&url).unwrap();
252
253 assert_eq!(pool.active_host_count(), 0);
255 assert_eq!(pool.host_stats("example.com").total_requests, 0);
256
257 let host = pool.get_or_create_host(host_key);
259
260 assert_eq!(pool.active_host_count(), 1);
262 assert_eq!(pool.host_stats("example.com").total_requests, 0);
263 assert_eq!(host.key.as_str(), "example.com");
264 }
265
266 #[tokio::test]
267 async fn test_host_reuse() {
268 let pool = HostPool::default();
269 let url: Url = "https://example.com/path1".parse().unwrap();
270 let host_key1 = HostKey::try_from(&url).unwrap();
271
272 let url: Url = "https://example.com/path2".parse().unwrap();
273 let host_key2 = HostKey::try_from(&url).unwrap();
274
275 let host1 = pool.get_or_create_host(host_key1);
277 assert_eq!(pool.active_host_count(), 1);
278
279 let host2 = pool.get_or_create_host(host_key2);
281 assert_eq!(pool.active_host_count(), 1);
282
283 assert!(Arc::ptr_eq(&host1, &host2));
285 }
286
287 #[test]
288 fn test_host_config_management() {
289 let pool = HostPool::default();
290
291 let configs = pool.host_configurations();
293 assert_eq!(configs.len(), 0);
294 }
295
296 #[test]
297 fn test_host_removal() {
298 let pool = HostPool::default();
299
300 assert!(!pool.remove_host("nonexistent.com"));
302
303 }
306}