Skip to content
1 change: 1 addition & 0 deletions config/sampler/sample_nano.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ engine:
max_pages: 256
max_seqs: 8
page_size: 8
max_seq_len: 256
max_queued_tokens: 16
max_seqs_in_prefill: 8
max_prefill_size: 128
Expand Down
135 changes: 61 additions & 74 deletions src/levanter/inference/engine.py

Large diffs are not rendered by default.

686 changes: 603 additions & 83 deletions src/levanter/inference/jit_scheduler.py

Large diffs are not rendered by default.

496 changes: 20 additions & 476 deletions src/levanter/inference/page_table.py

Large diffs are not rendered by default.

11 changes: 3 additions & 8 deletions src/levanter/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from jax.sharding import PartitionSpec
from jaxtyping import PRNGKeyArray

from ..inference.page_table import PageBatchInfo, PageTable
from ..inference.page_table import PageBatchInfo, PageTableSpec
from .kv_cache import KvPageCache
from .normalization import LayerNormConfigBase
from .rotary import RotaryEmbeddings, RotaryEmbeddingsConfig
Expand Down Expand Up @@ -1564,13 +1564,8 @@ def init(config: AttentionConfig, *, key) -> "Attention":

return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs)

def empty_page_cache(self, page_table: PageTable, *, dtype) -> "KvPageCache":
return KvPageCache.init(
page_table,
self.config.KVHeads,
self.config.HeadSize,
dtype=dtype,
)
def empty_page_cache(self, spec: PageTableSpec, *, dtype) -> "KvPageCache":
return KvPageCache.init(spec, self.config.KVHeads, self.config.HeadSize, dtype=dtype)

