diff --git a/references/classification/README.md b/references/classification/README.md index d18ab17bf73..d8c5eff8c05 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -124,22 +124,9 @@ Training converges at about 10 epochs. For post training quant, device is set to CPU. For training, the device is set to CUDA ### Command to evaluate quantized models using the pre-trained weights: -For all quantized models except inception_v3: +For all quantized models: ``` python references/classification/train_quantization.py --data-path='imagenet_full_size/' \ --device='cpu' --test-only --backend='fbgemm' --model='' ``` -For inception_v3, since it expects tensors with a size of N x 3 x 299 x 299, before running above command, -need to change the input size of dataset_test in train.py to: -``` -dataset_test = torchvision.datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(342), - transforms.CenterCrop(299), - transforms.ToTensor(), - normalize, - ])) -``` - diff --git a/references/classification/presets.py b/references/classification/presets.py new file mode 100644 index 00000000000..6bb389ba8db --- /dev/null +++ b/references/classification/presets.py @@ -0,0 +1,37 @@ +from torchvision.transforms import autoaugment, transforms + + +class ClassificationPresetTrain: + def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, + auto_augment_policy=None, random_erase_prob=0.0): + trans = [transforms.RandomResizedCrop(crop_size)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append(autoaugment.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + + self.transforms = transforms.Compose([ + transforms.Resize(resize_size), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + def __call__(self, img): + return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 47a7e5955e6..522ceaf3daa 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -6,8 +6,8 @@ import torch.utils.data from torch import nn import torchvision -from torchvision import transforms +import presets import utils try: @@ -82,8 +82,7 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224) print("Loading training data") st = time.time() @@ -93,22 +92,10 @@ def load_data(traindir, valdir, args): print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: - trans = [ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - ] - if args.auto_augment is not None: - aa_policy = transforms.AutoAugmentPolicy(args.auto_augment) - trans.append(transforms.AutoAugment(policy=aa_policy)) - trans.extend([ - transforms.ToTensor(), - normalize, - ]) - if args.random_erase > 0: - trans.append(transforms.RandomErasing(p=args.random_erase)) dataset = torchvision.datasets.ImageFolder( traindir, - transforms.Compose(trans)) + presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment, + random_erase_prob=args.random_erase)) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) @@ -124,12 +111,7 @@ def load_data(traindir, valdir, args): else: dataset_test = torchvision.datasets.ImageFolder( valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) diff --git a/references/detection/presets.py b/references/detection/presets.py new file mode 100644 index 00000000000..b0c86ed1265 --- /dev/null +++ b/references/detection/presets.py @@ -0,0 +1,21 @@ +import transforms as T + + +class DetectionPresetTrain: + def __init__(self, hflip_prob=0.5): + trans = [T.ToTensor()] + if hflip_prob > 0: + trans.append(T.RandomHorizontalFlip(hflip_prob)) + + self.transforms = T.Compose(trans) + + def __call__(self, img, target): + return self.transforms(img, target) + + +class DetectionPresetEval: + def __init__(self): + self.transforms = T.ToTensor() + + def __call__(self, img, target): + return self.transforms(img, target) diff --git a/references/detection/train.py b/references/detection/train.py index 7aa71314230..83fad36d2cc 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -32,8 +32,8 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups from engine import train_one_epoch, evaluate +import presets import utils -import transforms as T def get_dataset(name, image_set, transform, data_path): @@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train): - transforms = [] - transforms.append(T.ToTensor()) - if train: - transforms.append(T.RandomHorizontalFlip(0.5)) - return T.Compose(transforms) + return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval() def main(args): diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py new file mode 100644 index 00000000000..3bf29c23751 --- /dev/null +++ b/references/segmentation/presets.py @@ -0,0 +1,32 @@ +import transforms as T + + +class SegmentationPresetTrain: + def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + min_size = int(0.5 * base_size) + max_size = int(2.0 * base_size) + + trans = [T.RandomResize(min_size, max_size)] + if hflip_prob > 0: + trans.append(T.RandomHorizontalFlip(hflip_prob)) + trans.extend([ + T.RandomCrop(crop_size), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ]) + self.transforms = T.Compose(trans) + + def __call__(self, img, target): + return self.transforms(img, target) + + +class SegmentationPresetEval: + def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + self.transforms = T.Compose([ + T.RandomResize(base_size, base_size), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ]) + + def __call__(self, img, target): + return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5e5e5615e19..690e248323e 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -8,7 +8,7 @@ import torchvision from coco_utils import get_coco -import transforms as T +import presets import utils @@ -30,18 +30,7 @@ def get_transform(train): base_size = 520 crop_size = 480 - min_size = int((0.5 if train else 1.0) * base_size) - max_size = int((2.0 if train else 1.0) * base_size) - transforms = [] - transforms.append(T.RandomResize(min_size, max_size)) - if train: - transforms.append(T.RandomHorizontalFlip(0.5)) - transforms.append(T.RandomCrop(crop_size)) - transforms.append(T.ToTensor()) - transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])) - - return T.Compose(transforms) + return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) def criterion(inputs, target): diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py new file mode 100644 index 00000000000..3ee679ad5af --- /dev/null +++ b/references/video_classification/presets.py @@ -0,0 +1,40 @@ +import torch + +from torchvision.transforms import transforms +from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW + + +class VideoClassificationPresetTrain: + def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), + hflip_prob=0.5): + trans = [ + ConvertBHWCtoBCHW(), + transforms.ConvertImageDtype(torch.float32), + transforms.Resize(resize_size), + ] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + trans.extend([ + transforms.Normalize(mean=mean, std=std), + transforms.RandomCrop(crop_size), + ConvertBCHWtoCBHW() + ]) + self.transforms = transforms.Compose(trans) + + def __call__(self, x): + return self.transforms(x) + + +class VideoClassificationPresetEval: + def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + self.transforms = transforms.Compose([ + ConvertBHWCtoBCHW(), + transforms.ConvertImageDtype(torch.float32), + transforms.Resize(resize_size), + transforms.Normalize(mean=mean, std=std), + transforms.CenterCrop(crop_size), + ConvertBCHWtoCBHW() + ]) + + def __call__(self, x): + return self.transforms(x) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 3b5d8d8d206..bcc74064344 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -7,13 +7,12 @@ from torch import nn import torchvision import torchvision.datasets.video_utils -from torchvision import transforms as T from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler +import presets import utils from scheduler import WarmupMultiStepLR -from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW try: from apex import amp @@ -112,21 +111,11 @@ def main(args): print("Loading data") traindir = os.path.join(args.data_path, args.train_dir) valdir = os.path.join(args.data_path, args.val_dir) - normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], - std=[0.22803, 0.22145, 0.216989]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - transform_train = torchvision.transforms.Compose([ - ConvertBHWCtoBCHW(), - T.ConvertImageDtype(torch.float32), - T.Resize((128, 171)), - T.RandomHorizontalFlip(), - normalize, - T.RandomCrop((112, 112)), - ConvertBCHWtoCBHW() - ]) + transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) if args.cache_dataset and os.path.exists(cache_path): print("Loading dataset_train from {}".format(cache_path)) @@ -154,14 +143,7 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - transform_test = torchvision.transforms.Compose([ - ConvertBHWCtoBCHW(), - T.ConvertImageDtype(torch.float32), - T.Resize((128, 171)), - normalize, - T.CenterCrop((112, 112)), - ConvertBCHWtoCBHW() - ]) + transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) if args.cache_dataset and os.path.exists(cache_path): print("Loading dataset_test from {}".format(cache_path))