Skip to content

Commit

Permalink
fix(client): disconnect_reason/read_error is cancel-safe (#1347)
Browse files Browse the repository at this point in the history
* client: `disconnect_reason/read_error` cancel-safe

If/when the connection is closed, the cause is fetched by `read_error` from the background task.
It was not cancel-safe, which could have side-effects and mutate the state,
such as the internal Option could be `None` and cause a panic.

* fix wasm build

* remove async_lock dependency
  • Loading branch information
niklasad1 committed Apr 8, 2024
1 parent 166bccc commit 5c65daf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 60 deletions.
3 changes: 0 additions & 3 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ serde_json = { version = "1", features = ["raw_value"] }
tracing = "0.1.34"

# optional deps
async-lock = { version = "3.0", optional = true }
futures-util = { version = "0.3.14", default-features = false, optional = true }
hyper = { version = "0.14.10", default-features = false, features = ["stream"], optional = true }
rustc-hash = { version = "1", optional = true }
Expand All @@ -51,7 +50,6 @@ server = [
]
client = ["futures-util/sink", "tokio/sync"]
async-client = [
"async-lock",
"client",
"futures-util/alloc",
"rustc-hash",
Expand All @@ -63,7 +61,6 @@ async-client = [
"pin-project",
]
async-wasm-client = [
"async-lock",
"client",
"futures-util/alloc",
"wasm-bindgen-futures",
Expand Down
86 changes: 29 additions & 57 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ use jsonrpsee_types::{InvalidRequestId, ResponseSuccess, TwoPointZero};
use manager::RequestManager;
use std::sync::Arc;

use async_lock::RwLock as AsyncRwLock;
use async_trait::async_trait;
use futures_timer::Delay;
use futures_util::future::{self, Either};
Expand All @@ -69,6 +68,7 @@ use self::utils::{InactivityCheck, IntervalStream};
use super::{generate_batch_id_range, subscription_channel, FrontToBack, IdKind, RequestIdManager};

const LOG_TARGET: &str = "jsonrpsee-client";
const NOT_POISONED: &str = "Not poisoned; qed";

/// Configuration for WebSocket ping/pong mechanism and it may be used to disconnect
/// an inactive connection.
Expand Down Expand Up @@ -142,67 +142,39 @@ impl ThreadSafeRequestManager {
}

pub(crate) fn lock(&self) -> std::sync::MutexGuard<RequestManager> {
self.0.lock().expect("Not poisoned; qed")
self.0.lock().expect(NOT_POISONED)
}
}

pub(crate) type SharedDisconnectReason = Arc<std::sync::RwLock<Option<Arc<Error>>>>;

/// If the background thread is terminated, this type
/// can be used to read the error cause.
///
// NOTE: This is an AsyncRwLock to be &self.
#[derive(Debug)]
struct ErrorFromBack(AsyncRwLock<Option<ReadErrorOnce>>);
struct ErrorFromBack {
conn: mpsc::Sender<FrontToBack>,
disconnect_reason: SharedDisconnectReason,
}

impl ErrorFromBack {
fn new(unread: oneshot::Receiver<Error>) -> Self {
Self(AsyncRwLock::new(Some(ReadErrorOnce::Unread(unread))))
fn new(conn: mpsc::Sender<FrontToBack>, disconnect_reason: SharedDisconnectReason) -> Self {
Self { conn, disconnect_reason }
}

async fn read_error(&self) -> Error {
const PROOF: &str = "Option is only is used to workaround ownership issue and is always Some; qed";
// When the background task is closed the error is written to `disconnect_reason`.
self.conn.closed().await;

if let ReadErrorOnce::Read(ref err) = self.0.read().await.as_ref().expect(PROOF) {
return Error::RestartNeeded(err.clone());
};

let mut write = self.0.write().await;
let state = write.take();

let err = match state.expect(PROOF) {
ReadErrorOnce::Unread(rx) => {
let arc_err = Arc::new(match rx.await {
Ok(err) => err,
// This should never happen because the receiving end is still alive.
// Before shutting down the background task a error message should
// be emitted.
Err(_) => Error::Custom(
"Error reason could not be found. This is a bug. Please open an issue.".to_string(),
),
});
*write = Some(ReadErrorOnce::Read(arc_err.clone()));
arc_err
}
ReadErrorOnce::Read(arc_err) => {
*write = Some(ReadErrorOnce::Read(arc_err.clone()));
arc_err
}
};

Error::RestartNeeded(err)
if let Some(err) = self.disconnect_reason.read().expect(NOT_POISONED).as_ref() {
Error::RestartNeeded(err.clone())
} else {
Error::Custom("Error reason could not be found. This is a bug. Please open an issue.".to_string())
}
}
}

/// Wrapper over a [`oneshot::Receiver`] that reads
/// the underlying channel once and then stores the result in String.
/// It is possible that the error is read more than once if several calls are made
/// when the background thread has been terminated.
#[derive(Debug)]
enum ReadErrorOnce {
/// Error message is already read.
Read(Arc<Error>),
/// Error message is unread.
Unread(oneshot::Receiver<Error>),
}

/// Builder for [`Client`].
#[derive(Debug, Copy, Clone)]
pub struct ClientBuilder {
Expand Down Expand Up @@ -318,7 +290,7 @@ impl ClientBuilder {
R: TransportReceiverT + Send,
{
let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests);
let (err_to_front, err_from_back) = oneshot::channel::<Error>();
let disconnect_reason = SharedDisconnectReason::default();
let max_buffer_capacity_per_subscription = self.max_buffer_capacity_per_subscription;
let (client_dropped_tx, client_dropped_rx) = oneshot::channel();
let (send_receive_task_sync_tx, send_receive_task_sync_rx) = mpsc::channel(1);
Expand Down Expand Up @@ -366,12 +338,12 @@ impl ClientBuilder {
inactivity_stream,
}));

tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, err_to_front));
tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, disconnect_reason.clone()));

