Skip to content

Commit 5f7de5d

Browse files
committed
Revert "factor out a SequenceTable from PageTable/DecodeState, (#1256)"
This reverts commit cc11732.
1 parent cc11732 commit 5f7de5d

File tree

12 files changed

+834
-875
lines changed

12 files changed

+834
-875
lines changed

config/sampler/sample_nano.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ engine:
1818
max_pages: 256
1919
max_seqs: 8
2020
page_size: 8
21-
max_seq_len: 256
2221
max_queued_tokens: 16
2322
max_seqs_in_prefill: 8
2423
max_prefill_size: 128

src/levanter/inference/engine.py

Lines changed: 74 additions & 61 deletions
Large diffs are not rendered by default.

src/levanter/inference/jit_scheduler.py

Lines changed: 83 additions & 603 deletions
Large diffs are not rendered by default.

src/levanter/inference/page_table.py

Lines changed: 476 additions & 20 deletions
Large diffs are not rendered by default.

src/levanter/layers/attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from jax.sharding import PartitionSpec
4040
from jaxtyping import PRNGKeyArray
4141

42-
from ..inference.page_table import PageBatchInfo, PageTableSpec
42+
from ..inference.page_table import PageBatchInfo, PageTable
4343
from .kv_cache import KvPageCache
4444
from .normalization import LayerNormConfigBase
4545
from .rotary import RotaryEmbeddings, RotaryEmbeddingsConfig
@@ -1564,8 +1564,13 @@ def init(config: AttentionConfig, *, key) -> "Attention":
15641564

15651565
return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs)
15661566

1567-
def empty_page_cache(self, spec: PageTableSpec, *, dtype) -> "KvPageCache":
1568-
return KvPageCache.init(spec, self.config.KVHeads, self.config.HeadSize, dtype=dtype)
1567+
def empty_page_cache(self, page_table: PageTable, *, dtype) -> "KvPageCache":
1568+
return KvPageCache.init(
1569+
page_table,
1570+
self.config.KVHeads,
1571+
self.config.HeadSize,
1572+
dtype=dtype,
1573+
)
15691574

