@@ -222,31 +222,42 @@ def filter_weights(prefix: str, weights: dict):
222
222
num_kv_heads * head_dim , :]
223
223
v_weight = qkv_weight [hidden_size +
224
224
num_kv_heads * head_dim :, :]
225
- module .load_weights (weights = [
226
- {
227
- 'weight' : q_weight
228
- },
229
- {
230
- 'weight' : k_weight
231
- },
232
- {
233
- 'weight' : v_weight
234
- },
235
- ])
225
+
226
+ # Get the scale factor for the fused QKV projection
227
+ qkv_scale = module_weights .get ('weight_scale' , None )
228
+
229
+ q_dict = {'weight' : q_weight }
230
+ if qkv_scale is not None :
231
+ q_dict ['weight_scale' ] = qkv_scale
232
+
233
+ k_dict = {'weight' : k_weight }
234
+ if qkv_scale is not None :
235
+ k_dict ['weight_scale' ] = qkv_scale # Use same scale
236
+
237
+ v_dict = {'weight' : v_weight }
238
+ if qkv_scale is not None :
239
+ v_dict ['weight_scale' ] = qkv_scale # Use same scale
240
+
241
+ module .load_weights (weights = [q_dict , k_dict , v_dict ])
236
242
elif "mlp.gate_up_proj" in name :
237
243
# The weights need to be split correctly before sharding to support tp_size >1.
238
244
intermediate_size = self .config .intermediate_size
239
245
gate_up_weight = module_weights ['weight' ][:]
240
246
gate_weight = gate_up_weight [:intermediate_size , :]
241
247
up_weight = gate_up_weight [intermediate_size :, :]
242
- module .load_weights (weights = [
243
- {
244
- 'weight' : gate_weight
245
- },
246
- {
247
- 'weight' : up_weight
248
- },
249
- ])
248
+
249
+ # Get the scale factors if they exist
250
+ gate_up_scale = module_weights .get ('weight_scale' , None )
251
+
252
+ gate_dict = {'weight' : gate_weight }
253
+ if gate_up_scale is not None :
254
+ gate_dict ['weight_scale' ] = gate_up_scale
255
+
256
+ up_dict = {'weight' : up_weight }
257
+ if gate_up_scale is not None :
258
+ up_dict ['weight_scale' ] = gate_up_scale
259
+
260
+ module .load_weights (weights = [gate_dict , up_dict ])
250
261
else :
251
262
module .load_weights (weights = [module_weights ])
252
263
else :
0 commit comments