Skip to content

Commit 8451b0d

Browse files
authored
Integrate flex attention (#1193)
1 parent eb92658 commit 8451b0d

21 files changed

+817
-201
lines changed

recipes/full_finetune_distributed.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.optim import Optimizer
2121
from torch.utils.data import DataLoader, DistributedSampler
2222
from torchtune import config, modules, training, utils
23-
from torchtune.data import padded_collate_sft
23+
from torchtune.data import padded_collate_packed, padded_collate_sft
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.recipe_interfaces import FTRecipeInterface
2626
from torchtune.training import DummyProfiler, PROFILER_KEY
@@ -227,7 +227,7 @@ def setup(self, cfg: DictConfig) -> None:
227227
self._loss_fn = config.instantiate(cfg.loss)
228228

229229
if self._compile:
230-
training.compile_loss(self.loss_fn, verbose=self._is_rank_zero)
230+
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
231231

232232
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
233233
# set num_output_chunks for model
@@ -491,14 +491,14 @@ def _setup_data(
491491
dataset=ds,
492492
batch_size=batch_size,
493493
sampler=sampler,
494-
collate_fn=(
495-
partial(
496-
padded_collate_sft,
497-
padding_idx=self._tokenizer.pad_id,
498-
ignore_idx=self._loss_fn.ignore_index,
499-
)
500-
if not packed
501-
else None
494+
collate_fn=partial(
495+
padded_collate_sft,
496+
padding_idx=self._tokenizer.pad_id,
497+
ignore_idx=self._loss_fn.ignore_index,
498+
)
499+
if not packed
500+
else partial(
501+
padded_collate_packed,
502502
),
503503
)
504504

recipes/full_finetune_single_device.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils.data import DataLoader, DistributedSampler
1919

2020
from torchtune import config, modules, training, utils
21-
from torchtune.data import padded_collate_sft
21+
from torchtune.data import padded_collate_packed, padded_collate_sft
2222
from torchtune.datasets import ConcatDataset
2323
from torchtune.recipe_interfaces import FTRecipeInterface
2424
from torchtune.training import DummyProfiler, PROFILER_KEY
@@ -451,14 +451,14 @@ def _setup_data(
451451
dataset=ds,
452452
batch_size=batch_size,
453453
sampler=sampler,
454-
collate_fn=(
455-
partial(
456-
padded_collate_sft,
457-
padding_idx=self._tokenizer.pad_id,
458-
ignore_idx=self._loss_fn.ignore_index,
459-
)
460-
if not packed
461-
else None
454+
collate_fn=partial(
455+
padded_collate_sft,
456+
padding_idx=self._tokenizer.pad_id,
457+
ignore_idx=self._loss_fn.ignore_index,
458+
)
459+
if not packed
460+
else partial(
461+
padded_collate_packed,
462462
),
463463
)
464464

recipes/lora_finetune_distributed.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.optim import Optimizer
2121
from torch.utils.data import DataLoader, DistributedSampler
2222
from torchtune import config, modules, training, utils
23-
from torchtune.data import padded_collate_sft
23+
from torchtune.data import padded_collate_packed, padded_collate_sft
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.modules.peft import (
2626
DoRALinear,
@@ -559,14 +559,14 @@ def _setup_data(
559559
dataset=ds,
560560
batch_size=batch_size,
561561
sampler=sampler,
562-
collate_fn=(
563-
partial(
564-
padded_collate_sft,
565-
padding_idx=self._tokenizer.pad_id,
566-
ignore_idx=self._loss_fn.ignore_index,
567-
)
568-
if not packed
569-
else None
562+
collate_fn=partial(
563+
padded_collate_sft,
564+
padding_idx=self._tokenizer.pad_id,
565+
ignore_idx=self._loss_fn.ignore_index,
566+
)
567+
if not packed
568+
else partial(
569+
padded_collate_packed,
570570
),
571571
)
572572

recipes/lora_finetune_single_device.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.optim import Optimizer
2020
from torch.utils.data import DataLoader, DistributedSampler
2121
from torchtune import config, modules, training, utils
22-
from torchtune.data import padded_collate_sft
22+
from torchtune.data import padded_collate_packed, padded_collate_sft
2323
from torchtune.datasets import ConcatDataset
2424
from torchtune.modules.peft import (
2525
get_adapter_params,
@@ -486,14 +486,14 @@ def _setup_data(
486486
dataset=ds,
487487
sampler=sampler,
488488
batch_size=batch_size,
489-
collate_fn=(
490-
partial(
491-
padded_collate_sft,
492-
padding_idx=self._tokenizer.pad_id,
493-
ignore_idx=self._loss_fn.ignore_index,
494-
)
495-
if not packed
496-
else None
489+
collate_fn=partial(
490+
padded_collate_sft,
491+
padding_idx=self._tokenizer.pad_id,
492+
ignore_idx=self._loss_fn.ignore_index,
493+
)
494+
if not packed
495+
else partial(
496+
padded_collate_packed,
497497
),
498498
)
499499

recipes/qat_distributed.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.optim import Optimizer
2222
from torch.utils.data import DataLoader, DistributedSampler
2323
from torchtune import config, modules, training, utils
24-
from torchtune.data import padded_collate_sft
24+
from torchtune.data import padded_collate_packed, padded_collate_sft
2525
from torchtune.datasets import ConcatDataset
2626
from torchtune.recipe_interfaces import FTRecipeInterface
2727
from torchtune.training import DummyProfiler, PROFILER_KEY
@@ -523,14 +523,14 @@ def _setup_data(
523523
dataset=ds,
524524
batch_size=batch_size,
525525
sampler=sampler,
526-
collate_fn=(
527-
partial(
528-
padded_collate_sft,
529-
padding_idx=self._tokenizer.pad_id,
530-
ignore_idx=self._loss_fn.ignore_index,
531-
)
532-
if not packed
533-
else None
526+
collate_fn=partial(
527+
padded_collate_sft,
528+
padding_idx=self._tokenizer.pad_id,
529+
ignore_idx=self._loss_fn.ignore_index,
530+
)
531+
if not packed
532+
else partial(
533+
padded_collate_packed,
534534
),
535535
)
536536

tests/torchtune/config/test_config_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ def test_log_config(self, capsys):
131131
with mock.patch(
132132
"torchtune.config._utils.get_logger", return_value=logger
133133
), mock.patch(
134-
"torchtune.config._utils.dist.is_available", return_value=True
134+
"torchtune.utils.logging.dist.is_available", return_value=True
135135
), mock.patch(
136-
"torchtune.config._utils.dist.is_initialized", return_value=True
136+
"torchtune.utils.logging.dist.is_initialized", return_value=True
137137
):
138138
# Make sure rank 0 logs as expected
139139
with mock.patch(
140-
"torchtune.config._utils.dist.get_rank",
140+
"torchtune.utils.logging.dist.get_rank",
141141
return_value=0,
142142
):
143143
log_config("test", cfg)
@@ -153,7 +153,7 @@ def test_log_config(self, capsys):
153153

154154
# Make sure all other ranks do not log anything
155155
with mock.patch(
156-
"torchtune.config._utils.dist.get_rank",
156+
"torchtune.utils.logging.dist.get_rank",
157157
return_value=1,
158158
):
159159
log_config("test", cfg)

tests/torchtune/data/test_collate.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66

77
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
88

9+
from unittest import mock
10+
911
import pytest
1012
import torch
13+
from tests.test_utils import gpu_test
1114
from torchtune.data import (
1215
left_pad_sequence,
1316
padded_collate,
1417
padded_collate_dpo,
18+
padded_collate_packed,
1519
padded_collate_sft,
1620
)
21+
from torchtune.modules.attention_utils import _SUPPORTS_FLEX_ATTENTION
1722

1823

1924
class TestPaddedCollateSFT:
@@ -47,6 +52,119 @@ def test_batch_pad_sequence(self):
4752
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
4853
)
4954

55+
@mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False)
56+
def test_padded_collate_packed_sdpa(self):
57+
token_pairs = [
58+
{
59+
"tokens": torch.tensor([1, 2, 3, 4, 5, 6]),
60+
"labels": torch.tensor([7, 8, 9, 10, 11, 12]),
61+
"input_pos": torch.tensor([0, 1, 2, 0, 1, 0]),
62+
"seq_lens": torch.tensor([3, 2, 1]),
63+
},
64+
{
65+
"tokens": torch.tensor([13, 14, 15, 16, 17, 18]),
66+
"labels": torch.tensor([19, 20, 21, 22, 23, 24]),
67+
"input_pos": torch.tensor([0, 1, 0, 1, 0, 1]),
68+
"seq_lens": torch.tensor([2, 2, 2]),
69+
},
70+
]
71+
collated = padded_collate_packed(
72+
batch=token_pairs,
73+
)
74+
torch.testing.assert_close(
75+
collated["tokens"],
76+
torch.tensor([[1, 2, 3, 4, 5, 6], [13, 14, 15, 16, 17, 18]]),
77+
)
78+
torch.testing.assert_close(
79+
collated["labels"],
80+
torch.tensor([[7, 8, 9, 10, 11, 12], [19, 20, 21, 22, 23, 24]]),
81+
)
82+
torch.testing.assert_close(
83+
collated["input_pos"],
84+
torch.tensor([[0, 1, 2, 0, 1, 0], [0, 1, 0, 1, 0, 1]]),
85+
)
86+
torch.testing.assert_close(
87+
collated["mask"],
88+
torch.tensor(
89+
[
90+
[
91+
[1, 0, 0, 0, 0, 0],
92+
[1, 1, 0, 0, 0, 0],
93+
[1, 1, 1, 0, 0, 0],
94+
[0, 0, 0, 1, 0, 0],
95+
[0, 0, 0, 1, 1, 0],
96+
[0, 0, 0, 0, 0, 1],
97+
],
98+
[
99+
[1, 0, 0, 0, 0, 0],
100+
[1, 1, 0, 0, 0, 0],
101+
[0, 0, 1, 0, 0, 0],
102+
[0, 0, 1, 1, 0, 0],
103+
[0, 0, 0, 0, 1, 0],
104+
[0, 0, 0, 0, 1, 1],
105+
],
106+
],
107+
dtype=torch.bool,
108+
),
109+
)
110+
111+
@pytest.mark.skipif(
112+
not _SUPPORTS_FLEX_ATTENTION,
113+
reason="Please install a nightly build of torch to run this test.",
114+
)
115+
@gpu_test(gpu_count=1)
116+
def test_padded_collate_packed_flex(self):
117+
# create_block_mask requires that seq_len be divisible by 128, the default block size.
118+
# see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L636
119+
batch = [
120+
{
121+
"tokens": torch.arange(128, dtype=torch.long),
122+
"labels": torch.arange(128, dtype=torch.long),
123+
"input_pos": torch.arange(128, dtype=torch.long),
124+
"seq_lens": torch.ones(64, dtype=torch.long) * 2,
125+
},
126+
{
127+
"tokens": torch.arange(128, 256, dtype=torch.long),
128+
"labels": torch.arange(128, 256, dtype=torch.long),
129+
"input_pos": torch.arange(128, 256, dtype=torch.long),
130+
"seq_lens": torch.ones(32, dtype=torch.long) * 4,
131+
},
132+
]
133+
collated = padded_collate_packed(
134+
batch=batch,
135+
)
136+
torch.testing.assert_close(
137+
collated["tokens"],
138+
torch.stack(
139+
[
140+
torch.arange(128, dtype=torch.long),
141+
torch.arange(128, 256, dtype=torch.long),
142+
]
143+
),
144+
)
145+
torch.testing.assert_close(
146+
collated["labels"],
147+
torch.stack(
148+
[
149+
torch.arange(128, dtype=torch.long),
150+
torch.arange(128, 256, dtype=torch.long),
151+
]
152+
),
153+
)
154+
torch.testing.assert_close(
155+
collated["input_pos"],
156+
torch.stack(
157+
[
158+
torch.arange(128, dtype=torch.long),
159+
torch.arange(128, 256, dtype=torch.long),
160+
]
161+
),
162+
)
163+
torch.testing.assert_close(
164+
collated["mask"].to_dense(),
165+
torch.tensor([[[[1]]], [[[1]]]], dtype=torch.int32, device="cuda"),
166+
)
167+
50168

51169
class TestLeftPadSequence:
52170
def test_left_pad_sequence(self):

0 commit comments

Comments
 (0)