15701575
@named_call
15711576
def __call__(

src/levanter/layers/kv_cache.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
"""Cache implementations for paged attention."""
55

6+
from __future__ import annotations
7+
68
import dataclasses
79
import functools
810
from typing import Generic, Iterable, Iterator, TypeVar, Self
@@ -16,7 +18,7 @@
1618
from haliax import Axis, NamedArray
1719
from haliax.jax_utils import named_call
1820

19-
from levanter.inference.page_table import PageBatchInfo, PageTableSpec
21+
from levanter.inference.page_table import PageBatchInfo, PageTable
2022

2123

2224
class PageCache(eqx.Module):
@@ -33,20 +35,12 @@ class KvPageCache(PageCache):
3335
kv_pages: NamedArray # [Page, Slot, 2 * KVHeads, Embed]
3436

3537
@staticmethod
36-
def init(spec: PageTableSpec, kv_heads: Axis, head_size: Axis, dtype=jnp.float32) -> "KvPageCache":
37-
"""
38-
Initialize a KvPageCache with the given page table specification and dimensions.
39-
40-
Args:
41-
spec: The layout specification for KV pages.
42-
kv_heads: Axis for key/value heads.
43-
head_size: Axis for head size.
44-
dtype: Data type for the cache.
45-
"""
38+
def init(page_table: PageTable, kv_heads: Axis, head_size: Axis, dtype=jnp.float32) -> "KvPageCache":
39+
"""Allocate an empty KV cache matching *page_table* and head axes."""
4640
kv_pages = hax.zeros(
4741
{
48-
"page": spec.num_pages,
49-
"slot": spec.page_size,
42+
"page": page_table.num_pages,
43+
"slot": page_table.page_size,
5044
"kv_head": 2 * kv_heads.size,
5145
head_size.name: head_size.size,
5246
},
@@ -64,10 +58,7 @@ def update(
6458
"""Append keys and values to the cache based on *batch_info*."""
6559
page_size = self.kv_pages.array.shape[1]
6660

67-
assert page_size == batch_info.page_size, (
68-
f"Page size mismatch: {page_size} != {batch_info.page_size}. "
69-
"Ensure that the page size in batch_info matches the kv_pages."
70-
)
61+
_ = page_size # keeps JIT-shapes available; retained for future assertions.
7162

7263
K = jnp.asarray(batch_info.num_new_tokens, jnp.int32)
7364
t_pages, t_slots = batch_info.pages_and_slots() # [T] int32 (first K valid)
@@ -84,10 +75,7 @@ def update(
8475
return dataclasses.replace(self, kv_pages=updated)
8576

8677
def copy_page(self, src_page: int, dst_page: int) -> "KvPageCache":
87-
"""Copy the entire contents of page ``src_page`` into ``dst_page``.
88-
89-
This is used when creating clones that should have an identical last partial page, but mapped to a fresh page.
90-
"""
78+
"""Copy an entire page of cached keys/values."""
9179
new_k = self.kv_pages.at["page", dst_page].set(self.kv_pages["page", src_page])
9280
return dataclasses.replace(self, kv_pages=new_k)
9381

src/levanter/models/llama.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Callable, Dict, Optional, Type, Union
77

88
import equinox as eqx
9+
import haliax.debug
910
import jax
1011
import jax.random as jrandom
1112
from jaxtyping import PRNGKeyArray
@@ -18,7 +19,7 @@
1819
from haliax.state_dict import ModuleWithStateDictSerialization
1920

2021
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig
21-
from levanter.inference.page_table import PageBatchInfo, PageTableSpec
22+
from levanter.inference.page_table import PageBatchInfo, PageTable
2223
from levanter.layers import LayerNormConfigBase, RmsNormConfig
2324
from levanter.layers.attention import Attention, AttentionBackend, AttentionConfig, AttentionMask
2425
from levanter.layers.kv_cache import KvPageCache, ListCache
@@ -364,12 +365,12 @@ def decode(
364365
output = residual + mlp_output
365366
return output, kv_cache
366367

367-
def initial_cache(self, spec: PageTableSpec, *, dtype) -> KvPageCache:
368+
def initial_cache(self, page_table: PageTable, *, dtype) -> KvPageCache:
368369
"""
369370
Creates an empty page cache for this layer. Note that in order to create a decoder state, you
370371
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
371372
"""
372-
return self.self_attn.empty_page_cache(spec, dtype=dtype)
373+
return self.self_attn.empty_page_cache(page_table, dtype=dtype)
373374

374375

375376
class LlamaTransformer(eqx.Module):
@@ -447,14 +448,14 @@ def decode(
447448

448449
return x, ListCache(updated_caches)
449450

450-
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
451+
def initial_cache(self, page_table: PageTable, *, dtype) -> ListCache[KvPageCache]:
451452
"""
452453
Creates an empty page cache for this transformer. Note that in order to create a decoder state, you
453454
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
454455
"""
455456
# sadly this is too cute/smart for XLA to handle aliasing correctly
456-
# return self.layers.vmap_via(LlamaDecoderLayer.initial_cache)(spec, dtype=dtype)
457-
caches = [layer.initial_cache(spec, dtype=dtype) for layer in self.layers.unstacked()]
457+
# return self.layers.vmap_via(LlamaDecoderLayer.initial_cache)(page_table, dtype=dtype)
458+
caches = [layer.initial_cache(page_table, dtype=dtype) for layer in self.layers.unstacked()]
458459
return ListCache(caches)
459460

460461

@@ -604,12 +605,12 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
604605
def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
605606
return {"transformer": "model", "embeddings": None}
606607

607-
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
608+
def initial_cache(self, page_table: PageTable, *, dtype) -> ListCache[KvPageCache]:
608609
"""
609610
Creates an initial cache for this model. Note that in order to create a decoder state, you
610611
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
611612
"""
612-
return hax.auto_sharded(self.transformer.initial_cache(spec, dtype=dtype))
613+
return hax.auto_sharded(self.transformer.initial_cache(page_table, dtype=dtype))
613614

614615
@named_call
615616
def decode(

tests/inference/test_clone_pages.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,81 +8,103 @@
88

99
import haliax as hax
1010

11-
from levanter.inference.jit_scheduler import SequenceTable
1211
from levanter.inference.page_table import PageTable
1312
from levanter.layers.kv_cache import KvPageCache
1413

1514

16-
def _make_allocator(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3):
17-
pt = PageTable.init(max_pages, max_seqs, page_size, pages_per_seq)
18-
sequences = SequenceTable.init(max_seqs, pages_per_seq, page_size)
19-
return sequences, pt
15+
def _make_table(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3):
16+
return PageTable.init(max_pages, max_seqs, page_size, pages_per_seq)
2017

2118

2219
def test_clone_pages_from_partial_last_page_allocates_fresh_page():
23-
sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
20+
pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
2421

22+
# Parent uses two pages with a partial last page: length = 5 (pages 0 and 1 used)
2523
parent = 0
2624
child = 1
2725

28-
sequences, parent_id = sequences.reserve_slot(parent)
29-
sequences, child_id = sequences.reserve_slot(child)
26+
pt = dataclasses.replace(
27+
pt,
28+
page_indices=pt.page_indices.at["seq", parent, "page", 0]
29+
.set(jnp.array(2, dtype=jnp.int32))
30+
.at["seq", parent, "page", 1]
31+
.set(jnp.array(3, dtype=jnp.int32))
32+
.at["seq", parent, "page", 2]
33+
.set(jnp.array(-1, dtype=jnp.int32)),
34+
seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(5, dtype=jnp.int32)),
35+
)
36+
37+
# Clone
38+
new_pt = pt.clone_pages_from(parent, child)
3039

31-
seq_lens = sequences.seq_lens.at["seq", parent_id].set(5)
32-
page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(2)
33-
page_indices = page_indices.at["seq", parent_id, "page", 1].set(3)
34-
sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices)
35-
ref_counts = pt.page_ref_counts.at["page", 2].set(1)
36-
ref_counts = ref_counts.at["page", 3].set(1)
37-
pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq)
40+
# Fully used pages (all but last partial) should be shared: page 0 mapping identical
41+
assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 2
3842

39-
sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id)
43+
# Last page must be a fresh allocation, different from parent's last page (3)
44+
assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) != 3
4045

41-
assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 2
42-
assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) != 3
43-
assert int(new_pt.page_ref_counts["page", 2].scalar()) == 2
44-
assert int(new_pt.page_ref_counts["page", 3].scalar()) == 1
45-
assert int(sequences.seq_lens["seq", child_id].scalar()) == 5
46+
# Refcounts: +1 for shared full page (page 2) and +1 for newly allocated page; parent's partial last page unchanged
47+
ref_shared = int(new_pt.page_ref_counts["page", 2].scalar())
48+
ref_parent_last = int(new_pt.page_ref_counts["page", 3].scalar())
49+
assert ref_shared == 1
50+
assert ref_parent_last == 0
51+
52+
# Lengths equal (no rounding)
53+
assert int(new_pt.seq_lens["seq", child].scalar()) == 5
4654

4755

4856
def test_clone_pages_from_boundary_shares_last_page():
49-
sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
57+
pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
5058

59+
# Parent uses exactly 2 full pages: length = 8 (pages 4 and 5 used)
5160
parent = 0
5261
child = 1
5362

54-
sequences, parent_id = sequences.reserve_slot(parent)
55-
sequences, child_id = sequences.reserve_slot(child)
63+
pt = dataclasses.replace(
64+
pt,
65+
page_indices=pt.page_indices.at["seq", parent, "page", 0]
66+
.set(jnp.array(4, dtype=jnp.int32))
67+
.at["seq", parent, "page", 1]
68+
.set(jnp.array(5, dtype=jnp.int32))
69+
.at["seq", parent, "page", 2]
70+
.set(jnp.array(-1, dtype=jnp.int32)),
71+
seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(8, dtype=jnp.int32)),
72+
)
5673

