1use crate::{
2 BasicAuthCredentials, ErrorKind, FileType, Status, Uri,
3 chain::{Chain, ChainResult, ClientRequestChains, Handler, RequestChain},
4 quirks::Quirks,
5 ratelimit::{CacheableResponse, HostPool},
6 retry::RetryExt,
7 types::{redirect_history::RedirectHistory, uri::github::GithubUri},
8 utils::fragment_checker::{FragmentChecker, FragmentInput},
9};
10use async_trait::async_trait;
11use http::{Method, StatusCode};
12use octocrab::Octocrab;
13use reqwest::{Request, header::CONTENT_TYPE};
14use std::{collections::HashSet, path::Path, sync::Arc, time::Duration};
15use url::Url;
16
17#[derive(Debug, Clone)]
18pub(crate) struct WebsiteChecker {
19 method: reqwest::Method,
21
22 github_client: Option<Octocrab>,
24
25 plugin_request_chain: RequestChain,
27
28 max_retries: u64,
30
31 retry_wait_time: Duration,
34
35 accepted: HashSet<StatusCode>,
39
40 require_https: bool,
44
45 include_fragments: bool,
49
50 fragment_checker: FragmentChecker,
52
53 redirect_history: RedirectHistory,
55
56 host_pool: Arc<HostPool>,
61}
62
63impl WebsiteChecker {
64 #[must_use]
66 pub(crate) fn host_pool(&self) -> Arc<HostPool> {
67 self.host_pool.clone()
68 }
69
70 #[allow(clippy::too_many_arguments)]
71 pub(crate) fn new(
72 method: reqwest::Method,
73 retry_wait_time: Duration,
74 redirect_history: RedirectHistory,
75 max_retries: u64,
76 accepted: HashSet<StatusCode>,
77 github_client: Option<Octocrab>,
78 require_https: bool,
79 plugin_request_chain: RequestChain,
80 include_fragments: bool,
81 host_pool: Arc<HostPool>,
82 ) -> Self {
83 Self {
84 method,
85 github_client,
86 plugin_request_chain,
87 redirect_history,
88 max_retries,
89 retry_wait_time,
90 accepted,
91 require_https,
92 include_fragments,
93 fragment_checker: FragmentChecker::new(),
94 host_pool,
95 }
96 }
97
98 pub(crate) async fn retry_request(&self, request: Request) -> Status {
103 let mut retries: u64 = 0;
104 let mut wait_time = self.retry_wait_time;
105 let mut status = self.check_default(clone_unwrap(&request)).await;
106 while retries < self.max_retries {
107 if status.is_success() || !status.should_retry() {
108 return status;
109 }
110 retries += 1;
111 tokio::time::sleep(wait_time).await;
112 wait_time = wait_time.saturating_mul(2);
113 status = self.check_default(clone_unwrap(&request)).await;
114 }
115
116 status
117 }
118
119 async fn check_default(&self, request: Request) -> Status {
121 let method = request.method().clone();
122 let request_url = request.url().clone();
123
124 match self.host_pool.execute_request(request).await {
125 Ok(response) => {
126 let status = Status::new(&response, &self.accepted);
127 if self.include_fragments
130 && response.status.is_success()
131 && method == Method::GET
132 && request_url.fragment().is_some_and(|x| !x.is_empty())
133 {
134 let Some(content_type) = response
135 .headers
136 .get(CONTENT_TYPE)
137 .and_then(|header| header.to_str().ok())
138 else {
139 return status;
140 };
141
142 let file_type = match content_type {
143 ct if ct.starts_with("text/html") => FileType::Html,
144 ct if ct.starts_with("text/markdown") => FileType::Markdown,
145 ct if ct.starts_with("text/plain") => {
146 let path = Path::new(response.url.path());
147 match path.extension() {
148 Some(ext) if ext.eq_ignore_ascii_case("md") => FileType::Markdown,
149 _ => return status,
150 }
151 }
152 _ => return status,
153 };
154
155 self.check_html_fragment(request_url, status, response, file_type)
156 .await
157 } else {
158 status
159 }
160 }
161 Err(e) => e.into(),
162 }
163 }
164
165 async fn check_html_fragment(
166 &self,
167 url: Url,
168 status: Status,
169 response: CacheableResponse,
170 file_type: FileType,
171 ) -> Status {
172 let content = response.text;
173 match self
174 .fragment_checker
175 .check(FragmentInput { content, file_type }, &url)
176 .await
177 {
178 Ok(true) => status,
179 Ok(false) => Status::Error(ErrorKind::InvalidFragment(url.into())),
180 Err(e) => Status::Error(e),
181 }
182 }
183
184 pub(crate) async fn check_website(
194 &self,
195 uri: &Uri,
196 credentials: Option<BasicAuthCredentials>,
197 ) -> Result<Status, ErrorKind> {
198 let default_chain: RequestChain = Chain::new(vec![
199 Box::<Quirks>::default(),
200 Box::new(credentials),
201 Box::new(self.clone()),
202 ]);
203
204 let status = self.check_website_inner(uri, &default_chain).await;
205 let status = self
206 .handle_insecure_url(uri, &default_chain, status)
207 .await?;
208 Ok(self.redirect_history.handle_redirected(&uri.url, status))
209 }
210
211 async fn handle_insecure_url(
214 &self,
215 uri: &Uri,
216 default_chain: &Chain<Request, Status>,
217 status: Status,
218 ) -> Result<Status, ErrorKind> {
219 if self.require_https
220 && uri.scheme() == "http"
221 && let Status::Ok(_) = status
222 {
223 let https_uri = uri.to_https()?;
224 let is_https_available = self
225 .check_website_inner(&https_uri, default_chain)
226 .await
227 .is_success();
228
229 if is_https_available {
230 return Ok(Status::Error(ErrorKind::InsecureURL(https_uri)));
231 }
232 }
233
234 Ok(status)
235 }
236
237 async fn check_website_inner(&self, uri: &Uri, default_chain: &RequestChain) -> Status {
250 let request = self.host_pool.build_request(self.method.clone(), uri);
251
252 let request = match request {
253 Ok(r) => r,
254 Err(e) => return e.into(),
255 };
256
257 let status = ClientRequestChains::new(vec![&self.plugin_request_chain, default_chain])
258 .traverse(request)
259 .await;
260
261 self.handle_github(status, uri).await
262 }
263
264 async fn handle_github(&self, status: Status, uri: &Uri) -> Status {
268 if status.is_success() {
269 return status;
270 }
271
272 if let Ok(github_uri) = GithubUri::try_from(uri) {
273 let status = self.check_github(github_uri).await;
274 if status.is_success() {
275 return status;
276 }
277 }
278
279 status
280 }
281
282 async fn check_github(&self, uri: GithubUri) -> Status {
293 let Some(client) = &self.github_client else {
294 return ErrorKind::MissingGitHubToken.into();
295 };
296 let repo = match client.repos(&uri.owner, &uri.repo).get().await {
297 Ok(repo) => repo,
298 Err(e) => return ErrorKind::GithubRequest(Box::new(e)).into(),
299 };
300 if let Some(true) = repo.private {
301 return Status::Ok(StatusCode::OK);
302 } else if let Some(endpoint) = uri.endpoint {
303 return ErrorKind::InvalidGithubUrl(format!("{}/{}/{endpoint}", uri.owner, uri.repo))
304 .into();
305 }
306 Status::Ok(StatusCode::OK)
307 }
308}
309
310fn clone_unwrap(request: &Request) -> Request {
320 request.try_clone().expect("Failed to clone request: body was a stream, which should be impossible with `stream` feature disabled")
321}
322
323#[async_trait]
324impl Handler<Request, Status> for WebsiteChecker {
325 async fn handle(&mut self, input: Request) -> ChainResult<Request, Status> {
326 ChainResult::Done(self.retry_request(input).await)
327 }
328}