Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.0b27"
__version__ = "0.0.0b28"

import os
from typing import Sequence
Expand Down
42 changes: 41 additions & 1 deletion clients/python/llmengine/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterable, Iterator, List, Optional, Union
from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union

from llmengine.api_engine import APIEngine
from llmengine.data_types import (
Expand Down Expand Up @@ -43,6 +43,10 @@ async def acreate(
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
include_stop_str_in_output: Optional[bool] = False,
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
Expand Down Expand Up @@ -102,6 +106,18 @@ async def acreate(
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.

include_stop_str_in_output (Optional[bool]):
Whether to include the stop sequence in the output. Default to False.

guided_json (Optional[Dict[str, Any]]):
If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples.

guided_regex (Optional[str]):
If specified, the output will follow the regex pattern.

guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.

timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.

Expand Down Expand Up @@ -198,6 +214,10 @@ async def _acreate_stream(
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
timeout=timeout,
)

Expand Down Expand Up @@ -237,6 +257,10 @@ def create(
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
include_stop_str_in_output: Optional[bool] = False,
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
Expand Down Expand Up @@ -297,6 +321,18 @@ def create(
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.

include_stop_str_in_output (Optional[bool]):
Whether to include the stop sequence in the output. Default to False.

guided_json (Optional[Dict[str, Any]]):
If specified, the output will follow the JSON schema.

guided_regex (Optional[str]):
If specified, the output will follow the regex pattern.

guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.

timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.

Expand Down Expand Up @@ -396,6 +432,10 @@ def _create_stream(**kwargs):
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
).dict()
response = cls.post_sync(
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",
Expand Down
8 changes: 8 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ class CompletionSyncV1Request(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
include_stop_str_in_output: Optional[bool] = Field(default=False)
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -349,6 +353,10 @@ class CompletionStreamV1Request(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
include_stop_str_in_output: Optional[bool] = Field(default=False)
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)


class CompletionStreamOutput(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scale-llm-engine"
version = "0.0.0.beta27"
version = "0.0.0.beta28"
description = "Scale LLM Engine Python client"
license = "Apache-2.0"
authors = ["Phil Chen <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
setup(
name="scale-llm-engine",
python_requires=">=3.7",
version="0.0.0.beta27",
version="0.0.0.beta28",
packages=find_packages(),
)
53 changes: 53 additions & 0 deletions docs/guides/completions.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,59 @@ response = Completion.batch_create(
print(response.json())
```

## Guided decoding

Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines).
It enforces certain token generation patterns by tinkering with the sampling logits.

=== "Guided decoding with regex"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_regex="Sean.*",
)

print(response.json())
# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
```

=== "Guided decoding with choice"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_choice=["Sean", "Brian", "Tim"],
)

print(response.json())
# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}}
```

=== "Guided decoding with JSON schema"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]},
)

print(response.json())
# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
```

## Which model should I use?

See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions.
24 changes: 24 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ class CompletionSyncV1Request(BaseModel):
"""
Whether to include the stop strings in output text.
"""
guided_json: Optional[Dict[str, Any]] = None
"""
JSON schema for guided decoding.
"""
guided_regex: Optional[str] = None
"""
Regex for guided decoding.
"""
guided_choice: Optional[List[str]] = None
"""
Choices for guided decoding.
"""


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -248,6 +260,18 @@ class CompletionStreamV1Request(BaseModel):
"""
Whether to include the stop strings in output text.
"""
guided_json: Optional[Dict[str, Any]] = None
"""
JSON schema for guided decoding. Only supported in vllm.
"""
guided_regex: Optional[str] = None
"""
Regex for guided decoding. Only supported in vllm.
"""
guided_choice: Optional[List[str]] = None
"""
Choices for guided decoding. Only supported in vllm.
"""


class CompletionStreamOutput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,26 @@ def validate_and_update_completion_params(
"include_stop_str_in_output is only supported in vllm."
)

guided_count = 0
if request.guided_choice is not None:
guided_count += 1
if request.guided_json is not None:
guided_count += 1
if request.guided_regex is not None:
guided_count += 1

if guided_count > 1:
raise ObjectHasInvalidValueException(
"Only one of guided_json, guided_choice, guided_regex can be enabled."
)

if (
request.guided_choice is not None
or request.guided_regex is not None
or request.guided_json is not None
) and not inference_framework == LLMInferenceFramework.VLLM:
raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.")

return request


Expand Down Expand Up @@ -1656,6 +1676,12 @@ async def execute(
vllm_args["logprobs"] = 1
if request.include_stop_str_in_output is not None:
vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output
if request.guided_choice is not None:
vllm_args["guided_choice"] = request.guided_choice
if request.guided_regex is not None:
vllm_args["guided_regex"] = request.guided_regex
if request.guided_json is not None:
vllm_args["guided_json"] = request.guided_json

inference_request = SyncEndpointPredictV1Request(
args=vllm_args,
Expand Down Expand Up @@ -1918,6 +1944,12 @@ async def execute(
args["logprobs"] = 1
if request.include_stop_str_in_output is not None:
args["include_stop_str_in_output"] = request.include_stop_str_in_output
if request.guided_choice is not None:
args["guided_choice"] = request.guided_choice
if request.guided_regex is not None:
args["guided_regex"] = request.guided_regex
if request.guided_json is not None:
args["guided_json"] = request.guided_json
args["stream"] = True
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
args = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
ray>=2.9
vllm==0.3.2
vllm==0.3.3
pydantic>=2.0
32 changes: 31 additions & 1 deletion model-engine/model_engine_server/inference/vllm/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from typing import AsyncGenerator

import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand Down Expand Up @@ -38,7 +40,35 @@ async def generate(request: Request) -> Response:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
guided_json = request_dict.pop("guided_json", None)
guided_regex = request_dict.pop("guided_regex", None)
guided_choice = request_dict.pop("guided_choice", None)
sampling_params = SamplingParams(**request_dict)

# Dummy request to get guided decode logit processor
try:
partial_openai_request = OpenAICompletionRequest.model_validate(
{
"model": "",
"prompt": "",
"guided_json": guided_json,
"guided_regex": guided_regex,
"guided_choice": guided_choice,
}
)
except Exception:
raise HTTPException(
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
)

guided_decode_logit_processor = await get_guided_decoding_logits_processor(
partial_openai_request, engine.get_tokenizer()
)
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(guided_decode_logit_processor)

request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)

Expand Down
Loading