From ca4e28305ba906df119aefbd684a64829352f5ca Mon Sep 17 00:00:00 2001 From: Samuel Vasconcelos Date: Thu, 18 Jun 2026 14:56:59 -0300 Subject: [PATCH 1/3] feat: add configurable streaming mode for LLM requests Signed-off-by: Samuel Vasconcelos --- app/routers/websocket.py | 66 +++++++++++-------- app/services/llm.py | 34 ++++++---- .../agent/templates/ai-agent-deployment.yaml | 6 ++ chart/agent/templates/llm-config.yaml | 1 + chart/agent/values.yaml | 1 + 5 files changed, 67 insertions(+), 41 deletions(-) diff --git a/app/routers/websocket.py b/app/routers/websocket.py index 0b5f43a5..0c3ed30a 100644 --- a/app/routers/websocket.py +++ b/app/routers/websocket.py @@ -137,45 +137,57 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B async def _call_agent( agent: CompiledStateGraph, - input_data: any, + input_data: any, config: dict, websocket: WebSocket, ) -> None: """ - Streams the agent's response to a WebSocket connection, handling interruptions. - + Invokes the agent and sends the response to a WebSocket connection, handling interruptions. + Supports both streaming and non-streaming modes based on DISABLE_STREAMING env var. + Args: agent: The compiled LangGraph agent. input_data: The input data for the agent's run. config: The run configuration. websocket: The WebSocket connection. - stream_mode: The types of events to stream from the agent. """ + disable_streaming = os.environ.get("DISABLE_STREAMING", "false").lower() == "true" await websocket.send_text("") - - async for stream in agent.astream_events( - input_data, - config=config, - stream_mode=["updates", "messages", "custom", "events"], - ): - if stream["event"] == "on_chat_model_stream": - if not _should_stream_text(stream): - continue - if text := _extract_streaming_text(stream): - await websocket.send_text(text) - - if stream["event"] == "on_custom_event": - event_data = stream.get("data", "") - # Send custom events as-is (they should already be formatted) - if isinstance(event_data, str): - await websocket.send_text(event_data) - else: - await websocket.send_text(json.dumps(event_data)) - - if stream["event"] == "on_chain_stream": - if interrupt_value := _extract_interrupt_value(stream): - await websocket.send_text(interrupt_value) + + if disable_streaming: + # Non-streaming mode: invoke and send complete response + result = await agent.ainvoke(input_data, config=config) + + # Extract and send the final message content + if "messages" in result and len(result["messages"]) > 0: + last_message = result["messages"][-1] + if hasattr(last_message, "content") and last_message.content: + await websocket.send_text(last_message.content) + else: + # Streaming mode: stream events as they arrive + async for stream in agent.astream_events( + input_data, + config=config, + stream_mode=["updates", "messages", "custom", "events"], + ): + if stream["event"] == "on_chat_model_stream": + if not _should_stream_text(stream): + continue + if text := _extract_streaming_text(stream): + await websocket.send_text(text) + + if stream["event"] == "on_custom_event": + event_data = stream.get("data", "") + # Send custom events as-is (they should already be formatted) + if isinstance(event_data, str): + await websocket.send_text(event_data) + else: + await websocket.send_text(json.dumps(event_data)) + + if stream["event"] == "on_chain_stream": + if interrupt_value := _extract_interrupt_value(stream): + await websocket.send_text(interrupt_value) def _should_stream_text(stream: dict) -> bool: """ diff --git a/app/services/llm.py b/app/services/llm.py index d27d67ee..35f93f9e 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -38,17 +38,22 @@ def get_llm() -> BaseLanguageModel: Selects and returns a language model instance based on environment variables. - If the active LLM or the model is not configured, it raises a ValueError. - If LLM mocking is enabled, it configures the connections to the mock server. - + - If streaming is disabled, sets disable_streaming=True on the model instance. + Returns: An instance of a language model. - + Raises: ValueError: If the active LLM or the model is not configured. """ activeLlm = get_active_llm() model = get_llm_model(activeLlm) - + + disable_streaming = os.environ.get("DISABLE_STREAMING", "false").lower() == "true" + if disable_streaming: + logging.info("Streaming is disabled via DISABLE_STREAMING environment variable") + llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true" llm_mock_url = os.environ.get("LLM_MOCK_URL", "") if llm_mock_enabled: @@ -56,34 +61,35 @@ def get_llm() -> BaseLanguageModel: if activeLlm == "ollama": if llm_mock_enabled: - return ChatOllama(model=model, base_url=llm_mock_url) + return ChatOllama(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming) ollama_url = os.environ.get("OLLAMA_URL") - return ChatOllama(model=model, base_url=ollama_url) + return ChatOllama(model=model, base_url=ollama_url, disable_streaming=disable_streaming) if activeLlm == "gemini": if llm_mock_enabled: return ChatGoogleGenerativeAI( model=model, base_url=llm_mock_url, - transport="rest" + transport="rest", + disable_streaming=disable_streaming ) if model == "gemini-2.5-flash": # Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget - return ChatGoogleGenerativeAI(model=model, thinking_budget=0) - - return ChatGoogleGenerativeAI(model=model) + return ChatGoogleGenerativeAI(model=model, thinking_budget=0, disable_streaming=disable_streaming) + + return ChatGoogleGenerativeAI(model=model, disable_streaming=disable_streaming) if activeLlm == "openai": if llm_mock_enabled: - return ChatOpenAI(model=model, base_url=llm_mock_url) - + return ChatOpenAI(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming) + openai_url = os.environ.get("OPENAI_URL") if openai_url: - return ChatOpenAI(model=model, base_url=openai_url) - return ChatOpenAI(model=model) + return ChatOpenAI(model=model, base_url=openai_url, disable_streaming=disable_streaming) + return ChatOpenAI(model=model, disable_streaming=disable_streaming) if activeLlm == "bedrock": if llm_mock_enabled: os.environ["AWS_ENDPOINT_URL"] = llm_mock_url - return ChatBedrockConverse(model=model) + return ChatBedrockConverse(model=model, disable_streaming=disable_streaming) def get_active_llm() -> str: """ diff --git a/chart/agent/templates/ai-agent-deployment.yaml b/chart/agent/templates/ai-agent-deployment.yaml index 152659ae..cd8de748 100644 --- a/chart/agent/templates/ai-agent-deployment.yaml +++ b/chart/agent/templates/ai-agent-deployment.yaml @@ -109,6 +109,12 @@ spec: name: llm-config key: ACTIVE_LLM optional: true + - name: DISABLE_STREAMING + valueFrom: + configMapKeyRef: + name: llm-config + key: DISABLE_STREAMING + optional: true - name: LANGFUSE_PUBLIC_KEY valueFrom: secretKeyRef: diff --git a/chart/agent/templates/llm-config.yaml b/chart/agent/templates/llm-config.yaml index ef2bb704..68a02952 100644 --- a/chart/agent/templates/llm-config.yaml +++ b/chart/agent/templates/llm-config.yaml @@ -9,3 +9,4 @@ data: OPENAI_MODEL: "{{ .Values.openaiLlmModel }}" BEDROCK_MODEL: "{{ .Values.bedrockLlmModel }}" ACTIVE_LLM: "{{ .Values.activeLlm }}" + DISABLE_STREAMING: "{{ .Values.disableStreaming }}" diff --git a/chart/agent/values.yaml b/chart/agent/values.yaml index 461b1300..f8c3d0c5 100644 --- a/chart/agent/values.yaml +++ b/chart/agent/values.yaml @@ -26,6 +26,7 @@ awsBedrock: bearerToken: region: activeLlm: # ollama, gemini, openai, bedrock +disableStreaming: false # Set to true to disable streaming mode for LLM responses # Enable RAG with embedded rancher documentation rag: enabled: false From 24239cf45dda9d97264899fe12c4f24c458005aa Mon Sep 17 00:00:00 2001 From: Samuel Vasconcelos Date: Thu, 18 Jun 2026 15:13:32 -0300 Subject: [PATCH 2/3] fix tests Signed-off-by: Samuel Vasconcelos --- tests/unit/services/test_llm.py | 68 +++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/tests/unit/services/test_llm.py b/tests/unit/services/test_llm.py index 2f6a9ff4..b3078f13 100644 --- a/tests/unit/services/test_llm.py +++ b/tests/unit/services/test_llm.py @@ -13,21 +13,21 @@ def test_get_llm_ollama(mock_chat_ollama): with patch.dict(os.environ, {"OLLAMA_MODEL": "test-model", "ACTIVE_LLM": "ollama", "OLLAMA_URL": "http://localhost:11434"}, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=False) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') def test_get_llm_gemini(mock_chat_gemini): with patch.dict(os.environ, {"GEMINI_MODEL": "gemini-pro", "ACTIVE_LLM": "gemini", "GOOGLE_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro") + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=False) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') def test_get_llm_openai(mock_openai): with patch.dict(os.environ, {"OPENAI_MODEL": "gpt-4", "ACTIVE_LLM": "openai", "OPENAI_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4") + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=False) assert llm == mock_openai.return_value def test_get_active_llm_not_configured(): @@ -57,7 +57,7 @@ def test_get_llm_with_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000", disable_streaming=False) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatOllama') @@ -71,7 +71,7 @@ def test_get_llm_without_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434") + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=False) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') @@ -85,7 +85,7 @@ def test_get_llm_gemini_with_mock(mock_chat_gemini): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest") + mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest", disable_streaming=False) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') @@ -99,7 +99,7 @@ def test_get_llm_openai_with_mock(mock_openai): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000") + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000", disable_streaming=False) assert llm == mock_openai.return_value @patch('app.services.llm.ChatOpenAI') @@ -112,6 +112,58 @@ def test_get_llm_openai_with_custom_url(mock_openai): "OPENAI_URL": "http://custom-openai:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000") + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000", disable_streaming=False) assert llm == mock_openai.return_value +@patch('app.services.llm.ChatOllama') +def test_get_llm_with_streaming_disabled(mock_chat_ollama): + """Test that disable_streaming=True is passed when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "OLLAMA_MODEL": "test-model", + "ACTIVE_LLM": "ollama", + "OLLAMA_URL": "http://localhost:11434", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=True) + assert llm == mock_chat_ollama.return_value + +@patch('app.services.llm.ChatGoogleGenerativeAI') +def test_get_llm_gemini_with_streaming_disabled(mock_chat_gemini): + """Test that disable_streaming=True is passed for Gemini when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "GEMINI_MODEL": "gemini-pro", + "ACTIVE_LLM": "gemini", + "GOOGLE_API_KEY": "fake-key", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=True) + assert llm == mock_chat_gemini.return_value + +@patch('app.services.llm.ChatOpenAI') +def test_get_llm_openai_with_streaming_disabled(mock_openai): + """Test that disable_streaming=True is passed for OpenAI when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "OPENAI_MODEL": "gpt-4", + "ACTIVE_LLM": "openai", + "OPENAI_API_KEY": "fake-key", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=True) + assert llm == mock_openai.return_value + +@patch('app.services.llm.ChatBedrockConverse') +def test_get_llm_bedrock_with_streaming_disabled(mock_bedrock): + """Test that disable_streaming=True is passed for Bedrock when DISABLE_STREAMING=true""" + with patch.dict(os.environ, { + "BEDROCK_MODEL": "anthropic.claude-3-sonnet-20240229-v1:0", + "ACTIVE_LLM": "bedrock", + "AWS_REGION": "us-east-1", + "DISABLE_STREAMING": "true" + }, clear=True): + llm = get_llm() + mock_bedrock.assert_called_once_with(model="anthropic.claude-3-sonnet-20240229-v1:0", disable_streaming=True) + assert llm == mock_bedrock.return_value + From 6cec3a56cd45c51bbe9b2be9c17be69dd8298e1a Mon Sep 17 00:00:00 2001 From: Samuel Vasconcelos Date: Fri, 19 Jun 2026 14:25:32 -0300 Subject: [PATCH 3/3] disable streaming only for tool calls --- app/routers/websocket.py | 58 +++++++++++++-------------------- app/services/llm.py | 29 +++++++++-------- chart/agent/values.yaml | 2 +- tests/unit/services/test_llm.py | 32 +++++++++--------- 4 files changed, 56 insertions(+), 65 deletions(-) diff --git a/app/routers/websocket.py b/app/routers/websocket.py index 0c3ed30a..92f16292 100644 --- a/app/routers/websocket.py +++ b/app/routers/websocket.py @@ -142,52 +142,40 @@ async def _call_agent( websocket: WebSocket, ) -> None: """ - Invokes the agent and sends the response to a WebSocket connection, handling interruptions. - Supports both streaming and non-streaming modes based on DISABLE_STREAMING env var. + Streams the agent's response to a WebSocket connection, handling interruptions. Args: agent: The compiled LangGraph agent. input_data: The input data for the agent's run. config: The run configuration. websocket: The WebSocket connection. + stream_mode: The types of events to stream from the agent. """ - disable_streaming = os.environ.get("DISABLE_STREAMING", "false").lower() == "true" await websocket.send_text("") - if disable_streaming: - # Non-streaming mode: invoke and send complete response - result = await agent.ainvoke(input_data, config=config) + async for stream in agent.astream_events( + input_data, + config=config, + stream_mode=["updates", "messages", "custom", "events"], + ): + if stream["event"] == "on_chat_model_stream": + if not _should_stream_text(stream): + continue + if text := _extract_streaming_text(stream): + await websocket.send_text(text) + + if stream["event"] == "on_custom_event": + event_data = stream.get("data", "") + # Send custom events as-is (they should already be formatted) + if isinstance(event_data, str): + await websocket.send_text(event_data) + else: + await websocket.send_text(json.dumps(event_data)) - # Extract and send the final message content - if "messages" in result and len(result["messages"]) > 0: - last_message = result["messages"][-1] - if hasattr(last_message, "content") and last_message.content: - await websocket.send_text(last_message.content) - else: - # Streaming mode: stream events as they arrive - async for stream in agent.astream_events( - input_data, - config=config, - stream_mode=["updates", "messages", "custom", "events"], - ): - if stream["event"] == "on_chat_model_stream": - if not _should_stream_text(stream): - continue - if text := _extract_streaming_text(stream): - await websocket.send_text(text) - - if stream["event"] == "on_custom_event": - event_data = stream.get("data", "") - # Send custom events as-is (they should already be formatted) - if isinstance(event_data, str): - await websocket.send_text(event_data) - else: - await websocket.send_text(json.dumps(event_data)) - - if stream["event"] == "on_chain_stream": - if interrupt_value := _extract_interrupt_value(stream): - await websocket.send_text(interrupt_value) + if stream["event"] == "on_chain_stream": + if interrupt_value := _extract_interrupt_value(stream): + await websocket.send_text(interrupt_value) def _should_stream_text(stream: dict) -> bool: """ diff --git a/app/services/llm.py b/app/services/llm.py index 35f93f9e..b8549155 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -38,7 +38,8 @@ def get_llm() -> BaseLanguageModel: Selects and returns a language model instance based on environment variables. - If the active LLM or the model is not configured, it raises a ValueError. - If LLM mocking is enabled, it configures the connections to the mock server. - - If streaming is disabled, sets disable_streaming=True on the model instance. + - If streaming is disabled, sets disable_streaming="tool_calling" to disable streaming + only for tool calls while keeping it enabled for regular text generation. Returns: An instance of a language model. @@ -50,9 +51,11 @@ def get_llm() -> BaseLanguageModel: activeLlm = get_active_llm() model = get_llm_model(activeLlm) - disable_streaming = os.environ.get("DISABLE_STREAMING", "false").lower() == "true" - if disable_streaming: - logging.info("Streaming is disabled via DISABLE_STREAMING environment variable") + # Use "tool_calling" to disable streaming only for tool calls, not regular text + disable_streaming_value = None + if os.environ.get("DISABLE_STREAMING", "false").lower() == "true": + disable_streaming_value = "tool_calling" + logging.info("Streaming disabled for tool calls via DISABLE_STREAMING environment variable") llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true" llm_mock_url = os.environ.get("LLM_MOCK_URL", "") @@ -61,35 +64,35 @@ def get_llm() -> BaseLanguageModel: if activeLlm == "ollama": if llm_mock_enabled: - return ChatOllama(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming) + return ChatOllama(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value) ollama_url = os.environ.get("OLLAMA_URL") - return ChatOllama(model=model, base_url=ollama_url, disable_streaming=disable_streaming) + return ChatOllama(model=model, base_url=ollama_url, disable_streaming=disable_streaming_value) if activeLlm == "gemini": if llm_mock_enabled: return ChatGoogleGenerativeAI( model=model, base_url=llm_mock_url, transport="rest", - disable_streaming=disable_streaming + disable_streaming=disable_streaming_value ) if model == "gemini-2.5-flash": # Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget - return ChatGoogleGenerativeAI(model=model, thinking_budget=0, disable_streaming=disable_streaming) + return ChatGoogleGenerativeAI(model=model, thinking_budget=0, disable_streaming=disable_streaming_value) - return ChatGoogleGenerativeAI(model=model, disable_streaming=disable_streaming) + return ChatGoogleGenerativeAI(model=model, disable_streaming=disable_streaming_value) if activeLlm == "openai": if llm_mock_enabled: - return ChatOpenAI(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming) + return ChatOpenAI(model=model, base_url=llm_mock_url, disable_streaming=disable_streaming_value) openai_url = os.environ.get("OPENAI_URL") if openai_url: - return ChatOpenAI(model=model, base_url=openai_url, disable_streaming=disable_streaming) - return ChatOpenAI(model=model, disable_streaming=disable_streaming) + return ChatOpenAI(model=model, base_url=openai_url, disable_streaming=disable_streaming_value) + return ChatOpenAI(model=model, disable_streaming=disable_streaming_value) if activeLlm == "bedrock": if llm_mock_enabled: os.environ["AWS_ENDPOINT_URL"] = llm_mock_url - return ChatBedrockConverse(model=model, disable_streaming=disable_streaming) + return ChatBedrockConverse(model=model, disable_streaming=disable_streaming_value) def get_active_llm() -> str: """ diff --git a/chart/agent/values.yaml b/chart/agent/values.yaml index f8c3d0c5..ecaf49ce 100644 --- a/chart/agent/values.yaml +++ b/chart/agent/values.yaml @@ -26,7 +26,7 @@ awsBedrock: bearerToken: region: activeLlm: # ollama, gemini, openai, bedrock -disableStreaming: false # Set to true to disable streaming mode for LLM responses +disableStreaming: false # Set to true to disable streaming for tool calls (keeps text streaming enabled) # Enable RAG with embedded rancher documentation rag: enabled: false diff --git a/tests/unit/services/test_llm.py b/tests/unit/services/test_llm.py index b3078f13..e821fa53 100644 --- a/tests/unit/services/test_llm.py +++ b/tests/unit/services/test_llm.py @@ -13,21 +13,21 @@ def test_get_llm_ollama(mock_chat_ollama): with patch.dict(os.environ, {"OLLAMA_MODEL": "test-model", "ACTIVE_LLM": "ollama", "OLLAMA_URL": "http://localhost:11434"}, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=False) + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') def test_get_llm_gemini(mock_chat_gemini): with patch.dict(os.environ, {"GEMINI_MODEL": "gemini-pro", "ACTIVE_LLM": "gemini", "GOOGLE_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=False) + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=None) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') def test_get_llm_openai(mock_openai): with patch.dict(os.environ, {"OPENAI_MODEL": "gpt-4", "ACTIVE_LLM": "openai", "OPENAI_API_KEY": "fake-key"}, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=False) + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=None) assert llm == mock_openai.return_value def test_get_active_llm_not_configured(): @@ -57,7 +57,7 @@ def test_get_llm_with_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000", disable_streaming=False) + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://mock-server:8000", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatOllama') @@ -71,7 +71,7 @@ def test_get_llm_without_mock(mock_chat_ollama): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=False) + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=None) assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') @@ -85,7 +85,7 @@ def test_get_llm_gemini_with_mock(mock_chat_gemini): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest", disable_streaming=False) + mock_chat_gemini.assert_called_once_with(model="gemini-pro", base_url="http://mock-server:8000", transport="rest", disable_streaming=None) assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') @@ -99,7 +99,7 @@ def test_get_llm_openai_with_mock(mock_openai): "LLM_MOCK_URL": "http://mock-server:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000", disable_streaming=False) + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://mock-server:8000", disable_streaming=None) assert llm == mock_openai.return_value @patch('app.services.llm.ChatOpenAI') @@ -112,12 +112,12 @@ def test_get_llm_openai_with_custom_url(mock_openai): "OPENAI_URL": "http://custom-openai:8000" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000", disable_streaming=False) + mock_openai.assert_called_once_with(model="gpt-4", base_url="http://custom-openai:8000", disable_streaming=None) assert llm == mock_openai.return_value @patch('app.services.llm.ChatOllama') def test_get_llm_with_streaming_disabled(mock_chat_ollama): - """Test that disable_streaming=True is passed when DISABLE_STREAMING=true""" + """Test that disable_streaming='tool_calling' is passed when DISABLE_STREAMING=true""" with patch.dict(os.environ, { "OLLAMA_MODEL": "test-model", "ACTIVE_LLM": "ollama", @@ -125,12 +125,12 @@ def test_get_llm_with_streaming_disabled(mock_chat_ollama): "DISABLE_STREAMING": "true" }, clear=True): llm = get_llm() - mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming=True) + mock_chat_ollama.assert_called_once_with(model="test-model", base_url="http://localhost:11434", disable_streaming="tool_calling") assert llm == mock_chat_ollama.return_value @patch('app.services.llm.ChatGoogleGenerativeAI') def test_get_llm_gemini_with_streaming_disabled(mock_chat_gemini): - """Test that disable_streaming=True is passed for Gemini when DISABLE_STREAMING=true""" + """Test that disable_streaming='tool_calling' is passed for Gemini when DISABLE_STREAMING=true""" with patch.dict(os.environ, { "GEMINI_MODEL": "gemini-pro", "ACTIVE_LLM": "gemini", @@ -138,12 +138,12 @@ def test_get_llm_gemini_with_streaming_disabled(mock_chat_gemini): "DISABLE_STREAMING": "true" }, clear=True): llm = get_llm() - mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming=True) + mock_chat_gemini.assert_called_once_with(model="gemini-pro", disable_streaming="tool_calling") assert llm == mock_chat_gemini.return_value @patch('app.services.llm.ChatOpenAI') def test_get_llm_openai_with_streaming_disabled(mock_openai): - """Test that disable_streaming=True is passed for OpenAI when DISABLE_STREAMING=true""" + """Test that disable_streaming='tool_calling' is passed for OpenAI when DISABLE_STREAMING=true""" with patch.dict(os.environ, { "OPENAI_MODEL": "gpt-4", "ACTIVE_LLM": "openai", @@ -151,12 +151,12 @@ def test_get_llm_openai_with_streaming_disabled(mock_openai): "DISABLE_STREAMING": "true" }, clear=True): llm = get_llm() - mock_openai.assert_called_once_with(model="gpt-4", disable_streaming=True) + mock_openai.assert_called_once_with(model="gpt-4", disable_streaming="tool_calling") assert llm == mock_openai.return_value @patch('app.services.llm.ChatBedrockConverse') def test_get_llm_bedrock_with_streaming_disabled(mock_bedrock): - """Test that disable_streaming=True is passed for Bedrock when DISABLE_STREAMING=true""" + """Test that disable_streaming='tool_calling' is passed for Bedrock when DISABLE_STREAMING=true""" with patch.dict(os.environ, { "BEDROCK_MODEL": "anthropic.claude-3-sonnet-20240229-v1:0", "ACTIVE_LLM": "bedrock", @@ -164,6 +164,6 @@ def test_get_llm_bedrock_with_streaming_disabled(mock_bedrock): "DISABLE_STREAMING": "true" }, clear=True): llm = get_llm() - mock_bedrock.assert_called_once_with(model="anthropic.claude-3-sonnet-20240229-v1:0", disable_streaming=True) + mock_bedrock.assert_called_once_with(model="anthropic.claude-3-sonnet-20240229-v1:0", disable_streaming="tool_calling") assert llm == mock_bedrock.return_value