Skip to content

fix: agent generate config err #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -481,16 +482,22 @@ def _message_to_generate_content_response(

def _get_completion_inputs(
llm_request: LlmRequest,
) -> tuple[Iterable[Message], Iterable[dict]]:
"""Converts an LlmRequest to litellm inputs.
) -> Tuple[
List[Message],
Optional[List[Dict]],
Optional[types.SchemaUnion],
Optional[Dict],
]:
"""Converts an LlmRequest to litellm inputs and extracts generation params.

Args:
llm_request: The LlmRequest to convert.

Returns:
The litellm inputs (message list, tool dictionary and response format).
The litellm inputs (message list, tool dictionary, response format, and generation params).
"""
messages = []
# 1. Construct messages
messages: List[Message] = []
for content in llm_request.contents or []:
message_param_or_list = _content_to_message_param(content)
if isinstance(message_param_or_list, list):
Expand All @@ -507,7 +514,8 @@ def _get_completion_inputs(
),
)

tools = None
# 2. Convert tool declarations
tools: Optional[List[Dict]] = None
if (
llm_request.config
and llm_request.config.tools
Expand All @@ -518,12 +526,39 @@ def _get_completion_inputs(
for tool in llm_request.config.tools[0].function_declarations
]

response_format = None
# 3. Handle response format
response_format: Optional[types.SchemaUnion] = (
llm_request.config.response_schema if llm_request.config else None
)

# 4. Extract generation parameters
generation_params: Optional[Dict] = None
if llm_request.config:
config_dict = llm_request.config.model_dump(exclude_none=True)
# Generate LiteLlm parameters here,
# Following https://docs.litellm.ai/docs/completion/input.
generation_params = {}
param_mapping = {
"max_output_tokens": "max_completion_tokens",
"stop_sequences": "stop",
}
for key in (
"temperature",
"max_output_tokens",
"top_p",
"top_k",
"stop_sequences",
"presence_penalty",
"frequency_penalty",
):
if key in config_dict:
mapped_key = param_mapping.get(key, key)
generation_params[mapped_key] = config_dict[key]

if llm_request.config.response_schema:
response_format = llm_request.config.response_schema
if not generation_params:
generation_params = None

return messages, tools, response_format
return messages, tools, response_format, generation_params


def _build_function_declaration_log(
Expand Down Expand Up @@ -660,15 +695,23 @@ async def generate_content_async(
self._maybe_append_user_content(llm_request)
logger.debug(_build_request_log(llm_request))

messages, tools, response_format = _get_completion_inputs(llm_request)
messages, tools, response_format, generation_params = (
_get_completion_inputs(llm_request)
)

completion_args = {
"model": self.model,
"messages": messages,
"tools": tools,
"response_format": response_format,
}
completion_args.update(self._additional_args)

# Merge additional arguments and generation parameters safely
if hasattr(self, "_additional_args") and self._additional_args:
completion_args.update(self._additional_args)

if generation_params:
completion_args.update(generation_params)

if stream:
text = ""
Expand Down
33 changes: 32 additions & 1 deletion tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import json
from unittest.mock import AsyncMock
from unittest.mock import Mock

Expand Down Expand Up @@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls(
assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "1"
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}


@pytest.mark.asyncio
def test_get_completion_inputs_generation_params():
# Test that generation_params are extracted and mapped correctly
req = LlmRequest(
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="hi")]),
],
config=types.GenerateContentConfig(
temperature=0.33,
max_output_tokens=123,
top_p=0.88,
top_k=7,
stop_sequences=["foo", "bar"],
presence_penalty=0.1,
frequency_penalty=0.2,
),
)
from google.adk.models.lite_llm import _get_completion_inputs

_, _, _, generation_params = _get_completion_inputs(req)
assert generation_params["temperature"] == 0.33
assert generation_params["max_completion_tokens"] == 123
assert generation_params["top_p"] == 0.88
assert generation_params["top_k"] == 7
assert generation_params["stop"] == ["foo", "bar"]
assert generation_params["presence_penalty"] == 0.1
assert generation_params["frequency_penalty"] == 0.2
# Should not include max_output_tokens
assert "max_output_tokens" not in generation_params
assert "stop_sequences" not in generation_params