Skip to content

Commit 2832a2c

Browse files
authored
Merge pull request #2417 from danielaskdd/neo4j-retry
Fix: Add Comprehensive Retry Mechanism for Neo4j Storage Operations
2 parents 5b81ef0 + 5f91063 commit 2832a2c

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

lightrag/kg/neo4j_impl.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@
4444
logging.getLogger("neo4j").setLevel(logging.ERROR)
4545

4646

47+
READ_RETRY_EXCEPTIONS = (
48+
neo4jExceptions.ServiceUnavailable,
49+
neo4jExceptions.TransientError,
50+
neo4jExceptions.SessionExpired,
51+
ConnectionResetError,
52+
OSError,
53+
AttributeError,
54+
)
55+
56+
READ_RETRY = retry(
57+
stop=stop_after_attempt(3),
58+
wait=wait_exponential(multiplier=1, min=4, max=10),
59+
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
60+
reraise=True,
61+
)
62+
63+
4764
@final
4865
@dataclass
4966
class Neo4JStorage(BaseGraphStorage):
@@ -352,6 +369,7 @@ async def index_done_callback(self) -> None:
352369
# Neo4J handles persistence automatically
353370
pass
354371

372+
@READ_RETRY
355373
async def has_node(self, node_id: str) -> bool:
356374
"""
357375
Check if a node with the given label exists in the database
@@ -385,6 +403,7 @@ async def has_node(self, node_id: str) -> bool:
385403
await result.consume() # Ensure results are consumed even on error
386404
raise
387405

406+
@READ_RETRY
388407
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
389408
"""
390409
Check if an edge exists between two nodes
@@ -426,6 +445,7 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
426445
await result.consume() # Ensure results are consumed even on error
427446
raise
428447

448+
@READ_RETRY
429449
async def get_node(self, node_id: str) -> dict[str, str] | None:
430450
"""Get node by its label identifier, return only node properties
431451
@@ -479,6 +499,7 @@ async def get_node(self, node_id: str) -> dict[str, str] | None:
479499
)
480500
raise
481501

502+
@READ_RETRY
482503
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
483504
"""
484505
Retrieve multiple nodes in one query using UNWIND.
@@ -515,6 +536,7 @@ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
515536
await result.consume() # Make sure to consume the result fully
516537
return nodes
517538

539+
@READ_RETRY
518540
async def node_degree(self, node_id: str) -> int:
519541
"""Get the degree (number of relationships) of a node with the given label.
520542
If multiple nodes have the same label, returns the degree of the first node.
@@ -563,6 +585,7 @@ async def node_degree(self, node_id: str) -> int:
563585
)
564586
raise
565587

588+
@READ_RETRY
566589
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
567590
"""
568591
Retrieve the degree for multiple nodes in a single query using UNWIND.
@@ -621,6 +644,7 @@ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
621644
degrees = int(src_degree) + int(trg_degree)
622645
return degrees
623646

647+
@READ_RETRY
624648
async def edge_degrees_batch(
625649
self, edge_pairs: list[tuple[str, str]]
626650
) -> dict[tuple[str, str], int]:
@@ -647,6 +671,7 @@ async def edge_degrees_batch(
647671
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
648672
return edge_degrees
649673

674+
@READ_RETRY
650675
async def get_edge(
651676
self, source_node_id: str, target_node_id: str
652677
) -> dict[str, str] | None:
@@ -734,6 +759,7 @@ async def get_edge(
734759
)
735760
raise
736761

762+
@READ_RETRY
737763
async def get_edges_batch(
738764
self, pairs: list[dict[str, str]]
739765
) -> dict[tuple[str, str], dict]:
@@ -784,6 +810,7 @@ async def get_edges_batch(
784810
await result.consume()
785811
return edges_dict
786812

813+
@READ_RETRY
787814
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
788815
"""Retrieves all edges (relationships) for a particular node identified by its label.
789816
@@ -851,6 +878,7 @@ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | N
851878
)
852879
raise
853880

881+
@READ_RETRY
854882
async def get_nodes_edges_batch(
855883
self, node_ids: list[str]
856884
) -> dict[str, list[tuple[str, str]]]:

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pytest = [
4747
"pytest>=8.4.2",
4848
"pytest-asyncio>=1.2.0",
4949
"pre-commit",
50+
"ruff",
5051
]
5152

5253
api = [
@@ -132,10 +133,11 @@ offline = [
132133
]
133134

134135
evaluation = [
135-
# Test framework dependencies (for evaluation)
136+
# Test framework dependencies
136137
"pytest>=8.4.2",
137138
"pytest-asyncio>=1.2.0",
138139
"pre-commit",
140+
"ruff",
139141
# RAG evaluation dependencies (RAGAS framework)
140142
"ragas>=0.3.7",
141143
"datasets>=4.3.0",

uv.lock

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)