diff --git a/app/routers/websocket.py b/app/routers/websocket.py index 0b5f43a5..92f16292 100644 --- a/app/routers/websocket.py +++ b/app/routers/websocket.py @@ -137,13 +137,13 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B async def _call_agent( agent: CompiledStateGraph, - input_data: any, + input_data: any, config: dict, websocket: WebSocket, ) -> None: """ Streams the agent's response to a WebSocket connection, handling interruptions. - + Args: agent: The compiled LangGraph agent. input_data: The input data for the agent's run. @@ -153,7 +153,7 @@ async def _call_agent( """ await websocket.send_text("") - + async for stream in agent.astream_events( input_data, config=config, @@ -164,7 +164,7 @@ async def _call_agent( continue if text := _extract_streaming_text(stream): await websocket.send_text(text) - + if stream["event"] == "on_custom_event": event_data = stream.get("data", "") # Send custom events as-is (they should already be formatted) @@ -172,7 +172,7 @@ async def _call_agent( await websocket.send_text(event_data) else: await websocket.send_text(json.dumps(event_data)) - + if stream["event"] == "on_chain_stream": if interrupt_value := _extract_interrupt_value(stream): await websocket.send_text(interrupt_value) diff --git a/app/services/llm.py b/app/services/llm.py index d27d67ee..b8549155 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -38,17 +38,25 @@ def get_llm() -> BaseLanguageModel: Selects and returns a language model instance based on environment variables. - If the active LLM or the model is not configured, it raises a ValueError. - If LLM mocking is enabled, it configures the connections to the mock server. - + - If streaming is disabled, sets disable_streaming="tool_calling" to disable streaming + only for tool calls while keeping it enabled for regular text generation. + Returns: An instance of a language model. - + Raises: ValueError: If the active LLM or the model is not configured. """ activeLlm = get_active_llm() model = get_llm_model(activeLlm) - + + # Use "tool_calling" to disable streaming only for tool calls, not regular text + disable_streaming_value = None + if os.environ.get("DISABLE_STREAMING", "false").lower() == "true": + disable_streaming_value = "tool_calling" + logging.info("Streaming disabled for tool calls via DISABLE_STREAMING environment variable") + llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true" llm_mock_url = os.environ.get("LLM_MOCK_URL", "") if llm_mock_enabled: @@ -56,34 +64,35 @@ def get_llm() -> BaseLanguageModel: if activeLlm == "ollama": if llm_mock_enabled: - return ChatOllama(model=model, base_url=llm_mock_url) + return ChatOllama(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value) ollama_url = os.environ.get("OLLAMA_URL") - return ChatOllama(model=model, base_url=ollama_url) + return ChatOllama(model=model, base_url=ollama_url, disable_streaming=disable_streaming_value) if activeLlm == "gemini": if llm_mock_enabled: return ChatGoogleGenerativeAI( model=model, base_url=llm_mock_url, - transport="rest" + transport="rest", + disable_streaming=disable_streaming_value ) if model == "gemini-2.5-flash": # Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget - return ChatGoogleGenerativeAI(model=model, thinking_budget=0) - - return ChatGoogleGenerativeAI(model=model) + return ChatGoogleGenerativeAI(model=model, thinking_budget=0, disable_streaming=disable_streaming_value) + + return ChatGoogleGenerativeAI(model=model, disable_streaming=disable_streaming_value) if activeLlm == "openai": if llm_mock_enabled: - return ChatOpenAI(model=model, base_url=llm_mock_url) - + return ChatOpenAI(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value) + openai_url = os.environ.get("OPENAI_URL") if openai_url: - return ChatOpenAI(model=model, base_url=openai_url) - return ChatOpenAI(model=model) + return ChatOpenAI(model=model, base_url=openai_url, disable_streaming=disable_streaming_value) + return ChatOpenAI(model=model, disable_streaming=disable_streaming_value) if activeLlm == "bedrock": if llm_mock_enabled: os.environ["AWS_ENDPOINT_URL"] = llm_mock_url - return ChatBedrockConverse(model=model) + return ChatBedrockConverse(model=model, disable_streaming=disable_streaming_value) def get_active_llm() -> str: """ diff --git a/chart/agent/templates/ai-agent-deployment.yaml b/chart/agent/templates/ai-agent-deployment.yaml index 152659ae..cd8de748 100644 --- a/chart/agent/templates/ai-agent-deployment.yaml +++ b/chart/agent/templates/ai-agent-deployment.yaml @@ -109,6 +109,12 @@ spec: name: llm-config key: ACTIVE_LLM optional: true + - name: DISABLE_STREAMING + valueFrom: + configMapKeyRef: + name: llm-config + key: DISABLE_STREAMING + optional: true - name: LANGFUSE_PUBLIC_KEY valueFrom: secretKeyRef: diff --git a/chart/agent/templates/llm-config.yaml b/chart/agent/templates/llm-config.yaml index ef2bb704..68a02952 100644 --- a/chart/agent/templates/llm-config.yaml +++ b/chart/agent/templates/llm-config.yaml @@ -9,3 +9,4 @@ data: OPENAI_MODEL: "{{ .Values.openaiLlmModel }}" BEDROCK_MODEL: "{{ .Values.bedrockLlmModel }}" ACTIVE_LLM: "{{ .Values.activeLlm }}" + DISABLE_STREAMING: "{{ .Values.disableStreaming }}" diff --git a/chart/agent/values.yaml b/chart/agent/values.yaml index 461b1300..ecaf49ce 100644 --- a/chart/agent/values.yaml +++ b/chart/agent/values.yaml @@ -26,6 +26,7 @@ awsBedrock: bearerToken: region: activeLlm: # ollama, gemini, openai, bedrock +disableStreaming: false # Set to true to disable streaming for tool calls (keeps text streaming enabled) # Enable RAG with embedded rancher documentation rag: enabled: false diff --git a/tests/unit/services/test_llm.py b/tests/unit/services/test_llm.py index 2f6a9ff4..e821fa53 100644 --- a/tests/unit/services/test_llm.py +++ b/tests/unit/services/test_llm.py @@ -13,21 +13,21 @@ def test_get_llm_ollama(mock_chat_ollama): with patch.dict(os.environ, {"OLLAMA_MODEL": "test-model", "ACTIVE_LLM": "ollama", "OLLAMA_URL": "http://localhost:11434"}, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') def test_get_llm_gemini(mock_chat_gemini): with patch.dict(os.environ, {"GEMINI_MODEL": "gemini-pro", "ACTIVE_LLM": "gemini", "GOOGLE_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro") + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=None) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') def test_get_llm_openai(mock_openai): with patch.dict(os.environ, {"OPENAI_MODEL": "gpt-4", "ACTIVE_LLM": "openai", "OPENAI_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4") + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=None) assert llm == mock_openai.return_value def test_get_active_llm_not_configured(): @@ -57,7 +57,7 @@ def test_get_llm_with_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatOllama') @@ -71,7 +71,7 @@ def test_get_llm_without_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') @@ -85,7 +85,7 @@ def test_get_llm_gemini_with_mock(mock_chat_gemini): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest") + mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest", disable_streaming=None) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') @@ -99,7 +99,7 @@ def test_get_llm_openai_with_mock(mock_openai): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000") + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000", disable_streaming=None) assert llm == mock_openai.return_value @patch('app.services.llm.ChatOpenAI') @@ -112,6 +112,58 @@ def test_get_llm_openai_with_custom_url(mock_openai): "OPENAI_URL": "http://custom-openai:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000") + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000", disable_streaming=None) assert llm == mock_openai.return_value +@patch('app.services.llm.ChatOllama') +def test_get_llm_with_streaming_disabled(mock_chat_ollama): + """Test that disable_streaming='tool_calling' is passed when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "OLLAMA_MODEL": "test-model", + "ACTIVE_LLM": "ollama", + "OLLAMA_URL": "http://localhost:11434", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming="tool_calling") + assert llm == mock_chat_ollama.return_value + +@patch('app.services.llm.ChatGoogleGenerativeAI') +def test_get_llm_gemini_with_streaming_disabled(mock_chat_gemini): + """Test that disable_streaming='tool_calling' is passed for Gemini when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "GEMINI_MODEL": "gemini-pro", + "ACTIVE_LLM": "gemini", + "GOOGLE_API_KEY": "fake-key", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming="tool_calling") + assert llm == mock_chat_gemini.return_value + +@patch('app.services.llm.ChatOpenAI') +def test_get_llm_openai_with_streaming_disabled(mock_openai): + """Test that disable_streaming='tool_calling' is passed for OpenAI when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "OPENAI_MODEL": "gpt-4", + "ACTIVE_LLM": "openai", + "OPENAI_API_KEY": "fake-key", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming="tool_calling") + assert llm == mock_openai.return_value + +@patch('app.services.llm.ChatBedrockConverse') +def test_get_llm_bedrock_with_streaming_disabled(mock_bedrock): + """Test that disable_streaming='tool_calling' is passed for Bedrock when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "BEDROCK_MODEL": "anthropic.claude-3-sonnet-20240229-v1:0", + "ACTIVE_LLM": "bedrock", + "AWS_REGION": "us-east-1", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_bedrock.assert_called_once_with(model="anthropic.claude-3-sonnet-20240229-v1:0", disable_streaming="tool_calling") + assert llm == mock_bedrock.return_value +