Skip to content

Commit 2d32730

Browse files
committed
don't forget about gemma, nicer docstrings, fix a dora bug
1 parent 8440cee commit 2d32730

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

torchtune/models/gemma/_component_builders.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ def lora_gemma_self_attention(
289289
alpha=lora_alpha,
290290
dropout=lora_dropout,
291291
quantize_base=quantize_base,
292-
use_dora=use_dora,
293292
)
294293
if "q_proj" in lora_modules
295294
else (
@@ -306,7 +305,6 @@ def lora_gemma_self_attention(
306305
alpha=lora_alpha,
307306
dropout=lora_dropout,
308307
quantize_base=quantize_base,
309-
use_dora=use_dora,
310308
)
311309
if "k_proj" in lora_modules
312310
else (
@@ -323,7 +321,6 @@ def lora_gemma_self_attention(
323321
alpha=lora_alpha,
324322
dropout=lora_dropout,
325323
quantize_base=quantize_base,
326-
use_dora=use_dora,
327324
)
328325
if "v_proj" in lora_modules
329326
else (
@@ -340,7 +337,6 @@ def lora_gemma_self_attention(
340337
alpha=lora_alpha,
341338
dropout=lora_dropout,
342339
quantize_base=quantize_base,
343-
use_dora=use_dora,
344340
)
345341
if "output_proj" in lora_modules
346342
else (
@@ -385,7 +381,6 @@ def lora_gemma_mlp(
385381
alpha=lora_alpha,
386382
dropout=lora_dropout,
387383
quantize_base=quantize_base,
388-
use_dora=use_dora,
389384
)
390385
down_proj = adapter_cls(
391386
in_dim=hidden_dim,
@@ -394,7 +389,6 @@ def lora_gemma_mlp(
394389
alpha=lora_alpha,
395390
dropout=lora_dropout,
396391
quantize_base=quantize_base,
397-
use_dora=use_dora,
398392
)
399393
up_proj = adapter_cls(
400394
in_dim=dim,
@@ -403,7 +397,6 @@ def lora_gemma_mlp(
403397
alpha=lora_alpha,
404398
dropout=lora_dropout,
405399
quantize_base=quantize_base,
406-
use_dora=use_dora,
407400
)
408401
activation = nn.GELU(approximate="tanh")
409402

torchtune/models/gemma/transformer.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import List, Optional
88

99
import torch
1010
import torch.nn as nn
@@ -98,6 +98,28 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
9898
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool)
9999
)
100100

101+
@torch.compiler.disable
102+
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
103+
"""
104+
Apply output projection in chunks. This should be applied in conjunction with
105+
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
106+
107+
To use this method, you should first call
108+
:func:`~torchtune.models.gemma.GemmaTransformerDecoder.set_num_output_chunks`.
109+
110+
Args:
111+
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
112+
[b, seq_len, embed_dim].
113+
114+
Returns:
115+
List[torch.Tensor]: List of num_chunks output tensors, each with shape
116+
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
117+
"""
118+
return [
119+
F.linear(chunk, self.tok_embeddings.weight)
120+
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
121+
]
122+
101123
def forward(
102124
self,
103125
tokens: torch.Tensor,
@@ -168,13 +190,7 @@ def forward(
168190
h = self.norm(h)
169191

170192
if self.num_output_chunks > 0:
171-
# shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size
172-
# Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe,
173-
# before calling forward. Upcasting it done inside of the loss function.
174-
output = [
175-
F.linear(chunk, self.tok_embeddings.weight)
176-
for chunk in h.chunk(self.num_output_chunks, dim=1)
177-
]
193+
output = self.chunked_output(h)
178194
else:
179195
# shape: [b, seq_len, out_dim]
180196
output = F.linear(h, self.tok_embeddings.weight).float()

torchtune/modules/transformer.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,19 @@ def reset_caches(self):
383383
@torch.compiler.disable
384384
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
385385
"""
386-
shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size
387-
Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe,
388-
before calling forward. Upcasting it done inside of the loss function.
386+
Apply output projection in chunks. This should be applied in conjunction with
387+
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
388+
389+
To use this method, you should first call
390+
:func:`~torchtune.modules.TransformerDecoder.set_num_output_chunks`.
391+
392+
Args:
393+
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
394+
[b, seq_len, embed_dim].
395+
396+
Returns:
397+
List[torch.Tensor]: List of num_chunks output tensors, each with shape
398+
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
389399
"""
390400
return [
391401
self.output(chunk)
@@ -604,9 +614,19 @@ def reset_caches(self):
604614
@torch.compiler.disable
605615
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
606616
"""
607-
shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size
608-
Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe,
609-
before calling forward. Upcasting it done inside of the loss function.
617+
Apply output projection in chunks. This should be applied in conjunction with
618+
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
619+
620+
To use this method, you should first call
621+
:func:`~torchtune.modules.TiedEmbeddingTransformerDecoder.set_num_output_chunks`.
622+
623+
Args:
624+
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
625+
[b, seq_len, embed_dim].
626+
627+
Returns:
628+
List[torch.Tensor]: List of num_chunks output tensors, each with shape
629+
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
610630
"""
611631
return [
612632
F.linear(chunk, self.tok_embeddings.weight)

0 commit comments

Comments
 (0)