Skip to content

Commit 3902140

Browse files
authored
torchvision QAT tutorial: update for QAT with DDP (#2280)
Summary: We've made two recent changes to QAT in PyTorch core: 1. add support for SyncBatchNorm 2. make eager mode QAT prepare scripts respect device affinity This PR updates the torchvision QAT reference script to take advantage of both of these. This should be landed after pytorch/pytorch#39337 (the last PT fix) to avoid compatibility issues. Test Plan: ``` python -m torch.distributed.launch --nproc_per_node 8 --use_env references/classification/train_quantization.py --data-path {imagenet1k_subset} --output-dir {tmp} --sync-bn ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 34810c0 commit 3902140

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

references/classification/train_quantization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,16 @@ def main(args):
5151
print("Creating model", args.model)
5252
# when training quantized models, we always start from a pre-trained fp32 reference model
5353
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
54+
model.to(device)
5455

5556
if not (args.test_only or args.post_training_quantize):
5657
model.fuse_model()
5758
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
5859
torch.quantization.prepare_qat(model, inplace=True)
5960

61+
if args.distributed and args.sync_bn:
62+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63+
6064
optimizer = torch.optim.SGD(
6165
model.parameters(), lr=args.lr, momentum=args.momentum,
6266
weight_decay=args.weight_decay)
@@ -65,8 +69,6 @@ def main(args):
6569
step_size=args.lr_step_size,
6670
gamma=args.lr_gamma)
6771

68-
model.to(device)
69-
7072
criterion = nn.CrossEntropyLoss()
7173
model_without_ddp = model
7274
if args.distributed:
@@ -224,6 +226,12 @@ def parse_args():
224226
It also serializes the transforms",
225227
action="store_true",
226228
)
229+
parser.add_argument(
230+
"--sync-bn",
231+
dest="sync_bn",
232+
help="Use sync batch norm",
233+
action="store_true",
234+
)
227235
parser.add_argument(
228236
"--test-only",
229237
dest="test_only",

0 commit comments

Comments
 (0)