Skip to content

Commit f6ea58b

Browse files
GWealecopybara-github
authored andcommitted
fix: Add read-only session support in DatabaseSessionService
This change introduces a separate async session factory for Spanner connections configured with `read_only=True` Close #4771 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 886267104
1 parent fdc2b43 commit f6ea58b

File tree

2 files changed

+151
-5
lines changed

2 files changed

+151
-5
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import logging
2222
from typing import Any
2323
from typing import AsyncIterator
24-
from typing import Optional
2524
from typing import TypeAlias
2625
from typing import TypeVar
2726

@@ -178,11 +177,14 @@ def __init__(self, db_url: str, **kwargs: Any):
178177
) from e
179178

180179
self.db_engine: AsyncEngine = db_engine
181-
182180
# DB session factory method
183181
self.database_session_factory: async_sessionmaker[
184182
DatabaseSessionFactory
185183
] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False)
184+
read_only_engine = self.db_engine.execution_options(read_only=True)
185+
self._read_only_database_session_factory: async_sessionmaker[
186+
DatabaseSessionFactory
187+
] = async_sessionmaker(bind=read_only_engine, expire_on_commit=False)
186188

187189
# Flag to indicate if tables are created
188190
self._tables_created = False
@@ -201,17 +203,27 @@ def __init__(self, db_url: str, **kwargs: Any):
201203
def _get_schema_classes(self) -> _SchemaClasses:
202204
return _SchemaClasses(self._db_schema_version)
203205

