@@ -54,14 +54,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
54
54
target = target .clone ()
55
55
56
56
if target .ndim == 1 :
57
- target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch . float32 )
57
+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = batch . dtype )
58
58
59
59
if torch .rand (1 ).item () >= self .p :
60
60
return batch , target
61
61
62
62
# It's faster to roll the batch by one instead of shuffling it to create image pairs
63
63
batch_rolled = batch .roll (1 , 0 )
64
- target_rolled = target .roll (1 )
64
+ target_rolled = target .roll (1 , 0 )
65
65
66
66
# Implemented as on mixup paper, page 3.
67
67
lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
@@ -132,14 +132,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
132
132
target = target .clone ()
133
133
134
134
if target .ndim == 1 :
135
- target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch . float32 )
135
+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = batch . dtype )
136
136
137
137
if torch .rand (1 ).item () >= self .p :
138
138
return batch , target
139
139
140
140
# It's faster to roll the batch by one instead of shuffling it to create image pairs
141
141
batch_rolled = batch .roll (1 , 0 )
142
- target_rolled = target .roll (1 )
142
+ target_rolled = target .roll (1 , 0 )
143
143
144
144
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
145
145
lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
0 commit comments