11"""Callback Handler that logs to streamlit."""
22from typing import Any , Dict , List , Optional , Union
33
4- import streamlit as st
5-
64from langchain .callbacks .base import BaseCallbackHandler
75from langchain .schema import AgentAction , AgentFinish , LLMResult
86
@@ -11,16 +9,25 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
119 """Callback Handler that logs to streamlit."""
1210
1311 def __init__ (self ) -> None :
12+ try :
13+ import streamlit as st
14+ except ImportError as e :
15+ raise ImportError (
16+ "Could not import streamlit Python package. "
17+ "Please install it with `pip install streamlit`."
18+ ) from e
19+
1420 self .tokens_area = st .empty ()
1521 self .tokens_stream = ""
22+ self .st = st
1623
1724 def on_llm_start (
1825 self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any
1926 ) -> None :
2027 """Print out the prompts."""
21- st .write ("Prompts after formatting:" )
28+ self . st .write ("Prompts after formatting:" )
2229 for prompt in prompts :
23- st .write (prompt )
30+ self . st .write (prompt )
2431
2532 def on_llm_new_token (self , token : str , ** kwargs : Any ) -> None :
2633 """Run on new LLM token. Only available when streaming is enabled."""
@@ -42,11 +49,11 @@ def on_chain_start(
4249 ) -> None :
4350 """Print out that we are entering a chain."""
4451 class_name = serialized ["name" ]
45- st .write (f"Entering new { class_name } chain..." )
52+ self . st .write (f"Entering new { class_name } chain..." )
4653
4754 def on_chain_end (self , outputs : Dict [str , Any ], ** kwargs : Any ) -> None :
4855 """Print out that we finished a chain."""
49- st .write ("Finished chain." )
56+ self . st .write ("Finished chain." )
5057
5158 def on_chain_error (
5259 self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
@@ -66,7 +73,8 @@ def on_tool_start(
6673 def on_agent_action (self , action : AgentAction , ** kwargs : Any ) -> Any :
6774 """Run on agent action."""
6875 # st.write requires two spaces before a newline to render it
69- st .markdown (action .log .replace ("\n " , " \n " ))
76+
77+ self .st .markdown (action .log .replace ("\n " , " \n " ))
7078
7179 def on_tool_end (
7280 self ,
@@ -76,8 +84,8 @@ def on_tool_end(
7684 ** kwargs : Any ,
7785 ) -> None :
7886 """If not the final action, print out observation."""
79- st .write (f"{ observation_prefix } { output } " )
80- st .write (llm_prefix )
87+ self . st .write (f"{ observation_prefix } { output } " )
88+ self . st .write (llm_prefix )
8189
8290 def on_tool_error (
8391 self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
@@ -88,9 +96,9 @@ def on_tool_error(
8896 def on_text (self , text : str , ** kwargs : Any ) -> None :
8997 """Run on text."""
9098 # st.write requires two spaces before a newline to render it
91- st .write (text .replace ("\n " , " \n " ))
99+ self . st .write (text .replace ("\n " , " \n " ))
92100
93101 def on_agent_finish (self , finish : AgentFinish , ** kwargs : Any ) -> None :
94102 """Run on agent end."""
95103 # st.write requires two spaces before a newline to render it
96- st .write (finish .log .replace ("\n " , " \n " ))
104+ self . st .write (finish .log .replace ("\n " , " \n " ))
0 commit comments