Skip to content

Commit b3c49e9

Browse files
authored
Make streamlit import optional (#6510)
1 parent cece8c8 commit b3c49e9

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

langchain/callbacks/streamlit.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Callback Handler that logs to streamlit."""
22
from typing import Any, Dict, List, Optional, Union
33

4-
import streamlit as st
5-
64
from langchain.callbacks.base import BaseCallbackHandler
75
from langchain.schema import AgentAction, AgentFinish, LLMResult
86

@@ -11,16 +9,25 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
119
"""Callback Handler that logs to streamlit."""
1210

1311
def __init__(self) -> None:
12+
try:
13+
import streamlit as st
14+
except ImportError as e:
15+
raise ImportError(
16+
"Could not import streamlit Python package. "
17+
"Please install it with `pip install streamlit`."
18+
) from e
19+
1420
self.tokens_area = st.empty()
1521
self.tokens_stream = ""
22+
self.st = st
1623

1724
def on_llm_start(
1825
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
1926
) -> None:
2027
"""Print out the prompts."""
21-
st.write("Prompts after formatting:")
28+
self.st.write("Prompts after formatting:")
2229
for prompt in prompts:
23-
st.write(prompt)
30+
self.st.write(prompt)
2431

2532
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
2633
"""Run on new LLM token. Only available when streaming is enabled."""
@@ -42,11 +49,11 @@ def on_chain_start(
4249
) -> None:
4350
"""Print out that we are entering a chain."""
4451
class_name = serialized["name"]
45-
st.write(f"Entering new {class_name} chain...")
52+
self.st.write(f"Entering new {class_name} chain...")
4653

4754
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
4855
"""Print out that we finished a chain."""
49-
st.write("Finished chain.")
56+
self.st.write("Finished chain.")
5057

5158
def on_chain_error(
5259
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -66,7 +73,8 @@ def on_tool_start(
6673
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
6774
"""Run on agent action."""
6875
# st.write requires two spaces before a newline to render it
69-
st.markdown(action.log.replace("\n", " \n"))
76+
77+
self.st.markdown(action.log.replace("\n", " \n"))
7078

7179
def on_tool_end(
7280
self,
@@ -76,8 +84,8 @@ def on_tool_end(
7684
**kwargs: Any,
7785
) -> None:
7886
"""If not the final action, print out observation."""
79-
st.write(f"{observation_prefix}{output}")
80-
st.write(llm_prefix)
87+
self.st.write(f"{observation_prefix}{output}")
88+
self.st.write(llm_prefix)
8189

8290
def on_tool_error(
8391
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -88,9 +96,9 @@ def on_tool_error(
8896
def on_text(self, text: str, **kwargs: Any) -> None:
8997
"""Run on text."""
9098
# st.write requires two spaces before a newline to render it
91-
st.write(text.replace("\n", " \n"))
99+
self.st.write(text.replace("\n", " \n"))
92100

93101
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
94102
"""Run on agent end."""
95103
# st.write requires two spaces before a newline to render it
96-
st.write(finish.log.replace("\n", " \n"))
104+
self.st.write(finish.log.replace("\n", " \n"))

0 commit comments

Comments
 (0)