diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 4a2c0a1df..cc923b553 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -563,7 +563,9 @@ def _prepare_session_config(cls, preview_check, config_kwargs): async def close(self) -> None: """ Shut down, closing any open connections in the pool. """ - self._check_state() + # TODO: 6.0 - NOOP if already closed + # if self._closed: + # return try: await self._pool.close() except asyncio.CancelledError: diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index cad5f0442..754c2138c 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -562,7 +562,9 @@ def _prepare_session_config(cls, preview_check, config_kwargs): def close(self) -> None: """ Shut down, closing any open connections in the pool. """ - self._check_state() + # TODO: 6.0 - NOOP if already closed + # if self._closed: + # return try: self._pool.close() except asyncio.CancelledError: diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index ea11f43be..e65e456f4 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -21,6 +21,7 @@ import inspect import ssl import typing as t +import warnings import pytest import typing_extensions as te @@ -963,12 +964,11 @@ async def test_supports_session_auth(session_cls_mock) -> None: ("get_server_info", (), {}), ("supports_multi_db", (), {}), ("supports_session_auth", (), {}), - ("close", (), {}), ) ) @mark_async_test -async def test_using_closed_driver_is_deprecated( +async def test_using_closed_driver_where_deprecated( method_name, args, kwargs, session_cls_mock ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") @@ -983,3 +983,25 @@ async def test_using_closed_driver_is_deprecated( await method(*args, **kwargs) else: method(*args, **kwargs) + + +@pytest.mark.parametrize( + ("method_name", "args", "kwargs"), + ( + ("close", (), {}), + ) +) +@mark_async_test +async def test_using_closed_driver_where_not_deprecated( + method_name, args, kwargs, session_cls_mock +) -> None: + driver = AsyncGraphDatabase.driver("bolt://localhost") + await driver.close() + + method = getattr(driver, method_name) + with warnings.catch_warnings(): + warnings.simplefilter("error") + if inspect.iscoroutinefunction(method): + await method(*args, **kwargs) + else: + method(*args, **kwargs) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index c05c821ce..18e1afa69 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -21,6 +21,7 @@ import inspect import ssl import typing as t +import warnings import pytest import typing_extensions as te @@ -962,12 +963,11 @@ def test_supports_session_auth(session_cls_mock) -> None: ("get_server_info", (), {}), ("supports_multi_db", (), {}), ("supports_session_auth", (), {}), - ("close", (), {}), ) ) @mark_sync_test -def test_using_closed_driver_is_deprecated( +def test_using_closed_driver_where_deprecated( method_name, args, kwargs, session_cls_mock ) -> None: driver = GraphDatabase.driver("bolt://localhost") @@ -982,3 +982,25 @@ def test_using_closed_driver_is_deprecated( method(*args, **kwargs) else: method(*args, **kwargs) + + +@pytest.mark.parametrize( + ("method_name", "args", "kwargs"), + ( + ("close", (), {}), + ) +) +@mark_sync_test +def test_using_closed_driver_where_not_deprecated( + method_name, args, kwargs, session_cls_mock +) -> None: + driver = GraphDatabase.driver("bolt://localhost") + driver.close() + + method = getattr(driver, method_name) + with warnings.catch_warnings(): + warnings.simplefilter("error") + if inspect.iscoroutinefunction(method): + method(*args, **kwargs) + else: + method(*args, **kwargs)