Skip to content

Commit 180ae76

Browse files
authored
Merge pull request #689 from ydb-platform/query_session_pool_async_methods
Add async methods to QuerySessionPool
2 parents 456cd74 + 84015ee commit 180ae76

File tree

3 files changed

+132
-2
lines changed

3 files changed

+132
-2
lines changed

tests/query/test_query_session_pool.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
import ydb
3+
import time
4+
from concurrent import futures
35

46
from typing import Optional
57

@@ -132,7 +134,7 @@ def test_pool_recreates_bad_sessions(self, pool: QuerySessionPool):
132134

133135
def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPool):
134136
pool.stop()
135-
with pytest.raises(RuntimeError):
137+
with pytest.raises(ydb.SessionPoolClosed):
136138
pool.acquire(1)
137139

138140
def test_no_session_leak(self, driver_sync, docker_project):
@@ -146,3 +148,55 @@ def test_no_session_leak(self, driver_sync, docker_project):
146148

147149
docker_project.start()
148150
pool.stop()
151+
152+
def test_execute_with_retries_async(self, pool: QuerySessionPool):
153+
fut = pool.execute_with_retries_async("select 1;")
154+
res = fut.result()
155+
assert len(res) == 1
156+
157+
def test_retry_operation_async(self, pool: QuerySessionPool):
158+
def callee(session: QuerySession):
159+
with session.transaction() as tx:
160+
iterator = tx.execute("select 1;", commit_tx=True)
161+
return [result_set for result_set in iterator]
162+
163+
fut = pool.retry_operation_async(callee)
164+
res = fut.result()
165+
assert len(res) == 1
166+
167+
def test_retry_tx_async(self, pool: QuerySessionPool):
168+
retry_no = 0
169+
170+
def callee(tx: QueryTxContext):
171+
nonlocal retry_no
172+
if retry_no < 2:
173+
retry_no += 1
174+
raise ydb.Unavailable("Fake fast backoff error")
175+
result_stream = tx.execute("SELECT 1")
176+
return [result_set for result_set in result_stream]
177+
178+
result = pool.retry_tx_async(callee=callee).result()
179+
assert len(result) == 1
180+
assert retry_no == 2
181+
182+
def test_execute_with_retries_async_many_calls(self, pool: QuerySessionPool):
183+
futs = [pool.execute_with_retries_async("select 1;") for _ in range(10)]
184+
results = [f.result() for f in futures.as_completed(futs)]
185+
assert all(len(r) == 1 for r in results)
186+
187+
def test_future_waits_on_stop(self, pool: QuerySessionPool):
188+
def callee(session: QuerySession):
189+
time.sleep(0.1)
190+
with session.transaction() as tx:
191+
it = tx.execute("select 1;", commit_tx=True)
192+
return [rs for rs in it]
193+
194+
fut = pool.retry_operation_async(callee)
195+
pool.stop()
196+
assert fut.done()
197+
assert len(fut.result()) == 1
198+
199+
def test_async_methods_after_stop_raise(self, pool: QuerySessionPool):
200+
pool.stop()
201+
with pytest.raises(ydb.SessionPoolClosed):
202+
pool.execute_with_retries_async("select 1;")

ydb/issues.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class StatusCode(enum.IntEnum):
5151

5252
UNAUTHENTICATED = _CLIENT_STATUSES_FIRST + 30
5353
SESSION_POOL_EMPTY = _CLIENT_STATUSES_FIRST + 40
54+
SESSION_POOL_CLOSED = _CLIENT_STATUSES_FIRST + 50
5455

5556

5657
# TODO: convert from proto IssueMessage
@@ -179,6 +180,13 @@ class SessionPoolEmpty(Error, queue.Empty):
179180
status = StatusCode.SESSION_POOL_EMPTY
180181

181182

183+
class SessionPoolClosed(Error):
184+
status = StatusCode.SESSION_POOL_CLOSED
185+
186+
def __init__(self):
187+
super().__init__("Session pool is closed.")
188+
189+
182190
class ClientInternalError(Error):
183191
status = StatusCode.CLIENT_INTERNAL_ERROR
184192

