Skip to content

Commit a9ff30f

Browse files
authored
Merge pull request #122 from fixie-ai/juberti/mistral
Add Mistral Large 2 and Nemo
2 parents 30d8fe0 + fb93386 commit a9ff30f

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

llm_benchmark_suite.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
LLAMA_3_8B_CHAT_FP4 = "llama-3-8b-chat-fp4"
4141
MIXTRAL_8X7B_INSTRUCT = "mixtral-8x7b-instruct"
4242
MIXTRAL_8X7B_INSTRUCT_FP8 = "mixtral-8x7b-instruct-fp8"
43-
PHI_2 = "phi-2"
4443

4544

4645
parser = argparse.ArgumentParser()
@@ -183,6 +182,18 @@ def __init__(self, model: str, display_model: Optional[str] = None):
183182
)
184183

185184

185+
class _MistralLlm(_Llm):
186+
"""See https://docs.mistral.ai/getting-started/models"""
187+
188+
def __init__(self, model: str, display_model: Optional[str] = None):
189+
super().__init__(
190+
model,
191+
"mistral.ai/" + (display_model or model),
192+
api_key=os.getenv("MISTRAL_API_KEY"),
193+
base_url="https://api.mistral.ai/v1",
194+
)
195+
196+
186197
class _NvidiaLlm(_Llm):
187198
"""See https://build.nvidia.com/explore/discover"""
188199

@@ -330,6 +341,9 @@ def _text_models():
330341
_Llm("gemini-pro"),
331342
_Llm(GEMINI_1_5_PRO),
332343
_Llm(GEMINI_1_5_FLASH),
344+
# Mistral
345+
_MistralLlm("mistral-large-latest", "mistral-large"),
346+
_MistralLlm("open-mistral-nemo", "mistral-nemo"),
333347
# Mistral 8x7b
334348
_DatabricksLlm("databricks-mixtral-8x7b-instruct", MIXTRAL_8X7B_INSTRUCT),
335349
_DeepInfraLlm("mistralai/Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X7B_INSTRUCT),
@@ -484,6 +498,7 @@ def _image_models():
484498
_FireworksLlm(
485499
"accounts/fireworks/models/phi-3-vision-128k-instruct", "phi-3-vision"
486500
),
501+
_MistralLlm("pixtral-latest", "pixtral"),
487502
]
488503

489504

llm_request.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ async def run(self, on_token: Optional[Callable[["ApiContext", str], None]] = No
146146
if not self.metrics.error:
147147
token_time = end_time - first_token_time
148148
self.metrics.total_time = end_time - start_time
149-
self.metrics.tps = min((self.metrics.output_tokens - 1) / token_time, MAX_TPS)
149+
self.metrics.tps = min(
150+
(self.metrics.output_tokens - 1) / token_time, MAX_TPS
151+
)
150152
if self.metrics.tps == MAX_TPS:
151153
self.metrics.tps = 0.0
152154
else:
@@ -293,7 +295,7 @@ async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiRe
293295
# Some providers require opt-in for stream stats, but some providers don't like this opt-in.
294296
# Regardless of opt-in, Azure and ovh.net don't return stream stats at the moment.
295297
# See https://github.com/Azure/azure-rest-api-specs/issues/25062
296-
if not any(p in ctx.name for p in ["azure", "databricks", "fireworks"]):
298+
if not any(p in ctx.name for p in ["azure", "databricks", "fireworks", "mistral"]):
297299
kwargs["stream_options"] = {"include_usage": True}
298300
data = make_openai_chat_body(ctx, **kwargs)
299301
return await post(ctx, url, headers, data, openai_chunk_gen)

0 commit comments

Comments
 (0)