Skip to content

Commit 0634be4

Browse files
committed
Add cluster option to result backend
1 parent 204caa5 commit 0634be4

File tree

5 files changed

+279
-3
lines changed

5 files changed

+279
-3
lines changed

.github/workflows/test.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,30 @@ jobs:
4242
--health-retries=30
4343
ports:
4444
- 6379:6379
45+
redis-cluster:
46+
image: bitnami/redis-cluster:6.2.5
47+
env:
48+
ALLOW_EMPTY_PASSWORD: "yes"
49+
REDIS_NODES: "localhost"
50+
options: >-
51+
--health-cmd="redis-cli ping"
52+
--health-interval=5s
53+
--health-timeout=5s
54+
--health-retries=30
55+
ports:
56+
- 7000:6379
4557
strategy:
4658
matrix:
4759
py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
4860
runs-on: "ubuntu-latest"
4961
steps:
5062
- uses: actions/checkout@v4
63+
- uses: shogo82148/actions-setup-redis@v1
64+
with:
65+
redis-version: "6.x"
66+
auto-start: false
67+
- name: Set up single-node Redis cluster
68+
run: redis-cli -h localhost -p 7000 --cluster-yes CLUSTER ADDSLOTSRANGE 0 16383
5169
- name: Set up Python
5270
uses: actions/setup-python@v2
5371
with:

taskiq_redis/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Package for redis integration."""
2-
from taskiq_redis.redis_backend import RedisAsyncResultBackend
2+
from taskiq_redis.redis_backend import (
3+
RedisAsyncClusterResultBackend,
4+
RedisAsyncResultBackend,
5+
)
36
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
47
from taskiq_redis.schedule_source import RedisScheduleSource
58

69
__all__ = [
10+
"RedisAsyncClusterResultBackend",
711
"RedisAsyncResultBackend",
812
"ListQueueBroker",
913
"PubSubBroker",

taskiq_redis/redis_backend.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pickle
22
from typing import Dict, Optional, TypeVar, Union
33

4-
from redis.asyncio import ConnectionPool, Redis
4+
from redis.asyncio import ConnectionPool, Redis, RedisCluster
55
from taskiq import AsyncResultBackend
66
from taskiq.abc.result_backend import TaskiqResult
77

@@ -134,3 +134,117 @@ async def get_result(
134134
taskiq_result.log = None
135135

136136
return taskiq_result
137+
138+
139+
class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
140+
"""Async result backend based on redis cluster."""
141+
142+
def __init__(
143+
self,
144+
redis_url: str,
145+
keep_results: bool = True,
146+
result_ex_time: Optional[int] = None,
147+
result_px_time: Optional[int] = None,
148+
) -> None:
149+
"""
150+
Constructs a new result backend.
151+
152+
:param redis_url: url to redis cluster.
153+
:param keep_results: flag to not remove results from Redis after reading.
154+
:param result_ex_time: expire time in seconds for result.
155+
:param result_px_time: expire time in milliseconds for result.
156+
157+
:raises DuplicateExpireTimeSelectedError: if result_ex_time
158+
and result_px_time are selected.
159+
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
160+
and result_px_time are equal zero.
161+
"""
162+
self.redis = RedisCluster.from_url(redis_url)
163+
self.keep_results = keep_results
164+
self.result_ex_time = result_ex_time
165+
self.result_px_time = result_px_time
166+
167+
unavailable_conditions = any(
168+
(
169+
self.result_ex_time is not None and self.result_ex_time <= 0,
170+
self.result_px_time is not None and self.result_px_time <= 0,
171+
),
172+
)
173+
if unavailable_conditions:
174+
raise ExpireTimeMustBeMoreThanZeroError(
175+
"You must select one expire time param and it must be more than zero.",
176+
)
177+
178+
if self.result_ex_time and self.result_px_time:
179+
raise DuplicateExpireTimeSelectedError(
180+
"Choose either result_ex_time or result_px_time.",
181+
)
182+
183+
async def set_result(
184+
self,
185+
task_id: str,
186+
result: TaskiqResult[_ReturnType],
187+
) -> None:
188+
"""
189+
Sets task result in redis.
190+
191+
Dumps TaskiqResult instance into the bytes and writes
192+
it to redis.
193+
194+
:param task_id: ID of the task.
195+
:param result: TaskiqResult instance.
196+
"""
197+
redis_set_params: Dict[str, Union[str, bytes, int]] = {
198+
"name": task_id,
199+
"value": pickle.dumps(result),
200+
}
201+
if self.result_ex_time:
202+
redis_set_params["ex"] = self.result_ex_time
203+
elif self.result_px_time:
204+
redis_set_params["px"] = self.result_px_time
205+
206+
await self.redis.set(**redis_set_params) # type: ignore
207+
208+
async def is_result_ready(self, task_id: str) -> bool:
209+
"""
210+
Returns whether the result is ready.
211+
212+
:param task_id: ID of the task.
213+
214+
:returns: True if the result is ready else False.
215+
"""
216+
return bool(await self.redis.exists(task_id))
217+
218+
async def get_result(
219+
self,
220+
task_id: str,
221+
with_logs: bool = False,
222+
) -> TaskiqResult[_ReturnType]:
223+
"""
224+
Gets result from the task.
225+
226+
:param task_id: task's id.
227+
:param with_logs: if True it will download task's logs.
228+
:raises ResultIsMissingError: if there is no result when trying to get it.
229+
:return: task's return value.
230+
"""
231+
if self.keep_results:
232+
result_value = await self.redis.get(
233+
name=task_id,
234+
)
235+
else:
236+
result_value = await self.redis.getdel(
237+
name=task_id,
238+
)
239+
240+
if result_value is None:
241+
raise ResultIsMissingError
242+
243+
taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
244+
result_value,
245+
)
246+
247+
if not with_logs:
248+
taskiq_result.log = None
249+
250+
return taskiq_result

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,17 @@ def redis_url() -> str:
2626
:return: URL string.
2727
"""
2828
return os.environ.get("TEST_REDIS_URL", "redis://localhost")
29+
30+
31+
@pytest.fixture
32+
def redis_cluster_url() -> str:
33+
"""
34+
URL to connect to redis cluster.
35+
36+
It tries to get it from environ,
37+
and return default one if the variable is
38+
not set.
39+
40+
:return: URL string.
41+
"""
42+
return os.environ.get("TEST_REDIS_CLUSTER_URL", "redis://localhost:7000")

