1185 lines
44 KiB
Python
1185 lines
44 KiB
Python
|
|
"""Amazon SQS transport module for Kombu.
|
|||
|
|
|
|||
|
|
This package implements an AMQP-like interface on top of Amazons SQS service,
|
|||
|
|
with the goal of being optimized for high performance and reliability.
|
|||
|
|
|
|||
|
|
The default settings for this module are focused now on high performance in
|
|||
|
|
task queue situations where tasks are small, idempotent and run very fast.
|
|||
|
|
|
|||
|
|
SQS Features supported by this transport
|
|||
|
|
========================================
|
|||
|
|
Long Polling
|
|||
|
|
------------
|
|||
|
|
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-long-polling.html
|
|||
|
|
|
|||
|
|
Long polling is enabled by setting the `wait_time_seconds` transport
|
|||
|
|
option to a number > 1. Amazon supports up to 20 seconds. This is
|
|||
|
|
enabled with 10 seconds by default.
|
|||
|
|
|
|||
|
|
Batch API Actions
|
|||
|
|
-----------------
|
|||
|
|
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-batch-api.html
|
|||
|
|
|
|||
|
|
The default behavior of the SQS Channel.drain_events() method is to
|
|||
|
|
request up to the 'prefetch_count' messages on every request to SQS.
|
|||
|
|
These messages are stored locally in a deque object and passed back
|
|||
|
|
to the Transport until the deque is empty, before triggering a new
|
|||
|
|
API call to Amazon.
|
|||
|
|
|
|||
|
|
This behavior dramatically speeds up the rate that you can pull tasks
|
|||
|
|
from SQS when you have short-running tasks (or a large number of workers).
|
|||
|
|
|
|||
|
|
When a Celery worker has multiple queues to monitor, it will pull down
|
|||
|
|
up to 'prefetch_count' messages from queueA and work on them all before
|
|||
|
|
moving on to queueB. If queueB is empty, it will wait up until
|
|||
|
|
'polling_interval' expires before moving back and checking on queueA.
|
|||
|
|
|
|||
|
|
Message Attributes
|
|||
|
|
-----------------
|
|||
|
|
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
|
|||
|
|
|
|||
|
|
SQS supports sending message attributes along with the message body.
|
|||
|
|
To use this feature, you can pass a 'message_attributes' as keyword argument
|
|||
|
|
to `basic_publish` method.
|
|||
|
|
|
|||
|
|
Other Features supported by this transport
|
|||
|
|
==========================================
|
|||
|
|
Predefined Queues
|
|||
|
|
-----------------
|
|||
|
|
The default behavior of this transport is to use a single AWS credential
|
|||
|
|
pair in order to manage all SQS queues (e.g. listing queues, creating
|
|||
|
|
queues, polling queues, deleting messages).
|
|||
|
|
|
|||
|
|
If it is preferable for your environment to use multiple AWS credentials, you
|
|||
|
|
can use the 'predefined_queues' setting inside the 'transport_options' map.
|
|||
|
|
This setting allows you to specify the SQS queue URL and AWS credentials for
|
|||
|
|
each of your queues. For example, if you have two queues which both already
|
|||
|
|
exist in AWS) you can tell this transport about them as follows:
|
|||
|
|
|
|||
|
|
.. code-block:: python
|
|||
|
|
|
|||
|
|
transport_options = {
|
|||
|
|
'predefined_queues': {
|
|||
|
|
'queue-1': {
|
|||
|
|
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/aaa',
|
|||
|
|
'access_key_id': 'a',
|
|||
|
|
'secret_access_key': 'b',
|
|||
|
|
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
|||
|
|
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
|
|||
|
|
},
|
|||
|
|
'queue-2.fifo': {
|
|||
|
|
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
|
|||
|
|
'access_key_id': 'c',
|
|||
|
|
'secret_access_key': 'd',
|
|||
|
|
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
|
|||
|
|
'backoff_tasks': ['svc.tasks.tasks.task2'] # optional
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
|
|||
|
|
'sts_token_timeout': 900, # optional
|
|||
|
|
'sts_token_buffer_time': 0, # optional, added in 5.6.0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Note that FIFO and standard queues must be named accordingly (the name of
|
|||
|
|
a FIFO queue must end with the .fifo suffix).
|
|||
|
|
|
|||
|
|
backoff_policy & backoff_tasks are optional arguments. These arguments
|
|||
|
|
automatically change the message visibility timeout, in order to have
|
|||
|
|
different times between specific task retries. This would apply after
|
|||
|
|
task failure.
|
|||
|
|
|
|||
|
|
AWS STS authentication is supported, by using sts_role_arn, and
|
|||
|
|
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
|
|||
|
|
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
|
|||
|
|
to 900 seconds. After the mentioned period, a new token will be created.
|
|||
|
|
|
|||
|
|
.. versionadded:: 5.6.0
|
|||
|
|
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
|
|||
|
|
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
|
|||
|
|
should be less than sts_token_timeout.
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
If you authenticate using Okta_ (e.g. calling |gac|_), you can also specify
|
|||
|
|
a 'session_token' to connect to a queue. Note that those tokens have a
|
|||
|
|
limited lifetime and are therefore only suited for short-lived tests.
|
|||
|
|
|
|||
|
|
.. _Okta: https://www.okta.com/
|
|||
|
|
.. _gac: https://github.com/Nike-Inc/gimme-aws-creds#readme
|
|||
|
|
.. |gac| replace:: ``gimme-aws-creds``
|
|||
|
|
|
|||
|
|
|
|||
|
|
Client config
|
|||
|
|
-------------
|
|||
|
|
In some cases you may need to override the botocore config. You can do it
|
|||
|
|
as follows:
|
|||
|
|
|
|||
|
|
.. code-block:: python
|
|||
|
|
|
|||
|
|
transport_option = {
|
|||
|
|
'client-config': {
|
|||
|
|
'connect_timeout': 5,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
For a complete list of settings you can adjust using this option see
|
|||
|
|
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
|||
|
|
|
|||
|
|
Features
|
|||
|
|
========
|
|||
|
|
* Type: Virtual
|
|||
|
|
* Supports Direct: Yes
|
|||
|
|
* Supports Topic: Yes
|
|||
|
|
* Supports Fanout: Yes
|
|||
|
|
* Supports Priority: No
|
|||
|
|
* Supports TTL: No
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import base64
|
|||
|
|
import binascii
|
|||
|
|
import re
|
|||
|
|
import socket
|
|||
|
|
import string
|
|||
|
|
import uuid
|
|||
|
|
from datetime import datetime, timedelta, timezone
|
|||
|
|
from json import JSONDecodeError
|
|||
|
|
from queue import Empty
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from botocore.client import Config
|
|||
|
|
from botocore.exceptions import ClientError
|
|||
|
|
from vine import ensure_promise, promise, transform
|
|||
|
|
|
|||
|
|
from kombu.asynchronous import get_event_loop
|
|||
|
|
from kombu.asynchronous.aws.ext import boto3, exceptions
|
|||
|
|
from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
|
|||
|
|
from kombu.asynchronous.aws.sqs.message import AsyncMessage
|
|||
|
|
from kombu.log import get_logger
|
|||
|
|
from kombu.utils import scheduling
|
|||
|
|
from kombu.utils.encoding import bytes_to_str, safe_str
|
|||
|
|
from kombu.utils.json import dumps, loads
|
|||
|
|
from kombu.utils.objects import cached_property
|
|||
|
|
|
|||
|
|
from . import virtual
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
# dots are replaced by dash, dash remains dash, all other punctuation
|
|||
|
|
# replaced by underscore.
|
|||
|
|
CHARS_REPLACE_TABLE = {
|
|||
|
|
ord(c): 0x5f for c in string.punctuation if c not in '-_.'
|
|||
|
|
}
|
|||
|
|
CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-'
|
|||
|
|
|
|||
|
|
#: SQS bulk get supports a maximum of 10 messages at a time.
|
|||
|
|
SQS_MAX_MESSAGES = 10
|
|||
|
|
|
|||
|
|
|
|||
|
|
def maybe_int(x):
|
|||
|
|
"""Try to convert x' to int, or return x' if that fails."""
|
|||
|
|
try:
|
|||
|
|
return int(x)
|
|||
|
|
except ValueError:
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UndefinedQueueException(Exception):
|
|||
|
|
"""Predefined queues are being used and an undefined queue was used."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class InvalidQueueException(Exception):
|
|||
|
|
"""Predefined queues are being used and configuration is not valid."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AccessDeniedQueueException(Exception):
|
|||
|
|
"""Raised when access to the AWS queue is denied.
|
|||
|
|
|
|||
|
|
This may occur if the permissions are not correctly set or the
|
|||
|
|
credentials are invalid.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DoesNotExistQueueException(Exception):
|
|||
|
|
"""The specified queue doesn't exist."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QoS(virtual.QoS):
|
|||
|
|
"""Quality of Service guarantees implementation for SQS."""
|
|||
|
|
|
|||
|
|
def reject(self, delivery_tag, requeue=False):
|
|||
|
|
super().reject(delivery_tag, requeue=requeue)
|
|||
|
|
routing_key, message, backoff_tasks, backoff_policy = \
|
|||
|
|
self._extract_backoff_policy_configuration_and_message(
|
|||
|
|
delivery_tag)
|
|||
|
|
if routing_key and message and backoff_tasks and backoff_policy:
|
|||
|
|
self.apply_backoff_policy(
|
|||
|
|
routing_key, delivery_tag, backoff_policy, backoff_tasks)
|
|||
|
|
|
|||
|
|
def _extract_backoff_policy_configuration_and_message(self, delivery_tag):
|
|||
|
|
try:
|
|||
|
|
message = self._delivered[delivery_tag]
|
|||
|
|
routing_key = message.delivery_info['routing_key']
|
|||
|
|
except KeyError:
|
|||
|
|
return None, None, None, None
|
|||
|
|
if not routing_key or not message:
|
|||
|
|
return None, None, None, None
|
|||
|
|
queue_config = self.channel.predefined_queues.get(routing_key, {})
|
|||
|
|
backoff_tasks = queue_config.get('backoff_tasks')
|
|||
|
|
backoff_policy = queue_config.get('backoff_policy')
|
|||
|
|
return routing_key, message, backoff_tasks, backoff_policy
|
|||
|
|
|
|||
|
|
def apply_backoff_policy(self, routing_key, delivery_tag,
|
|||
|
|
backoff_policy, backoff_tasks):
|
|||
|
|
queue_url = self.channel._queue_cache[routing_key]
|
|||
|
|
task_name, number_of_retries = \
|
|||
|
|
self.extract_task_name_and_number_of_retries(delivery_tag)
|
|||
|
|
if not task_name or not number_of_retries:
|
|||
|
|
return None
|
|||
|
|
policy_value = backoff_policy.get(number_of_retries)
|
|||
|
|
if task_name in backoff_tasks and policy_value is not None:
|
|||
|
|
c = self.channel.sqs(routing_key)
|
|||
|
|
c.change_message_visibility(
|
|||
|
|
QueueUrl=queue_url,
|
|||
|
|
ReceiptHandle=delivery_tag,
|
|||
|
|
VisibilityTimeout=policy_value
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def extract_task_name_and_number_of_retries(self, delivery_tag):
|
|||
|
|
message = self._delivered[delivery_tag]
|
|||
|
|
message_headers = message.headers
|
|||
|
|
task_name = message_headers['task']
|
|||
|
|
number_of_retries = int(
|
|||
|
|
message.properties['delivery_info']['sqs_message']
|
|||
|
|
['Attributes']['ApproximateReceiveCount'])
|
|||
|
|
return task_name, number_of_retries
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Channel(virtual.Channel):
|
|||
|
|
"""SQS Channel."""
|
|||
|
|
|
|||
|
|
default_region = 'us-east-1'
|
|||
|
|
default_visibility_timeout = 1800 # 30 minutes.
|
|||
|
|
default_wait_time_seconds = 10 # up to 20 seconds max
|
|||
|
|
domain_format = 'kombu%(vhost)s'
|
|||
|
|
_asynsqs = None
|
|||
|
|
_predefined_queue_async_clients = {} # A client for each predefined queue
|
|||
|
|
_sqs = None
|
|||
|
|
_predefined_queue_clients = {} # A client for each predefined queue
|
|||
|
|
_queue_cache = {} # SQS queue name => SQS queue URL
|
|||
|
|
_noack_queues = set()
|
|||
|
|
QoS = QoS
|
|||
|
|
# https://stackoverflow.com/questions/475074/regex-to-parse-or-validate-base64-data
|
|||
|
|
B64_REGEX = re.compile(rb'^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$')
|
|||
|
|
|
|||
|
|
def __init__(self, *args, **kwargs):
|
|||
|
|
if boto3 is None:
|
|||
|
|
raise ImportError('boto3 is not installed')
|
|||
|
|
super().__init__(*args, **kwargs)
|
|||
|
|
self._validate_predifined_queues()
|
|||
|
|
|
|||
|
|
# SQS blows up if you try to create a new queue when one already
|
|||
|
|
# exists but with a different visibility_timeout. This prepopulates
|
|||
|
|
# the queue_cache to protect us from recreating
|
|||
|
|
# queues that are known to already exist.
|
|||
|
|
self._update_queue_cache(self.queue_name_prefix)
|
|||
|
|
|
|||
|
|
self.hub = kwargs.get('hub') or get_event_loop()
|
|||
|
|
|
|||
|
|
def _validate_predifined_queues(self):
|
|||
|
|
"""Check that standard and FIFO queues are named properly.
|
|||
|
|
|
|||
|
|
AWS requires FIFO queues to have a name
|
|||
|
|
that ends with the .fifo suffix.
|
|||
|
|
"""
|
|||
|
|
for queue_name, q in self.predefined_queues.items():
|
|||
|
|
fifo_url = q['url'].endswith('.fifo')
|
|||
|
|
fifo_name = queue_name.endswith('.fifo')
|
|||
|
|
if fifo_url and not fifo_name:
|
|||
|
|
raise InvalidQueueException(
|
|||
|
|
"Queue with url '{}' must have a name "
|
|||
|
|
"ending with .fifo".format(q['url'])
|
|||
|
|
)
|
|||
|
|
elif not fifo_url and fifo_name:
|
|||
|
|
raise InvalidQueueException(
|
|||
|
|
"Queue with name '{}' is not a FIFO queue: "
|
|||
|
|
"'{}'".format(queue_name, q['url'])
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _update_queue_cache(self, queue_name_prefix):
|
|||
|
|
if self.predefined_queues:
|
|||
|
|
for queue_name, q in self.predefined_queues.items():
|
|||
|
|
self._queue_cache[queue_name] = q['url']
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
resp = self.sqs().list_queues(QueueNamePrefix=queue_name_prefix)
|
|||
|
|
for url in resp.get('QueueUrls', []):
|
|||
|
|
queue_name = url.split('/')[-1]
|
|||
|
|
self._queue_cache[queue_name] = url
|
|||
|
|
|
|||
|
|
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
|||
|
|
if no_ack:
|
|||
|
|
self._noack_queues.add(queue)
|
|||
|
|
if self.hub:
|
|||
|
|
self._loop1(queue)
|
|||
|
|
return super().basic_consume(
|
|||
|
|
queue, no_ack, *args, **kwargs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def basic_cancel(self, consumer_tag):
|
|||
|
|
if consumer_tag in self._consumers:
|
|||
|
|
queue = self._tag_to_queue[consumer_tag]
|
|||
|
|
self._noack_queues.discard(queue)
|
|||
|
|
return super().basic_cancel(consumer_tag)
|
|||
|
|
|
|||
|
|
def drain_events(self, timeout=None, callback=None, **kwargs):
|
|||
|
|
"""Return a single payload message from one of our queues.
|
|||
|
|
|
|||
|
|
Raises
|
|||
|
|
------
|
|||
|
|
Queue.Empty: if no messages available.
|
|||
|
|
"""
|
|||
|
|
# If we're not allowed to consume or have no consumers, raise Empty
|
|||
|
|
if not self._consumers or not self.qos.can_consume():
|
|||
|
|
raise Empty()
|
|||
|
|
|
|||
|
|
# At this point, go and get more messages from SQS
|
|||
|
|
self._poll(self.cycle, callback, timeout=timeout)
|
|||
|
|
|
|||
|
|
def _reset_cycle(self):
|
|||
|
|
"""Reset the consume cycle.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
FairCycle: object that points to our _get_bulk() method
|
|||
|
|
rather than the standard _get() method. This allows for
|
|||
|
|
multiple messages to be returned at once from SQS (
|
|||
|
|
based on the prefetch limit).
|
|||
|
|
"""
|
|||
|
|
self._cycle = scheduling.FairCycle(
|
|||
|
|
self._get_bulk, self._active_queues, Empty,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
|||
|
|
"""Format AMQP queue name into a legal SQS queue name."""
|
|||
|
|
if name.endswith('.fifo'):
|
|||
|
|
partial = name[:-len('.fifo')]
|
|||
|
|
partial = str(safe_str(partial)).translate(table)
|
|||
|
|
return partial + '.fifo'
|
|||
|
|
else:
|
|||
|
|
return str(safe_str(name)).translate(table)
|
|||
|
|
|
|||
|
|
def canonical_queue_name(self, queue_name):
|
|||
|
|
return self.entity_name(self.queue_name_prefix + queue_name)
|
|||
|
|
|
|||
|
|
def _resolve_queue_url(self, queue):
|
|||
|
|
"""Try to retrieve the SQS queue URL for a given queue name."""
|
|||
|
|
# Translate to SQS name for consistency with initial
|
|||
|
|
# _queue_cache population.
|
|||
|
|
sqs_qname = self.canonical_queue_name(queue)
|
|||
|
|
|
|||
|
|
# The SQS ListQueues method only returns 1000 queues. When you have
|
|||
|
|
# so many queues, it's possible that the queue you are looking for is
|
|||
|
|
# not cached. In this case, we could update the cache with the exact
|
|||
|
|
# queue name first.
|
|||
|
|
if sqs_qname not in self._queue_cache:
|
|||
|
|
self._update_queue_cache(sqs_qname)
|
|||
|
|
try:
|
|||
|
|
return self._queue_cache[sqs_qname]
|
|||
|
|
except KeyError:
|
|||
|
|
if self.predefined_queues:
|
|||
|
|
raise UndefinedQueueException((
|
|||
|
|
"Queue with name '{}' must be "
|
|||
|
|
"defined in 'predefined_queues'."
|
|||
|
|
).format(sqs_qname))
|
|||
|
|
|
|||
|
|
raise DoesNotExistQueueException(
|
|||
|
|
f"Queue with name '{sqs_qname}' doesn't exist in SQS"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _new_queue(self, queue, **kwargs):
|
|||
|
|
"""Ensure a queue with given name exists in SQS.
|
|||
|
|
|
|||
|
|
Arguments:
|
|||
|
|
---------
|
|||
|
|
queue (str): the AMQP queue name
|
|||
|
|
Returns
|
|||
|
|
str: the SQS queue URL
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
return self._resolve_queue_url(queue)
|
|||
|
|
except DoesNotExistQueueException:
|
|||
|
|
sqs_qname = self.canonical_queue_name(queue)
|
|||
|
|
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
|
|||
|
|
if sqs_qname.endswith('.fifo'):
|
|||
|
|
attributes['FifoQueue'] = 'true'
|
|||
|
|
|
|||
|
|
resp = self._create_queue(sqs_qname, attributes)
|
|||
|
|
self._queue_cache[sqs_qname] = resp['QueueUrl']
|
|||
|
|
return resp['QueueUrl']
|
|||
|
|
|
|||
|
|
def _create_queue(self, queue_name, attributes):
|
|||
|
|
"""Create an SQS queue with a given name and nominal attributes."""
|
|||
|
|
# Allow specifying additional boto create_queue Attributes
|
|||
|
|
# via transport options
|
|||
|
|
if self.predefined_queues:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
attributes.update(
|
|||
|
|
self.transport_options.get('sqs-creation-attributes') or {},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
queue_tags = self.transport_options.get('queue_tags')
|
|||
|
|
|
|||
|
|
create_params = {
|
|||
|
|
'QueueName': queue_name,
|
|||
|
|
'Attributes': attributes,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if queue_tags:
|
|||
|
|
create_params['tags'] = queue_tags
|
|||
|
|
|
|||
|
|
return self.sqs(queue=queue_name).create_queue(**create_params)
|
|||
|
|
|
|||
|
|
def _delete(self, queue, *args, **kwargs):
|
|||
|
|
"""Delete queue by name."""
|
|||
|
|
if self.predefined_queues:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
q_url = self._resolve_queue_url(queue)
|
|||
|
|
self.sqs().delete_queue(
|
|||
|
|
QueueUrl=q_url,
|
|||
|
|
)
|
|||
|
|
self._queue_cache.pop(queue, None)
|
|||
|
|
|
|||
|
|
def _put(self, queue, message, **kwargs):
|
|||
|
|
"""Put message onto queue."""
|
|||
|
|
q_url = self._new_queue(queue)
|
|||
|
|
kwargs = {'QueueUrl': q_url}
|
|||
|
|
if 'properties' in message:
|
|||
|
|
if 'message_attributes' in message['properties']:
|
|||
|
|
# we don't want to want to have the attribute in the body
|
|||
|
|
kwargs['MessageAttributes'] = \
|
|||
|
|
message['properties'].pop('message_attributes')
|
|||
|
|
if queue.endswith('.fifo'):
|
|||
|
|
if 'MessageGroupId' in message['properties']:
|
|||
|
|
kwargs['MessageGroupId'] = \
|
|||
|
|
message['properties']['MessageGroupId']
|
|||
|
|
else:
|
|||
|
|
kwargs['MessageGroupId'] = 'default'
|
|||
|
|
if 'MessageDeduplicationId' in message['properties']:
|
|||
|
|
kwargs['MessageDeduplicationId'] = \
|
|||
|
|
message['properties']['MessageDeduplicationId']
|
|||
|
|
else:
|
|||
|
|
kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
|
|||
|
|
else:
|
|||
|
|
if "DelaySeconds" in message['properties']:
|
|||
|
|
kwargs['DelaySeconds'] = \
|
|||
|
|
message['properties']['DelaySeconds']
|
|||
|
|
|
|||
|
|
if self.sqs_base64_encoding:
|
|||
|
|
body = AsyncMessage().encode(dumps(message))
|
|||
|
|
else:
|
|||
|
|
body = dumps(message)
|
|||
|
|
kwargs['MessageBody'] = body
|
|||
|
|
|
|||
|
|
c = self.sqs(queue=self.canonical_queue_name(queue))
|
|||
|
|
if message.get('redelivered'):
|
|||
|
|
c.change_message_visibility(
|
|||
|
|
QueueUrl=q_url,
|
|||
|
|
ReceiptHandle=message['properties']['delivery_tag'],
|
|||
|
|
VisibilityTimeout=self.wait_time_seconds
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
c.send_message(**kwargs)
|
|||
|
|
|
|||
|
|
def _message_to_python(self, message, queue_name, q_url):
|
|||
|
|
raw_msg_body = message['Body']
|
|||
|
|
decoded_bytes = self._decode_python_message_body(raw_msg_body)
|
|||
|
|
text = bytes_to_str(decoded_bytes)
|
|||
|
|
|
|||
|
|
payload = self._prepare_json_payload(text)
|
|||
|
|
|
|||
|
|
# handle no-ack queues immediately
|
|||
|
|
if queue_name in self._noack_queues:
|
|||
|
|
self._delete_message(queue_name, message)
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
return self._envelope_payload(payload, text, message, q_url)
|
|||
|
|
|
|||
|
|
def _messages_to_python(self, messages, queue):
|
|||
|
|
"""Convert a list of SQS Message objects into Payloads.
|
|||
|
|
|
|||
|
|
This method handles converting SQS Message objects into
|
|||
|
|
Payloads, and appropriately updating the queue depending on
|
|||
|
|
the 'ack' settings for that queue.
|
|||
|
|
|
|||
|
|
Arguments:
|
|||
|
|
---------
|
|||
|
|
messages (SQSMessage): A list of SQS Message objects.
|
|||
|
|
queue (str): Name representing the queue they came from.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
List: A list of Payload objects
|
|||
|
|
"""
|
|||
|
|
q_url = self._new_queue(queue)
|
|||
|
|
return [self._message_to_python(m, queue, q_url) for m in messages]
|
|||
|
|
|
|||
|
|
def _receive_message(
|
|||
|
|
self,
|
|||
|
|
queue: str,
|
|||
|
|
max_number_of_messages: int = 1,
|
|||
|
|
wait_time_seconds: int | None = None
|
|||
|
|
):
|
|||
|
|
"""Unified receive_message wrapper for SQS (boto3.client.SQS) with full attribute support.
|
|||
|
|
|
|||
|
|
:param queue: The queue as a string
|
|||
|
|
:param max_number_of_messages: Int of max number of messages to receive.
|
|||
|
|
:param wait_time_seconds: Int of sqs wait time in seconds.
|
|||
|
|
:return: SQS client recieve_message
|
|||
|
|
"""
|
|||
|
|
q_url: str = self._new_queue(queue)
|
|||
|
|
client = self.sqs(queue=queue)
|
|||
|
|
|
|||
|
|
message_system_attribute_names = self.get_message_attributes.get(
|
|||
|
|
'MessageSystemAttributeNames') or []
|
|||
|
|
|
|||
|
|
message_attribute_names = self.get_message_attributes.get(
|
|||
|
|
'MessageAttributeNames') or []
|
|||
|
|
|
|||
|
|
params: dict[str, Any] = {
|
|||
|
|
'QueueUrl': q_url,
|
|||
|
|
'MaxNumberOfMessages': max_number_of_messages,
|
|||
|
|
'WaitTimeSeconds': wait_time_seconds or self.wait_time_seconds,
|
|||
|
|
'MessageAttributeNames': message_attribute_names,
|
|||
|
|
'MessageSystemAttributeNames': message_system_attribute_names
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return client.receive_message(**params)
|
|||
|
|
|
|||
|
|
def _get_bulk(self, queue,
|
|||
|
|
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
|
|||
|
|
"""Try to retrieve multiple messages off ``queue``.
|
|||
|
|
|
|||
|
|
Where :meth:`_get` returns a single Payload object, this method
|
|||
|
|
returns a list of Payload objects. The number of objects returned
|
|||
|
|
is determined by the total number of messages available in the queue
|
|||
|
|
and the number of messages the QoS object allows (based on the
|
|||
|
|
prefetch_count).
|
|||
|
|
|
|||
|
|
Note:
|
|||
|
|
----
|
|||
|
|
Ignores QoS limits so caller is responsible for checking
|
|||
|
|
that we are allowed to consume at least one message from the
|
|||
|
|
queue. get_bulk will then ask QoS for an estimate of
|
|||
|
|
the number of extra messages that we can consume.
|
|||
|
|
|
|||
|
|
Arguments:
|
|||
|
|
---------
|
|||
|
|
queue (str): The queue name to pull from.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
List[Message]
|
|||
|
|
"""
|
|||
|
|
# drain_events calls `can_consume` first, consuming
|
|||
|
|
# a token, so we know that we are allowed to consume at least
|
|||
|
|
# one message.
|
|||
|
|
|
|||
|
|
# Note: ignoring max_messages for SQS with boto3
|
|||
|
|
max_count = self._get_message_estimate()
|
|||
|
|
if max_count:
|
|||
|
|
resp = self._receive_message(
|
|||
|
|
queue=queue,
|
|||
|
|
wait_time_seconds=self.wait_time_seconds,
|
|||
|
|
max_number_of_messages=max_count
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if resp.get('Messages'):
|
|||
|
|
for m in resp['Messages']:
|
|||
|
|
m['Body'] = AsyncMessage(body=m['Body']).decode()
|
|||
|
|
for msg in self._messages_to_python(resp['Messages'], queue):
|
|||
|
|
self.connection._deliver(msg, queue)
|
|||
|
|
return
|
|||
|
|
raise Empty()
|
|||
|
|
|
|||
|
|
def _get(self, queue):
|
|||
|
|
"""Try to retrieve a single message off ``queue``."""
|
|||
|
|
resp = self._receive_message(
|
|||
|
|
queue=queue,
|
|||
|
|
wait_time_seconds=self.wait_time_seconds,
|
|||
|
|
max_number_of_messages=1
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if resp.get('Messages'):
|
|||
|
|
body = AsyncMessage(body=resp['Messages'][0]['Body']).decode()
|
|||
|
|
resp['Messages'][0]['Body'] = body
|
|||
|
|
return self._messages_to_python(resp['Messages'], queue)[0]
|
|||
|
|
raise Empty()
|
|||
|
|
|
|||
|
|
def _loop1(self, queue, _=None):
|
|||
|
|
self.hub.call_soon(self._schedule_queue, queue)
|
|||
|
|
|
|||
|
|
def _schedule_queue(self, queue):
|
|||
|
|
if queue in self._active_queues:
|
|||
|
|
if self.qos.can_consume():
|
|||
|
|
self._get_bulk_async(
|
|||
|
|
queue, callback=promise(self._loop1, (queue,)),
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
self._loop1(queue)
|
|||
|
|
|
|||
|
|
def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
|
|||
|
|
maxcount = self.qos.can_consume_max_estimate()
|
|||
|
|
return min(
|
|||
|
|
max_if_unlimited if maxcount is None else max(maxcount, 1),
|
|||
|
|
max_if_unlimited,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _get_bulk_async(self, queue, callback=None):
|
|||
|
|
maxcount = self._get_message_estimate()
|
|||
|
|
if maxcount:
|
|||
|
|
return self._get_async(queue, maxcount, callback=callback)
|
|||
|
|
# Not allowed to consume, make sure to notify callback..
|
|||
|
|
callback = ensure_promise(callback)
|
|||
|
|
callback([])
|
|||
|
|
return callback
|
|||
|
|
|
|||
|
|
def _get_async(self, queue, count=1, callback=None):
|
|||
|
|
q_url = self._new_queue(queue)
|
|||
|
|
qname = self.canonical_queue_name(queue)
|
|||
|
|
return self._get_from_sqs(
|
|||
|
|
queue_name=qname, queue_url=q_url, count=count,
|
|||
|
|
connection=self.asynsqs(queue=qname),
|
|||
|
|
callback=transform(
|
|||
|
|
self._on_messages_ready, callback, q_url, queue
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _on_messages_ready(self, queue, qname, messages):
|
|||
|
|
if 'Messages' in messages and messages['Messages']:
|
|||
|
|
callbacks = self.connection._callbacks
|
|||
|
|
for msg in messages['Messages']:
|
|||
|
|
msg_parsed = self._message_to_python(msg, qname, queue)
|
|||
|
|
callbacks[qname](msg_parsed)
|
|||
|
|
|
|||
|
|
def _get_from_sqs(self, queue_name, queue_url,
|
|||
|
|
connection, count=1, callback=None):
|
|||
|
|
"""Retrieve and handle messages from SQS.
|
|||
|
|
|
|||
|
|
Uses long polling and returns :class:`~vine.promises.promise`.
|
|||
|
|
"""
|
|||
|
|
return connection.receive_message(
|
|||
|
|
queue_name, queue_url, number_messages=count,
|
|||
|
|
wait_time_seconds=self.wait_time_seconds,
|
|||
|
|
callback=callback,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _restore(self, message,
|
|||
|
|
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
|
|||
|
|
for unwanted_key in unwanted_delivery_info:
|
|||
|
|
# Remove objects that aren't JSON serializable (Issue #1108).
|
|||
|
|
message.delivery_info.pop(unwanted_key, None)
|
|||
|
|
return super()._restore(message)
|
|||
|
|
|
|||
|
|
def basic_ack(self, delivery_tag, multiple=False):
|
|||
|
|
try:
|
|||
|
|
message = self.qos.get(delivery_tag).delivery_info
|
|||
|
|
sqs_message = message['sqs_message']
|
|||
|
|
except KeyError:
|
|||
|
|
super().basic_ack(delivery_tag)
|
|||
|
|
else:
|
|||
|
|
queue = None
|
|||
|
|
if 'routing_key' in message:
|
|||
|
|
queue = self.canonical_queue_name(message['routing_key'])
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self.sqs(queue=queue).delete_message(
|
|||
|
|
QueueUrl=message['sqs_queue'],
|
|||
|
|
ReceiptHandle=sqs_message['ReceiptHandle']
|
|||
|
|
)
|
|||
|
|
except ClientError as exception:
|
|||
|
|
if exception.response['Error']['Code'] == 'AccessDenied':
|
|||
|
|
raise AccessDeniedQueueException(
|
|||
|
|
exception.response["Error"]["Message"]
|
|||
|
|
)
|
|||
|
|
super().basic_reject(delivery_tag)
|
|||
|
|
else:
|
|||
|
|
super().basic_ack(delivery_tag)
|
|||
|
|
|
|||
|
|
def _size(self, queue):
|
|||
|
|
"""Return the number of messages in a queue."""
|
|||
|
|
q_url = self._new_queue(queue)
|
|||
|
|
c = self.sqs(queue=self.canonical_queue_name(queue))
|
|||
|
|
resp = c.get_queue_attributes(
|
|||
|
|
QueueUrl=q_url,
|
|||
|
|
AttributeNames=['ApproximateNumberOfMessages'])
|
|||
|
|
return int(resp['Attributes']['ApproximateNumberOfMessages'])
|
|||
|
|
|
|||
|
|
def _purge(self, queue):
|
|||
|
|
"""Delete all current messages in a queue."""
|
|||
|
|
q_url = self._new_queue(queue)
|
|||
|
|
# SQS is slow at registering messages, so run for a few
|
|||
|
|
# iterations to ensure messages are detected and deleted.
|
|||
|
|
size = 0
|
|||
|
|
for i in range(10):
|
|||
|
|
size += int(self._size(queue))
|
|||
|
|
if not size:
|
|||
|
|
break
|
|||
|
|
self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
|
|||
|
|
return size
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
super().close()
|
|||
|
|
# if self._asynsqs:
|
|||
|
|
# try:
|
|||
|
|
# self.asynsqs().close()
|
|||
|
|
# except AttributeError as exc: # FIXME ???
|
|||
|
|
# if "can't set attribute" not in str(exc):
|
|||
|
|
# raise
|
|||
|
|
|
|||
|
|
def new_sqs_client(self, region, access_key_id,
|
|||
|
|
secret_access_key, session_token=None):
|
|||
|
|
session = boto3.session.Session(
|
|||
|
|
region_name=region,
|
|||
|
|
aws_access_key_id=access_key_id,
|
|||
|
|
aws_secret_access_key=secret_access_key,
|
|||
|
|
aws_session_token=session_token,
|
|||
|
|
)
|
|||
|
|
is_secure = self.is_secure if self.is_secure is not None else True
|
|||
|
|
client_kwargs = {
|
|||
|
|
'use_ssl': is_secure
|
|||
|
|
}
|
|||
|
|
if self.endpoint_url is not None:
|
|||
|
|
client_kwargs['endpoint_url'] = self.endpoint_url
|
|||
|
|
client_config = self.transport_options.get('client-config') or {}
|
|||
|
|
config = Config(**client_config)
|
|||
|
|
return session.client('sqs', config=config, **client_kwargs)
|
|||
|
|
|
|||
|
|
def sqs(self, queue=None):
|
|||
|
|
if queue is not None and self.predefined_queues:
|
|||
|
|
|
|||
|
|
if queue not in self.predefined_queues:
|
|||
|
|
raise UndefinedQueueException(
|
|||
|
|
f"Queue with name '{queue}' must be defined"
|
|||
|
|
" in 'predefined_queues'.")
|
|||
|
|
q = self.predefined_queues[queue]
|
|||
|
|
if self.transport_options.get('sts_role_arn'):
|
|||
|
|
return self._handle_sts_session(queue, q)
|
|||
|
|
if not self.transport_options.get('sts_role_arn'):
|
|||
|
|
if queue in self._predefined_queue_clients:
|
|||
|
|
return self._predefined_queue_clients[queue]
|
|||
|
|
else:
|
|||
|
|
c = self._predefined_queue_clients[queue] = \
|
|||
|
|
self.new_sqs_client(
|
|||
|
|
region=q.get('region', self.region),
|
|||
|
|
access_key_id=q.get(
|
|||
|
|
'access_key_id', self.conninfo.userid),
|
|||
|
|
secret_access_key=q.get(
|
|||
|
|
'secret_access_key', self.conninfo.password)
|
|||
|
|
)
|
|||
|
|
return c
|
|||
|
|
|
|||
|
|
if self._sqs is not None:
|
|||
|
|
return self._sqs
|
|||
|
|
|
|||
|
|
c = self._sqs = self.new_sqs_client(
|
|||
|
|
region=self.region,
|
|||
|
|
access_key_id=self.conninfo.userid,
|
|||
|
|
secret_access_key=self.conninfo.password,
|
|||
|
|
)
|
|||
|
|
return c
|
|||
|
|
|
|||
|
|
def _handle_sts_session(self, queue, q):
|
|||
|
|
region = q.get('region', self.region)
|
|||
|
|
if not hasattr(self, 'sts_expiration'): # STS token - token init
|
|||
|
|
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
|||
|
|
# STS token - refresh if expired
|
|||
|
|
elif self.sts_expiration.replace(tzinfo=None) < datetime.now(timezone.utc).replace(tzinfo=None):
|
|||
|
|
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
|||
|
|
else: # STS token - ruse existing
|
|||
|
|
if queue not in self._predefined_queue_clients:
|
|||
|
|
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
|||
|
|
return self._predefined_queue_clients[queue]
|
|||
|
|
|
|||
|
|
def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
|
|||
|
|
"""Generate STS session credentials with an optional expiration buffer.
|
|||
|
|
|
|||
|
|
The buffer is only applied if it is less than `token_expiry_seconds` to prevent an expired token.
|
|||
|
|
"""
|
|||
|
|
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
|
|||
|
|
if token_buffer_seconds and 0 < token_buffer_seconds < token_expiry_seconds:
|
|||
|
|
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
|
|||
|
|
return credentials
|
|||
|
|
|
|||
|
|
def _new_predefined_queue_client_with_sts_session(self, queue, region):
|
|||
|
|
sts_creds = self.generate_sts_session_token_with_buffer(
|
|||
|
|
self.transport_options.get('sts_role_arn'),
|
|||
|
|
self.transport_options.get('sts_token_timeout', 900),
|
|||
|
|
self.transport_options.get('sts_token_buffer_time', 0),
|
|||
|
|
)
|
|||
|
|
self.sts_expiration = sts_creds['Expiration']
|
|||
|
|
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
|||
|
|
region=region,
|
|||
|
|
access_key_id=sts_creds['AccessKeyId'],
|
|||
|
|
secret_access_key=sts_creds['SecretAccessKey'],
|
|||
|
|
session_token=sts_creds['SessionToken'],
|
|||
|
|
)
|
|||
|
|
return c
|
|||
|
|
|
|||
|
|
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
|
|||
|
|
sts_client = boto3.client('sts')
|
|||
|
|
sts_policy = sts_client.assume_role(
|
|||
|
|
RoleArn=role_arn,
|
|||
|
|
RoleSessionName='Celery',
|
|||
|
|
DurationSeconds=token_expiry_seconds
|
|||
|
|
)
|
|||
|
|
return sts_policy['Credentials']
|
|||
|
|
|
|||
|
|
def asynsqs(self, queue=None):
|
|||
|
|
message_system_attribute_names = self.get_message_attributes.get(
|
|||
|
|
'MessageSystemAttributeNames')
|
|||
|
|
message_attribute_names = self.get_message_attributes.get(
|
|||
|
|
'MessageAttributeNames')
|
|||
|
|
|
|||
|
|
if queue is not None and self.predefined_queues:
|
|||
|
|
if queue in self._predefined_queue_async_clients and \
|
|||
|
|
not hasattr(self, 'sts_expiration'):
|
|||
|
|
return self._predefined_queue_async_clients[queue]
|
|||
|
|
if queue not in self.predefined_queues:
|
|||
|
|
raise UndefinedQueueException((
|
|||
|
|
"Queue with name '{}' must be defined in "
|
|||
|
|
"'predefined_queues'."
|
|||
|
|
).format(queue))
|
|||
|
|
q = self.predefined_queues[queue]
|
|||
|
|
c = self._predefined_queue_async_clients[queue] = \
|
|||
|
|
AsyncSQSConnection(
|
|||
|
|
sqs_connection=self.sqs(queue=queue),
|
|||
|
|
region=q.get('region', self.region),
|
|||
|
|
message_system_attribute_names=message_system_attribute_names,
|
|||
|
|
message_attribute_names=message_attribute_names
|
|||
|
|
)
|
|||
|
|
return c
|
|||
|
|
|
|||
|
|
if self._asynsqs is not None:
|
|||
|
|
return self._asynsqs
|
|||
|
|
|
|||
|
|
c = self._asynsqs = AsyncSQSConnection(
|
|||
|
|
sqs_connection=self.sqs(queue=queue),
|
|||
|
|
region=self.region,
|
|||
|
|
message_system_attribute_names=message_system_attribute_names,
|
|||
|
|
message_attribute_names=message_attribute_names
|
|||
|
|
)
|
|||
|
|
return c
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def conninfo(self):
|
|||
|
|
return self.connection.client
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def transport_options(self):
|
|||
|
|
return self.connection.client.transport_options
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def visibility_timeout(self):
|
|||
|
|
return (self.transport_options.get('visibility_timeout') or
|
|||
|
|
self.default_visibility_timeout)
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def predefined_queues(self):
|
|||
|
|
"""Map of queue_name to predefined queue settings."""
|
|||
|
|
return self.transport_options.get('predefined_queues', {})
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def queue_name_prefix(self):
|
|||
|
|
return self.transport_options.get('queue_name_prefix', '')
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def supports_fanout(self):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def region(self):
|
|||
|
|
return (self.transport_options.get('region') or
|
|||
|
|
boto3.Session().region_name or
|
|||
|
|
self.default_region)
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def regioninfo(self):
|
|||
|
|
return self.transport_options.get('regioninfo')
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def is_secure(self):
|
|||
|
|
return self.transport_options.get('is_secure')
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def port(self):
|
|||
|
|
return self.transport_options.get('port')
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def endpoint_url(self):
|
|||
|
|
if self.conninfo.hostname is not None:
|
|||
|
|
scheme = 'https' if self.is_secure else 'http'
|
|||
|
|
if self.conninfo.port is not None:
|
|||
|
|
port = f':{self.conninfo.port}'
|
|||
|
|
else:
|
|||
|
|
port = ''
|
|||
|
|
return '{}://{}{}'.format(
|
|||
|
|
scheme,
|
|||
|
|
self.conninfo.hostname,
|
|||
|
|
port
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def wait_time_seconds(self) -> int:
|
|||
|
|
return self.transport_options.get('wait_time_seconds',
|
|||
|
|
self.default_wait_time_seconds)
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def sqs_base64_encoding(self):
|
|||
|
|
return self.transport_options.get('sqs_base64_encoding', True)
|
|||
|
|
|
|||
|
|
@cached_property
|
|||
|
|
def fetch_message_attributes(self):
|
|||
|
|
return self.transport_options.get('fetch_message_attributes', None)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def get_message_attributes(self) -> dict[str, Any]:
|
|||
|
|
"""Get the message attributes to be fetched from SQS.
|
|||
|
|
|
|||
|
|
Ensures 'ApproximateReceiveCount' is included in system attributes if list is provided.
|
|||
|
|
- The number of retries is managed by SQS /
|
|||
|
|
(specifically by the ``ApproximateReceiveCount`` message attribute)
|
|||
|
|
- See: class QoS(virtual.QoS):
|
|||
|
|
(method) def extract_task_name_and_number_of_retries
|
|||
|
|
|
|||
|
|
:return: A dictionary with SQS message attribute fetch config.
|
|||
|
|
"""
|
|||
|
|
APPROXIMATE_RECEIVE_COUNT = 'ApproximateReceiveCount'
|
|||
|
|
fetch = self.fetch_message_attributes
|
|||
|
|
message_system_attrs = None
|
|||
|
|
message_attrs = None
|
|||
|
|
|
|||
|
|
if fetch is None or isinstance(fetch, str):
|
|||
|
|
return {
|
|||
|
|
'MessageAttributeNames': [],
|
|||
|
|
'MessageSystemAttributeNames': [APPROXIMATE_RECEIVE_COUNT],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if isinstance(fetch, list):
|
|||
|
|
message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in fetch] else (
|
|||
|
|
list(set(fetch + [APPROXIMATE_RECEIVE_COUNT]))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif isinstance(fetch, dict):
|
|||
|
|
system = fetch.get('MessageSystemAttributeNames', [])
|
|||
|
|
attrs = fetch.get('MessageAttributeNames', None)
|
|||
|
|
|
|||
|
|
if isinstance(system, list):
|
|||
|
|
message_system_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in system] else (
|
|||
|
|
list(set(system + [APPROXIMATE_RECEIVE_COUNT]))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if isinstance(attrs, list) and attrs:
|
|||
|
|
message_attrs = ['ALL'] if 'ALL'.lower() in [s.lower() for s in attrs] else (
|
|||
|
|
list(set(attrs))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'MessageAttributeNames': sorted(message_attrs) if message_attrs else [],
|
|||
|
|
'MessageSystemAttributeNames': (
|
|||
|
|
sorted(message_system_attrs) if message_system_attrs else [APPROXIMATE_RECEIVE_COUNT]
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# —————————————————————————————————————————————————————————————
|
|||
|
|
# _message_to_python helper methods (extracted for testing/readability)
|
|||
|
|
# —————————————————————————————————————————————————————————————
|
|||
|
|
|
|||
|
|
def _optional_b64_decode(self, raw: bytes) -> bytes:
|
|||
|
|
"""Optionally decode a base64 encoded string.
|
|||
|
|
|
|||
|
|
:param raw: The raw bytes object to decode.
|
|||
|
|
:return: Bytes of the optionally decoded raw input.
|
|||
|
|
"""
|
|||
|
|
candidate = raw.strip()
|
|||
|
|
|
|||
|
|
if self.B64_REGEX.fullmatch(candidate) is None:
|
|||
|
|
return raw
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
decoded = base64.b64decode(candidate, validate=True)
|
|||
|
|
except (binascii.Error, ValueError):
|
|||
|
|
return raw
|
|||
|
|
|
|||
|
|
reencoded = base64.b64encode(decoded).rstrip(b'=')
|
|||
|
|
if reencoded != candidate.rstrip(b'='):
|
|||
|
|
return raw
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
decoded.decode('utf-8')
|
|||
|
|
except UnicodeDecodeError:
|
|||
|
|
return raw
|
|||
|
|
|
|||
|
|
return decoded
|
|||
|
|
|
|||
|
|
def _decode_python_message_body(self, raw_body):
|
|||
|
|
"""Decode the message body when needed.
|
|||
|
|
|
|||
|
|
raw_body: bytes or str
|
|||
|
|
returns: bytes (decoded Base64 if it looks like Base64, otherwise raw bytes)
|
|||
|
|
"""
|
|||
|
|
b = raw_body.encode() if isinstance(raw_body, str) else raw_body
|
|||
|
|
return self._optional_b64_decode(b)
|
|||
|
|
|
|||
|
|
def _prepare_json_payload(self, text):
|
|||
|
|
"""Try to JSON-decode text into a dict; on failure return {}."""
|
|||
|
|
try:
|
|||
|
|
data = loads(text)
|
|||
|
|
return data if isinstance(data, dict) else {}
|
|||
|
|
except (JSONDecodeError, TypeError):
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
def _delete_message(self, queue_name, message):
|
|||
|
|
"""Move the message over to the new queue URL and delete it."""
|
|||
|
|
new_q = self._new_queue(queue_name)
|
|||
|
|
self.asynsqs(queue=queue_name).delete_message(
|
|||
|
|
new_q, message['ReceiptHandle']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _envelope_payload(self, payload, raw_text, message, q_url):
|
|||
|
|
"""Prepare the payload envelope.
|
|||
|
|
|
|||
|
|
Ensure we have a dict with 'body' and 'properties.delivery_info',
|
|||
|
|
then stamp on SQS-specific metadata.
|
|||
|
|
|
|||
|
|
:param payload: The payload as an object
|
|||
|
|
:param raw_text: Text that will be set as the payload body.
|
|||
|
|
:param message: A kombu Message.
|
|||
|
|
:param q_url: The SQS queue URL.
|
|||
|
|
|
|||
|
|
:return: Payload object.
|
|||
|
|
"""
|
|||
|
|
# if payload wasn’t already a Kombu JSON dict, wrap it
|
|||
|
|
if 'properties' not in payload:
|
|||
|
|
payload = {
|
|||
|
|
'body': raw_text,
|
|||
|
|
'properties': {'delivery_info': {}},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
props = payload.setdefault('properties', {})
|
|||
|
|
di = props.setdefault('delivery_info', {})
|
|||
|
|
|
|||
|
|
# add SQS metadata
|
|||
|
|
di.update({
|
|||
|
|
'sqs_message': message,
|
|||
|
|
'sqs_queue': q_url,
|
|||
|
|
})
|
|||
|
|
props['delivery_tag'] = message['ReceiptHandle']
|
|||
|
|
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Transport(virtual.Transport):
|
|||
|
|
"""SQS Transport.
|
|||
|
|
|
|||
|
|
Additional queue attributes can be supplied to SQS during queue
|
|||
|
|
creation by passing an ``sqs-creation-attributes`` key in
|
|||
|
|
transport_options. ``sqs-creation-attributes`` must be a dict whose
|
|||
|
|
key-value pairs correspond with Attributes in the
|
|||
|
|
`CreateQueue SQS API`_.
|
|||
|
|
|
|||
|
|
For example, to have SQS queues created with server-side encryption
|
|||
|
|
enabled using the default Amazon Managed Customer Master Key, you
|
|||
|
|
can set ``KmsMasterKeyId`` Attribute. When the queue is initially
|
|||
|
|
created by Kombu, encryption will be enabled.
|
|||
|
|
|
|||
|
|
.. code-block:: python
|
|||
|
|
|
|||
|
|
from kombu.transport.SQS import Transport
|
|||
|
|
|
|||
|
|
transport = Transport(
|
|||
|
|
...,
|
|||
|
|
transport_options={
|
|||
|
|
'sqs-creation-attributes': {
|
|||
|
|
'KmsMasterKeyId': 'alias/aws/sqs',
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
.. _CreateQueue SQS API: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_CreateQueue.html#API_CreateQueue_RequestParameters
|
|||
|
|
|
|||
|
|
.. versionadded:: 5.6
|
|||
|
|
Queue tags can be applied to SQS queues during creation by passing an
|
|||
|
|
``queue_tags`` key in transport_options. ``queue_tags`` must be
|
|||
|
|
a dict of tag key-value pairs.
|
|||
|
|
|
|||
|
|
.. code-block:: python
|
|||
|
|
|
|||
|
|
from kombu.transport.SQS import Transport
|
|||
|
|
|
|||
|
|
transport = Transport(
|
|||
|
|
...,
|
|||
|
|
transport_options={
|
|||
|
|
'queue_tags': {
|
|||
|
|
'Environment': 'production',
|
|||
|
|
'Team': 'backend',
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
The ``ApproximateReceiveCount`` message attribute is fetched by this
|
|||
|
|
transport by default. Requested message attributes can be changed by
|
|||
|
|
setting ``fetch_message_attributes`` in the transport options.
|
|||
|
|
|
|||
|
|
.. code-block:: python
|
|||
|
|
|
|||
|
|
from kombu.transport.SQS import Transport
|
|||
|
|
|
|||
|
|
transport = Transport(
|
|||
|
|
...,
|
|||
|
|
transport_options={
|
|||
|
|
'fetch_message_attributes': ["All"], # Get all of the MessageSystemAttributeNames (formerly AttributeNames)
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
# Preferred - A dict specifying system and custom message attributes
|
|||
|
|
transport = Transport(
|
|||
|
|
...,
|
|||
|
|
transport_options={
|
|||
|
|
'fetch_message_attributes': {
|
|||
|
|
'MessageSystemAttributeNames': ["SenderId", "SentTimestamp"],
|
|||
|
|
'MessageAttributeNames': ['S3MessageBodyKey']
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
.. _Message Attributes: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
|
|||
|
|
|
|||
|
|
""" # noqa: E501
|
|||
|
|
|
|||
|
|
Channel = Channel
|
|||
|
|
|
|||
|
|
polling_interval = 1
|
|||
|
|
wait_time_seconds = 0
|
|||
|
|
default_port = None
|
|||
|
|
connection_errors = (
|
|||
|
|
virtual.Transport.connection_errors +
|
|||
|
|
(exceptions.BotoCoreError, socket.error)
|
|||
|
|
)
|
|||
|
|
channel_errors = (
|
|||
|
|
virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
|
|||
|
|
)
|
|||
|
|
driver_type = 'sqs'
|
|||
|
|
driver_name = 'sqs'
|
|||
|
|
|
|||
|
|
implements = virtual.Transport.implements.extend(
|
|||
|
|
asynchronous=True,
|
|||
|
|
exchange_type=frozenset(['direct']),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def default_connection_params(self):
|
|||
|
|
return {'port': self.default_port}
|