Skip to content

Commit 3c9dc48

Browse files
authored
Simplify KV cache usage (#728)
* Simplify KV cache usage * Swap mark text with ghostwriter
1 parent 9cf6417 commit 3c9dc48

File tree

4 files changed

+31
-39
lines changed

4 files changed

+31
-39
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ git clone --depth 1 https://github.com/rasbt/LLMs-from-scratch.git
4040

4141
# Table of Contents
4242

43-
Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [MarkText](https://www.marktext.cc) is a good free option.
43+
Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [Ghostwriter](https://ghostwriter.kde.org) is a good free option.
4444

4545
You can alternatively view this and other files on GitHub at [https://github.com/rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch) in your browser, which renders Markdown automatically.
4646

pkg/llms_from_scratch/kv_cache/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
1111
model.eval()
1212
ctx_len = context_size or model.cfg["context_length"]
13-
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
1413

1514
with torch.no_grad():
1615
if use_cache:
16+
cache = KVCache(n_layers=model.cfg["n_layers"])
1717
model.reset_kv_cache()
18-
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
18+
logits = model(idx[:, -ctx_len:], cache=cache)
1919

2020
for _ in range(max_new_tokens):
2121
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
2222
idx = torch.cat([idx, next_idx], dim=1)
23-
logits = model(next_idx, use_cache=True, cache=cache)
23+
logits = model(next_idx, cache=cache)
2424
else:
2525
for _ in range(max_new_tokens):
26-
logits = model(idx[:, -ctx_len:], use_cache=False)
26+
logits = model(idx[:, -ctx_len:], cache=None)
2727
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
2828
idx = torch.cat([idx, next_idx], dim=1)
2929

pkg/llms_from_scratch/kv_cache/llama3.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def __init__(self, cfg):
7777
self.cfg = cfg
7878
self.current_pos = 0 # Track current position in KV cache
7979

80-
def forward(self, in_idx, use_cache=False, cache=None):
80+
def forward(self, in_idx, cache=None):
8181
tok_embeds = self.tok_emb(in_idx)
8282
x = tok_embeds
8383

8484
num_tokens = x.shape[1]
85-
if use_cache:
85+
if cache is not None:
8686
pos_start = self.current_pos
8787
pos_end = pos_start + num_tokens
8888
self.current_pos = pos_end
@@ -101,10 +101,9 @@ def forward(self, in_idx, use_cache=False, cache=None):
101101
for i, block in enumerate(self.trf_blocks):
102102
blk_cache = cache.get(i) if cache else None
103103
x, new_blk_cache = block(x, mask, self.cos, self.sin,
104-
use_cache=use_cache,
105104
start_pos=pos_start,
106105
cache=blk_cache)
107-
if cache:
106+
if cache is not None:
108107
cache.update(i, new_blk_cache)
109108
next_cache.append(new_blk_cache)
110109

@@ -130,11 +129,11 @@ def __init__(self, cfg):
130129
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
131130
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
132131

133-
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
132+
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
134133
# Shortcut connection for attention block
135134
shortcut = x
136135
x = self.norm1(x)
137-
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
136+
x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
138137
x = x + shortcut # Add the original input back
139138

140139
# Shortcut connection for feed-forward block
@@ -180,7 +179,7 @@ def __init__(
180179
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
181180
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
182181

183-
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
182+
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
184183
b, num_tokens, _ = x.shape
185184

186185
# Apply projections
@@ -197,18 +196,15 @@ def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
197196
queries = apply_rope(queries, cos, sin, offset=start_pos)
198197
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
199198

200-
if use_cache:
201-
if cache is None:
202-
keys = keys_new
203-
values = values_new
204-
else:
205-
prev_k, prev_v = cache
206-
keys = torch.cat([prev_k, keys_new], dim=2)
207-
values = torch.cat([prev_v, values_new], dim=2)
199+
if cache is not None:
200+
prev_k, prev_v = cache
201+
keys = torch.cat([prev_k, keys_new], dim=2)
202+
values = torch.cat([prev_v, values_new], dim=2)
208203
next_cache = (keys, values)
209204
else:
205+
start_pos = 0 # reset RoPE
210206
keys, values = keys_new, values_new
211-
next_cache = None
207+
next_cache = (keys, values)
212208

213209
# Expand keys and values to match the number of heads
214210
# Shape: (b, num_heads, num_tokens, head_dim)
@@ -226,7 +222,7 @@ def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
226222
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
227223

228224
# Use the mask to fill attention scores
229-
attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
225+
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
230226

231227
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
232228
assert keys.shape[-1] == self.head_dim
@@ -286,7 +282,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
286282
return cos, sin
287283

288284

289-
def apply_rope(x, cos, sin, offset=9):
285+
def apply_rope(x, cos, sin, offset=0):
290286
# x: (batch_size, num_heads, seq_len, head_dim)
291287
batch_size, num_heads, seq_len, head_dim = x.shape
292288
assert head_dim % 2 == 0, "Head dimension must be even"

pkg/llms_from_scratch/kv_cache/qwen3.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def __init__(self, cfg):
4444
self.cfg = cfg
4545
self.current_pos = 0 # Track current position in KV cache
4646

47-
def forward(self, in_idx, use_cache=False, cache=None):
47+
def forward(self, in_idx, cache=None):
4848
# Forward pass
4949
tok_embeds = self.tok_emb(in_idx)
5050
x = tok_embeds
5151

5252
num_tokens = x.shape[1]
53-
if use_cache:
53+
if cache is not None:
5454
pos_start = self.current_pos
5555
pos_end = pos_start + num_tokens
5656
self.current_pos = pos_end
@@ -69,10 +69,9 @@ def forward(self, in_idx, use_cache=False, cache=None):
6969
for i, block in enumerate(self.trf_blocks):
7070
blk_cache = cache.get(i) if cache else None
7171
x, new_blk_cache = block(x, mask, self.cos, self.sin,
72-
use_cache=use_cache,
7372
start_pos=pos_start,
7473
cache=blk_cache)
75-
if cache:
74+
if cache is not None:
7675
cache.update(i, new_blk_cache)
7776
next_cache.append(new_blk_cache)
7877

@@ -99,11 +98,11 @@ def __init__(self, cfg):
9998
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
10099
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
101100

102-
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
101+
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
103102
# Shortcut connection for attention block
104103
shortcut = x
105104
x = self.norm1(x)
106-
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
105+
x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
107106
x = x + shortcut # Add the original input back
108107

109108
# Shortcut connection for feed-forward block
@@ -159,7 +158,7 @@ def __init__(
159158
else:
160159
self.q_norm = self.k_norm = None
161160

162-
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
161+
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
163162
b, num_tokens, _ = x.shape
164163

165164
# Apply projections
@@ -182,18 +181,15 @@ def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
182181
queries = apply_rope(queries, cos, sin, offset=start_pos)
183182
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
184183

185-
if use_cache:
186-
if cache is None:
187-
keys = keys_new
188-
values = values_new
189-
else:
190-
prev_k, prev_v = cache
191-
keys = torch.cat([prev_k, keys_new], dim=2)
192-
values = torch.cat([prev_v, values_new], dim=2)
184+
if cache is not None:
185+
prev_k, prev_v = cache
186+
keys = torch.cat([prev_k, keys_new], dim=2)
187+
values = torch.cat([prev_v, values_new], dim=2)
193188
next_cache = (keys, values)
194189
else:
190+
start_pos = 0 # reset RoPE
195191
keys, values = keys_new, values_new
196-
next_cache = None
192+
next_cache = (keys, values)
197193

198194
# Expand K and V to match number of heads
199195
keys = keys.repeat_interleave(self.group_size, dim=1)

0 commit comments

Comments
 (0)