From 26b63323b0331c1bfa91289f7e234eefd9586b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Barc=C3=A9los?= Date: Tue, 23 Jul 2024 14:45:24 +0200 Subject: [PATCH 1/2] Fix Notification.description polyfill from GqlStatusObject Bolt 5.6 introduces the original notification description back in the protocol level. This avoids the `Notification.description` changes when connected to GQL aware servers. This issues was detected during homologation, so the problem won't happen with any released server since the bolt version which miss information will not be released. --- src/neo4j/_async/io/_bolt.py | 7 +- src/neo4j/_async/io/_bolt5.py | 36 +- src/neo4j/_sync/io/_bolt.py | 7 +- src/neo4j/_sync/io/_bolt5.py | 36 +- src/neo4j/_work/summary.py | 2 +- testkitbackend/test_config.json | 1 + tests/unit/async_/io/test_class_bolt.py | 12 +- tests/unit/async_/io/test_class_bolt5x5.py | 10 +- tests/unit/async_/io/test_class_bolt5x6.py | 666 +++++++++++++++++++++ tests/unit/common/work/test_summary.py | 31 +- tests/unit/sync/io/test_class_bolt.py | 12 +- tests/unit/sync/io/test_class_bolt5x5.py | 10 +- tests/unit/sync/io/test_class_bolt5x6.py | 666 +++++++++++++++++++++ 13 files changed, 1461 insertions(+), 35 deletions(-) create mode 100644 tests/unit/async_/io/test_class_bolt5x6.py create mode 100644 tests/unit/sync/io/test_class_bolt5x6.py diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index d161371f0..e8923e33e 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -284,6 +284,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x3, AsyncBolt5x4, AsyncBolt5x5, + AsyncBolt5x6, ) handlers = { @@ -299,6 +300,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x3.PROTOCOL_VERSION: AsyncBolt5x3, AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, + AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, } if protocol_version is None: @@ -413,7 +415,10 @@ async def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 5): + if protocol_version == (5, 6): + from ._bolt5 import AsyncBolt5x6 + bolt_cls = AsyncBolt5x6 + elif protocol_version == (5, 5): from ._bolt5 import AsyncBolt5x5 bolt_cls = AsyncBolt5x5 elif protocol_version == (5, 4): diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index cd015f6e8..2b04b61a2 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -783,7 +783,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, ("CURRENT_SCHEMA", "/"), ) - def _make_enrich_diagnostic_record_handler(self, wrapped_handler=None): + def _make_enrich_statuses_handler(self, wrapped_handler=None): async def handler(metadata): def enrich(metadata_): if not isinstance(metadata_, dict): @@ -794,6 +794,7 @@ def enrich(metadata_): for status in statuses: if not isinstance(status, dict): continue + status["description"] = status.get("status_description") diag_record = status.setdefault("diagnostic_record", {}) if not isinstance(diag_record, dict): log.info("[#%04X] _: Server supplied an " @@ -810,14 +811,43 @@ def enrich(metadata_): def discard(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().discard(n, qid, dehydration_hooks, hydration_hooks, **handlers) def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + +class AsyncBolt5x6(AsyncBolt5x5): + + PROTOCOL_VERSION = Version(5, 6) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + async def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info("[#%04X] _: Server supplied an " + "invalid diagnostic record (%r).", + self.local_port, diag_record) + continue + for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + await AsyncUtil.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 92258c2c1..1c6763900 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -284,6 +284,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x3, Bolt5x4, Bolt5x5, + Bolt5x6, ) handlers = { @@ -299,6 +300,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x3.PROTOCOL_VERSION: Bolt5x3, Bolt5x4.PROTOCOL_VERSION: Bolt5x4, Bolt5x5.PROTOCOL_VERSION: Bolt5x5, + Bolt5x6.PROTOCOL_VERSION: Bolt5x6, } if protocol_version is None: @@ -413,7 +415,10 @@ def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 5): + if protocol_version == (5, 6): + from ._bolt5 import Bolt5x6 + bolt_cls = Bolt5x6 + elif protocol_version == (5, 5): from ._bolt5 import Bolt5x5 bolt_cls = Bolt5x5 elif protocol_version == (5, 4): diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 12740a6c5..af321ffd5 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -783,7 +783,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, ("CURRENT_SCHEMA", "/"), ) - def _make_enrich_diagnostic_record_handler(self, wrapped_handler=None): + def _make_enrich_statuses_handler(self, wrapped_handler=None): def handler(metadata): def enrich(metadata_): if not isinstance(metadata_, dict): @@ -794,6 +794,7 @@ def enrich(metadata_): for status in statuses: if not isinstance(status, dict): continue + status["description"] = status.get("status_description") diag_record = status.setdefault("diagnostic_record", {}) if not isinstance(diag_record, dict): log.info("[#%04X] _: Server supplied an " @@ -810,14 +811,43 @@ def enrich(metadata_): def discard(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().discard(n, qid, dehydration_hooks, hydration_hooks, **handlers) def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + +class Bolt5x6(Bolt5x5): + + PROTOCOL_VERSION = Version(5, 6) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info("[#%04X] _: Server supplied an " + "invalid diagnostic record (%r).", + self.local_port, diag_record) + continue + for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + Util.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_work/summary.py b/src/neo4j/_work/summary.py index 3b795d3ba..3029c53f5 100644 --- a/src/neo4j/_work/summary.py +++ b/src/neo4j/_work/summary.py @@ -163,7 +163,7 @@ def _set_notifications(self): for notification_key, status_key in ( ("title", "title"), ("code", "neo4j_code"), - ("description", "status_description"), + ("description", "description"), ): value = status.get(status_key) if not isinstance(value, str) or not value: diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 7e16d0cfc..e4d8b14b5 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -57,6 +57,7 @@ "Feature:Bolt:5.3": true, "Feature:Bolt:5.4": true, "Feature:Bolt:5.5": true, + "Feature:Bolt:5.6": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 7d1e5217b..686b5a393 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } protocol_handlers = AsyncBolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 3), 1), ((5, 4), 1), ((5, 5), 1), - ((5, 6), 0), + ((5, 6), 1), + ((5, 7), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() - assert (b"\x00\x05\x05\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -134,6 +135,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 3), "neo4j._async.io._bolt5.AsyncBolt5x3"), ((5, 4), "neo4j._async.io._bolt5.AsyncBolt5x4"), ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), + ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ), ) @mark_async_test @@ -166,14 +168,14 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 6), + (5, 7), (6, 0), )) @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index 3da6dee7c..413bf9abf 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -615,7 +615,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): ) @pytest.mark.parametrize("method", ("pull", "discard")) @mark_async_test -async def test_enriches_diagnostic_record( +async def test_enriches_statuses( sent_diag_records, method, fake_socket_pair, @@ -628,7 +628,9 @@ async def test_enriches_diagnostic_record( sent_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description" } for r in sent_diag_records ] } @@ -654,7 +656,9 @@ def extend_diag_record(r): expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] expected_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "the status description" } for r in expected_diag_records ] } diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py new file mode 100644 index 000000000..84521bfb9 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -0,0 +1,666 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x6 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6( + address, socket, AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_clss, method_dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6( + address, socket, AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_clss", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss + ) + + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + limit=3, + ) +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 6e86436e5..bb3610716 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -190,6 +190,7 @@ def test_statuses_and_notifications_dont_mix(summary_args_kwargs) -> None: raw_status = { "gql_status": "12345", "status_description": "cool description", + "description": "cool notification description", "neo4j_code": "Neo.Foo.Bar.Baz", "title": "nice title", "diagnostic_record": raw_diag_rec, @@ -314,6 +315,7 @@ def make_raw_status( "gql_status": gql_status, "status_description": "note: successful completion - " f"custom stuff {i}", + "description": f"notification description {i}", "neo4j_code": f"Neo.Foo.Bar.{type_}-{i}", "title": f"Some cool title which defo is dope! {i}", "diagnostic_record": { @@ -606,15 +608,18 @@ def test_status( ) -> None: args, kwargs = summary_args_kwargs default_position = SummaryInputPosition(line=1337, column=42, offset=420) - default_description = "some nice description goes here" + default_status_description = "some nice description goes here" + default_description = "some nice notification description here" default_severity = "WARNING" default_classification = "HINT" default_code = "Neo.Cool.Legacy.Code" default_title = "Cool Title" default_gql_status = "12345" + raw_status: t.Dict[str, t.Any] = { "gql_status": default_gql_status, - "status_description": default_description, + "status_description": default_status_description, + "description": default_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { @@ -661,7 +666,7 @@ def test_status( == expectation_overwrite.get("gql_status", default_gql_status)) assert (status.status_description == expectation_overwrite.get("status_description", - default_description)) + default_status_description)) assert (status.position == expectation_overwrite.get("position", default_position)) assert (status.raw_classification @@ -837,6 +842,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 3), "t_first"), ((5, 4), "t_first"), ((5, 5), "t_first"), + ((5, 6), "t_first"), )) def test_summary_result_available_after( summary_args_kwargs, exists, bolt_version, meta_name @@ -869,6 +875,7 @@ def test_summary_result_available_after( ((5, 3), "t_last"), ((5, 4), "t_last"), ((5, 5), "t_last"), + ((5, 6), "t_last"), )) def test_summary_result_consumed_after( summary_args_kwargs, exists, bolt_version, meta_name @@ -1438,9 +1445,9 @@ def test_no_notification_from_status(raw_status, summary_args_kwargs) -> None: ("FOOBAR", None, ..., -1, 1.6, False, [], {})) ), - # copies status_description to description + # copies description to description ( - {"status_description": "something completely different 👀"}, {}, + {"description": "something completely different 👀"}, {}, {"description": "something completely different 👀"} ), @@ -1537,15 +1544,17 @@ def test_notification_from_status( summary_args_kwargs ) -> None: default_status = "03BAZ" - default_description = "note: successful completion - custom stuff" + default_status_description = "note: successful completion - custom stuff" default_code = "Neo.Foo.Bar.Baz" default_title = "Some cool title which defo is dope!" default_severity = "INFORMATION" default_classification = "HINT" + default_description = "nice message" default_position = SummaryInputPosition(line=1337, column=42, offset=420) raw_status_obj: t.Dict[str, t.Any] = { "gql_status": default_status, - "status_description": default_description, + "status_description": default_status_description, + "description": default_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { @@ -1767,7 +1776,7 @@ def test_broken_diagnostic_record(in_status, summary_args_kwargs) -> None: ("status_overwrite", "diagnostic_record_overwrite"), ( *( - ({"status_description": value}, {}) + ({"description": value}, {}) for value in t.cast(t.Iterable[t.Any], ("", None, ..., 1, False, [], {})) ), @@ -1818,7 +1827,8 @@ def test_no_notification_from_broken_status( status_overwrite, diagnostic_record_overwrite, summary_args_kwargs ) -> None: default_status = "03BAZ" - default_description = "note: successful completion - custom stuff" + default_status_description = "note: successful completion - custom stuff" + default_description = "some description" default_code = "Neo.Foo.Bar.Baz" default_title = "Some cool title which defo is dope!" default_severity = "INFORMATION" @@ -1826,7 +1836,8 @@ def test_no_notification_from_broken_status( default_position = SummaryInputPosition(line=1337, column=42, offset=420) raw_status_obj: t.Dict[str, t.Any] = { "gql_status": default_status, - "status_description": default_description, + "description": default_description, + "status_description": default_status_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index d07f673c1..b0d000cf9 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } protocol_handlers = Bolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 3), 1), ((5, 4), 1), ((5, 5), 1), - ((5, 6), 0), + ((5, 6), 1), + ((5, 7), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = Bolt.get_handshake() - assert (b"\x00\x05\x05\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -134,6 +135,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 3), "neo4j._sync.io._bolt5.Bolt5x3"), ((5, 4), "neo4j._sync.io._bolt5.Bolt5x4"), ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), + ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ), ) @mark_sync_test @@ -166,14 +168,14 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 6), + (5, 7), (6, 0), )) @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 550cf8546..fe30b7e9a 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -615,7 +615,7 @@ def test_tracks_last_database(fake_socket_pair, actions): ) @pytest.mark.parametrize("method", ("pull", "discard")) @mark_sync_test -def test_enriches_diagnostic_record( +def test_enriches_statuses( sent_diag_records, method, fake_socket_pair, @@ -628,7 +628,9 @@ def test_enriches_diagnostic_record( sent_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description" } for r in sent_diag_records ] } @@ -654,7 +656,9 @@ def extend_diag_record(r): expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] expected_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "the status description" } for r in expected_diag_records ] } diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py new file mode 100644 index 000000000..0504f0731 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -0,0 +1,666 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x6 + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_clss, method_dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6( + address, socket, PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_clss", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss + ) + + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + limit=3, + ) +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata From 75e4afe71ec5efa6a49e76bae634f567d40ba8e2 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 29 Jul 2024 09:17:39 +0200 Subject: [PATCH 2/2] Code formatting --- src/neo4j/_async/io/_bolt5.py | 2 ++ src/neo4j/_sync/io/_bolt5.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 2b04b61a2..20e5ef7bd 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -685,6 +685,7 @@ def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, Response(self, "telemetry", hydration_hooks, **handlers), dehydration_hooks=dehydration_hooks) + class AsyncBolt5x5(AsyncBolt5x4): PROTOCOL_VERSION = Version(5, 5) @@ -823,6 +824,7 @@ def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + class AsyncBolt5x6(AsyncBolt5x5): PROTOCOL_VERSION = Version(5, 6) diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index af321ffd5..86924b8a3 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -685,6 +685,7 @@ def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, Response(self, "telemetry", hydration_hooks, **handlers), dehydration_hooks=dehydration_hooks) + class Bolt5x5(Bolt5x4): PROTOCOL_VERSION = Version(5, 5) @@ -823,6 +824,7 @@ def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + class Bolt5x6(Bolt5x5): PROTOCOL_VERSION = Version(5, 6)