Skip to content

Commit a226a58

Browse files
authored
added tie_word_embeddings to llama3_2 models (#2331)
1 parent 9f14fe9 commit a226a58

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

torchtune/models/llama3_2/_component_builders.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def llama3_2(
5252
intermediate_dim: Optional[int] = None,
5353
norm_eps: float = 1e-5,
5454
scale_factor: int = 32,
55+
tie_word_embeddings: bool = True,
5556
) -> TransformerDecoder:
5657
"""
5758
Build the decoder associated with the Llama3.2 model. This includes:
@@ -78,6 +79,7 @@ def llama3_2(
7879
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
7980
norm_eps (float): epsilon in RMS norms.
8081
scale_factor (int): scaling factor for RoPE. Default: 32
82+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
8183
8284
Returns:
8385
TransformerDecoder: Instantiation of Llama3.2 model.
@@ -112,7 +114,11 @@ def llama3_2(
112114
layers.append(layer)
113115

114116
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
115-
output_proj = TiedLinear(tok_embeddings)
117+
if tie_word_embeddings:
118+
output_proj = TiedLinear(tok_embeddings)
119+
else:
120+
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
121+
116122
return TransformerDecoder(
117123
tok_embeddings=tok_embeddings,
118124
layers=layers,
@@ -161,6 +167,7 @@ def lora_llama3_2(
161167
use_dora: bool = False,
162168
# Quantization args
163169
quantize_base: bool = False,
170+
tie_word_embeddings: bool = True,
164171
) -> TransformerDecoder:
165172
"""
166173
Return a version of Llama3.2 (an instance of :func:`~torchtune.modules.TransformerDecoder`)
@@ -197,6 +204,7 @@ def lora_llama3_2(
197204
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
198205
weights within linear layers LoRA is applied to. The final output linear projection is not
199206
supported for quantization currently.
207+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
200208
201209
Returns:
202210
TransformerDecoder: Instantiation of Llama3.2 model with LoRA applied to
@@ -254,7 +262,11 @@ def lora_llama3_2(
254262
"apply_lora_to_output is currently not supporting in llama3.2 1b and 3b,"
255263
"as the projection layer weights are tied to the embeddings"
256264
)
257-
output_proj = TiedLinear(tok_embeddings)
265+
if tie_word_embeddings:
266+
output_proj = TiedLinear(tok_embeddings)
267+
else:
268+
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
269+
258270
model = TransformerDecoder(
259271
tok_embeddings=tok_embeddings,
260272
layers=layers,

torchtune/models/llama3_2/_model_builders.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
the llama3_2_1b model builder uses the llama3_2 component builder to create the
1717
Llama3.2 1B model.
1818
"""
19-
def llama3_2_1b() -> TransformerDecoder:
19+
def llama3_2_1b(
20+
tie_word_embeddings: bool = True,
21+
) -> TransformerDecoder:
2022
"""
2123
Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.
2224
25+
Args:
26+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
27+
2328
Returns:
2429
TransformerDecoder: Instantiation of Llama3.2 1B model
2530
"""
@@ -35,11 +40,17 @@ def llama3_2_1b() -> TransformerDecoder:
3540
norm_eps=1e-5,
3641
rope_base=500_000,
3742
scale_factor=32,
43+
tie_word_embeddings=tie_word_embeddings,
3844
)
39-
def llama3_2_3b() -> TransformerDecoder:
45+
def llama3_2_3b(
46+
tie_word_embeddings: bool = True,
47+
) -> TransformerDecoder:
4048
"""
4149
Builder for creating a Llama3.2 model initialized w/ the default 3b parameter values.
4250
51+
Args:
52+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
53+
4354
Returns:
4455
TransformerDecoder: Instantiation of Llama3.2 3B model
4556
"""
@@ -55,6 +66,7 @@ def llama3_2_3b() -> TransformerDecoder:
5566
norm_eps=1e-5,
5667
rope_base=500_000,
5768
scale_factor=32,
69+
tie_word_embeddings=tie_word_embeddings,
5870
)
5971
def lora_llama3_2_1b(
6072
lora_attn_modules: List[LORA_ATTN_MODULES],
@@ -65,6 +77,7 @@ def lora_llama3_2_1b(
6577
lora_dropout: float = 0.0,
6678
use_dora: bool = False,
6779
quantize_base: bool = False,
80+
tie_word_embeddings: bool = True,
6881
) -> TransformerDecoder:
6982
"""
7083
Builder for creating a Llama3.2 1B model with LoRA enabled.
@@ -86,6 +99,7 @@ def lora_llama3_2_1b(
8699
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
87100
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
88101
quantize_base (bool): Whether to quantize base model weights
102+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
89103
90104
Returns:
91105
TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied
@@ -110,6 +124,7 @@ def lora_llama3_2_1b(
110124
lora_dropout=lora_dropout,
111125
use_dora=use_dora,
112126
quantize_base=quantize_base,
127+
tie_word_embeddings=tie_word_embeddings,
113128
)
114129
def lora_llama3_2_3b(
115130
lora_attn_modules: List[LORA_ATTN_MODULES],
@@ -120,6 +135,7 @@ def lora_llama3_2_3b(
120135
lora_dropout: float = 0.0,
121136
use_dora: bool = False,
122137
quantize_base: bool = False,
138+
tie_word_embeddings: bool = True,
123139
) -> TransformerDecoder:
124140
"""
125141
Builder for creating a Llama3.2 3B model with LoRA enabled.
@@ -141,6 +157,7 @@ def lora_llama3_2_3b(
141157
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
142158
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
143159
quantize_base (bool): Whether to quantize base model weights
160+
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
144161
145162
Returns:
146163
TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
@@ -166,6 +183,7 @@ def lora_llama3_2_3b(
166183
lora_dropout=lora_dropout,
167184
use_dora=use_dora,
168185
quantize_base=quantize_base,
186+
tie_word_embeddings=tie_word_embeddings,
169187
)
170188
qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True)
171189
qlora_llama3_2_1b.__doc__ = """

0 commit comments

Comments
 (0)