Skip to content

Commit 17b482a

Browse files
simonwei97hangfeigenquan9
authored
fix: agent generate config err (google#1305)
* fix: agent generate config err * fix: resovle comment --------- Co-authored-by: Hangfei Lin <[email protected]> Co-authored-by: genquan9 <[email protected]>
1 parent 675faef commit 17b482a

File tree

2 files changed

+86
-12
lines changed

2 files changed

+86
-12
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Dict
2424
from typing import Generator
2525
from typing import Iterable
26+
from typing import List
2627
from typing import Literal
2728
from typing import Optional
2829
from typing import Tuple
@@ -481,16 +482,22 @@ def _message_to_generate_content_response(
481482

482483
def _get_completion_inputs(
483484
llm_request: LlmRequest,
484-
) -> tuple[Iterable[Message], Iterable[dict]]:
485-
"""Converts an LlmRequest to litellm inputs.
485+
) -> Tuple[
486+
List[Message],
487+
Optional[List[Dict]],
488+
Optional[types.SchemaUnion],
489+
Optional[Dict],
490+
]:
491+
"""Converts an LlmRequest to litellm inputs and extracts generation params.
486492
487493
Args:
488494
llm_request: The LlmRequest to convert.
489495
490496
Returns:
491-
The litellm inputs (message list, tool dictionary and response format).
497+
The litellm inputs (message list, tool dictionary, response format, and generation params).
492498
"""
493-
messages = []
499+
# 1. Construct messages
500+
messages: List[Message] = []
494501
for content in llm_request.contents or []:
495502
message_param_or_list = _content_to_message_param(content)
496503
if isinstance(message_param_or_list, list):
@@ -507,7 +514,8 @@ def _get_completion_inputs(
507514
),
508515
)
509516

510-
tools = None
517+
# 2. Convert tool declarations
518+
tools: Optional[List[Dict]] = None
511519
if (
512520
llm_request.config
513521
and llm_request.config.tools
@@ -518,12 +526,39 @@ def _get_completion_inputs(
518526
for tool in llm_request.config.tools[0].function_declarations
519527
]
520528

521-
response_format = None
529+
# 3. Handle response format
530+
response_format: Optional[types.SchemaUnion] = (
531+
llm_request.config.response_schema if llm_request.config else None
532+
)
533+
534+
# 4. Extract generation parameters
535+
generation_params: Optional[Dict] = None
536+
if llm_request.config:
537+
config_dict = llm_request.config.model_dump(exclude_none=True)
538+
# Generate LiteLlm parameters here,
539+
# Following https://docs.litellm.ai/docs/completion/input.
540+
generation_params = {}
541+
param_mapping = {
542+
"max_output_tokens": "max_completion_tokens",
543+
"stop_sequences": "stop",
544+
}
545+
for key in (
546+
"temperature",
547+
"max_output_tokens",
548+
"top_p",
549+
"top_k",
550+
"stop_sequences",
551+
"presence_penalty",
552+
"frequency_penalty",
553+
):
554+
if key in config_dict:
555+
mapped_key = param_mapping.get(key, key)
556+
generation_params[mapped_key] = config_dict[key]
522557

523-
if llm_request.config.response_schema:
524-
response_format = llm_request.config.response_schema
558+
if not generation_params:
559+
generation_params = None
525560

526-
return messages, tools, response_format
561+
return messages, tools, response_format, generation_params
527562

528563

529564
def _build_function_declaration_log(
@@ -660,15 +695,23 @@ async def generate_content_async(
660695
self._maybe_append_user_content(llm_request)
661696
logger.debug(_build_request_log(llm_request))
662697

663-
messages, tools, response_format = _get_completion_inputs(llm_request)
698+
messages, tools, response_format, generation_params = (
699+
_get_completion_inputs(llm_request)
700+
)
664701

665702
completion_args = {
666703
"model": self.model,
667704
"messages": messages,
668705
"tools": tools,
669706
"response_format": response_format,
670707
}
671-
completion_args.update(self._additional_args)
708+
709+
# Merge additional arguments and generation parameters safely
710+
if hasattr(self, "_additional_args") and self._additional_args:
711+
completion_args.update(self._additional_args)
712+
713+
if generation_params:
714+
completion_args.update(generation_params)
672715

673716
if stream:
674717
text = ""

tests/unittests/models/test_litellm.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
import json
1716
from unittest.mock import AsyncMock
1817
from unittest.mock import Mock
1918

@@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls(
14301429
assert final_response.content.parts[1].function_call.name == "function_2"
14311430
assert final_response.content.parts[1].function_call.id == "1"
14321431
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
1432+
1433+
1434+
@pytest.mark.asyncio
1435+
def test_get_completion_inputs_generation_params():
1436+
# Test that generation_params are extracted and mapped correctly
1437+
req = LlmRequest(
1438+
contents=[
1439+
types.Content(role="user", parts=[types.Part.from_text(text="hi")]),
1440+
],
1441+
config=types.GenerateContentConfig(
1442+
temperature=0.33,
1443+
max_output_tokens=123,
1444+
top_p=0.88,
1445+
top_k=7,
1446+
stop_sequences=["foo", "bar"],
1447+
presence_penalty=0.1,
1448+
frequency_penalty=0.2,
1449+
),
1450+
)
1451+
from google.adk.models.lite_llm import _get_completion_inputs
1452+
1453+
_, _, _, generation_params = _get_completion_inputs(req)
1454+
assert generation_params["temperature"] == 0.33
1455+
assert generation_params["max_completion_tokens"] == 123
1456+
assert generation_params["top_p"] == 0.88
1457+
assert generation_params["top_k"] == 7
1458+
assert generation_params["stop"] == ["foo", "bar"]
1459+
assert generation_params["presence_penalty"] == 0.1
1460+
assert generation_params["frequency_penalty"] == 0.2
1461+
# Should not include max_output_tokens
1462+
assert "max_output_tokens" not in generation_params
1463+
assert "stop_sequences" not in generation_params

0 commit comments

Comments
 (0)