Skip to content

Commit 18a2969

Browse files
committed
Final code clean up.
1 parent eb932b9 commit 18a2969

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

references/classification/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
5454
target = target.clone()
5555

5656
if target.ndim == 1:
57-
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
57+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
5858

5959
if torch.rand(1).item() >= self.p:
6060
return batch, target
6161

6262
# It's faster to roll the batch by one instead of shuffling it to create image pairs
6363
batch_rolled = batch.roll(1, 0)
64-
target_rolled = target.roll(1)
64+
target_rolled = target.roll(1, 0)
6565

6666
# Implemented as on mixup paper, page 3.
6767
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
@@ -132,14 +132,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
132132
target = target.clone()
133133

134134
if target.ndim == 1:
135-
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
135+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
136136

137137
if torch.rand(1).item() >= self.p:
138138
return batch, target
139139

140140
# It's faster to roll the batch by one instead of shuffling it to create image pairs
141141
batch_rolled = batch.roll(1, 0)
142-
target_rolled = target.roll(1)
142+
target_rolled = target.roll(1, 0)
143143

144144
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
145145
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])

torchvision/transforms/transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
except ImportError:
1414
accimage = None
1515

16-
from . import functional as F
1716
from .functional import InterpolationMode, _interpolation_modes_from_int
1817

1918

0 commit comments

Comments
 (0)