Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 13 additions & 34 deletions src/transformers/cli/serving/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,19 @@
from fastapi.responses import JSONResponse, StreamingResponse
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionMessage as OpenAIChatCompletionMessage,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
ChoiceDelta,
ChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
from openai.types.completion_usage import CompletionUsage

class ChatCompletionMessage(OpenAIChatCompletionMessage):
"""OpenAI ``ChatCompletionMessage`` extended with the ``reasoning_content`` field."""

reasoning_content: str | None = None

class ChoiceDelta(OpenAIChoiceDelta):
"""OpenAI ``ChoiceDelta`` extended with the ``reasoning_content`` field."""

reasoning_content: str | None = None


from .utils import (
BaseGenerateManager,
Expand Down Expand Up @@ -413,21 +399,17 @@ def _build_completion(
Returns:
`dict`: Serialized ``ChatCompletion`` ready for JSON response.
"""
message = ChatCompletionMessage(
content=content, role="assistant", tool_calls=tool_calls, reasoning_content=reasoning_content
# reasoning_content is added as an extra field (base types set extra="allow")
# we use model_validate rather than __init__ to avoid ty raising errors for the extra field
message = ChatCompletionMessage.model_validate(
{"content": content, "role": "assistant", "tool_calls": tool_calls, "reasoning_content": reasoning_content}
)
result = ChatCompletion(
id=request_id,
created=int(time.time()),
object="chat.completion",
model=model_id,
choices=[
Choice(
index=0,
message=message,
finish_reason=finish_reason,
)
],
choices=[Choice(index=0, message=message, finish_reason=finish_reason)],
usage=usage,
)
return result.model_dump(exclude_none=True)
Expand Down Expand Up @@ -458,19 +440,16 @@ def _build_chunk_sse(
Returns:
`str`: A formatted SSE event string.
"""
# reasoning_content is added as an extra field (base types set extra="allow")
# we use model_validate rather than __init__ to avoid ty raising errors for the extra field
delta = ChoiceDelta.model_validate(
{"content": content, "role": role, "tool_calls": tool_calls, "reasoning_content": reasoning_content}
)
chunk = ChatCompletionChunk(
id=request_id,
created=int(time.time()),
model=model,
choices=[
ChoiceChunk(
delta=ChoiceDelta(
content=content, role=role, tool_calls=tool_calls, reasoning_content=reasoning_content
),
index=0,
finish_reason=finish_reason,
)
],
choices=[ChoiceChunk(delta=delta, index=0, finish_reason=finish_reason)],
usage=usage,
system_fingerprint="",
object="chat.completion.chunk",
Expand Down
Loading