Skip to content

Commit 08ae56f

Browse files
committed
Update comments.
1 parent 5e0be6e commit 08ae56f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,13 @@ def forward(self, *inputs: Any) -> Any:
505505
aug = self._apply_image_or_video_transform(
506506
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
507507
)
508-
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
508+
mix.add_(
509+
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
510+
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
511+
# TODO: change this once all ops in `F` support float inputs.
512+
combined_weights[:, i].reshape(batch_dims)
513+
* aug
514+
)
509515
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
510516

511517
if isinstance(orig_image_or_video, (features.Image, features.Video)):

0 commit comments

Comments
 (0)