|
8 | 8 |
|
9 | 9 | import haliax as hax |
10 | 10 |
|
11 | | -from levanter.inference.jit_scheduler import SequenceTable |
12 | 11 | from levanter.inference.page_table import PageTable |
13 | 12 | from levanter.layers.kv_cache import KvPageCache |
14 | 13 |
|
15 | 14 |
|
16 | | -def _make_allocator(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3): |
17 | | - pt = PageTable.init(max_pages, max_seqs, page_size, pages_per_seq) |
18 | | - sequences = SequenceTable.init(max_seqs, pages_per_seq, page_size) |
19 | | - return sequences, pt |
| 15 | +def _make_table(max_pages=8, max_seqs=2, page_size=4, pages_per_seq=3): |
| 16 | + return PageTable.init(max_pages, max_seqs, page_size, pages_per_seq) |
20 | 17 |
|
21 | 18 |
|
22 | 19 | def test_clone_pages_from_partial_last_page_allocates_fresh_page(): |
23 | | - sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3) |
| 20 | + pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3) |
24 | 21 |
|
| 22 | + # Parent uses two pages with a partial last page: length = 5 (pages 0 and 1 used) |
25 | 23 | parent = 0 |
26 | 24 | child = 1 |
27 | 25 |
|
28 | | - sequences, parent_id = sequences.reserve_slot(parent) |
29 | | - sequences, child_id = sequences.reserve_slot(child) |
| 26 | + pt = dataclasses.replace( |
| 27 | + pt, |
| 28 | + page_indices=pt.page_indices.at["seq", parent, "page", 0] |
| 29 | + .set(jnp.array(2, dtype=jnp.int32)) |
| 30 | + .at["seq", parent, "page", 1] |
| 31 | + .set(jnp.array(3, dtype=jnp.int32)) |
| 32 | + .at["seq", parent, "page", 2] |
| 33 | + .set(jnp.array(-1, dtype=jnp.int32)), |
| 34 | + seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(5, dtype=jnp.int32)), |
| 35 | + ) |
| 36 | + |
| 37 | + # Clone |
| 38 | + new_pt = pt.clone_pages_from(parent, child) |
30 | 39 |
|
31 | | - seq_lens = sequences.seq_lens.at["seq", parent_id].set(5) |
32 | | - page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(2) |
33 | | - page_indices = page_indices.at["seq", parent_id, "page", 1].set(3) |
34 | | - sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices) |
35 | | - ref_counts = pt.page_ref_counts.at["page", 2].set(1) |
36 | | - ref_counts = ref_counts.at["page", 3].set(1) |
37 | | - pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq) |
| 40 | + # Fully used pages (all but last partial) should be shared: page 0 mapping identical |
| 41 | + assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 2 |
38 | 42 |
|
39 | | - sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id) |
| 43 | + # Last page must be a fresh allocation, different from parent's last page (3) |
| 44 | + assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) != 3 |
40 | 45 |
|
41 | | - assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 2 |
42 | | - assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) != 3 |
43 | | - assert int(new_pt.page_ref_counts["page", 2].scalar()) == 2 |
44 | | - assert int(new_pt.page_ref_counts["page", 3].scalar()) == 1 |
45 | | - assert int(sequences.seq_lens["seq", child_id].scalar()) == 5 |
| 46 | + # Refcounts: +1 for shared full page (page 2) and +1 for newly allocated page; parent's partial last page unchanged |
| 47 | + ref_shared = int(new_pt.page_ref_counts["page", 2].scalar()) |
| 48 | + ref_parent_last = int(new_pt.page_ref_counts["page", 3].scalar()) |
| 49 | + assert ref_shared == 1 |
| 50 | + assert ref_parent_last == 0 |
| 51 | + |
| 52 | + # Lengths equal (no rounding) |
| 53 | + assert int(new_pt.seq_lens["seq", child].scalar()) == 5 |
46 | 54 |
|
47 | 55 |
|
48 | 56 | def test_clone_pages_from_boundary_shares_last_page(): |
49 | | - sequences, pt = _make_allocator(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3) |
| 57 | + pt = _make_table(max_pages=10, max_seqs=2, page_size=4, pages_per_seq=3) |
50 | 58 |
|
| 59 | + # Parent uses exactly 2 full pages: length = 8 (pages 4 and 5 used) |
51 | 60 | parent = 0 |
52 | 61 | child = 1 |
53 | 62 |
|
54 | | - sequences, parent_id = sequences.reserve_slot(parent) |
55 | | - sequences, child_id = sequences.reserve_slot(child) |
| 63 | + pt = dataclasses.replace( |
| 64 | + pt, |
| 65 | + page_indices=pt.page_indices.at["seq", parent, "page", 0] |
| 66 | + .set(jnp.array(4, dtype=jnp.int32)) |
| 67 | + .at["seq", parent, "page", 1] |
| 68 | + .set(jnp.array(5, dtype=jnp.int32)) |
| 69 | + .at["seq", parent, "page", 2] |
| 70 | + .set(jnp.array(-1, dtype=jnp.int32)), |
| 71 | + seq_lens=pt.seq_lens.at["seq", parent].set(jnp.array(8, dtype=jnp.int32)), |
| 72 | + ) |
56 | 73 |
|
57 | | - seq_lens = sequences.seq_lens.at["seq", parent_id].set(8) |
58 | | - page_indices = sequences.page_indices.at["seq", parent_id, "page", 0].set(4) |
59 | | - page_indices = page_indices.at["seq", parent_id, "page", 1].set(5) |
60 | | - sequences = dataclasses.replace(sequences, seq_lens=seq_lens, page_indices=page_indices) |
61 | | - ref_counts = pt.page_ref_counts.at["page", 4].set(1) |
62 | | - ref_counts = ref_counts.at["page", 5].set(1) |
63 | | - pt = PageTable(ref_counts, pt.page_size, pt._max_seqs, pt._pages_per_seq) |
| 74 | + new_pt = pt.clone_pages_from(parent, child) |
64 | 75 |
|
65 | | - sequences, new_pt = sequences.clone_pages_from(pt, parent_id, child_id) |
| 76 | + # Child should share both pages |
| 77 | + assert int(new_pt.page_indices["seq", child, "page", 0].scalar()) == 4 |
| 78 | + assert int(new_pt.page_indices["seq", child, "page", 1].scalar()) == 5 |
66 | 79 |
|
67 | | - assert int(sequences.page_indices["seq", child_id, "page", 0].scalar()) == 4 |
68 | | - assert int(sequences.page_indices["seq", child_id, "page", 1].scalar()) == 5 |
69 | | - assert int(new_pt.page_ref_counts["page", 4].scalar()) == 2 |
70 | | - assert int(new_pt.page_ref_counts["page", 5].scalar()) == 2 |
71 | | - assert int(sequences.seq_lens["seq", child_id].scalar()) == 8 |
| 80 | + # Refcounts incremented for both pages |
| 81 | + assert int(new_pt.page_ref_counts["page", 4].scalar()) == 1 |
| 82 | + assert int(new_pt.page_ref_counts["page", 5].scalar()) == 1 |
| 83 | + |
| 84 | + # Lengths equal |
| 85 | + assert int(new_pt.seq_lens["seq", child].scalar()) == 8 |
72 | 86 |
|
73 | 87 |
|
74 | 88 | def test_kv_cache_copy_page(): |
75 | | - sequences, pt = _make_allocator(max_pages=3, max_seqs=1, page_size=2, pages_per_seq=1) |
76 | | - kv = KvPageCache.init(pt.spec(), kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32) |
| 89 | + # Minimal KvPageCache with 3 pages and small dims |
| 90 | + from levanter.inference.page_table import PageTable as _PT |
| 91 | + |
| 92 | + pt = _PT.init(max_pages=3, max_seqs=1, page_size=2, max_pages_per_seq=1) |
| 93 | + kv = KvPageCache.init(pt, kv_heads=hax.Axis("kv_head", 2), head_size=hax.Axis("head", 3), dtype=jnp.float32) |
77 | 94 |
|
| 95 | + # Write identifiable values into page 1 for both K and V |
78 | 96 | src_page = 1 |
79 | 97 | dst_page = 2 |
80 | 98 | k_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 0::2], 7.0) |
81 | 99 | v_pattern = hax.full_like(kv.kv_pages["page", src_page, "kv_head", 1::2], 3.0) |
| 100 | + # need to interleave k and v along kv_heads |
82 | 101 | kv_pattern = hax.stack("inter", [k_pattern, v_pattern]).rearrange( |
83 | 102 | "{inter kv_head} -> ... (kv_head: kv_head inter)" |
84 | 103 | ) |
85 | | - kv = dataclasses.replace(kv, kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern)) |
| 104 | + kv = dataclasses.replace( |
| 105 | + kv, |
| 106 | + kv_pages=kv.kv_pages.at["page", src_page].set(kv_pattern), |
| 107 | + ) |
86 | 108 |
|
87 | 109 | kv2 = kv.copy_page(src_page, dst_page) |
88 | 110 | np.testing.assert_allclose( |
|
0 commit comments