Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions lightrag/kg/neo4j_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@
logging.getLogger("neo4j").setLevel(logging.ERROR)


READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)

READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)


@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
Expand Down Expand Up @@ -352,6 +369,7 @@ async def index_done_callback(self) -> None:
# Neo4J handles persistence automatically
pass

@READ_RETRY
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
Expand Down Expand Up @@ -385,6 +403,7 @@ async def has_node(self, node_id: str) -> bool:
await result.consume() # Ensure results are consumed even on error
raise

@READ_RETRY
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
Expand Down Expand Up @@ -426,6 +445,7 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
await result.consume() # Ensure results are consumed even on error
raise

@READ_RETRY
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties

Expand Down Expand Up @@ -479,6 +499,7 @@ async def get_node(self, node_id: str) -> dict[str, str] | None:
)
raise

@READ_RETRY
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Expand Down Expand Up @@ -515,6 +536,7 @@ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
await result.consume() # Make sure to consume the result fully
return nodes

@READ_RETRY
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
Expand Down Expand Up @@ -563,6 +585,7 @@ async def node_degree(self, node_id: str) -> int:
)
raise

@READ_RETRY
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Expand Down Expand Up @@ -621,6 +644,7 @@ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
degrees = int(src_degree) + int(trg_degree)
return degrees

@READ_RETRY
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
Expand All @@ -647,6 +671,7 @@ async def edge_degrees_batch(
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees

@READ_RETRY
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
Expand Down Expand Up @@ -734,6 +759,7 @@ async def get_edge(
)
raise

@READ_RETRY
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
Expand Down Expand Up @@ -784,6 +810,7 @@ async def get_edges_batch(
await result.consume()
return edges_dict

@READ_RETRY
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.

Expand Down Expand Up @@ -851,6 +878,7 @@ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | N
)
raise

@READ_RETRY
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pytest = [
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pre-commit",
"ruff",
]

api = [
Expand Down Expand Up @@ -132,10 +133,11 @@ offline = [
]

evaluation = [
# Test framework dependencies (for evaluation)
# Test framework dependencies
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pre-commit",
"ruff",
# RAG evaluation dependencies (RAGAS framework)
"ragas>=0.3.7",
"datasets>=4.3.0",
Expand Down
30 changes: 30 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading