Skip to content

Commit 7fb6c3e

Browse files
committed
removed pad_max_tiles
1 parent 4107cc4 commit 7fb6c3e

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

torchtune/data/_collate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def padded_collate_tiled_images_and_mask(
222222
padding_idx: int = 0,
223223
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
224224
pad_direction: str = "right",
225+
pad_max_tiles: Optional[int] = None,
225226
pad_max_images: Optional[int] = None,
226227
) -> Dict[str, torch.Tensor]:
227228
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
@@ -259,6 +260,8 @@ def padded_collate_tiled_images_and_mask(
259260
:func:`torch.nn.utils.rnn.pad_sequence`, otherwise if ``pad_direction="left"``,
260261
we use :func:`torchtune.data.left_pad_sequence`. For training, we typically want to pad from the right.
261262
For inference, we typically want to pad from the left. Defaults to "right".
263+
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
264+
in the batch. Defaults to None.
262265
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
263266
in the batch. Defaults to None.
264267
@@ -272,6 +275,7 @@ def padded_collate_tiled_images_and_mask(
272275
273276
Raises:
274277
ValueError: if ``pad_direction`` is not one of "left" or "right".
278+
ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image.
275279
276280
Example:
277281
>>> image_id = 1
@@ -355,6 +359,13 @@ def padded_collate_tiled_images_and_mask(
355359
for sample in batch
356360
for image in sample["encoder_input"]["images"]
357361
)
362+
if pad_max_tiles is not None:
363+
if pad_max_tiles < max_num_tiles:
364+
raise ValueError(
365+
f"More tiles in image {max_num_tiles}, than pad_max_tiles {pad_max_tiles}"
366+
)
367+
max_num_tiles = pad_max_tiles
368+
358369
# Second loop: pad images and masks to max number of tiles, max text seq len in batch
359370
batch_images = []
360371
batch_masks = []

torchtune/models/clip/_transform.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
find_supported_resolutions,
1616
get_canvas_best_fit,
1717
)
18-
from torchtune.modules.transforms.vision_utils.pad_dim_to_size import pad_dim_to_size
1918
from torchtune.modules.transforms.vision_utils.resize_with_pad import resize_with_pad
2019
from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop
2120

@@ -63,7 +62,6 @@ class CLIPImageTransform:
6362
This will be used to generate possible_resolutions,
6463
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224.
6564
Default 4.
66-
pad_max_tiles (bool): If True, the image will be padded to have tiles == max_num_tiles. Default False.
6765
dtype (torch.dtype): Data type of the output image. Default torch.bfloat16.
6866
resample (str): Resampling method used when resizing images. Supports any enum of
6967
``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic".
@@ -101,7 +99,6 @@ def __init__(
10199
possible_resolutions: Optional[List[Tuple[int, int]]] = None,
102100
tile_size: int = 224,
103101
max_num_tiles: Optional[int] = 4,
104-
pad_max_tiles: bool = False,
105102
dtype: torch.dtype = torch.bfloat16,
106103
resample: str = "bilinear",
107104
resize_to_max_canvas: bool = False,
@@ -142,7 +139,6 @@ def __init__(
142139
# tile_crop
143140
self.tile_size = tile_size
144141
self.tile_crop = tile_crop
145-
self.pad_tile_size = max_num_tiles if pad_max_tiles else None
146142

147143
def __call__(
148144
self, sample: Mapping[str, Any], inference: bool = False
@@ -190,8 +186,6 @@ def __call__(
190186

191187
# Divide the image into equally sized tiles
192188
image = self.tile_crop(image=image, tile_size=self.tile_size)
193-
if self.pad_tile_size:
194-
image = pad_dim_to_size(image, size=self.pad_tile_size, dim=0)
195189

196190
aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size
197191

torchtune/models/llama3_2_vision/_transform.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def __init__(
8686
tile_size=tile_size,
8787
possible_resolutions=None,
8888
max_num_tiles=max_num_tiles,
89-
pad_max_tiles=True,
9089
resample="bilinear",
9190
resize_to_max_canvas=False,
9291
)

0 commit comments

Comments
 (0)