diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a085bb2..9743354a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,10 +21,11 @@ repos: ftfy, loguru, numpy, - openai, + pillow, pydantic, pydantic_settings, pyyaml, + respx, requests, rich, transformers, diff --git a/pyproject.toml b/pyproject.toml index 6ab2c6e9..e0b47007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,10 @@ dependencies = [ "click", "datasets", "ftfy>=6.0.0", + "httpx[http2]<1.0.0", "loguru", "numpy", - "openai", + "pillow", "pydantic>=2.0.0", "pydantic-settings>=2.0.0", "pyyaml>=6.0.0", @@ -48,12 +49,14 @@ dev = [ "tox~=4.16.0", # testing + "lorem~=0.1.1", "pytest~=8.2.2", "pytest-asyncio~=0.23.8", "pytest-cov~=5.0.0", "pytest-mock~=3.14.0", "pytest-rerunfailures~=14.0", "requests-mock~=1.12.1", + "respx~=0.22.0", # code quality "mypy~=1.10.1", @@ -82,10 +85,6 @@ guidellm-config = "guidellm.config:print_config" # ********** Code Quality Tools ********** # ************************************************ -[tool.black] -line-length = 88 -target-version = ['py38'] - [tool.isort] profile = "black" @@ -127,8 +126,8 @@ ignore = [ "TCH002", "PLW1514", # allow Path.open without encoding "RET505", # allow `else` blocks - "RET506" # allow `else` blocks - + "RET506", # allow `else` blocks + "PD011", # ignore .values usage since ruff assumes it's a Pandas DataFrame ] select = [ # Rules reference: https://docs.astral.sh/ruff/rules/ diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index 875e319e..a45a66a7 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,10 +1,21 @@ -from .base import Backend, BackendEngine, BackendEnginePublic, GenerativeResponse -from .openai import OpenAIBackend +from .backend import ( + Backend, + BackendType, +) +from .openai import OpenAIHTTPBackend +from .response import ( + RequestArgs, + ResponseSummary, + StreamingResponseType, + StreamingTextResponse, +) __all__ = [ + "StreamingResponseType", + "StreamingTextResponse", + "RequestArgs", + "ResponseSummary", "Backend", - "BackendEngine", - "BackendEnginePublic", - "GenerativeResponse", - "OpenAIBackend", + "BackendType", + "OpenAIHTTPBackend", ] diff --git a/src/guidellm/backend/backend.py b/src/guidellm/backend/backend.py new file mode 100644 index 00000000..e2b89f1e --- /dev/null +++ b/src/guidellm/backend/backend.py @@ -0,0 +1,223 @@ +import asyncio +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Type, Union + +from loguru import logger +from PIL import Image + +from guidellm.backend.response import ResponseSummary, StreamingTextResponse + +__all__ = [ + "Backend", + "BackendType", +] + + +BackendType = Literal["openai_http"] + + +class Backend(ABC): + """ + Abstract base class for generative AI backends. + + This class provides a common interface for creating and interacting with different + generative AI backends. Subclasses should implement the abstract methods to + define specific backend behavior. + + :cvar _registry: A registration dictionary that maps BackendType to backend classes. + :param type_: The type of the backend. + """ + + _registry: Dict[BackendType, "Type[Backend]"] = {} + + @classmethod + def register(cls, backend_type: BackendType): + """ + A decorator to register a backend class in the backend registry. + + :param backend_type: The type of backend to register. + :type backend_type: BackendType + :return: The decorated backend class. + :rtype: Type[Backend] + """ + if backend_type in cls._registry: + raise ValueError(f"Backend type already registered: {backend_type}") + + if not issubclass(cls, Backend): + raise TypeError("Only subclasses of Backend can be registered") + + def inner_wrapper(wrapped_class: Type["Backend"]): + cls._registry[backend_type] = wrapped_class + logger.info("Registered backend type: {}", backend_type) + return wrapped_class + + return inner_wrapper + + @classmethod + def create(cls, type_: BackendType, **kwargs) -> "Backend": + """ + Factory method to create a backend instance based on the backend type. + + :param type_: The type of backend to create. + :type type_: BackendType + :param kwargs: Additional arguments for backend initialization. + :return: An instance of a subclass of Backend. + :rtype: Backend + :raises ValueError: If the backend type is not registered. + """ + + logger.info("Creating backend of type {}", type_) + + if type_ not in cls._registry: + err = ValueError(f"Unsupported backend type: {type_}") + logger.error("{}", err) + raise err + + return Backend._registry[type_](**kwargs) + + def __init__(self, type_: BackendType): + self._type = type_ + + @property + def type_(self) -> BackendType: + """ + :return: The type of the backend. + """ + return self._type + + @property + @abstractmethod + def target(self) -> str: + """ + :return: The target location for the backend. + """ + ... + + @property + @abstractmethod + def model(self) -> Optional[str]: + """ + :return: The model used for the backend requests. + """ + ... + + def validate(self): + """ + Handle final setup and validate the backend is ready for use. + If not successful, raises the appropriate exception. + """ + logger.info("{} validating backend {}", self.__class__.__name__, self.type_) + self.check_setup() + models = self.available_models() + if not models: + raise ValueError("No models available for the backend") + + async def _test_request(): + async for _ in self.text_completions( + prompt="Test connection", output_token_count=1 + ): # type: ignore[attr-defined] + pass + + asyncio.run(_test_request()) + + @abstractmethod + def check_setup(self): + """ + Check the setup for the backend. + If unsuccessful, raises the appropriate exception. + + :raises ValueError: If the setup check fails. + """ + ... + + @abstractmethod + def available_models(self) -> List[str]: + """ + Get the list of available models for the backend. + + :return: The list of available models. + :rtype: List[str] + """ + ... + + @abstractmethod + async def text_completions( + self, + prompt: Union[str, List[str]], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + """ + Generate text only completions for the given prompt. + Does not support multiple modalities, complicated chat interfaces, + or chat templates. Specifically, it requests with only the prompt. + + :param prompt: The prompt (or list of prompts) to generate a completion for. + If a list is supplied, these are concatenated and run through the model + for a single prompt. + :param request_id: The unique identifier for the request, if any. + Added to logging statements and the response for tracking purposes. + :param prompt_token_count: The number of tokens measured in the prompt, if any. + Returned in the response stats for later analysis, if applicable. + :param output_token_count: If supplied, the number of tokens to enforce + generation of for the output for this request. + :param kwargs: Additional keyword arguments to pass with the request. + :return: An async generator that yields a StreamingTextResponse for start, + a StreamingTextResponse for each received iteration, + and a ResponseSummary for the final response. + """ + ... + + @abstractmethod + async def chat_completions( + self, + content: Union[ + str, + List[Union[str, Dict[str, Union[str, Dict[str, str]]], Path, Image.Image]], + Any, + ], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + raw_content: bool = False, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + """ + Generate chat completions for the given content. + Supports multiple modalities, complicated chat interfaces, and chat templates. + Specifically, it requests with the content, which can be any combination of + text, images, and audio provided the target model supports it, + and returns the output text. Additionally, any chat templates + for the model are applied within the backend. + + :param content: The content (or list of content) to generate a completion for. + This supports any combination of text, images, and audio (model dependent). + Supported text only request examples: + content="Sample prompt", content=["Sample prompt", "Second prompt"], + content=[{"type": "text", "value": "Sample prompt"}. + Supported text and image request examples: + content=["Describe the image", PIL.Image.open("image.jpg")], + content=["Describe the image", Path("image.jpg")], + content=["Describe the image", {"type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. + Supported text and audio request examples: + content=["Transcribe the audio", Path("audio.wav")], + content=["Transcribe the audio", {"type": "input_audio", + "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. + Additionally, if raw_content=True then the content is passed directly to the + backend without any processing. + :param request_id: The unique identifier for the request, if any. + Added to logging statements and the response for tracking purposes. + :param prompt_token_count: The number of tokens measured in the prompt, if any. + Returned in the response stats for later analysis, if applicable. + :param output_token_count: If supplied, the number of tokens to enforce + generation of for the output for this request. + :param kwargs: Additional keyword arguments to pass with the request. + :return: An async generator that yields a StreamingTextResponse for start, + a StreamingTextResponse for each received iteration, + and a ResponseSummary for the final response. + """ + ... diff --git a/src/guidellm/backend/base.py b/src/guidellm/backend/base.py deleted file mode 100644 index d71c5f66..00000000 --- a/src/guidellm/backend/base.py +++ /dev/null @@ -1,320 +0,0 @@ -import asyncio -import functools -from abc import ABC, abstractmethod -from typing import AsyncGenerator, Dict, List, Literal, Optional, Type, Union - -from loguru import logger -from pydantic import BaseModel -from transformers import ( # type: ignore # noqa: PGH003 - AutoTokenizer, - PreTrainedTokenizer, -) - -from guidellm.core import TextGenerationRequest, TextGenerationResult - -__all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"] - - -BackendEnginePublic = Literal["openai_server"] -BackendEngine = Union[BackendEnginePublic, Literal["test"]] - - -class GenerativeResponse(BaseModel): - """ - A model representing a response from a generative AI backend. - - :param type_: The type of response, either 'token_iter' for intermediate - token output or 'final' for the final result. - :type type_: Literal["token_iter", "final"] - :param add_token: The token to add to the output - (only applicable if type_ is 'token_iter'). - :type add_token: Optional[str] - :param prompt: The original prompt sent to the backend. - :type prompt: Optional[str] - :param output: The final generated output (only applicable if type_ is 'final'). - :type output: Optional[str] - :param prompt_token_count: The number of tokens in the prompt. - :type prompt_token_count: Optional[int] - :param output_token_count: The number of tokens in the output. - :type output_token_count: Optional[int] - """ - - type_: Literal["token_iter", "final"] - add_token: Optional[str] = None - prompt: Optional[str] = None - output: Optional[str] = None - prompt_token_count: Optional[int] = None - output_token_count: Optional[int] = None - - -class Backend(ABC): - """ - Abstract base class for generative AI backends. - - This class provides a common interface for creating and interacting with different - generative AI backends. Subclasses should implement the abstract methods to - define specific backend behavior. - - :cvar _registry: A dictionary that maps BackendEngine types to backend classes. - :type _registry: Dict[BackendEngine, Type[Backend]] - :param type_: The type of the backend. - :type type_: BackendEngine - :param target: The target URL for the backend. - :type target: str - :param model: The model used by the backend. - :type model: str - """ - - _registry: Dict[BackendEngine, "Type[Backend]"] = {} - - @classmethod - def register(cls, backend_type: BackendEngine): - """ - A decorator to register a backend class in the backend registry. - - :param backend_type: The type of backend to register. - :type backend_type: BackendEngine - :return: The decorated backend class. - :rtype: Type[Backend] - """ - - def inner_wrapper(wrapped_class: Type["Backend"]): - cls._registry[backend_type] = wrapped_class - logger.info("Registered backend type: {}", backend_type) - return wrapped_class - - return inner_wrapper - - @classmethod - def create(cls, backend_type: BackendEngine, **kwargs) -> "Backend": - """ - Factory method to create a backend instance based on the backend type. - - :param backend_type: The type of backend to create. - :type backend_type: BackendEngine - :param kwargs: Additional arguments for backend initialization. - :return: An instance of a subclass of Backend. - :rtype: Backend - :raises ValueError: If the backend type is not registered. - """ - - logger.info("Creating backend of type {}", backend_type) - - if backend_type not in cls._registry: - err = ValueError(f"Unsupported backend type: {backend_type}") - logger.error("{}", err) - raise err - - return Backend._registry[backend_type](**kwargs) - - def __init__(self, type_: BackendEngine, target: str, model: str): - """ - Base constructor for the Backend class. - Calls into test_connection to ensure the backend is reachable. - Ensure all setup is done in the subclass constructor before calling super. - - :param type_: The type of the backend. - :param target: The target URL for the backend. - :param model: The model used by the backend. - """ - self._type = type_ - self._target = target - self._model = model - - self.test_connection() - - @property - def default_model(self) -> str: - """ - Get the default model for the backend. - - :return: The default model. - :rtype: str - :raises ValueError: If no models are available. - """ - return _cachable_default_model(self) - - @property - def type_(self) -> BackendEngine: - """ - Get the type of the backend. - - :return: The type of the backend. - :rtype: BackendEngine - """ - return self._type - - @property - def target(self) -> str: - """ - Get the target URL for the backend. - - :return: The target URL. - :rtype: str - """ - return self._target - - @property - def model(self) -> str: - """ - Get the model used by the backend. - - :return: The model name. - :rtype: str - """ - return self._model - - def model_tokenizer(self) -> PreTrainedTokenizer: - """ - Get the tokenizer for the backend model. - - :return: The tokenizer instance. - """ - return AutoTokenizer.from_pretrained(self.model) - - def test_connection(self) -> bool: - """ - Test the connection to the backend by running a short text generation request. - If successful, returns True, otherwise raises an exception. - - :return: True if the connection is successful. - :rtype: bool - :raises ValueError: If the connection test fails. - """ - try: - asyncio.get_running_loop() - is_async = True - except RuntimeError: - is_async = False - - if is_async: - logger.warning("Running in async mode, cannot test connection") - return True - - try: - request = TextGenerationRequest( - prompt="Test connection", output_token_count=5 - ) - - asyncio.run(self.submit(request)) - return True - except Exception as err: - raise_err = RuntimeError( - f"Backend connection test failed for backend type={self.type_} " - f"with target={self.target} and model={self.model} with error: {err}" - ) - logger.error(raise_err) - raise raise_err from err - - async def submit(self, request: TextGenerationRequest) -> TextGenerationResult: - """ - Submit a text generation request and return the result. - - This method handles the request submission to the backend and processes - the response in a streaming fashion if applicable. - - :param request: The request object containing the prompt - and other configurations. - :type request: TextGenerationRequest - :return: The result of the text generation request. - :rtype: TextGenerationResult - :raises ValueError: If no response is received from the backend. - """ - - logger.debug("Submitting request with prompt: {}", request.prompt) - - result = TextGenerationResult(request=request) - result.start(request.prompt) - received_final = False - - async for response in self.make_request(request): - logger.debug("Received response: {}", response) - if response.type_ == "token_iter": - result.output_token(response.add_token if response.add_token else "") - elif response.type_ == "final": - if received_final: - err = ValueError( - "Received multiple final responses from the backend." - ) - logger.error(err) - raise err - - result.end( - output=response.output, - prompt_token_count=response.prompt_token_count, - output_token_count=response.output_token_count, - ) - received_final = True - else: - err = ValueError( - f"Invalid response received from the backend of type: " - f"{response.type_} for {response}" - ) - logger.error(err) - raise err - - if not received_final: - err = ValueError("No final response received from the backend.") - logger.error(err) - raise err - - logger.info("Request completed with output: {}", result.output) - - return result - - @abstractmethod - async def make_request( - self, - request: TextGenerationRequest, - ) -> AsyncGenerator[GenerativeResponse, None]: - """ - Abstract method to make a request to the backend. - - Subclasses must implement this method to define how requests are handled - by the backend. - - :param request: The request object containing the prompt and - other configurations. - :type request: TextGenerationRequest - :yield: A generator yielding responses from the backend. - :rtype: AsyncGenerator[GenerativeResponse, None] - """ - yield None # type: ignore # noqa: PGH003 - - @abstractmethod - def available_models(self) -> List[str]: - """ - Abstract method to get the available models for the backend. - - Subclasses must implement this method to provide the list of models - supported by the backend. - - :return: A list of available models. - :rtype: List[str] - :raises NotImplementedError: If the method is not implemented by a subclass. - """ - raise NotImplementedError - - -@functools.lru_cache(maxsize=1) -def _cachable_default_model(backend: Backend) -> str: - """ - Get the default model for a backend using LRU caching. - - This function caches the default model to optimize repeated lookups. - - :param backend: The backend instance for which to get the default model. - :type backend: Backend - :return: The default model. - :rtype: str - :raises ValueError: If no models are available. - """ - logger.debug("Getting default model for backend: {}", backend) - models = backend.available_models() - if models: - logger.debug("Default model: {}", models[0]) - return models[0] - - err = ValueError("No models available.") - logger.error(err) - raise err diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 8c83f914..7870a949 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -1,169 +1,509 @@ -from typing import AsyncGenerator, Dict, List, Optional +import base64 +import json +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union +import httpx from loguru import logger -from openai import AsyncOpenAI, OpenAI - -from guidellm.backend.base import Backend, GenerativeResponse +from PIL import Image + +from guidellm.backend.backend import Backend +from guidellm.backend.response import ( + RequestArgs, + ResponseSummary, + StreamingTextResponse, +) from guidellm.config import settings -from guidellm.core import TextGenerationRequest -__all__ = ["OpenAIBackend"] +__all__ = ["OpenAIHTTPBackend"] -@Backend.register("openai_server") -class OpenAIBackend(Backend): +@Backend.register("openai_http") +class OpenAIHTTPBackend(Backend): """ - An OpenAI backend implementation for generative AI results. - - This class provides an interface to communicate with the - OpenAI server for generating responses based on given prompts. - - :param openai_api_key: The API key for OpenAI. - If not provided, it will default to the key from settings. - :type openai_api_key: Optional[str] - :param target: The target URL string for the OpenAI server. - :type target: Optional[str] - :param model: The OpenAI model to use, defaults to the first available model. - :type model: Optional[str] - :param request_args: Additional arguments for the OpenAI request. - :type request_args: Dict[str, Any] + A HTTP-based backend implementation for requests to an OpenAI compatible server. + For example, a vLLM server instance or requests to OpenAI's API. + + :param target: The target URL string for the OpenAI server. ex: http://0.0.0.0:8000 + :param model: The model to use for all requests on the target server. + If none is provided, the first available model will be used. + :param api_key: The API key to use for requests to the OpenAI server. + If provided, adds an Authorization header with the value + "Authorization: Bearer {api_key}". + If not provided, no Authorization header is added. + :param organization: The organization to use for requests to the OpenAI server. + For example, if set to "org_123", adds an OpenAI-Organization header with the + value "OpenAI-Organization: org_123". + If not provided, no OpenAI-Organization header is added. + :param project: The project to use for requests to the OpenAI server. + For example, if set to "project_123", adds an OpenAI-Project header with the + value "OpenAI-Project: project_123". + If not provided, no OpenAI-Project header is added. + :param timeout: The timeout to use for requests to the OpenAI server. + If not provided, the default timeout provided from settings is used. + :param http2: If True, uses HTTP/2 for requests to the OpenAI server. + Defaults to True. + :param max_output_tokens: The maximum number of tokens to request for completions. + If not provided, the default maximum tokens provided from settings is used. """ def __init__( self, - openai_api_key: Optional[str] = None, target: Optional[str] = None, model: Optional[str] = None, - **request_args, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + timeout: Optional[float] = None, + http2: Optional[bool] = True, + max_output_tokens: Optional[int] = None, ): - self._request_args: Dict = request_args - api_key: str = openai_api_key or settings.openai.api_key - - if not api_key: - err = ValueError( - "`GUIDELLM__OPENAI__API_KEY` environment variable or " - "--openai-api-key CLI parameter must be specified for the " - "OpenAI backend." - ) - logger.error("{}", err) - raise err + super().__init__(type_="openai_http") + self._target = target or settings.openai.base_url + self._model = model + + api_key = api_key or settings.openai.api_key + self.authorization = ( + f"Bearer {api_key}" if api_key else settings.openai.bearer_token + ) - base_url = target or settings.openai.base_url + self.organization = organization or settings.openai.organization + self.project = project or settings.openai.project + self.timeout = timeout if timeout is not None else settings.request_timeout + self.http2 = http2 if http2 is not None else settings.request_http2 + self.max_output_tokens = ( + max_output_tokens + if max_output_tokens is not None + else settings.openai.max_output_tokens + ) + + @property + def target(self) -> str: + """ + :return: The target URL string for the OpenAI server. + """ + return self._target + + @property + def model(self) -> Optional[str]: + """ + :return: The model to use for all requests on the target server. + If validate hasn't been called yet and no model was passed in, + this will be None until validate is called to set the default. + """ + return self._model + + def check_setup(self): + """ + Check if the backend is setup correctly and can be used for requests. + Specifically, if a model is not provided, it grabs the first available model. + If no models are available, raises a ValueError. + If a model is provided and not available, raises a ValueError. - if not base_url: - err = ValueError( - "`GUIDELLM__OPENAI__BASE_URL` environment variable or " - "target parameter must be specified for the OpenAI backend." + :raises ValueError: If no models or the provided model is not available. + """ + models = self.available_models() + if not models: + raise ValueError(f"No models available for target: {self.target}") + + if not self.model: + self._model = models[0] + elif self.model not in models: + raise ValueError( + f"Model {self.model} not found in available models:" + "{models} for target: {self.target}" ) - logger.error("{}", err) - raise err - self._async_client = AsyncOpenAI(api_key=api_key, base_url=base_url) - self._client = OpenAI(api_key=api_key, base_url=base_url) - self._model = model or self.default_model + def available_models(self) -> List[str]: + """ + Get the available models for the target server using the OpenAI models endpoint: + /v1/models + """ + target = f"{self.target}/v1/models" + headers = self._headers() + + with httpx.Client(http2=self.http2, timeout=self.timeout) as client: + response = client.get(target, headers=headers) + response.raise_for_status() + + models = [] + + for item in response.json()["data"]: + models.append(item["id"]) - super().__init__(type_="openai_server", target=base_url, model=self._model) - logger.info("OpenAI {} Backend listening on {}", self._model, base_url) + return models - async def make_request( + async def text_completions( # type: ignore[override] self, - request: TextGenerationRequest, - ) -> AsyncGenerator[GenerativeResponse, None]: + prompt: Union[str, List[str]], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: """ - Make a request to the OpenAI backend. + Generate text completions for the given prompt using the OpenAI + completions endpoint: /v1/completions. + + :param prompt: The prompt (or list of prompts) to generate a completion for. + If a list is supplied, these are concatenated and run through the model + for a single prompt. + :param request_id: The unique identifier for the request, if any. + Added to logging statements and the response for tracking purposes. + :param prompt_token_count: The number of tokens measured in the prompt, if any. + Returned in the response stats for later analysis, if applicable. + :param output_token_count: If supplied, the number of tokens to enforce + generation of for the output for this request. + :param kwargs: Additional keyword arguments to pass with the request. + :return: An async generator that yields a StreamingTextResponse for start, + a StreamingTextResponse for each received iteration, + and a ResponseSummary for the final response. + """ + + logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) + headers = self._headers() + payload = self._completions_payload( + orig_kwargs=kwargs, + max_output_tokens=output_token_count, + prompt=prompt, + ) - This method sends a prompt to the OpenAI backend and streams - the response tokens back. + try: + async for resp in self._iterative_completions_request( + type_="text", + request_id=request_id, + request_prompt_tokens=prompt_token_count, + request_output_tokens=output_token_count, + headers=headers, + payload=payload, + ): + yield resp + except Exception as ex: + logger.error( + "{} request with headers: {} and payload: {} failed: {}", + self.__class__.__name__, + headers, + payload, + ex, + ) + raise ex - :param request: The text generation request to submit. - :type request: TextGenerationRequest - :yield: A stream of GenerativeResponse objects. - :rtype: AsyncGenerator[GenerativeResponse, None] + async def chat_completions( # type: ignore[override] + self, + content: Union[ + str, + List[Union[str, Dict[str, Union[str, Dict[str, str]]], Path, Image.Image]], + Any, + ], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + raw_content: bool = False, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: """ + Generate chat completions for the given content using the OpenAI + chat completions endpoint: /v1/chat/completions. + + :param content: The content (or list of content) to generate a completion for. + This supports any combination of text, images, and audio (model dependent). + Supported text only request examples: + content="Sample prompt", content=["Sample prompt", "Second prompt"], + content=[{"type": "text", "value": "Sample prompt"}. + Supported text and image request examples: + content=["Describe the image", PIL.Image.open("image.jpg")], + content=["Describe the image", Path("image.jpg")], + content=["Describe the image", {"type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. + Supported text and audio request examples: + content=["Transcribe the audio", Path("audio.wav")], + content=["Transcribe the audio", {"type": "input_audio", + "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. + Additionally, if raw_content=True then the content is passed directly to the + backend without any processing. + :param request_id: The unique identifier for the request, if any. + Added to logging statements and the response for tracking purposes. + :param prompt_token_count: The number of tokens measured in the prompt, if any. + Returned in the response stats for later analysis, if applicable. + :param output_token_count: If supplied, the number of tokens to enforce + generation of for the output for this request. + :param kwargs: Additional keyword arguments to pass with the request. + :return: An async generator that yields a StreamingTextResponse for start, + a StreamingTextResponse for each received iteration, + and a ResponseSummary for the final response. + """ + logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) + headers = self._headers() + messages = ( + content if raw_content else self._create_chat_messages(content=content) + ) + payload = self._completions_payload( + orig_kwargs=kwargs, + max_output_tokens=output_token_count, + messages=messages, + ) - logger.debug("Making request to OpenAI backend with prompt: {}", request.prompt) + try: + async for resp in self._iterative_completions_request( + type_="chat", + request_id=request_id, + request_prompt_tokens=prompt_token_count, + request_output_tokens=output_token_count, + headers=headers, + payload=payload, + ): + yield resp + except Exception as ex: + logger.error( + "{} request with headers: {} and payload: {} failed: {}", + self.__class__.__name__, + headers, + payload, + ex, + ) + raise ex - request_args: Dict = { - "n": 1, # Number of completions for each prompt + def _headers(self) -> Dict[str, str]: + headers = { + "Content-Type": "application/json", } - if request.output_token_count is not None: - request_args.update( + if self.authorization: + headers["Authorization"] = self.authorization + + if self.organization: + headers["OpenAI-Organization"] = self.organization + + if self.project: + headers["OpenAI-Project"] = self.project + + return headers + + def _completions_payload( + self, orig_kwargs: Optional[Dict], max_output_tokens: Optional[int], **kwargs + ) -> Dict: + payload = orig_kwargs or {} + payload.update(kwargs) + payload["model"] = self.model + payload["stream"] = True + payload["stream_options"] = { + "include_usage": True, + } + + if max_output_tokens or self.max_output_tokens: + logger.debug( + "{} adding payload args for setting output_token_count: {}", + self.__class__.__name__, + max_output_tokens or self.max_output_tokens, + ) + payload["max_tokens"] = max_output_tokens or self.max_output_tokens + payload["max_completion_tokens"] = payload["max_tokens"] + + if max_output_tokens: + # only set stop and ignore_eos if max_output_tokens set at request level + # otherwise the instance value is just the max to enforce we stay below + payload["stop"] = None + payload["ignore_eos"] = True + + return payload + + @staticmethod + def _create_chat_messages( + content: Union[ + str, + List[Union[str, Dict[str, Union[str, Dict[str, str]]], Path, Image.Image]], + Any, + ], + ) -> List[Dict]: + if isinstance(content, str): + return [ { - "max_tokens": request.output_token_count, - "stop": None, + "role": "user", + "content": content, } - ) - elif settings.openai.max_gen_tokens and settings.openai.max_gen_tokens > 0: - request_args.update( + ] + + if isinstance(content, list): + resolved_content = [] + + for item in content: + if isinstance(item, Dict): + resolved_content.append(item) + elif isinstance(item, str): + resolved_content.append({"type": "text", "text": item}) + elif isinstance(item, Image.Image) or ( + isinstance(item, Path) and item.suffix.lower() in [".jpg", ".jpeg"] + ): + image = item if isinstance(item, Image.Image) else Image.open(item) + encoded = base64.b64encode(image.tobytes()).decode("utf-8") + resolved_content.append( + { + "type": "image", + "image": { + "url": f"data:image/jpeg;base64,{encoded}", + }, + } + ) + elif isinstance(item, Path) and item.suffix.lower() in [".wav"]: + encoded = base64.b64encode(item.read_bytes()).decode("utf-8") + resolved_content.append( + { + "type": "input_audio", + "input_audio": { + "data": f"{encoded}", + "format": "wav", + }, + } + ) + else: + raise ValueError( + f"Unsupported content item type: {item} in list: {content}" + ) + + return [ { - "max_tokens": settings.openai.max_gen_tokens, + "role": "user", + "content": resolved_content, } - ) + ] - request_args.update(self._request_args) + raise ValueError(f"Unsupported content type: {content}") - stream = await self._async_client.chat.completions.create( - model=self.model, - messages=[ - {"role": "user", "content": request.prompt}, - ], - stream=True, - **request_args, + async def _iterative_completions_request( + self, + type_: Literal["text", "chat"], + request_id: Optional[str], + request_prompt_tokens: Optional[int], + request_output_tokens: Optional[int], + headers: Dict, + payload: Dict, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + target = f"{self.target}/v1/" + + if type_ == "text": + target += "completions" + elif type_ == "chat": + target += "chat/completions" + else: + raise ValueError(f"Unsupported type: {type_}") + + logger.info( + "{} making request: {} to target: {} using http2: {} for " + "timeout: {} with headers: {} and payload: {}", + self.__class__.__name__, + request_id, + target, + self.http2, + self.timeout, + headers, + payload, ) - token_count = 0 - async for chunk in stream: - choice = chunk.choices[0] - token = choice.delta.content or "" - - if choice.finish_reason is not None: - yield GenerativeResponse( - type_="final", - prompt=request.prompt, - prompt_token_count=request.prompt_token_count, - output_token_count=token_count, - ) - break - - token_count += 1 - yield GenerativeResponse( - type_="token_iter", - add_token=token, - prompt=request.prompt, - prompt_token_count=request.prompt_token_count, - output_token_count=token_count, + async with httpx.AsyncClient(http2=self.http2, timeout=self.timeout) as client: + response_value = "" + response_prompt_count: Optional[int] = None + response_output_count: Optional[int] = None + iter_count = 0 + start_time = time.time() + iter_time = start_time + + yield StreamingTextResponse( + type_="start", + iter_count=iter_count, + delta="", + time=start_time, + request_id=request_id, ) - def available_models(self) -> List[str]: - """ - Get the available models for the backend. + async with client.stream( + "POST", target, headers=headers, json=payload + ) as stream: + stream.raise_for_status() + + async for line in stream.aiter_lines(): + iter_time = time.time() + logger.debug( + "{} request: {} recieved iter response line: {}", + self.__class__.__name__, + request_id, + line, + ) + + if not line or not line.strip().startswith("data:"): + continue + + if line.strip() == "data: [DONE]": + break + + data = json.loads(line.strip()[len("data: ") :]) + if delta := self._extract_completions_delta_content(type_, data): + iter_count += 1 + response_value += delta + + yield StreamingTextResponse( + type_="iter", + iter_count=iter_count, + delta=delta, + time=iter_time, + request_id=request_id, + ) + + if usage := self._extract_completions_usage(data): + response_prompt_count = usage["prompt"] + response_output_count = usage["output"] + + logger.info( + "{} request: {} with headers: {} and payload: {} completed with: {}", + self.__class__.__name__, + request_id, + headers, + payload, + response_value, + ) - This method queries the OpenAI API to retrieve a list of available models. + yield ResponseSummary( + value=response_value, + request_args=RequestArgs( + target=target, + headers=headers, + payload=payload, + timeout=self.timeout, + http2=self.http2, + ), + start_time=start_time, + end_time=iter_time, + iterations=iter_count, + request_prompt_tokens=request_prompt_tokens, + request_output_tokens=request_output_tokens, + response_prompt_tokens=response_prompt_count, + response_output_tokens=response_output_count, + request_id=request_id, + ) - :return: A list of available models. - :rtype: List[str] - :raises openai.OpenAIError: If an error occurs while retrieving models. - """ + @staticmethod + def _extract_completions_delta_content( + type_: Literal["text", "chat"], data: Dict + ) -> Optional[str]: + if "choices" not in data or not data["choices"]: + return None - try: - return [model.id for model in self._client.models.list().data] - except Exception as error: - logger.error("Failed to retrieve available models: {}", error) - raise error + if type_ == "text": + return data["choices"][0]["text"] - def validate_connection(self): - """ - Validate the connection to the OpenAI backend. + if type_ == "chat": + return data["choices"][0]["delta"]["content"] - This method checks that the OpenAI backend is reachable and - the API key is valid. + raise ValueError(f"Unsupported type: {type_}") - :raises openai.OpenAIError: If the connection is invalid. - """ + @staticmethod + def _extract_completions_usage( + data: Dict, + ) -> Optional[Dict[Literal["prompt", "output"], int]]: + if "usage" not in data or not data["usage"]: + return None - try: - self._client.models.list() - except Exception as error: - logger.error("Failed to validate OpenAI connection: {}", error) - raise error + return { + "prompt": data["usage"]["prompt_tokens"], + "output": data["usage"]["completion_tokens"], + } diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py new file mode 100644 index 00000000..699f41cc --- /dev/null +++ b/src/guidellm/backend/response.py @@ -0,0 +1,139 @@ +from typing import Any, Dict, Literal, Optional + +from loguru import logger +from pydantic import BaseModel, computed_field + +from guidellm.config import settings + +__all__ = [ + "StreamingResponseType", + "StreamingTextResponse", + "RequestArgs", + "ResponseSummary", +] + + +StreamingResponseType = Literal["start", "iter"] + + +class StreamingTextResponse(BaseModel): + """ + A model representing the response content for a streaming text request. + + :param type_: The type of the response; either 'start' or 'iter'. + :param iter_count: The iteration count for the response. For 'start' this is 0 + and for the first 'iter' it is 1. + :param delta: The text delta added to the response for this stream iteration. + :param time: If 'start', the time.time() the request started. + If 'iter', the time.time() the iteration was received. + :param request_id: The unique identifier for the request, if any. + """ + + type_: StreamingResponseType + iter_count: int + delta: str + time: float + request_id: Optional[str] = None + + +class RequestArgs(BaseModel): + """ + A model representing the arguments for a request to a backend. + Biases towards an HTTP request, but can be used for other types of backends. + + :param target: The target URL or function for the request. + :param headers: The headers, if any, included in the request such as authorization. + :param payload: The payload / arguments for the request including the prompt / + content and other configurations. + :param timeout: The timeout for the request in seconds, if any. + :param http2: Whether HTTP/2 was used for the request, if applicable. + """ + + target: str + headers: Dict[str, str] + payload: Dict[str, Any] + timeout: Optional[float] = None + http2: Optional[bool] = None + + +class ResponseSummary(BaseModel): + """ + A model representing a summary of a backend request. + Always returned as the final iteration of a streaming request. + + :param value: The final value returned from the request. + :param request_args: The arguments used to make the request. + :param start_time: The time the request started. + :param end_time: The time the request ended. + :param iterations: The number of iterations in the request. + :param prompt_tokens: The number of tokens in the prompt, if any usage was returned. + :param output_tokens: The number of tokens in the output, if any usage was returned. + :param request_id: The unique identifier for the request, if any. + """ + + value: str + request_args: RequestArgs + iterations: int = 0 + start_time: float + end_time: float + request_prompt_tokens: Optional[int] = None + request_output_tokens: Optional[int] = None + response_prompt_tokens: Optional[int] = None + response_output_tokens: Optional[int] = None + request_id: Optional[str] = None + + @computed_field # type: ignore[misc] + @property + def prompt_tokens(self) -> Optional[int]: + """ + The number of tokens measured in the prompt based on preferences + for trusting the input or response. + + :return: The number of tokens in the prompt, if any. + """ + if settings.preferred_prompt_tokens_source == "backend": + if self.response_prompt_tokens is None: + logger.warning( + "Preferred prompt tokens source is backend, but no prompt token " + f"values were returned with the response for {self}. " + "Defulating to request_prompt_tokens (if available)." + ) + return self.response_prompt_tokens or self.request_prompt_tokens + elif settings.preferred_prompt_tokens_source == "request": + if self.request_prompt_tokens is None: + logger.warning( + "Preferred prompt tokens source is request, but no prompt token " + f"values were returned with the request for {self}. " + "Defulating to response_prompt_tokens (if available)." + ) + return self.request_prompt_tokens or self.response_prompt_tokens + + return self.response_prompt_tokens or self.request_prompt_tokens + + @computed_field # type: ignore[misc] + @property + def output_tokens(self) -> Optional[int]: + """ + The number of tokens measured in the output based on preferences + for trusting the input or response. + + :return: The number of tokens in the output, if any. + """ + if settings.preferred_output_tokens_source == "backend": + if self.response_output_tokens is None: + logger.warning( + "Preferred output tokens source is backend, but no output token " + f"values were returned with the response for {self}. " + "Defulating to request_output_tokens (if available)." + ) + return self.response_output_tokens or self.request_output_tokens + elif settings.preferred_output_tokens_source == "request": + if self.request_output_tokens is None: + logger.warning( + "Preferred output tokens source is request, but no output token " + f"values were returned with the request for {self}. " + "Defulating to response_output_tokens (if available)." + ) + return self.request_output_tokens or self.response_output_tokens + + return self.response_output_tokens or self.request_output_tokens diff --git a/src/guidellm/config.py b/src/guidellm/config.py index c3d950ec..2d4e102a 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Literal, Optional, Sequence from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -98,14 +98,12 @@ class OpenAISettings(BaseModel): for OpenAI server based pathways """ - # OpenAI API key. - api_key: str = "invalid_token" - - # OpenAI-compatible server URL - # NOTE: The default value is default address of llama.cpp web server - base_url: str = "http://localhost:8000/v1" - - max_gen_tokens: int = 4096 + api_key: Optional[str] = None + bearer_token: Optional[str] = None + organization: Optional[str] = None + project: Optional[str] = None + base_url: str = "http://localhost:8000" + max_output_tokens: int = 16384 class ReportGenerationSettings(BaseModel): @@ -141,7 +139,8 @@ class Settings(BaseSettings): # general settings env: Environment = Environment.PROD - request_timeout: int = 30 + request_timeout: int = 60 * 5 # 5 minutes + request_http2: bool = True max_concurrency: int = 512 num_sweep_profiles: int = 9 logging: LoggingSettings = LoggingSettings() @@ -150,7 +149,10 @@ class Settings(BaseSettings): dataset: DatasetSettings = DatasetSettings() emulated_data: EmulatedDataSettings = EmulatedDataSettings() - # Request settings + # Request/stats settings + preferred_prompt_tokens_source: Optional[Literal["backend", "local"]] = None + preferred_output_tokens_source: Optional[Literal["backend", "local"]] = None + preferred_backend: Literal["openai"] = "openai" openai: OpenAISettings = OpenAISettings() # Report settings diff --git a/src/guidellm/core/distribution.py b/src/guidellm/core/distribution.py index 3f770528..749d6818 100644 --- a/src/guidellm/core/distribution.py +++ b/src/guidellm/core/distribution.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from typing import List, Sequence, Union import numpy as np from loguru import logger @@ -33,7 +33,7 @@ def mean(self) -> float: :return: The mean of the distribution. """ if not self.data: - logger.warning("No data points available to calculate mean.") + logger.info("No data points available to calculate mean.") return 0.0 mean_value = np.mean(self.data).item() @@ -47,7 +47,7 @@ def median(self) -> float: :return: The median of the distribution. """ if not self.data: - logger.warning("No data points available to calculate median.") + logger.info("No data points available to calculate median.") return 0.0 median_value = np.median(self.data).item() @@ -61,7 +61,7 @@ def variance(self) -> float: :return: The variance of the distribution. """ if not self.data: - logger.warning("No data points available to calculate variance.") + logger.info("No data points available to calculate variance.") return 0.0 variance_value = np.var(self.data).item() @@ -75,7 +75,7 @@ def std_deviation(self) -> float: :return: The standard deviation of the distribution. """ if not self.data: - logger.warning("No data points available to calculate standard deviation.") + logger.info("No data points available to calculate standard deviation.") return 0.0 std_deviation_value = np.std(self.data).item() @@ -89,21 +89,21 @@ def percentile(self, percentile: float) -> float: :return: The specified percentile of the distribution. """ if not self.data: - logger.warning("No data points available to calculate percentile.") + logger.info("No data points available to calculate percentile.") return 0.0 percentile_value = np.percentile(self.data, percentile).item() logger.debug(f"Calculated {percentile}th percentile: {percentile_value}") return percentile_value - def percentiles(self, percentiles: List[float]) -> List[float]: + def percentiles(self, percentiles: Union[List[int], List[float]]) -> List[float]: """ Calculate and return the specified percentiles of the distribution. :param percentiles: A list of desired percentiles to calculate (0-100). :return: A list of the specified percentiles of the distribution. """ if not self.data: - logger.warning("No data points available to calculate percentiles.") + logger.info("No data points available to calculate percentiles.") return [0.0] * len(percentiles) percentiles_values: List[float] = np.percentile(self.data, percentiles).tolist() # type: ignore # noqa: PGH003 @@ -117,7 +117,7 @@ def min(self) -> float: :return: The minimum value of the distribution. """ if not self.data: - logger.warning("No data points available to calculate minimum.") + logger.info("No data points available to calculate minimum.") return 0.0 min_value: float = np.min(self.data).item() # type: ignore # noqa: PGH003 @@ -131,7 +131,7 @@ def max(self) -> float: :return: The maximum value of the distribution. """ if not self.data: - logger.warning("No data points available to calculate maximum.") + logger.info("No data points available to calculate maximum.") return 0.0 max_value: float = np.max(self.data).item() # type: ignore # noqa: PGH003 @@ -145,7 +145,7 @@ def range(self) -> float: :return: The range of the distribution. """ if not self.data: - logger.warning("No data points available to calculate range.") + logger.info("No data points available to calculate range.") return 0.0 range_value = self.max - self.min diff --git a/src/guidellm/core/report.py b/src/guidellm/core/report.py index b6791e45..584fe63c 100644 --- a/src/guidellm/core/report.py +++ b/src/guidellm/core/report.py @@ -135,9 +135,9 @@ def _create_benchmark_report_data_tokens_summary( table = Table( "Benchmark", "Prompt", - "Prompt (1%, 5%, 50%, 95%, 99%)", + "Prompt (1%, 5%, 10%, 50%, 90%, 95%, 99%)", "Output", - "Output (1%, 5%, 50%, 95%, 99%)", + "Output (1%, 5%, 10%, 50%, 90%, 95%, 99%)", title="[magenta]Tokens Data by Benchmark[/magenta]", title_style="bold", title_justify="left", @@ -147,19 +147,15 @@ def _create_benchmark_report_data_tokens_summary( for benchmark in report.benchmarks_sorted: table.add_row( _benchmark_rate_id(benchmark), - f"{benchmark.prompt_token_distribution.mean:.2f}", + f"{benchmark.prompt_token:.2f}", ", ".join( f"{percentile:.1f}" - for percentile in benchmark.prompt_token_distribution.percentiles( - [1, 5, 50, 95, 99] - ) + for percentile in benchmark.prompt_token_percentiles.values() ), - f"{benchmark.output_token_distribution.mean:.2f}", + f"{benchmark.output_token:.2f}", ", ".join( f"{percentile:.1f}" - for percentile in benchmark.output_token_distribution.percentiles( - [1, 5, 50, 95, 99] - ) + for percentile in benchmark.output_token_percentiles.values() ), ) logger.debug("Created data tokens summary table for the report.") @@ -181,7 +177,7 @@ def _create_benchmark_report_dist_perf_summary( "Benchmark", "Request Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (sec)", "Time to First Token [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)", - "Inter Token Latency [1%, 5%, 10%, 50%, 90% 95%, 99%] (ms)", + "Inter Token Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)", title="[magenta]Performance Stats by Benchmark[/magenta]", title_style="bold", title_justify="left", @@ -193,21 +189,15 @@ def _create_benchmark_report_dist_perf_summary( _benchmark_rate_id(benchmark), ", ".join( f"{percentile:.2f}" - for percentile in benchmark.request_latency_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + for percentile in benchmark.request_latency_percentiles.values() ), ", ".join( - f"{percentile * 1000:.1f}" - for percentile in benchmark.ttft_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + f"{percentile:.1f}" + for percentile in benchmark.time_to_first_token_percentiles.values() ), ", ".join( - f"{percentile * 1000:.1f}" - for percentile in benchmark.itl_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + f"{percentile:.1f}" + for percentile in benchmark.inter_token_latency_percentiles.values() ), ) logger.debug("Created distribution performance summary table for the report.") diff --git a/src/guidellm/core/request.py b/src/guidellm/core/request.py index 4f7315c5..547ac60a 100644 --- a/src/guidellm/core/request.py +++ b/src/guidellm/core/request.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional from pydantic import Field @@ -15,6 +15,10 @@ class TextGenerationRequest(Serializable): default_factory=lambda: str(uuid.uuid4()), description="The unique identifier for the request.", ) + type_: Literal["text", "chat"] = Field( + default="text", + description="The type of text generation request (e.g., text, chat).", + ) prompt: str = Field(description="The input prompt for the text generation.") prompt_token_count: Optional[int] = Field( default=None, @@ -38,6 +42,7 @@ def __str__(self) -> str: return ( f"TextGenerationRequest(id={self.id}, " + f"type_={self.type_}" f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, " f"output_token_count={self.output_token_count}, " f"params={self.params})" diff --git a/src/guidellm/core/result.py b/src/guidellm/core/result.py index f218784c..2670c105 100644 --- a/src/guidellm/core/result.py +++ b/src/guidellm/core/result.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from loguru import logger -from pydantic import Field +from pydantic import Field, computed_field from guidellm.core.distribution import Distribution from guidellm.core.request import TextGenerationRequest @@ -17,6 +17,9 @@ ] +DEFAULT_PERCENTILES = [1, 5, 10, 50, 90, 95, 99] + + class TextGenerationResult(Serializable): """ A class to represent the result of a text generation request @@ -26,139 +29,98 @@ class TextGenerationResult(Serializable): request: TextGenerationRequest = Field( description="The text generation request used to generate the result.", ) - prompt: str = Field( - default_factory=str, - description="The input prompt for the text generation.", - ) - prompt_word_count: int = Field( - default=0, - description="The number of words in the input prompt.", - ) - prompt_token_count: int = Field( - default=0, + prompt_token_count: Optional[int] = Field( + default=None, description="The number of tokens in the input prompt.", ) output: str = Field( default_factory=str, description="The generated output for the text generation.", ) - output_word_count: int = Field( - default=0, - description="The number of words in the output.", - ) - output_token_count: int = Field( - default=0, - description="The number of tokens in the output.", - ) - last_time: Optional[float] = Field( + output_token_count: Optional[int] = Field( default=None, - description="The last time recorded.", - ) - first_token_set: bool = Field( - default=False, - description="Whether the first token time is set.", + description="The number of tokens in the output.", ) start_time: Optional[float] = Field( default=None, - description="The start time of the text generation.", + description="The absolute start time, in seconds, of the text generation.", ) end_time: Optional[float] = Field( default=None, - description="The end time of the text generation.", + description="The absolute end time, in seconds, of the text generation.", ) first_token_time: Optional[float] = Field( default=None, - description="The time taken to decode the first token.", + description="The absolute time, in seconds, the first token was received.", ) - decode_times: Distribution = Field( - default_factory=Distribution, - description="The distribution of decode times.", + last_token_time: Optional[float] = Field( + default=None, + description="The absolute time, in seconds, the last token was received.", ) - def start(self, prompt: str): + @computed_field # type: ignore[misc] + @property + def request_latency(self) -> Optional[float]: """ - Start the text generation by recording the prompt and start time. + Get the request latency in seconds. - :param prompt: The input prompt for the text generation. - :type prompt: str + :return: The request latency in seconds. """ - self.prompt = prompt - self.prompt_word_count = len(prompt.split()) - self.prompt_token_count = len(prompt) # Token count placeholder - self.start_time = time() - self.last_time = time() - self.first_token_set = False + if not self.end_time or not self.start_time: + return None - logger.info("Text generation started with prompt: '{}'", prompt) + return self.end_time - self.start_time - def output_token(self, token: str): + @computed_field # type: ignore[misc] + @property + def time_to_first_token(self) -> Optional[float]: """ - Add a token to the output and record the decode time. + Get the time taken to decode the first token in milliseconds. - :param token: The decoded token. - :type token: str + :return: The time taken to decode the first token in milliseconds. """ - self._check_recording_started() + if not self.first_token_time or not self.start_time: + return None - if self.last_time is None: - raise ValueError( - "last time is not specified. " - "Did you call `text_generation_benchmark.start()`?" - ) + return 1000 * (self.first_token_time - self.start_time) - current_counter = time() + @computed_field # type: ignore[misc] + @property + def inter_token_latency(self) -> Optional[float]: + """ + Get the average time between tokens in milliseconds. - if not self.first_token_set: - self.first_token_time = current_counter - self.last_time - self.first_token_set = True - logger.debug(f"First token decode time: {self.first_token_time}") - else: - decode_time = current_counter - self.last_time - self.decode_times.add_data([decode_time]) - logger.debug(f"Token '{token}' decoded in {decode_time} seconds") + :return: The average time between tokens. + """ + if ( + not self.last_token_time + or not self.first_token_time + or not self.output_token_count + or self.output_token_count < 2 # noqa: PLR2004 + ): + return None - self.last_time = current_counter - self.output += token - logger.debug("Added token {} to output", token) + return ( + 1000 + * (self.last_token_time - self.first_token_time) + / (self.output_token_count - 1) # ignore first token + ) - def end( - self, - output: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - ): + @computed_field # type: ignore[misc] + @property + def output_tokens_per_second(self) -> Optional[float]: """ - End the text generation by recording the output and end time. + Get the average token throughput in tokens per second for the entire request. + Note, does not account for the time taken to decode the first token. - :param output: The generated output for the text generation. - :type output: str - :param prompt_token_count: Optional token count for the prompt, - defaults to word count. - :type prompt_token_count: Optional[int] - :param output_token_count: Optional token count for the output, - defaults to word count. - :type output_token_count: Optional[int] + :return: The average token throughput. """ - self._check_recording_started() - self.end_time = time() - - if output: - self.output = output + itl = self.inter_token_latency - self.output_word_count = len(self.output.split()) - self.output_token_count = output_token_count or self.output_word_count - self.prompt_token_count = prompt_token_count or self.prompt_word_count - - logger.info(f"Text generation ended with output: '{self.output}'") + if itl is None: + return None - def _check_recording_started( - self, - ): - if self.start_time is None: - raise ValueError( - "start time is not specified. " - "Did you make the `text_generation_benchmark.start()`?", - ) + return 1000.0 / itl class TextGenerationError(Serializable): @@ -221,88 +183,96 @@ def __iter__(self): """ return iter(self.results) + @computed_field # type: ignore[misc] @property def request_count(self) -> int: """ Get the number of requests in the result. :return: The number of requests. - :rtype: int """ return len(self.results) + @computed_field # type: ignore[misc] @property def error_count(self) -> int: """ Get the number of errors in the result. :return: The number of errors. - :rtype: int """ return len(self.errors) + @computed_field # type: ignore[misc] @property def total_count(self) -> int: """ Get the total number of requests in the result. :return: The total number of requests. - :rtype: int """ return self.request_count + self.error_count + @computed_field # type: ignore[misc] @property def start_time(self) -> Optional[float]: """ Get the start time of the first request in the result. :return: The start time of the first request. - :rtype: Optional[float] """ - if not self.results: - return None - - return self.results[0].start_time + return self.results[0].start_time if self.results else None + @computed_field # type: ignore[misc] @property def end_time(self) -> Optional[float]: """ Get the end time of the last request in the result. :return: The end time of the last request. - :rtype: Optional[float] """ - if not self.results: - return None - - return self.results[-1].end_time + return self.results[-1].end_time if self.results else None + @computed_field # type: ignore[misc] @property def duration(self) -> float: """ Get the duration of the result in seconds. :return: The duration of the result. - :rtype: float """ - if not self.results or not self.start_time or not self.end_time: - return 0.0 - - return self.end_time - self.start_time + return ( + self.end_time - self.start_time + if self.end_time and self.start_time + else 0.0 + ) + @computed_field # type: ignore[misc] @property def completed_request_rate(self) -> float: """ Get the rate of requests per second in the result. :return: The rate of requests per second. - :rtype: float """ - if not self.results or not self.duration: - return 0.0 + return self.request_count / self.duration if self.duration else 0.0 + + @property + def request_latency_distribution(self) -> Distribution: + """ + Get the distribution of request latencies in seconds. - return len(self.results) / self.duration + :return: The distribution of request latencies. + """ + return Distribution( + data=[ + result.request_latency + for result in self.results + if result.request_latency + ] + ) + @computed_field # type: ignore[misc] @property def request_latency(self) -> float: """ @@ -311,97 +281,124 @@ def request_latency(self) -> float: :return: The average request latency in seconds. :rtype: float """ + return self.request_latency_distribution.mean + + @computed_field # type: ignore[misc] + @property + def request_latency_percentiles(self) -> Dict[str, float]: + """ + Get standard percentiles of request latency in seconds. + + :return: A dictionary mapping percentile to request latency in seconds. + """ if not self.results: - return 0.0 + return {} - return self.request_latency_distribution.mean + values = self.request_latency_distribution.percentiles(DEFAULT_PERCENTILES) + + return dict(zip(map(str, DEFAULT_PERCENTILES), values)) @property - def request_latency_distribution(self) -> Distribution: + def ttft_distribution(self) -> Distribution: """ - Get the distribution of request latencies. + Get the distribution of time taken to decode the first token. - :return: The distribution of request latencies. - :rtype: Distribution + :return: The distribution of time taken to decode the first token. """ return Distribution( data=[ - result.end_time - result.start_time + result.time_to_first_token for result in self.results - if result.end_time is not None and result.start_time is not None + if result.time_to_first_token ] ) + @computed_field # type: ignore[misc] @property def time_to_first_token(self) -> float: """ Get the time taken to decode the first token in milliseconds. :return: The time taken to decode the first token in milliseconds. - :rtype: float + """ + return self.ttft_distribution.mean + + @computed_field # type: ignore[misc] + @property + def time_to_first_token_percentiles(self) -> Dict[str, float]: + """ + Get standard percentiles for time taken to decode the first token + in milliseconds. + + :return: A dictionary mapping percentile to time taken for the first token. """ if not self.results: - return 0.0 + return {} - return 1000 * self.ttft_distribution.mean + values = self.ttft_distribution.percentiles(DEFAULT_PERCENTILES) + + return dict(zip(map(str, DEFAULT_PERCENTILES), values)) @property - def ttft_distribution(self) -> Distribution: + def itl_distribution(self) -> Distribution: """ - Get the distribution of time taken to decode the first token. + Get the distribution of time between tokens in milliseconds. - :return: The distribution of time taken to decode the first token. - :rtype: Distribution + :return: The distribution of time between tokens. """ return Distribution( data=[ - result.first_token_time + result.inter_token_latency for result in self.results - if result.first_token_time is not None + for _ in range( + result.output_token_count - 1 + if result.output_token_count and result.output_token_count > 1 + else 0 + ) + if (result.inter_token_latency) ] ) + @computed_field # type: ignore[misc] @property def inter_token_latency(self) -> float: """ Get the average time between tokens in milliseconds. :return: The average time between tokens. - :rtype: float """ - if not self.results: - return 0.0 - - return 1000 * self.itl_distribution.mean + return self.itl_distribution.mean + @computed_field # type: ignore[misc] @property - def itl_distribution(self) -> Distribution: + def inter_token_latency_percentiles(self) -> Dict[str, float]: """ - Get the distribution of time between tokens. + Get standard percentiles for the time between tokens in milliseconds. - :return: The distribution of time between tokens. - :rtype: Distribution + :return: A dictionary mapping percentile to time between tokens. """ - return Distribution( - data=[ - decode for result in self.results for decode in result.decode_times.data - ] - ) + if not self.results: + return {} + + values = self.itl_distribution.percentiles(DEFAULT_PERCENTILES) + return dict(zip(map(str, DEFAULT_PERCENTILES), values)) + + @computed_field # type: ignore[misc] @property def output_token_throughput(self) -> float: """ Get the average token throughput in tokens per second. :return: The average token throughput. - :rtype: float """ - if not self.results or not self.duration: - return 0.0 - - total_tokens = sum(result.output_token_count for result in self.results) + output_tokens = sum( + result.output_token_count + for result in self.results + if result.output_token_count and result.output_token_count > 0 + ) - return total_tokens / self.duration + return output_tokens / self.duration if self.duration else 0.0 @property def prompt_token_distribution(self) -> Distribution: @@ -409,9 +406,39 @@ def prompt_token_distribution(self) -> Distribution: Get the distribution of prompt token counts. :return: The distribution of prompt token counts. - :rtype: Distribution """ - return Distribution(data=[result.prompt_token_count for result in self.results]) + return Distribution( + data=[ + result.prompt_token_count + for result in self.results + if result.prompt_token_count + ] + ) + + @computed_field # type: ignore[misc] + @property + def prompt_token(self) -> float: + """ + Get the average number of prompt tokens. + + :return: The average number of prompt tokens. + """ + return self.prompt_token_distribution.mean + + @computed_field # type: ignore[misc] + @property + def prompt_token_percentiles(self) -> Dict[str, float]: + """ + Get standard percentiles for number of prompt tokens. + + :return: A dictionary mapping percentile to number of prompt tokens. + """ + if not self.results: + return {} + + values = self.prompt_token_distribution.percentiles(DEFAULT_PERCENTILES) + + return dict(zip(map(str, DEFAULT_PERCENTILES), values)) @property def output_token_distribution(self) -> Distribution: @@ -419,26 +446,39 @@ def output_token_distribution(self) -> Distribution: Get the distribution of output token counts. :return: The distribution of output token counts. - :rtype: Distribution """ - return Distribution(data=[result.output_token_count for result in self.results]) + return Distribution( + data=[ + result.output_token_count + for result in self.results + if result.output_token_count + ] + ) + + @computed_field # type: ignore[misc] + @property + def output_token(self) -> float: + """ + Get the average number of output tokens. + + :return: The average number of output tokens. + """ + return self.output_token_distribution.mean + @computed_field # type: ignore[misc] @property - def overloaded(self) -> bool: - if ( - self.rate is None - or not self.results - or not self.concurrencies - or len(self.concurrencies) < 2 # noqa: PLR2004 - ): - # if rate was not set, sync mode is assumed, - # or we have less than 2 data points, - # then we cannot be overloaded by definition - return False - - # if the calculated rate is less than 75% of the requested rate, - # safe to assume the system is overloaded - return self.completed_request_rate < 0.75 * self.rate + def output_token_percentiles(self) -> Dict[str, float]: + """ + Get standard percentiles for number of output tokens. + + :return: List of percentiles of number of output tokens. + """ + if not self.results: + return {} + + values = self.output_token_distribution.percentiles(DEFAULT_PERCENTILES) + + return dict(zip(map(str, DEFAULT_PERCENTILES), values)) def request_started(self): """ diff --git a/src/guidellm/core/serializable.py b/src/guidellm/core/serializable.py index 1e6b2944..23e6845a 100644 --- a/src/guidellm/core/serializable.py +++ b/src/guidellm/core/serializable.py @@ -18,7 +18,7 @@ class Serializable(BaseModel): """ model_config = ConfigDict( - extra="forbid", + extra="ignore", use_enum_values=True, validate_assignment=True, from_attributes=True, diff --git a/src/guidellm/executor/__init__.py b/src/guidellm/executor/__init__.py index d5858d07..7665e898 100644 --- a/src/guidellm/executor/__init__.py +++ b/src/guidellm/executor/__init__.py @@ -1,4 +1,4 @@ -from .base import Executor, ExecutorResult +from .executor import Executor, ExecutorResult from .profile_generator import Profile, ProfileGenerationMode, ProfileGenerator __all__ = [ diff --git a/src/guidellm/executor/base.py b/src/guidellm/executor/executor.py similarity index 99% rename from src/guidellm/executor/base.py rename to src/guidellm/executor/executor.py index 865ab30d..bfecf17f 100644 --- a/src/guidellm/executor/base.py +++ b/src/guidellm/executor/executor.py @@ -170,7 +170,7 @@ async def run(self) -> AsyncGenerator[ExecutorResult, None]: logger.debug("Generated profile: {}", profile) scheduler = Scheduler( generator=self.request_generator, - worker=self.backend, + backend=self.backend, mode=profile.load_gen_mode, rate=profile.load_gen_rate, max_number=self.max_number or profile.args.get("max_number", None), diff --git a/src/guidellm/main.py b/src/guidellm/main.py index 1e6bed80..e7363c6e 100644 --- a/src/guidellm/main.py +++ b/src/guidellm/main.py @@ -3,8 +3,9 @@ import click from loguru import logger +from transformers import AutoTokenizer # type: ignore[import-untyped] -from guidellm.backend import Backend, BackendEnginePublic +from guidellm.backend import Backend, BackendType from guidellm.core import GuidanceReport, TextGenerationBenchmarkReport from guidellm.executor import Executor, ProfileGenerationMode from guidellm.request import ( @@ -25,13 +26,13 @@ required=True, help=( "The target path or url for the backend to evaluate. " - "Ex: 'http://localhost:8000/v1'" + "Ex: 'http://localhost:8000'" ), ) @click.option( "--backend", - type=click.Choice(get_args(BackendEnginePublic)), - default="openai_server", + type=click.Choice(get_args(BackendType)), + default="openai_http", help=( "The backend to use for benchmarking. " "The default is OpenAI Server enabling compatability with any server that " @@ -153,7 +154,7 @@ ) def generate_benchmark_report_cli( target: str, - backend: BackendEnginePublic, + backend: BackendType, model: Optional[str], data: Optional[str], data_type: Literal["emulated", "file", "transformers"], @@ -186,18 +187,18 @@ def generate_benchmark_report_cli( def generate_benchmark_report( target: str, - backend: BackendEnginePublic, - model: Optional[str], data: Optional[str], data_type: Literal["emulated", "file", "transformers"], - tokenizer: Optional[str], - rate_type: ProfileGenerationMode, - rate: Optional[float], - max_seconds: Optional[int], - max_requests: Union[Literal["dataset"], int, None], - output_path: str, - cont_refresh_table: bool, + backend: BackendType = "openai_http", backend_kwargs: Optional[Mapping[str, Any]] = None, + model: Optional[str] = None, + tokenizer: Optional[str] = None, + rate_type: ProfileGenerationMode = "sweep", + rate: Optional[float] = None, + max_seconds: Optional[int] = 120, + max_requests: Union[Literal["dataset"], int, None] = None, + output_path: Optional[str] = None, + cont_refresh_table: bool = False, ) -> GuidanceReport: """ Generate a benchmark report for a specified backend and dataset. @@ -227,11 +228,12 @@ def generate_benchmark_report( # Create backend backend_inst = Backend.create( - backend_type=backend, + type_=backend, target=target, model=model, **(backend_kwargs or {}), ) + backend_inst.validate() request_generator: RequestGenerator @@ -239,7 +241,7 @@ def generate_benchmark_report( tokenizer_inst = tokenizer if not tokenizer_inst: try: - tokenizer_inst = backend_inst.model_tokenizer() + tokenizer_inst = AutoTokenizer.from_pretrained(backend_inst.model) except Exception as err: raise ValueError( "Could not load model's tokenizer, " diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index b3b4ac50..39485648 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -1,4 +1,4 @@ -from .base import Scheduler, SchedulerResult from .load_generator import LoadGenerationMode, LoadGenerator +from .scheduler import Scheduler, SchedulerResult __all__ = ["LoadGenerationMode", "LoadGenerator", "Scheduler", "SchedulerResult"] diff --git a/src/guidellm/scheduler/base.py b/src/guidellm/scheduler/scheduler.py similarity index 79% rename from src/guidellm/scheduler/base.py rename to src/guidellm/scheduler/scheduler.py index f0087330..2f8c44fe 100644 --- a/src/guidellm/scheduler/base.py +++ b/src/guidellm/scheduler/scheduler.py @@ -6,7 +6,7 @@ from loguru import logger -from guidellm.backend import Backend +from guidellm.backend import Backend, ResponseSummary, StreamingTextResponse from guidellm.config import settings from guidellm.core import ( TextGenerationBenchmark, @@ -50,8 +50,8 @@ class Scheduler: :param generator: The request generator that produces text generation requests. :type generator: RequestGenerator - :param worker: The backend worker that processes the requests. - :type worker: Backend + :param backend: The backend that processes the requests. + :type backend: Backend :param mode: The mode of load generation (e.g., synchronous, asynchronous). :type mode: LoadGenerationMode :param rate: The rate at which requests are generated, if applicable. @@ -69,17 +69,17 @@ class Scheduler: def __init__( self, generator: RequestGenerator, - worker: Backend, + backend: Backend, mode: LoadGenerationMode = "synchronous", rate: Optional[float] = None, max_number: Optional[int] = None, max_duration: Optional[float] = None, ): logger.info( - "Scheduler initialized with params: generator={}, worker={}, mode={}, " + "Scheduler initialized with params: generator={}, backend={}, mode={}, " "rate={}, max_number={}, max_duration={}", generator, - worker, + backend, mode, rate, max_number, @@ -115,7 +115,7 @@ def __init__( raise err self._generator = generator - self._worker = worker + self._backend = backend self._mode = mode self._rate = rate self._max_number = max_number @@ -134,14 +134,14 @@ def generator(self) -> RequestGenerator: return self._generator @property - def worker(self) -> Backend: + def backend(self) -> Backend: """ - The backend worker that processes the requests. + The backend that processes the requests. - :return: The backend worker instance. + :return: The backend instance. :rtype: Backend """ - return self._worker + return self._backend @property def mode(self) -> LoadGenerationMode: @@ -289,7 +289,7 @@ async def _run_sync( submit_at, ) benchmark.request_started() - result = await self._submit_task_coroutine(request, submit_at, end_time) + result = await self._scheduled_request(request, submit_at, end_time) if result is not None: benchmark.request_completed(result) logger.debug("Request completed with output: {}", result) @@ -328,7 +328,7 @@ def _completed(_task: asyncio.Task) -> None: benchmark.request_started() task = asyncio.create_task( - self._submit_task_coroutine(request, submit_at, end_time) + self._scheduled_request(request, submit_at, end_time) ) task.add_done_callback(_completed) tasks.append(task) @@ -341,17 +341,11 @@ def _completed(_task: asyncio.Task) -> None: if task_res is not None: yield task_res - async def _submit_task_coroutine( + async def _scheduled_request( self, request: TextGenerationRequest, submit_at: float, end_time: float ) -> Optional[Union[TextGenerationResult, TextGenerationError]]: try: if submit_at > end_time: - logger.info( - "Request {} submission time {} is greater than end time {}", - request, - submit_at, - end_time, - ) raise asyncio.TimeoutError( f"Request submission time {submit_at} " f"is greater than end time {end_time}" @@ -364,12 +358,60 @@ async def _submit_task_coroutine( end_time - time.time() if end_time and end_time < math.inf else None ) - return await asyncio.wait_for(self._worker.submit(request), timeout=timeout) - except asyncio.TimeoutError as exc: - logger.info("Request {} timed out: {}", request, exc) - - return None + return await asyncio.wait_for( + self._resolve_text_request(request), timeout=timeout + ) except Exception as exc: # noqa: BLE001 - logger.warning("Request {} failed: {}", request, exc) + if not isinstance(exc, asyncio.TimeoutError): + logger.warning("Request {} failed: {}", request, exc) return TextGenerationError(request=request, message=str(exc)) + + async def _resolve_text_request( + self, request: TextGenerationRequest + ) -> TextGenerationResult: + final_resp = None + first_token_time = None + last_token_time = None + + if request.type_ == "text": + async for resp in self._backend.text_completions( # type: ignore[attr-defined] + prompt=request.prompt, + id_=request.id, + prompt_token_count=request.prompt_token_count, + output_token_count=request.output_token_count, + ): + if isinstance(resp, StreamingTextResponse) and resp.type_ == "iter": + first_token_time = first_token_time or resp.time + last_token_time = resp.time + + final_resp = resp + elif request.type_ == "chat": + async for resp in self._backend.chat_completions( # type: ignore[attr-defined] + content=request.prompt, + id_=request.id, + prompt_token_count=request.prompt_token_count, + output_token_count=request.output_token_count, + ): + if isinstance(resp, StreamingTextResponse) and resp.type_ == "iter": + first_token_time = first_token_time or resp.time + last_token_time = resp.time + + final_resp = resp + + if not final_resp or not isinstance(final_resp, ResponseSummary): + raise ValueError( + f"Invalid final response for request: {request} " + f"and backend: {self._backend}, recieved: {final_resp}" + ) + + return TextGenerationResult( + request=request, + prompt_token_count=final_resp.prompt_tokens, + output=final_resp.value, + output_token_count=resp.output_tokens, + start_time=resp.start_time, + end_time=resp.end_time, + first_token_time=first_token_time, + last_token_time=last_token_time, + ) diff --git a/tests/dummy/data/__init__.py b/tests/dummy/data/__init__.py index 95a2c946..e69de29b 100644 --- a/tests/dummy/data/__init__.py +++ b/tests/dummy/data/__init__.py @@ -1,3 +0,0 @@ -from .openai import openai_completion_factory, openai_model_factory - -__all__ = ["openai_completion_factory", "openai_model_factory"] diff --git a/tests/dummy/data/openai.py b/tests/dummy/data/openai.py deleted file mode 100644 index 6e168658..00000000 --- a/tests/dummy/data/openai.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -This module includes data models factories for openai 3-rd party package -""" - -import random -import string -import time -import uuid -from typing import Generator - -from openai.types import Completion, Model - - -def words(n: int = 1) -> Generator[str, None, None]: - for _ in range(n): - yield "".join( - random.choice(string.ascii_letters) for _ in range(random.randint(3, 10)) - ) - - -def openai_completion_factory( - n: int = 3, - **kwargs, -) -> Generator[Completion, None, None]: - """ - The factory that yields the openai Completion instance. - """ - - for i in range(1, n + 1): - payload = { - "id": str(uuid.uuid4()), - "choices": [], - "stop": not i < n, - "content": " ".join(words(random.randint(3, 10))) if i < n else "", - "object": "text_completion", - "model": "mock-model", - "created": int(time.time()), - } - payload.update(kwargs) - - yield Completion(**payload) # type: ignore - - -def openai_model_factory(n: int = 3) -> Generator[Model, None, None]: - """ - The factory that yields the random openai Model instance. - """ - for _ in range(n): - yield Model( - id=str(uuid.uuid4()), - created=int(time.time()), - object="model", - owned_by="neuralmagic", - ) diff --git a/tests/unit/backend/test_backend.py b/tests/unit/backend/test_backend.py new file mode 100644 index 00000000..29a008e1 --- /dev/null +++ b/tests/unit/backend/test_backend.py @@ -0,0 +1,133 @@ +import time + +import pytest + +from guidellm.backend import ( + Backend, + ResponseSummary, + StreamingTextResponse, +) + + +@pytest.mark.smoke() +def test_backend_registry(): + assert Backend._registry["mock"] is not None # type: ignore + + backend_instance = Backend.create("mock") # type: ignore + assert backend_instance is not None + + with pytest.raises(ValueError): + Backend.register("mock")("backend") # type: ignore + + with pytest.raises(ValueError): + Backend.create("invalid_type") # type: ignore + + +@pytest.mark.smoke() +@pytest.mark.asyncio() +async def test_backend_text_completions(mock_backend): + index = 0 + prompt = "Test Prompt" + request_id = "test-request-id" + prompt_token_count = 3 + output_token_count = 10 + final_resp = None + + async for response in mock_backend.text_completions( + prompt=prompt, + request_id=request_id, + prompt_token_count=prompt_token_count, + output_token_count=output_token_count, + ): + assert isinstance(response, (StreamingTextResponse, ResponseSummary)) + + if index == 0: + assert isinstance(response, StreamingTextResponse) + assert response.type_ == "start" + assert response.iter_count == 0 + assert response.delta == "" + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == request_id + elif not isinstance(response, ResponseSummary): + assert response.type_ == "iter" + assert response.iter_count == index + assert len(response.delta) > 0 + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == request_id + else: + assert not final_resp + final_resp = response + assert isinstance(response, ResponseSummary) + assert len(response.value) > 0 + assert response.iterations > 0 + assert response.start_time > 0 + assert response.end_time == pytest.approx(time.time(), abs=0.01) + assert response.request_prompt_tokens == prompt_token_count + assert response.request_output_tokens == output_token_count + assert response.response_prompt_tokens == 3 + assert response.response_output_tokens == 10 + assert response.request_id == request_id + + index += 1 + + assert final_resp + + +@pytest.mark.smoke() +@pytest.mark.asyncio() +async def test_backend_chat_completions(mock_backend): + index = 0 + prompt = "Test Prompt" + request_id = "test-request-id" + prompt_token_count = 3 + output_token_count = 10 + final_resp = None + + async for response in mock_backend.chat_completions( + content=prompt, + request_id=request_id, + prompt_token_count=prompt_token_count, + output_token_count=output_token_count, + ): + assert isinstance(response, (StreamingTextResponse, ResponseSummary)) + + if index == 0: + assert isinstance(response, StreamingTextResponse) + assert response.type_ == "start" + assert response.iter_count == 0 + assert response.delta == "" + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == request_id + elif not isinstance(response, ResponseSummary): + assert response.type_ == "iter" + assert response.iter_count == index + assert len(response.delta) > 0 + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == request_id + else: + assert not final_resp + final_resp = response + assert isinstance(response, ResponseSummary) + assert len(response.value) > 0 + assert response.iterations > 0 + assert response.start_time > 0 + assert response.end_time == pytest.approx(time.time(), abs=0.01) + assert response.request_prompt_tokens == prompt_token_count + assert response.request_output_tokens == output_token_count + assert response.response_prompt_tokens == 3 + assert response.response_output_tokens == 10 + assert response.request_id == request_id + + index += 1 + + assert final_resp + + +@pytest.mark.smoke() +def test_backend_models(mock_backend): + assert mock_backend.available_models() == ["mock-model"] + + +@pytest.mark.smoke() +def test_backend_validate(mock_backend): + mock_backend.validate() diff --git a/tests/unit/backend/test_base.py b/tests/unit/backend/test_base.py deleted file mode 100644 index edd61a90..00000000 --- a/tests/unit/backend/test_base.py +++ /dev/null @@ -1,258 +0,0 @@ -import pytest - -from guidellm.backend import Backend, GenerativeResponse -from guidellm.core import TextGenerationRequest, TextGenerationResult - - -@pytest.mark.smoke() -def test_backend_registry(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="Test") - - def available_models(self): - return ["mock-model"] - - backend_type = "test" - Backend.register(backend_type)(MockBackend) # type: ignore - assert Backend._registry[backend_type] is MockBackend # type: ignore - - backend_instance = Backend.create(backend_type) # type: ignore - assert isinstance(backend_instance, MockBackend) - - with pytest.raises(ValueError): - Backend.create("invalid_type") # type: ignore - - -@pytest.mark.smoke() -def test_generative_response_creation(): - response = GenerativeResponse(type_="final", output="Test Output") - assert response.type_ == "final" - assert response.output == "Test Output" - assert response.add_token is None - assert response.prompt is None - - response = GenerativeResponse(type_="token_iter", add_token="token") - assert response.type_ == "token_iter" - assert response.add_token == "token" - assert response.output is None - - -@pytest.mark.smoke() -@pytest.mark.asyncio() -async def test_backend_make_request(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse( - type_="token_iter", - add_token="Token", - prompt="Hello, world!", - prompt_token_count=5, - ) - yield GenerativeResponse( - type_="final", - output="This is a final response.", - prompt="Hello, world!", - prompt_token_count=5, - output_token_count=10, - ) - - def available_models(self): - return ["mock-model"] - - backend = MockBackend() - index = 0 - - async for response in backend.make_request(TextGenerationRequest(prompt="Test")): - if index == 0: - assert response.type_ == "token_iter" - assert response.add_token == "Token" - assert response.prompt == "Hello, world!" - assert response.prompt_token_count == 5 - else: - assert response.type_ == "final" - assert response.output == "This is a final response." - assert response.prompt == "Hello, world!" - assert response.prompt_token_count == 5 - assert response.output_token_count == 10 - index += 1 - - -@pytest.mark.smoke() -@pytest.mark.asyncio() -async def test_backend_submit_final(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="Test") - - def available_models(self): - return ["mock-model"] - - backend = MockBackend() - result = await backend.submit(TextGenerationRequest(prompt="Test")) - assert isinstance(result, TextGenerationResult) - assert result.output == "Test" - - -@pytest.mark.smoke() -@pytest.mark.asyncio() -async def test_backend_submit_multi(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse(type_="token_iter", add_token="Token") - yield GenerativeResponse(type_="token_iter", add_token=" ") - yield GenerativeResponse(type_="token_iter", add_token="Test") - yield GenerativeResponse(type_="final") - - def available_models(self): - return ["mock-model"] - - backend = MockBackend() - result = await backend.submit(TextGenerationRequest(prompt="Test")) - assert isinstance(result, TextGenerationResult) - assert result.output == "Token Test" - - -@pytest.mark.regression() -@pytest.mark.asyncio() -async def test_backend_submit_no_response(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - if False: # simulate no yield - yield - - def available_models(self): - return ["mock-model"] - - backend = MockBackend() - - with pytest.raises(ValueError): - await backend.submit(TextGenerationRequest(prompt="Test")) - - -@pytest.mark.smoke() -@pytest.mark.asyncio() -async def test_backend_submit_multi_final(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse(type_="token_iter", add_token="Token") - yield GenerativeResponse(type_="token_iter", add_token=" ") - yield GenerativeResponse(type_="token_iter", add_token="Test") - yield GenerativeResponse(type_="final") - yield GenerativeResponse(type_="final") - - def available_models(self): - return ["mock-model"] - - backend = MockBackend() - - with pytest.raises(ValueError): - await backend.submit(TextGenerationRequest(prompt="Test")) - - -@pytest.mark.smoke() -def test_backend_models(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - def available_models(self): - return ["mock-model", "mock-model-2"] - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="") - - backend = MockBackend() - assert backend.available_models() == ["mock-model", "mock-model-2"] - assert backend.default_model == "mock-model" - - -@pytest.mark.smoke() -def test_backend_test_connection(): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def available_models(self): - return ["mock-model", "mock-model-2"] - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="") - - assert MockBackend().test_connection() - - -@pytest.mark.smoke() -def test_backend_tokenizer(mock_auto_tokenizer): - class MockBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def available_models(self): - return ["mock-model", "mock-model-2"] - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="") - - backend = MockBackend() - tokenizer = backend.model_tokenizer() - assert tokenizer is not None - assert tokenizer.tokenize("text") is not None - - -@pytest.mark.regression() -def test_backend_abstract_methods(): - with pytest.raises(TypeError): - Backend() # type: ignore - - class IncompleteBackend(Backend): - def __init__(self): - super().__init__("test", "http://localhost:8000", "mock-model") - - def test_connection(self) -> bool: - return True - - async def make_request(self, request): - yield GenerativeResponse(type_="final", output="Test") - - with pytest.raises(TypeError): - IncompleteBackend() # type: ignore diff --git a/tests/unit/backend/test_openai_backend.py b/tests/unit/backend/test_openai_backend.py index 00b74236..db03c259 100644 --- a/tests/unit/backend/test_openai_backend.py +++ b/tests/unit/backend/test_openai_backend.py @@ -1,282 +1,199 @@ -from unittest.mock import AsyncMock, Mock, patch +import time import pytest -from guidellm.backend import Backend, OpenAIBackend -from guidellm.config import reload_settings, settings -from guidellm.core import TextGenerationRequest +from guidellm.backend import OpenAIHTTPBackend, ResponseSummary, StreamingTextResponse +from guidellm.config import settings -@pytest.fixture() -def mock_openai_client(): - with patch("guidellm.backend.openai.AsyncOpenAI") as mock_async_const, patch( - "guidellm.backend.openai.OpenAI" - ) as mock_sync_const: - mock_model = Mock() - mock_model.id = "mock-model" - mock_model_2 = Mock() - mock_model_2.id = "mock-model-2" - mock_model_data = Mock() - mock_model_data.data = [mock_model, mock_model_2] - - def create_async_create(inst): - async def stream(): - for ind in range(3): - choice = Mock() - choice.delta.content = f"token{ind}" if ind % 2 == 0 else " " - choice.finish_reason = None - chunk = Mock() - chunk.choices = [choice] - - yield chunk - - choice = Mock() - choice.finish_reason = "stop" - chunk = Mock() - chunk.choices = [choice] - yield chunk - - async def create(*args, **kwargs): - inst.create_args = args - inst.create_kwargs = kwargs - return stream() - - return create - - def async_constructor(*args, **kwargs): - mock_async_instance = AsyncMock() - mock_async_instance.models.list.return_value = mock_model_data - mock_async_instance.args = args - mock_async_instance.kwargs = kwargs - mock_async_instance.chat.completions.create.side_effect = ( - create_async_create(mock_async_instance) - ) - - return mock_async_instance +@pytest.mark.smoke() +def test_openai_http_backend_default_initialization(): + backend = OpenAIHTTPBackend() + assert backend.target == settings.openai.base_url + assert backend.model is None + assert backend.authorization == settings.openai.bearer_token + assert backend.organization == settings.openai.organization + assert backend.project == settings.openai.project + assert backend.timeout == settings.request_timeout + assert backend.http2 is True + assert backend.max_output_tokens == settings.openai.max_output_tokens - def sync_constructor(*args, **kwargs): - mock_sync_instance = Mock() - mock_sync_instance.models.list.return_value = mock_model_data - mock_sync_instance.args = args - mock_sync_instance.kwargs = kwargs - return mock_sync_instance - mock_async_const.side_effect = async_constructor - mock_sync_const.side_effect = sync_constructor - yield mock_async_const, mock_sync_const +@pytest.mark.smoke() +def test_openai_http_backend_intialization(): + backend = OpenAIHTTPBackend( + target="http://test-target", + model="test-model", + api_key="test-key", + organization="test-org", + project="test-proj", + timeout=10, + http2=False, + max_output_tokens=100, + ) + assert backend.target == "http://test-target" + assert backend.model == "test-model" + assert backend.authorization == "Bearer test-key" + assert backend.organization == "test-org" + assert backend.project == "test-proj" + assert backend.timeout == 10 + assert backend.http2 is False + assert backend.max_output_tokens == 100 @pytest.mark.smoke() -@pytest.mark.parametrize( - ( - "openai_api_key", - "target", - "model", - "request_args", - "expected_base_url", - ), - [ - ( - "test_key", - "http://test-target", - "test-model", - {"arg1": "value1"}, - "http://test-target", - ), - (None, None, None, {}, settings.openai.base_url), - ], -) -def test_openai_backend_create( - openai_api_key, - target, - model, - request_args, - expected_base_url, - mock_openai_client, -): - backends = [ - Backend.create( - "openai_server", - openai_api_key=openai_api_key, - target=target, - model=model, - **request_args, - ), - OpenAIBackend( - openai_api_key=openai_api_key, - target=target, - model=model, - **request_args, - ), - ] - - for backend in backends: - assert backend._async_client.kwargs["api_key"] == ( # type: ignore - openai_api_key or settings.openai.api_key - ) - assert backend._async_client.kwargs["base_url"] == expected_base_url # type: ignore - assert backend._client.kwargs["api_key"] == ( # type: ignore - openai_api_key or settings.openai.api_key - ) - assert backend._client.kwargs["base_url"] == expected_base_url # type: ignore - if model: - assert backend._model == model # type: ignore +def test_openai_http_backend_available_models(httpx_openai_mock): + backend = OpenAIHTTPBackend(target="http://target.mock") + models = backend.available_models() + assert models == ["mock-model"] @pytest.mark.smoke() -def test_openai_backend_models(mock_openai_client): - backend = OpenAIBackend() - assert backend.available_models() == ["mock-model", "mock-model-2"] - assert backend.default_model == "mock-model" +def test_openai_http_backend_validate(httpx_openai_mock): + backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") + backend.validate() + + backend = OpenAIHTTPBackend(target="http://target.mock") + backend.validate() assert backend.model == "mock-model" + backend = OpenAIHTTPBackend(target="http://target.mock", model="invalid-model") + with pytest.raises(ValueError): + backend.validate() + @pytest.mark.smoke() -@pytest.mark.parametrize( - ("req", "request_args"), - [ - (TextGenerationRequest(prompt="Test"), None), - ( - TextGenerationRequest(prompt="Test", params={"generated_tokens": 10}), - None, - ), - ( - TextGenerationRequest(prompt="Test", params={"generated_tokens": 10}), - {"max_tokens": 10}, - ), - ( - TextGenerationRequest(prompt="Test"), - {"max_tokens": 10, "stop": "stop"}, - ), - ], -) @pytest.mark.asyncio() -async def test_openai_backend_make_request(req, request_args, mock_openai_client): - backend = OpenAIBackend(**(request_args or {})) - counter = 0 - - async for response in backend.make_request(req): - if counter < 3: - assert response.type_ == "token_iter" - assert response.add_token == f"token{counter}" if counter % 2 == 0 else " " - elif counter == 3: - assert response.type_ == "final" +async def test_openai_http_backend_text_completions(httpx_openai_mock): + backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") + + index = 0 + final_resp = None + async for response in backend.text_completions("Test Prompt", request_id="test-id"): + assert isinstance(response, (StreamingTextResponse, ResponseSummary)) + + if index == 0: + assert isinstance(response, StreamingTextResponse) + assert response.type_ == "start" + assert response.iter_count == 0 + assert response.delta == "" + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == "test-id" + elif not isinstance(response, ResponseSummary): + assert response.type_ == "iter" + assert response.iter_count == index + assert len(response.delta) > 0 + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == "test-id" else: - raise ValueError("Too many responses received from the backend") - - counter += 1 - - # check the kwargs passed to the openai client - # now that the generator has been consumed - assert backend._async_client.create_args == () # type: ignore - assert backend._async_client.create_kwargs["model"] == "mock-model" # type: ignore - assert backend._async_client.create_kwargs["messages"] == [ # type: ignore - {"role": "user", "content": req.prompt} - ] - assert backend._async_client.create_kwargs["stream"] # type: ignore - assert backend._async_client.create_kwargs["n"] == 1 # type: ignore - - if req.output_token_count is not None: - assert ( - backend._async_client.create_kwargs["max_tokens"] == req.output_token_count # type: ignore - ) - assert backend._async_client.create_kwargs["stop"] is None # type: ignore - elif request_args is not None and "max_tokens" not in request_args: - assert ( - backend._async_client.create_kwargs["max_tokens"] # type: ignore - == settings.openai.max_gen_tokens - ) + assert not final_resp + final_resp = response + assert isinstance(response, ResponseSummary) + assert len(response.value) > 0 + assert response.request_args is not None + assert response.iterations > 0 + assert response.start_time > 0 + assert response.end_time == pytest.approx(time.time(), abs=0.01) + assert response.request_prompt_tokens is None + assert response.request_output_tokens is None + assert response.response_prompt_tokens == 3 + assert response.response_output_tokens > 0 # type: ignore + assert response.request_id == "test-id" + + index += 1 + assert final_resp - if request_args: - for key, value in request_args.items(): - assert backend._async_client.create_kwargs[key] == value # type: ignore - -@pytest.mark.sanity() +@pytest.mark.smoke() @pytest.mark.asyncio() -async def test_openai_backend_submit(mock_openai_client): - backend = OpenAIBackend() - request = TextGenerationRequest(prompt="Test", prompt_token_count=1) - result = await backend.submit(request) - - assert result.request == request - assert result.prompt == request.prompt - assert result.prompt_token_count == 1 - assert result.output == "token0 token2" - assert result.output_token_count == 3 - assert result.last_time is not None - assert result.first_token_set - assert result.start_time is not None - assert result.first_token_time is not None - assert result.end_time is not None - assert len(result.decode_times) == 2 - - -@pytest.mark.sanity() -def test_openai_backend_api_key(mock_openai_client): - backend = OpenAIBackend() - assert backend._async_client.kwargs["api_key"] == settings.openai.api_key # type: ignore - assert backend._client.kwargs["api_key"] == settings.openai.api_key # type: ignore - - backend = OpenAIBackend(openai_api_key="test_key") - assert backend._async_client.kwargs["api_key"] == "test_key" # type: ignore - assert backend._client.kwargs["api_key"] == "test_key" # type: ignore - - -@pytest.mark.sanity() -def test_openai_backend_api_key_env(mock_openai_client, mocker): - mocker.patch.dict( - "os.environ", - { - "GUIDELLM__OPENAI__API_KEY": "test_key", - }, +async def test_openai_http_backend_text_completions_counts(httpx_openai_mock): + backend = OpenAIHTTPBackend( + target="http://target.mock", + model="mock-model", + max_output_tokens=100, ) - reload_settings() - - backend = OpenAIBackend() - assert backend._async_client.kwargs["api_key"] == "test_key" # type: ignore - assert backend._client.kwargs["api_key"] == "test_key" # type: ignore + final_resp = None + async for response in backend.text_completions( + "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 + ): + final_resp = response -@pytest.mark.sanity() -def test_openai_backend_target(mock_openai_client): - backend = OpenAIBackend(target="http://test-target") - assert backend._async_client.kwargs["base_url"] == "http://test-target" # type: ignore - assert backend._client.kwargs["base_url"] == "http://test-target" # type: ignore + assert final_resp + assert isinstance(final_resp, ResponseSummary) + assert len(final_resp.value) > 0 + assert final_resp.request_args is not None + assert final_resp.request_prompt_tokens == 3 + assert final_resp.request_output_tokens == 10 + assert final_resp.response_prompt_tokens == 3 + assert final_resp.response_output_tokens == 10 + assert final_resp.request_id == "test-id" - backend = OpenAIBackend() - assert backend._async_client.kwargs["base_url"] == "http://localhost:8000/v1" # type: ignore - assert backend._client.kwargs["base_url"] == "http://localhost:8000/v1" # type: ignore - - backend = OpenAIBackend() - assert backend._async_client.kwargs["base_url"] == settings.openai.base_url # type: ignore - assert backend._client.kwargs["base_url"] == settings.openai.base_url # type: ignore +@pytest.mark.smoke() +@pytest.mark.asyncio() +async def test_openai_http_backend_chat_completions(httpx_openai_mock): + backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") + + index = 0 + final_resp = None + async for response in backend.chat_completions("Test Prompt", request_id="test-id"): + assert isinstance(response, (StreamingTextResponse, ResponseSummary)) + + if index == 0: + assert isinstance(response, StreamingTextResponse) + assert response.type_ == "start" + assert response.iter_count == 0 + assert response.delta == "" + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == "test-id" + elif not isinstance(response, ResponseSummary): + assert response.type_ == "iter" + assert response.iter_count == index + assert len(response.delta) > 0 + assert response.time == pytest.approx(time.time(), abs=0.01) + assert response.request_id == "test-id" + else: + assert not final_resp + final_resp = response + assert isinstance(response, ResponseSummary) + assert len(response.value) > 0 + assert response.request_args is not None + assert response.iterations > 0 + assert response.start_time > 0 + assert response.end_time == pytest.approx(time.time(), abs=0.01) + assert response.request_prompt_tokens is None + assert response.request_output_tokens is None + assert response.response_prompt_tokens == 3 + assert response.response_output_tokens > 0 # type: ignore + assert response.request_id == "test-id" -@pytest.mark.sanity() -def test_openai_backend_target_env(mock_openai_client, mocker): - mocker.patch.dict( - "os.environ", - { - "GUIDELLM__OPENAI__BASE_URL": "http://test-target", - }, - ) - reload_settings() + index += 1 - backend = OpenAIBackend() - assert backend._async_client.kwargs["base_url"] == "http://test-target" # type: ignore - assert backend._client.kwargs["base_url"] == "http://test-target" # type: ignore + assert final_resp -@pytest.mark.regression() -def test_openai_backend_target_none_error(mock_openai_client, mocker): - mocker.patch.dict( - "os.environ", - { - "GUIDELLM__OPENAI__BASE_URL": "", - }, +@pytest.mark.smoke() +@pytest.mark.asyncio() +async def test_openai_http_backend_chat_completions_counts(httpx_openai_mock): + backend = OpenAIHTTPBackend( + target="http://target.mock", + model="mock-model", + max_output_tokens=100, ) - reload_settings() - - with pytest.raises(ValueError): - OpenAIBackend(target=None) + final_resp = None + + async for response in backend.chat_completions( + "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 + ): + final_resp = response + + assert final_resp + assert isinstance(final_resp, ResponseSummary) + assert len(final_resp.value) > 0 + assert final_resp.request_args is not None + assert final_resp.request_prompt_tokens == 3 + assert final_resp.request_output_tokens == 10 + assert final_resp.response_prompt_tokens == 3 + assert final_resp.response_output_tokens == 10 + assert final_resp.request_id == "test-id" diff --git a/tests/unit/backend/test_response.py b/tests/unit/backend/test_response.py new file mode 100644 index 00000000..8de78925 --- /dev/null +++ b/tests/unit/backend/test_response.py @@ -0,0 +1,157 @@ +from typing import get_args + +import pytest + +from guidellm.backend import ( + RequestArgs, + ResponseSummary, + StreamingResponseType, + StreamingTextResponse, +) + + +@pytest.mark.smoke() +def test_streaming_response_types(): + valid_types = get_args(StreamingResponseType) + assert valid_types == ("start", "iter") + + +@pytest.mark.smoke() +def test_streaming_text_response_default_initilization(): + response = StreamingTextResponse( + type_="start", + iter_count=0, + delta="", + time=0.0, + ) + assert response.request_id is None + + +@pytest.mark.smoke() +def test_streaming_text_response_initialization(): + response = StreamingTextResponse( + type_="start", + iter_count=0, + delta="Hello, world!", + time=1.0, + request_id="123", + ) + assert response.type_ == "start" + assert response.iter_count == 0 + assert response.delta == "Hello, world!" + assert response.time == 1.0 + assert response.request_id == "123" + + +@pytest.mark.smoke() +def test_streaming_text_response_marshalling(): + response = StreamingTextResponse( + type_="start", + iter_count=0, + delta="Hello, world!", + time=1.0, + request_id="123", + ) + serialized = response.model_dump() + deserialized = StreamingTextResponse.model_validate(serialized) + + for key, value in vars(response).items(): + assert getattr(deserialized, key) == value + + +@pytest.mark.smoke() +def test_request_args_default_initialization(): + args = RequestArgs( + target="http://example.com", + headers={}, + payload={}, + ) + assert args.timeout is None + assert args.http2 is None + + +@pytest.mark.smoke() +def test_request_args_initialization(): + args = RequestArgs( + target="http://example.com", + headers={ + "Authorization": "Bearer token", + }, + payload={ + "query": "Hello, world!", + }, + timeout=10.0, + http2=True, + ) + assert args.target == "http://example.com" + assert args.headers == {"Authorization": "Bearer token"} + assert args.payload == {"query": "Hello, world!"} + assert args.timeout == 10.0 + assert args.http2 is True + + +@pytest.mark.smoke() +def test_response_args_marshalling(): + args = RequestArgs( + target="http://example.com", + headers={"Authorization": "Bearer token"}, + payload={"query": "Hello, world!"}, + timeout=10.0, + http2=True, + ) + serialized = args.model_dump() + deserialized = RequestArgs.model_validate(serialized) + + for key, value in vars(args).items(): + assert getattr(deserialized, key) == value + + +@pytest.mark.smoke() +def test_response_summary_default_initialization(): + summary = ResponseSummary( + value="Hello, world!", + request_args=RequestArgs( + target="http://example.com", + headers={}, + payload={}, + ), + start_time=0.0, + end_time=0.0, + ) + assert summary.request_prompt_tokens is None + assert summary.request_output_tokens is None + assert summary.response_prompt_tokens is None + assert summary.response_output_tokens is None + assert summary.request_id is None + + +@pytest.mark.smoke() +def test_response_summary_initialization(): + summary = ResponseSummary( + value="Hello, world!", + request_args=RequestArgs( + target="http://example.com", + headers={}, + payload={}, + ), + start_time=1.0, + end_time=2.0, + iterations=3, + request_prompt_tokens=5, + request_output_tokens=10, + response_prompt_tokens=5, + response_output_tokens=10, + request_id="123", + ) + assert summary.value == "Hello, world!" + assert summary.request_args.target == "http://example.com" + assert summary.request_args.headers == {} + assert summary.request_args.payload == {} + assert summary.start_time == 1.0 + assert summary.end_time == 2.0 + assert summary.iterations == 3 + assert summary.request_prompt_tokens == 5 + assert summary.request_output_tokens == 10 + assert summary.response_prompt_tokens == 5 + assert summary.response_output_tokens == 10 + assert summary.request_id == "123" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3257a8d2..2a31df5d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,9 +1,16 @@ +import json from pathlib import Path -from typing import List +from typing import Any, AsyncIterable, Dict, List, Literal, Optional from unittest.mock import MagicMock, patch +import httpx import pytest import requests_mock +import respx + +from guidellm.backend import ResponseSummary, StreamingTextResponse + +from .mock_backend import MockBackend @pytest.fixture() @@ -33,3 +40,172 @@ def mock_requests_pride_and_prejudice(): text=text_content, ) yield mock + + +@pytest.fixture() +def mock_backend(request): + params = request.param if hasattr(request, "param") else {} + kwargs = {} + + for key in ("model", "target", "iter_delay"): + if key in params: + kwargs[key] = params[key] + + return MockBackend(**kwargs) + + +class MockCompletionsIter(AsyncIterable): + def __init__( + self, + type_: Literal["text", "chat"], + prompt: str, + output_token_count: Optional[int], + target: Optional[str] = None, + model: Optional[str] = None, + iter_delay: Optional[float] = None, + ): + self._type = type_ + self._backend = MockBackend( + model=model, + target=target, + iter_delay=iter_delay, + ) + self._prompt = prompt + self._output_token_count = output_token_count + + async def __aiter__(self): + async for token_iter in ( + self._backend.text_completions( + prompt=self._prompt, output_token_count=self._output_token_count + ) + if self._type == "text" + else self._backend.chat_completions( + content=self._prompt, output_token_count=self._output_token_count + ) + ): + if ( + isinstance(token_iter, StreamingTextResponse) + and token_iter.type_ == "start" + ): + continue + + data: Dict[str, Any] + + if isinstance(token_iter, StreamingTextResponse): + if self._type == "text": + data = { + "choices": [ + { + "index": token_iter.iter_count, + "text": token_iter.delta, + } + ] + } + elif self._type == "chat": + data = { + "choices": [ + { + "index": token_iter.iter_count, + "delta": {"content": token_iter.delta}, + } + ] + } + else: + raise ValueError("Invalid type for mock completions") + elif isinstance(token_iter, ResponseSummary): + data = { + "usage": { + "prompt_tokens": ( + len(self._prompt.split()) + self._prompt.count(" ") + ), + "completion_tokens": token_iter.response_output_tokens, + } + } + else: + raise ValueError("Invalid token_iter type") + + yield f"data: {json.dumps(data)}\n".encode() + + yield b"data: [DONE]\n" + + +@pytest.fixture() +def httpx_openai_mock(request): + params = request.param if hasattr(request, "param") else {} + model = params.get("model", "mock-model") + target = params.get("target", "http://target.mock") + iter_delay = params.get("iter_delay", None) + + with respx.mock(assert_all_mocked=True, assert_all_called=False) as mock_router: + + async def _mock_completions_response(request) -> AsyncIterable[str]: + headers = request.headers + payload = json.loads(request.content) + + assert headers["Content-Type"] == "application/json" + assert payload["model"] == model + assert payload["stream"] is True + assert payload["stream_options"] == {"include_usage": True} + assert payload["prompt"] is not None + assert len(payload["prompt"]) > 0 + assert payload["max_completion_tokens"] > 0 + assert payload["max_tokens"] > 0 + + return httpx.Response( # type: ignore + 200, + stream=MockCompletionsIter( # type: ignore + type_="text", + prompt=payload["prompt"], + output_token_count=( + payload["max_completion_tokens"] + if payload.get("ignore_eos", False) + else None + ), + target=target, + model=model, + iter_delay=iter_delay, + ), + ) + + async def _mock_chat_completions_response(request): + headers = request.headers + payload = json.loads(request.content) + + assert headers["Content-Type"] == "application/json" + assert payload["model"] == model + assert payload["stream"] is True + assert payload["stream_options"] == {"include_usage": True} + assert payload["messages"] is not None + assert len(payload["messages"]) > 0 + assert payload["max_completion_tokens"] > 0 + assert payload["max_tokens"] > 0 + + return httpx.Response( # type: ignore + 200, + stream=MockCompletionsIter( # type: ignore + type_="chat", + prompt=payload["messages"][0]["content"], + output_token_count=( + payload["max_completion_tokens"] + if payload.get("ignore_eos", False) + else None + ), + target=target, + model=model, + iter_delay=iter_delay, + ), + ) + + mock_router.route(method="GET", path="/v1/models").mock( + return_value=httpx.Response( + 200, json={"data": [{"id": model} if model else {"id": "mock-model"}]} + ) + ) + mock_router.route(method="POST", path="/v1/completions").mock( + side_effect=_mock_completions_response # type: ignore + ) + mock_router.route(method="POST", path="/v1/chat/completions").mock( + side_effect=_mock_chat_completions_response + ) + + yield mock_router diff --git a/tests/unit/core/test_report.py b/tests/unit/core/test_report.py index 42b46814..c9e4ef3a 100644 --- a/tests/unit/core/test_report.py +++ b/tests/unit/core/test_report.py @@ -4,7 +4,6 @@ import pytest from guidellm.core import ( - Distribution, GuidanceReport, TextGenerationBenchmark, TextGenerationBenchmarkReport, @@ -16,21 +15,15 @@ @pytest.fixture() def sample_benchmark_report() -> TextGenerationBenchmarkReport: sample_request = TextGenerationRequest(prompt="sample prompt") - sample_distribution = Distribution() sample_result = TextGenerationResult( request=sample_request, - prompt="sample prompt", - prompt_word_count=2, prompt_token_count=2, output="sample output", - output_word_count=2, output_token_count=2, - last_time=None, - first_token_set=False, start_time=None, end_time=None, first_token_time=None, - decode_times=sample_distribution, + last_token_time=None, ) sample_benchmark = TextGenerationBenchmark( mode="asynchronous", diff --git a/tests/unit/core/test_result.py b/tests/unit/core/test_result.py index 02232ba9..ddd62d7f 100644 --- a/tests/unit/core/test_result.py +++ b/tests/unit/core/test_result.py @@ -3,6 +3,7 @@ import pytest from guidellm.core import ( + RequestConcurrencyMeasurement, TextGenerationBenchmark, TextGenerationBenchmarkReport, TextGenerationError, @@ -11,400 +12,268 @@ ) -@pytest.mark.smoke() -def test_text_generation_result_initialization(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - assert result.request == request - assert result.prompt == "" - assert result.output == "" +def create_sample_request(): + return TextGenerationRequest(prompt="Hello, world!") -@pytest.mark.smoke() -def test_text_generation_result_start(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - prompt = "Once upon a time" - result.start(prompt) - assert result.prompt == prompt - assert result.start_time is not None +def create_sample_result(): + start_time = time.time() + + return TextGenerationResult( + request=create_sample_request(), + prompt_token_count=4, + output="Generated text", + output_token_count=3, + start_time=start_time, + end_time=start_time + 1.5, + first_token_time=start_time + 0.5, + last_token_time=start_time + 1.4, + ) @pytest.mark.smoke() -def test_text_generation_result_output_token(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - prompt = "Once upon a time" - result.start(prompt) - tokens = ["the", " ", "quick", " ", "brown", " ", "fox"] - for token in tokens: - result.output_token(token) - result.end() - - assert result.last_time - assert result.start_time - assert result.output == "the quick brown fox" - assert result.last_time is not None - assert result.last_time > result.start_time +def test_text_generation_result_default_initialization(): + result = TextGenerationResult(request=create_sample_request()) + assert result.request.prompt == "Hello, world!" + assert result.prompt_token_count is None + assert result.output == "" + assert result.output_token_count is None + assert result.start_time is None + assert result.end_time is None + assert result.first_token_time is None + assert result.last_token_time is None @pytest.mark.smoke() -def test_text_generation_result_end(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - result.end("The end") - - assert result.output == "The end" - assert result.last_time - assert result.start_time - assert result.end_time is not None - assert result.end_time > result.start_time - - -@pytest.mark.sanity() -def test_text_generation_result_improper_lifecycle(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - with pytest.raises(ValueError): - result.output_token("the") - with pytest.raises(ValueError): - result.end("The end") - - -@pytest.mark.regression() -def test_text_generation_result_json(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - prompt = "Once upon a time" - result.start(prompt) - generated = "The end" - result.end(generated) - json_str = result.to_json() - assert '"prompt":"Once upon a time"' in json_str - assert '"output":"The end"' in json_str - - result_restored = TextGenerationResult.from_json(json_str) - assert result.request == result_restored.request - assert result_restored.prompt == prompt - assert result_restored.output == generated - - json_str_restored = result_restored.to_json() - assert json_str == json_str_restored - - -@pytest.mark.regression() -def test_text_generation_result_yaml(): - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - prompt = "Once upon a time" - result.start(prompt) - generated = "The end" - result.end(generated) - yaml_str = result.to_yaml() - assert "prompt: Once upon a time" in yaml_str - assert "output: The end" in yaml_str - - result_restored = TextGenerationResult.from_yaml(yaml_str) - assert result.request == result_restored.request - assert result_restored.prompt == prompt - assert result_restored.output == generated - - yaml_str_restored = result_restored.to_yaml() - assert yaml_str == yaml_str_restored +def test_text_generation_result_initialization(): + result = create_sample_result() + assert result.request.prompt == "Hello, world!" + assert result.prompt_token_count == 4 + assert result.output == "Generated text" + assert result.output_token_count == 3 + assert result.start_time >= 0.0 + assert result.end_time == result.start_time + 1.5 + assert result.first_token_time == result.start_time + 0.5 + assert result.last_token_time == result.start_time + 1.4 + + # computed fields + assert result.request_latency == 1.5 + assert result.time_to_first_token == 0.5 * 1000 + assert result.inter_token_latency == pytest.approx((1.4 - 0.5) * 1000 / 2) + assert result.output_tokens_per_second == pytest.approx(2 / (1.4 - 0.5)) @pytest.mark.smoke() -def test_text_generation_error_initialization(): - request = TextGenerationRequest(prompt="Generate a story") - error = Exception("Test error") - result = TextGenerationError(request=request, message=str(error)) - assert result.request == request - assert str(result.message) == str(error) +def test_text_generation_result_marshalling(): + result = create_sample_result() + serialized = result.model_dump() + deserialized = TextGenerationResult.model_validate(serialized) + for key, value in vars(result).items(): + assert getattr(deserialized, key) == value -@pytest.mark.regression() -def test_text_generation_error_json(): - request = TextGenerationRequest(prompt="Generate a story") - error = Exception("Test error") - result = TextGenerationError(request=request, message=str(error)) - json_str = result.to_json() - result_restored = TextGenerationError.from_json(json_str) +@pytest.mark.smoke() +def test_text_generation_error_initialization(): + error = TextGenerationError( + request=create_sample_request(), message="Error message" + ) + assert error.request.prompt == "Hello, world!" + assert error.message == "Error message" - assert result.message == "Test error" - assert result.request == result_restored.request - assert str(result_restored.message) == str(error) - json_str_restored = result_restored.to_json() - assert json_str == json_str_restored +@pytest.mark.smoke() +def test_text_generation_error_marshalling(): + error = TextGenerationError( + request=create_sample_request(), message="Error message" + ) + serialized = error.model_dump() + deserialized = TextGenerationError.model_validate(serialized) + for key, value in vars(error).items(): + assert getattr(deserialized, key) == value -@pytest.mark.regression() -def test_text_generation_error_yaml(): - request = TextGenerationRequest(prompt="Generate a story") - error = Exception("Test error") - result = TextGenerationError(request=request, message=str(error)) - yaml_str = result.to_yaml() - result_restored = TextGenerationError.from_yaml(yaml_str) +@pytest.mark.smoke() +def test_request_concurrency_measurement_initialization(): + start_time = time.time() + measurement = RequestConcurrencyMeasurement( + time=start_time, + completed=8, + errored=2, + processing=3, + ) + assert measurement.time == start_time + assert measurement.completed == 8 + assert measurement.errored == 2 + assert measurement.processing == 3 + - assert result.message == "Test error" - assert result.request == result_restored.request - assert str(result_restored.message) == str(error) +@pytest.mark.smoke() +def test_request_concurrency_measurement_marshalling(): + start_time = time.time() + measurement = RequestConcurrencyMeasurement( + time=start_time, + completed=8, + errored=2, + processing=3, + ) + serialized = measurement.model_dump() + deserialized = RequestConcurrencyMeasurement.model_validate(serialized) - yaml_str_restored = result_restored.to_yaml() - assert yaml_str == yaml_str_restored + for key, value in vars(measurement).items(): + assert getattr(deserialized, key) == value @pytest.mark.smoke() -def test_text_generation_benchmark_initialization(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - assert benchmark.mode == "synchronous" - assert benchmark.rate == 1.0 +def test_text_generation_benchmark_default_initialization(): + benchmark = TextGenerationBenchmark(mode="asynchronous") + assert benchmark.mode == "asynchronous" + assert benchmark.rate is None + assert benchmark.results == [] + assert benchmark.errors == [] + assert benchmark.concurrencies == [] + + # computed assert benchmark.request_count == 0 assert benchmark.error_count == 0 - - -@pytest.mark.smoke() -def test_text_generation_benchmark_started(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) + assert benchmark.total_count == 0 + assert benchmark.start_time is None + assert benchmark.end_time is None + assert benchmark.duration == 0.0 assert benchmark.completed_request_rate == 0.0 - assert not benchmark.overloaded - benchmark.request_started() - assert len(benchmark.concurrencies) == 1 + assert benchmark.request_latency_distribution is not None + assert benchmark.request_latency == 0.0 + assert benchmark.request_latency_percentiles == {} + assert benchmark.ttft_distribution is not None + assert benchmark.time_to_first_token == 0.0 + assert benchmark.time_to_first_token_percentiles == {} + assert benchmark.itl_distribution is not None + assert benchmark.inter_token_latency == 0.0 + assert benchmark.inter_token_latency_percentiles == {} + assert benchmark.output_token_throughput == 0.0 + assert benchmark.prompt_token_distribution is not None + assert benchmark.prompt_token == 0.0 + assert benchmark.prompt_token_percentiles == {} + assert benchmark.output_token_distribution is not None + assert benchmark.output_token == 0.0 + assert benchmark.output_token_percentiles == {} @pytest.mark.smoke() -def test_text_generation_benchmark_expected_rate(): - num_requests = 5 - time_per_request = 0.25 - expected_rate = 1.0 / time_per_request - - benchmark = TextGenerationBenchmark(mode="synchronous", rate=expected_rate) +def test_text_generation_benchmark_initialization(): + benchmark = TextGenerationBenchmark(mode="asynchronous", rate=10) + assert benchmark.mode == "asynchronous" + assert benchmark.rate == 10 - for index in range(num_requests): - request = TextGenerationRequest(prompt=f"Generate a story {index}") + for _ in range(5): benchmark.request_started() - result = TextGenerationResult(request=request) - result.start("Once upon a time") - time.sleep(time_per_request) - result.end("The end") - benchmark.request_completed(result) - - assert len(benchmark.results) == num_requests - assert len(benchmark.errors) == 0 - assert len(benchmark.concurrencies) == 10 - assert benchmark.request_count == num_requests - assert benchmark.error_count == 0 - assert benchmark.completed_request_rate == pytest.approx(expected_rate, rel=0.1) - assert not benchmark.overloaded - - -@pytest.mark.smoke() -def test_text_generation_benchmark_overloaded_rate(): - num_requests = 5 - time_per_request = 0.25 - expected_rate = 1.0 / time_per_request - - benchmark = TextGenerationBenchmark(mode="synchronous", rate=expected_rate * 1.5) + benchmark.request_completed(create_sample_result()) + time.sleep(1.5) - for index in range(num_requests): - request = TextGenerationRequest(prompt=f"Generate a story {index}") + for _ in range(2): benchmark.request_started() - result = TextGenerationResult(request=request) - result.start("Once upon a time") - time.sleep(time_per_request) - result.end("The end") - benchmark.request_completed(result) - - assert len(benchmark.results) == num_requests - assert len(benchmark.errors) == 0 - assert len(benchmark.concurrencies) == 10 - assert benchmark.request_count == num_requests - assert benchmark.error_count == 0 - assert benchmark.completed_request_rate == pytest.approx(expected_rate, rel=0.1) - assert benchmark.overloaded + benchmark.request_completed( + TextGenerationError( + request=create_sample_request(), message="Error message" + ) + ) + + def _test_percentiles(percentiles, value=None): + assert len(percentiles) == 7 + assert list(percentiles.keys()) == ["1", "5", "10", "50", "90", "95", "99"] + + if value is None: + assert all(per >= 0.0 for per in percentiles.values()) + else: + assert all(per == pytest.approx(value) for per in percentiles.values()) + + assert len(benchmark.results) == 5 + assert len(benchmark.errors) == 2 + assert len(benchmark.concurrencies) == 14 + assert benchmark.request_count == 5 + assert benchmark.error_count == 2 + assert benchmark.total_count == 7 + assert benchmark.start_time == pytest.approx(time.time() - 1.5 * 5, abs=0.01) + assert benchmark.end_time == pytest.approx(time.time(), abs=0.01) + assert benchmark.duration == benchmark.end_time - benchmark.start_time # type: ignore + assert benchmark.completed_request_rate == pytest.approx(5 / benchmark.duration) + assert benchmark.request_latency_distribution is not None + assert benchmark.request_latency == pytest.approx(1.5) + _test_percentiles(benchmark.request_latency_percentiles, 1.5) + assert benchmark.ttft_distribution is not None + assert benchmark.time_to_first_token == pytest.approx(500) + _test_percentiles(benchmark.time_to_first_token_percentiles, 500) + assert benchmark.itl_distribution is not None + assert benchmark.inter_token_latency == pytest.approx(450) + _test_percentiles(benchmark.inter_token_latency_percentiles, 450) + assert benchmark.output_token_throughput == pytest.approx(3.0 / 1.5, abs=0.01) + assert benchmark.prompt_token_distribution is not None + assert benchmark.prompt_token == pytest.approx(4.0) + _test_percentiles(benchmark.prompt_token_percentiles, 4.0) + assert benchmark.output_token_distribution is not None + assert benchmark.output_token == pytest.approx(3.0) + _test_percentiles(benchmark.output_token_percentiles, 3.0) @pytest.mark.smoke() -def test_text_generation_benchmark_completed_with_result(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - - with pytest.raises(ValueError): - benchmark.request_completed(None) # type: ignore - - benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - - with pytest.raises(ValueError): - benchmark.request_completed(result) +def test_text_generation_benchmark_marshalling(): + benchmark = TextGenerationBenchmark(mode="asynchronous", rate=10) + for _ in range(5): + benchmark.request_started() + benchmark.request_completed(create_sample_result()) - result.start("Once upon a time") - result.end("The end") - benchmark.request_completed(result) - assert benchmark.request_count == 1 - assert benchmark.error_count == 0 + for _ in range(2): + benchmark.request_started() + benchmark.request_completed( + TextGenerationError( + request=create_sample_request(), message="Error message" + ) + ) + serialized = benchmark.model_dump() + deserialized = TextGenerationBenchmark.model_validate(serialized) -@pytest.mark.smoke() -def test_text_generation_benchmark_completed_with_error(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - error = TextGenerationError(request=request, message=str(Exception("Test error"))) - benchmark.request_completed(error) - assert benchmark.request_count == 0 - assert benchmark.error_count == 1 - - -@pytest.mark.regression() -def test_text_generation_benchmark_iter(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - result.end("The end") - benchmark.request_completed(result) - for res in benchmark: - assert res == result - - -@pytest.mark.regression() -def test_text_generation_benchmark_json(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - result.end("The end") - benchmark.request_completed(result) - json_str = benchmark.to_json() - assert '"mode":"synchronous"' in json_str - assert '"rate":1.0' in json_str - - benchmark_restored = TextGenerationBenchmark.from_json(json_str) - assert benchmark.mode == benchmark_restored.mode - assert benchmark.rate == benchmark_restored.rate - assert benchmark.request_count == benchmark_restored.request_count - assert benchmark.error_count == benchmark_restored.error_count - - json_str_restored = benchmark_restored.to_json() - assert json_str == json_str_restored - - -@pytest.mark.regression() -def test_text_generation_benchmark_yaml(): - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - result.end("The end") - benchmark.request_completed(result) - yaml_str = benchmark.to_yaml() - assert "mode: synchronous" in yaml_str - assert "rate: 1.0" in yaml_str - - benchmark_restored = TextGenerationBenchmark.from_yaml(yaml_str) - assert benchmark.mode == benchmark_restored.mode - assert benchmark.rate == benchmark_restored.rate - assert benchmark.request_count == benchmark_restored.request_count - assert benchmark.error_count == benchmark_restored.error_count - - yaml_str_restored = benchmark_restored.to_yaml() - assert yaml_str == yaml_str_restored + for key, value in vars(benchmark).items(): + assert getattr(deserialized, key) == value @pytest.mark.smoke() def test_text_generation_benchmark_report_initialization(): - report = TextGenerationBenchmarkReport() - assert len(report.benchmarks) == 0 - assert len(report.args) == 0 + report = TextGenerationBenchmarkReport( + benchmarks=[ + TextGenerationBenchmark(mode="asynchronous", rate=10), + TextGenerationBenchmark(mode="asynchronous", rate=20), + ], + args={ + "backend_type": "http", + "target": "http://example.com", + "model": "test-model", + }, + ) + assert len(report.benchmarks) == 2 + assert report.args == { + "backend_type": "http", + "target": "http://example.com", + "model": "test-model", + } @pytest.mark.smoke() -def test_text_generation_benchmark_report_add_benchmark(): - report = TextGenerationBenchmarkReport() - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - report.add_benchmark(benchmark) - assert len(report.benchmarks) == 1 - - -@pytest.mark.sanity() -def test_text_generation_benchmark_report_iter(): - report = TextGenerationBenchmarkReport() - - fast_benchmark = TextGenerationBenchmark(mode="synchronous", rate=10.0) - for _ in range(5): - fast_benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - time.sleep(0.1) - result.end("The end") - fast_benchmark.request_completed(result) - report.add_benchmark(fast_benchmark) - - slow_benchmark = TextGenerationBenchmark(mode="synchronous", rate=5.0) - for _ in range(5): - slow_benchmark.request_started() - request = TextGenerationRequest(prompt="Generate a story") - result = TextGenerationResult(request=request) - result.start("Once upon a time") - time.sleep(0.2) - result.end("The end") - slow_benchmark.request_completed(result) - report.add_benchmark(slow_benchmark) - - for index, benchmark in enumerate(report): - if index == 0: - assert benchmark == fast_benchmark - elif index == 1: - assert benchmark == slow_benchmark - else: - raise AssertionError("Unexpected report in report") - - for index, benchmark in enumerate(report.benchmarks_sorted): - if index == 0: - assert benchmark == slow_benchmark - elif index == 1: - assert benchmark == fast_benchmark - else: - raise AssertionError("Unexpected report in report") - - -@pytest.mark.regression() -def test_text_generation_benchmark_report_json(): - report = TextGenerationBenchmarkReport() - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - report.add_benchmark(benchmark) - json_str = report.to_json() - assert '"benchmarks":' in json_str - assert '"args":{}' in json_str - - report_restored = TextGenerationBenchmarkReport.from_json(json_str) - assert len(report.benchmarks) == len(report_restored.benchmarks) - assert len(report.args) == len(report_restored.args) - - json_str_restored = report_restored.to_json() - assert json_str == json_str_restored - - -@pytest.mark.regression() -def test_text_generation_benchmark_report_yaml(): - report = TextGenerationBenchmarkReport() - benchmark = TextGenerationBenchmark(mode="synchronous", rate=1.0) - report.add_benchmark(benchmark) - yaml_str = report.to_yaml() - assert "benchmarks:" in yaml_str - assert "args: {}" in yaml_str - - report_restored = TextGenerationBenchmarkReport.from_yaml(yaml_str) - assert len(report.benchmarks) == len(report_restored.benchmarks) - assert len(report.args) == len(report_restored.args) - - yaml_str_restored = report_restored.to_yaml() - assert yaml_str == yaml_str_restored +def test_text_generation_benchmark_report_marshalling(): + report = TextGenerationBenchmarkReport( + benchmarks=[ + TextGenerationBenchmark(mode="asynchronous", rate=10), + TextGenerationBenchmark(mode="asynchronous", rate=20), + ], + args={ + "backend_type": "http", + "target": "http://example.com", + "model": "test-model", + }, + ) + serialized = report.model_dump() + deserialized = TextGenerationBenchmarkReport.model_validate(serialized) + + for key, value in vars(report).items(): + assert getattr(deserialized, key) == value diff --git a/tests/unit/executor/test_base.py b/tests/unit/executor/test_executor.py similarity index 99% rename from tests/unit/executor/test_base.py rename to tests/unit/executor/test_executor.py index 844cf7f4..58c0a9d4 100644 --- a/tests/unit/executor/test_base.py +++ b/tests/unit/executor/test_executor.py @@ -21,7 +21,7 @@ @pytest.fixture() def mock_scheduler(): - with patch("guidellm.executor.base.Scheduler") as mock_scheduler: + with patch("guidellm.executor.executor.Scheduler") as mock_scheduler: def scheduler_constructor(*args, **kwargs): mock_instance = create_autospec(Scheduler, instance=True) diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py new file mode 100644 index 00000000..9eb4d6ee --- /dev/null +++ b/tests/unit/mock_backend.py @@ -0,0 +1,154 @@ +import asyncio +import random +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +from lorem.text import TextLorem # type: ignore +from PIL import Image + +from guidellm.backend import ( + Backend, + RequestArgs, + ResponseSummary, + StreamingTextResponse, +) + + +@Backend.register("mock") # type: ignore +class MockBackend(Backend): + def __init__( + self, + model: Optional[str] = "mock-model", + target: Optional[str] = "mock-target", + iter_delay: Optional[float] = None, + ): + super().__init__(type_="mock") # type: ignore + self._model = model + self._target = target + self._iter_delay = iter_delay + + @property + def target(self) -> str: + return self._target # type: ignore + + @property + def model(self) -> Optional[str]: + return self._model + + def check_setup(self): + pass + + def available_models(self) -> List[str]: + return [self.model] # type: ignore + + async def text_completions( # type: ignore + self, + prompt: Union[str, List[str]], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + if not isinstance(prompt, str) or not prompt: + raise ValueError("Prompt must be a non-empty string") + + async for response in self._text_prompt_response_generator( + prompt, + request_id, + prompt_token_count, + output_token_count, + ): + yield response + + async def chat_completions( # type: ignore + self, + content: Union[ + str, + List[Union[str, Dict[str, Union[str, Dict[str, str]]], Path, Image.Image]], + Any, + ], + request_id: Optional[str] = None, + prompt_token_count: Optional[int] = None, + output_token_count: Optional[int] = None, + raw_content: bool = False, + **kwargs, + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + if not isinstance(content, str) or not content: + raise ValueError("Content must be a non-empty string") + + async for response in self._text_prompt_response_generator( + content, + request_id, + prompt_token_count, + output_token_count, + ): + yield response + + async def _text_prompt_response_generator( + self, + prompt: str, + request_id: Optional[str], + prompt_token_count: Optional[int], + output_token_count: Optional[int], + ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + tokens = self._get_tokens(output_token_count) + start_time = time.time() + + yield StreamingTextResponse( + type_="start", + iter_count=0, + delta="", + time=start_time, + request_id=request_id, + ) + + for index, token in enumerate(tokens): + if self._iter_delay: + await asyncio.sleep(self._iter_delay) + + yield StreamingTextResponse( + type_="iter", + iter_count=index + 1, + delta=token, + time=time.time(), + request_id=request_id, + ) + + yield ResponseSummary( + value="".join(tokens), + request_args=RequestArgs( + target=self.target, + headers={}, + payload={"prompt": prompt, "output_token_count": output_token_count}, + ), + iterations=len(tokens), + start_time=start_time, + end_time=time.time(), + request_prompt_tokens=prompt_token_count, + request_output_tokens=output_token_count, + response_prompt_tokens=len(prompt.split()) + prompt.count(" "), + response_output_tokens=len(tokens), + request_id=request_id, + ) + + @staticmethod + def _get_tokens(token_count: Optional[int] = None) -> List[str]: + if token_count is None: + token_count = random.randint(8, 512) + + words = TextLorem(srange=(token_count, token_count)).sentence().split() + tokens = [] # type: ignore + + for word in words: + if len(tokens) == token_count - 1: + tokens.append(".") + break + if len(tokens) == token_count - 2: + tokens.append(word) + tokens.append(".") + break + tokens.append(word) + tokens.append(" ") + + return tokens diff --git a/tests/unit/request/test_transformers.py b/tests/unit/request/test_transformers.py index eaa465e6..d3b45325 100644 --- a/tests/unit/request/test_transformers.py +++ b/tests/unit/request/test_transformers.py @@ -114,7 +114,7 @@ def test_transformers_dataset_request_generator_lifecycle( create_sample_dataset_dict(splits=["test"], column="output"), ), (create_sample_dataset_dict(splits=["val", "train"], column="custom"), None), - (create_sample_dataset(), None) + (create_sample_dataset(), None), ], ) def test_transformers_dataset_request_generator_len( diff --git a/tests/unit/scheduler/test_base.py b/tests/unit/scheduler/test_scheduler.py similarity index 56% rename from tests/unit/scheduler/test_base.py rename to tests/unit/scheduler/test_scheduler.py index b485e59e..d765280f 100644 --- a/tests/unit/scheduler/test_base.py +++ b/tests/unit/scheduler/test_scheduler.py @@ -1,6 +1,5 @@ -import asyncio -import time -from unittest.mock import AsyncMock, create_autospec +import random +from unittest.mock import create_autospec import pytest @@ -19,20 +18,37 @@ @pytest.mark.smoke() -def test_scheduler_result(): +def test_scheduler_result_default_intialization(): + benchmark = create_autospec(TextGenerationBenchmark, instance=True) + scheduler_result = SchedulerResult( + completed=False, + count_total=0, + count_completed=0, + benchmark=benchmark, + ) + + assert scheduler_result.completed is False + assert scheduler_result.count_total == 0 + assert scheduler_result.count_completed == 0 + assert scheduler_result.benchmark == benchmark + assert scheduler_result.current_result is None + + +@pytest.mark.smoke() +def test_scheduler_result_initialization(): benchmark = create_autospec(TextGenerationBenchmark, instance=True) result = TextGenerationResult( request=TextGenerationRequest(prompt="prompt"), output="Test output" ) scheduler_result = SchedulerResult( - completed=True, + completed=False, count_total=10, count_completed=5, benchmark=benchmark, current_result=result, ) - assert scheduler_result.completed is True + assert scheduler_result.completed is False assert scheduler_result.count_total == 10 assert scheduler_result.count_completed == 5 assert scheduler_result.benchmark == benchmark @@ -49,12 +65,12 @@ def test_scheduler_result(): ("constant", 1.0, None, 120.0), ], ) -def test_scheduler_instantiation(mode, rate, max_number, max_duration): +def test_scheduler_initialization(mode, rate, max_number, max_duration): generator = create_autospec(RequestGenerator, instance=True) - worker = create_autospec(Backend, instance=True) + backend = create_autospec(Backend, instance=True) scheduler = Scheduler( generator, - worker, + backend, mode=mode, rate=rate, max_number=max_number, @@ -62,7 +78,7 @@ def test_scheduler_instantiation(mode, rate, max_number, max_duration): ) assert scheduler.generator == generator - assert scheduler.worker == worker + assert scheduler.backend == backend assert scheduler.mode == mode assert scheduler.rate == rate assert scheduler.max_number == max_number @@ -88,19 +104,19 @@ def test_scheduler_instantiation(mode, rate, max_number, max_duration): ("poisson", None, None, 10), ], ) -def test_scheduler_invalid_instantiation( +def test_scheduler_invalid_initialization( mode, rate, max_number, max_duration, ): generator = create_autospec(RequestGenerator, instance=True) - worker = create_autospec(Backend, instance=True) + backend = create_autospec(Backend, instance=True) with pytest.raises(ValueError): Scheduler( generator, - worker, + backend, mode=mode, rate=rate, max_number=max_number, @@ -119,30 +135,20 @@ def test_scheduler_invalid_instantiation( "constant", ], ) -async def test_scheduler_run_number(mode): +async def test_scheduler_run_number(mode, mock_backend): rate = 10.0 max_number = 20 generator = create_autospec(RequestGenerator, instance=True) - worker = create_autospec(Backend, instance=True) # Mock the request generator and backend submit behavior generator.__iter__.return_value = iter( - [TextGenerationRequest(prompt="Test")] * (max_number * 2) + [TextGenerationRequest(prompt="Test", type_=random.choice(["text", "chat"]))] + * (max_number * 2) ) - worker.submit = AsyncMock() - - def _submit(req): - res = TextGenerationResult(request=req, output="Output") - res.start(prompt=req.prompt) - res.output_token("token") - res.end() - return res - - worker.submit.side_effect = _submit scheduler = Scheduler( generator, - worker, + mock_backend, mode=mode, rate=rate, max_number=max_number, @@ -191,89 +197,3 @@ def _submit(req): assert received_init assert received_final assert count_completed == max_number - - -@pytest.mark.sanity() -@pytest.mark.asyncio() -@pytest.mark.parametrize( - "mode", - [ - "synchronous", - "constant", - ], -) -@pytest.mark.flaky(reruns=5) -async def test_scheduler_run_duration(mode): - rate = 10 - max_duration = 2 - generator = create_autospec(RequestGenerator, instance=True) - worker = create_autospec(Backend, instance=True) - - # Mock the request generator and backend submit behavior - generator.__iter__.return_value = iter( - [TextGenerationRequest(prompt="Test")] * (rate * max_duration * 100) - ) - worker.submit = AsyncMock() - - async def _submit(req): - await asyncio.sleep(0.1) - res = TextGenerationResult(request=req, output="Output") - res.start(prompt=req.prompt) - res.output_token("token") - res.end() - return res - - worker.submit.side_effect = _submit - - scheduler = Scheduler( - generator, - worker, - mode=mode, - rate=rate, - max_duration=max_duration, - ) - - run_count = 0 - count_completed = 0 - received_init = False - received_final = False - start_time = time.time() - async for result in scheduler.run(): - run_count += 1 - - assert run_count <= max_duration * rate + 2 - assert result.count_total == max_duration - assert result.benchmark is not None - assert isinstance(result.benchmark, TextGenerationBenchmark) - - if result.current_result is not None: - count_completed += 1 - - if run_count == 1: - assert not received_init - assert not received_final - assert count_completed == 0 - assert result.count_completed == 0 - assert not result.completed - assert result.current_result is None - received_init = True - elif time.time() - start_time >= max_duration: - assert received_init - assert not received_final - assert result.count_completed == max_duration - assert result.completed - assert result.current_result is None - received_final = True - else: - assert received_init - assert not received_final - assert result.count_completed == round(time.time() - start_time) - assert not result.completed - assert result.current_result is not None - assert isinstance(result.current_result, TextGenerationResult) - - assert received_init - assert received_final - end_time = time.time() - assert pytest.approx(end_time - start_time, abs=0.1) == max_duration - assert pytest.approx(count_completed, abs=5) == max_duration * rate diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py deleted file mode 100644 index 82de3edf..00000000 --- a/tests/unit/test_main.py +++ /dev/null @@ -1,414 +0,0 @@ -import tempfile -from pathlib import Path -from typing import List, Optional -from unittest.mock import create_autospec, patch - -import pytest -from click.testing import CliRunner - -from guidellm import generate_benchmark_report -from guidellm.backend import Backend -from guidellm.core import TextGenerationBenchmarkReport, TextGenerationResult -from guidellm.executor import Executor, ExecutorResult, Profile, ProfileGenerationMode -from guidellm.main import generate_benchmark_report_cli -from guidellm.request import ( - EmulatedRequestGenerator, - FileRequestGenerator, - TransformersDatasetRequestGenerator, -) -from guidellm.scheduler import SchedulerResult -from guidellm.utils.progress import BenchmarkReportProgress - - -@pytest.fixture() -def mock_benchmark_report(): - with patch("guidellm.main.GuidanceReport") as mock_benchmark_report: - - def _mock_const(*args, **kwargs): - instance = create_autospec(BenchmarkReportProgress, instance=True) - instance.args = args - instance.kwargs = kwargs - instance.benchmarks = [] - instance.save_file = lambda output_path: None - instance.print = lambda *args, **kwargs: None - - return instance - - mock_benchmark_report.side_effect = _mock_const - yield mock_benchmark_report - - -@pytest.fixture() -def mock_benchmark_report_progress(): - with patch( - "guidellm.main.BenchmarkReportProgress" - ) as mock_benchmark_report_progress: - - def _mock_const(*args, **kwargs): - instance = create_autospec(BenchmarkReportProgress, instance=True) - instance.args = args - instance.kwargs = kwargs - - return instance - - mock_benchmark_report_progress.side_effect = _mock_const - yield mock_benchmark_report_progress - - -@pytest.fixture() -def mock_backend(): - with patch("guidellm.main.Backend.create") as mock_create: - - def _mock_create(*args, **kwargs): - backend = create_autospec(Backend, instance=True) - backend.args = args - backend.kwargs = kwargs - return backend - - mock_create.side_effect = _mock_create - yield mock_create - - -@pytest.fixture() -def mock_request_generator_emulated(): - with patch( - "guidellm.main.EmulatedRequestGenerator" - ) as mock_request_generator_class: - - def _mock_const(*args, **kwargs): - request_generator = create_autospec(EmulatedRequestGenerator, instance=True) - request_generator.args = args - request_generator.kwargs = kwargs - return request_generator - - mock_request_generator_class.side_effect = _mock_const - yield mock_request_generator_class - - -@pytest.fixture() -def mock_request_generator_file(): - with patch("guidellm.main.FileRequestGenerator") as mock_request_generator_class: - - def _mock_const(*args, **kwargs): - request_generator = create_autospec(FileRequestGenerator, instance=True) - request_generator.args = args - request_generator.kwargs = kwargs - return request_generator - - mock_request_generator_class.side_effect = _mock_const - yield mock_request_generator_class - - -@pytest.fixture() -def mock_request_generator_transformers(): - with patch( - "guidellm.main.TransformersDatasetRequestGenerator" - ) as mock_request_generator_class: - - def _mock_const(*args, **kwargs): - request_generator = create_autospec( - TransformersDatasetRequestGenerator, instance=True - ) - request_generator.args = args - request_generator.kwargs = kwargs - return request_generator - - mock_request_generator_class.side_effect = _mock_const - yield mock_request_generator_class - - -@pytest.fixture() -def mock_executor(): - with patch("guidellm.main.Executor") as mock_executor_class: - - def _mock_const(*args, **kwargs): - executor = create_autospec(Executor, instance=True) - executor.args = args - executor.kwargs = kwargs - - async def _mock_executor_run(): - generation_modes: List[ProfileGenerationMode] - generation_rates: List[Optional[float]] - completed_rates: List[float] - - if kwargs["mode"] == "sweep": - num_benchmarks = 12 - generation_modes = [ # type: ignore - "synchronous", - "throughput", - ] + ["constant"] * 10 - generation_rates = [None, None] + [ind + 1.0 for ind in range(10)] - completed_rates = [1.0, 10.0] + [ind + 1.0 for ind in range(10)] - elif kwargs["rate"] is not None and isinstance(kwargs["rate"], list): - num_benchmarks = len(kwargs["rate"]) - generation_modes = [kwargs["mode"]] * num_benchmarks - generation_rates = kwargs["rate"] - completed_rates = kwargs["rate"] - else: - num_benchmarks = 1 - generation_modes = [kwargs["mode"]] - generation_rates = [kwargs["rate"]] - completed_rates = [1.0] - - report = create_autospec(TextGenerationBenchmarkReport, instance=True) - report.args = { - "backend": "backend", - "request_generator": "request_generator", - "mode": kwargs["mode"], - "rate": kwargs["rate"], - "max_number": kwargs["max_number"], - "max_duration": kwargs["max_duration"], - } - yield ExecutorResult( - completed=False, - count_total=num_benchmarks, - count_completed=0, - generation_modes=generation_modes, - report=report, - scheduler_result=None, - current_index=None, - current_profile=None, - ) - for bench in range(num_benchmarks): - benchmark = create_autospec( - TextGenerationBenchmarkReport, instance=True - ) - benchmark.start_time = 0 - benchmark.end_time = 1 - benchmark.completed_request_rate = completed_rates[bench] - profile = Profile( - load_gen_mode=generation_modes[bench], # type: ignore - load_gen_rate=generation_rates[bench], - ) - yield ExecutorResult( - completed=False, - count_total=num_benchmarks, - count_completed=bench, - generation_modes=generation_modes, - report=report, - scheduler_result=SchedulerResult( - completed=False, - count_total=10, - count_completed=0, - benchmark=benchmark, - current_result=None, - ), - current_index=bench, - current_profile=profile, - ) - for ind in range(10): - yield ExecutorResult( - completed=False, - count_total=num_benchmarks, - count_completed=bench, - generation_modes=generation_modes, - report=report, - scheduler_result=SchedulerResult( - completed=False, - count_total=10, - count_completed=ind + 1, - benchmark=benchmark, - current_result=create_autospec(TextGenerationResult), - ), - current_index=bench, - current_profile=profile, - ) - yield ExecutorResult( - completed=False, - count_total=num_benchmarks, - count_completed=bench + 1, - generation_modes=generation_modes, - report=report, - ) - yield ExecutorResult( - completed=True, - count_total=num_benchmarks, - count_completed=num_benchmarks, - generation_modes=generation_modes, - report=report, - ) - - executor.run.side_effect = _mock_executor_run - return executor - - mock_executor_class.side_effect = _mock_const - yield mock_executor_class - - -@pytest.mark.smoke() -def test_generate_benchmark_report_invoke_smoke( - mock_backend, mock_request_generator_emulated, mock_executor -): - report = generate_benchmark_report( - target="http://localhost:8000/v1", - backend="openai_server", - model=None, - data=None, - data_type="emulated", - tokenizer=None, - rate_type="sweep", - rate=None, - max_seconds=10, - max_requests=10, - output_path="benchmark_report.json", - cont_refresh_table=False, - ) - assert report is not None - - -@pytest.mark.smoke() -def test_generate_benchmark_report_cli_smoke( - mock_backend, mock_request_generator_emulated, mock_executor -): - runner = CliRunner() - result = runner.invoke( - generate_benchmark_report_cli, - [ - "--target", - "http://localhost:8000/v1", - "--backend", - "openai_server", - "--data-type", - "emulated", - "--data", - "prompt_tokens=512", - "--rate-type", - "sweep", - "--max-seconds", - "10", - "--max-requests", - "10", - "--output-path", - "benchmark_report.json", - ], - ) - - if result.stdout: - print(result.stdout) # noqa: T201 - - assert result.exit_code == 0 - assert "Benchmarks" in result.output - assert "Generating report..." in result.output - assert "Benchmark Report 1" in result.output - - -@pytest.mark.smoke() -def test_generate_benchmark_report_emulated_with_dataset_requests( - mock_backend, mock_request_generator_emulated, mock_executor -): - with pytest.raises(ValueError, match="Cannot use 'dataset' for emulated data"): - generate_benchmark_report( - target="http://localhost:8000/v1", - backend="openai_server", - model=None, - data_type="emulated", - data=None, - tokenizer=None, - rate_type="sweep", - rate=None, - max_seconds=10, - max_requests="dataset", - output_path="benchmark_report.json", - cont_refresh_table=False, - ) - - -@pytest.mark.smoke() -def test_generate_benchmark_report_cli_emulated_with_dataset_requests( - mock_backend, mock_request_generator_emulated, mock_executor -): - runner = CliRunner() - with pytest.raises(ValueError, match="Cannot use 'dataset' for emulated data"): - runner.invoke( - generate_benchmark_report_cli, - [ - "--target", - "http://localhost:8000/v1", - "--backend", - "openai_server", - "--data-type", - "emulated", - "--data", - "prompt_tokens=512", - "--rate-type", - "sweep", - "--max-seconds", - "10", - "--max-requests", - "dataset", - "--output-path", - "benchmark_report.json", - ], - catch_exceptions=False, - ) - - -@pytest.mark.sanity() -@pytest.mark.parametrize(("rate_type", "rate"), [("constant", 1.0), ("sweep", 1.0)]) -@pytest.mark.parametrize( - ("file_extension", "file_content", "expected_results"), - [ - ("txt", "Test prompt 1", 1), - ("txt", "Test prompt 1\nTest prompt 2\nTest prompt 3\n", 3), - ], -) -def test_generate_benchmark_report_openai_limited_by_file_dataset( - mocker, - mock_auto_tokenizer, - mock_benchmark_report, - mock_benchmark_report_progress, - rate_type, - rate, - file_extension, - file_content, - expected_results, -): - """ - Mock only a few functions to get the proper report result - from the ``Backend.make_request``. - - Notes: - All the results are collected in the `benchmark.errors``, - since the most of the responses are mocked and can't be processed. - But the ordering of the results is still the same for both collections. - - ``mock_benchmark_report`` and ``mock_benchmark_report_progress`` - are used for preventing working with IO bound tasks. - """ - - mocker.patch("guidellm.backend.openai.AsyncOpenAI") - mocker.patch("openai.AsyncOpenAI") - mocker.patch("guidellm.backend.openai.OpenAIBackend.test_connection") - mocker.patch("guidellm.backend.openai.OpenAIBackend.available_models") - - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / f"example.{file_extension}" - file_path.write_text(file_content) - - # Run the benchmark report generation - report = generate_benchmark_report( - target="http://localhost:8000/v1", - backend="openai_server", - model=None, - data=str(file_path), - data_type="file", - tokenizer=None, - rate_type=rate_type, - rate=rate, - max_seconds=None, - max_requests="dataset", - output_path="benchmark_report.json", - cont_refresh_table=False, - ) - - assert report is not None - assert len(report.benchmarks) == 1 - assert len(report.benchmarks[0].benchmarks[0].errors) == expected_results - - file_lines: List[str] = [line for line in file_content.split("\n") if line] - output_prompts = [ - text_generation.request.prompt - for text_generation in report.benchmarks[0].benchmarks[0].errors - ] - - assert output_prompts == file_lines