@@ -113,7 +113,12 @@ def _get_cache_path(filepath):
113113def load_data (traindir , valdir , args ):
114114 # Data loading code
115115 print ("Loading data" )
116- val_resize_size , val_crop_size , train_crop_size = args .val_resize_size , args .val_crop_size , args .train_crop_size
116+ val_resize_size , val_crop_size , train_crop_size , center_crop = (
117+ args .val_resize_size ,
118+ args .val_crop_size ,
119+ args .train_crop_size ,
120+ args .train_center_crop ,
121+ )
117122 interpolation = InterpolationMode (args .interpolation )
118123
119124 print ("Loading training data" )
@@ -126,13 +131,18 @@ def load_data(traindir, valdir, args):
126131 else :
127132 auto_augment_policy = getattr (args , "auto_augment" , None )
128133 random_erase_prob = getattr (args , "random_erase" , 0.0 )
134+ ra_magnitude = args .ra_magnitude
135+ augmix_severity = args .augmix_severity
129136 dataset = torchvision .datasets .ImageFolder (
130137 traindir ,
131138 presets .ClassificationPresetTrain (
139+ center_crop = center_crop ,
132140 crop_size = train_crop_size ,
133141 interpolation = interpolation ,
134142 auto_augment_policy = auto_augment_policy ,
135143 random_erase_prob = random_erase_prob ,
144+ ra_magnitude = ra_magnitude ,
145+ augmix_severity = augmix_severity ,
136146 ),
137147 )
138148 if args .cache_dataset :
@@ -207,7 +217,10 @@ def main(args):
207217 mixup_transforms .append (transforms .RandomCutmix (num_classes , p = 1.0 , alpha = args .cutmix_alpha ))
208218 if mixup_transforms :
209219 mixupcutmix = torchvision .transforms .RandomChoice (mixup_transforms )
210- collate_fn = lambda batch : mixupcutmix (* default_collate (batch )) # noqa: E731
220+
221+ def collate_fn (batch ):
222+ return mixupcutmix (* default_collate (batch ))
223+
211224 data_loader = torch .utils .data .DataLoader (
212225 dataset ,
213226 batch_size = args .batch_size ,
@@ -448,6 +461,8 @@ def get_args_parser(add_help=True):
448461 action = "store_true" ,
449462 )
450463 parser .add_argument ("--auto-augment" , default = None , type = str , help = "auto augment policy (default: None)" )
464+ parser .add_argument ("--ra-magnitude" , default = 9 , type = int , help = "magnitude of auto augment policy" )
465+ parser .add_argument ("--augmix-severity" , default = 3 , type = int , help = "severity of augmix policy" )
451466 parser .add_argument ("--random-erase" , default = 0.0 , type = float , help = "random erasing probability (default: 0.0)" )
452467
453468 # Mixed precision training parameters
@@ -486,13 +501,17 @@ def get_args_parser(add_help=True):
486501 parser .add_argument (
487502 "--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
488503 )
504+ parser .add_argument (
505+ "--train-center-crop" ,
506+ action = "store_true" ,
507+ help = "use center crop instead of random crop for training (default: False)" ,
508+ )
489509 parser .add_argument ("--clip-grad-norm" , default = None , type = float , help = "the maximum gradient norm (default None)" )
490510 parser .add_argument ("--ra-sampler" , action = "store_true" , help = "whether to use Repeated Augmentation in training" )
491511 parser .add_argument (
492512 "--ra-reps" , default = 3 , type = int , help = "number of repetitions for Repeated Augmentation (default: 3)"
493513 )
494514 parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
495-
496515 return parser
497516
498517
0 commit comments