Skip to content

Commit a2e9306

Browse files
committed
Adding presets in the segmentation reference scripts.
1 parent 9f7a0f7 commit a2e9306

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

references/segmentation/presets.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import transforms as T
2+
3+
4+
class SegmentationPresetTrain:
5+
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
6+
min_size = int(0.5 * base_size)
7+
max_size = int(2.0 * base_size)
8+
9+
trans = [T.RandomResize(min_size, max_size)]
10+
if hflip_prob > 0:
11+
trans.append(T.RandomHorizontalFlip(hflip_prob))
12+
trans.extend([
13+
T.RandomCrop(crop_size),
14+
T.ToTensor(),
15+
T.Normalize(mean=mean, std=std),
16+
])
17+
self.transforms = T.Compose(trans)
18+
19+
def __call__(self, img, target):
20+
return self.transforms(img, target)
21+
22+
23+
class SegmentationPresetEval:
24+
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
25+
self.transforms = T.Compose([
26+
T.RandomResize(base_size, base_size),
27+
T.ToTensor(),
28+
T.Normalize(mean=mean, std=std),
29+
])
30+
31+
def __call__(self, img, target):
32+
return self.transforms(img, target)

references/segmentation/train.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision
99

1010
from coco_utils import get_coco
11-
import transforms as T
11+
import presets
1212
import utils
1313

1414

@@ -30,18 +30,7 @@ def get_transform(train):
3030
base_size = 520
3131
crop_size = 480
3232

33-
min_size = int((0.5 if train else 1.0) * base_size)
34-
max_size = int((2.0 if train else 1.0) * base_size)
35-
transforms = []
36-
transforms.append(T.RandomResize(min_size, max_size))
37-
if train:
38-
transforms.append(T.RandomHorizontalFlip(0.5))
39-
transforms.append(T.RandomCrop(crop_size))
40-
transforms.append(T.ToTensor())
41-
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
42-
std=[0.229, 0.224, 0.225]))
43-
44-
return T.Compose(transforms)
33+
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
4534

4635

4736
def criterion(inputs, target):

0 commit comments

Comments
 (0)