Skip to content

Commit d354a07

Browse files
feat(megatron-lm): reduce extra qkv transpose in attn (#645)
Keep input layout to SBHD layout to reduce extra q,k,v transpose in attention. Co-authored-by: RuibinCheung <ruibzhan@amd.com>
1 parent b2a561b commit d354a07

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,20 @@ def forward(
458458
packed_seq_params: PackedSeqParams = None,
459459
):
460460
"""Forward."""
461+
SUPPORTED_QKV_FORMATS = "sbhd"
462+
461463
packed_seq_kwargs = (
462464
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
463465
if packed_seq_params is not None
464466
else {}
465467
)
466468

467469
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
468-
assert qkv_format in ("sbhd", "bhsd"), "qkv_format only support bshd, but got {qkv_format}"
469-
if qkv_format == "sbhd":
470-
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
470+
assert (
471+
qkv_format in SUPPORTED_QKV_FORMATS
472+
), f"qkv_format only support {SUPPORTED_QKV_FORMATS}, but got {qkv_format}"
473+
# NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D).
474+
query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)]
471475
mask_type = attn_mask_type.name
472476
if mask_type == AttnMaskType.causal.name:
473477
causal = True
@@ -523,9 +527,10 @@ def forward(
523527
**self.attn_kwargs,
524528
)
525529

526-
o = o.reshape(o.shape[0], o.shape[1], -1).transpose(0, 1)
527-
if not o.is_contiguous():
528-
o = o.contiguous()
530+
# NOTE(ruibin): The output of attn is BSHD. Use permute to convert the layout to SBHD.
531+
o = o.permute(1, 0, 2, 3).contiguous()
532+
o = o.view(o.shape[0], o.shape[1], -1)
533+
529534
return o
530535

531536

0 commit comments

Comments
 (0)