Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions app/routers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B

async def _call_agent(
agent: CompiledStateGraph,
input_data: any,
input_data: any,
config: dict,
websocket: WebSocket,
) -> None:
"""
Streams the agent's response to a WebSocket connection, handling interruptions.

Args:
agent: The compiled LangGraph agent.
input_data: The input data for the agent's run.
Expand All @@ -153,7 +153,7 @@ async def _call_agent(
"""

await websocket.send_text("<message>")

async for stream in agent.astream_events(
input_data,
config=config,
Expand All @@ -164,15 +164,15 @@ async def _call_agent(
continue
if text := _extract_streaming_text(stream):
await websocket.send_text(text)

if stream["event"] == "on_custom_event":
event_data = stream.get("data", "")
# Send custom events as-is (they should already be formatted)
if isinstance(event_data, str):
await websocket.send_text(event_data)
else:
await websocket.send_text(json.dumps(event_data))

if stream["event"] == "on_chain_stream":
if interrupt_value := _extract_interrupt_value(stream):
await websocket.send_text(interrupt_value)
Expand Down
37 changes: 23 additions & 14 deletions app/services/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,61 @@ def get_llm() -> BaseLanguageModel:
Selects and returns a language model instance based on environment variables.
- If the active LLM or the model is not configured, it raises a ValueError.
- If LLM mocking is enabled, it configures the connections to the mock server.

- If streaming is disabled, sets disable_streaming="tool_calling" to disable streaming
only for tool calls while keeping it enabled for regular text generation.

Returns:
An instance of a language model.

Raises:
ValueError: If the active LLM or the model is not configured.
"""

activeLlm = get_active_llm()
model = get_llm_model(activeLlm)


# Use "tool_calling" to disable streaming only for tool calls, not regular text
disable_streaming_value = None
if os.environ.get("DISABLE_STREAMING", "false").lower() == "true":
disable_streaming_value = "tool_calling"
logging.info("Streaming disabled for tool calls via DISABLE_STREAMING environment variable")

llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true"
llm_mock_url = os.environ.get("LLM_MOCK_URL", "")
if llm_mock_enabled:
logging.info(f"Connecting to LLM Mock server at {llm_mock_url}")

if activeLlm == "ollama":
if llm_mock_enabled:
return ChatOllama(model=model, base_url=llm_mock_url)
return ChatOllama(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value)

ollama_url = os.environ.get("OLLAMA_URL")
return ChatOllama(model=model, base_url=ollama_url)
return ChatOllama(model=model, base_url=ollama_url, disable_streaming=disable_streaming_value)
if activeLlm == "gemini":
if llm_mock_enabled:
return ChatGoogleGenerativeAI(
model=model,
base_url=llm_mock_url,
transport="rest"
transport="rest",
disable_streaming=disable_streaming_value
)
if model == "gemini-2.5-flash":
# Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget
return ChatGoogleGenerativeAI(model=model, thinking_budget=0)
return ChatGoogleGenerativeAI(model=model)
return ChatGoogleGenerativeAI(model=model, thinking_budget=0, disable_streaming=disable_streaming_value)

return ChatGoogleGenerativeAI(model=model, disable_streaming=disable_streaming_value)
if activeLlm == "openai":
if llm_mock_enabled:
return ChatOpenAI(model=model, base_url=llm_mock_url)
return ChatOpenAI(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value)

openai_url = os.environ.get("OPENAI_URL")
if openai_url:
return ChatOpenAI(model=model, base_url=openai_url)
return ChatOpenAI(model=model)
return ChatOpenAI(model=model, base_url=openai_url, disable_streaming=disable_streaming_value)
return ChatOpenAI(model=model, disable_streaming=disable_streaming_value)
if activeLlm == "bedrock":
if llm_mock_enabled:
os.environ["AWS_ENDPOINT_URL"] = llm_mock_url
return ChatBedrockConverse(model=model)
return ChatBedrockConverse(model=model, disable_streaming=disable_streaming_value)

def get_active_llm() -> str:
"""
Expand Down
6 changes: 6 additions & 0 deletions chart/agent/templates/ai-agent-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ spec:
name: llm-config
key: ACTIVE_LLM
optional: true
- name: DISABLE_STREAMING
valueFrom:
configMapKeyRef:
name: llm-config
key: DISABLE_STREAMING
optional: true
- name: LANGFUSE_PUBLIC_KEY
valueFrom:
secretKeyRef:
Expand Down
1 change: 1 addition & 0 deletions chart/agent/templates/llm-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ data:
OPENAI_MODEL: "{{ .Values.openaiLlmModel }}"
BEDROCK_MODEL: "{{ .Values.bedrockLlmModel }}"
ACTIVE_LLM: "{{ .Values.activeLlm }}"
DISABLE_STREAMING: "{{ .Values.disableStreaming }}"
1 change: 1 addition & 0 deletions chart/agent/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ awsBedrock:
bearerToken:
region:
activeLlm: # ollama, gemini, openai, bedrock
disableStreaming: false # Set to true to disable streaming for tool calls (keeps text streaming enabled)
# Enable RAG with embedded rancher documentation
rag:
enabled: false
Expand Down
68 changes: 60 additions & 8 deletions tests/unit/services/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
def test_get_llm_ollama(mock_chat_ollama):
with patch.dict(os.environ, {"OLLAMA_MODEL": "test-model", "ACTIVE_LLM": "ollama", "OLLAMA_URL": "http://localhost:11434"}, clear=True):
llm = get_llm()
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434")
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None)
assert llm == mock_chat_ollama.return_value

@patch('app.services.llm.ChatGoogleGenerativeAI')
def test_get_llm_gemini(mock_chat_gemini):
with patch.dict(os.environ, {"GEMINI_MODEL": "gemini-pro", "ACTIVE_LLM": "gemini", "GOOGLE_API_KEY": "fake-key"}, clear=True):
llm = get_llm()
mock_chat_gemini.assert_called_once_with(model="gemini-pro")
mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=None)
assert llm == mock_chat_gemini.return_value

