Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: support batched subscription notifs #1332

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions client/ws-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,29 @@ async fn batched_notification_handler_works() {
}
}

#[tokio::test]
async fn batched_subscription_notif_works() {
init_logger();

let server = WebSocketTestServer::with_hardcoded_subscription(
"127.0.0.1:0".parse().unwrap(),
server_subscription_id_response(Id::Num(0)),
server_batched_subscription("sub", "batched_notif".into()),
)
.with_default_timeout()
.await
.unwrap();

let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
{
let mut nh: Subscription<String> =
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be interesting to extend this for 2 subscriptions and 2 (or more) batched notifications and ensure that responses to go appropriate handlers:

	let server = WebSocketTestServer::with_hardcoded_subscription(
		"127.0.0.1:0".parse().unwrap(),
		server_subscription_id_response(Id::Num(0)),
		server_batched_subscription(&[("sub1", "batched_notif1"), ("sub2", "resp2")]),
	)

    let mut sub1 = client.subscribe("sub1", rpc_params![], "unsub");
    let mut sub2 = client.subscribe("sub2", rpc_params![], "unsub");

    assert_eq!(sub1.next(), "batched_notif1");
    assert_eq!(sub2.next(), "resp2");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a nice extra check yup; right now we can see that the subscription response comes back ok but would be nice to know that eg a few batched subscription responses all end up in the right place :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a look at the updated test we now emit [notif1, notif2, sub] and check that it goes to the correct handler (notif=sub, sub_notif=sub)

Sure we could spawn a few more subscriptions to ensure the responses are propagated properly in the Vec

assert_eq!("batched_notif", response);
}
}

#[tokio::test]
async fn notification_close_on_lagging() {
init_logger();
Expand Down
28 changes: 19 additions & 9 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,15 @@ fn handle_backend_messages<R: TransportReceiverT>(
message: Option<Result<ReceivedMessage, R::Error>>,
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
// Handle raw messages of form `ReceivedMessage::Bytes` (Vec<u8>) or ReceivedMessage::Data` (String).
fn handle_recv_message(
raw: &[u8],
manager: &ThreadSafeRequestManager,
max_buffer_capacity_per_subscription: usize,
) -> Result<Option<FrontToBack>, Error> {
) -> Result<Vec<FrontToBack>, Error> {
let first_non_whitespace = raw.iter().find(|byte| !byte.is_ascii_whitespace());
let mut messages = Vec::new();

match first_non_whitespace {
Some(b'{') => {
Expand All @@ -741,13 +742,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
process_single_response(&mut manager.lock(), single, max_buffer_capacity_per_subscription)?;

if let Some(unsub) = maybe_unsub {
return Ok(Some(FrontToBack::Request(unsub)));
return Ok(vec![FrontToBack::Request(unsub)]);
}
}
// Subscription response.
else if let Ok(response) = serde_json::from_slice::<SubscriptionResponse<_>>(raw) {
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
return Ok(Some(FrontToBack::SubscriptionClosed(sub_id)));
return Ok(vec![FrontToBack::SubscriptionClosed(sub_id)]);
}
}
// Subscription error response.
Expand Down Expand Up @@ -784,6 +785,14 @@ fn handle_backend_messages<R: TransportReceiverT>(
if id > r.end {
r.end = id;
}
} else if let Ok(response) = serde_json::from_str::<SubscriptionResponse<_>>(r.get()) {
got_notif = true;
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
messages.push(FrontToBack::SubscriptionClosed(sub_id));
}
} else if let Ok(response) = serde_json::from_slice::<SubscriptionError<_>>(raw) {
got_notif = true;
process_subscription_close_response(&mut manager.lock(), response);
} else if let Ok(notif) = serde_json::from_str::<Notification<_>>(r.get()) {
got_notif = true;
process_notification(&mut manager.lock(), notif);
Expand All @@ -808,13 +817,13 @@ fn handle_backend_messages<R: TransportReceiverT>(
}
};

Ok(None)
Ok(messages)
}

match message {
Some(Ok(ReceivedMessage::Pong)) => {
tracing::debug!(target: LOG_TARGET, "Received pong");
Ok(None)
Ok(vec![])
}
Some(Ok(ReceivedMessage::Bytes(raw))) => {
handle_recv_message(raw.as_ref(), manager, max_buffer_capacity_per_subscription)
Expand Down Expand Up @@ -1036,14 +1045,15 @@ where
let Some(msg) = maybe_msg else { break Ok(()) };

match handle_backend_messages::<R>(Some(msg), &manager, max_buffer_capacity_per_subscription) {
Ok(Some(msg)) => {
pending_unsubscribes.push(to_send_task.send(msg));
Ok(messages) => {
for msg in messages {
pending_unsubscribes.push(to_send_task.send(msg));
}
}
Err(e) => {
tracing::error!(target: LOG_TARGET, "Failed to read message: {e}");
break Err(e);
}
Ok(None) => (),
}
}
_ = inactivity_stream.next() => {
Expand Down
8 changes: 8 additions & 0 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ pub fn server_batched_notification(method: &str, params: Value) -> String {
format!(r#"[{{"jsonrpc":"2.0","method":"{}", "params":{} }}]"#, method, serde_json::to_string(&params).unwrap())
}

/// Batched server originated notification
pub fn server_batched_subscription(method: &str, result: Value) -> String {
format!(
r#"[{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}]"#,
serde_json::to_string(&result).unwrap()
)
}

pub async fn http_request(body: Body, uri: Uri) -> Result<HttpResponse, String> {
let client = hyper::Client::new();
http_post(client, body, uri).await
Expand Down