From 4dea3d3685051c5531a7a96a0b2c6fa0301518a0 Mon Sep 17 00:00:00 2001 From: kaustubh-darekar Date: Mon, 19 May 2025 09:58:28 +0000 Subject: [PATCH 1/2] added validation for tuple schema relationships --- backend/score.py | 1 + backend/src/llm.py | 51 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/backend/score.py b/backend/score.py index 7edb541ad..11209a0b4 100644 --- a/backend/score.py +++ b/backend/score.py @@ -4,6 +4,7 @@ from src.main import * from src.QA_integration import * from src.shared.common_fn import * +from src.shared.llm_graph_builder_exception import LLMGraphBuilderException import uvicorn import asyncio import base64 diff --git a/backend/src/llm.py b/backend/src/llm.py index a36c7cc33..854e5926f 100644 --- a/backend/src/llm.py +++ b/backend/src/llm.py @@ -14,8 +14,9 @@ import boto3 import google.auth from src.shared.constants import ADDITIONAL_INSTRUCTIONS +from src.shared.llm_graph_builder_exception import LLMGraphBuilderException import re -import json +from typing import List def get_llm(model: str): """Retrieve the specified language model based on the model name.""" @@ -209,21 +210,45 @@ async def get_graph_document_list( return graph_document_list async def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship, chunks_to_combine, additional_instructions=None): + try: + llm, model_name = get_llm(model) + logging.info(f"Using model: {model_name}") - llm, model_name = get_llm(model) - combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine) + combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine) + logging.info(f"Combined {len(combined_chunk_document_list)} chunks") - allowedNodes = allowedNodes.split(',') if allowedNodes else [] + allowed_nodes = [node.strip() for node in allowedNodes.split(',') if node.strip()] + logging.info(f"Allowed nodes: {allowed_nodes}") + + allowed_relationships = [] + if allowedRelationship: + items = [item.strip() for item in allowedRelationship.split(',') if item.strip()] + if len(items) % 3 != 0: + raise LLMGraphBuilderException("allowedRelationship must be a multiple of 3 (source, relationship, target)") + for i in range(0, len(items), 3): + source, relation, target = items[i:i + 3] + if source not in allowed_nodes or target not in allowed_nodes: + raise LLMGraphBuilderException( + f"Invalid relationship ({source}, {relation}, {target}): " + f"source or target not in allowedNodes" + ) + allowed_relationships.append((source, relation, target)) + logging.info(f"Allowed relationships: {allowed_relationships}") + else: + logging.info("No allowed relationships provided") - if not allowedRelationship: - allowedRelationship = [] - else: - items = allowedRelationship.split(',') - allowedRelationship = [tuple(items[i:i+3]) for i in range(0, len(items), 3)] - graph_document_list = await get_graph_document_list( - llm, combined_chunk_document_list, allowedNodes, allowedRelationship, additional_instructions - ) - return graph_document_list + graph_document_list = await get_graph_document_list( + llm, + combined_chunk_document_list, + allowed_nodes, + allowed_relationships, + additional_instructions + ) + logging.info(f"Generated {len(graph_document_list)} graph documents") + return graph_document_list + except Exception as e: + logging.error(f"Error in get_graph_from_llm: {e}", exc_info=True) + raise LLMGraphBuilderException(f"Error in getting graph from llm: {e}") def sanitize_additional_instruction(instruction: str) -> str: """ From c3598731830a953daa949789ce6d46dd85b12ad2 Mon Sep 17 00:00:00 2001 From: Prakriti Solankey <156313631+prakriti-solankey@users.noreply.github.com> Date: Mon, 19 May 2025 11:51:35 +0000 Subject: [PATCH 2/2] remove duplicate --- frontend/src/components/Layout/PageLayout.tsx | 8 ++++---- .../EnitityExtraction/NewEntityExtractionSetting.tsx | 6 +++--- frontend/src/utils/Utils.ts | 8 +++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/frontend/src/components/Layout/PageLayout.tsx b/frontend/src/components/Layout/PageLayout.tsx index d0826daa4..5d544f4f8 100644 --- a/frontend/src/components/Layout/PageLayout.tsx +++ b/frontend/src/components/Layout/PageLayout.tsx @@ -22,7 +22,7 @@ import LoadDBSchemaDialog from '../Popups/GraphEnhancementDialog/EnitityExtracti import PredefinedSchemaDialog from '../Popups/GraphEnhancementDialog/EnitityExtraction/PredefinedSchemaDialog'; import { SKIP_AUTH } from '../../utils/Constants'; import { useNavigate } from 'react-router'; -import { deduplicateByRelationshipTypeOnly, deduplicateNodeByValue } from '../../utils/Utils'; +import { deduplicateByFullPattern, deduplicateNodeByValue } from '../../utils/Utils'; const GCSModal = lazy(() => import('../DataSources/GCS/GCSModal')); const S3Modal = lazy(() => import('../DataSources/AWS/S3Modal')); @@ -378,7 +378,7 @@ const PageLayout: React.FC = () => { setSchemaValRels(rels); setCombinedRelsVal((prevRels: OptionType[]) => { const combined = [...rels, ...prevRels]; - return deduplicateByRelationshipTypeOnly(combined); + return deduplicateByFullPattern(combined); }); setSchemaView('text'); localStorage.setItem(LOCAL_KEYS.source, JSON.stringify(updatedSource)); @@ -418,7 +418,7 @@ const PageLayout: React.FC = () => { setDbRels(rels); setCombinedRelsVal((prevRels: OptionType[]) => { const combined = [...rels, ...prevRels]; - return deduplicateByRelationshipTypeOnly(combined); + return deduplicateByFullPattern(combined); }); localStorage.setItem(LOCAL_KEYS.source, JSON.stringify(updatedSource)); localStorage.setItem(LOCAL_KEYS.type, JSON.stringify(updatedType)); @@ -456,7 +456,7 @@ const PageLayout: React.FC = () => { setPreDefinedRels(rels); setCombinedRelsVal((prevRels: OptionType[]) => { const combined = [...rels, ...prevRels]; - return deduplicateByRelationshipTypeOnly(combined); + return deduplicateByFullPattern(combined); }); localStorage.setItem(LOCAL_KEYS.source, JSON.stringify(updatedSource)); localStorage.setItem(LOCAL_KEYS.type, JSON.stringify(updatedType)); diff --git a/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/NewEntityExtractionSetting.tsx b/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/NewEntityExtractionSetting.tsx index 7b51cae28..727f27d4c 100644 --- a/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/NewEntityExtractionSetting.tsx +++ b/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/NewEntityExtractionSetting.tsx @@ -13,7 +13,7 @@ import { updateLocalStorage, extractOptions, parseRelationshipString, - deduplicateByRelationshipTypeOnly, + deduplicateByFullPattern, deduplicateNodeByValue, } from '../../../../utils/Utils'; import TooltipWrapper from '../../../UI/TipWrapper'; @@ -175,7 +175,7 @@ export default function NewEntityExtractionSetting({ }); setUserDefinedRels((prev: OptionType[]) => { const combined = [...prev, ...relationshipTypeOptions]; - return deduplicateByRelationshipTypeOnly(combined); + return deduplicateByFullPattern(combined); }); setCombinedNodes((prev: OptionType[]) => { const combined = [...prev, ...nodeLabelOptions]; @@ -183,7 +183,7 @@ export default function NewEntityExtractionSetting({ }); setCombinedRels((prev: OptionType[]) => { const combined = [...prev, ...relationshipTypeOptions]; - return deduplicateByRelationshipTypeOnly(combined); + return deduplicateByFullPattern(combined); }); setTupleOptions((prev) => [...updatedTuples, ...prev]); } else { diff --git a/frontend/src/utils/Utils.ts b/frontend/src/utils/Utils.ts index b623da9ad..0996932d4 100644 --- a/frontend/src/utils/Utils.ts +++ b/frontend/src/utils/Utils.ts @@ -881,14 +881,12 @@ export const deduplicateNodeByValue = (arrays: { value: any }[]) => { }); return Array.from(map.values()); }; - -export const deduplicateByRelationshipTypeOnly = (arrays: { value: string; label: string }[]) => { +export const deduplicateByFullPattern = (arrays: { value: string; label: string }[]) => { const seen = new Set(); const result: { value: string; label: string }[] = []; arrays.forEach((item) => { - const [, type] = item.value.split(','); - if (!seen.has(type)) { - seen.add(type); + if (!seen.has(item.value)) { + seen.add(item.value); result.push(item); } });