lychee_lib/ratelimit/host/
host.rs

1use crate::{
2    ratelimit::{CacheableResponse, headers},
3    retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7    Quota, RateLimiter,
8    clock::DefaultClock,
9    state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::time::{Duration, Instant};
16use std::{num::NonZeroU32, sync::Mutex};
17use tokio::sync::Semaphore;
18
19use super::key::HostKey;
20use super::stats::HostStats;
21use crate::Uri;
22use crate::types::Result;
23use crate::{
24    ErrorKind,
25    ratelimit::{HostConfig, RateLimitConfig},
26};
27
28/// Cap maximum backoff duration to reasonable limits
29const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
30
31/// Per-host cache for storing request results
32type HostCache = DashMap<Uri, CacheableResponse>;
33
34/// Represents a single host with its own rate limiting, concurrency control,
35/// HTTP client configuration, and request cache.
36///
37/// Each host maintains:
38/// - A token bucket rate limiter using governor
39/// - A semaphore for concurrency control
40/// - A dedicated HTTP client with host-specific headers and cookies
41/// - Statistics tracking for adaptive behavior
42/// - A per-host cache to prevent duplicate requests
43#[derive(Debug)]
44pub struct Host {
45    /// The hostname this instance manages
46    pub key: HostKey,
47
48    /// Rate limiter using token bucket algorithm
49    rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50
51    /// Controls maximum concurrent requests to this host
52    semaphore: Semaphore,
53
54    /// HTTP client configured for this specific host
55    client: ReqwestClient,
56
57    /// Request statistics and adaptive behavior tracking
58    stats: Mutex<HostStats>,
59
60    /// Current backoff duration for adaptive rate limiting
61    backoff_duration: Mutex<Duration>,
62
63    /// Per-host cache to prevent duplicate requests during a single link check invocation.
64    /// Note that this cache has no direct relation to the inter-process persistable [`crate::CacheStatus`].
65    cache: HostCache,
66}
67
68impl Host {
69    /// Create a new Host instance for the given hostname
70    #[must_use]
71    pub fn new(
72        key: HostKey,
73        host_config: &HostConfig,
74        global_config: &RateLimitConfig,
75        client: ReqwestClient,
76    ) -> Self {
77        const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
78        let interval = host_config.effective_request_interval(global_config);
79        let rate_limiter =
80            Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
81
82        // Create semaphore for concurrency control
83        let max_concurrent = host_config.effective_concurrency(global_config);
84        let semaphore = Semaphore::new(max_concurrent);
85
86        Host {
87            key,
88            rate_limiter,
89            semaphore,
90            client,
91            stats: Mutex::new(HostStats::default()),
92            backoff_duration: Mutex::new(Duration::from_millis(0)),
93            cache: DashMap::new(),
94        }
95    }
96
97    /// Check if a URI is cached and return the cached status if valid
98    ///
99    /// # Panics
100    ///
101    /// Panics if the statistics mutex is poisoned
102    fn get_cached_status(&self, uri: &Uri) -> Option<CacheableResponse> {
103        self.cache.get(uri).map(|v| v.clone())
104    }
105
106    fn record_cache_hit(&self) {
107        self.stats.lock().unwrap().record_cache_hit();
108    }
109
110    fn record_cache_miss(&self) {
111        self.stats.lock().unwrap().record_cache_miss();
112    }
113
114    /// Cache a request result
115    fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
116        // Do not cache responses that are potentially retried
117        if !response.status.should_retry() {
118            self.cache.insert(uri.clone(), response);
119        }
120    }
121
122    /// Execute a request with rate limiting, concurrency control, and caching
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the request fails or rate limiting is exceeded
127    ///
128    /// # Panics
129    ///
130    /// Panics if the statistics mutex is poisoned
131    pub(crate) async fn execute_request(&self, request: Request) -> Result<CacheableResponse> {
132        let uri = Uri::from(request.url().clone());
133        let _permit = self.acquire_semaphore().await;
134
135        if let Some(cached) = self.get_cached_status(&uri) {
136            self.record_cache_hit();
137            return Ok(cached);
138        }
139
140        self.await_backoff().await;
141
142        if let Some(rate_limiter) = &self.rate_limiter {
143            rate_limiter.until_ready().await;
144        }
145
146        if let Some(cached) = self.get_cached_status(&uri) {
147            self.record_cache_hit();
148            return Ok(cached);
149        }
150
151        self.record_cache_miss();
152        self.perform_request(request, uri).await
153    }
154
155    pub(crate) const fn get_client(&self) -> &ReqwestClient {
156        &self.client
157    }
158
159    async fn perform_request(&self, request: Request, uri: Uri) -> Result<CacheableResponse> {
160        let start_time = Instant::now();
161        let response = match self.client.execute(request).await {
162            Ok(response) => response,
163            Err(e) => {
164                // Wrap network/HTTP errors to preserve the original error
165                return Err(ErrorKind::NetworkRequest(e));
166            }
167        };
168
169        self.update_stats(response.status(), start_time.elapsed());
170        self.update_backoff(response.status());
171        self.handle_rate_limit_headers(&response);
172
173        let response = CacheableResponse::try_from(response).await?;
174        self.cache_result(&uri, response.clone());
175        Ok(response)
176    }
177
178    /// Await adaptive backoff if needed
179    async fn await_backoff(&self) {
180        let backoff_duration = {
181            let backoff = self.backoff_duration.lock().unwrap();
182            *backoff
183        };
184        if !backoff_duration.is_zero() {
185            log::debug!(
186                "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
187                self.key,
188                backoff_duration.as_millis()
189            );
190            tokio::time::sleep(backoff_duration).await;
191        }
192    }
193
194    async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
195        self.semaphore
196            .acquire()
197            .await
198            // SAFETY: this should not panic as we never close the semaphore
199            .expect("Semaphore was closed unexpectedly")
200    }
201
202    fn update_backoff(&self, status: StatusCode) {
203        let mut backoff = self.backoff_duration.lock().unwrap();
204        match status.as_u16() {
205            200..=299 => {
206                // Reset backoff on success
207                *backoff = Duration::from_millis(0);
208            }
209            429 => {
210                // Exponential backoff on rate limit, capped at 30 seconds
211                let new_backoff = std::cmp::min(
212                    if backoff.is_zero() {
213                        Duration::from_millis(500)
214                    } else {
215                        *backoff * 2
216                    },
217                    Duration::from_secs(30),
218                );
219                log::debug!(
220                    "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
221                    self.key,
222                    backoff.as_millis(),
223                    new_backoff.as_millis()
224                );
225                *backoff = new_backoff;
226            }
227            500..=599 => {
228                // Moderate backoff increase on server errors, capped at 10 seconds
229                *backoff = std::cmp::min(
230                    *backoff + Duration::from_millis(200),
231                    Duration::from_secs(10),
232                );
233            }
234            _ => {} // No backoff change for other status codes
235        }
236    }
237
238    fn update_stats(&self, status: StatusCode, request_time: Duration) {
239        self.stats
240            .lock()
241            .unwrap()
242            .record_response(status.as_u16(), request_time);
243    }
244
245    /// Parse rate limit headers from response and adjust behavior
246    fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
247        // Implement basic parsing here rather than using the rate-limits crate to keep dependencies minimal
248        let headers = response.headers();
249        self.handle_retry_after_header(headers);
250        self.handle_common_rate_limit_header_fields(headers);
251    }
252
253    /// Handle the common "X-RateLimit" header fields.
254    fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
255        if let (Some(remaining), Some(limit)) =
256            headers::parse_common_rate_limit_header_fields(headers)
257            && limit > 0
258        {
259            #[allow(clippy::cast_precision_loss)]
260            let usage_ratio = (limit - remaining) as f64 / limit as f64;
261
262            // If we've used more than 80% of our quota, apply preventive backoff
263            if usage_ratio > 0.8 {
264                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
265                let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
266                self.increase_backoff(duration);
267            }
268        }
269    }
270
271    /// Handle the "Retry-After" header
272    fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
273        if let Some(retry_after_value) = headers.get("retry-after") {
274            let duration = match headers::parse_retry_after(retry_after_value) {
275                Ok(e) => e,
276                Err(e) => {
277                    warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
278                    return;
279                }
280            };
281
282            self.increase_backoff(duration);
283        }
284    }
285
286    fn increase_backoff(&self, mut increased_backoff: Duration) {
287        if increased_backoff > MAXIMUM_BACKOFF {
288            warn!(
289                "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
290                self.key,
291                format_duration(increased_backoff),
292                format_duration(MAXIMUM_BACKOFF)
293            );
294            increased_backoff = MAXIMUM_BACKOFF;
295        }
296
297        let mut backoff = self.backoff_duration.lock().unwrap();
298        *backoff = std::cmp::max(*backoff, increased_backoff);
299    }
300
301    /// Get host statistics
302    ///
303    /// # Panics
304    ///
305    /// Panics if the statistics mutex is poisoned
306    pub fn stats(&self) -> HostStats {
307        self.stats.lock().unwrap().clone()
308    }
309
310    /// Record a cache hit from the persistent disk cache.
311    /// Cache misses are tracked internally, so we don't expose such a method.
312    pub(crate) fn record_persistent_cache_hit(&self) {
313        self.record_cache_hit();
314    }
315
316    /// Get the current cache size (number of cached entries)
317    pub fn cache_size(&self) -> usize {
318        self.cache.len()
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::ratelimit::{HostConfig, RateLimitConfig};
326    use reqwest::Client;
327
328    #[tokio::test]
329    async fn test_host_creation() {
330        let key = HostKey::from("example.com");
331        let host_config = HostConfig::default();
332        let global_config = RateLimitConfig::default();
333
334        let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
335
336        assert_eq!(host.key, key);
337        assert_eq!(host.semaphore.available_permits(), 10); // Default concurrency
338        assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
339        assert_eq!(host.cache_size(), 0);
340    }
341}