Skip to content

Commit

Permalink
Revert "[fix #1726] Use boto3 for SQS async requests (#1759)" (#1799)
Browse files Browse the repository at this point in the history
This reverts commit 862d0bc.
  • Loading branch information
auvipy committed Oct 10, 2023
1 parent 678c0db commit 1dfe4f3
Show file tree
Hide file tree
Showing 6 changed files with 801 additions and 104 deletions.
202 changes: 164 additions & 38 deletions kombu/asynchronous/aws/sqs/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

from kombu.asynchronous import get_event_loop
from vine import transform

from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection

from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue

__all__ = ('AsyncSQSConnection',)

Expand All @@ -18,50 +21,173 @@ def __init__(self, sqs_connection, debug=0, region=None, **kwargs):
raise ImportError('boto3 is not installed')
super().__init__(
sqs_connection,
region_name=region,
debug=debug,
region_name=region, debug=debug,
**kwargs
)
self.hub = kwargs.get('hub') or get_event_loop()

def _async_sqs_request(self, api, callback, *args, **kwargs):
"""Makes an asynchronous request to an SQS API.
Arguments:
---------
api -- The name of the API, e.g. 'receive_message'.
callback -- The callback to pass the response to when it is available.
*args, **kwargs -- The arguments and keyword arguments to pass to the
SQS API. Those are API dependent and can be found in the boto3
documentation.
"""
# Define a method to execute the SQS API synchronously.
def sqs_request(api, callback, args, kwargs):
method = getattr(self.sqs_connection, api)
resp = method(*args, **kwargs)
if callback:
callback(resp)

# Hand off the request to the event loop to execute it asynchronously.
self.hub.call_soon(sqs_request, api, callback, args, kwargs)

def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
if visibility_timeout:
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
return self.get_object('CreateQueue', params,
callback=callback)

def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)

def get_queue_url(self, queue):
res = self.sqs_connection.get_queue_url(QueueName=queue)
return res['QueueUrl']

def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
queue.id, callback=callback,
)

def set_queue_attribute(self, queue, attribute, value, callback=None):
return self.get_status(
'SetQueueAttribute',
{'Attribute.Name': attribute, 'Attribute.Value': value},
queue.id, callback=callback,
)

def receive_message(
self, queue_url, number_messages=1, visibility_timeout=None,
self, queue, queue_url, number_messages=1, visibility_timeout=None,
attributes=('ApproximateReceiveCount',), wait_time_seconds=None,
callback=None
):
kwargs = {
"QueueUrl": queue_url,
"MaxNumberOfMessages": number_messages,
"MessageAttributeNames": attributes,
"WaitTimeSeconds": wait_time_seconds,
}
params = {'MaxNumberOfMessages': number_messages}
if visibility_timeout:
kwargs["VisibilityTimeout"] = visibility_timeout
params['VisibilityTimeout'] = visibility_timeout
if attributes:
attrs = {}
for idx, attr in enumerate(attributes):
attrs['AttributeName.' + str(idx + 1)] = attr
params.update(attrs)
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
return self.get_list(
'ReceiveMessage', params, [('Message', AsyncMessage)],
queue_url, callback=callback, parent=queue,
)

def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
queue, receipt_handle, callback,
)

def delete_message_batch(self, queue, messages, callback=None):
params = {}
for i, m in enumerate(messages):
prefix = f'DeleteMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': m.id,
f'{prefix}.ReceiptHandle': m.receipt_handle,
})
return self.get_object(
'DeleteMessageBatch', params, queue.id,
verb='POST', callback=callback,
)

def delete_message_from_handle(self, queue, receipt_handle,
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
queue, callback=callback,
)

def send_message(self, queue, message_content,
delay_seconds=None, callback=None):
params = {'MessageBody': message_content}
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
'SendMessage', params, queue.id,
verb='POST', callback=callback,
)

