Skip to content

Commit 0a77122

Browse files
Update sequence packing case when dummy PackedSeqParams are used (#2743)
1 parent f5d4c3a commit 0a77122

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

megatron/rl/sequence_packing_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def get_default_packed_seq_params(seq_length: int, device: torch.device) -> Pack
413413
PackedSeqParams configured as a single unpacked sequence.
414414
"""
415415
# Single sequence spanning the full length = no actual packing
416-
cu_seqlens = torch.tensor([0, seq_length], dtype=torch.int32, device=device)
416+
cu_seqlens = torch.full((seq_length,), seq_length, dtype=torch.int32, device=device)
417+
cu_seqlens[0] = 0
417418

418419
return PackedSeqParams(
419420
qkv_format='thd',

0 commit comments

Comments
 (0)