206+
def _get_database_session_factory(
207+
self, *, read_only: bool = False
208+
) -> async_sessionmaker[DatabaseSessionFactory]:
209+
if read_only:
210+
return self._read_only_database_session_factory
211+
return self.database_session_factory
212+
204213
@asynccontextmanager
205214
async def _rollback_on_exception_session(
206215
self,
216+
*,
217+
read_only: bool = False,
207218
) -> AsyncIterator[DatabaseSessionFactory]:
208219
"""Yields a database session with guaranteed rollback on errors.
209220
210221
On normal exit the caller is responsible for committing; on any exception
211222
the transaction is explicitly rolled back before the error propagates,
212223
preventing connection-pool exhaustion from lingering invalid transactions.
213224
"""
214-
async with self.database_session_factory() as sql_session:
225+
session_factory = self._get_database_session_factory(read_only=read_only)
226+
async with session_factory() as sql_session:
215227
try:
216228
yield sql_session
217229
except BaseException:
@@ -441,7 +453,9 @@ async def get_session(
441453
# 2. Get all the events based on session id and filtering config
442454
# 3. Convert and return the session
443455
schema = self._get_schema_classes()
444-
async with self._rollback_on_exception_session() as sql_session:
456+
async with self._rollback_on_exception_session(
457+
read_only=True
458+
) as sql_session:
445459
storage_session = await sql_session.get(
446460
schema.StorageSession, (app_name, user_id, session_id)
447461
)
@@ -496,7 +510,9 @@ async def list_sessions(
496510
) -> ListSessionsResponse:
497511
await self._prepare_tables()
498512
schema = self._get_schema_classes()
499-
async with self._rollback_on_exception_session() as sql_session:
513+
async with self._rollback_on_exception_session(
514+
read_only=True
515+
) as sql_session:
500516
stmt = select(schema.StorageSession).filter(
501517
schema.StorageSession.app_name == app_name
502518
)

tests/unittests/sessions/test_session_service.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
from contextlib import asynccontextmanager
1617
from datetime import datetime
1718
from datetime import timezone
1819
import enum
@@ -151,6 +152,74 @@ def fake_create_async_engine(_db_url: str, **kwargs):
151152
assert captured_kwargs.get('pool_pre_ping') is False
152153

153154

155+
def test_database_session_service_creates_read_only_engine_for_spanner():
156+
captured_binds = []
157+
fake_engine = mock.Mock()
158+
fake_engine.dialect.name = 'spanner+spanner'
159+
fake_engine.sync_engine = mock.Mock()
160+
read_only_engine = mock.Mock()
161+
fake_engine.execution_options.return_value = read_only_engine
162+
163+
def fake_async_sessionmaker(*, bind, expire_on_commit, **kwargs):
164+
del expire_on_commit
165+
del kwargs
166+
captured_binds.append(bind)
167+
return mock.Mock()
168+
169+
with (
170+
mock.patch.object(
171+
database_session_service,
172+
'create_async_engine',
173+
return_value=fake_engine,
174+
),
175+
mock.patch.object(
176+
database_session_service,
177+
'async_sessionmaker',
178+
side_effect=fake_async_sessionmaker,
179+
),
180+
):
181+
database_session_service.DatabaseSessionService(
182+
'spanner+spanner:///projects/test/instances/test/databases/test'
183+
)
184+
185+
assert captured_binds == [fake_engine, read_only_engine]
186+
fake_engine.execution_options.assert_called_once_with(read_only=True)
187+
188+
189+
def test_database_session_service_creates_read_only_engine_for_other_dialects():
190+
captured_binds = []
191+
fake_engine = mock.Mock()
192+
fake_engine.dialect.name = 'postgresql'
193+
fake_engine.sync_engine = mock.Mock()
194+
read_only_engine = mock.Mock()
195+
fake_engine.execution_options.return_value = read_only_engine
196+
197+
def fake_async_sessionmaker(*, bind, expire_on_commit, **kwargs):
198+
del expire_on_commit
199+
del kwargs
200+
captured_binds.append(bind)
201+
return mock.Mock()
202+
203+
with (
204+
mock.patch.object(
205+
database_session_service,
206+
'create_async_engine',
207+
return_value=fake_engine,
208+
),
209+
mock.patch.object(
210+
database_session_service,
211+
'async_sessionmaker',
212+
side_effect=fake_async_sessionmaker,
213+
),
214+
):
215+
database_session_service.DatabaseSessionService(
216+
'postgresql+psycopg2://user:pass@localhost:5432/db'
217+
)
218+
219+
assert captured_binds == [fake_engine, read_only_engine]
220+
fake_engine.execution_options.assert_called_once_with(read_only=True)
221+
222+
154223
@pytest.mark.asyncio
155224
async def test_sqlite_session_service_accepts_sqlite_urls(
156225
tmp_path, monkeypatch
@@ -198,6 +267,67 @@ async def test_get_empty_session(session_service):
198267
)
199268

200269

270+
@pytest.mark.asyncio
271+
async def test_database_session_service_get_session_uses_read_only_factory():
272+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
273+
service._prepare_tables = mock.AsyncMock()
274+
275+
read_only_session = mock.AsyncMock()
276+
read_only_session.get = mock.AsyncMock(return_value=None)
277+
278+
@asynccontextmanager
279+
async def fake_read_only_session():
280+
yield read_only_session
281+
282+
service.database_session_factory = mock.Mock(
283+
side_effect=AssertionError('write session factory should not be used')
284+
)
285+
service._read_only_database_session_factory = mock.Mock(
286+
return_value=fake_read_only_session()
287+
)
288+
289+
session = await service.get_session(
290+
app_name='my_app', user_id='test_user', session_id='123'
291+
)
292+
293+
assert session is None
294+
service._read_only_database_session_factory.assert_called_once_with()
295+
service.database_session_factory.assert_not_called()
296+
297+
await service.close()
298+
299+
300+
@pytest.mark.asyncio
301+
async def test_database_session_service_list_sessions_uses_read_only_factory():
302+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
303+
service._prepare_tables = mock.AsyncMock()
304+
305+
read_only_session = mock.AsyncMock()
306+
empty_result = mock.Mock()
307+
empty_result.scalars.return_value.all.return_value = []
308+
read_only_session.execute = mock.AsyncMock(return_value=empty_result)
309+
read_only_session.get = mock.AsyncMock(return_value=None)
310+
311+
@asynccontextmanager
312+
async def fake_read_only_session():
313+
yield read_only_session
314+
315+
service.database_session_factory = mock.Mock(
316+
side_effect=AssertionError('write session factory should not be used')
317+
)
318+
service._read_only_database_session_factory = mock.Mock(
319+
return_value=fake_read_only_session()
320+
)
321+
322+
response = await service.list_sessions(app_name='my_app', user_id='test_user')
323+
324+
assert response.sessions == []
325+
service._read_only_database_session_factory.assert_called_once_with()
326+
service.database_session_factory.assert_not_called()
327+
328+
await service.close()
329+
330+
201331
@pytest.mark.asyncio
202332
async def test_create_get_session(session_service):
203333
app_name = 'my_app'

0 commit comments

Comments
 (0)