Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions lightrag/kg/postgres_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,10 +1843,11 @@ async def get_by_id(self, id: str) -> dict[str, Any] | None:
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
if not ids:
return []

sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, list(params.values()), multirows=True)

def _order_results(
Expand Down Expand Up @@ -1949,11 +1950,12 @@ def _order_results(

async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()

table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
Expand Down Expand Up @@ -2532,11 +2534,12 @@ async def finalize(self):

async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()

table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
Expand Down Expand Up @@ -2849,34 +2852,41 @@ async def get_docs_paginated(
elif page_size > 200:
page_size = 200

if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
# Whitelist validation for sort_field to prevent SQL injection
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
if sort_field not in allowed_sort_fields:
sort_field = "updated_at"

# Whitelist validation for sort_direction to prevent SQL injection
if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc"
else:
sort_direction = sort_direction.lower()

# Calculate offset
offset = (page - 1) * page_size

# Build WHERE clause
where_clause = "WHERE workspace=$1"
# Build parameterized query components
params = {"workspace": self.workspace}
param_count = 1

# Build WHERE clause with parameterized query
if status_filter is not None:
param_count += 1
where_clause += f" AND status=${param_count}"
where_clause = "WHERE workspace=$1 AND status=$2"
params["status"] = status_filter.value
else:
where_clause = "WHERE workspace=$1"

# Build ORDER BY clause
# Build ORDER BY clause using validated whitelist values
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"

# Query for total count
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0

# Query for paginated data
# Query for paginated data with parameterized LIMIT and OFFSET
data_sql = f"""
SELECT * FROM LIGHTRAG_DOC_STATUS
{where_clause}
Expand Down Expand Up @@ -4874,19 +4884,19 @@ def namespace_to_table_name(namespace: str) -> str:
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
COALESCE(doc_name, '') as file_path
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_id_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
Expand All @@ -4901,12 +4911,12 @@ def namespace_to_table_name(namespace: str) -> str:
"get_by_ids_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
Expand Down
Loading