lychee_lib/ratelimit/host/
host.rs1use crate::{
2 ratelimit::{CacheableResponse, headers},
3 retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7 Quota, RateLimiter,
8 clock::DefaultClock,
9 state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::time::{Duration, Instant};
16use std::{num::NonZeroU32, sync::Mutex};
17use tokio::sync::Semaphore;
18
19use super::key::HostKey;
20use super::stats::HostStats;
21use crate::Uri;
22use crate::types::Result;
23use crate::{
24 ErrorKind,
25 ratelimit::{HostConfig, RateLimitConfig},
26};
27
28const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
30
31type HostCache = DashMap<Uri, CacheableResponse>;
33
34#[derive(Debug)]
44pub struct Host {
45 pub key: HostKey,
47
48 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50
51 semaphore: Semaphore,
53
54 client: ReqwestClient,
56
57 stats: Mutex<HostStats>,
59
60 backoff_duration: Mutex<Duration>,
62
63 cache: HostCache,
66}
67
68impl Host {
69 #[must_use]
71 pub fn new(
72 key: HostKey,
73 host_config: &HostConfig,
74 global_config: &RateLimitConfig,
75 client: ReqwestClient,
76 ) -> Self {
77 const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
78 let interval = host_config.effective_request_interval(global_config);
79 let rate_limiter =
80 Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
81
82 let max_concurrent = host_config.effective_concurrency(global_config);
84 let semaphore = Semaphore::new(max_concurrent);
85
86 Host {
87 key,
88 rate_limiter,
89 semaphore,
90 client,
91 stats: Mutex::new(HostStats::default()),
92 backoff_duration: Mutex::new(Duration::from_millis(0)),
93 cache: DashMap::new(),
94 }
95 }
96
97 fn get_cached_status(&self, uri: &Uri) -> Option<CacheableResponse> {
103 self.cache.get(uri).map(|v| v.clone())
104 }
105
106 fn record_cache_hit(&self) {
107 self.stats.lock().unwrap().record_cache_hit();
108 }
109
110 fn record_cache_miss(&self) {
111 self.stats.lock().unwrap().record_cache_miss();
112 }
113
114 fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
116 if !response.status.should_retry() {
118 self.cache.insert(uri.clone(), response);
119 }
120 }
121
122 pub(crate) async fn execute_request(&self, request: Request) -> Result<CacheableResponse> {
132 let uri = Uri::from(request.url().clone());
133 let _permit = self.acquire_semaphore().await;
134
135 if let Some(cached) = self.get_cached_status(&uri) {
136 self.record_cache_hit();
137 return Ok(cached);
138 }
139
140 self.await_backoff().await;
141
142 if let Some(rate_limiter) = &self.rate_limiter {
143 rate_limiter.until_ready().await;
144 }
145
146 if let Some(cached) = self.get_cached_status(&uri) {
147 self.record_cache_hit();
148 return Ok(cached);
149 }
150
151 self.record_cache_miss();
152 self.perform_request(request, uri).await
153 }
154
155 pub(crate) const fn get_client(&self) -> &ReqwestClient {
156 &self.client
157 }
158
159 async fn perform_request(&self, request: Request, uri: Uri) -> Result<CacheableResponse> {
160 let start_time = Instant::now();
161 let response = match self.client.execute(request).await {
162 Ok(response) => response,
163 Err(e) => {
164 return Err(ErrorKind::NetworkRequest(e));
166 }
167 };
168
169 self.update_stats(response.status(), start_time.elapsed());
170 self.update_backoff(response.status());
171 self.handle_rate_limit_headers(&response);
172
173 let response = CacheableResponse::try_from(response).await?;
174 self.cache_result(&uri, response.clone());
175 Ok(response)
176 }
177
178 async fn await_backoff(&self) {
180 let backoff_duration = {
181 let backoff = self.backoff_duration.lock().unwrap();
182 *backoff
183 };
184 if !backoff_duration.is_zero() {
185 log::debug!(
186 "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
187 self.key,
188 backoff_duration.as_millis()
189 );
190 tokio::time::sleep(backoff_duration).await;
191 }
192 }
193
194 async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
195 self.semaphore
196 .acquire()
197 .await
198 .expect("Semaphore was closed unexpectedly")
200 }
201
202 fn update_backoff(&self, status: StatusCode) {
203 let mut backoff = self.backoff_duration.lock().unwrap();
204 match status.as_u16() {
205 200..=299 => {
206 *backoff = Duration::from_millis(0);
208 }
209 429 => {
210 let new_backoff = std::cmp::min(
212 if backoff.is_zero() {
213 Duration::from_millis(500)
214 } else {
215 *backoff * 2
216 },
217 Duration::from_secs(30),
218 );
219 log::debug!(
220 "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
221 self.key,
222 backoff.as_millis(),
223 new_backoff.as_millis()
224 );
225 *backoff = new_backoff;
226 }
227 500..=599 => {
228 *backoff = std::cmp::min(
230 *backoff + Duration::from_millis(200),
231 Duration::from_secs(10),
232 );
233 }
234 _ => {} }
236 }
237
238 fn update_stats(&self, status: StatusCode, request_time: Duration) {
239 self.stats
240 .lock()
241 .unwrap()
242 .record_response(status.as_u16(), request_time);
243 }
244
245 fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
247 let headers = response.headers();
249 self.handle_retry_after_header(headers);
250 self.handle_common_rate_limit_header_fields(headers);
251 }
252
253 fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
255 if let (Some(remaining), Some(limit)) =
256 headers::parse_common_rate_limit_header_fields(headers)
257 && limit > 0
258 {
259 #[allow(clippy::cast_precision_loss)]
260 let usage_ratio = (limit - remaining) as f64 / limit as f64;
261
262 if usage_ratio > 0.8 {
264 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
265 let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
266 self.increase_backoff(duration);
267 }
268 }
269 }
270
271 fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
273 if let Some(retry_after_value) = headers.get("retry-after") {
274 let duration = match headers::parse_retry_after(retry_after_value) {
275 Ok(e) => e,
276 Err(e) => {
277 warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
278 return;
279 }
280 };
281
282 self.increase_backoff(duration);
283 }
284 }
285
286 fn increase_backoff(&self, mut increased_backoff: Duration) {
287 if increased_backoff > MAXIMUM_BACKOFF {
288 warn!(
289 "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
290 self.key,
291 format_duration(increased_backoff),
292 format_duration(MAXIMUM_BACKOFF)
293 );
294 increased_backoff = MAXIMUM_BACKOFF;
295 }
296
297 let mut backoff = self.backoff_duration.lock().unwrap();
298 *backoff = std::cmp::max(*backoff, increased_backoff);
299 }
300
301 pub fn stats(&self) -> HostStats {
307 self.stats.lock().unwrap().clone()
308 }
309
310 pub(crate) fn record_persistent_cache_hit(&self) {
313 self.record_cache_hit();
314 }
315
316 pub fn cache_size(&self) -> usize {
318 self.cache.len()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::ratelimit::{HostConfig, RateLimitConfig};
326 use reqwest::Client;
327
328 #[tokio::test]
329 async fn test_host_creation() {
330 let key = HostKey::from("example.com");
331 let host_config = HostConfig::default();
332 let global_config = RateLimitConfig::default();
333
334 let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
335
336 assert_eq!(host.key, key);
337 assert_eq!(host.semaphore.available_permits(), 10); assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
339 assert_eq!(host.cache_size(), 0);
340 }
341}