Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the correct protocol for SQS requests #1807

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion kombu/asynchronous/aws/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def make_request(self, operation, params_, path, verb, callback=None): # noqa
param_payload = {'data': params}
if verb.lower() == 'get':
# query-based opts
signing_type = 'presignurl'
signing_type = 'presign-url'
param_payload = {'params': params}

request = AWSRequest(method=verb, url=path, **param_payload)
Expand Down
82 changes: 82 additions & 0 deletions kombu/asynchronous/aws/sqs/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

from __future__ import annotations

import json
auvipy marked this conversation as resolved.
Show resolved Hide resolved

auvipy marked this conversation as resolved.
Show resolved Hide resolved
from botocore.serialize import Serializer
from vine import transform

from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
from kombu.asynchronous.aws.ext import AWSRequest

from .ext import boto3
from .message import AsyncMessage
Expand All @@ -25,6 +29,84 @@ def __init__(self, sqs_connection, debug=0, region=None, **kwargs):
**kwargs
)

def _create_query_request(self, operation, params, queue_url, method):
params = params.copy()
if operation:
params['Action'] = operation

# defaults for non-get
param_payload = {'data': params}
if method.lower() == 'get':
# query-based opts
param_payload = {'params': params}

return AWSRequest(method=method, url=queue_url, **param_payload)

def _create_json_request(self, operation, params, queue_url):
params = params.copy()
params['QueueUrl'] = queue_url

service_model = self.sqs_connection.meta.service_model
operation_model = service_model.operation_model(operation)

url = self.sqs_connection._endpoint.host

headers = {}
# Content-Type
json_version = operation_model.metadata['jsonVersion']
content_type = f'application/x-amz-json-{json_version}'
headers['Content-Type'] = content_type

# X-Amz-Target
target = '{}.{}'.format(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we try f string here? or there are some version restrictions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I think this is cleaner, since the placeholders would be long, making it hard to read:

f'{operation_model.metadata['targetPrefix']}.{operation_model.name}'

especially since that it would make the line longer than 80 chars, meaning that we will need to split the line, or otherwise define variables to make it shorter.

Let me know if you feel strongly about it.

operation_model.metadata['targetPrefix'],
operation_model.name,
)
headers['X-Amz-Target'] = target

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at a sample request in the docs I also see Content-Encoding: amz-1.0, X-Amz-Date: <date> and a few others in the header as well? I suppose those aren't required and we'll have no issues excluding them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which document are you referring to?

What I did was to check what botocore is doing in their JSONSerialiazer, and I don't see those fields in there.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll find the docs, I think they were in a slack thread that I will share with you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll find the docs, I think they were in a slack thread that I will share with you.

would you mind doing a last round of review?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll find the docs, I think they were in a slack thread that I will share with you.

would you mind doing a last round of review?

@auvipy , we discussed this actually internally and there is nothing to worry about here. I simply followed what the JSONSerializer class from botocore is doing.


param_payload = {
'data': json.dumps(params),
'headers': headers
}

method = operation_model.http.get('method', Serializer.DEFAULT_METHOD)
return AWSRequest(
method=method,
url=url,
**param_payload
)

def make_request(self, operation_name, params, queue_url, verb, callback=None): # noqa
"""
Overide make_request to support different protocols.

botocore is soon going to change the default protocol of communicating
with SQS backend from 'query' to 'json', so we need a special
implementation of make_request for SQS. More information on this can
be found in: https://github.com/celery/kombu/pull/1807.
"""
signer = self.sqs_connection._request_signer

service_model = self.sqs_connection.meta.service_model
protocol = service_model.protocol

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth leaving a comment here explaining why this method is being updated with this branching. Just set a little bit of the context with some links. I know it could be gotten from the git blame/commit, but it's nice to have a bit in the code as well to hint people to go looking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally. I will do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

if protocol == 'query':
request = self._create_query_request(
operation_name, params, queue_url, verb)
elif protocol == 'json':
request = self._create_json_request(
operation_name, params, queue_url)
else:
raise Exception(f'Unsupported protocol: {protocol}.')

