Skip to content

Commit e7f61d8

Browse files
authored
Merge branch 'main' into datasets/download
2 parents 5a4a2cb + 4282c9f commit e7f61d8

File tree

11 files changed

+1038
-11
lines changed

11 files changed

+1038
-11
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ jobs:
311311
descr: Install Python type check utilities
312312
- run:
313313
name: Check Python types statically
314-
command: mypy --config-file mypy.ini
314+
command: mypy --install-types --non-interactive --config-file mypy.ini
315315

316316
unittest_torchhub:
317317
docker:

packaging/torchvision/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ requirements:
2424
run:
2525
- python
2626
- defaults::numpy >=1.11
27+
- requests
2728
- libpng
2829
- ffmpeg >=4.2 # [not win]
2930
- jpeg

references/classification/sampler.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import math
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
7+
class RASampler(torch.utils.data.Sampler):
8+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
9+
with repeated augmentation.
10+
It ensures that different each augmented version of a sample will be visible to a
11+
different process (GPU).
12+
Heavily based on 'torch.utils.data.DistributedSampler'.
13+
14+
This is borrowed from the DeiT Repo:
15+
https://github.com/facebookresearch/deit/blob/main/samplers.py
16+
"""
17+
18+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
19+
if num_replicas is None:
20+
if not dist.is_available():
21+
raise RuntimeError("Requires distributed package to be available!")
22+
num_replicas = dist.get_world_size()
23+
if rank is None:
24+
if not dist.is_available():
25+
raise RuntimeError("Requires distributed package to be available!")
26+
rank = dist.get_rank()
27+
self.dataset = dataset
28+
self.num_replicas = num_replicas
29+
self.rank = rank
30+
self.epoch = 0
31+
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
32+
self.total_size = self.num_samples * self.num_replicas
33+
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34+
self.shuffle = shuffle
35+
36+
def __iter__(self):
37+
# Deterministically shuffle based on epoch
38+
g = torch.Generator()
39+
g.manual_seed(self.epoch)
40+
if self.shuffle:
41+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
42+
else:
43+
indices = list(range(len(self.dataset)))
44+
45+
# Add extra samples to make it evenly divisible
46+
indices = [ele for ele in indices for i in range(3)]
47+
indices += indices[: (self.total_size - len(indices))]
48+
assert len(indices) == self.total_size
49+
50+
# Subsample
51+
indices = indices[self.rank : self.total_size : self.num_replicas]
52+
assert len(indices) == self.num_samples
53+
54+
return iter(indices[: self.num_selected_samples])
55+
56+
def __len__(self):
57+
return self.num_selected_samples
58+
59+
def set_epoch(self, epoch):
60+
self.epoch = epoch

references/classification/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torchvision
1010
import transforms
1111
import utils
12+
from references.classification.sampler import RASampler
1213
from torch import nn
1314
from torch.utils.data.dataloader import default_collate
1415
from torchvision.transforms.functional import InterpolationMode
@@ -172,7 +173,10 @@ def load_data(traindir, valdir, args):
172173

173174
print("Creating data loaders")
174175
if args.distributed:
175-
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
176+
if args.ra_sampler:
177+
train_sampler = RASampler(dataset, shuffle=True)
178+
else:
179+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
176180
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
177181
else:
178182
train_sampler = torch.utils.data.RandomSampler(dataset)
@@ -481,6 +485,7 @@ def get_args_parser(add_help=True):
481485
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
482486
)
483487
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
488+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")
484489

485490
# Prototype models only
486491
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

references/optical_flow/presets.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import transforms as T
3+
4+
5+
class OpticalFlowPresetEval(torch.nn.Module):
6+
def __init__(self):
7+
super().__init__()
8+
9+
self.transforms = T.Compose(
10+
[
11+
T.PILToTensor(),
12+
T.ConvertImageDtype(torch.float32),
13+
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
14+
T.ValidateModelInput(),
15+
]
16+
)
17+
18+
def forward(self, img1, img2, flow, valid):
19+
return self.transforms(img1, img2, flow, valid)
20+
21+
22+
class OpticalFlowPresetTrain(torch.nn.Module):
23+
def __init__(
24+
self,
25+
# RandomResizeAndCrop params
26+
crop_size,
27+
min_scale=-0.2,
28+
max_scale=0.5,
29+
stretch_prob=0.8,
30+
# AsymmetricColorJitter params
31+
brightness=0.4,
32+
contrast=0.4,
33+
saturation=0.4,
34+
hue=0.5 / 3.14,
35+
# Random[H,V]Flip params
36+
asymmetric_jitter_prob=0.2,
37+
do_flip=True,
38+
):
39+
super().__init__()
40+
41+
transforms = [
42+
T.PILToTensor(),
43+
T.AsymmetricColorJitter(
44+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
45+
),
46+
T.RandomResizeAndCrop(
47+
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
48+
),
49+
]
50+
51+
if do_flip:
52+
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
53+
54+
transforms += [
55+
T.ConvertImageDtype(torch.float32),
56+
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
57+
T.RandomErasing(max_erase=2),
58+
T.MakeValidFlowMask(),
59+
T.ValidateModelInput(),
60+
]
61+
self.transforms = T.Compose(transforms)
62+
63+
def forward(self, img1, img2, flow, valid):
64+
return self.transforms(img1, img2, flow, valid)

0 commit comments

Comments
 (0)