1616the llama3_2_1b model builder uses the llama3_2 component builder to create the
1717Llama3.2 1B model.
1818"""
19- def llama3_2_1b () -> TransformerDecoder :
19+ def llama3_2_1b (
20+ tie_word_embeddings : bool = True ,
21+ ) -> TransformerDecoder :
2022 """
2123 Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.
2224
25+ Args:
26+ tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
27+
2328 Returns:
2429 TransformerDecoder: Instantiation of Llama3.2 1B model
2530 """
@@ -35,11 +40,17 @@ def llama3_2_1b() -> TransformerDecoder:
3540 norm_eps = 1e-5 ,
3641 rope_base = 500_000 ,
3742 scale_factor = 32 ,
43+ tie_word_embeddings = tie_word_embeddings ,
3844 )
39- def llama3_2_3b () -> TransformerDecoder :
45+ def llama3_2_3b (
46+ tie_word_embeddings : bool = True ,
47+ ) -> TransformerDecoder :
4048 """
4149 Builder for creating a Llama3.2 model initialized w/ the default 3b parameter values.
4250
51+ Args:
52+ tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
53+
4354 Returns:
4455 TransformerDecoder: Instantiation of Llama3.2 3B model
4556 """
@@ -55,6 +66,7 @@ def llama3_2_3b() -> TransformerDecoder:
5566 norm_eps = 1e-5 ,
5667 rope_base = 500_000 ,
5768 scale_factor = 32 ,
69+ tie_word_embeddings = tie_word_embeddings ,
5870 )
5971def lora_llama3_2_1b (
6072 lora_attn_modules : List [LORA_ATTN_MODULES ],
@@ -65,6 +77,7 @@ def lora_llama3_2_1b(
6577 lora_dropout : float = 0.0 ,
6678 use_dora : bool = False ,
6779 quantize_base : bool = False ,
80+ tie_word_embeddings : bool = True ,
6881) -> TransformerDecoder :
6982 """
7083 Builder for creating a Llama3.2 1B model with LoRA enabled.
@@ -86,6 +99,7 @@ def lora_llama3_2_1b(
8699 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
87100 introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
88101 quantize_base (bool): Whether to quantize base model weights
102+ tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
89103
90104 Returns:
91105 TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied
@@ -110,6 +124,7 @@ def lora_llama3_2_1b(
110124 lora_dropout = lora_dropout ,
111125 use_dora = use_dora ,
112126 quantize_base = quantize_base ,
127+ tie_word_embeddings = tie_word_embeddings ,
113128 )
114129def lora_llama3_2_3b (
115130 lora_attn_modules : List [LORA_ATTN_MODULES ],
@@ -120,6 +135,7 @@ def lora_llama3_2_3b(
120135 lora_dropout : float = 0.0 ,
121136 use_dora : bool = False ,
122137 quantize_base : bool = False ,
138+ tie_word_embeddings : bool = True ,
123139) -> TransformerDecoder :
124140 """
125141 Builder for creating a Llama3.2 3B model with LoRA enabled.
@@ -141,6 +157,7 @@ def lora_llama3_2_3b(
141157 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
142158 introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
143159 quantize_base (bool): Whether to quantize base model weights
160+ tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
144161
145162 Returns:
146163 TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
@@ -166,6 +183,7 @@ def lora_llama3_2_3b(
166183 lora_dropout = lora_dropout ,
167184 use_dora = use_dora ,
168185 quantize_base = quantize_base ,
186+ tie_word_embeddings = tie_word_embeddings ,
169187 )
170188qlora_llama3_2_1b = partial (lora_llama3_2_1b , quantize_base = True )
171189qlora_llama3_2_1b .__doc__ = """
0 commit comments