Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/flexible_inference_benchmark/engine/backend_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RequestFuncInput(BaseModel):
top_p: Optional[float] = None
top_k: Optional[int] = None
run_id: Optional[str] = None
json_response: bool = False


class RequestFuncOutput(BaseModel):
Expand Down Expand Up @@ -448,6 +449,20 @@ async def async_request_openai_chat_completions(
with otel_span as span:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search

# Apply JSON response formatting if flag is enabled
if request_func_input.json_response:
append_msg = (
"\nPlease send your response as a JSON object. "
"Follow this schema: {'assistant_response': 'your full, detailed response here'}. "
"Do not include any other text or formatting. "
"Only return the JSON object without any additional text or explanation."
)
if isinstance(content_body, str):
content_body += append_msg
else:
content_body[-1]["text"] += append_msg

payload = {
"model": request_func_input.model,
"messages": [{"role": "user", "content": content_body}],
Expand All @@ -456,6 +471,11 @@ async def async_request_openai_chat_completions(
"ignore_eos": request_func_input.ignore_eos,
"stream_options": {"include_usage": True},
}

# Add JSON response format if flag is enabled
if request_func_input.json_response:
payload["response_format"] = {"type": "json_object"}
payload["chat_template_kwargs"] = {"enable_thinking": False}
apply_sampling_params(payload, request_func_input, always_top_p=False)
if request_func_input.logprobs is not None:
payload["logprobs"] = True
Expand Down
3 changes: 3 additions & 0 deletions src/flexible_inference_benchmark/engine/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
top_p: Optional[float] = None,
top_k: Optional[int] = None,
run_id: Optional[str] = None,
json_response: bool = False,
):
self.backend = backend
self.api_url = api_url
Expand All @@ -66,6 +67,7 @@ def __init__(
self.top_p = top_p
self.top_k = top_k
self.run_id = run_id or str(uuid.uuid4())
self.json_response = json_response

@property
def request_func(
Expand Down Expand Up @@ -178,6 +180,7 @@ async def benchmark(
top_p=self.top_p,
top_k=self.top_k,
run_id=self.run_id,
json_response=self.json_response,
)
for (data_sample, media_sample) in zip(data, requests_media)
]
Expand Down
5 changes: 5 additions & 0 deletions src/flexible_inference_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> Any: # t

benchmark_parser.add_argument("--use-beam-search", action="store_true", help="Use beam search for completions.")

benchmark_parser.add_argument(
"--json-response", action="store_true", help="Request responses in JSON format from the API."
)

benchmark_parser.add_argument(
"--output-file",
type=str,
Expand Down Expand Up @@ -736,6 +740,7 @@ def run_main(args: argparse.Namespace) -> None:
args.top_p,
args.top_k,
run_id=run_id,
json_response=args.json_response,
)
# disable verbose output for validation of the endpoint. This is done to avoid confusion on terminal output.
client_verbose_value = client.verbose
Expand Down
Loading