Skip to content

Commit

Permalink
Use the correct protocol for SQS requests
Browse files Browse the repository at this point in the history
TL;DR - The use of boto3 in celery#1759 resulted in relying on blocking
(synchronous) HTTP requests, which caused the performance issue reported
in celery#1783.

`kombu` previously used to craft AWS requests manually as explained in
detail in celery#1726, which resulted in an outage when botocore temporarily
changed the default protocol to JSON (before rolling back due to the
impact on celery and airflow.) To fix the issue, I submitted celery#1759,
which changes `kombu` to use `boto3` instead of manually crafting AWS
requests. This way when boto3 changes the default protocol, kombu won't
be impacted.

While working on celery#1759, I did extensive debugging to understand the
multi-threading nature of kombu. What I discovered is that there isn't
an actual multi-threading in the true sense of the word, but an event
loop that runs on the same thread and process and orchestrate the
communication with SQS. As such, it didn't appear to me that there is
anything to worry about my change, and the testing I did didn't discover
any issue. However, it turns out that while kombu's event loop doesn't
have actual multi-threading, its [reliance on
pycurl](https://github.com/celery/kombu/blob/main/kombu/asynchronous/http/curl.py#L48)
(and thus libcurl) meant that the requests to AWS were being done
asynchronously. On the other hand, boto3 requests are always done
synchronously, i.e. they are blocking requests.

The above meant that my change introduced blocking on the event loop of
kombu. This is fine in most of the cases, since the requests to SQS are
pretty fast. However, in the case of using long-polling, a call to SQS's
ReceiveMessage can last up to 20 seconds (depending on the user
configuration).

To solve this problem, I rolled back my earlier changes and, instead, to
address the issue reported in celery#1726, I now changed the
`AsyncSQSConnection` class such that it crafts either a `query` or a
`json` request depending on the protocol used by the SQS client. Thus,
when botocore changes the default protocol of SQS to JSON, kombu won't
be impacted, since it crafts its own request and, after my change, it
uses a hard-coded protocol based on the crafted requests.

This solution shouldn't be the final solution, and it is more of a
workaround that does the job for now. The final solution should be to
completely rely on boto3 for any communication with AWS, and ensuring
that all requests are async in nature (non-blocking.) This, however, is
a fundamental change that requires a lot of testing, in particular
performance testing.
  • Loading branch information
rafidka committed Oct 23, 2023
1 parent 1dfe4f3 commit a49fe47
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
75 changes: 75 additions & 0 deletions kombu/asynchronous/aws/sqs/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

from __future__ import annotations

import json

from vine import transform

from botocore.serialize import Serializer

from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
from kombu.asynchronous.aws.ext import AWSRequest

from .ext import boto3
from .message import AsyncMessage
Expand All @@ -25,6 +30,76 @@ def __init__(self, sqs_connection, debug=0, region=None, **kwargs):
**kwargs
)

def _create_query_request(self, operation, params, queue_url, method):
params = params.copy()
if operation:
params['Action'] = operation

# defaults for non-get
param_payload = {'data': params}
if method.lower() == 'get':
# query-based opts
param_payload = {'params': params}

return AWSRequest(method=method, url=queue_url, **param_payload)

def _create_json_request(self, operation, params, queue_url):
params = params.copy()
params['QueueUrl'] = queue_url

service_model = self.sqs_connection.meta.service_model
operation_model = service_model.operation_model(operation)

url = self.sqs_connection._endpoint.host

headers = {}
# Content-Type
json_version = operation_model.metadata['jsonVersion']
content_type = f'application/x-amz-json-{json_version}'
headers['Content-Type'] = content_type

# X-Amz-Target
target = '{}.{}'.format(
operation_model.metadata['targetPrefix'],
operation_model.name,
)
headers['X-Amz-Target'] = target

param_payload = {
'data': json.dumps(params),
'headers': headers
}

method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
return AWSRequest(
method=method,
url=url,
**param_payload
)

def make_request(self, operation_name, params, queue_url, verb, callback=None): # noqa
signer = self.sqs_connection._request_signer