57-
seq_lens = sequences.seq_lens.at["seq", parent_id].set(8)
58-
page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(4)
59-
page_indices = page_indices.at["seq", parent_id, "page", 1].set(5)
60-
sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices)
61-
ref_counts = pt.page_ref_counts.at["page", 4].set(1)
62-
ref_counts = ref_counts.at["page", 5].set(1)
63-
pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq)
74+
new_pt = pt.clone_pages_from(parent, child)
6475

65-
sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id)
76+
# Child should share both pages
77+
assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 4
78+
assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) == 5
6679

67-
assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 4
68-
assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) == 5
69-
assert int(new_pt.page_ref_counts["page", 4].scalar()) == 2
70-
assert int(new_pt.page_ref_counts["page", 5].scalar()) == 2
71-
assert int(sequences.seq_lens["seq", child_id].scalar()) == 8
80+
# Refcounts incremented for both pages
81+
assert int(new_pt.page_ref_counts["page", 4].scalar()) == 1
82+
assert int(new_pt.page_ref_counts["page", 5].scalar()) == 1
83+
84+
# Lengths equal
85+
assert int(new_pt.seq_lens["seq", child].scalar()) == 8
7286

7387

7488
def test_kv_cache_copy_page():
75-
sequences, pt = _make_allocator(max_pages=3, max_seqs=1, page_size=2, pages_per_seq=1)
76-
kv = KvPageCache.init(pt.spec(), kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32)
89+
# Minimal KvPageCache with 3 pages and small dims
90+
from levanter.inference.page_table import PageTable as _PT
91+
92+
pt = _PT.init(max_pages=3, max_seqs=1, page_size=2, max_pages_per_seq=1)
93+
kv = KvPageCache.init(pt, kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32)
7794

