lychee_lib/checker/
website.rs

1use crate::{
2    BasicAuthCredentials, ErrorKind, FileType, Status, Uri,
3    chain::{Chain, ChainResult, ClientRequestChains, Handler, RequestChain},
4    quirks::Quirks,
5    ratelimit::{CacheableResponse, HostPool},
6    retry::RetryExt,
7    types::{redirect_history::RedirectHistory, uri::github::GithubUri},
8    utils::fragment_checker::{FragmentChecker, FragmentInput},
9};
10use async_trait::async_trait;
11use http::{Method, StatusCode};
12use octocrab::Octocrab;
13use reqwest::{Request, header::CONTENT_TYPE};
14use std::{collections::HashSet, path::Path, sync::Arc, time::Duration};
15use url::Url;
16
17#[derive(Debug, Clone)]
18pub(crate) struct WebsiteChecker {
19    /// Request method used for making requests.
20    method: reqwest::Method,
21
22    /// GitHub client used for requests.
23    github_client: Option<Octocrab>,
24
25    /// The chain of plugins to be executed on each request.
26    plugin_request_chain: RequestChain,
27
28    /// Maximum number of retries per request before returning an error.
29    max_retries: u64,
30
31    /// Initial wait time between retries of failed requests. This doubles after
32    /// each failure.
33    retry_wait_time: Duration,
34
35    /// Set of accepted return codes / status codes.
36    ///
37    /// Unmatched return codes/ status codes are deemed as errors.
38    accepted: HashSet<StatusCode>,
39
40    /// Requires using HTTPS when it's available.
41    ///
42    /// This would treat unencrypted links as errors when HTTPS is available.
43    require_https: bool,
44
45    /// Whether to check the existence of fragments in the response HTML files.
46    ///
47    /// Will be disabled if the request method is `HEAD`.
48    include_fragments: bool,
49
50    /// Utility for performing fragment checks in HTML files.
51    fragment_checker: FragmentChecker,
52
53    /// Keep track of HTTP redirections for reporting
54    redirect_history: RedirectHistory,
55
56    /// Optional host pool for per-host rate limiting.
57    ///
58    /// When present, HTTP requests will be routed through this pool for
59    /// rate limiting. When None, requests go directly through `reqwest_client`.
60    host_pool: Arc<HostPool>,
61}
62
63impl WebsiteChecker {
64    /// Get a reference to `HostPool`
65    #[must_use]
66    pub(crate) fn host_pool(&self) -> Arc<HostPool> {
67        self.host_pool.clone()
68    }
69
70    #[allow(clippy::too_many_arguments)]
71    pub(crate) fn new(
72        method: reqwest::Method,
73        retry_wait_time: Duration,
74        redirect_history: RedirectHistory,
75        max_retries: u64,
76        accepted: HashSet<StatusCode>,
77        github_client: Option<Octocrab>,
78        require_https: bool,
79        plugin_request_chain: RequestChain,
80        include_fragments: bool,
81        host_pool: Arc<HostPool>,
82    ) -> Self {
83        Self {
84            method,
85            github_client,
86            plugin_request_chain,
87            redirect_history,
88            max_retries,
89            retry_wait_time,
90            accepted,
91            require_https,
92            include_fragments,
93            fragment_checker: FragmentChecker::new(),
94            host_pool,
95        }
96    }
97
98    /// Retry requests up to `max_retries` times
99    /// with an exponential backoff.
100    /// Note that, in addition, there also is a host-specific backoff
101    /// when host-specific rate limiting or errors are detected.
102    pub(crate) async fn retry_request(&self, request: Request) -> Status {
103        let mut retries: u64 = 0;
104        let mut wait_time = self.retry_wait_time;
105        let mut status = self.check_default(clone_unwrap(&request)).await;
106        while retries < self.max_retries {
107            if status.is_success() || !status.should_retry() {
108                return status;
109            }
110            retries += 1;
111            tokio::time::sleep(wait_time).await;
112            wait_time = wait_time.saturating_mul(2);
113            status = self.check_default(clone_unwrap(&request)).await;
114        }
115
116        status
117    }
118
119    /// Check a URI using [reqwest](https://github.com/seanmonstar/reqwest).
120    async fn check_default(&self, request: Request) -> Status {
121        let method = request.method().clone();
122        let request_url = request.url().clone();
123
124        match self.host_pool.execute_request(request).await {
125            Ok(response) => {
126                let status = Status::new(&response, &self.accepted);
127                // when `accept=200,429`, `status_code=429` will be treated as success
128                // but we are not able the check the fragment since it's inapplicable.
129                if self.include_fragments
130                    && response.status.is_success()
131                    && method == Method::GET
132                    && request_url.fragment().is_some_and(|x| !x.is_empty())
133                {
134                    let Some(content_type) = response
135                        .headers
136                        .get(CONTENT_TYPE)
137                        .and_then(|header| header.to_str().ok())
138                    else {
139                        return status;
140                    };
141
142                    let file_type = match content_type {
143                        ct if ct.starts_with("text/html") => FileType::Html,
144                        ct if ct.starts_with("text/markdown") => FileType::Markdown,
145                        ct if ct.starts_with("text/plain") => {
146                            let path = Path::new(response.url.path());
147                            match path.extension() {
148                                Some(ext) if ext.eq_ignore_ascii_case("md") => FileType::Markdown,
149                                _ => return status,
150                            }
151                        }
152                        _ => return status,
153                    };
154
155                    self.check_html_fragment(request_url, status, response, file_type)
156                        .await
157                } else {
158                    status
159                }
160            }
161            Err(e) => e.into(),
162        }
163    }
164
165    async fn check_html_fragment(
166        &self,
167        url: Url,
168        status: Status,
169        response: CacheableResponse,
170        file_type: FileType,
171    ) -> Status {
172        let content = response.text;
173        match self
174            .fragment_checker
175            .check(FragmentInput { content, file_type }, &url)
176            .await
177        {
178            Ok(true) => status,
179            Ok(false) => Status::Error(ErrorKind::InvalidFragment(url.into())),
180            Err(e) => Status::Error(e),
181        }
182    }
183
184    /// Checks the given URI of a website.
185    ///
186    /// # Errors
187    ///
188    /// This returns an `Err` if
189    /// - The URI is invalid.
190    /// - The request failed.
191    /// - The response status code is not accepted.
192    /// - The URI cannot be converted to HTTPS.
193    pub(crate) async fn check_website(
194        &self,
195        uri: &Uri,
196        credentials: Option<BasicAuthCredentials>,
197    ) -> Result<Status, ErrorKind> {
198        let default_chain: RequestChain = Chain::new(vec![
199            Box::<Quirks>::default(),
200            Box::new(credentials),
201            Box::new(self.clone()),
202        ]);
203
204        let status = self.check_website_inner(uri, &default_chain).await;
205        let status = self
206            .handle_insecure_url(uri, &default_chain, status)
207            .await?;
208        Ok(self.redirect_history.handle_redirected(&uri.url, status))
209    }
210
211    /// Mark HTTP URLs as insecure, if the user required HTTPS
212    /// and the URL is available under HTTPS.
213    async fn handle_insecure_url(
214        &self,
215        uri: &Uri,
216        default_chain: &Chain<Request, Status>,
217        status: Status,
218    ) -> Result<Status, ErrorKind> {
219        if self.require_https
220            && uri.scheme() == "http"
221            && let Status::Ok(_) = status
222        {
223            let https_uri = uri.to_https()?;
224            let is_https_available = self
225                .check_website_inner(&https_uri, default_chain)
226                .await
227                .is_success();
228
229            if is_https_available {
230                return Ok(Status::Error(ErrorKind::InsecureURL(https_uri)));
231            }
232        }
233
234        Ok(status)
235    }
236
237    /// Checks the given URI of a website.
238    ///
239    /// Unsupported schemes will be ignored
240    ///
241    /// Note: we use `inner` to improve compile times by avoiding monomorphization
242    ///
243    /// # Errors
244    ///
245    /// This returns an `Err` if
246    /// - The URI is invalid.
247    /// - The request failed.
248    /// - The response status code is not accepted.
249    async fn check_website_inner(&self, uri: &Uri, default_chain: &RequestChain) -> Status {
250        let request = self.host_pool.build_request(self.method.clone(), uri);
251
252        let request = match request {
253            Ok(r) => r,
254            Err(e) => return e.into(),
255        };
256
257        let status = ClientRequestChains::new(vec![&self.plugin_request_chain, default_chain])
258            .traverse(request)
259            .await;
260
261        self.handle_github(status, uri).await
262    }
263
264    // Pull out the heavy machinery in case of a failed normal request.
265    // This could be a GitHub URL and we ran into the rate limiter.
266    // TODO: We should try to parse the URI as GitHub URI first (Lucius, Jan 2023)
267    async fn handle_github(&self, status: Status, uri: &Uri) -> Status {
268        if status.is_success() {
269            return status;
270        }
271
272        if let Ok(github_uri) = GithubUri::try_from(uri) {
273            let status = self.check_github(github_uri).await;
274            if status.is_success() {
275                return status;
276            }
277        }
278
279        status
280    }
281
282    /// Check a `uri` hosted on `GitHub` via the GitHub API.
283    ///
284    /// # Caveats
285    ///
286    /// Files inside private repositories won't get checked and instead would
287    /// be reported as valid if the repository itself is reachable through the
288    /// API.
289    ///
290    /// A better approach would be to download the file through the API or
291    /// clone the repo, but we chose the pragmatic approach.
292    async fn check_github(&self, uri: GithubUri) -> Status {
293        let Some(client) = &self.github_client else {
294            return ErrorKind::MissingGitHubToken.into();
295        };
296        let repo = match client.repos(&uri.owner, &uri.repo).get().await {
297            Ok(repo) => repo,
298            Err(e) => return ErrorKind::GithubRequest(Box::new(e)).into(),
299        };
300        if let Some(true) = repo.private {
301            return Status::Ok(StatusCode::OK);
302        } else if let Some(endpoint) = uri.endpoint {
303            return ErrorKind::InvalidGithubUrl(format!("{}/{}/{endpoint}", uri.owner, uri.repo))
304                .into();
305        }
306        Status::Ok(StatusCode::OK)
307    }
308}
309
310/// Clones a `reqwest::Request`.
311///
312/// # Safety
313///
314/// This panics if the request cannot be cloned. This should only happen if the
315/// request body is a `reqwest` stream. We disable the `stream` feature, so the
316/// body should never be a stream.
317///
318/// See <https://github.com/seanmonstar/reqwest/blob/de5dbb1ab849cc301dcefebaeabdf4ce2e0f1e53/src/async_impl/body.rs#L168>
319fn clone_unwrap(request: &Request) -> Request {
320    request.try_clone().expect("Failed to clone request: body was a stream, which should be impossible with `stream` feature disabled")
321}
322
323#[async_trait]
324impl Handler<Request, Status> for WebsiteChecker {
325    async fn handle(&mut self, input: Request) -> ChainResult<Request, Status> {
326        ChainResult::Done(self.retry_request(input).await)
327    }
328}