Skip to content

Commit 9f7a0f7

Browse files
committed
Adding presets in the object detection reference scripts.
1 parent 992d41f commit 9f7a0f7

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

references/detection/presets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import transforms as T
2+
3+
4+
class DetectionPresetTrain:
5+
def __init__(self, hflip_prob=0.5):
6+
trans = [T.ToTensor()]
7+
if hflip_prob > 0:
8+
trans.append(T.RandomHorizontalFlip(hflip_prob))
9+
10+
self.transforms = T.Compose(trans)
11+
12+
def __call__(self, img, target):
13+
return self.transforms(img, target)
14+
15+
16+
class DetectionPresetEval:
17+
def __init__(self):
18+
self.transforms = T.ToTensor()
19+
20+
def __call__(self, img, target):
21+
return self.transforms(img, target)

references/detection/train.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
3333
from engine import train_one_epoch, evaluate
3434

35+
import presets
3536
import utils
36-
import transforms as T
3737

3838

3939
def get_dataset(name, image_set, transform, data_path):
@@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path):
4848

4949

5050
def get_transform(train):
51-
transforms = []
52-
transforms.append(T.ToTensor())
53-
if train:
54-
transforms.append(T.RandomHorizontalFlip(0.5))
55-
return T.Compose(transforms)
51+
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
5652

5753

5854
def main(args):

0 commit comments

Comments
 (0)