Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 109 additions & 99 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.model_worker import (
logger,
Expand Down Expand Up @@ -66,108 +67,117 @@ def __init__(

async def generate_stream(self, params):
self.call_ct += 1

context = params.pop("prompt")
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)

request = params.get("request", None)

# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)

for tid in stop_token_ids:
if tid is not None:
s = self.tokenizer.decode(tid)
if s != "":
stop.add(s)

# make sampling params in vllm
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0

sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
stop_token_ids=stop_token_ids,
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)

partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
# prevent yielding partial stop sequence
if partial_stop:
continue

aborted = False
if request and await request.is_disconnected():
await engine.abort(request_id)
request_output.finished = True
aborted = True
for output in request_output.outputs:
output.finish_reason = "abort"

prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
try:
context = params.pop("prompt")
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)

request = params.get("request", None)

# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)

for tid in stop_token_ids:
if tid is not None:
s = self.tokenizer.decode(tid)
if s != "":
stop.add(s)

# make sampling params in vllm
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
stop_token_ids=stop_token_ids,
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)

partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
# prevent yielding partial stop sequence
if partial_stop:
continue

aborted = False
if request and await request.is_disconnected():
await engine.abort(request_id)
request_output.finished = True
aborted = True
for output in request_output.outputs:
output.finish_reason = "abort"

prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
ret = {
"text": text_outputs,
"error_code": 0,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"cumulative_logprob": [
output.cumulative_logprob for output in request_output.outputs
],
"finish_reason": (
request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs]
),
}
# Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response.
# This aligns with the behavior of model_worker.
if request_output.finished:
yield (
json.dumps({**ret, **{"finish_reason": None}}) + "\0"
).encode()
yield (json.dumps(ret) + "\0").encode()

if aborted:
break
except ValueError as e:
ret = {
"text": text_outputs,
"error_code": 0,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"cumulative_logprob": [
output.cumulative_logprob for output in request_output.outputs
],
"finish_reason": request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs],
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
# Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response.
# This aligns with the behavior of model_worker.
if request_output.finished:
yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode()
yield (json.dumps(ret) + "\0").encode()

if aborted:
break
yield json.dumps(ret).encode() + b"\0"

async def generate(self, params):
async for x in self.generate_stream(params):
Expand Down