diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 95d8dce0f28..2b031ef39b7 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,5 +1,5 @@ import torch.nn as nn -import torch.utils.model_zoo as model_zoo +from .utils import load_state_dict_from_url __all__ = ['AlexNet', 'alexnet'] @@ -48,14 +48,17 @@ def forward(self, x): return x -def alexnet(pretrained=False, **kwargs): +def alexnet(pretrained=False, progress=True, **kwargs): r"""AlexNet model architecture from the `"One weird trick..." `_ paper. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ model = AlexNet(**kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) + state_dict = load_state_dict_from_url(model_urls['alexnet'], + progress=progress) + model.load_state_dict(state_dict) return model diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 3d63f38249d..35536770feb 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -2,12 +2,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo +from .utils import load_state_dict_from_url from collections import OrderedDict __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] - model_urls = { 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', @@ -22,17 +21,20 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, bias=False)), + growth_rate, kernel_size=1, stride=1, + bias=False)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, bias=False)), + kernel_size=3, stride=1, padding=1, + bias=False)), self.drop_rate = drop_rate def forward(self, x): new_features = super(_DenseLayer, self).forward(x) if self.drop_rate > 0: - new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + new_features = F.dropout(new_features, p=self.drop_rate, + training=self.training) return torch.cat([x, new_features], 1) @@ -40,7 +42,8 @@ class _DenseBlock(nn.Sequential): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): super(_DenseBlock, self).__init__() for i in range(num_layers): - layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) + layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, + bn_size, drop_rate) self.add_module('denselayer%d' % (i + 1), layer) @@ -75,7 +78,8 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), # First convolution self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, + padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), @@ -85,11 +89,13 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, - bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) + bn_size=bn_size, growth_rate=growth_rate, + drop_rate=drop_rate) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) + trans = _Transition(num_input_features=num_features, + num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 @@ -117,14 +123,15 @@ def forward(self, x): return out -def _load_state_dict(model, model_url): +def _load_state_dict(model, model_url, progress): # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_url) + + state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: @@ -134,57 +141,57 @@ def _load_state_dict(model, model_url): model.load_state_dict(state_dict) -def densenet121(pretrained=False, **kwargs): +def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, + **kwargs): + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) + if pretrained: + _load_state_dict(model, model_urls[arch], progress) + return model + + +def densenet121(pretrained=False, progress=True, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), + return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) - if pretrained: - _load_state_dict(model, model_urls['densenet121']) - return model -def densenet169(pretrained=False, **kwargs): - r"""Densenet-169 model from +def densenet161(pretrained=False, progress=True, **kwargs): + r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), + return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) - if pretrained: - _load_state_dict(model, model_urls['densenet169']) - return model -def densenet201(pretrained=False, **kwargs): - r"""Densenet-201 model from +def densenet169(pretrained=False, progress=True, **kwargs): + r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), + return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) - if pretrained: - _load_state_dict(model, model_urls['densenet201']) - return model -def densenet161(pretrained=False, **kwargs): - r"""Densenet-161 model from +def densenet201(pretrained=False, progress=True, **kwargs): + r"""Densenet-201 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), + return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) - if pretrained: - _load_state_dict(model, model_urls['densenet161']) - return model diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 4725dc95819..0889cd37ba8 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils import model_zoo +from .utils import load_state_dict_from_url __all__ = ['GoogLeNet', 'googlenet'] @@ -15,12 +15,13 @@ _GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1']) -def googlenet(pretrained=False, **kwargs): +def googlenet(pretrained=False, progress=True, **kwargs): r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it @@ -38,7 +39,9 @@ def googlenet(pretrained=False, **kwargs): kwargs['aux_logits'] = True kwargs['init_weights'] = False model = GoogLeNet(**kwargs) - model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) + state_dict = load_state_dict_from_url(model_urls['googlenet'], + progress=progress) + model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False del model.aux1, model.aux2 diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index f8217b72fbe..33117223423 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo +from .utils import load_state_dict_from_url __all__ = ['Inception3', 'inception_v3'] @@ -16,7 +16,7 @@ _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) -def inception_v3(pretrained=False, **kwargs): +def inception_v3(pretrained=False, progress=True, **kwargs): r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. @@ -26,6 +26,7 @@ def inception_v3(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it @@ -40,7 +41,9 @@ def inception_v3(pretrained=False, **kwargs): else: original_aux_logits = True model = Inception3(**kwargs) - model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) + state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], + progress=progress) + model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False del model.AuxLogits diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 70411ea4fd3..2f4bdff2505 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,5 +1,5 @@ import torch.nn as nn -import torch.utils.model_zoo as model_zoo +from .utils import load_state_dict_from_url __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -185,75 +185,79 @@ def forward(self, x): return x -def resnet18(pretrained=False, **kwargs): +def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs): + model = ResNet(inplanes, planes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) - return model + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) -def resnet34(pretrained=False, **kwargs): +def resnet34(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) - return model + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) -def resnet50(pretrained=False, **kwargs): +def resnet50(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) - return model + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) -def resnet101(pretrained=False, **kwargs): +def resnet101(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-101 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) - return model + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) -def resnet152(pretrained=False, **kwargs): +def resnet152(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) - return model + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) -def resnext50_32x4d(pretrained=False, **kwargs): - model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) - # if pretrained: - # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) - return model +def resnext50_32x4d(**kwargs): + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained=False, progress=True, **kwargs) -def resnext101_32x8d(pretrained=False, **kwargs): - model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) - # if pretrained: - # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) - return model +def resnext101_32x8d(**kwargs): + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained=False, progress=True, **kwargs) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 456f4d9187b..1f8cd6bace0 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,12 +1,10 @@ import torch import torch.nn as nn import torch.nn.init as init -import torch.utils.model_zoo as model_zoo - +from .utils import load_state_dict_from_url __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] - model_urls = { 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', @@ -38,13 +36,10 @@ def forward(self, x): class SqueezeNet(nn.Module): - def __init__(self, version=1.0, num_classes=1000): + def __init__(self, version='1_0', num_classes=1000): super(SqueezeNet, self).__init__() - if version not in [1.0, 1.1]: - raise ValueError("Unsupported SqueezeNet version {version}:" - "1.0 or 1.1 expected".format(version=version)) self.num_classes = num_classes - if version == 1.0: + if version == '1_0': self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), @@ -60,7 +55,7 @@ def __init__(self, version=1.0, num_classes=1000): nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) - else: + elif version == '1_1': self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.ReLU(inplace=True), @@ -76,6 +71,13 @@ def __init__(self, version=1.0, num_classes=1000): Fire(384, 64, 256, 256), Fire(512, 64, 256, 256), ) + else: + # FIXME: Is this needed? SqueezeNet should only be called from the + # FIXME: squeezenet1_x() functions + # FIXME: This checking is not done for the other models + raise ValueError("Unsupported SqueezeNet version {version}:" + "1_0 or 1_1 expected".format(version=version)) + # Final convolution is initialized differently from the rest final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) self.classifier = nn.Sequential( @@ -100,21 +102,29 @@ def forward(self, x): return x.view(x.size(0), self.num_classes) -def squeezenet1_0(pretrained=False, **kwargs): +def _squeezenet(version, pretrained, progress, **kwargs): + model = SqueezeNet(version, **kwargs) + if pretrained: + arch = 'squeezenet' + version + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def squeezenet1_0(pretrained=False, progress=True, **kwargs): r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = SqueezeNet(version=1.0, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) - return model + return _squeezenet('1_0', pretrained, progress, **kwargs) -def squeezenet1_1(pretrained=False, **kwargs): +def squeezenet1_1(pretrained=False, progress=True, **kwargs): r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters @@ -122,8 +132,6 @@ def squeezenet1_1(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - model = SqueezeNet(version=1.1, **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) - return model + return _squeezenet('1_1', pretrained, progress, **kwargs) diff --git a/torchvision/models/utils.py b/torchvision/models/utils.py new file mode 100644 index 00000000000..638ef07cd85 --- /dev/null +++ b/torchvision/models/utils.py @@ -0,0 +1,4 @@ +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index ada15fd1c09..0e72e99538c 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,5 +1,5 @@ import torch.nn as nn -import torch.utils.model_zoo as model_zoo +from .utils import load_state_dict_from_url __all__ = [ @@ -75,7 +75,7 @@ def make_layers(cfg, batch_norm=False): return nn.Sequential(*layers) -cfg = { +cfgs = { 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], @@ -83,113 +83,92 @@ def make_layers(cfg, batch_norm=False): } -def vgg11(pretrained=False, **kwargs): +def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def vgg11(pretrained=False, progress=True, **kwargs): """VGG 11-layer model (configuration "A") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A']), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) - return model + return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) -def vgg11_bn(pretrained=False, **kwargs): +def vgg11_bn(pretrained=False, progress=True, **kwargs): """VGG 11-layer model (configuration "A") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) - return model + return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) -def vgg13(pretrained=False, **kwargs): +def vgg13(pretrained=False, progress=True, **kwargs): """VGG 13-layer model (configuration "B") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B']), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) - return model + return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) -def vgg13_bn(pretrained=False, **kwargs): +def vgg13_bn(pretrained=False, progress=True, **kwargs): """VGG 13-layer model (configuration "B") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) - return model + return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) -def vgg16(pretrained=False, **kwargs): +def vgg16(pretrained=False, progress=True, **kwargs): """VGG 16-layer model (configuration "D") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D']), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) - return model + return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) -def vgg16_bn(pretrained=False, **kwargs): +def vgg16_bn(pretrained=False, progress=True, **kwargs): """VGG 16-layer model (configuration "D") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) - return model + return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) -def vgg19(pretrained=False, **kwargs): +def vgg19(pretrained=False, progress=True, **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E']), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) - return model + return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) -def vgg19_bn(pretrained=False, **kwargs): +def vgg19_bn(pretrained=False, progress=True, **kwargs): """VGG 19-layer model (configuration 'E') with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) - return model + return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)