Skip to content

Commit 3ac97c9

Browse files
committed
Add cluster option to result backend
1 parent 204caa5 commit 3ac97c9

File tree

4 files changed

+92
-21
lines changed

4 files changed

+92
-21
lines changed

.github/workflows/test.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,35 @@ jobs:
4242
--health-retries=30
4343
ports:
4444
- 6379:6379
45+
redis-cluster1:
46+
image: bitnami/redis-cluster:6.2.5
47+
env:
48+
ALLOW_EMPTY_PASSWORD: "yes"
49+
REDIS_NODES: "localhost"
50+
ports:
51+
- 7000:6379
52+
redis-cluster2:
53+
image: bitnami/redis-cluster:6.2.5
54+
env:
55+
ALLOW_EMPTY_PASSWORD: "yes"
56+
REDIS_NODES: "localhost"
57+
ports:
58+
- 7001:6379
59+
redis-cluster3:
60+
image: bitnami/redis-cluster:6.2.5
61+
env:
62+
ALLOW_EMPTY_PASSWORD: "yes"
63+
REDIS_NODES: "localhost"
64+
ports:
65+
- 7002:6379
4566
strategy:
4667
matrix:
4768
py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
4869
runs-on: "ubuntu-latest"
4970
steps:
5071
- uses: actions/checkout@v4
72+
- name: Set up Redis cluster
73+
run: redis-cli --cluster create localhost:7000 localhost:7001 localhost:7002 --cluster-replicas 0 --cluster-yes
5174
- name: Set up Python
5275
uses: actions/setup-python@v2
5376
with:

taskiq_redis/redis_backend.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pickle
2-
from typing import Dict, Optional, TypeVar, Union
2+
from typing import Dict, Optional, Type, TypeVar, Union
33

4-
from redis.asyncio import ConnectionPool, Redis
4+
from redis.asyncio import Redis, RedisCluster
55
from taskiq import AsyncResultBackend
66
from taskiq.abc.result_backend import TaskiqResult
77

@@ -23,11 +23,15 @@ def __init__(
2323
keep_results: bool = True,
2424
result_ex_time: Optional[int] = None,
2525
result_px_time: Optional[int] = None,
26+
*,
27+
redis_cls: Union[Type[Redis], Type[RedisCluster], None] = None,
2628
) -> None:
2729
"""
2830
Constructs a new result backend.
2931
3032
:param redis_url: url to redis.
33+
:param redis_cls: async redis class, should be either redis.asyncio.Redis
34+
or redis.asyncio.RedisCluster.
3135
:param keep_results: flag to not remove results from Redis after reading.
3236
:param result_ex_time: expire time in seconds for result.
3337
:param result_px_time: expire time in milliseconds for result.
@@ -37,7 +41,10 @@ def __init__(
3741
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
3842
and result_px_time are equal zero.
3943
"""
40-
self.redis_pool = ConnectionPool.from_url(redis_url)
44+
if redis_cls is None:
45+
redis_cls = Redis
46+
47+
self.redis = redis_cls.from_url(redis_url)
4148
self.keep_results = keep_results
4249
self.result_ex_time = result_ex_time
4350
self.result_px_time = result_px_time
@@ -58,11 +65,6 @@ def __init__(
5865
"Choose either result_ex_time or result_px_time.",
5966
)
6067

61-
async def shutdown(self) -> None:
62-
"""Closes redis connection."""
63-
await self.redis_pool.disconnect()
64-
await super().shutdown()
65-
6668
async def set_result(
6769
self,
6870
task_id: str,
@@ -86,8 +88,7 @@ async def set_result(
8688
elif self.result_px_time:
8789
redis_set_params["px"] = self.result_px_time
8890

89-
async with Redis(connection_pool=self.redis_pool) as redis:
90-
await redis.set(**redis_set_params) # type: ignore
91+
await self.redis.set(**redis_set_params) # type: ignore
9192

9293
async def is_result_ready(self, task_id: str) -> bool:
9394
"""
@@ -97,8 +98,7 @@ async def is_result_ready(self, task_id: str) -> bool:
9798
9899
:returns: True if the result is ready else False.
99100
"""
100-
async with Redis(connection_pool=self.redis_pool) as redis:
101-
return bool(await redis.exists(task_id))
101+
return bool(await self.redis.exists(task_id))
102102

103103
async def get_result(
104104
self,
@@ -113,15 +113,14 @@ async def get_result(
113113
:raises ResultIsMissingError: if there is no result when trying to get it.
114114
:return: task's return value.
115115
"""
116-
async with Redis(connection_pool=self.redis_pool) as redis:
117-
if self.keep_results:
118-
result_value = await redis.get(
119-
name=task_id,
120-
)
121-
else:
122-
result_value = await redis.getdel(
123-
name=task_id,
124-
)
116+
if self.keep_results:
117+
result_value = await self.redis.get(
118+
name=task_id,
119+
)
120+
else:
121+
result_value = await self.redis.getdel(
122+
name=task_id,
123+
)
125124

126125
if result_value is None:
127126
raise ResultIsMissingError

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,17 @@ def redis_url() -> str:
2626
:return: URL string.
2727
"""
2828
return os.environ.get("TEST_REDIS_URL", "redis://localhost")
29+
30+
31+
@pytest.fixture
32+
def redis_cluster_url() -> str:
33+
"""
34+
URL to connect to redis cluster.
35+
36+
It tries to get it from environ,
37+
and return default one if the variable is
38+
not set.
39+
40+
:return: URL string.
41+
"""
42+
return os.environ.get("TEST_REDIS_CLUSTER_URL", "redis://localhost:7000")

tests/test_result_backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22

33
import pytest
4+
from redis.asyncio import RedisCluster
45
from taskiq import TaskiqResult
56

67
from taskiq_redis import RedisAsyncResultBackend
@@ -130,3 +131,37 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
130131
res2 = await result_backend.get_result(task_id=task_id)
131132
assert res1 == res2
132133
await result_backend.shutdown()
134+
135+
136+
@pytest.mark.anyio
137+
async def test_set_result_success_cluster(redis_cluster_url: str) -> None:
138+
"""
139+
Tests that results can be set without errors in cluster mode.
140+
141+
:param redis_url: redis URL.
142+
"""
143+
result_backend = RedisAsyncResultBackend( # type: ignore
144+
redis_url=redis_cluster_url,
145+
redis_cls=RedisCluster,
146+
)
147+
task_id = uuid.uuid4().hex
148+
result: "TaskiqResult[int]" = TaskiqResult(
149+
is_err=True,
150+
log="My Log",
151+
return_value=11,
152+
execution_time=112.2,
153+
)
154+
await result_backend.set_result(
155+
task_id=task_id,
156+
result=result,
157+
)
158+
159+
fetched_result = await result_backend.get_result(
160+
task_id=task_id,
161+
with_logs=True,
162+
)
163+
assert fetched_result.log == "My Log"
164+
assert fetched_result.return_value == 11
165+
assert fetched_result.execution_time == 112.2
166+
assert fetched_result.is_err
167+
await result_backend.shutdown()

0 commit comments

Comments
 (0)