Skip to content

Commit ad48a78

Browse files
committed
feat(langgraph): add rag.py.
1 parent 606bb28 commit ad48a78

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

langgraph/rag/rag.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import os
2+
from typing import Annotated, List, Literal
3+
4+
import bs4 # BeauifulSoup to parse HTML
5+
from langchain_community.document_loaders import WebBaseLoader
6+
from langchain_core.documents import Document
7+
from langchain_core.vectorstores import InMemoryVectorStore
8+
from langchain_openai import OpenAIEmbeddings
9+
from langchain_text_splitters import RecursiveCharacterTextSplitter
10+
from typing_extensions import TypedDict
11+
12+
# LangChain Hub is a centralized platform for uploading,
13+
# browsing, pulling, and managing prompts to help developers
14+
# discover and share polished prompt templates for various large
15+
# language models (LLMs)
16+
from langchain import hub
17+
from langchain.chat_models import init_chat_model
18+
from langgraph.graph import START, StateGraph
19+
20+
os.environ["OPENAI_API_KEY"] = input("OpenAI API key: ")
21+
22+
llm = init_chat_model("gpt-4o-mini", model_provider="openai")
23+
24+
25+
# Embedding model
26+
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
27+
28+
29+
# Vector store
30+
vector_store = InMemoryVectorStore(embeddings)
31+
32+
33+
# Load and chunk contents of the blog
34+
loader = WebBaseLoader( # A document loader for web pages
35+
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
36+
bs_kwargs=dict(
37+
parse_only=bs4.SoupStrainer(
38+
class_=("post-content", "post-title", "post-header")
39+
) # Customize HTML parsing
40+
),
41+
)
42+
docs = loader.load()
43+
44+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
45+
all_splits = text_splitter.split_documents(docs)
46+
47+
48+
# Add documents to the vector store
49+
_ = vector_store.add_documents(documents=all_splits)
50+
51+
52+
# Define prompt for question-answering
53+
# ↳ You can pull prompts from LangChain Hub with `hub.pull`
54+
prompt = hub.pull("rlm/rag-prompt")
55+
56+
57+
# Define the state for the application
58+
# ↳ The state is typically a Python `TypedDict` but can also
59+
# be a `pydantic` model
60+
class State(TypedDict):
61+
question: str
62+
context: list[Document]
63+
answer: str
64+
65+
66+
# Define application steps
67+
# ↳ The retrieve step retrieves documents from the vector store
68+
def retrieve(state: State):
69+
retrieved_docs = vector_store.similarity_search(state["question"])
70+
return {"context": retrieved_docs}
71+
72+
73+
# ↳ The generate step generates an answer
74+
def generate(state: State):
75+
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
76+
# The two following instructions could be rewritten using
77+
# the LCEL (LangChain Expression Language)
78+
# ```python
79+
# chain = prompt | llm
80+
# response = chain.invoke({"question": state["question"],
81+
# "context": docs_content})
82+
# ```
83+
messages = prompt.invoke({"question": state["question"], "context": docs_content})
84+
response = llm.invoke(messages)
85+
return {"answer": response.content}
86+
87+
88+
# Compile application
89+
graph_builder = StateGraph(State)
90+
91+
# ↳ `add_sequence` adds a list of nodes to the graph that
92+
# will be executed sequentially in the order provided
93+
graph_builder.add_sequence([retrieve, generate])
94+
graph_builder.add_edge(START, "retrieve")
95+
graph = graph_builder.compile()
96+
97+
# Test it out
98+
# ↳ LangGraph supports multiple invocation modes, including
99+
# sync, async and streaming. Here is with the stream tokens:
100+
for message, metadata in graph.stream(
101+
{"question": "What is Task Decomposition?"}, stream_mode="messages"
102+
):
103+
print(message.content, end="|")
104+
105+
106+
# Query analysis (self-querying)
107+
# ↳ Query analysis employs models to transform or construct
108+
# optimized search queries from raw user input. For illustrative
109+
# purposes, let's add some metadata to the documents in your
110+
# vector store. You will add some (contrived) sections to the
111+
# document which you can filter on later
112+
total_documents = len(all_splits)
113+
third = total_documents // 3
114+
115+
for i, document in enumerate(all_splits):
116+
if i < third:
117+
document.metadata["section"] = "beginning"
118+
elif i < 2 * third:
119+
document.metadata["section"] = "middle"
120+
else:
121+
document.metadata["section"] = "end"
122+
123+
# ↳ Update the vector store
124+
vector_store = InMemoryVectorStore(embeddings)
125+
_ = vector_store.add_documents(all_splits)
126+
127+
128+
# ↳ Define a schema for your search query
129+
# ↳ You will use structured output for this purpose (i.e. using
130+
# function call mode or JSON mode underneath). Here you
131+
# define a query as containing a string query and a document
132+
# section (either "beginning", "middle", or "end")
133+
class Search(TypedDict):
134+
"""Search query."""
135+
136+
query: Annotated[str, ..., "Search query to run."]
137+
section: Annotated[
138+
Literal["beginning", "middle", "end"],
139+
...,
140+
"Section to query.",
141+
]
142+
143+
144+
# ↳ Update the state
145+
class State(TypedDict):
146+
question: str
147+
query: Search
148+
context: list[Document]
149+
answer: str
150+
151+
152+
# ↳ Add a new application step
153+
def analyze_query(state: State):
154+
structured_llm = llm.with_structured_output(Search)
155+
query = structured_llm.invoke(state["question"])
156+
return {"query": query}
157+
158+
159+
# ↳ Update the retrieve step
160+
def retrieve(state: State):
161+
query = state["query"]
162+
retrieved_docs = vector_store.similarity_search(
163+
query["query"],
164+
filter=lambda doc: doc.metadata.get("section") == query["section"],
165+
)
166+
return {"context": retrieved_docs}
167+
168+
169+
# ↳ Recompile the application
170+
graph_builder = StateGraph(State).add_sequence([analyze_query, retrieve, generate])
171+
graph_builder.add_edge(START, "analyze_query")
172+
graph = graph_builder.compile()
173+
174+
# ↳ Test it out
175+
for step in graph.stream(
176+
{"question": "What does the end of the post say about Task Decomposition?"},
177+
stream_mode="updates",
178+
):
179+
print(f"{step}\n\n----------------\n")

0 commit comments

Comments
 (0)