service_model = self.sqs_connection.meta.service_model
protocol = service_model.protocol

if protocol == 'query':
request = self._create_query_request(
operation_name, params, queue_url, verb)
elif protocol == 'json':
request = self._create_json_request(
operation_name, params, queue_url)
else:
raise Exception(f'Unsupported protocol: {protocol}.')

signing_type = 'presignurl' if request.method.lower() == 'get' \
else 'standard'

signer.sign(operation_name, request, signing_type=signing_type)
prepared_request = request.prepare()

return self._mexe(prepared_request, callback=callback)

def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
Expand Down
141 changes: 141 additions & 0 deletions t/unit/asynchronous/aws/sqs/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json

from unittest.mock import MagicMock, Mock

from kombu.asynchronous.aws.ext import AWSRequest
from kombu.asynchronous.aws.ext import boto3
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
from kombu.asynchronous.aws.sqs.message import AsyncMessage
Expand All @@ -12,6 +15,9 @@
from ..case import AWSCase


SQS_URL = 'https://sqs.us-west-2.amazonaws.com/'


class test_AsyncSQSConnection(AWSCase):

def setup(self):
Expand All @@ -31,6 +37,141 @@ def setup(self):
'QueueUrl': 'http://aws.com'
})

def MockRequest(self):
return AWSRequest(
method='POST',
url='https://aws.com',
)

def MockOperationModel(self, operation_name, method):
mock = MagicMock()
mock.configure_mock(
http=MagicMock(
get=MagicMock(
return_value=method,
)
),
name=operation_name,
metadata={
'jsonVersion': '1.0',
'targetPrefix': 'sqs',
}
)
return mock

def MockServiceModel(self, operation_name, method):
service_model = MagicMock()
service_model.protocol = 'json',
service_model.operation_model = MagicMock(
return_value=self.MockOperationModel(operation_name, method)
)
return service_model

def assert_requests_equal(self, req1, req2):
assert req1.url == req2.url
assert req1.method == req2.method
assert req1.data == req2.data
assert req1.params == req2.params
assert dict(req1.headers) == dict(req2.headers)

def test_create_query_request(self):
operation_name = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
req = self.x._create_query_request(operation_name, params, queue_url,
verb)
self.assert_requests_equal(req, AWSRequest(
url=queue_url,
method=verb,
data={
'Action': (operation_name),
**params
},
headers={},
))

def test_create_json_request(self):
operation_name = 'ReceiveMessage'
method = 'POST'
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'

self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection._endpoint.host = SQS_URL
self.x.sqs_connection.meta.service_model = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'json',
self.x.sqs_connection.meta.service_model.operation_model = MagicMock(
return_value=self.MockOperationModel(operation_name, method)
)

req = self.x._create_json_request(operation_name, params, queue_url)
self.assert_requests_equal(req, AWSRequest(
url=SQS_URL,
method=method,
data=json.dumps({
**params,
"QueueUrl": queue_url
}),
headers={
'Content-Type': 'application/x-amz-json-1.0',
'X-Amz-Target': f'sqs.{operation_name}'
},
))

def test_make_request__with_query_protocol(self):
# Do the necessary mocking.
self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'query'
self.x._create_query_request = Mock(return_value=self.MockRequest())

# Execute the make_request called and confirm we are creating a
# query request.
operation = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
self.x.make_request(operation, params, queue_url, verb)
self.x._create_query_request.assert_called_with(
operation, params, queue_url, verb
)

def test_make_request__with_json_protocol(self):
# Do the necessary mocking.
self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'json'
self.x._create_json_request = Mock(return_value=self.MockRequest())

# Execute the make_request called and confirm we are creating a
# query request.
operation = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
self.x.make_request(operation, params, queue_url, verb)
self.x._create_json_request.assert_called_with(
operation, params, queue_url
)

def test_create_queue(self):
self.x.create_queue('foo', callback=self.callback)
self.x.get_object.assert_called_with(
Expand Down

0 comments on commit a49fe47

Please sign in to comment.