Skip to content

Commit a1bcb97

Browse files
authored
[FIX] MM Eval Mask Sizes (#1920)
1 parent 4fb2464 commit a1bcb97

File tree

7 files changed

+39
-29
lines changed

7 files changed

+39
-29
lines changed

recipes/dev/generate_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig):
152152
batch = {}
153153
if is_multimodal_input:
154154
batch = padded_collate_tiled_images_and_mask(
155-
[model_inputs], pad_direction="left", pad_max_images=1
155+
[model_inputs],
156+
pad_direction="left",
157+
pad_max_images=1,
158+
pad_max_tiles=self.model_transform.max_num_tiles,
156159
)
157160
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
158161
prompt = batch.pop("tokens").to(self._device)

recipes/eleuther_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def tok_batch_multimodal_encode(
187187
all_encoded_messages,
188188
pad_direction="left",
189189
pad_max_images=self._max_images_per_sample,
190+
pad_max_tiles=self._transform.max_num_tiles,
190191
)
191192
utils.batch_to_device(tok_batch, self.device)
192193

tests/torchtune/data/test_collate.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,41 @@ def test_batch_pad_sequence(self):
5656

5757

5858
class TestPaddedCollateTiledImagesAndMask:
59+
img_shape = 1, 1, 1
60+
tokens_per_tile = 5
61+
5962
@pytest.fixture
6063
def batch(self):
64+
c, h, w = self.img_shape
65+
s = self.tokens_per_tile
6166
return [
6267
{
6368
"tokens": [1, 2, 1, 3],
6469
"labels": [4, 5, 6, 7],
6570
"encoder_input": {
66-
"images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)],
71+
"images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
6772
"aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
6873
},
69-
"encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
74+
"encoder_mask": [torch.ones(4, s * 2), torch.ones(4, s * 3)],
7075
},
7176
{
7277
"tokens": [1, 4],
7378
"labels": [8, 9],
7479
"encoder_input": {
75-
"images": [torch.ones(4, 1, 1, 1)],
80+
"images": [torch.ones(4, c, h, w)],
7681
"aspect_ratio": [torch.tensor([2, 2])],
7782
},
78-
"encoder_mask": [torch.ones(2, 5 * 4)],
83+
"encoder_mask": [torch.ones(2, s * 4)],
7984
},
8085
]
8186

8287
def test_right_pad_sequence(self, batch):
8388
actual = padded_collate_tiled_images_and_mask(
8489
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right"
8590
)
91+
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
92+
seq_len = actual["encoder_mask"].shape[-1]
93+
assert imgs * tiles * self.tokens_per_tile == seq_len
8694

8795
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
8896
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
@@ -126,28 +134,36 @@ def test_left_pad_sequence(self, batch):
126134
ignore_idx=-100,
127135
pad_direction="left",
128136
pad_max_images=4,
137+
pad_max_tiles=5,
129138
)
139+
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
140+
seq_len = actual["encoder_mask"].shape[-1]
141+
assert 5 * 4 * self.tokens_per_tile == seq_len
130142

131-
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
132-
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
143+
# pad 3 extra tiles
144+
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1)
145+
# pad 2 extra tiles
146+
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1)
147+
# Left pad text tokens
133148
mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0)
149+
mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile
134150
sample_1 = torch.stack([mask_1, mask_2])
135-
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
151+
sample_2 = torch.stack([mask_3, torch.zeros(4, 25)])
136152
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)
137-
expected_mask = F.pad(expected_mask, (0, 40), value=0)
153+
expected_mask = F.pad(expected_mask, (0, 50), value=0)
138154

139155
expected = {
140156
"tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]),
141157
"encoder_input": {
142158
"images": torch.tensor(
143159
[
144160
[
145-
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
146-
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
161+
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
162+
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
147163
],
148164
[
149-
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
150-
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
165+
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
166+
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
151167
],
152168
]
153169
),

tests/torchtune/modules/transforms/test_transforms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
IMAGE_TOKEN_ID = 1
13-
MAX_NUM_TILES = 4
1413

1514

1615
class TestVisionCrossAttentionMask:
@@ -54,7 +53,6 @@ def cross_attn_mask_transform(self, tile_size, patch_size):
5453
tile_size=tile_size,
5554
patch_size=patch_size,
5655
image_token_id=IMAGE_TOKEN_ID,
57-
max_num_tiles=MAX_NUM_TILES,
5856
)
5957

6058
def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
@@ -89,7 +87,7 @@ def test_inference_call(
8987
sample.update(dummy_kwargs)
9088
actual = cross_attn_mask_transform(sample, inference=True)
9189
expected = [
92-
torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool)
90+
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool)
9391
for _ in range(len(images))
9492
]
9593
expected[0][2:6, :image_num_tokens] = True

torchtune/data/_collate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def padded_collate_tiled_images_and_mask(
426426
if pad_max_images is not None:
427427
_, _, img_seq = concat_masks.shape
428428
concat_masks = F.pad(
429-
concat_masks, (0, pad_max_images * image_seq_len - img_seq)
429+
concat_masks,
430+
(0, pad_max_images * max_num_tiles * tokens_per_tile - img_seq),
430431
)
431432

432433
batch_dict = {

torchtune/models/llama3_2_vision/_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def __init__(
9393
tile_size=tile_size,
9494
patch_size=patch_size,
9595
image_token_id=self.tokenizer.image_id,
96-
max_num_tiles=max_num_tiles,
9796
)
9897

9998
self.stop_tokens = self.tokenizer.stop_tokens
10099
self.max_seq_len = max_seq_len
100+
self.max_num_tiles = max_num_tiles
101101
self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1)
102102
self.prompt_template = prompt_template
103103
self.pad_id = self.tokenizer.pad_id

torchtune/modules/transforms/_transforms.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List, Mapping, Optional, Protocol
7+
from typing import Any, List, Mapping, Protocol
88

99
import torch
1010

@@ -57,21 +57,17 @@ class VisionCrossAttentionMask(Transform):
5757
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
5858
with shape (40, 40) each.
5959
image_token_id (int): Token ID of the image special token.
60-
max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to
61-
pad mask during inference. Defaults to None
6260
"""
6361

6462
def __init__(
6563
self,
6664
tile_size: int,
6765
patch_size: int,
6866
image_token_id: int,
69-
max_num_tiles: Optional[int] = None,
7067
):
7168
patch_grid_size = tile_size // patch_size
7269
self.patches_per_tile = patch_grid_size**2
7370
self.image_token_id = image_token_id
74-
self.max_num_tiles = max_num_tiles
7571

7672
def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
7773
"""
@@ -163,9 +159,6 @@ def __call__(
163159
# which can vary based on number of tiles since they are not yet tile padded.
164160
# The masks are padded and concatenated together in the batch collator
165161
text_seq_len = len(tokens)
166-
max_image_size = None
167-
if inference and self.max_num_tiles is not None:
168-
max_image_size = self.max_num_tiles * (self.patches_per_tile + 1)
169162
masks = []
170163
for image_num, interval in enumerate(intervals):
171164
# Identify what part of text sequence should be attended
@@ -178,9 +171,7 @@ def __call__(
178171
# to a single image, so text tokens attend to all the image's tokens.
179172
# The mask is text_seq_len x mask_image_size if defined, otherwise
180173
# it uses current text/image sequence lengths.
181-
mask = torch.zeros(
182-
text_seq_len, max_image_size or image_seq_len, dtype=torch.bool
183-
)
174+
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool)
184175
mask[start:end, :image_seq_len] = True
185176
masks.append(mask)
186177

0 commit comments

Comments
 (0)