Skip to content

Commit 32965e0

Browse files
authored
remove redundant next_cache (#817)
1 parent c7a4362 commit 32965e0

File tree

5 files changed

+0
-10
lines changed

5 files changed

+0
-10
lines changed

ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,15 +496,13 @@
496496
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
497497
" mask = mask[None, None, :, :]\n",
498498
"\n",
499-
" next_cache = []\n",
500499
" for i, block in enumerate(self.trf_blocks):\n",
501500
" blk_cache = cache.get(i) if cache else None\n",
502501
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
503502
" start_pos=pos_start,\n",
504503
" cache=blk_cache)\n",
505504
" if cache is not None:\n",
506505
" cache.update(i, new_blk_cache)\n",
507-
" next_cache.append(new_blk_cache)\n",
508506
"\n",
509507
" x = self.final_norm(x)\n",
510508
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",

ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,15 +422,13 @@
422422
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
423423
" mask = mask[None, None, :, :]\n",
424424
"\n",
425-
" next_cache = []\n",
426425
" for i, block in enumerate(self.trf_blocks):\n",
427426
" blk_cache = cache.get(i) if cache else None\n",
428427
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
429428
" start_pos=pos_start,\n",
430429
" cache=blk_cache)\n",
431430
" if cache is not None:\n",
432431
" cache.update(i, new_blk_cache)\n",
433-
" next_cache.append(new_blk_cache)\n",
434432
"\n",
435433
" x = self.final_norm(x)\n",
436434
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",

pkg/llms_from_scratch/kv_cache/gpt2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,11 @@ def forward(self, in_idx, use_cache=False, cache=None):
177177
else:
178178
start_pos = 0
179179

180-
next_cache = []
181180
for i, block in enumerate(self.trf_blocks):
182181
blk_cache = cache.get(i) if cache else None
183182
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
184183
if cache:
185184
cache.update(i, new_cache)
186-
next_cache.append(new_cache)
187185

188186
x = self.final_norm(x)
189187
logits = self.out_head(x)

pkg/llms_from_scratch/kv_cache/llama3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ def forward(self, in_idx, cache=None):
9797
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
9898
mask = mask[None, None, :, :]
9999

100-
next_cache = []
101100
for i, block in enumerate(self.trf_blocks):
102101
blk_cache = cache.get(i) if cache else None
103102
x, new_blk_cache = block(x, mask, self.cos, self.sin,
104103
start_pos=pos_start,
105104
cache=blk_cache)
106105
if cache is not None:
107106
cache.update(i, new_blk_cache)
108-
next_cache.append(new_blk_cache)
109107

110108
x = self.final_norm(x)
111109
logits = self.out_head(x.to(self.cfg["dtype"]))

pkg/llms_from_scratch/kv_cache/qwen3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,13 @@ def forward(self, in_idx, cache=None):
6565
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
6666
mask = mask[None, None, :, :]
6767

68-
next_cache = []
6968
for i, block in enumerate(self.trf_blocks):
7069
blk_cache = cache.get(i) if cache else None
7170
x, new_blk_cache = block(x, mask, self.cos, self.sin,
7271
start_pos=pos_start,
7372
cache=blk_cache)
7473
if cache is not None:
7574
cache.update(i, new_blk_cache)
76-
next_cache.append(new_blk_cache)
7775

7876
x = self.final_norm(x)
7977
logits = self.out_head(x.to(self.cfg["dtype"]))

0 commit comments

Comments
 (0)