Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
RedisAsyncResultBackend,
)
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
from taskiq_redis.schedule_source import RedisScheduleSource

__all__ = [
"RedisAsyncClusterResultBackend",
"RedisAsyncResultBackend",
"ListQueueBroker",
"PubSubBroker",
"ListQueueClusterBroker",
"RedisScheduleSource",
]
78 changes: 78 additions & 0 deletions taskiq_redis/redis_cluster_broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar

from redis.asyncio import RedisCluster
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage

_T = TypeVar("_T")


class BaseRedisClusterBroker(AsyncBroker):
"""Base broker that works with Redis Cluster."""

def __init__(
self,
url: str,
task_id_generator: Optional[Callable[[], str]] = None,
result_backend: Optional[AsyncResultBackend[_T]] = None,
queue_name: str = "taskiq",
max_connection_pool_size: int = 2**31,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new broker.

:param url: url to redis.
:param task_id_generator: custom task_id generator.
:param result_backend: custom result backend.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
"""
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)

self.redis: RedisCluster[bytes] = RedisCluster.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)

self.queue_name = queue_name

async def shutdown(self) -> None:
"""Closes redis connection pool."""
await self.redis.aclose() # type: ignore[attr-defined]
await super().shutdown()


class ListQueueClusterBroker(BaseRedisClusterBroker):
"""Broker that works with Redis Cluster and distributes tasks between workers."""

async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.

This method appends a message to the list of all messages.

:param message: message to append.
"""
await self.redis.lpush(self.queue_name, message.message)

async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.

This function listens to the queue
and yields new messages if they have BrokerMessage type.

:yields: broker messages.
"""
redis_brpop_data_position = 1
while True:
yield (await self.redis.brpop([self.queue_name]))[
redis_brpop_data_position
]
33 changes: 32 additions & 1 deletion tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from taskiq import AckableMessage, AsyncBroker, BrokerMessage

from taskiq_redis import ListQueueBroker, PubSubBroker
from taskiq_redis import ListQueueBroker, PubSubBroker, ListQueueClusterBroker


def test_no_url_should_raise_typeerror() -> None:
Expand Down Expand Up @@ -96,3 +96,34 @@ async def test_list_queue_broker(
worker1_task.cancel()
worker2_task.cancel()
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_cluster_broker(
valid_broker_message: BrokerMessage,
redis_cluster_url: str,
) -> None:
"""
Test that messages are published and read correctly by ListQueueClusterBroker.

We create two workers that listen and send a message to them.
Expect only one worker to receive the same message we sent.
"""

print(f"redis_cluster_url: {redis_cluster_url}")
broker = ListQueueClusterBroker(
url=redis_cluster_url, queue_name=uuid.uuid4().hex
)
worker1_task = asyncio.create_task(get_message(broker))
worker2_task = asyncio.create_task(get_message(broker))
await asyncio.sleep(0.3)

await broker.kick(valid_broker_message)
await asyncio.sleep(0.3)

assert worker1_task.done() != worker2_task.done()
message = worker1_task.result() if worker1_task.done() else worker2_task.result()
assert message == valid_broker_message.message
worker1_task.cancel()
worker2_task.cancel()
await broker.shutdown()