Skip to content

Commit 7ebb770

Browse files
authored
[None][fix] Fix batching bug in Mistral3 model (NVIDIA#6841)
Prior to this commit, if multiple requests with images were in the same batch, the batching logic for the images would fail. This commit fixes it, and adds unit tests for it that were verified to fail prior to the fix. Signed-off-by: William Zhang <[email protected]>
1 parent b4167cc commit 7ebb770

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def forward(
373373
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
374374
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
375375
f"({multimodal_params_len}).")
376-
batched_pixel_values, batched_image_sizes = self._batch_pixel_values(
376+
batched_pixel_values, batched_image_sizes = self.batch_pixel_values(
377377
pixel_values=pixel_values, image_sizes=image_sizes)
378378
mm_embeds = [
379379
self._get_image_features(pixel_values=batched_pixel_values,
@@ -440,21 +440,38 @@ def _get_image_features(
440440
# (the transformers one expected numpy arrays).
441441
@staticmethod
442442
@torch.inference_mode()
443-
def _batch_pixel_values(
443+
def batch_pixel_values(
444444
pixel_values: List[torch.Tensor],
445445
image_sizes: List[torch.Tensor],
446446
) -> tuple[torch.Tensor, torch.Tensor]:
447+
# NOTES:
448+
# * `pixel_values` is a list of `[B_idx, C, H_idx, W_idx]` tensors, i.e. a batch of images as
449+
# padded + batched by the input processor.
450+
# The height (H_idx) and width (W_idx) of each element need not coincide.
451+
# * Similarly, each element in `image_sizes` describes the original image sizes prior to
452+
# padding for the corresponding element in `pixel_values`.
453+
454+
# The below creates a single `[sum(B_idx), 2]` tensor describing all image sizes, and then
455+
# calculates the maximum height / width across all of them.
447456
batched_image_sizes = torch.cat(image_sizes)
448457
max_shape = batched_image_sizes.max(dim=0).values
458+
459+
# This next step then pads the pixel values potentially a second time by using the `max_shape`
460+
# computed above. Note that as far as this function is concerned, the original sizes for
461+
# batching purposes can be deduced from looking at the tensors in `pixel_values`, NOT in
462+
# `image_sizes`.
449463
pixel_values = [
450464
torchvision.transforms.v2.functional.pad(
451465
image,
452466
# Per torchvision docs, this should be in LTRB order if it's a sequence of 4 numbers.
453-
padding=[0, 0, max_shape[1] - size[1], max_shape[0] - size[0]],
467+
padding=[
468+
0, 0, max_shape[1] - image.shape[-1],
469+
max_shape[0] - image.shape[-2]
470+
],
454471
# Values extracted from HF implementation.
455472
fill=0.0,
456473
padding_mode="constant",
457-
) for image, size in zip(pixel_values, batched_image_sizes)
474+
) for image in pixel_values
458475
]
459476
return torch.cat(pixel_values), batched_image_sizes
460477

tests/unittest/_torch/modeling/test_modeling_mistral.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,58 @@ def run_forward(input_ids, position_ids, attn_metadata):
438438
)
439439

440440
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
441+
442+
443+
@pytest.mark.parametrize(
444+
"in_shapes, image_sizes, expected_out_shape",
445+
[
446+
(
447+
[(2, 3, 100, 150), (1, 3, 200, 100), (3, 3, 120, 180)],
448+
[
449+
[[92, 150], [100, 73]],
450+
[[200, 100]],
451+
[[37, 130], [120, 83], [73, 180]],
452+
],
453+
[6, 3, 200, 180],
454+
),
455+
# Single batch, single image.
456+
(
457+
[(1, 3, 64, 128)],
458+
[[[64, 128]]],
459+
[1, 3, 64, 128],
460+
),
461+
# Same max size across batches.
462+
(
463+
[(2, 3, 59, 59), (1, 3, 59, 59), (5, 3, 59, 59)],
464+
[
465+
[[13, 59], [59, 17]],
466+
[[59, 59]],
467+
[[19, 29], [59, 31], [17, 54], [13, 59], [11, 37]],
468+
],
469+
[8, 3, 59, 59],
470+
),
471+
],
472+
)
473+
def test_batch_pixel_values(in_shapes, image_sizes, expected_out_shape):
474+
# Test case 1: Basic functionality with different sized images
475+
pixel_values = [torch.randn(*shape) for shape in in_shapes]
476+
image_sizes = [torch.tensor(size) for size in image_sizes]
477+
478+
batched_pixels, batched_sizes = modeling_mistral.Mistral3VLM.batch_pixel_values(
479+
pixel_values, image_sizes
480+
)
481+
482+
# Check output shapes
483+
assert list(batched_pixels.shape) == expected_out_shape
484+
assert list(batched_sizes.shape) == [expected_out_shape[0], 2]
485+
486+
# Check that the original image data is preserved (with padding).
487+
start_idx = 0
488+
for original_values in pixel_values:
489+
batch_size = original_values.shape[0]
490+
end_idx = start_idx + batch_size
491+
orig_h, orig_w = original_values.shape[-2:]
492+
padded_values = batched_pixels[start_idx:end_idx, :, :orig_h, :orig_w]
493+
torch.testing.assert_close(padded_values, original_values)
494+
495+
start_idx += batch_size

0 commit comments

Comments
 (0)