signing_type = 'presign-url' if request.method.lower() == 'get' \
else 'standard'

signer.sign(operation_name, request, signing_type=signing_type)
prepared_request = request.prepare()

return self._mexe(prepared_request, callback=callback)

def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
Expand Down
140 changes: 139 additions & 1 deletion t/unit/asynchronous/aws/sqs/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import json
from unittest.mock import MagicMock, Mock

from kombu.asynchronous.aws.ext import boto3
from kombu.asynchronous.aws.ext import AWSRequest, boto3
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
from kombu.asynchronous.aws.sqs.message import AsyncMessage
from kombu.asynchronous.aws.sqs.queue import AsyncQueue
Expand All @@ -11,6 +12,8 @@

from ..case import AWSCase

SQS_URL = 'https://sqs.us-west-2.amazonaws.com/'


class test_AsyncSQSConnection(AWSCase):

Expand All @@ -31,6 +34,141 @@ def setup(self):
'QueueUrl': 'http://aws.com'
})

def MockRequest(self):
return AWSRequest(
method='POST',
url='https://aws.com',
)

def MockOperationModel(self, operation_name, method):
mock = MagicMock()
mock.configure_mock(
http=MagicMock(
get=MagicMock(
return_value=method,
)
),
name=operation_name,
metadata={
'jsonVersion': '1.0',
'targetPrefix': 'sqs',
}
)
return mock

def MockServiceModel(self, operation_name, method):
service_model = MagicMock()
service_model.protocol = 'json',
service_model.operation_model = MagicMock(
return_value=self.MockOperationModel(operation_name, method)
)
return service_model

def assert_requests_equal(self, req1, req2):
assert req1.url == req2.url
assert req1.method == req2.method
assert req1.data == req2.data
assert req1.params == req2.params
assert dict(req1.headers) == dict(req2.headers)

def test_create_query_request(self):
operation_name = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
req = self.x._create_query_request(operation_name, params, queue_url,
verb)
self.assert_requests_equal(req, AWSRequest(
url=queue_url,
method=verb,
data={
'Action': (operation_name),
**params
},
headers={},
))

def test_create_json_request(self):
operation_name = 'ReceiveMessage'
method = 'POST'
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'

self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection._endpoint.host = SQS_URL
self.x.sqs_connection.meta.service_model = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'json',
self.x.sqs_connection.meta.service_model.operation_model = MagicMock(
return_value=self.MockOperationModel(operation_name, method)
)

req = self.x._create_json_request(operation_name, params, queue_url)
self.assert_requests_equal(req, AWSRequest(
url=SQS_URL,
method=method,
data=json.dumps({
**params,
"QueueUrl": queue_url
}),
headers={
'Content-Type': 'application/x-amz-json-1.0',
'X-Amz-Target': f'sqs.{operation_name}'
},
))

def test_make_request__with_query_protocol(self):
# Do the necessary mocking.
self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'query'
self.x._create_query_request = Mock(return_value=self.MockRequest())

# Execute the make_request called and confirm we are creating a
# query request.
operation = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
self.x.make_request(operation, params, queue_url, verb)
self.x._create_query_request.assert_called_with(
operation, params, queue_url, verb
)

def test_make_request__with_json_protocol(self):
# Do the necessary mocking.
self.x.sqs_connection = Mock()
self.x.sqs_connection._request_signer = Mock()
self.x.sqs_connection.meta.service_model.protocol = 'json'
self.x._create_json_request = Mock(return_value=self.MockRequest())

# Execute the make_request called and confirm we are creating a
# query request.
operation = 'ReceiveMessage',
params = {
'MaxNumberOfMessages': 10,
'AttributeName.1': 'ApproximateReceiveCount',
'WaitTimeSeconds': 20
}
queue_url = f'{SQS_URL}/123456789012/celery-test'
verb = 'POST'
self.x.make_request(operation, params, queue_url, verb)
self.x._create_json_request.assert_called_with(
operation, params, queue_url
)

def test_create_queue(self):
self.x.create_queue('foo', callback=self.callback)
self.x.get_object.assert_called_with(
Expand Down