@patch('app.services.llm.ChatOpenAI')
def test_get_llm_openai(mock_openai):
with patch.dict(os.environ, {"OPENAI_MODEL": "gpt-4", "ACTIVE_LLM": "openai", "OPENAI_API_KEY": "fake-key"}, clear=True):
llm = get_llm()
mock_openai.assert_called_once_with(model="gpt-4")
mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=None)
assert llm == mock_openai.return_value

def test_get_active_llm_not_configured():
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_get_llm_with_mock(mock_chat_ollama):
"LLM_MOCK_URL": "http://mock-server:8000"
}, clear=True):
llm = get_llm()
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000")
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000", disable_streaming=None)
assert llm == mock_chat_ollama.return_value

@patch('app.services.llm.ChatOllama')
Expand All @@ -71,7 +71,7 @@ def test_get_llm_without_mock(mock_chat_ollama):
"LLM_MOCK_URL": "http://mock-server:8000"
}, clear=True):
llm = get_llm()
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434")
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None)
assert llm == mock_chat_ollama.return_value

@patch('app.services.llm.ChatGoogleGenerativeAI')
Expand All @@ -85,7 +85,7 @@ def test_get_llm_gemini_with_mock(mock_chat_gemini):
"LLM_MOCK_URL": "http://mock-server:8000"
}, clear=True):
llm = get_llm()
mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest")
mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest", disable_streaming=None)
assert llm == mock_chat_gemini.return_value

@patch('app.services.llm.ChatOpenAI')
Expand All @@ -99,7 +99,7 @@ def test_get_llm_openai_with_mock(mock_openai):
"LLM_MOCK_URL": "http://mock-server:8000"
}, clear=True):
llm = get_llm()
mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000")
mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000", disable_streaming=None)
assert llm == mock_openai.return_value

@patch('app.services.llm.ChatOpenAI')
Expand All @@ -112,6 +112,58 @@ def test_get_llm_openai_with_custom_url(mock_openai):
"OPENAI_URL": "http://custom-openai:8000"
}, clear=True):
llm = get_llm()
mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000")
mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000", disable_streaming=None)
assert llm == mock_openai.return_value

@patch('app.services.llm.ChatOllama')
def test_get_llm_with_streaming_disabled(mock_chat_ollama):
"""Test that disable_streaming='tool_calling' is passed when DISABLE_STREAMING=true"""
with patch.dict(os.environ, {
"OLLAMA_MODEL": "test-model",
"ACTIVE_LLM": "ollama",
"OLLAMA_URL": "http://localhost:11434",
"DISABLE_STREAMING": "true"
}, clear=True):
llm = get_llm()
mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming="tool_calling")
assert llm == mock_chat_ollama.return_value

@patch('app.services.llm.ChatGoogleGenerativeAI')
def test_get_llm_gemini_with_streaming_disabled(mock_chat_gemini):
"""Test that disable_streaming='tool_calling' is passed for Gemini when DISABLE_STREAMING=true"""
with patch.dict(os.environ, {
"GEMINI_MODEL": "gemini-pro",
"ACTIVE_LLM": "gemini",
"GOOGLE_API_KEY": "fake-key",
"DISABLE_STREAMING": "true"
}, clear=True):
llm = get_llm()
mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming="tool_calling")
assert llm == mock_chat_gemini.return_value

@patch('app.services.llm.ChatOpenAI')
def test_get_llm_openai_with_streaming_disabled(mock_openai):
"""Test that disable_streaming='tool_calling' is passed for OpenAI when DISABLE_STREAMING=true"""
with patch.dict(os.environ, {
"OPENAI_MODEL": "gpt-4",
"ACTIVE_LLM": "openai",
"OPENAI_API_KEY": "fake-key",
"DISABLE_STREAMING": "true"
}, clear=True):
llm = get_llm()
mock_openai.assert_called_once_with(model="gpt-4", disable_streaming="tool_calling")
assert llm == mock_openai.return_value

@patch('app.services.llm.ChatBedrockConverse')
def test_get_llm_bedrock_with_streaming_disabled(mock_bedrock):
"""Test that disable_streaming='tool_calling' is passed for Bedrock when DISABLE_STREAMING=true"""
with patch.dict(os.environ, {
"BEDROCK_MODEL": "anthropic.claude-3-sonnet-20240229-v1:0",
"ACTIVE_LLM": "bedrock",
"AWS_REGION": "us-east-1",
"DISABLE_STREAMING": "true"
}, clear=True):
llm = get_llm()
mock_bedrock.assert_called_once_with(model="anthropic.claude-3-sonnet-20240229-v1:0", disable_streaming="tool_calling")
assert llm == mock_bedrock.return_value

Loading