Skip to content

Commit bdf4108

Browse files
2ez4bzdominicshanshan
authored andcommitted
[None][fix] Fix batching bug in Mistral3 model (NVIDIA#6841)
Prior to this commit, if multiple requests with images were in the same batch, the batching logic for the images would fail. This commit fixes it, and adds unit tests for it that were verified to fail prior to the fix. Signed-off-by: William Zhang <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 5d1249b commit bdf4108

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def forward(
387387
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
388388
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
389389
f"({multimodal_params_len}).")
390-
batched_pixel_values, batched_image_sizes = self._batch_pixel_values(
390+
batched_pixel_values, batched_image_sizes = self.batch_pixel_values(
391391
pixel_values=pixel_values, image_sizes=image_sizes)
392392
mm_embeds = [
393393
self._get_image_features(pixel_values=batched_pixel_values,
@@ -454,21 +454,38 @@ def _get_image_features(
454454
# (the transformers one expected numpy arrays).
455455
@staticmethod
456456
@torch.inference_mode()
457-
def _batch_pixel_values(
457+
def batch_pixel_values(
458458
pixel_values: List[torch.Tensor],
459459
image_sizes: List[torch.Tensor],
460460
) -> tuple[torch.Tensor, torch.Tensor]:
461+
# NOTES:
462+
# * `pixel_values` is a list of `[B_idx, C, H_idx, W_idx]` tensors, i.e. a batch of images as
463+
# padded + batched by the input processor.
464+
# The height (H_idx) and width (W_idx) of each element need not coincide.
465+
# * Similarly, each element in `image_sizes` describes the original image sizes prior to
466+
# padding for the corresponding element in `pixel_values`.
467+
468+
# The below creates a single `[sum(B_idx), 2]` tensor describing all image sizes, and then
469+
# calculates the maximum height / width across all of them.
461470
batched_image_sizes = torch.cat(image_sizes)
462471
max_shape = batched_image_sizes.max(dim=0).values
472+
473+
# This next step then pads the pixel values potentially a second time by using the `max_shape`
474+
# computed above. Note that as far as this function is concerned, the original sizes for
475+
# batching purposes can be deduced from looking at the tensors in `pixel_values`, NOT in
476+
# `image_sizes`.
463477
pixel_values = [
464478
torchvision.transforms.v2.functional.pad(
465479
image,
466480
# Per torchvision docs, this should be in LTRB order if it's a sequence of 4 numbers.
467-
padding=[0, 0, max_shape[1] - size[1], max_shape[0] - size[0]],
481+
padding=[
482+
0, 0, max_shape[1] - image.shape[-1],
483+
max_shape[0] - image.shape[-2]
484+
],
468485
# Values extracted from HF implementation.
469486
fill=0.0,
470487
padding_mode="constant",
471-
) for image, size in zip(pixel_values, batched_image_sizes)
488+
) for image in pixel_values
472489
]
473490
return torch.cat(pixel_values), batched_image_sizes
474491

tests/unittest/_torch/modeling/test_modeling_mistral.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,58 @@ def run_forward(input_ids, position_ids, attn_metadata):
443443
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
444444
if graph_runner is not None:
445445
graph_runner.clear()
446+
447+
448+
@pytest.mark.parametrize(
449+
"in_shapes, image_sizes, expected_out_shape",
450+
[
451+
(
452+
[(2, 3, 100, 150), (1, 3, 200, 100), (3, 3, 120, 180)],
453+
[
454+
[[92, 150], [100, 73]],
455+
[[200, 100]],
456+
[[37, 130], [120, 83], [73, 180]],
457+
],
458+
[6, 3, 200, 180],
459+
),
460+
# Single batch, single image.
461+
(
462+
[(1, 3, 64, 128)],
463+
[[[64, 128]]],
464+
[1, 3, 64, 128],
465+
),
466+
# Same max size across batches.
467+
(
468+
[(2, 3, 59, 59), (1, 3, 59, 59), (5, 3, 59, 59)],
469+
[
470+
[[13, 59], [59, 17]],
471+
[[59, 59]],
472+
[[19, 29], [59, 31], [17, 54], [13, 59], [11, 37]],
473+
],
474+
[8, 3, 59, 59],
475+
),
476+
],
477+
)
478+
def test_batch_pixel_values(in_shapes, image_sizes, expected_out_shape):
479+
# Test case 1: Basic functionality with different sized images
480+
pixel_values = [torch.randn(*shape) for shape in in_shapes]
481+
image_sizes = [torch.tensor(size) for size in image_sizes]
482+
483+
batched_pixels, batched_sizes = modeling_mistral.Mistral3VLM.batch_pixel_values(
484+
pixel_values, image_sizes
485+
)
486+
487+
# Check output shapes
488+
assert list(batched_pixels.shape) == expected_out_shape
489+
assert list(batched_sizes.shape) == [expected_out_shape[0], 2]
490+
491+
# Check that the original image data is preserved (with padding).
492+
start_idx = 0
493+
for original_values in pixel_values:
494+
batch_size = original_values.shape[0]
495+
end_idx = start_idx + batch_size
496+
orig_h, orig_w = original_values.shape[-2:]
497+
padded_values = batched_pixels[start_idx:end_idx, :, :orig_h, :orig_w]
498+
torch.testing.assert_close(padded_values, original_values)
499+
500+
start_idx += batch_size

0 commit comments

Comments
 (0)