@@ -144,6 +144,9 @@ def get_args_parser(add_help=True):
144
144
# Prototype models only
145
145
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
146
146
147
+ # Mixed precision training parameters
148
+ parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
149
+
147
150
return parser
148
151
149
152
@@ -209,6 +212,8 @@ def main(args):
209
212
params = [p for p in model .parameters () if p .requires_grad ]
210
213
optimizer = torch .optim .SGD (params , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
211
214
215
+ scaler = torch .cuda .amp .GradScaler () if args .amp else None
216
+
212
217
args .lr_scheduler = args .lr_scheduler .lower ()
213
218
if args .lr_scheduler == "multisteplr" :
214
219
lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , milestones = args .lr_steps , gamma = args .lr_gamma )
@@ -225,6 +230,8 @@ def main(args):
225
230
optimizer .load_state_dict (checkpoint ["optimizer" ])
226
231
lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
227
232
args .start_epoch = checkpoint ["epoch" ] + 1
233
+ if args .amp :
234
+ scaler .load_state_dict (checkpoint ["scaler" ])
228
235
229
236
if args .test_only :
230
237
evaluate (model , data_loader_test , device = device )
@@ -235,7 +242,7 @@ def main(args):
235
242
for epoch in range (args .start_epoch , args .epochs ):
236
243
if args .distributed :
237
244
train_sampler .set_epoch (epoch )
238
- train_one_epoch (model , optimizer , data_loader , device , epoch , args .print_freq )
245
+ train_one_epoch (model , optimizer , data_loader , device , epoch , args .print_freq , scaler )
239
246
lr_scheduler .step ()
240
247
if args .output_dir :
241
248
checkpoint = {
@@ -245,6 +252,8 @@ def main(args):
245
252
"args" : args ,
246
253
"epoch" : epoch ,
247
254
}
255
+ if args .amp :
256
+ checkpoint ["scaler" ] = scaler .state_dict ()
248
257
utils .save_on_master (checkpoint , os .path .join (args .output_dir , f"model_{ epoch } .pth" ))
249
258
utils .save_on_master (checkpoint , os .path .join (args .output_dir , "checkpoint.pth" ))
250
259
0 commit comments