Skip to content

Commit 7334f93

Browse files
authored
[None][fix] Accommodate Phi3/4 to work with ModelOpt's FP8 ckpts in Torch (#6761)
Signed-off-by: Michal Guzek <[email protected]>
1 parent d26a5a9 commit 7334f93

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

tensorrt_llm/_torch/models/modeling_phi3.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -222,31 +222,42 @@ def filter_weights(prefix: str, weights: dict):
222222
num_kv_heads * head_dim, :]
223223
v_weight = qkv_weight[hidden_size +
224224
num_kv_heads * head_dim:, :]
225-
module.load_weights(weights=[
226-
{
227-
'weight': q_weight
228-
},
229-
{
230-
'weight': k_weight
231-
},
232-
{
233-
'weight': v_weight
234-
},
235-
])
225+
226+
# Get the scale factor for the fused QKV projection
227+
qkv_scale = module_weights.get('weight_scale', None)
228+
229+
q_dict = {'weight': q_weight}
230+
if qkv_scale is not None:
231+
q_dict['weight_scale'] = qkv_scale
232+
233+
k_dict = {'weight': k_weight}
234+
if qkv_scale is not None:
235+
k_dict['weight_scale'] = qkv_scale # Use same scale
236+
237+
v_dict = {'weight': v_weight}
238+
if qkv_scale is not None:
239+
v_dict['weight_scale'] = qkv_scale # Use same scale
240+
241+
module.load_weights(weights=[q_dict, k_dict, v_dict])
236242
elif "mlp.gate_up_proj" in name:
237243
# The weights need to be split correctly before sharding to support tp_size >1.
238244
intermediate_size = self.config.intermediate_size
239245
gate_up_weight = module_weights['weight'][:]
240246
gate_weight = gate_up_weight[:intermediate_size, :]
241247
up_weight = gate_up_weight[intermediate_size:, :]
242-
module.load_weights(weights=[
243-
{
244-
'weight': gate_weight
245-
},
246-
{
247-
'weight': up_weight
248-
},
249-
])
248+
249+
# Get the scale factors if they exist
250+
gate_up_scale = module_weights.get('weight_scale', None)
251+
252+
gate_dict = {'weight': gate_weight}
253+
if gate_up_scale is not None:
254+
gate_dict['weight_scale'] = gate_up_scale
255+
256+
up_dict = {'weight': up_weight}
257+
if gate_up_scale is not None:
258+
up_dict['weight_scale'] = gate_up_scale
259+
260+
module.load_weights(weights=[gate_dict, up_dict])
250261
else:
251262
module.load_weights(weights=[module_weights])
252263
else:

0 commit comments

Comments
 (0)