Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions tensorrt_llm/llmapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],

# Format the argument documentation with 12 spaces indent for args
arg_line = f"{indent} {field_name} ({type_str}): "
if status := field_info.get("status", None):
arg_line += f":tag:`{status}` "
if field_description:
arg_line += field_description.split('\n')[0] # First line with type

Expand Down Expand Up @@ -554,20 +556,21 @@ class ApiParamTagger:
'''

def __call__(self, cls: Type[BaseModel]) -> None:
self.process_pydantic_model(cls)
""" The main entry point to tag the api doc. """
self._process_pydantic_model(cls)

def process_pydantic_model(self, cls: Type[BaseModel]) -> None:
def _process_pydantic_model(self, cls: Type[BaseModel]) -> None:
"""Process the Pydantic model to add tags to the fields.
"""
for field_name, field_info in cls.model_fields.items():
if field_info.json_schema_extra and 'status' in field_info.json_schema_extra:
status = field_info.json_schema_extra['status']
self.amend_pydantic_field_description_with_tags(
self._amend_pydantic_field_description_with_tags(
cls, [field_name], status)

def amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
field_names: list[str],
tag: str) -> None:
def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
field_names: list[str],
tag: str) -> None:
"""Amend the description of the fields with tags.
e.g. :tag:`beta` or :tag:`prototype`
Args:
Expand Down
10 changes: 9 additions & 1 deletion tests/unittest/llmapi/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from tensorrt_llm.llmapi.utils import ApiStatusRegistry
from tensorrt_llm.llmapi import LlmArgs
from tensorrt_llm.llmapi.utils import (ApiStatusRegistry,
generate_api_docs_as_docstring)


def test_api_status_registry():
Expand All @@ -24,3 +26,9 @@ def _my_method(self, *args, **kwargs):
pass

assert ApiStatusRegistry.get_api_status(App._my_method) == "beta"


def test_generate_api_docs_as_docstring():
doc = generate_api_docs_as_docstring(LlmArgs)
assert ":tag:`beta`" in doc, "the label is not generated"
print(doc)