@@ -458,16 +458,20 @@ def forward(
458458 packed_seq_params : PackedSeqParams = None ,
459459 ):
460460 """Forward."""
461+ SUPPORTED_QKV_FORMATS = "sbhd"
462+
461463 packed_seq_kwargs = (
462464 {key : getattr (packed_seq_params , key ) for key in self .kept_packed_seq_params }
463465 if packed_seq_params is not None
464466 else {}
465467 )
466468
467469 qkv_format = packed_seq_kwargs .get ("qkv_format" , self .qkv_format )
468- assert qkv_format in ("sbhd" , "bhsd" ), "qkv_format only support bshd, but got {qkv_format}"
469- if qkv_format == "sbhd" :
470- query , key , value = [x .transpose (0 , 1 ).contiguous () for x in (query , key , value )]
470+ assert (
471+ qkv_format in SUPPORTED_QKV_FORMATS
472+ ), f"qkv_format only support { SUPPORTED_QKV_FORMATS } , but got { qkv_format } "
473+ # NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D).
474+ query , key , value = [x .permute (1 , 0 , 2 , 3 ) for x in (query , key , value )]
471475 mask_type = attn_mask_type .name
472476 if mask_type == AttnMaskType .causal .name :
473477 causal = True
@@ -523,9 +527,10 @@ def forward(
523527 ** self .attn_kwargs ,
524528 )
525529
526- o = o .reshape (o .shape [0 ], o .shape [1 ], - 1 ).transpose (0 , 1 )
527- if not o .is_contiguous ():
528- o = o .contiguous ()
530+ # NOTE(ruibin): The output of attn is BSHD. Use permute to convert the layout to SBHD.
531+ o = o .permute (1 , 0 , 2 , 3 ).contiguous ()
532+ o = o .view (o .shape [0 ], o .shape [1 ], - 1 )
533+
529534 return o
530535
531536
0 commit comments