Skip to content

Commit 6a8a027

Browse files
authored
Remove pad_max_tiles in CLIP (#1836)
1 parent a6fd945 commit 6a8a027

File tree

8 files changed

+19
-43
lines changed

8 files changed

+19
-43
lines changed

recipes/configs/llama3_2_vision/11B_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ tokenizer:
2828
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
2929
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
3030
image_size: 560
31+
max_seq_len: 8192
3132

3233
# Checkpointer
3334
checkpointer:

recipes/configs/llama3_2_vision/11B_full_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ tokenizer:
3030
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
3131
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
3232
image_size: 560
33+
max_seq_len: 8192
3334

3435
# Checkpointer
3536
checkpointer:

recipes/configs/llama3_2_vision/11B_lora.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ tokenizer:
3434
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
3535
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
3636
image_size: 560
37+
max_seq_len: 8192
3738

3839
# Checkpointer
3940
checkpointer:

recipes/configs/llama3_2_vision/11B_lora_single_device.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ tokenizer:
3232
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
3333
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
3434
image_size: 560
35+
max_seq_len: 8192
3536

3637
# Checkpointer
3738
checkpointer:

tests/torchtune/models/clip/test_clip_image_transform.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,6 @@ class TestCLIPImageTransform:
3737
"expected_tile_max": [1.0, 1.0],
3838
"expected_tile_min": [0.0, 0.0],
3939
"expected_aspect_ratio": [1, 2],
40-
"pad_max_tiles": False,
41-
},
42-
{
43-
"image_size": (100, 400, 3),
44-
"expected_shape": torch.Size([4, 3, 224, 224]),
45-
"resize_to_max_canvas": False,
46-
"expected_tile_means": [0.2230, 0.1763, 0.0, 0.0],
47-
"expected_tile_max": [1.0, 1.0, 0.0, 0.0],
48-
"expected_tile_min": [0.0, 0.0, 0.0, 0.0],
49-
"expected_aspect_ratio": [1, 2],
50-
"pad_max_tiles": True,
5140
},
5241
{
5342
"image_size": (1000, 300, 3),
@@ -57,7 +46,6 @@ class TestCLIPImageTransform:
5746
"expected_tile_max": [0.9705, 0.9694, 0.9521, 0.9314],
5847
"expected_tile_min": [0.0353, 0.0435, 0.0528, 0.0],
5948
"expected_aspect_ratio": [4, 1],
60-
"pad_max_tiles": False,
6149
},
6250
{
6351
"image_size": (200, 200, 3),
@@ -67,7 +55,6 @@ class TestCLIPImageTransform:
6755
"expected_tile_max": [0.9922, 0.9926, 0.9970, 0.9908],
6856
"expected_tile_min": [0.0056, 0.0069, 0.0059, 0.0033],
6957
"expected_aspect_ratio": [2, 2],
70-
"pad_max_tiles": False,
7158
"pad_tiles": 1,
7259
},
7360
{
@@ -78,17 +65,6 @@ class TestCLIPImageTransform:
7865
"expected_tile_max": [1.0, 1.0, 1.0],
7966
"expected_tile_min": [0.0, 0.0, 0.0],
8067
"expected_aspect_ratio": [3, 1],
81-
"pad_max_tiles": False,
82-
},
83-
{
84-
"image_size": (600, 200, 3),
85-
"expected_shape": torch.Size([4, 3, 224, 224]),
86-
"resize_to_max_canvas": False,
87-
"expected_tile_means": [0.4473, 0.4469, 0.3032, 0.0],
88-
"expected_tile_max": [1.0, 1.0, 1.0, 0.0],
89-
"expected_tile_min": [0.0, 0.0, 0.0, 0.0],
90-
"expected_aspect_ratio": [3, 1],
91-
"pad_max_tiles": True,
9268
},
9369
],
9470
)
@@ -103,7 +79,6 @@ def test_clip_image_transform(self, params):
10379
resample="bilinear",
10480
dtype=torch.float32,
10581
resize_to_max_canvas=params["resize_to_max_canvas"],
106-
pad_max_tiles=params["pad_max_tiles"],
10782
)
10883

10984
image_transform_inference = CLIPImageTransformInference(
@@ -115,7 +90,6 @@ def test_clip_image_transform(self, params):
11590
resample="bilinear",
11691
resize_to_max_canvas=params["resize_to_max_canvas"],
11792
antialias=True,
118-
pad_max_tiles=params["pad_max_tiles"],
11993
)
12094

12195
# Generate a deterministic image using np.arange for reproducibility
@@ -169,13 +143,7 @@ def test_clip_image_transform(self, params):
169143
), f"Expected aspect ratio {params['expected_aspect_ratio']} but got {tuple(output_ar.numpy())}"
170144

