Skip to content

Commit 22ff44f

Browse files
xiaohu2015datumbox
andauthored
save grad_scaler if use amp for better resume (#4923)
Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 9b034e1 commit 22ff44f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

references/classification/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ def main(args):
325325
args.start_epoch = checkpoint["epoch"] + 1
326326
if model_ema:
327327
model_ema.load_state_dict(checkpoint["model_ema"])
328+
if scaler:
329+
scaler.load_state_dict(checkpoint["scaler"])
328330

329331
if args.test_only:
330332
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
@@ -356,6 +358,8 @@ def main(args):
356358
}
357359
if model_ema:
358360
checkpoint["model_ema"] = model_ema.state_dict()
361+
if scaler:
362+
checkpoint["scaler"] = scaler.state_dict()
359363
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
360364
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
361365

0 commit comments

Comments
 (0)