diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 8e1d2573068..15df4d8ae39 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -32,7 +32,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True aux_classifier = FCNHead(inplanes, num_classes) model_map = { - 'deeplab': (DeepLabHead, DeepLabV3), + 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 @@ -43,20 +43,12 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True return model -def fcn_resnet50(pretrained=False, progress=True, - num_classes=21, aux_loss=None, **kwargs): - """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC - progress (bool): If True, displays a progress bar of the download to stderr - """ +def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): if pretrained: aux_loss = True - model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) + model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) if pretrained: - arch = 'fcn_resnet50_coco' + arch = arch_type + '_' + backbone + '_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) @@ -66,6 +58,18 @@ def fcn_resnet50(pretrained=False, progress=True, return model +def fcn_resnet50(pretrained=False, progress=True, + num_classes=21, aux_loss=None, **kwargs): + """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + + def fcn_resnet101(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. @@ -75,18 +79,7 @@ def fcn_resnet101(pretrained=False, progress=True, contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - aux_loss = True - model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs) - if pretrained: - arch = 'fcn_resnet101_coco' - model_url = model_urls[arch] - if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - return model + return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) def deeplabv3_resnet50(pretrained=False, progress=True, @@ -98,18 +91,7 @@ def deeplabv3_resnet50(pretrained=False, progress=True, contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - aux_loss = True - model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs) - if pretrained: - arch = 'deeplabv3_resnet50_coco' - model_url = model_urls[arch] - if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - return model + return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) def deeplabv3_resnet101(pretrained=False, progress=True, @@ -121,15 +103,4 @@ def deeplabv3_resnet101(pretrained=False, progress=True, contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - aux_loss = True - model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs) - if pretrained: - arch = 'deeplabv3_resnet101_coco' - model_url = model_urls[arch] - if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - return model + return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)