33
44from pydantic import BaseModel
55
6+ from langchain import BasePromptTemplate
67from langchain .chains .llm import LLMChain
78from langchain .graphs .networkx_graph import NetworkxEntityGraph , parse_triples
89from langchain .indexes .prompts .knowledge_triplet_extraction import (
@@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel):
1718 llm : Optional [BaseLanguageModel ] = None
1819 graph_type : Type [NetworkxEntityGraph ] = NetworkxEntityGraph
1920
20- def from_text (self , text : str ) -> NetworkxEntityGraph :
21+ def from_text (
22+ self , text : str , prompt : BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
23+ ) -> NetworkxEntityGraph :
2124 """Create graph index from text."""
2225 if self .llm is None :
2326 raise ValueError ("llm should not be None" )
2427 graph = self .graph_type ()
25- chain = LLMChain (llm = self .llm , prompt = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT )
28+ chain = LLMChain (llm = self .llm , prompt = prompt )
2629 output = chain .predict (text = text )
2730 knowledge = parse_triples (output )
2831 for triple in knowledge :
2932 graph .add_triple (triple )
3033 return graph
3134
32- async def afrom_text (self , text : str ) -> NetworkxEntityGraph :
35+ async def afrom_text (
36+ self , text : str , prompt : BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
37+ ) -> NetworkxEntityGraph :
3338 """Create graph index from text asynchronously."""
3439 if self .llm is None :
3540 raise ValueError ("llm should not be None" )
3641 graph = self .graph_type ()
37- chain = LLMChain (llm = self .llm , prompt = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT )
42+ chain = LLMChain (llm = self .llm , prompt = prompt )
3843 output = await chain .apredict (text = text )
3944 knowledge = parse_triples (output )
4045 for triple in knowledge :
0 commit comments