Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 127 additions & 20 deletions rllm-model-gateway/src/rllm_model_gateway/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,29 +309,40 @@ async def _handle_cumulative_non_streaming(
token_ids: list[int],
originally_requested_logprobs: bool = False,
) -> Response:
"""Non-streaming cumulative turn: send non-streaming to vLLM, return JSON."""
"""Non-streaming cumulative turn: send pre-tokenized prompt, return JSON.

Routes to the in-process ``local_handler`` (e.g. Tinker) when present,
otherwise POSTs ``/v1/completions`` to a vLLM worker. Both return a
completions-style body carrying ``prompt_token_ids`` + ``token_ids``.
"""
t0 = time.perf_counter()

worker = self.router.route(session_id)
url = self._build_url(worker.api_url, "/v1/completions", "")
headers = self._forward_headers(request)
raw_body = json.dumps(completions_body).encode()
try:
resp = await self._send_with_retry(
method="POST",
url=url,
content=raw_body,
headers=headers,
)
content = resp.content
status_code = resp.status_code
finally:
self.router.release(worker.url)
if self.local_handler is not None:
# In-process path (Tinker): sample directly from the pre-tokenized
# prompt; no HTTP worker, no re-tokenization.
response_body = await self.local_handler(completions_body)
status_code = 200
else:
worker = self.router.route(session_id)
url = self._build_url(worker.api_url, "/v1/completions", "")
headers = self._forward_headers(request)
raw_body = json.dumps(completions_body).encode()
try:
resp = await self._send_with_retry(
method="POST",
url=url,
content=raw_body,
headers=headers,
)
content = resp.content
status_code = resp.status_code
finally:
self.router.release(worker.url)

try:
response_body = json.loads(content)
except (json.JSONDecodeError, UnicodeDecodeError):
response_body = {}
try:
response_body = json.loads(content)
except (json.JSONDecodeError, UnicodeDecodeError):
response_body = {}

latency_ms = (time.perf_counter() - t0) * 1000

Expand Down Expand Up @@ -375,6 +386,11 @@ async def _handle_cumulative_streaming(
token_ids: list[int],
) -> StreamingResponse:
"""Streaming cumulative turn: stream from vLLM, translate chunks to chat format."""
if self.local_handler is not None:
# In-process backends (Tinker) don't stream; synthesize an SSE
# stream from a single pre-tokenized completion call.
return await self._handle_cumulative_streaming_local(request, request_body, completions_body, session_id, acc, token_ids)

completions_body["stream"] = True

worker = self.router.route(session_id)
Expand Down Expand Up @@ -502,6 +518,97 @@ async def event_generator():
status_code=resp.status_code,
)

async def _handle_cumulative_streaming_local(
self,
request: Request,
request_body: dict[str, Any],
completions_body: dict[str, Any],
session_id: str,
acc: TokenAccumulator,
token_ids: list[int],
) -> StreamingResponse:
"""Cumulative streaming via the in-process handler (fake-streaming).

The local handler returns a full completion in one shot; we ingest its
token IDs into the accumulator, translate to chat format, and emit it as
a synthesized SSE stream (role → content+tokens → finish → [DONE]).
"""
assert self.local_handler is not None
t0 = time.perf_counter()
response_body = await self.local_handler(completions_body)
latency_ms = (time.perf_counter() - t0) * 1000

prompt_token_ids = extract_prompt_token_ids(response_body) or token_ids
completion_token_ids = extract_completion_token_ids(response_body)
acc.ingest_turn(prompt_token_ids, completion_token_ids)
acc.update_prefix(request_body.get("messages", []))

choices = response_body.get("choices") or []
content = choices[0].pop("text", "") if choices else ""
finish_reason = (choices[0].get("finish_reason") if choices else None) or "stop"
completion_logprobs = choices[0].get("logprobs") if choices else None

if session_id and response_body:
chat_body = dict(response_body)
chat_body["object"] = "chat.completion"
if chat_body.get("choices"):
chat_body["choices"][0]["message"] = {"role": "assistant", "content": content}
trace = build_trace_record(session_id, request_body, chat_body, latency_ms, weight_version=request.state.weight_version)
await self._persist(trace)

chat_id = response_body.get("id", "chatcmpl-local")
created = response_body.get("created", int(time.time()))
model = response_body.get("model", "")
usage = response_body.get("usage", {})

def _sanitize_chunk(chunk: dict[str, Any]) -> dict[str, Any]:
return strip_vllm_fields(chunk) if self.strip_vllm else chunk

async def event_generator():
def _sse(data: str) -> str:
return f"data: {data}\n\n"