ydb/query/pool.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from concurrent import futures
23
from typing import (
34
Callable,
45
Optional,
@@ -36,14 +37,17 @@ def __init__(
3637
size: int = 100,
3738
*,
3839
query_client_settings: Optional[QueryClientSettings] = None,
40+
workers_threads_count: int = 4,
3941
):
4042
"""
4143
:param driver: A driver instance.
4244
:param size: Max size of Session Pool.
4345
:param query_client_settings: ydb.QueryClientSettings object to configure QueryService behavior
46+
:param workers_threads_count: A number of threads in executor used for *_async methods
4447
"""
4548

4649
self._driver = driver
50+
self._tp = futures.ThreadPoolExecutor(workers_threads_count)
4751
self._queue = queue.Queue()
4852
self._current_size = 0
4953
self._size = size
@@ -72,7 +76,7 @@ def acquire(self, timeout: Optional[float] = None) -> QuerySession:
7276
try:
7377
if self._should_stop.is_set():
7478
logger.error("An attempt to take session from closed session pool.")
75-
raise RuntimeError("An attempt to take session from closed session pool.")
79+
raise issues.SessionPoolClosed()
7680

7781
session = None
7882
try:
@@ -132,6 +136,9 @@ def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetryS
132136
:return: Result sets or exception in case of execution errors.
133137
"""
134138

139+
if self._should_stop.is_set():
140+
raise issues.SessionPoolClosed()
141+
135142
retry_settings = RetrySettings() if retry_settings is None else retry_settings
136143

137144
def wrapped_callee():
@@ -140,6 +147,38 @@ def wrapped_callee():
140147

141148
return retry_operation_sync(wrapped_callee, retry_settings)
142149

150+
def retry_tx_async(
151+
self,
152+
callee: Callable,
153+
tx_mode: Optional[BaseQueryTxMode] = None,
154+
retry_settings: Optional[RetrySettings] = None,
155+
*args,
156+
**kwargs,
157+
) -> futures.Future:
158+
"""Asynchronously execute a transaction in a retriable way."""
159+
160+
if self._should_stop.is_set():
161+
raise issues.SessionPoolClosed()
162+
163+
return self._tp.submit(
164+
self.retry_tx_sync,
165+
callee,
166+
tx_mode,
167+
retry_settings,
168+
*args,
169+
**kwargs,
170+
)
171+
172+
def retry_operation_async(
173+
self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs
174+
) -> futures.Future:
175+
"""Asynchronously execute a retryable operation."""
176+
177+
if self._should_stop.is_set():
178+
raise issues.SessionPoolClosed()
179+
180+
return self._tp.submit(self.retry_operation_sync, callee, retry_settings, *args, **kwargs)
181+
143182
def retry_tx_sync(
144183
self,
145184
callee: Callable,
@@ -161,6 +200,9 @@ def retry_tx_sync(
161200
:return: Result sets or exception in case of execution errors.
162201
"""
163202

203+
if self._should_stop.is_set():
204+
raise issues.SessionPoolClosed()
205+
164206
tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite()
165207
retry_settings = RetrySettings() if retry_settings is None else retry_settings
166208

@@ -194,6 +236,9 @@ def execute_with_retries(
194236
:return: Result sets or exception in case of execution errors.
195237
"""
196238

239+
if self._should_stop.is_set():
240+
raise issues.SessionPoolClosed()
241+
197242
retry_settings = RetrySettings() if retry_settings is None else retry_settings
198243

199244
def wrapped_callee():
@@ -203,11 +248,34 @@ def wrapped_callee():
203248

204249
return retry_operation_sync(wrapped_callee, retry_settings)
205250

251+
def execute_with_retries_async(
252+
self,
253+
query: str,
254+
parameters: Optional[dict] = None,
255+
retry_settings: Optional[RetrySettings] = None,
256+
*args,
257+
**kwargs,
258+
) -> futures.Future:
259+
"""Asynchronously execute a query with retries."""
260+
261+
if self._should_stop.is_set():
262+
raise issues.SessionPoolClosed()
263+
264+
return self._tp.submit(
265+
self.execute_with_retries,
266+
query,
267+
parameters,
268+
retry_settings,
269+
*args,
270+
**kwargs,
271+
)
272+
206273
def stop(self, timeout=None):
207274
acquire_timeout = timeout if timeout is not None else -1
208275
acquired = self._lock.acquire(timeout=acquire_timeout)
209276
try:
210277
self._should_stop.set()
278+
self._tp.shutdown(wait=True)
211279
while True:
212280
try:
213281
session = self._queue.get_nowait()

0 commit comments

Comments
 (0)