Skip to content

Commit ec0dd6e

Browse files
authored
propagate callbacks to ConversationalRetrievalChain (#5572)
# Allow callbacks to monitor ConversationalRetrievalChain <!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> I ran into an issue where load_qa_chain was not passing the callbacks down to the child LLM chains, and so made sure that callbacks are propagated. There are probably more improvements to do here but this seemed like a good place to stop. Note that I saw a lot of references to callbacks_manager, which seems to be deprecated. I left that code alone for now. ## Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @agola11 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 -->
1 parent 3294774 commit ec0dd6e

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

langchain/chains/conversational_retrieval/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from langchain.callbacks.manager import (
1313
AsyncCallbackManagerForChainRun,
1414
CallbackManagerForChainRun,
15+
Callbacks,
1516
)
1617
from langchain.chains.base import Chain
1718
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@@ -204,6 +205,7 @@ def from_llm(
204205
verbose: bool = False,
205206
condense_question_llm: Optional[BaseLanguageModel] = None,
206207
combine_docs_chain_kwargs: Optional[Dict] = None,
208+
callbacks: Callbacks = None,
207209
**kwargs: Any,
208210
) -> BaseConversationalRetrievalChain:
209211
"""Load chain from LLM."""
@@ -212,17 +214,22 @@ def from_llm(
212214
llm,
213215
chain_type=chain_type,
214216
verbose=verbose,
217+
callbacks=callbacks,
215218
**combine_docs_chain_kwargs,
216219
)
217220

218221
_llm = condense_question_llm or llm
219222
condense_question_chain = LLMChain(
220-
llm=_llm, prompt=condense_question_prompt, verbose=verbose
223+
llm=_llm,
224+
prompt=condense_question_prompt,
225+
verbose=verbose,
226+
callbacks=callbacks,
221227
)
222228
return cls(
223229
retriever=retriever,
224230
combine_docs_chain=doc_chain,
225231
question_generator=condense_question_chain,
232+
callbacks=callbacks,
226233
**kwargs,
227234
)
228235

@@ -264,19 +271,24 @@ def from_llm(
264271
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
265272
chain_type: str = "stuff",
266273
combine_docs_chain_kwargs: Optional[Dict] = None,
274+
callbacks: Callbacks = None,
267275
**kwargs: Any,
268276
) -> BaseConversationalRetrievalChain:
269277
"""Load chain from LLM."""
270278
combine_docs_chain_kwargs = combine_docs_chain_kwargs or {}
271279
doc_chain = load_qa_chain(
272280
llm,
273281
chain_type=chain_type,
282+
callbacks=callbacks,
274283
**combine_docs_chain_kwargs,
275284
)
276-
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
285+
condense_question_chain = LLMChain(
286+
llm=llm, prompt=condense_question_prompt, callbacks=callbacks
287+
)
277288
return cls(
278289
vectorstore=vectorstore,
279290
combine_docs_chain=doc_chain,
280291
question_generator=condense_question_chain,
292+
callbacks=callbacks,
281293
**kwargs,
282294
)

langchain/chains/question_answering/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from langchain.base_language import BaseLanguageModel
55
from langchain.callbacks.base import BaseCallbackManager
6+
from langchain.callbacks.manager import Callbacks
67
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
78
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
89
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@@ -35,10 +36,15 @@ def _load_map_rerank_chain(
3536
rank_key: str = "score",
3637
answer_key: str = "answer",
3738
callback_manager: Optional[BaseCallbackManager] = None,
39+
callbacks: Callbacks = None,
3840
**kwargs: Any,
3941
) -> MapRerankDocumentsChain:
4042
llm_chain = LLMChain(
41-
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
43+
llm=llm,
44+
prompt=prompt,
45+
verbose=verbose,
46+
callback_manager=callback_manager,
47+
callbacks=callbacks,
4248
)
4349
return MapRerankDocumentsChain(
4450
llm_chain=llm_chain,
@@ -57,11 +63,16 @@ def _load_stuff_chain(
5763
document_variable_name: str = "context",
5864
verbose: Optional[bool] = None,
5965
callback_manager: Optional[BaseCallbackManager] = None,
66+
callbacks: Callbacks = None,
6067
**kwargs: Any,
6168
) -> StuffDocumentsChain:
6269
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
6370
llm_chain = LLMChain(
64-
llm=llm, prompt=_prompt, verbose=verbose, callback_manager=callback_manager
71+
llm=llm,
72+
prompt=_prompt,
73+
verbose=verbose,
74+
callback_manager=callback_manager,
75+
callbacks=callbacks,
6576
)
6677
# TODO: document prompt
6778
return StuffDocumentsChain(
@@ -84,6 +95,7 @@ def _load_map_reduce_chain(
8495
collapse_llm: Optional[BaseLanguageModel] = None,
8596
verbose: Optional[bool] = None,
8697
callback_manager: Optional[BaseCallbackManager] = None,
98+
callbacks: Callbacks = None,
8799
**kwargs: Any,
88100
) -> MapReduceDocumentsChain:
89101
_question_prompt = (
@@ -97,20 +109,23 @@ def _load_map_reduce_chain(
97109
prompt=_question_prompt,
98110
verbose=verbose,
99111
callback_manager=callback_manager,
112+
callbacks=callbacks,
100113
)
101114
_reduce_llm = reduce_llm or llm
102115
reduce_chain = LLMChain(
103116
llm=_reduce_llm,
104117
prompt=_combine_prompt,
105118
verbose=verbose,
106119
callback_manager=callback_manager,
120+
callbacks=callbacks,
107121
)
108122
# TODO: document prompt
109123
combine_document_chain = StuffDocumentsChain(
110124
llm_chain=reduce_chain,
111125
document_variable_name=combine_document_variable_name,
112126
verbose=verbose,
113127
callback_manager=callback_manager,
128+
callbacks=callbacks,
114129
)
115130
if collapse_prompt is None:
116131
collapse_chain = None
@@ -127,6 +142,7 @@ def _load_map_reduce_chain(
127142
prompt=collapse_prompt,
128143
verbose=verbose,
129144
callback_manager=callback_manager,
145+
callbacks=callbacks,
130146
),
131147
document_variable_name=combine_document_variable_name,
132148
verbose=verbose,
@@ -139,6 +155,7 @@ def _load_map_reduce_chain(
139155
collapse_document_chain=collapse_chain,
140156
verbose=verbose,
141157
callback_manager=callback_manager,
158+
callbacks=callbacks,
142159
**kwargs,
143160
)
144161

@@ -152,6 +169,7 @@ def _load_refine_chain(
152169
refine_llm: Optional[BaseLanguageModel] = None,
153170
verbose: Optional[bool] = None,
154171
callback_manager: Optional[BaseCallbackManager] = None,
172+
callbacks: Callbacks = None,
155173
**kwargs: Any,
156174
) -> RefineDocumentsChain:
157175
_question_prompt = (
@@ -165,13 +183,15 @@ def _load_refine_chain(
165183
prompt=_question_prompt,
166184
verbose=verbose,
167185
callback_manager=callback_manager,
186+
callbacks=callbacks,
168187
)
169188
_refine_llm = refine_llm or llm
170189
refine_chain = LLMChain(
171190
llm=_refine_llm,
172191
prompt=_refine_prompt,
173192
verbose=verbose,
174193
callback_manager=callback_manager,
194+
callbacks=callbacks,
175195
)
176196
return RefineDocumentsChain(
177197
initial_llm_chain=initial_chain,

0 commit comments

Comments
 (0)