diff --git a/references/classification/train.py b/references/classification/train.py index 9ba99b3dc54..79b99156a05 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -175,7 +175,7 @@ def main(args): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - criterion = nn.CrossEntropyLoss() + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() if opt_name == 'sgd': @@ -256,6 +256,9 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') + parser.add_argument('--label-smoothing', default=0.0, type=float, + help='label smoothing (default: 0.0)', + dest='label_smoothing') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')