Skip to content

Commit 4d697d3

Browse files
ozabludabaskaryan
andauthored
Allow passing custom prompts to GraphIndexCreator (#7381)
--------- Co-authored-by: Bagatur <[email protected]>
1 parent 612a74e commit 4d697d3

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

langchain/indexes/graph.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic import BaseModel
55

6+
from langchain import BasePromptTemplate
67
from langchain.chains.llm import LLMChain
78
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
89
from langchain.indexes.prompts.knowledge_triplet_extraction import (
@@ -17,24 +18,28 @@ class GraphIndexCreator(BaseModel):
1718
llm: Optional[BaseLanguageModel] = None
1819
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
1920

20-
def from_text(self, text: str) -> NetworkxEntityGraph:
21+
def from_text(
22+
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
23+
) -> NetworkxEntityGraph:
2124
"""Create graph index from text."""
2225
if self.llm is None:
2326
raise ValueError("llm should not be None")
2427
graph = self.graph_type()
25-
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
28+
chain = LLMChain(llm=self.llm, prompt=prompt)
2629
output = chain.predict(text=text)
2730
knowledge = parse_triples(output)
2831
for triple in knowledge:
2932
graph.add_triple(triple)
3033
return graph
3134

32-
async def afrom_text(self, text: str) -> NetworkxEntityGraph:
35+
async def afrom_text(
36+
self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
37+
) -> NetworkxEntityGraph:
3338
"""Create graph index from text asynchronously."""
3439
if self.llm is None:
3540
raise ValueError("llm should not be None")
3641
graph = self.graph_type()
37-
chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)
42+
chain = LLMChain(llm=self.llm, prompt=prompt)
3843
output = await chain.apredict(text=text)
3944
knowledge = parse_triples(output)
4045
for triple in knowledge:

0 commit comments

Comments
 (0)