@named_call
def __call__(
Expand Down
30 changes: 21 additions & 9 deletions src/levanter/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""Cache implementations for paged attention."""

from __future__ import annotations

import dataclasses
import functools
from typing import Generic, Iterable, Iterator, TypeVar, Self
Expand All @@ -18,7 +16,7 @@
from haliax import Axis, NamedArray
from haliax.jax_utils import named_call

from levanter.inference.page_table import PageBatchInfo, PageTable
from levanter.inference.page_table import PageBatchInfo, PageTableSpec


class PageCache(eqx.Module):
Expand All @@ -35,12 +33,20 @@ class KvPageCache(PageCache):
kv_pages: NamedArray # [Page, Slot, 2 * KVHeads, Embed]

@staticmethod
def init(page_table: PageTable, kv_heads: Axis, head_size: Axis, dtype=jnp.float32) -> "KvPageCache":
"""Allocate an empty KV cache matching *page_table* and head axes."""
def init(spec: PageTableSpec, kv_heads: Axis, head_size: Axis, dtype=jnp.float32) -> "KvPageCache":
"""
Initialize a KvPageCache with the given page table specification and dimensions.

Args:
spec: The layout specification for KV pages.
kv_heads: Axis for key/value heads.
head_size: Axis for head size.
dtype: Data type for the cache.
"""
kv_pages = hax.zeros(
{
"page": page_table.num_pages,
"slot": page_table.page_size,
"page": spec.num_pages,
"slot": spec.page_size,
"kv_head": 2 * kv_heads.size,
head_size.name: head_size.size,
},
Expand All @@ -58,7 +64,10 @@ def update(
"""Append keys and values to the cache based on *batch_info*."""
page_size = self.kv_pages.array.shape[1]

_ = page_size # keeps JIT-shapes available; retained for future assertions.
assert page_size == batch_info.page_size, (
f"Page size mismatch: {page_size} != {batch_info.page_size}. "
"Ensure that the page size in batch_info matches the kv_pages."
)

K = jnp.asarray(batch_info.num_new_tokens, jnp.int32)
t_pages, t_slots = batch_info.pages_and_slots() # [T] int32 (first K valid)
Expand All @@ -75,7 +84,10 @@ def update(
return dataclasses.replace(self, kv_pages=updated)

def copy_page(self, src_page: int, dst_page: int) -> "KvPageCache":
"""Copy an entire page of cached keys/values."""
"""Copy the entire contents of page ``src_page`` into ``dst_page``.

This is used when creating clones that should have an identical last partial page, but mapped to a fresh page.
"""
new_k = self.kv_pages.at["page", dst_page].set(self.kv_pages["page", src_page])
return dataclasses.replace(self, kv_pages=new_k)

Expand Down
17 changes: 8 additions & 9 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Callable, Dict, Optional, Type, Union

import equinox as eqx
import haliax.debug
import jax
import jax.random as jrandom
from jaxtyping import PRNGKeyArray
Expand All @@ -19,7 +18,7 @@
from haliax.state_dict import ModuleWithStateDictSerialization

from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig
from levanter.inference.page_table import PageBatchInfo, PageTable
from levanter.inference.page_table import PageBatchInfo, PageTableSpec
from levanter.layers import LayerNormConfigBase, RmsNormConfig
from levanter.layers.attention import Attention, AttentionBackend, AttentionConfig, AttentionMask
from levanter.layers.kv_cache import KvPageCache, ListCache
Expand Down Expand Up @@ -365,12 +364,12 @@ def decode(
output = residual + mlp_output
return output, kv_cache

def initial_cache(self, page_table: PageTable, *, dtype) -> KvPageCache:
def initial_cache(self, spec: PageTableSpec, *, dtype) -> KvPageCache:
"""
Creates an empty page cache for this layer. Note that in order to create a decoder state, you
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
"""
return self.self_attn.empty_page_cache(page_table, dtype=dtype)
return self.self_attn.empty_page_cache(spec, dtype=dtype)


class LlamaTransformer(eqx.Module):
Expand Down Expand Up @@ -448,14 +447,14 @@ def decode(

return x, ListCache(updated_caches)

def initial_cache(self, page_table: PageTable, *, dtype) -> ListCache[KvPageCache]:
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
"""
Creates an empty page cache for this transformer. Note that in order to create a decoder state, you
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
"""
# sadly this is too cute/smart for XLA to handle aliasing correctly
# return self.layers.vmap_via(LlamaDecoderLayer.initial_cache)(page_table, dtype=dtype)
caches = [layer.initial_cache(page_table, dtype=dtype) for layer in self.layers.unstacked()]
# return self.layers.vmap_via(LlamaDecoderLayer.initial_cache)(spec, dtype=dtype)
caches = [layer.initial_cache(spec, dtype=dtype) for layer in self.layers.unstacked()]
return ListCache(caches)


Expand Down Expand Up @@ -605,12 +604,12 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
return {"transformer": "model", "embeddings": None}

def initial_cache(self, page_table: PageTable, *, dtype) -> ListCache[KvPageCache]:
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
"""
Creates an initial cache for this model. Note that in order to create a decoder state, you
need to couple the KvPageCache to the PageTable's state with a BatchInfo object.
"""
return hax.auto_sharded(self.transformer.initial_cache(page_table, dtype=dtype))
return hax.auto_sharded(self.transformer.initial_cache(spec, dtype=dtype))

@named_call
def decode(
Expand Down
102 changes: 40 additions & 62 deletions tests/inference/test_clone_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,103 +8,81 @@

import haliax as hax

from levanter.inference.jit_scheduler import SequenceTable
from levanter.inference.page_table import PageTable
from levanter.layers.kv_cache import KvPageCache


def _make_table(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3):
return PageTable.init(max_pages, max_seqs, page_size, pages_per_seq)
def _make_allocator(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3):
pt = PageTable.init(max_pages, max_seqs, page_size, pages_per_seq)
sequences = SequenceTable.init(max_seqs, pages_per_seq, page_size)
return sequences, pt


def test_clone_pages_from_partial_last_page_allocates_fresh_page():
pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)

# Parent uses two pages with a partial last page: length = 5 (pages 0 and 1 used)
parent = 0
child = 1

pt = dataclasses.replace(
pt,
page_indices=pt.page_indices.at["seq", parent, "page", 0]
.set(jnp.array(2, dtype=jnp.int32))
.at["seq", parent, "page", 1]
.set(jnp.array(3, dtype=jnp.int32))
.at["seq", parent, "page", 2]
.set(jnp.array(-1, dtype=jnp.int32)),
seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(5, dtype=jnp.int32)),
)

# Clone
new_pt = pt.clone_pages_from(parent, child)
sequences, parent_id = sequences.reserve_slot(parent)
sequences, child_id = sequences.reserve_slot(child)

# Fully used pages (all but last partial) should be shared: page 0 mapping identical
assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 2
seq_lens = sequences.seq_lens.at["seq", parent_id].set(5)
page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(2)
page_indices = page_indices.at["seq", parent_id, "page", 1].set(3)
sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices)
ref_counts = pt.page_ref_counts.at["page", 2].set(1)
ref_counts = ref_counts.at["page", 3].set(1)
pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq)

# Last page must be a fresh allocation, different from parent's last page (3)
assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) != 3
sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id)

# Refcounts: +1 for shared full page (page 2) and +1 for newly allocated page; parent's partial last page unchanged
ref_shared = int(new_pt.page_ref_counts["page", 2].scalar())
ref_parent_last = int(new_pt.page_ref_counts["page", 3].scalar())
assert ref_shared == 1
assert ref_parent_last == 0

# Lengths equal (no rounding)
assert int(new_pt.seq_lens["seq", child].scalar()) == 5
assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 2
assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) != 3
assert int(new_pt.page_ref_counts["page", 2].scalar()) == 2
assert int(new_pt.page_ref_counts["page", 3].scalar()) == 1
assert int(sequences.seq_lens["seq", child_id].scalar()) == 5


def test_clone_pages_from_boundary_shares_last_page():
pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)
sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3)

# Parent uses exactly 2 full pages: length = 8 (pages 4 and 5 used)
parent = 0
child = 1

pt = dataclasses.replace(
pt,
page_indices=pt.page_indices.at["seq", parent, "page", 0]
.set(jnp.array(4, dtype=jnp.int32))
.at["seq", parent, "page", 1]
.set(jnp.array(5, dtype=jnp.int32))
.at["seq", parent, "page", 2]
.set(jnp.array(-1, dtype=jnp.int32)),
seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(8, dtype=jnp.int32)),
)
sequences, parent_id = sequences.reserve_slot(parent)
sequences, child_id = sequences.reserve_slot(child)

new_pt = pt.clone_pages_from(parent, child)
seq_lens = sequences.seq_lens.at["seq", parent_id].set(8)
page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(4)
page_indices = page_indices.at["seq", parent_id, "page", 1].set(5)
sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices)
ref_counts = pt.page_ref_counts.at["page", 4].set(1)
ref_counts = ref_counts.at["page", 5].set(1)
pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq)

# Child should share both pages
assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 4
assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) == 5
sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id)

# Refcounts incremented for both pages
assert int(new_pt.page_ref_counts["page", 4].scalar()) == 1
assert int(new_pt.page_ref_counts["page", 5].scalar()) == 1

# Lengths equal
assert int(new_pt.seq_lens["seq", child].scalar()) == 8
assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 4
assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) == 5
assert int(new_pt.page_ref_counts["page", 4].scalar()) == 2
assert int(new_pt.page_ref_counts["page", 5].scalar()) == 2
assert int(sequences.seq_lens["seq", child_id].scalar()) == 8


def test_kv_cache_copy_page():
# Minimal KvPageCache with 3 pages and small dims
from levanter.inference.page_table import PageTable as _PT

pt = _PT.init(max_pages=3, max_seqs=1, page_size=2, max_pages_per_seq=1)
kv = KvPageCache.init(pt, kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32)
sequences, pt = _make_allocator(max_pages=3, max_seqs=1, page_size=2, pages_per_seq=1)
kv = KvPageCache.init(pt.spec(), kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32)

# Write identifiable values into page 1 for both K and V
src_page = 1
dst_page = 2
k_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 0::2], 7.0)
v_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 1::2], 3.0)
# need to interleave k and v along kv_heads
kv_pattern = hax.stack("inter", [k_pattern, v_pattern]).rearrange(
"{inter kv_head} -> ... (kv_head: kv_head inter)"
)
kv = dataclasses.replace(
kv,
kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern),
)
kv = dataclasses.replace(kv, kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern))

kv2 = kv.copy_page(src_page, dst_page)
np.testing.assert_allclose(
Expand Down
18 changes: 10 additions & 8 deletions tests/inference/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

from levanter.inference.engine import InferenceEngine, Request, InferenceEngineConfig
from levanter.inference.jit_scheduler import SeqDecodingParams
from levanter.inference.page_table import PageTable
from levanter.inference.page_table import PageTableSpec
from levanter.inference.utils import INVALID
from levanter.layers.kv_cache import KvPageCache


class DummyModel(eqx.Module):
"""Minimal model stub to drive GenerationService for tests.

- `initial_cache` returns an empty KvPageCache sized to the PageTable.
- `initial_cache` returns an empty KvPageCache sized to the page-table spec.
- `decode` returns constant logits that strongly prefer token `EOS`.
"""

Expand All @@ -30,11 +31,11 @@ def __init__(self, vocab_size: int, eos_id: int = 3):
self.Vocab = Axis("vocab", vocab_size)
self.eos = eos_id

def initial_cache(self, page_table: PageTable, *, dtype):
def initial_cache(self, spec: PageTableSpec, *, dtype):
# Use trivial cache dimensions; the cache is unused by this dummy model
kv_heads = Axis("kv_head", 1)
head_size = Axis("embed", 1)
return KvPageCache.init(page_table, kv_heads, head_size, dtype=dtype)
return KvPageCache.init(spec, kv_heads, head_size, dtype=dtype)

def decode(self, input_ids, kv_cache, batch_info, pos_ids):
# Produce logits that prefer `eos` for every sampled position
Expand Down Expand Up @@ -93,13 +94,14 @@ def test_release_on_finish_and_reuse_slots(caplog: pytest.LogCaptureFixture):
assert result.tokens[1] == [3]
assert result.total_generated == 2 # one new token per prompt

# Finished sequences are auto-released; PageTable should have no active seqs
# Finished sequences are auto-released; no active seqs remain
sequences = svc.gen_state.decode_state.sequences
pt = svc.gen_state.decode_state.page_table
# All slots should be marked unused and lengths zeroed
seq_lens = jax.device_get(pt.seq_lens.array)
used_mask = jax.device_get(pt.used_mask.array)
seq_lens = jax.device_get(sequences.seq_lens.array)
used_mask = jax.device_get(sequences.used_mask.array)
assert (used_mask == 0).all()
assert (seq_lens == 0).all()
assert ((seq_lens == 0) | (seq_lens == INVALID)).all()
# No pages should be held
ref_counts = jax.device_get(pt.page_ref_counts.array)
assert int(ref_counts.sum()) == 0
Expand Down
Loading
Loading