Skip to content

Commit b7419d7

Browse files
committed
Removing json_schema_type from Enums to flatten them out in the schema
1 parent cfd14a0 commit b7419d7

File tree

2 files changed

+14
-35
lines changed

2 files changed

+14
-35
lines changed

models/datatypes.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import BaseModel, ConfigDict, Field
1212
from typing_extensions import Annotated
1313

14-
from .schema_utils import json_schema_type
14+
from .schema_utils import json_schema_type, register_schema
1515

1616

1717
@json_schema_type
@@ -32,10 +32,13 @@ class TopKSamplingStrategy(BaseModel):
3232
top_k: int = Field(..., ge=1)
3333

3434

35-
SamplingStrategy = Annotated[
36-
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
37-
Field(discriminator="type"),
38-
]
35+
SamplingStrategy = register_schema(
36+
Annotated[
37+
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
38+
Field(discriminator="type"),
39+
],
40+
name="SamplingStrategy",
41+
)
3942

4043

4144
@json_schema_type
@@ -46,15 +49,6 @@ class SamplingParams(BaseModel):
4649
repetition_penalty: Optional[float] = 1.0
4750

4851

49-
@json_schema_type(
50-
schema={
51-
"description": """
52-
The format in which weights are specified. This does not necessarily
53-
always equal what quantization is desired at runtime since there
54-
can be on-the-fly conversions done.
55-
""",
56-
}
57-
)
5852
class CheckpointQuantizationFormat(Enum):
5953
# default format
6054
bf16 = "bf16"
@@ -67,7 +61,6 @@ class CheckpointQuantizationFormat(Enum):
6761
int4 = "int4"
6862

6963

70-
@json_schema_type
7164
class ModelFamily(Enum):
7265
llama2 = "llama2"
7366
llama3 = "llama3"
@@ -77,7 +70,6 @@ class ModelFamily(Enum):
7770
safety = "safety"
7871

7972

80-
@json_schema_type
8173
class CoreModelId(Enum):
8274
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
8375

@@ -187,11 +179,6 @@ def model_family(model_id) -> ModelFamily:
187179
raise ValueError(f"Unknown model family for {model_id}")
188180

189181

190-
@json_schema_type(
191-
schema={
192-
"description": "The model family and SKU of the model along with other parameters corresponding to the model."
193-
}
194-
)
195182
class Model(BaseModel):
196183
core_model_id: CoreModelId
197184
description: str

models/llama3/api/datatypes.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
from ...schema_utils import json_schema_type
2121

2222

23-
@json_schema_type
2423
class Role(Enum):
2524
system = "system"
2625
user = "user"
2726
assistant = "assistant"
2827
tool = "tool"
2928

3029

31-
@json_schema_type
3230
class BuiltinTool(Enum):
3331
brave_search = "brave_search"
3432
wolfram_alpha = "wolfram_alpha"
@@ -82,13 +80,10 @@ def validate_field(cls, v):
8280
return v
8381

8482

85-
@json_schema_type
8683
class ToolPromptFormat(Enum):
87-
"""This Enum refers to the prompt format for calling custom / zero shot tools
84+
"""Prompt format for calling custom / zero shot tools.
8885
89-
`json` --
90-
Refers to the json format for calling tools.
91-
The json format takes the form like
86+
:cvar json: JSON format for calling tools. It takes the form:
9287
{
9388
"type": "function",
9489
"function" : {
@@ -97,22 +92,19 @@ class ToolPromptFormat(Enum):
9792
"parameters": {...}
9893
}
9994
}
100-
101-
`function_tag` --
102-
This is an example of how you could define
103-
your own user defined format for making tool calls.
104-
The function_tag format looks like this,
95+
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
10596
<function=function_name>(parameters)</function>
10697
107-
The detailed prompts for each of these formats are added to llama cli
98+
:cvar python_list: Python list. The output is a valid Python expression that can be
99+
evaluated to a list. Each element in the list is a function call. Example:
100+
["function_name(param1, param2)", "function_name(param1, param2)"]
108101
"""
109102

110103
json = "json"
111104
function_tag = "function_tag"
112105
python_list = "python_list"
113106

114107

115-
@json_schema_type
116108
class StopReason(Enum):
117109
end_of_turn = "end_of_turn"
118110
end_of_message = "end_of_message"

0 commit comments

Comments
 (0)