Skip to content

Commit b071b61

Browse files
authored
Merge pull request #549 from FlorentinD/sessions-block-non-remote-projections
Block local projections in AuraGDS
2 parents a055272 + 9c99329 commit b071b61

File tree

10 files changed

+94
-74
lines changed

10 files changed

+94
-74
lines changed

examples/dev/aura-only-features.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@
316316
"There are two key differences between the remote projection and Cypher projections V2:\n",
317317
"\n",
318318
"1. In AuraDB, the aggregating function does not take a graph name as a parameter.\n",
319-
"2. The aggregation function should only be called through the GDS Python Client endpoint `gds.graph.project.remoteDb`\n",
319+
"2. The aggregation function should only be called through the GDS Python Client endpoint `gds.graph.project`\n",
320320
"\n",
321321
"### Limitations\n",
322322
"\n",
@@ -335,7 +335,7 @@
335335
"metadata": {},
336336
"outputs": [],
337337
"source": [
338-
"G, result = gds.graph.project.remoteDb(\n",
338+
"G, result = gds.graph.project(\n",
339339
" \"pagerank-graph\",\n",
340340
" \"\"\"\n",
341341
" MATCH (u:User) \n",

graphdatascience/aura_graph_data_science.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional, Tuple
1+
from typing import Any, Dict, Optional, Tuple, Union
22

33
from neo4j import GraphDatabase
44
from pandas import DataFrame
@@ -8,11 +8,13 @@
88
from graphdatascience.call_builder import IndirectCallBuilder
99
from graphdatascience.endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
1010
from graphdatascience.error.uncallable_namespace import UncallableNamespace
11+
from graphdatascience.graph.graph_proc_runner import GraphRemoteProcRunner
1112
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
1213
from graphdatascience.query_runner.aura_db_arrow_query_runner import (
1314
AuraDbArrowQueryRunner,
1415
AuraDbConnectionInfo,
1516
)
17+
from graphdatascience.query_runner.query_runner import QueryRunner
1618

1719

1820
class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
@@ -23,46 +25,51 @@ class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
2325

2426
def __init__(
2527
self,
26-
endpoint: str,
28+
endpoint: Union[str, QueryRunner],
2729
auth: Tuple[str, str],
2830
aura_db_connection_info: AuraDbConnectionInfo,
2931
arrow_disable_server_verification: bool = True,
3032
arrow_tls_root_certs: Optional[bytes] = None,
3133
bookmarks: Optional[Any] = None,
3234
):
33-
gds_query_runner = ArrowQueryRunner.create(
34-
Neo4jQueryRunner.create(endpoint, auth, aura_ds=True),
35-
auth,
36-
True,
37-
arrow_disable_server_verification,
38-
arrow_tls_root_certs,
39-
)
40-
41-
self._server_version = gds_query_runner.server_version()
42-
43-
if self._server_version < ServerVersion(2, 6, 0):
44-
raise RuntimeError(
45-
f"AuraDB connection info was provided but GDS version {self._server_version} \
46-
does not support connecting to AuraDB"
35+
if isinstance(endpoint, str):
36+
gds_query_runner = ArrowQueryRunner.create(
37+
Neo4jQueryRunner.create(endpoint, auth, aura_ds=True),
38+
auth,
39+
True,
40+
arrow_disable_server_verification,
41+
arrow_tls_root_certs,
4742
)
4843

49-
self._driver_config = gds_query_runner.driver_config()
50-
driver = GraphDatabase.driver(
51-
aura_db_connection_info.uri, auth=aura_db_connection_info.auth, **self._driver_config
52-
)
53-
self._db_query_runner = Neo4jQueryRunner(
54-
driver, auto_close=True, bookmarks=bookmarks, server_version=self._server_version
55-
)
44+
self._server_version = gds_query_runner.server_version()
5645

57-
# we need to explicitly set these as the default value is None
58-
# which signals the driver to use the default configured database
59-
# from the dbms.
60-
gds_query_runner.set_database("neo4j")
61-
self._db_query_runner.set_database("neo4j")
46+
if self._server_version < ServerVersion(2, 6, 0):
47+
raise RuntimeError(
48+
f"AuraDB connection info was provided but GDS version {self._server_version} \
49+
does not support connecting to AuraDB"
50+
)
6251

63-
self._query_runner = AuraDbArrowQueryRunner(
64-
gds_query_runner, self._db_query_runner, driver.encrypted, aura_db_connection_info
65-
)
52+
self._driver_config = gds_query_runner.driver_config()
53+
driver = GraphDatabase.driver(
54+
aura_db_connection_info.uri, auth=aura_db_connection_info.auth, **self._driver_config
55+
)
56+
self._db_query_runner: QueryRunner = Neo4jQueryRunner(
57+
driver, auto_close=True, bookmarks=bookmarks, server_version=self._server_version
58+
)
59+
60+
# we need to explicitly set these as the default value is None
61+
# which signals the driver to use the default configured database
62+
# from the dbms.
63+
gds_query_runner.set_database("neo4j")
64+
self._db_query_runner.set_database("neo4j")
65+
66+
self._query_runner = AuraDbArrowQueryRunner(
67+
gds_query_runner, self._db_query_runner, driver.encrypted, aura_db_connection_info
68+
)
69+
else:
70+
self._query_runner = endpoint
71+
self._db_query_runner = endpoint
72+
self._server_version = self._query_runner.server_version()
6673

6774
super().__init__(self._query_runner, "gds", self._server_version)
6875

@@ -87,6 +94,10 @@ def run_cypher(
8794
# This will avoid calling valid gds procedures through a raw string
8895
return self._db_query_runner.run_cypher(query, params, database, False)
8996

97+
@property
98+
def graph(self) -> GraphRemoteProcRunner:
99+
return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
100+
90101
@property
91102
def alpha(self) -> AlphaEndpoints:
92103
return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version)

graphdatascience/endpoints.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
SingleModeAlphaAlgoEndpoints,
44
)
55
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
6-
from .graph.graph_endpoints import (
7-
GraphAlphaEndpoints,
8-
GraphBetaEndpoints,
9-
GraphEndpoints,
10-
)
6+
from .graph.graph_endpoints import GraphAlphaEndpoints, GraphBetaEndpoints
117
from .model.model_endpoints import (
128
ModelAlphaEndpoints,
139
ModelBetaEndpoints,
@@ -39,7 +35,6 @@ class DirectEndpoints(
3935
SingleModeAlgoEndpoints,
4036
DirectSystemEndpoints,
4137
DirectUtilEndpoints,
42-
GraphEndpoints,
4338
PipelineEndpoints,
4439
ModelEndpoints,
4540
ConfigEndpoints,

graphdatascience/graph/graph_endpoints.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
from ..caller_base import CallerBase
22
from .graph_alpha_proc_runner import GraphAlphaProcRunner
33
from .graph_beta_proc_runner import GraphBetaProcRunner
4-
from .graph_proc_runner import GraphProcRunner
5-
6-
7-
class GraphEndpoints(CallerBase):
8-
@property
9-
def graph(self) -> GraphProcRunner:
10-
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
114

125

136
class GraphAlphaEndpoints(CallerBase):

graphdatascience/graph/graph_proc_runner.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from .graph_export_runner import GraphExportRunner
2626
from .graph_object import Graph
27-
from .graph_project_runner import GraphProjectRunner
27+
from .graph_project_runner import GraphProjectRemoteRunner, GraphProjectRunner
2828
from .graph_sample_runner import GraphSampleRunner
2929
from .graph_type_check import (
3030
from_graph_type_check,
@@ -41,7 +41,7 @@
4141
is_neo4j_4_driver = ServerVersion.from_string(neo4j_driver_version) < ServerVersion(5, 0, 0)
4242

4343

44-
class GraphProcRunner(UncallableNamespace, IllegalAttrChecker):
44+
class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
4545
@staticmethod
4646
def _path(package: str, resource: str) -> pathlib.Path:
4747
if sys.version_info >= (3, 9):
@@ -217,11 +217,6 @@ def networkx(self): # type: ignore
217217
self._namespace += ".networkx"
218218
return NXLoader(self._query_runner, self._namespace, self._server_version)
219219

220-
@property
221-
def project(self) -> GraphProjectRunner:
222-
self._namespace += ".project"
223-
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
224-
225220
@property
226221
@compatible_with("graphProperty", min_inclusive=ServerVersion(2, 5, 0))
227222
def graphProperty(self) -> GraphPropertyRunner:
@@ -234,11 +229,6 @@ def nodeLabel(self) -> GraphLabelRunner:
234229
self._namespace += ".nodeLabel"
235230
return GraphLabelRunner(self._query_runner, self._namespace, self._server_version)
236231

237-
@property
238-
def cypher(self) -> GraphCypherRunner:
239-
self._namespace += ".project"
240-
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
241-
242232
@compatible_with("generate", min_inclusive=ServerVersion(2, 5, 0))
243233
def generate(self, graph_name: str, node_count: int, average_degree: int, **config: Any) -> GraphCreateResult:
244234
self._namespace += ".generate"
@@ -568,3 +558,22 @@ def deleteRelationships(self, G: Graph, relationship_type: str) -> "Series[Any]"
568558
endpoint=self._namespace,
569559
params=params,
570560
).squeeze()
561+
562+
563+
class GraphProcRunner(BaseGraphProcRunner):
564+
@property
565+
def project(self) -> GraphProjectRunner:
566+
self._namespace += ".project"
567+
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
568+
569+
@property
570+
def cypher(self) -> GraphCypherRunner:
571+
self._namespace += ".project"
572+
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
573+
574+
575+
class GraphRemoteProcRunner(BaseGraphProcRunner):
576+
@property
577+
def project(self) -> GraphProjectRemoteRunner:
578+
self._namespace += ".project.remoteDb"
579+
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)

graphdatascience/graph/graph_project_runner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ def estimate(self, node_projection: Any, relationship_projection: Any, **config:
4646
def cypher(self) -> GraphProjectRunner:
4747
return GraphProjectRunner(self._query_runner, self._namespace + ".cypher", self._server_version)
4848

49-
@property
50-
def remoteDb(self) -> GraphProjectRemoteRunner:
51-
return GraphProjectRemoteRunner(self._query_runner, self._namespace + ".remoteDb", self._server_version)
52-
5349

5450
class GraphProjectBetaRunner(IllegalAttrChecker):
5551
@from_graph_type_check
@@ -79,7 +75,7 @@ def subgraph(
7975

8076

8177
class GraphProjectRemoteRunner(IllegalAttrChecker):
82-
@compatible_with("remoteDb", min_inclusive=ServerVersion(2, 6, 0))
78+
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
8379
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
8480
placeholder = "<>" # host and token will be added by query runner
8581
params = CallParameters(

graphdatascience/graph_data_science.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
1313
from .query_runner.query_runner import QueryRunner
1414
from .server_version.server_version import ServerVersion
15+
from graphdatascience.graph.graph_proc_runner import GraphProcRunner
1516

1617

1718
class GraphDataScience(DirectEndpoints, UncallableNamespace):
@@ -75,6 +76,10 @@ def __init__(
7576

7677
super().__init__(self._query_runner, "gds", self._server_version)
7778

79+
@property
80+
def graph(self) -> GraphProcRunner:
81+
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
82+
7883
@property
7984
def alpha(self) -> AlphaEndpoints:
8085
return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version)

graphdatascience/tests/integration/test_remote_graph_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No
3737
@pytest.mark.cloud_architecture
3838
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0))
3939
def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None:
40-
G, result = gds_with_cloud_setup.graph.project.remoteDb(
41-
GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)"
42-
)
40+
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
4341

4442
assert G.name() == GRAPH_NAME
4543
assert result["nodeCount"] == 3
@@ -48,9 +46,7 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None:
4846
@pytest.mark.cloud_architecture
4947
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0))
5048
def test_remote_write_back(gds_with_cloud_setup: AuraGraphDataScience) -> None:
51-
G, result = gds_with_cloud_setup.graph.project.remoteDb(
52-
GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)"
53-
)
49+
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
5450

