1111from pydantic import BaseModel , ConfigDict , Field
1212from typing_extensions import Annotated
1313
14- from .schema_utils import json_schema_type
14+ from .schema_utils import json_schema_type , register_schema
1515
1616
1717@json_schema_type
@@ -32,10 +32,13 @@ class TopKSamplingStrategy(BaseModel):
3232 top_k : int = Field (..., ge = 1 )
3333
3434
35- SamplingStrategy = Annotated [
36- Union [GreedySamplingStrategy , TopPSamplingStrategy , TopKSamplingStrategy ],
37- Field (discriminator = "type" ),
38- ]
35+ SamplingStrategy = register_schema (
36+ Annotated [
37+ Union [GreedySamplingStrategy , TopPSamplingStrategy , TopKSamplingStrategy ],
38+ Field (discriminator = "type" ),
39+ ],
40+ name = "SamplingStrategy" ,
41+ )
3942
4043
4144@json_schema_type
@@ -46,15 +49,6 @@ class SamplingParams(BaseModel):
4649 repetition_penalty : Optional [float ] = 1.0
4750
4851
49- @json_schema_type (
50- schema = {
51- "description" : """
52- The format in which weights are specified. This does not necessarily
53- always equal what quantization is desired at runtime since there
54- can be on-the-fly conversions done.
55- """ ,
56- }
57- )
5852class CheckpointQuantizationFormat (Enum ):
5953 # default format
6054 bf16 = "bf16"
@@ -67,7 +61,6 @@ class CheckpointQuantizationFormat(Enum):
6761 int4 = "int4"
6862
6963
70- @json_schema_type
7164class ModelFamily (Enum ):
7265 llama2 = "llama2"
7366 llama3 = "llama3"
@@ -77,7 +70,6 @@ class ModelFamily(Enum):
7770 safety = "safety"
7871
7972
80- @json_schema_type
8173class CoreModelId (Enum ):
8274 """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
8375
@@ -187,11 +179,6 @@ def model_family(model_id) -> ModelFamily:
187179 raise ValueError (f"Unknown model family for { model_id } " )
188180
189181
190- @json_schema_type (
191- schema = {
192- "description" : "The model family and SKU of the model along with other parameters corresponding to the model."
193- }
194- )
195182class Model (BaseModel ):
196183 core_model_id : CoreModelId
197184 description : str
0 commit comments