Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: support batched subscription notifs #1332

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 36 additions & 6 deletions client/ws-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::ErrorObjectOwned;
use jsonrpsee_types::{Notification, SubscriptionId, SubscriptionPayload, SubscriptionResponse};
use serde_json::Value as JsonValue;

fn init_logger() {
Expand Down Expand Up @@ -152,7 +153,7 @@ async fn subscription_works() {
let server = WebSocketTestServer::with_hardcoded_subscription(
"127.0.0.1:0".parse().unwrap(),
server_subscription_id_response(Id::Num(0)),
server_subscription_response(JsonValue::String("hello my friend".to_owned())),
server_subscription_response("subscribe_hello", "hello my friend".into()),
)
.with_default_timeout()
.await
Expand Down Expand Up @@ -192,22 +193,51 @@ async fn notification_handler_works() {
}

#[tokio::test]
async fn batched_notification_handler_works() {
let server = WebSocketTestServer::with_hardcoded_notification(
async fn batched_notifs_works() {
init_logger();

let notifs = vec![
serde_json::to_value(&Notification::new("test".into(), "method_notif".to_string())).unwrap(),
serde_json::to_value(&Notification::new("sub".into(), "method_notif".to_string())).unwrap(),
serde_json::to_value(&SubscriptionResponse::new(
"sub".into(),
SubscriptionPayload {
subscription: SubscriptionId::Str("D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o".into()),
result: "sub_notif".to_string(),
},
))
.unwrap(),
];

let serialized_batch = serde_json::to_string(&notifs).unwrap();

let server = WebSocketTestServer::with_hardcoded_subscription(
"127.0.0.1:0".parse().unwrap(),
server_batched_notification("test", "batched server originated notification works".into()),
server_subscription_id_response(Id::Num(0)),
serialized_batch,
)
.with_default_timeout()
.await
.unwrap();

let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();

// Ensure that subscription is returned back to the correct handle
// and is handled separately from ordinary notifications.
{
let mut nh: Subscription<String> =
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be interesting to extend this for 2 subscriptions and 2 (or more) batched notifications and ensure that responses to go appropriate handlers:

	let server = WebSocketTestServer::with_hardcoded_subscription(
		"127.0.0.1:0".parse().unwrap(),
		server_subscription_id_response(Id::Num(0)),
		server_batched_subscription(&[("sub1", "batched_notif1"), ("sub2", "resp2")]),
	)

    let mut sub1 = client.subscribe("sub1", rpc_params![], "unsub");
    let mut sub2 = client.subscribe("sub2", rpc_params![], "unsub");

    assert_eq!(sub1.next(), "batched_notif1");
    assert_eq!(sub2.next(), "resp2");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a nice extra check yup; right now we can see that the subscription response comes back ok but would be nice to know that eg a few batched subscription responses all end up in the right place :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a look at the updated test we now emit [notif1, notif2, sub] and check that it goes to the correct handler (notif=sub, sub_notif=sub)

Sure we could spawn a few more subscriptions to ensure the responses are propagated properly in the Vec

assert_eq!("sub_notif", response);
}

// Ensure that method notif is returned back to the correct handle.
{
let mut nh: Subscription<String> =
client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("batched server originated notification works".to_owned(), response);
assert_eq!("method_notif", response);
}
}

Expand Down
28 changes: 19 additions & 9 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,15 @@ fn handle_backend_messages<R: TransportReceiverT>(
message: Option<Result<ReceivedMessage, R::Error>>,
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
// Handle raw messages of form `ReceivedMessage::Bytes` (Vec<u8>) or ReceivedMessage::Data` (String).
fn handle_recv_message(
raw: &[u8],
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
let first_non_whitespace = raw.iter().find(|byte| !byte.is_ascii_whitespace());
let mut messages = Vec::new();

match first_non_whitespace {
Some(b'{') => {
Expand All @@ -741,13 +742,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
process_single_response(&mut manager.lock(), single, max_buffer_capacity_per_subscription)?;

if let Some(unsub) = maybe_unsub {
return Ok(Some(FrontToBack::Request(unsub)));
return Ok(vec![FrontToBack::Request(unsub)]);
}
}
// Subscription response.
else if let Ok(response) = serde_json::from_slice::<SubscriptionResponse<_>>(raw) {
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
return Ok(Some(FrontToBack::SubscriptionClosed(sub_id)));
return Ok(vec![FrontToBack::SubscriptionClosed(sub_id)]);
}
}
// Subscription error response.
Expand Down Expand Up @@ -784,6 +785,14 @@ fn handle_backend_messages<R: TransportReceiverT>(
if id > r.end {
r.end = id;
}
} else if let Ok(response) = serde_json::from_str::<SubscriptionResponse<_>>(r.get()) {
got_notif = true;
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
messages.push(FrontToBack::SubscriptionClosed(sub_id));
}
} else if let Ok(response) = serde_json::from_slice::<SubscriptionError<_>>(raw) {
got_notif = true;
process_subscription_close_response(&mut manager.lock(), response);
} else if let Ok(notif) = serde_json::from_str::<Notification<_>>(r.get()) {
got_notif = true;
process_notification(&mut manager.lock(), notif);
Expand All @@ -808,13 +817,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
}
};

Ok(None)
Ok(messages)
}

match message {
Some(Ok(ReceivedMessage::Pong)) => {
tracing::debug!(target: LOG_TARGET, "Received pong");
Ok(None)
Ok(vec![])
}
Some(Ok(ReceivedMessage::Bytes(raw))) => {
handle_recv_message(raw.as_ref(), manager, max_buffer_capacity_per_subscription)
Expand Down Expand Up @@ -1036,14 +1045,15 @@ where
let Some(msg) = maybe_msg else { break Ok(()) };

match handle_backend_messages::<R>(Some(msg), &manager, max_buffer_capacity_per_subscription) {
Ok(Some(msg)) => {
pending_unsubscribes.push(to_send_task.send(msg));
Ok(messages) => {
for msg in messages {
pending_unsubscribes.push(to_send_task.send(msg));
}
}
Err(e) => {
tracing::error!(target: LOG_TARGET, "Failed to read message: {e}");
break Err(e);
}
Ok(None) => (),
}
}
_ = inactivity_stream.next() => {
Expand Down
9 changes: 2 additions & 7 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ pub fn server_subscription_id_response(id: Id) -> String {
}

/// Server response to a hardcoded pending subscription
pub fn server_subscription_response(result: Value) -> String {
pub fn server_subscription_response(method: &str, result: Value) -> String {
format!(
r#"{{"jsonrpc":"2.0","method":"bar","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#,
r#"{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#,
serde_json::to_string(&result).unwrap()
)
}
Expand All @@ -186,11 +186,6 @@ pub fn server_notification(method: &str, params: Value) -> String {
format!(r#"{{"jsonrpc":"2.0","method":"{}", "params":{} }}"#, method, serde_json::to_string(&params).unwrap())
}

/// Batched server originated notification
pub fn server_batched_notification(method: &str, params: Value) -> String {
format!(r#"[{{"jsonrpc":"2.0","method":"{}", "params":{} }}]"#, method, serde_json::to_string(&params).unwrap())
}

pub async fn http_request(body: Body, uri: Uri) -> Result<HttpResponse, String> {
let client = hyper::Client::new();
http_post(client, body, uri).await
Expand Down