Skip to content

Commit 1c46a69

Browse files
committed
Refactoring to use Compose.
1 parent c337d76 commit 1c46a69

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

references/classification/presets.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,37 @@
1-
from torch import nn, Tensor
2-
from torchvision.transforms import autoaugment, transforms
1+
from torchvision.transforms import *
32

43

5-
class ClassificationPresetTrain(nn.Module):
4+
class ClassificationPresetTrain:
65
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), auto_augment_policy=None,
76
random_erase_prob=0.0):
8-
super().__init__()
9-
107
trans = [
11-
transforms.RandomResizedCrop(crop_size),
12-
transforms.RandomHorizontalFlip(),
8+
RandomResizedCrop(crop_size),
9+
RandomHorizontalFlip(),
1310
]
1411
if auto_augment_policy is not None:
15-
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
16-
trans.append(autoaugment.AutoAugment(policy=aa_policy))
12+
aa_policy = AutoAugmentPolicy(auto_augment_policy)
13+
trans.append(AutoAugment(policy=aa_policy))
1714
trans.extend([
18-
transforms.ToTensor(),
19-
transforms.Normalize(mean=mean, std=std),
15+
ToTensor(),
16+
Normalize(mean=mean, std=std),
2017
])
2118
if random_erase_prob > 0:
22-
trans.append(transforms.RandomErasing(p=random_erase_prob))
19+
trans.append(RandomErasing(p=random_erase_prob))
2320

24-
self.transforms = nn.Sequential(*trans)
21+
self.transforms = Compose(trans)
2522

26-
def forward(self, img: Tensor) -> Tensor:
23+
def __call__(self, img):
2724
return self.transforms(img)
2825

2926

30-
class ClassificationPresetEval(nn.Module):
27+
class ClassificationPresetEval:
3128
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
32-
super().__init__()
33-
34-
trans = [
35-
transforms.Resize(resize_size),
36-
transforms.CenterCrop(crop_size),
37-
transforms.ToTensor(),
38-
transforms.Normalize(mean=mean, std=std),
39-
]
40-
self.transforms = nn.Sequential(*trans)
29+
self.transforms = Compose([
30+
Resize(resize_size),
31+
CenterCrop(crop_size),
32+
ToTensor(),
33+
Normalize(mean=mean, std=std),
34+
])
4135

42-
def forward(self, img: Tensor) -> Tensor:
36+
def __call__(self, img):
4337
return self.transforms(img)

0 commit comments

Comments
 (0)