return self._async_sqs_request('receive_message', callback, **kwargs)
def send_message_batch(self, queue, messages, callback=None):
params = {}
for i, msg in enumerate(messages):
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': msg[0],
f'{prefix}.MessageBody': msg[1],
f'{prefix}.DelaySeconds': msg[2],
})
return self.get_object(
'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)

def delete_message(self, queue_url, receipt_handle, callback=None):
return self._async_sqs_request('delete_message', callback,
QueueUrl=queue_url,
ReceiptHandle=receipt_handle)
def change_message_visibility(self, queue, receipt_handle,
visibility_timeout, callback=None):
return self.get_status(
'ChangeMessageVisibility',
{'ReceiptHandle': receipt_handle,
'VisibilityTimeout': visibility_timeout},
queue.id, callback=callback,
)

def change_message_visibility_batch(self, queue, messages, callback=None):
params = {}
for i, t in enumerate(messages):
pre = f'ChangeMessageVisibilityBatchRequestEntry.{i + 1}'
params.update({
f'{pre}.Id': t[0].id,
f'{pre}.ReceiptHandle': t[0].receipt_handle,
f'{pre}.VisibilityTimeout': t[1],
})
return self.get_object(
'ChangeMessageVisibilityBatch', params, queue.id,
verb='POST', callback=callback,
)

def get_all_queues(self, prefix='', callback=None):
params = {}
if prefix:
params['QueueNamePrefix'] = prefix
return self.get_list(
'ListQueues', params, [('QueueUrl', AsyncQueue)],
callback=callback,
)

def get_queue(self, queue_name, callback=None):
# TODO Does not support owner_acct_id argument
return self.get_all_queues(
queue_name,
transform(self._on_queue_ready, callback, queue_name),
)
lookup = get_queue

def _on_queue_ready(self, name, queues):
return next(
(q for q in queues if q.url.endswith(name)), None,
)

def get_dead_letter_source_queues(self, queue, callback=None):
return self.get_list(
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
[('QueueUrl', AsyncQueue)],
callback=callback,
)

def add_permission(self, queue, label, aws_account_id, action_name,
callback=None):
return self.get_status(
'AddPermission',
{'Label': label,
'AWSAccountId': aws_account_id,
'ActionName': action_name},
queue.id, callback=callback,
)

def remove_permission(self, queue, label, callback=None):
return self.get_status(
'RemovePermission', {'Label': label}, queue.id, callback=callback,
)
130 changes: 130 additions & 0 deletions kombu/asynchronous/aws/sqs/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Amazon SQS queue implementation."""

from __future__ import annotations

from vine import transform

from .message import AsyncMessage

_all__ = ['AsyncQueue']


def list_first(rs):
"""Get the first item in a list, or None if list empty."""
return rs[0] if len(rs) == 1 else None


class AsyncQueue:
"""Async SQS Queue."""

def __init__(self, connection=None, url=None, message_class=AsyncMessage):
self.connection = connection
self.url = url
self.message_class = message_class
self.visibility_timeout = None

def _NA(self, *args, **kwargs):
raise NotImplementedError()
count_slow = dump = save_to_file = save_to_filename = save = \
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
load = clear = _NA

def get_attributes(self, attributes='All', callback=None):
return self.connection.get_queue_attributes(
self, attributes, callback,
)

def set_attribute(self, attribute, value, callback=None):
return self.connection.set_queue_attribute(
self, attribute, value, callback,
)

def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
return self.get_attributes(
_attr, transform(
self._coerce_field_value, callback, _attr, int,
),
)

def _coerce_field_value(self, key, type, response):
return type(response[key])

def set_timeout(self, visibility_timeout, callback=None):
return self.set_attribute(
'VisibilityTimeout', visibility_timeout,
transform(
self._on_timeout_set, callback,
)
)

def _on_timeout_set(self, visibility_timeout):
if visibility_timeout:
self.visibility_timeout = visibility_timeout
return self.visibility_timeout

def add_permission(self, label, aws_account_id, action_name,
callback=None):
return self.connection.add_permission(
self, label, aws_account_id, action_name, callback,
)

def remove_permission(self, label, callback=None):
return self.connection.remove_permission(self, label, callback)

def read(self, visibility_timeout=None, wait_time_seconds=None,
callback=None):
return self.get_messages(
1, visibility_timeout,
wait_time_seconds=wait_time_seconds,
callback=transform(list_first, callback),
)

def write(self, message, delay_seconds=None, callback=None):
return self.connection.send_message(
self, message.get_body_encoded(), delay_seconds,
callback=transform(self._on_message_sent, callback, message),
)

def write_batch(self, messages, callback=None):
return self.connection.send_message_batch(
self, messages, callback=callback,
)

def _on_message_sent(self, orig_message, new_message):
orig_message.id = new_message.id
orig_message.md5 = new_message.md5
return new_message

def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None, callback=None):
return self.connection.receive_message(
self, number_messages=num_messages,
visibility_timeout=visibility_timeout,
attributes=attributes,
wait_time_seconds=wait_time_seconds,
callback=callback,
)

def delete_message(self, message, callback=None):
return self.connection.delete_message(self, message, callback)

def delete_message_batch(self, messages, callback=None):
return self.connection.delete_message_batch(
self, messages, callback=callback,
)

def change_message_visibility_batch(self, messages, callback=None):
return self.connection.change_message_visibility_batch(
self, messages, callback=callback,
)

def delete(self, callback=None):
return self.connection.delete_queue(self, callback=callback)

def count(self, page_size=10, vtimeout=10, callback=None,
_attr='ApproximateNumberOfMessages'):
return self.get_attributes(
_attr, callback=transform(
self._coerce_field_value, callback, _attr, int,
),
)
12 changes: 9 additions & 3 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ def _get_from_sqs(self, queue_name, queue_url,
Uses long polling and returns :class:`~vine.promises.promise`.
"""
return connection.receive_message(
queue_url,
number_messages=count,
queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback)
callback=callback,
)

def _restore(self, message,
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
Expand Down Expand Up @@ -674,6 +674,12 @@ def _purge(self, queue):

def close(self):
super().close()
# if self._asynsqs:
# try:
# self.asynsqs().close()
# except AttributeError as exc: # FIXME ???
# if "can't set attribute" not in str(exc):
# raise

def new_sqs_client(self, region, access_key_id,
secret_access_key, session_token=None):
Expand Down

0 comments on commit 1dfe4f3

Please sign in to comment.