5551
result = gds_with_cloud_setup.pageRank.write(G, writeProperty="score")
5652

graphdatascience/tests/unit/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
from pandas import DataFrame
55

66
from graphdatascience import QueryRunner
7+
from graphdatascience.aura_graph_data_science import AuraGraphDataScience
78
from graphdatascience.call_parameters import CallParameters
89
from graphdatascience.graph_data_science import GraphDataScience
10+
from graphdatascience.query_runner.aura_db_arrow_query_runner import (
11+
AuraDbConnectionInfo,
12+
)
913
from graphdatascience.query_runner.cypher_graph_constructor import (
1014
CypherGraphConstructor,
1115
)
@@ -109,6 +113,16 @@ def gds(runner: CollectingQueryRunner) -> Generator[GraphDataScience, None, None
109113
gds.close()
110114

111115

116+
@pytest.fixture
117+
def aura_gds(runner: CollectingQueryRunner) -> Generator[AuraGraphDataScience, None, None]:
118+
aura_gds = AuraGraphDataScience(
119+
endpoint=runner, auth=("some", "auth"), aura_db_connection_info=AuraDbConnectionInfo("uri", ("some", "auth"))
120+
)
121+
yield aura_gds
122+
123+
aura_gds.close()
124+
125+
112126
@pytest.fixture(scope="package")
113127
def server_version() -> ServerVersion:
114128
return DEFAULT_SERVER_VERSION

graphdatascience/tests/unit/test_graph_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pandas import DataFrame
33

44
from .conftest import CollectingQueryRunner
5+
from graphdatascience.aura_graph_data_science import AuraGraphDataScience
56
from graphdatascience.graph_data_science import GraphDataScience
67
from graphdatascience.server_version.server_version import ServerVersion
78

@@ -89,8 +90,8 @@ def test_project_subgraph(runner: CollectingQueryRunner, gds: GraphDataScience)
8990
}
9091

9192

92-
def test_project_remote(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
93-
gds.graph.project.remoteDb("g", "RETURN gds.graph.project.remote(0, 1, null)")
93+
def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None:
94+
aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)")
9495

9596
assert (
9697
runner.last_query()
@@ -687,9 +688,9 @@ def test_graph_sample_cnarw(runner: CollectingQueryRunner, gds: GraphDataScience
687688

688689

689690
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0))
690-
def test_remote_projection_on_specific_database(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
691-
gds.set_database("bar")
692-
G, _ = gds.graph.project.remoteDb("g", "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
691+
def test_remote_projection_on_specific_database(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None:
692+
aura_gds.set_database("bar")
693+
G, _ = aura_gds.graph.project("g", "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
693694

694695
assert (
695696
runner.last_query()

0 commit comments

Comments
 (0)