171145
# number of tiles matches the product of the aspect ratio
172-
if params["pad_max_tiles"]:
173-
# max_num_tiles=4.
174-
assert (
175-
4 == output_image.shape[0]
176-
), f"Expected 4 tiles but got {output_image.shape[0]}"
177-
else:
178-
expected_num_tiles = output_ar[0] * output_ar[1]
179-
assert (
180-
expected_num_tiles == output_image.shape[0]
181-
), f"Expected {expected_num_tiles} tiles but got {output_image.shape[0]}"
146+
expected_num_tiles = output_ar[0] * output_ar[1]
147+
assert (
148+
expected_num_tiles == output_image.shape[0]
149+
), f"Expected {expected_num_tiles} tiles but got {output_image.shape[0]}"

torchtune/data/_collate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def padded_collate_tiled_images_and_mask(
222222
padding_idx: int = 0,
223223
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
224224
pad_direction: str = "right",
225+
pad_max_tiles: Optional[int] = None,
225226
pad_max_images: Optional[int] = None,
226227
) -> Dict[str, torch.Tensor]:
227228
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
@@ -259,6 +260,8 @@ def padded_collate_tiled_images_and_mask(
259260
:func:`torch.nn.utils.rnn.pad_sequence`, otherwise if ``pad_direction="left"``,
260261
we use :func:`torchtune.data.left_pad_sequence`. For training, we typically want to pad from the right.
261262
For inference, we typically want to pad from the left. Defaults to "right".
263+
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
264+
in the batch. Defaults to None.
262265
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
263266
in the batch. Defaults to None.
264267
@@ -272,6 +275,7 @@ def padded_collate_tiled_images_and_mask(
272275
273276
Raises:
274277
ValueError: if ``pad_direction`` is not one of "left" or "right".
278+
ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image.
275279
276280
Example:
277281
>>> image_id = 1
@@ -355,6 +359,13 @@ def padded_collate_tiled_images_and_mask(
355359
for sample in batch
356360
for image in sample["encoder_input"]["images"]
357361
)
362+
if pad_max_tiles is not None:
363+
if pad_max_tiles < max_num_tiles:
364+
raise ValueError(
365+
f"More tiles in image {max_num_tiles}, than pad_max_tiles {pad_max_tiles}"
366+
)
367+
max_num_tiles = pad_max_tiles
368+
358369
# Second loop: pad images and masks to max number of tiles, max text seq len in batch
359370
batch_images = []
360371
batch_masks = []

torchtune/models/clip/_transform.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
find_supported_resolutions,
1616
get_canvas_best_fit,
1717
)
18-
from torchtune.modules.transforms.vision_utils.pad_dim_to_size import pad_dim_to_size
1918
from torchtune.modules.transforms.vision_utils.resize_with_pad import resize_with_pad
2019
from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop
2120

@@ -63,7 +62,6 @@ class CLIPImageTransform:
6362
This will be used to generate possible_resolutions,
6463
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224.
6564
Default 4.
66-
pad_max_tiles (bool): If True, the image will be padded to have tiles == max_num_tiles. Default False.
6765
dtype (torch.dtype): Data type of the output image. Default torch.bfloat16.
6866
resample (str): Resampling method used when resizing images. Supports any enum of
6967
``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic".
@@ -101,7 +99,6 @@ def __init__(
10199
possible_resolutions: Optional[List[Tuple[int, int]]] = None,
102100
tile_size: int = 224,
103101
max_num_tiles: Optional[int] = 4,
104-
pad_max_tiles: bool = False,
105102
dtype: torch.dtype = torch.bfloat16,
106103
resample: str = "bilinear",
107104
resize_to_max_canvas: bool = False,
@@ -142,7 +139,6 @@ def __init__(
142139
# tile_crop
143140
self.tile_size = tile_size
144141
self.tile_crop = tile_crop
145-
self.pad_tile_size = max_num_tiles if pad_max_tiles else None
146142

147143
def __call__(
148144
self, sample: Mapping[str, Any], inference: bool = False
@@ -190,8 +186,6 @@ def __call__(
190186

191187
# Divide the image into equally sized tiles
192188
image = self.tile_crop(image=image, tile_size=self.tile_size)
193-
if self.pad_tile_size:
194-
image = pad_dim_to_size(image, size=self.pad_tile_size, dim=0)
195189

196190
aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size
197191

torchtune/models/llama3_2_vision/_transform.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def __init__(
8686
tile_size=tile_size,
8787
possible_resolutions=None,
8888
max_num_tiles=max_num_tiles,
89-
pad_max_tiles=True,
9089
resample="bilinear",
9190
resize_to_max_canvas=False,
9291
)

0 commit comments

Comments
 (0)