Skip to content

Adding Preset Transforms in reference scripts #3317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,9 @@ Training converges at about 10 epochs.
For post training quant, device is set to CPU. For training, the device is set to CUDA

### Command to evaluate quantized models using the pre-trained weights:
For all quantized models except inception_v3:
For all quantized models:
```
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
```

For inception_v3, since it expects tensors with a size of N x 3 x 299 x 299, before running above command,
need to change the input size of dataset_test in train.py to:
```
dataset_test = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(342),
transforms.CenterCrop(299),
transforms.ToTensor(),
normalize,
]))
```

37 changes: 37 additions & 0 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from torchvision.transforms import autoaugment, transforms


class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
auto_augment_policy=None, random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))

self.transforms = transforms.Compose(trans)

def __call__(self, img):
return self.transforms(img)


class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):

self.transforms = transforms.Compose([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, for text domain, we will need to download the transform, for example sentencepiece model, or a vocabulary saved in text file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm writing here what we discussed on the call.

It seems that supporting your case is possible by using PyTorch Hub's load_state_dict_from_url() method and then passing the result to your code. This is very common pattern in TorchVision, used mainly for pre-trained models. Example:

model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError("No checkpoint is available for model type {}".format(arch))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)

transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

def __call__(self, img):
return self.transforms(img)
28 changes: 5 additions & 23 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms

import presets
import utils

try:
Expand Down Expand Up @@ -82,8 +82,7 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)

print("Loading training data")
st = time.time()
Expand All @@ -93,22 +92,10 @@ def load_data(traindir, valdir, args):
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose(trans))
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
random_erase_prob=args.random_erase))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
Expand All @@ -124,12 +111,7 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
Expand Down
21 changes: 21 additions & 0 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import transforms as T


class DetectionPresetTrain:
def __init__(self, hflip_prob=0.5):
trans = [T.ToTensor()]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))

self.transforms = T.Compose(trans)

def __call__(self, img, target):
return self.transforms(img, target)


class DetectionPresetEval:
def __init__(self):
self.transforms = T.ToTensor()

def __call__(self, img, target):
return self.transforms(img, target)
8 changes: 2 additions & 6 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate

import presets
import utils
import transforms as T


def get_dataset(name, image_set, transform, data_path):
Expand All @@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path):


def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()


def main(args):
Expand Down
32 changes: 32 additions & 0 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import transforms as T


class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)

trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)

def __call__(self, img, target):
return self.transforms(img, target)


class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.RandomResize(base_size, base_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])

def __call__(self, img, target):
return self.transforms(img, target)
15 changes: 2 additions & 13 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision

from coco_utils import get_coco
import transforms as T
import presets
import utils


Expand All @@ -30,18 +30,7 @@ def get_transform(train):
base_size = 520
crop_size = 480

min_size = int((0.5 if train else 1.0) * base_size)
max_size = int((2.0 if train else 1.0) * base_size)
transforms = []
transforms.append(T.RandomResize(min_size, max_size))
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomCrop(crop_size))
transforms.append(T.ToTensor())
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))

return T.Compose(transforms)
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)


def criterion(inputs, target):
Expand Down
40 changes: 40 additions & 0 deletions references/video_classification/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from torchvision.transforms import transforms
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW


class VideoClassificationPresetTrain:
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989),
hflip_prob=0.5):
trans = [
ConvertBHWCtoBCHW(),
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
trans.extend([
transforms.Normalize(mean=mean, std=std),
transforms.RandomCrop(crop_size),
ConvertBCHWtoCBHW()
])
self.transforms = transforms.Compose(trans)

def __call__(self, x):
return self.transforms(x)


class VideoClassificationPresetEval:
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
self.transforms = transforms.Compose([
ConvertBHWCtoBCHW(),
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
transforms.Normalize(mean=mean, std=std),
transforms.CenterCrop(crop_size),
ConvertBCHWtoCBHW()
])

def __call__(self, x):
return self.transforms(x)
24 changes: 3 additions & 21 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from torch import nn
import torchvision
import torchvision.datasets.video_utils
from torchvision import transforms as T
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler

import presets
import utils

from scheduler import WarmupMultiStepLR
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW

try:
from apex import amp
Expand Down Expand Up @@ -112,21 +111,11 @@ def main(args):
print("Loading data")
traindir = os.path.join(args.data_path, args.train_dir)
valdir = os.path.join(args.data_path, args.val_dir)
normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
std=[0.22803, 0.22145, 0.216989])

print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
transform_train = torchvision.transforms.Compose([
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
T.RandomHorizontalFlip(),
normalize,
T.RandomCrop((112, 112)),
ConvertBCHWtoCBHW()
])
transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))

if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_train from {}".format(cache_path))
Expand Down Expand Up @@ -154,14 +143,7 @@ def main(args):
print("Loading validation data")
cache_path = _get_cache_path(valdir)

transform_test = torchvision.transforms.Compose([
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
normalize,
T.CenterCrop((112, 112)),
ConvertBCHWtoCBHW()
])
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))

if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_test from {}".format(cache_path))
Expand Down