Skip to content

Commit 992d41f

Browse files
committed
Adding presets in the classification reference scripts.
1 parent 7621a8e commit 992d41f

File tree

3 files changed

+44
-37
lines changed

3 files changed

+44
-37
lines changed

references/classification/README.md

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,22 +124,9 @@ Training converges at about 10 epochs.
124124
For post training quant, device is set to CPU. For training, the device is set to CUDA
125125

126126
### Command to evaluate quantized models using the pre-trained weights:
127-
For all quantized models except inception_v3:
127+
For all quantized models:
128128
```
129129
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \
130130
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
131131
```
132132

133-
For inception_v3, since it expects tensors with a size of N x 3 x 299 x 299, before running above command,
134-
need to change the input size of dataset_test in train.py to:
135-
```
136-
dataset_test = torchvision.datasets.ImageFolder(
137-
valdir,
138-
transforms.Compose([
139-
transforms.Resize(342),
140-
transforms.CenterCrop(299),
141-
transforms.ToTensor(),
142-
normalize,
143-
]))
144-
```
145-

references/classification/presets.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from torchvision.transforms import autoaugment, transforms
2+
3+
4+
class ClassificationPresetTrain:
5+
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), auto_augment_policy=None,
6+
random_erase_prob=0.0):
7+
trans = [
8+
transforms.RandomResizedCrop(crop_size),
9+
transforms.RandomHorizontalFlip(),
10+
]
11+
if auto_augment_policy is not None:
12+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
13+
trans.append(autoaugment.AutoAugment(policy=aa_policy))
14+
trans.extend([
15+
transforms.ToTensor(),
16+
transforms.Normalize(mean=mean, std=std),
17+
])
18+
if random_erase_prob > 0:
19+
trans.append(transforms.RandomErasing(p=random_erase_prob))
20+
21+
self.transforms = transforms.Compose(trans)
22+
23+
def __call__(self, img):
24+
return self.transforms(img)
25+
26+
27+
class ClassificationPresetEval:
28+
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29+
30+
self.transforms = transforms.Compose([
31+
transforms.Resize(resize_size),
32+
transforms.CenterCrop(crop_size),
33+
transforms.ToTensor(),
34+
transforms.Normalize(mean=mean, std=std),
35+
])
36+
37+
def __call__(self, img):
38+
return self.transforms(img)

references/classification/train.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch.utils.data
77
from torch import nn
88
import torchvision
9-
from torchvision import transforms
109

10+
import presets
1111
import utils
1212

1313
try:
@@ -82,8 +82,7 @@ def _get_cache_path(filepath):
8282
def load_data(traindir, valdir, args):
8383
# Data loading code
8484
print("Loading data")
85-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
86-
std=[0.229, 0.224, 0.225])
85+
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
8786

8887
print("Loading training data")
8988
st = time.time()
@@ -93,22 +92,10 @@ def load_data(traindir, valdir, args):
9392
print("Loading dataset_train from {}".format(cache_path))
9493
dataset, _ = torch.load(cache_path)
9594
else:
96-
trans = [
97-
transforms.RandomResizedCrop(224),
98-
transforms.RandomHorizontalFlip(),
99-
]
100-
if args.auto_augment is not None:
101-
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
102-
trans.append(transforms.AutoAugment(policy=aa_policy))
103-
trans.extend([
104-
transforms.ToTensor(),
105-
normalize,
106-
])
107-
if args.random_erase > 0:
108-
trans.append(transforms.RandomErasing(p=args.random_erase))
10995
dataset = torchvision.datasets.ImageFolder(
11096
traindir,
111-
transforms.Compose(trans))
97+
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
98+
random_erase_prob=args.random_erase))
11299
if args.cache_dataset:
113100
print("Saving dataset_train to {}".format(cache_path))
114101
utils.mkdir(os.path.dirname(cache_path))
@@ -124,12 +111,7 @@ def load_data(traindir, valdir, args):
124111
else:
125112
dataset_test = torchvision.datasets.ImageFolder(
126113
valdir,
127-
transforms.Compose([
128-
transforms.Resize(256),
129-
transforms.CenterCrop(224),
130-
transforms.ToTensor(),
131-
normalize,
132-
]))
114+
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
133115
if args.cache_dataset:
134116
print("Saving dataset_test to {}".format(cache_path))
135117
utils.mkdir(os.path.dirname(cache_path))

0 commit comments

Comments
 (0)