11import datetime
22import os
3+ import random
34import time
45import warnings
56
1516from torchvision .transforms .functional import InterpolationMode
1617
1718
18- def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema = None , scaler = None ):
19+ def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema = None , scaler = None , scheduler = None ):
1920 model .train ()
2021 metric_logger = utils .MetricLogger (delimiter = " " )
2122 metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -43,6 +44,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
4344 if args .clip_grad_norm is not None :
4445 nn .utils .clip_grad_norm_ (model .parameters (), args .clip_grad_norm )
4546 optimizer .step ()
47+
48+ if scheduler is not None and args .lr_step_every_batch :
49+ scheduler .step ()
4650
4751 if model_ema and i % args .model_ema_steps == 0 :
4852 model_ema .update_parameters (model )
@@ -113,7 +117,7 @@ def _get_cache_path(filepath):
113117def load_data (traindir , valdir , args ):
114118 # Data loading code
115119 print ("Loading data" )
116- val_resize_size , val_crop_size , train_crop_size = args .val_resize_size , args .val_crop_size , args .train_crop_size
120+ val_resize_size , val_crop_size , train_crop_size , center_crop , policy_magnitude = args .val_resize_size , args .val_crop_size , args .train_crop_size , args . train_center_crop , args . policy_magnitude
117121 interpolation = InterpolationMode (args .interpolation )
118122
119123 print ("Loading training data" )
@@ -129,10 +133,12 @@ def load_data(traindir, valdir, args):
129133 dataset = torchvision .datasets .ImageFolder (
130134 traindir ,
131135 presets .ClassificationPresetTrain (
136+ center_crop = center_crop ,
132137 crop_size = train_crop_size ,
133138 interpolation = interpolation ,
134139 auto_augment_policy = auto_augment_policy ,
135140 random_erase_prob = random_erase_prob ,
141+ policy_magnitude = policy_magnitude ,
136142 ),
137143 )
138144 if args .cache_dataset :
@@ -182,7 +188,12 @@ def load_data(traindir, valdir, args):
182188def main (args ):
183189 if args .output_dir :
184190 utils .mkdir (args .output_dir )
185-
191+
192+ if args .seed is None :
193+ # randomly choose a seed
194+ args .seed = random .randint (0 , 2 ** 32 )
195+ utils .set_seed (args .seed )
196+
186197 utils .init_distributed_mode (args )
187198 print (args )
188199
@@ -261,13 +272,21 @@ def main(args):
261272 raise RuntimeError (f"Invalid optimizer { args .opt } . Only SGD, RMSprop and AdamW are supported." )
262273
263274 scaler = torch .cuda .amp .GradScaler () if args .amp else None
275+
276+ batches_per_epoch = len (data_loader )
277+ warmup_iters = args .lr_warmup_epochs
278+ total_iters = args .epochs
279+
280+ if args .lr_step_every_batch :
281+ warmup_iters *= batches_per_epoch
282+ total_iters *= batches_per_epoch
264283
265284 args .lr_scheduler = args .lr_scheduler .lower ()
266285 if args .lr_scheduler == "steplr" :
267286 main_lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = args .lr_step_size , gamma = args .lr_gamma )
268287 elif args .lr_scheduler == "cosineannealinglr" :
269288 main_lr_scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
270- optimizer , T_max = args . epochs - args . lr_warmup_epochs , eta_min = args .lr_min
289+ optimizer , T_max = total_iters - warmup_iters , eta_min = args .lr_min
271290 )
272291 elif args .lr_scheduler == "exponentiallr" :
273292 main_lr_scheduler = torch .optim .lr_scheduler .ExponentialLR (optimizer , gamma = args .lr_gamma )
@@ -280,18 +299,18 @@ def main(args):
280299 if args .lr_warmup_epochs > 0 :
281300 if args .lr_warmup_method == "linear" :
282301 warmup_lr_scheduler = torch .optim .lr_scheduler .LinearLR (
283- optimizer , start_factor = args .lr_warmup_decay , total_iters = args . lr_warmup_epochs
302+ optimizer , start_factor = args .lr_warmup_decay , total_iters = warmup_iters
284303 )
285304 elif args .lr_warmup_method == "constant" :
286305 warmup_lr_scheduler = torch .optim .lr_scheduler .ConstantLR (
287- optimizer , factor = args .lr_warmup_decay , total_iters = args . lr_warmup_epochs
306+ optimizer , factor = args .lr_warmup_decay , total_iters = warmup_iters
288307 )
289308 else :
290309 raise RuntimeError (
291310 f"Invalid warmup lr method '{ args .lr_warmup_method } '. Only linear and constant are supported."
292311 )
293312 lr_scheduler = torch .optim .lr_scheduler .SequentialLR (
294- optimizer , schedulers = [warmup_lr_scheduler , main_lr_scheduler ], milestones = [args . lr_warmup_epochs ]
313+ optimizer , schedulers = [warmup_lr_scheduler , main_lr_scheduler ], milestones = [warmup_iters ]
295314 )
296315 else :
297316 lr_scheduler = main_lr_scheduler
@@ -341,8 +360,9 @@ def main(args):
341360 for epoch in range (args .start_epoch , args .epochs ):
342361 if args .distributed :
343362 train_sampler .set_epoch (epoch )
344- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema , scaler )
345- lr_scheduler .step ()
363+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema , scaler , lr_scheduler )
364+ if not args .lr_step_every_batch :
365+ lr_scheduler .step ()
346366 evaluate (model , criterion , data_loader_test , device = device )
347367 if model_ema :
348368 evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = "EMA" )
@@ -371,7 +391,7 @@ def get_args_parser(add_help=True):
371391
372392 parser = argparse .ArgumentParser (description = "PyTorch Classification Training" , add_help = add_help )
373393
374- parser .add_argument ("--data-path" , default = "/datasets01 /imagenet_full_size/061417/" , type = str , help = "dataset path" )
394+ parser .add_argument ("--data-path" , default = "/datasets01_ontap /imagenet_full_size/061417/" , type = str , help = "dataset path" )
375395 parser .add_argument ("--model" , default = "resnet18" , type = str , help = "model name" )
376396 parser .add_argument ("--device" , default = "cuda" , type = str , help = "device (Use cuda or cpu Default: cuda)" )
377397 parser .add_argument (
@@ -425,6 +445,7 @@ def get_args_parser(add_help=True):
425445 parser .add_argument ("--lr-step-size" , default = 30 , type = int , help = "decrease lr every step-size epochs" )
426446 parser .add_argument ("--lr-gamma" , default = 0.1 , type = float , help = "decrease lr by a factor of lr-gamma" )
427447 parser .add_argument ("--lr-min" , default = 0.0 , type = float , help = "minimum lr of lr schedule (default: 0.0)" )
448+ parser .add_argument ("--lr-step-every-batch" , action = "store_true" , help = "decrease lr every step-size batches" , default = False )
428449 parser .add_argument ("--print-freq" , default = 10 , type = int , help = "print frequency" )
429450 parser .add_argument ("--output-dir" , default = "." , type = str , help = "path to save outputs" )
430451 parser .add_argument ("--resume" , default = "" , type = str , help = "path of checkpoint" )
@@ -448,6 +469,7 @@ def get_args_parser(add_help=True):
448469 action = "store_true" ,
449470 )
450471 parser .add_argument ("--auto-augment" , default = None , type = str , help = "auto augment policy (default: None)" )
472+ parser .add_argument ("--policy-magnitude" , default = 9 , type = int , help = "magnitude of auto augment policy" )
451473 parser .add_argument ("--random-erase" , default = 0.0 , type = float , help = "random erasing probability (default: 0.0)" )
452474
453475 # Mixed precision training parameters
@@ -486,13 +508,16 @@ def get_args_parser(add_help=True):
486508 parser .add_argument (
487509 "--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
488510 )
511+ parser .add_argument (
512+ "--train-center-crop" , action = "store_true" , help = "use center crop instead of random crop for training (default: False)"
513+ )
489514 parser .add_argument ("--clip-grad-norm" , default = None , type = float , help = "the maximum gradient norm (default None)" )
490515 parser .add_argument ("--ra-sampler" , action = "store_true" , help = "whether to use Repeated Augmentation in training" )
491516 parser .add_argument (
492517 "--ra-reps" , default = 3 , type = int , help = "number of repetitions for Repeated Augmentation (default: 3)"
493518 )
494519 parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
495-
520+ parser . add_argument ( "--seed" , default = None , type = int , help = "the seed for randomness (default: None). A `None` value means a seed will be randomly generated" )
496521 return parser
497522
498523
0 commit comments