Skip to content

Commit 59ec1df

Browse files
xiaohu2015datumbox
andauthored
support amp training for detection models (#4933)
* support amp training * support amp training * support amp training * Update references/detection/train.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update references/detection/engine.py Co-authored-by: Vasilis Vryniotis <[email protected]> * fix lint issues Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 4b20ac5 commit 59ec1df

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

references/detection/engine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from coco_utils import get_coco_api_from_dataset
1010

1111

12-
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
12+
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
1313
model.train()
1414
metric_logger = utils.MetricLogger(delimiter=" ")
1515
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
@@ -27,10 +27,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
2727
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
2828
images = list(image.to(device) for image in images)
2929
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
30-
31-
loss_dict = model(images, targets)
32-
33-
losses = sum(loss for loss in loss_dict.values())
30+
with torch.cuda.amp.autocast(enabled=scaler is not None):
31+
loss_dict = model(images, targets)
32+
losses = sum(loss for loss in loss_dict.values())
3433

3534
# reduce losses over all GPUs for logging purposes
3635
loss_dict_reduced = utils.reduce_dict(loss_dict)
@@ -44,8 +43,13 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
4443
sys.exit(1)
4544

4645
optimizer.zero_grad()
47-
losses.backward()
48-
optimizer.step()
46+
if scaler is not None:
47+
scaler.scale(losses).backward()
48+
scaler.step(optimizer)
49+
scaler.update()
50+
else:
51+
losses.backward()
52+
optimizer.step()
4953

5054
if lr_scheduler is not None:
5155
lr_scheduler.step()

references/detection/train.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def get_args_parser(add_help=True):
144144
# Prototype models only
145145
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
146146

147+
# Mixed precision training parameters
148+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
149+
147150
return parser
148151

149152

@@ -209,6 +212,8 @@ def main(args):
209212
params = [p for p in model.parameters() if p.requires_grad]
210213
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
211214

215+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
216+
212217
args.lr_scheduler = args.lr_scheduler.lower()
213218
if args.lr_scheduler == "multisteplr":
214219
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
@@ -225,6 +230,8 @@ def main(args):
225230
optimizer.load_state_dict(checkpoint["optimizer"])
226231
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
227232
args.start_epoch = checkpoint["epoch"] + 1
233+
if args.amp:
234+
scaler.load_state_dict(checkpoint["scaler"])
228235

229236
if args.test_only:
230237
evaluate(model, data_loader_test, device=device)
@@ -235,7 +242,7 @@ def main(args):
235242
for epoch in range(args.start_epoch, args.epochs):
236243
if args.distributed:
237244
train_sampler.set_epoch(epoch)
238-
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
245+
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
239246
lr_scheduler.step()
240247
if args.output_dir:
241248
checkpoint = {
@@ -245,6 +252,8 @@ def main(args):
245252
"args": args,
246253
"epoch": epoch,
247254
}
255+
if args.amp:
256+
checkpoint["scaler"] = scaler.state_dict()
248257
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
249258
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
250259

0 commit comments

Comments
 (0)