Skip to content

Commit

Permalink
client: support batched subscription notifs (#1332)
Browse files Browse the repository at this point in the history
* client: support batched subscriptions notifs

* address grumbles
  • Loading branch information
niklasad1 committed Apr 5, 2024
1 parent 87a37a7 commit 0620f86
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
42 changes: 36 additions & 6 deletions client/ws-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::ErrorObjectOwned;
use jsonrpsee_types::{Notification, SubscriptionId, SubscriptionPayload, SubscriptionResponse};
use serde_json::Value as JsonValue;

fn init_logger() {
Expand Down Expand Up @@ -152,7 +153,7 @@ async fn subscription_works() {
let server = WebSocketTestServer::with_hardcoded_subscription(
"127.0.0.1:0".parse().unwrap(),
server_subscription_id_response(Id::Num(0)),
server_subscription_response(JsonValue::String("hello my friend".to_owned())),
server_subscription_response("subscribe_hello", "hello my friend".into()),
)
.with_default_timeout()
.await
Expand Down Expand Up @@ -192,22 +193,51 @@ async fn notification_handler_works() {
}

#[tokio::test]
async fn batched_notification_handler_works() {
let server = WebSocketTestServer::with_hardcoded_notification(
async fn batched_notifs_works() {
init_logger();

let notifs = vec![
serde_json::to_value(&Notification::new("test".into(), "method_notif".to_string())).unwrap(),
serde_json::to_value(&Notification::new("sub".into(), "method_notif".to_string())).unwrap(),
serde_json::to_value(&SubscriptionResponse::new(
"sub".into(),
SubscriptionPayload {
subscription: SubscriptionId::Str("D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o".into()),
result: "sub_notif".to_string(),
},
))
.unwrap(),
];

let serialized_batch = serde_json::to_string(&notifs).unwrap();

let server = WebSocketTestServer::with_hardcoded_subscription(
"127.0.0.1:0".parse().unwrap(),
server_batched_notification("test", "batched server originated notification works".into()),
server_subscription_id_response(Id::Num(0)),
serialized_batch,
)
.with_default_timeout()
.await
.unwrap();

let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();

// Ensure that subscription is returned back to the correct handle
// and is handled separately from ordinary notifications.
{
let mut nh: Subscription<String> =
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("sub_notif", response);
}

// Ensure that method notif is returned back to the correct handle.
{
let mut nh: Subscription<String> =
client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("batched server originated notification works".to_owned(), response);
assert_eq!("method_notif", response);
}
}

Expand Down
28 changes: 19 additions & 9 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,15 @@ fn handle_backend_messages<R: TransportReceiverT>(
message: Option<Result<ReceivedMessage, R::Error>>,
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
// Handle raw messages of form `ReceivedMessage::Bytes` (Vec<u8>) or ReceivedMessage::Data` (String).
fn handle_recv_message(
raw: &[u8],
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
let first_non_whitespace = raw.iter().find(|byte| !byte.is_ascii_whitespace());
let mut messages = Vec::new();

match first_non_whitespace {
Some(b'{') => {
Expand All @@ -741,13 +742,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
process_single_response(&mut manager.lock(), single, max_buffer_capacity_per_subscription)?;

if let Some(unsub) = maybe_unsub {
return Ok(Some(FrontToBack::Request(unsub)));
return Ok(vec![FrontToBack::Request(unsub)]);
}
}
// Subscription response.
else if let Ok(response) = serde_json::from_slice::<SubscriptionResponse<_>>(raw) {
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
return Ok(Some(FrontToBack::SubscriptionClosed(sub_id)));
return Ok(vec![FrontToBack::SubscriptionClosed(sub_id)]);
}
}
// Subscription error response.
Expand Down Expand Up @@ -784,6 +785,14 @@ fn handle_backend_messages<R: TransportReceiverT>(
if id > r.end {
r.end = id;
}
} else if let Ok(response) = serde_json::from_str::<SubscriptionResponse<_>>(r.get()) {
got_notif = true;
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
messages.push(FrontToBack::SubscriptionClosed(sub_id));
}
} else if let Ok(response) = serde_json::from_slice::<SubscriptionError<_>>(raw) {
got_notif = true;
process_subscription_close_response(&mut manager.lock(), response);
} else if let Ok(notif) = serde_json::from_str::<Notification<_>>(r.get()) {
got_notif = true;
process_notification(&mut manager.lock(), notif);
Expand All @@ -808,13 +817,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
}
};

Ok(None)
Ok(messages)
}

match message {
Some(Ok(ReceivedMessage::Pong)) => {
tracing::debug!(target: LOG_TARGET, "Received pong");
Ok(None)
Ok(vec![])
}
Some(Ok(ReceivedMessage::Bytes(raw))) => {
handle_recv_message(raw.as_ref(), manager, max_buffer_capacity_per_subscription)
Expand Down Expand Up @@ -1036,14 +1045,15 @@ where
let Some(msg) = maybe_msg else { break Ok(()) };

match handle_backend_messages::<R>(Some(msg), &manager, max_buffer_capacity_per_subscription) {
Ok(Some(msg)) => {
pending_unsubscribes.push(to_send_task.send(msg));
Ok(messages) => {
for msg in messages {
pending_unsubscribes.push(to_send_task.send(msg));
}
}
Err(e) => {
tracing::error!(target: LOG_TARGET, "Failed to read message: {e}");
break Err(e);
}
Ok(None) => (),
}
}
_ = inactivity_stream.next() => {
Expand Down
9 changes: 2 additions & 7 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ pub fn server_subscription_id_response(id: Id) -> String {
}

/// Server response to a hardcoded pending subscription
pub fn server_subscription_response(result: Value) -> String {
pub fn server_subscription_response(method: &str, result: Value) -> String {
format!(
r#"{{"jsonrpc":"2.0","method":"bar","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#,
r#"{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#,
serde_json::to_string(&result).unwrap()
)
}
Expand All @@ -186,11 +186,6 @@ pub fn server_notification(method: &str, params: Value) -> String {
format!(r#"{{"jsonrpc":"2.0","method":"{}", "params":{} }}"#, method, serde_json::to_string(&params).unwrap())
}

/// Batched server originated notification
pub fn server_batched_notification(method: &str, params: Value) -> String {
format!(r#"[{{"jsonrpc":"2.0","method":"{}", "params":{} }}]"#, method, serde_json::to_string(&params).unwrap())
}

pub async fn http_request(body: Body, uri: Uri) -> Result<HttpResponse, String> {
let client = hyper::Client::new();
http_post(client, body, uri).await
Expand Down

0 comments on commit 0620f86

Please sign in to comment.