|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import asyncio |
| 16 | +from contextlib import asynccontextmanager |
16 | 17 | from datetime import datetime |
17 | 18 | from datetime import timezone |
18 | 19 | import enum |
@@ -151,6 +152,74 @@ def fake_create_async_engine(_db_url: str, **kwargs): |
151 | 152 | assert captured_kwargs.get('pool_pre_ping') is False |
152 | 153 |
|
153 | 154 |
|
| 155 | +def test_database_session_service_creates_read_only_engine_for_spanner(): |
| 156 | + captured_binds = [] |
| 157 | + fake_engine = mock.Mock() |
| 158 | + fake_engine.dialect.name = 'spanner+spanner' |
| 159 | + fake_engine.sync_engine = mock.Mock() |
| 160 | + read_only_engine = mock.Mock() |
| 161 | + fake_engine.execution_options.return_value = read_only_engine |
| 162 | + |
| 163 | + def fake_async_sessionmaker(*, bind, expire_on_commit, **kwargs): |
| 164 | + del expire_on_commit |
| 165 | + del kwargs |
| 166 | + captured_binds.append(bind) |
| 167 | + return mock.Mock() |
| 168 | + |
| 169 | + with ( |
| 170 | + mock.patch.object( |
| 171 | + database_session_service, |
| 172 | + 'create_async_engine', |
| 173 | + return_value=fake_engine, |
| 174 | + ), |
| 175 | + mock.patch.object( |
| 176 | + database_session_service, |
| 177 | + 'async_sessionmaker', |
| 178 | + side_effect=fake_async_sessionmaker, |
| 179 | + ), |
| 180 | + ): |
| 181 | + database_session_service.DatabaseSessionService( |
| 182 | + 'spanner+spanner:///projects/test/instances/test/databases/test' |
| 183 | + ) |
| 184 | + |
| 185 | + assert captured_binds == [fake_engine, read_only_engine] |
| 186 | + fake_engine.execution_options.assert_called_once_with(read_only=True) |
| 187 | + |
| 188 | + |
| 189 | +def test_database_session_service_creates_read_only_engine_for_other_dialects(): |
| 190 | + captured_binds = [] |
| 191 | + fake_engine = mock.Mock() |
| 192 | + fake_engine.dialect.name = 'postgresql' |
| 193 | + fake_engine.sync_engine = mock.Mock() |
| 194 | + read_only_engine = mock.Mock() |
| 195 | + fake_engine.execution_options.return_value = read_only_engine |
| 196 | + |
| 197 | + def fake_async_sessionmaker(*, bind, expire_on_commit, **kwargs): |
| 198 | + del expire_on_commit |
| 199 | + del kwargs |
| 200 | + captured_binds.append(bind) |
| 201 | + return mock.Mock() |
| 202 | + |
| 203 | + with ( |
| 204 | + mock.patch.object( |
| 205 | + database_session_service, |
| 206 | + 'create_async_engine', |
| 207 | + return_value=fake_engine, |
| 208 | + ), |
| 209 | + mock.patch.object( |
| 210 | + database_session_service, |
| 211 | + 'async_sessionmaker', |
| 212 | + side_effect=fake_async_sessionmaker, |
| 213 | + ), |
| 214 | + ): |
| 215 | + database_session_service.DatabaseSessionService( |
| 216 | + 'postgresql+psycopg2://user:pass@localhost:5432/db' |
| 217 | + ) |
| 218 | + |
| 219 | + assert captured_binds == [fake_engine, read_only_engine] |
| 220 | + fake_engine.execution_options.assert_called_once_with(read_only=True) |
| 221 | + |
| 222 | + |
154 | 223 | @pytest.mark.asyncio |
155 | 224 | async def test_sqlite_session_service_accepts_sqlite_urls( |
156 | 225 | tmp_path, monkeypatch |
@@ -198,6 +267,67 @@ async def test_get_empty_session(session_service): |
198 | 267 | ) |
199 | 268 |
|
200 | 269 |
|
| 270 | +@pytest.mark.asyncio |
| 271 | +async def test_database_session_service_get_session_uses_read_only_factory(): |
| 272 | + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') |
| 273 | + service._prepare_tables = mock.AsyncMock() |
| 274 | + |
| 275 | + read_only_session = mock.AsyncMock() |
| 276 | + read_only_session.get = mock.AsyncMock(return_value=None) |
| 277 | + |
| 278 | + @asynccontextmanager |
| 279 | + async def fake_read_only_session(): |
| 280 | + yield read_only_session |
| 281 | + |
| 282 | + service.database_session_factory = mock.Mock( |
| 283 | + side_effect=AssertionError('write session factory should not be used') |
| 284 | + ) |
| 285 | + service._read_only_database_session_factory = mock.Mock( |
| 286 | + return_value=fake_read_only_session() |
| 287 | + ) |
| 288 | + |
| 289 | + session = await service.get_session( |
| 290 | + app_name='my_app', user_id='test_user', session_id='123' |
| 291 | + ) |
| 292 | + |
| 293 | + assert session is None |
| 294 | + service._read_only_database_session_factory.assert_called_once_with() |
| 295 | + service.database_session_factory.assert_not_called() |
| 296 | + |
| 297 | + await service.close() |
| 298 | + |
| 299 | + |
| 300 | +@pytest.mark.asyncio |
| 301 | +async def test_database_session_service_list_sessions_uses_read_only_factory(): |
| 302 | + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') |
| 303 | + service._prepare_tables = mock.AsyncMock() |
| 304 | + |
| 305 | + read_only_session = mock.AsyncMock() |
| 306 | + empty_result = mock.Mock() |
| 307 | + empty_result.scalars.return_value.all.return_value = [] |
| 308 | + read_only_session.execute = mock.AsyncMock(return_value=empty_result) |
| 309 | + read_only_session.get = mock.AsyncMock(return_value=None) |
| 310 | + |
| 311 | + @asynccontextmanager |
| 312 | + async def fake_read_only_session(): |
| 313 | + yield read_only_session |
| 314 | + |
| 315 | + service.database_session_factory = mock.Mock( |
| 316 | + side_effect=AssertionError('write session factory should not be used') |
| 317 | + ) |
| 318 | + service._read_only_database_session_factory = mock.Mock( |
| 319 | + return_value=fake_read_only_session() |
| 320 | + ) |
| 321 | + |
| 322 | + response = await service.list_sessions(app_name='my_app', user_id='test_user') |
| 323 | + |
| 324 | + assert response.sessions == [] |
| 325 | + service._read_only_database_session_factory.assert_called_once_with() |
| 326 | + service.database_session_factory.assert_not_called() |
| 327 | + |
| 328 | + await service.close() |
| 329 | + |
| 330 | + |
201 | 331 | @pytest.mark.asyncio |
202 | 332 | async def test_create_get_session(session_service): |
203 | 333 | app_name = 'my_app' |
|
0 commit comments