95+
# Write identifiable values into page 1 for both K and V
7896
src_page = 1
7997
dst_page = 2
8098
k_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 0::2], 7.0)
8199
v_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 1::2], 3.0)
100+
# need to interleave k and v along kv_heads
82101
kv_pattern = hax.stack("inter", [k_pattern, v_pattern]).rearrange(
83102
"{inter kv_head} -> ... (kv_head: kv_head inter)"
84103
)
85-
kv = dataclasses.replace(kv, kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern))
104+
kv = dataclasses.replace(
105+
kv,
106+
kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern),
107+
)
86108

87109
kv2 = kv.copy_page(src_page, dst_page)
88110
np.testing.assert_allclose(

tests/inference/test_engine.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212

1313
from levanter.inference.engine import InferenceEngine, Request, InferenceEngineConfig
1414
from levanter.inference.jit_scheduler import SeqDecodingParams
15-
from levanter.inference.page_table import PageTableSpec
16-
from levanter.inference.utils import INVALID
15+
from levanter.inference.page_table import PageTable
1716
from levanter.layers.kv_cache import KvPageCache
1817

1918

2019
class DummyModel(eqx.Module):
2120
"""Minimal model stub to drive GenerationService for tests.
2221
23-
- `initial_cache` returns an empty KvPageCache sized to the page-table spec.
22+
- `initial_cache` returns an empty KvPageCache sized to the PageTable.
2423
- `decode` returns constant logits that strongly prefer token `EOS`.
2524
"""
2625

@@ -31,11 +30,11 @@ def __init__(self, vocab_size: int, eos_id: int = 3):
3130
self.Vocab = Axis("vocab", vocab_size)
3231
self.eos = eos_id
3332

34-
def initial_cache(self, spec: PageTableSpec, *, dtype):
33+
def initial_cache(self, page_table: PageTable, *, dtype):
3534
# Use trivial cache dimensions; the cache is unused by this dummy model
3635
kv_heads = Axis("kv_head", 1)
3736
head_size = Axis("embed", 1)
38-
return KvPageCache.init(spec, kv_heads, head_size, dtype=dtype)
37+
return KvPageCache.init(page_table, kv_heads, head_size, dtype=dtype)
3938

4039
def decode(self, input_ids, kv_cache, batch_info, pos_ids):
4140
# Produce logits that prefer `eos` for every sampled position
@@ -94,14 +93,13 @@ def test_release_on_finish_and_reuse_slots(caplog: pytest.LogCaptureFixture):
9493
assert result.tokens[1] == [3]
9594
assert result.total_generated == 2 # one new token per prompt
9695

97-
# Finished sequences are auto-released; no active seqs remain
98-
sequences = svc.gen_state.decode_state.sequences
96+
# Finished sequences are auto-released; PageTable should have no active seqs
9997
pt = svc.gen_state.decode_state.page_table
10098
# All slots should be marked unused and lengths zeroed
101-
seq_lens = jax.device_get(sequences.seq_lens.array)
102-
used_mask = jax.device_get(sequences.used_mask.array)
99+
seq_lens = jax.device_get(pt.seq_lens.array)
100+
used_mask = jax.device_get(pt.used_mask.array)
103101
assert (used_mask == 0).all()
104-
assert ((seq_lens == 0) | (seq_lens == INVALID)).all()
102+
assert (seq_lens == 0).all()
105103
# No pages should be held
106104
ref_counts = jax.device_get(pt.page_ref_counts.array)
107105
assert int(ref_counts.sum()) == 0

0 commit comments

Comments
 (0)