Skip to content

Commit

Permalink
feat(wasm): support request timeout
Browse files Browse the repository at this point in the history
fixes #1135 #1274
  • Loading branch information
flisky committed Apr 8, 2023
1 parent 7047669 commit c442286
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/wasm/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ async fn fetch(req: Request) -> crate::Result<Response> {
}
}

let abort = AbortGuard::new()?;
let mut abort = AbortGuard::new()?;
if let Some(timeout) = req.timeout() {
abort.timeout(*timeout);
}
init.signal(Some(&abort.signal()));

let js_req = web_sys::Request::new_with_str_and_init(req.url().as_str(), &init)
Expand Down
33 changes: 32 additions & 1 deletion src/wasm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use wasm_bindgen::JsCast;
use std::convert::TryInto;
use std::time::Duration;

use js_sys::Function;
use wasm_bindgen::prelude::{wasm_bindgen, Closure};
use wasm_bindgen::{JsCast, JsValue};
use web_sys::{AbortController, AbortSignal};

mod body;
Expand All @@ -14,6 +19,15 @@ pub use self::client::{Client, ClientBuilder};
pub use self::request::{Request, RequestBuilder};
pub use self::response::Response;

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_name = "setTimeout")]
fn set_timeout(handler: &Function, timeout: i32) -> JsValue;

#[wasm_bindgen(js_name = "clearTimeout")]
fn clear_timeout(handle: JsValue) -> JsValue;
}

async fn promise<T>(promise: js_sys::Promise) -> Result<T, crate::error::BoxError>
where
T: JsCast,
Expand All @@ -30,6 +44,7 @@ where
/// A guard that cancels a fetch request when dropped.
struct AbortGuard {
ctrl: AbortController,
timeout: Option<(JsValue, Closure<dyn FnMut()>)>,
}

impl AbortGuard {
Expand All @@ -38,16 +53,32 @@ impl AbortGuard {
ctrl: AbortController::new()
.map_err(crate::error::wasm)
.map_err(crate::error::builder)?,
timeout: None,
})
}

fn signal(&self) -> AbortSignal {
self.ctrl.signal()
}

fn timeout(&mut self, timeout: Duration) {
let ctrl = self.ctrl.clone();
let abort = Closure::once(move || ctrl.abort());
let timeout = set_timeout(
abort.as_ref().unchecked_ref::<js_sys::Function>(),
timeout.as_millis().try_into().expect("timeout"),
);
if let Some((id, _)) = self.timeout.replace((timeout, abort)) {
clear_timeout(id);
}
}
}

impl Drop for AbortGuard {
fn drop(&mut self) {
self.ctrl.abort();
if let Some((id, _)) = self.timeout.take() {
clear_timeout(id);
}
}
}
25 changes: 25 additions & 0 deletions src/wasm/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::convert::TryFrom;
use std::fmt;
use std::time::Duration;

use bytes::Bytes;
use http::{request::Parts, Method, Request as HttpRequest};
Expand All @@ -18,6 +19,7 @@ pub struct Request {
url: Url,
headers: HeaderMap,
body: Option<Body>,
timeout: Option<Duration>,
pub(super) cors: bool,
pub(super) credentials: Option<RequestCredentials>,
}
Expand All @@ -37,6 +39,7 @@ impl Request {
url,
headers: HeaderMap::new(),
body: None,
timeout: None,
cors: true,
credentials: None,
}
Expand Down Expand Up @@ -90,6 +93,18 @@ impl Request {
&mut self.body
}

/// Get the timeout.
#[inline]
pub fn timeout(&self) -> Option<&Duration> {
self.timeout.as_ref()
}

/// Get a mutable reference to the timeout.
#[inline]
pub fn timeout_mut(&mut self) -> &mut Option<Duration> {
&mut self.timeout
}

/// Attempts to clone the `Request`.
///
/// None is returned if a body is which can not be cloned.
Expand All @@ -104,6 +119,7 @@ impl Request {
url: self.url.clone(),
headers: self.headers.clone(),
body,
timeout: self.timeout.clone(),
cors: self.cors,
credentials: self.credentials,
})
Expand Down Expand Up @@ -233,6 +249,14 @@ impl RequestBuilder {
self
}

/// Enables a request timeout.
pub fn timeout(mut self, timeout: Duration) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.timeout_mut() = Some(timeout);
}
self
}

/// TODO
#[cfg(feature = "multipart")]
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
Expand Down Expand Up @@ -449,6 +473,7 @@ where
url,
headers,
body: Some(body.into()),
timeout: None,
cors: true,
credentials: None,
})
Expand Down
1 change: 0 additions & 1 deletion tests/timeouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ async fn request_timeout() {
assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str()));
}

#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn connect_timeout() {
let _ = env_logger::try_init();
Expand Down
15 changes: 15 additions & 0 deletions tests/wasm_simple.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(target_arch = "wasm32")]
use std::time::Duration;

use wasm_bindgen::prelude::*;
use wasm_bindgen_test::*;
Expand All @@ -22,3 +23,17 @@ async fn simple_example() {
let body = res.text().await.expect("response to utf-8 text");
log(&format!("Body:\n\n{}", body));
}

#[wasm_bindgen_test]
async fn request_with_timeout() {
let client = reqwest::Client::new();
let err = client
.get("https://hyper.rs")
.timeout(Duration::from_millis(10))
.send()
.await
.expect_err("Expected error from aborted request");

assert!(err.is_request());
assert!(format!("{:?}", err).contains("The user aborted a request."));
}

0 comments on commit c442286

Please sign in to comment.