Skip to content

Commit 43751a7

Browse files
authored
Add explicit reset points to page cache, decode state etc. (#1291)
Instead of trying to unwind the allocated sequences, just reset to the initial state explicitly.
1 parent 6cd783c commit 43751a7

File tree

5 files changed

+70
-38
lines changed

5 files changed

+70
-38
lines changed

src/levanter/inference/engine.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ class GenState(eqx.Module):
272272
cache: PageCache
273273
decode_state: DecodeState
274274

275+
def reset(self):
276+
return GenState(
277+
cache=self.cache.reset(),
278+
decode_state=self.decode_state.reset(),
279+
)
280+
275281
def clone_sequence(
276282
self, parent_local_id: int, child_local_id: int | None = None, seq_params: SeqDecodingParams | None = None
277283
) -> tuple["GenState", int]:
@@ -797,22 +803,6 @@ def __init__(
797803
# Results by request id -> choice -> DecodeResult
798804
self.results: dict[int, dict[int, DecodeResult]] = {}
799805

800-
def _verify_free_slot_view(self, *, context: str) -> None:
801-
"""Ensure host free-list matches the device page-table used mask."""
802-
803-
used_mask = np.asarray(jax.device_get(self.gen_state.decode_state.sequences.used_mask.array)).astype(bool)
804-
free_set = set(self.free_slots)
805-
806-
for slot_id, is_used in enumerate(used_mask):
807-
if is_used and slot_id in free_set:
808-
raise RuntimeError(
809-
f"[free slot invariant] slot {slot_id} marked used but present in free list during {context}"
810-
)
811-
if not is_used and slot_id not in free_set:
812-
raise RuntimeError(
813-
f"[free slot invariant] slot {slot_id} free in page table but missing from free list during {context}"
814-
)
815-
816806
@classmethod
817807
def from_model_with_config(
818808
cls,
@@ -853,26 +843,12 @@ def reset(self) -> None:
853843
854844
Keeps the KV cache memory allocated. Reuses current `PageTable` object with pages freed.
855845
"""
856-
decode_state = self.gen_state.decode_state
857-
page_table = decode_state.page_table
858-
sequences = decode_state.sequences
859-
for slot_id in range(page_table.max_seqs):
860-
sequences, page_table = sequences.free_pages(page_table, slot_id)
861-
862-
new_decode_state = DecodeState.init(
863-
page_table,
864-
max_stop_seqs=self.config.max_stop_seqs,
865-
max_stop_tokens=self.config.max_stop_tokens,
866-
max_queued_tokens=self.config.max_queued_tokens,
867-
)
868-
self.gen_state = dataclasses.replace(self.gen_state, decode_state=new_decode_state)
869-
self.free_slots = list(range(int(page_table.max_seqs)))
846+
self.gen_state = self.gen_state.reset()
847+
self.free_slots = list(range(int(self.gen_state.decode_state.max_seqs)))
870848
self.local_map.clear()
871849
self.sequences.clear()
872850
self.results = {}
873851

874-
self._verify_free_slot_view(context="reset")
875-
876852
def _prefill_batch(self, batch: Sequence[Request]) -> _DecodeOutputs | None:
877853
"""Admit a batch from the head of the queue that fits in free slots/pages.
878854

src/levanter/inference/jit_scheduler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ class DecodeState(eqx.Module):
564564
# Page table for KV page allocation and per-sequence lengths/usage
565565
page_table: PageTable
566566

567+
pad_token_id: int
568+
567569
# Per sequence sampling parameters
568570
max_num_tokens: ht.i32[NamedArray, "seq"]
569571
"""
@@ -583,6 +585,15 @@ class DecodeState(eqx.Module):
583585
# Cached finished flags per sequence (updated when tokens are enqueued)
584586
finished: ht.bool_[NamedArray, "seq"]
585587

588+
def reset(self):
589+
return DecodeState.init(
590+
page_table=self.page_table.reset(),
591+
pad_token_id=self.pad_token_id,
592+
max_stop_seqs=self.stop_tokens.shape["stop_seq"] if self.stop_tokens is not None else 0,
593+
max_stop_tokens=self.stop_tokens.shape["position"] if self.stop_tokens is not None else 0,
594+
max_queued_tokens=self.tqueue.max_queued_tokens,
595+
)
596+
586597
@staticmethod
587598
def init(
588599
page_table: PageTable,
@@ -605,6 +616,7 @@ def init(
605616
sequences=sequence_table,
606617
page_size=page_size,
607618
page_table=page_table,
619+
pad_token_id=pad_token_id,
608620
tokens=hax.full({"seq": max_seqs, "position": max_seq_len}, pad_token_id, dtype=jnp.int32),
609621
logprobs=hax.full({"seq": max_seqs, "position": max_seq_len}, jnp.nan, dtype=jnp.float32),
610622
max_num_tokens=hax.full({"seq": max_seqs}, 0, dtype=jnp.int32),

src/levanter/inference/page_table.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import dataclasses
55

66
import equinox as eqx
7-
import jax.numpy as jnp
87
import haliax as hax
98
import haliax.haxtyping as ht
9+
import jax.numpy as jnp
1010
from haliax import NamedArray
1111

1212
from levanter.inference.utils import INVALID, is_valid
@@ -35,6 +35,10 @@ def init(max_pages: int, max_seqs: int, page_size: int, max_pages_per_seq: int)
3535
ref_counts = hax.full({"page": max_pages}, 0, dtype=jnp.int32)
3636
return PageTable(ref_counts, page_size, max_seqs, max_pages_per_seq)
3737

38+
def reset(self) -> "PageTable":
39+
ref_counts = hax.full_like(self.page_ref_counts, 0)
40+
return PageTable(ref_counts, self.page_size, self._max_seqs, self._pages_per_seq)
41+
3842
@property
3943
def num_pages(self) -> int:
4044
return self.page_ref_counts.axis_size("page")

src/levanter/layers/kv_cache.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55

66
import dataclasses
77
import functools
8-
from typing import Generic, Iterable, Iterator, TypeVar, Self
8+
from typing import Generic, Iterable, Iterator, Self, TypeVar
99

1010
import equinox as eqx
11+
import haliax as hax
1112
import jax
1213
import jax.numpy as jnp
13-
from jax import lax
14-
15-
import haliax as hax
1614
from haliax import Axis, NamedArray
1715
from haliax.jax_utils import named_call
16+
from jax import lax
1817

1918
from levanter.inference.page_table import PageBatchInfo, PageTableSpec
2019

@@ -26,6 +25,10 @@ def copy_page(self, src_page: int, dst_page: int) -> Self:
2625
"""Return a copy of this cache with ``src_page`` cloned into ``dst_page``."""
2726
raise NotImplementedError
2827

28+
def reset(self) -> Self:
29+
"""Return a reset version of this cache."""
30+
raise NotImplementedError
31+
2932

3033
class KvPageCache(PageCache):
3134
"""Concrete KV cache storing interleaved key/value pages for paged attention."""
@@ -54,6 +57,11 @@ def init(spec: PageTableSpec, kv_heads: Axis, head_size: Axis, dtype=jnp.float32
5457
)
5558
return KvPageCache(kv_pages)
5659

60+
def reset(self) -> "KvPageCache":
61+
"""Return a reset version of this cache."""
62+
reset_pages = jnp.zeros_like(self.kv_pages.array)
63+
return dataclasses.replace(self, kv_pages=NamedArray(reset_pages, self.kv_pages.axes))
64+
5765
@named_call
5866
def update(
5967
self,
@@ -103,6 +111,9 @@ class ListCache(PageCache, Generic[PageCacheT]):
103111
def __post_init__(self):
104112
object.__setattr__(self, "caches", tuple(self.caches))
105113

114+
def reset(self) -> "ListCache[PageCacheT]":
115+
return ListCache(tuple(cache.reset() for cache in self.caches))
116+
106117
@staticmethod
107118
def from_iterable(caches: Iterable[PageCacheT]) -> "ListCache[PageCacheT]":
108119
return ListCache(tuple(caches))

tests/inference/test_inference_server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def trainer_config():
3838
def baby_llama_config():
3939
return InferenceServerConfig(
4040
service=InferenceEngineConfig(
41-
max_seq_len=16,
41+
max_seq_len=32,
4242
max_seqs=2,
4343
page_size=4,
4444
max_queued_tokens=32,
@@ -424,6 +424,35 @@ def test_logprobs_deterministic_behavior(test_client):
424424
print("Deterministic logprobs test passed!")
425425

426426

427+
def test_many_requests_threaded(test_client):
428+
executor = ThreadPoolExecutor(max_workers=8)
429+
client, server = test_client
430+
futures = []
431+
num_requests = 20
432+
for i in range(num_requests):
433+
futures.append(
434+
executor.submit(
435+
client.post,
436+
"/v1/completions",
437+
json={
438+
"model": "timinar/baby-llama-58m",
439+
"prompt": "The quick brown fox",
440+
"max_tokens": 16,
441+
"temperature": 0.0,
442+
"seed": i,
443+
},
444+
)
445+
)
446+
447+
for i, future in enumerate(futures):
448+
response = future.result()
449+
assert response.status_code == 200
450+
completion = Completion.model_validate(response.json())
451+
choice = completion.choices[0]
452+
assert choice.text
453+
print(f"Request {i} generated text: '{choice.text}'")
454+
455+
427456
def test_reload_with_zeros_clears_outputs(test_client):
428457
"""Test that reloading with a zeroed-out model properly clears outputs."""
429458
client, server = test_client

0 commit comments

Comments
 (0)