Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
amp = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
print_freq, apex=False, model_ema=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
Expand Down Expand Up @@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

if model_ema:
model_ema.update_parameters(model)

def evaluate(model, criterion, data_loader, device, print_freq=100):

def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
header = f'Test: {log_suffix}'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
Expand Down Expand Up @@ -199,12 +203,18 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model, device=device, decay=args.model_ema_decay)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema:
model_ema.load_state_dict(checkpoint['model_ema'])

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
Expand All @@ -215,16 +225,20 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
if model_ema:
checkpoint['model_ema'] = model_ema.state_dict()
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
Expand Down Expand Up @@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument(
'--model-ema', action='store_true',
help='enable tracking Exponential Moving Average of model parameters')
parser.add_argument(
'--model-ema-decay', type=float, default=0.99,
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')

return parser

Expand Down
19 changes: 19 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,25 @@ def log_every(self, iterable, print_freq, header=None):
print('{} Total time: {}'.format(header, total_time_str))


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an
exponential decay.
`ema_avg = decay * avg_model_param + (1 - decay) * model_param`
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`
is used to compute the EMA.
"""
def __init__(self, model, decay, device='cpu', name='ExponentialMovingAverage'):
ema_avg = (lambda avg_model_param, model_param, num_averaged:
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)
self._name = name

@property
def name(self):
"""ExponentialMovingAverage object's name."""
return self._name


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
Expand Down