@@ -497,3 +497,77 @@ def forward(self, in_idx):
497
497
x = self .final_norm (x )
498
498
logits = self .out_head (x .to (self .cfg ["dtype" ]))
499
499
return logits
500
+
501
+
502
+ def assign (left , right , tensor_name = "unknown" ):
503
+ if left .shape != right .shape :
504
+ raise ValueError (f"Shape mismatch in tensor '{ tensor_name } '. Left: { left .shape } , Right: { right .shape } " )
505
+
506
+ if isinstance (right , torch .Tensor ):
507
+ return torch .nn .Parameter (right .clone ().detach ())
508
+ else :
509
+ return torch .nn .Parameter (torch .tensor (right ))
510
+
511
+
512
+ def load_weights_into_llama (model , param_config , params ):
513
+ model .tok_emb .weight = assign (model .tok_emb .weight , params ["model.embed_tokens.weight" ], "model.embed_tokens.weight" )
514
+
515
+ for l in range (param_config ["n_layers" ]):
516
+
517
+ # Load attention weights
518
+ model .trf_blocks [l ].att .W_query .weight = assign (
519
+ model .trf_blocks [l ].att .W_query .weight ,
520
+ params [f"model.layers.{ l } .self_attn.q_proj.weight" ],
521
+ f"model.layers.{ l } .self_attn.q_proj.weight"
522
+ )
523
+ model .trf_blocks [l ].att .W_key .weight = assign (
524
+ model .trf_blocks [l ].att .W_key .weight ,
525
+ params [f"model.layers.{ l } .self_attn.k_proj.weight" ],
526
+ f"model.layers.{ l } .self_attn.k_proj.weight"
527
+ )
528
+ model .trf_blocks [l ].att .W_value .weight = assign (
529
+ model .trf_blocks [l ].att .W_value .weight ,
530
+ params [f"model.layers.{ l } .self_attn.v_proj.weight" ],
531
+ f"model.layers.{ l } .self_attn.v_proj.weight"
532
+ )
533
+ model .trf_blocks [l ].att .out_proj .weight = assign (
534
+ model .trf_blocks [l ].att .out_proj .weight ,
535
+ params [f"model.layers.{ l } .self_attn.o_proj.weight" ],
536
+ f"model.layers.{ l } .self_attn.o_proj.weight"
537
+ )
538
+ model .trf_blocks [l ].norm1 .weight = assign (
539
+ model .trf_blocks [l ].norm1 .weight ,
540
+ params [f"model.layers.{ l } .input_layernorm.weight" ],
541
+ f"model.layers.{ l } .input_layernorm.weight"
542
+ )
543
+
544
+ # Load FeedForward weights
545
+ model .trf_blocks [l ].ff .fc1 .weight = assign (
546
+ model .trf_blocks [l ].ff .fc1 .weight ,
547
+ params [f"model.layers.{ l } .mlp.gate_proj.weight" ],
548
+ f"model.layers.{ l } .mlp.gate_proj.weight"
549
+ )
550
+ model .trf_blocks [l ].ff .fc2 .weight = assign (
551
+ model .trf_blocks [l ].ff .fc2 .weight ,
552
+ params [f"model.layers.{ l } .mlp.up_proj.weight" ],
553
+ f"model.layers.{ l } .mlp.up_proj.weight"
554
+ )
555
+ model .trf_blocks [l ].ff .fc3 .weight = assign (
556
+ model .trf_blocks [l ].ff .fc3 .weight ,
557
+ params [f"model.layers.{ l } .mlp.down_proj.weight" ],
558
+ f"model.layers.{ l } .mlp.down_proj.weight"
559
+ )
560
+ model .trf_blocks [l ].norm2 .weight = assign (
561
+ model .trf_blocks [l ].norm2 .weight ,
562
+ params [f"model.layers.{ l } .post_attention_layernorm.weight" ],
563
+ f"model.layers.{ l } .post_attention_layernorm.weight"
564
+ )
565
+
566
+ # Load output layer weights
567
+ model .final_norm .weight = assign (model .final_norm .weight , params ["model.norm.weight" ], "model.norm.weight" )
568
+
569
+ if "lm_head.weight" in params .keys ():
570
+ model .out_head .weight = assign (model .out_head .weight , params ["lm_head.weight" ], "lm_head.weight" )
571
+ else :
572
+ model .out_head .weight = assign (model .out_head .weight , params ["model.embed_tokens.weight" ], "model.embed_tokens.weight" )
573
+ print ("Model uses weight tying." )
0 commit comments