yield _sse(
json.dumps(
{
"id": chat_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}
)
)
yield _sse(
json.dumps(
_sanitize_chunk(
{
"id": chat_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None, "token_ids": completion_token_ids, "logprobs": completion_logprobs}],
"prompt_token_ids": prompt_token_ids,
}
)
)
)
yield _sse(
json.dumps(
{
"id": chat_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}],
"usage": usage,
}
)
)
yield _sse("[DONE]")

return StreamingResponse(event_generator(), media_type="text/event-stream", status_code=200)

# ------------------------------------------------------------------
# Streaming (SSE)
# ------------------------------------------------------------------
Expand Down
136 changes: 136 additions & 0 deletions rllm-model-gateway/tests/unit/test_cumulative_token_mode_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Cumulative token mode over the in-process ``local_handler`` (e.g. Tinker).

The HTTP-worker path is covered by ``test_cumulative_token_mode.py``. These
tests cover the local-handler branch added so backends without a vLLM
``/v1/completions`` worker (Tinker runs in-process) get the same drift-free
prefix-extension: turn 2+ samples directly from pre-tokenized prompt IDs built
by the renderer, and the accumulator ingests the result.

Tests drive the proxy methods directly (no HTTP server) via ``asyncio.run`` to
avoid a pytest-asyncio dependency.
"""

import asyncio
import json

from rllm_model_gateway.proxy import ReverseProxy
from rllm_model_gateway.store.memory_store import MemoryTraceStore
from rllm_model_gateway.token_accumulator import TokenAccumulator


class _State:
weight_version = 0


class _Request:
"""Minimal stand-in: the cumulative-local paths only read request.state."""

state = _State()


def _make_proxy(local_handler):
"""ReverseProxy wired for the local cumulative path (no router/worker)."""
return ReverseProxy(
router=None,
store=MemoryTraceStore(),
sync_traces=True,
local_handler=local_handler,
cumulative_token_mode=True,
renderer=None, # accumulator state is set directly in these tests
)


def _completion_handler(record):
"""Fake Tinker token-path handler: echoes ``prompt`` as prompt_token_ids,
returns a fixed 2-token completion. Records the body it was called with."""

async def handler(body):
record.append(body)
return {
"id": "cmpl-x",
"object": "text_completion",
"choices": [
{
"index": 0,
"text": "next action",
"token_ids": [91, 92],
"finish_reason": "stop",
"logprobs": {"token_logprobs": [-0.1, -0.2]},
}
],
"prompt_token_ids": body["prompt"],
"usage": {"prompt_tokens": len(body["prompt"]), "completion_tokens": 2},
}

return handler


def test_cumulative_local_non_streaming_ingests_and_translates():
record = []
proxy = _make_proxy(_completion_handler(record))

acc = TokenAccumulator(renderer=None)
acc.ingest_turn([1, 2, 3], [4, 5]) # turn 1: prompt + completion already captured
bridged = [1, 2, 3, 4, 5, 6, 7] # what the renderer would produce for turn 2

completions_body = {"prompt": bridged, "add_special_tokens": False, "model": "q"}
resp = asyncio.run(
proxy._handle_cumulative_non_streaming(
_Request(),
{"messages": [{"role": "user", "content": "x"}]},
completions_body,
"sess1",
acc,
bridged,
)
)

# The local handler was called with the pre-tokenized prompt, not messages.
assert record and record[0]["prompt"] == bridged
assert "messages" not in record[0]

# Response is translated back to chat format for the agent.
body = json.loads(resp.body)
assert resp.status_code == 200
assert body["object"] == "chat.completion"
assert body["choices"][0]["message"] == {"role": "assistant", "content": "next action"}

# Accumulator advanced: prefix-extension holds (bridged starts with prev prompt+completion).
assert acc.turn_count == 2
assert acc.prev_prompt_ids == bridged
assert acc.prev_completion_ids == [91, 92]
assert acc.cumulative_ids[: len([1, 2, 3, 4, 5])] == [1, 2, 3, 4, 5]


def test_cumulative_local_streaming_emits_sse_and_ingests():
record = []
proxy = _make_proxy(_completion_handler(record))

acc = TokenAccumulator(renderer=None)
acc.ingest_turn([1, 2, 3], [4, 5])
bridged = [1, 2, 3, 4, 5, 6, 7]

resp = asyncio.run(
proxy._handle_cumulative_streaming_local(
_Request(),
{"messages": [{"role": "user", "content": "x"}]},
{"prompt": bridged},
"sess1",
acc,
bridged,
)
)

chunks = []

async def drain():
async for c in resp.body_iterator:
chunks.append(c if isinstance(c, str) else c.decode())

asyncio.run(drain())

assert any('"role": "assistant"' in c for c in chunks)
assert any('"next action"' in c for c in chunks)
assert chunks[-1].strip().endswith("[DONE]")
# Same ingest as non-streaming.
assert acc.turn_count == 2 and acc.prev_completion_ids == [91, 92]
Loading
Loading