diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 9f3efe30341..e3bdbd55abd 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran AugMix v2.AugMix -Cutmix - Mixup +CutMix - MixUp -------------- -Cutmix and Mixup are special transforms that +CutMix and MixUp are special transforms that are meant to be used on batches rather than on individual images, because they -are combining pairs of images together. These can be used after the dataloader, -or part of a collation function. See +are combining pairs of images together. These can be used after the dataloader +(once the samples are batched), or part of a collation function. See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. .. autosummary:: diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index 19838fe907d..d1c92a27812 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -1,8 +1,152 @@ """ =========================== -How to use Cutmix and Mixup +How to use CutMix and MixUp =========================== -TODO +:class:`~torchvision.transforms.v2.Cutmix` and +:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies +that can improve classification accuracy. + +These transforms are slightly different from the rest of the Torchvision +transforms, because they expect +**batches** of samples as input, not individual images. In this example we'll +explain how to use them: after the ``DataLoader``, or as part of a collation +function. """ + +# %% +import torch +import torchvision +from torchvision.datasets import FakeData + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision.transforms import v2 + + +NUM_CLASSES = 100 + +# %% +# Pre-processing pipeline +# ----------------------- +# +# We'll use a simple but typical image classification pipeline: + +preproc = v2.Compose([ + v2.PILToTensor(), + v2.RandomResizedCrop(size=(224, 224), antialias=True), + v2.RandomHorizontalFlip(p=0.5), + v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet +]) + +dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) + +img, label = dataset[0] +print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") + +# %% +# +# One important thing to note is that neither CutMix nor MixUp are part of this +# pre-processing pipeline. We'll add them a bit later once we define the +# DataLoader. Just as a refresher, this is what the DataLoader and training loop +# would look like if we weren't using CutMix or MixUp: + +from torch.utils.data import DataLoader + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + print(labels.dtype) + # + break +# %% + +# %% +# Where to use MixUp and CutMix +# ----------------------------- +# +# After the DataLoader +# ^^^^^^^^^^^^^^^^^^^^ +# +# Now let's add CutMix and MixUp. The simplest way to do this right after the +# DataLoader: the Dataloader has already batched the images and labels for us, +# and this is exactly what these transforms expect as input: + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +cutmix = v2.Cutmix(num_classes=NUM_CLASSES) +mixup = v2.Mixup(num_classes=NUM_CLASSES) +cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) + +for images, labels in dataloader: + print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") + images, labels = cutmix_or_mixup(images, labels) + print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") + + # + break +# %% +# +# Note how the labels were also transformed: we went from a batched label of +# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The +# transformed labels can still be passed as-is to a loss function like +# :func:`torch.nn.functional.cross_entropy`. +# +# As part of the collation function +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Passing the transforms after the DataLoader is the simplest way to use CutMix +# and MixUp, but one disadvantage is that it does not take advantage of the +# DataLoader multi-processing. For that, we can pass those transforms as part of +# the collation function (refer to the `PyTorch docs +# `_ to learn +# more about collation). + +from torch.utils.data import default_collate + + +def collate_fn(batch): + return cutmix_or_mixup(*default_collate(batch)) + + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! + # + break + +# %% +# Non-standard input format +# ------------------------- +# +# So far we've used a typical sample structure where we pass ``(images, +# labels)`` as inputs. MixUp and CutMix will magically work by default with most +# common sample structures: tuples where the second parameter is a tensor label, +# or dict with a "label[s]" key. Look at the documentation of the +# ``labels_getter`` parameter for more details. +# +# If your samples have a different structure, you can still use CutMix and MixUp +# by passing a callable to the ``labels_getter`` parameter. For example: + +batch = { + "imgs": torch.rand(4, 3, 224, 224), + "target": { + "classes": torch.randint(0, NUM_CLASSES, size=(4,)), + "some_other_key": "this is going to be passed-through" + } +} + + +def labels_getter(batch): + return batch["target"]["classes"] + + +out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) +print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 3b808d6b73c..f4e00a2b8f5 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1922,7 +1922,7 @@ def test_supported_input_structure(self, T): dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) - cutmix_mixup = T(alpha=0.5, num_classes=num_classes) + cutmix_mixup = T(num_classes=num_classes) dl = DataLoader(dataset, batch_size=batch_size) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index f9038c6af32..2c6844c969e 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -141,9 +141,9 @@ def _transform( class _BaseMixupCutmix(Transform): - def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: super().__init__() - self.alpha = alpha + self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self.num_classes = num_classes @@ -204,13 +204,20 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: class Mixup(_BaseMixupCutmix): - """[BETA] Apply Mixup to the provided batch of images and labels. + """[BETA] Apply MixUp to the provided batch of images and labels. .. v2betastatus:: Mixup transform Paper: `mixup: Beyond Empirical Risk Minimization `_. - See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. @@ -246,14 +253,21 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class Cutmix(_BaseMixupCutmix): - """[BETA] Apply Cutmix to the provided batch of images and labels. + """[BETA] Apply CutMix to the provided batch of images and labels. .. v2betastatus:: Cutmix transform Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features `_. - See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``.