tests/test_result_backend.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44
from taskiq import TaskiqResult
55

6-
from taskiq_redis import RedisAsyncResultBackend
6+
from taskiq_redis import (RedisAsyncClusterResultBackend,
7+
RedisAsyncResultBackend)
78
from taskiq_redis.exceptions import ResultIsMissingError
89

910

@@ -130,3 +131,128 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
130131
res2 = await result_backend.get_result(task_id=task_id)
131132
assert res1 == res2
132133
await result_backend.shutdown()
134+
135+
136+
@pytest.mark.anyio
137+
async def test_set_result_success_cluster(redis_cluster_url: str) -> None:
138+
"""
139+
Tests that results can be set without errors in cluster mode.
140+
141+
:param redis_url: redis URL.
142+
"""
143+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
144+
redis_url=redis_cluster_url,
145+
)
146+
task_id = uuid.uuid4().hex
147+
result: "TaskiqResult[int]" = TaskiqResult(
148+
is_err=True,
149+
log="My Log",
150+
return_value=11,
151+
execution_time=112.2,
152+
)
153+
await result_backend.set_result(
154+
task_id=task_id,
155+
result=result,
156+
)
157+
158+
fetched_result = await result_backend.get_result(
159+
task_id=task_id,
160+
with_logs=True,
161+
)
162+
assert fetched_result.log == "My Log"
163+
assert fetched_result.return_value == 11
164+
assert fetched_result.execution_time == 112.2
165+
assert fetched_result.is_err
166+
await result_backend.shutdown()
167+
168+
169+
@pytest.mark.anyio
170+
async def test_fetch_without_logs_cluster(redis_cluster_url: str) -> None:
171+
"""
172+
Check if fetching value without logs works fine.
173+
174+
:param redis_url: redis URL.
175+
"""
176+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
177+
redis_url=redis_cluster_url,
178+
)
179+
task_id = uuid.uuid4().hex
180+
result: "TaskiqResult[int]" = TaskiqResult(
181+
is_err=True,
182+
log="My Log",
183+
return_value=11,
184+
execution_time=112.2,
185+
)
186+
await result_backend.set_result(
187+
task_id=task_id,
188+
result=result,
189+
)
190+
191+
fetched_result = await result_backend.get_result(
192+
task_id=task_id,
193+
with_logs=False,
194+
)
195+
assert fetched_result.log is None
196+
assert fetched_result.return_value == 11
197+
assert fetched_result.execution_time == 112.2
198+
assert fetched_result.is_err
199+
await result_backend.shutdown()
200+
201+
202+
@pytest.mark.anyio
203+
async def test_remove_results_after_reading_cluster(redis_cluster_url: str) -> None:
204+
"""
205+
Check if removing results after reading works fine.
206+
207+
:param redis_url: redis URL.
208+
"""
209+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
210+
redis_url=redis_cluster_url,
211+
keep_results=False,
212+
)
213+
task_id = uuid.uuid4().hex
214+
result: "TaskiqResult[int]" = TaskiqResult(
215+
is_err=True,
216+
log="My Log",
217+
return_value=11,
218+
execution_time=112.2,
219+
)
220+
await result_backend.set_result(
221+
task_id=task_id,
222+
result=result,
223+
)
224+
225+
await result_backend.get_result(task_id=task_id)
226+
with pytest.raises(ResultIsMissingError):
227+
await result_backend.get_result(task_id=task_id)
228+
229+
await result_backend.shutdown()
230+
231+
232+
@pytest.mark.anyio
233+
async def test_keep_results_after_reading_cluster(redis_cluster_url: str) -> None:
234+
"""
235+
Check if keeping results after reading works fine.
236+
237+
:param redis_url: redis URL.
238+
"""
239+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
240+
redis_url=redis_cluster_url,
241+
keep_results=True,
242+
)
243+
task_id = uuid.uuid4().hex
244+
result: "TaskiqResult[int]" = TaskiqResult(
245+
is_err=True,
246+
log="My Log",
247+
return_value=11,
248+
execution_time=112.2,
249+
)
250+
await result_backend.set_result(
251+
task_id=task_id,
252+
result=result,
253+
)
254+
255+
res1 = await result_backend.get_result(task_id=task_id)
256+
res2 = await result_backend.get_result(task_id=task_id)
257+
assert res1 == res2
258+
await result_backend.shutdown()

0 commit comments

Comments
 (0)