Skip to content

Commit aac3ace

Browse files
joecummingsmaximegmd
authored andcommitted
Add 'on-the-fly' sample packing (meta-pytorch#1109)
1 parent efbab6f commit aac3ace

File tree

5 files changed

+196
-146
lines changed

5 files changed

+196
-146
lines changed

tests/torchtune/datasets/test_packed_dataset.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,29 @@ def _get_expected_mask_and_input_pos(
7777

7878
return mask[:max_seq_len, :max_seq_len], torch.tensor(input_pos[:max_seq_len])
7979

80+
def _calculate_num_packs(
81+
self, dataset_size, max_seq_len, sample_size, split_across_pack, max_packs
82+
):
83+
# First see how many samples we can fit in a single pack
84+
num_samples_per_pack, remainder = divmod(max_seq_len, sample_size)
85+
86+
# If we split across pack (and the samples don't fit perfectly in max_seq_len), we can fit more
87+
if split_across_pack and remainder > 0:
88+
# Now we need the fractional to see how many we can partially fit in each pack
89+
num_samples_per_pack = max_seq_len / sample_size
90+
91+
# If we don't split across pack, we will need more packs
92+
num_packs, remainder = divmod(dataset_size, num_samples_per_pack)
93+
94+
# If there's leftover, we need to add one more pack
95+
if remainder > 0:
96+
num_packs += 1
97+
98+
return num_packs if num_packs < max_packs else max_packs
99+
80100
@pytest.mark.parametrize("max_seq_len", [25])
81101
@pytest.mark.parametrize("sample_size", [2, 5])
82-
@pytest.mark.parametrize("max_packs", [5])
102+
@pytest.mark.parametrize("max_packs", [5, 200])
83103
@pytest.mark.parametrize("split_across_pack", [True, False])
84104
def test_packed_dataset(
85105
self, max_seq_len, sample_size, max_packs, split_across_pack
@@ -91,8 +111,13 @@ def test_packed_dataset(
91111
max_packs=max_packs,
92112
split_across_pack=split_across_pack,
93113
)
114+
94115
# Check we get right number of packs
95-
assert len(packed) == max_packs
116+
correct_num_packs = self._calculate_num_packs(
117+
len(dataset), max_seq_len, sample_size, split_across_pack, max_packs
118+
)
119+
assert len(packed) == correct_num_packs
120+
96121
# Check all fields are same length
97122
assert (
98123
len(packed[0]["tokens"])
@@ -105,15 +130,15 @@ def test_packed_dataset(
105130
if split_across_pack:
106131
# If we split samples, we'll know how many samples by taking the
107132
# full length and dividing by sample size
108-
last_index, remainder = divmod(max_packs * max_seq_len, sample_size)
133+
last_index, remainder = divmod(len(packed) * max_seq_len, sample_size)
109134
# Account for remaining sample that didn't fit in window
110135
last_index = last_index if remainder > 0 else last_index - 1
111136
else:
112137
# If we don't split samples, we know how many samples by taking
113138
# how much fits in a single window and multiplying by max rows.
114139
# If there is a remainder, this will end up being a pad token.
115140
last_index = (
116-
(max_seq_len // sample_size) * max_packs - 1
141+
(max_seq_len // sample_size) * len(packed) - 1
117142
if max_seq_len % sample_size == 0
118143
else 0
119144
)
@@ -207,11 +232,11 @@ def test_packed_dataset_real_data(self):
207232

208233
def test_pad_pack(self):
209234
padding_idx = -8
210-
ignore_idx = -9
235+
ignore_idx = -100 # Same as CROSS_ENTROPY_IGNORE_IDX
211236
pack = {
212237
"tokens": [2, 5],
213238
"labels": [3, 7],
214-
"mask": torch.tensor([[True, False], [True, True]]),
239+
"seq_lens": [1, 1],
215240
# Let the first token be the end of the previous sample (pos 8),
216241
# and the second token the start of the next sample (pos 0). Collate
217242
# should continue from 0 -> 1, 2, ...
@@ -224,11 +249,11 @@ def test_pad_pack(self):
224249
max_seq_len=4,
225250
)
226251

227-
padded = packed._pad_pack(pack, padding_idx=padding_idx, ignore_idx=ignore_idx)
252+
pack = packed._convert_to_tensors(pack)
253+
padded = packed._pad_pack(pack, padding_idx=padding_idx)
228254

229255
padded_input = padded["tokens"]
230256
padded_label = padded["labels"]
231-
padded_mask = padded["mask"]
232257
padded_input_pos = padded["input_pos"]
233258

234259
torch.testing.assert_close(
@@ -237,15 +262,12 @@ def test_pad_pack(self):
237262
torch.testing.assert_close(
238263
padded_label, torch.tensor([3, 7, ignore_idx, ignore_idx])
239264
)
240-
assert torch.equal(
241-
padded_mask,
242-
torch.tensor(
243-
[
244-
[True, False, False, False],
245-
[True, True, False, False],
246-
[False, False, True, False],
247-
[False, False, False, True],
248-
]
249-
),
250-
)
251265
torch.testing.assert_close(padded_input_pos, torch.tensor([8, 0, 1, 2]))
266+
267+
def test_pack_errors_if_sample_too_long(self):
268+
dataset = DummyDataset(8)
269+
with pytest.raises(ValueError, match="Dataset sample is too long"):
270+
PackedDataset(
271+
dataset,
272+
max_seq_len=4,
273+
)

torchtune/datasets/_chat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,8 @@ def chat_dataset(
185185
train_on_input=train_on_input,
186186
**load_dataset_kwargs,
187187
)
188-
return PackedDataset(ds, max_seq_len=max_seq_len) if packed else ds
188+
return (
189+
PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id)
190+
if packed
191+
else ds
192+
)

torchtune/datasets/_instruct.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,8 @@ def instruct_dataset(
177177
max_seq_len=max_seq_len,
178178
**load_dataset_kwargs,
179179
)
180-
return PackedDataset(ds, max_seq_len=max_seq_len) if packed else ds
180+
return (
181+
PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id)
182+
if packed
183+
else ds
184+
)

0 commit comments

Comments
 (0)