Client {
to_back,
to_back: to_back.clone(),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(err_from_back),
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
max_log_length: self.max_log_length,
on_exit: Some(client_dropped_tx),
Expand All @@ -391,7 +363,7 @@ impl ClientBuilder {
type PendingIntervalStream = IntervalStream<Pending<()>>;

let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests);
let (err_to_front, err_from_back) = oneshot::channel::<Error>();
let disconnect_reason = SharedDisconnectReason::default();
let max_buffer_capacity_per_subscription = self.max_buffer_capacity_per_subscription;
let (client_dropped_tx, client_dropped_rx) = oneshot::channel();
let (send_receive_task_sync_tx, send_receive_task_sync_rx) = mpsc::channel(1);
Expand Down Expand Up @@ -423,13 +395,13 @@ impl ClientBuilder {
wasm_bindgen_futures::spawn_local(wait_for_shutdown(
send_receive_task_sync_rx,
client_dropped_rx,
err_to_front,
disconnect_reason.clone(),
));

Client {
to_back,
to_back: to_back.clone(),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(err_from_back),
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
max_log_length: self.max_log_length,
on_exit: Some(client_dropped_tx),
Expand Down Expand Up @@ -474,7 +446,7 @@ impl Client {
///
/// # Cancel-safety
///
/// This method is not cancel-safe
/// This method is cancel-safe
pub async fn disconnect_reason(&self) -> Error {
self.error.read_error().await
}
Expand Down Expand Up @@ -1070,14 +1042,14 @@ where
async fn wait_for_shutdown(
mut close_rx: mpsc::Receiver<Result<(), Error>>,
client_dropped: oneshot::Receiver<()>,
err_to_front: oneshot::Sender<Error>,
err_to_front: SharedDisconnectReason,
) {
let rx_item = close_rx.recv();

tokio::pin!(rx_item);

// Send an error to the frontend if the send or receive task completed with an error.
if let Either::Left((Some(Err(err)), _)) = future::select(rx_item, client_dropped).await {
let _ = err_to_front.send(err);
*err_to_front.write().expect(NOT_POISONED) = Some(Arc::new(err));
}
}

0 comments on commit 5c65daf

Please sign in to comment.