@@ -435,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2):
435435 return checkpoint
436436
437437 @classmethod
438- def from_pretrained_llama3_hf (cls , model_id ):
438+ def from_pretrained_llama3_hf (cls , model_id , untie ):
439439 """Loads pretrained LLaMA model weights from HuggingFace"""
440440 from transformers import AutoModelForCausalLM , AutoTokenizer
441441 model_args = MODEL_DICT [model_id ]
442+ if untie :
443+ if not model_args .tied_embeddings :
444+ print ("Model embeddings are not tied, --untie has no effect." )
445+ else :
446+ print ("Untying token embeddings and LM head." )
447+ model_args .tied_embeddings = False
442448
443449 model = AutoModelForCausalLM .from_pretrained (model_id )
444450 checkpoint = LLaMA .adapt_llama_state_dict_keys_hf (model .state_dict (), model_args )
@@ -1026,6 +1032,7 @@ def print0(*args, **kwargs):
10261032 parser .add_argument ("--input_val_bin" , type = str , default = "" , help = "input .bin to eval validation loss on" )
10271033 parser .add_argument ("--output_dir" , type = str , default = "" , help = "output directory to which to write logs and checkpoints" )
10281034 parser .add_argument ("--model" , type = str , default = "meta-llama/Llama-3.2-1B" , help = "chose the llama model" )
1035+ parser .add_argument ("--untie" , type = int , default = False , help = "Untie token embeddings and LM-head, even if they are tied in the checkpoint." )
10291036 # token layout for each step of the optimization
10301037 parser .add_argument ("--batch_size" , type = int , default = 4 , help = "batch size, in units of #batch dimensions" )
10311038 parser .add_argument ("--sequence_length" , type = int , default = 64 , help = "sequence length" )
@@ -1131,7 +1138,7 @@ def print0(*args, **kwargs):
11311138
11321139 # init the model
11331140 if args .use_hf :
1134- model = LLaMA .from_pretrained_llama3_hf (args .model )
1141+ model = LLaMA .from_pretrained_llama3_hf (args .model , args . untie )
11351142 else : # use Meta's checkpoint
11361143 assert args .ckpt_dir is not None and os .path .exists (args .ckpt_dir ), f"llama3 ckpt dir { args .ckpt_dir } does not exist"
11371144 assert args .tokenizer_path is not None and os .path .exists (args .tokenizer_path ), f"llama3 tokenizer path { args .tokenizer_path } does not exist"
0 commit comments