Skip to content

Added progress flag to model getters #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url


__all__ = ['AlexNet', 'alexnet']
Expand Down Expand Up @@ -48,14 +48,17 @@ def forward(self, x):
return x


def alexnet(pretrained=False, **kwargs):
def alexnet(pretrained=False, progress=True, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = AlexNet(**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model
75 changes: 41 additions & 34 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
from collections import OrderedDict

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']


model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
Expand All @@ -22,25 +21,29 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate

def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)


Expand Down Expand Up @@ -75,7 +78,8 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),

# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
Expand All @@ -85,11 +89,13 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2

Expand Down Expand Up @@ -117,14 +123,15 @@ def forward(self, x):
return out


def _load_state_dict(model, model_url):
def _load_state_dict(model, model_url, progress):
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_url)

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
Expand All @@ -134,57 +141,57 @@ def _load_state_dict(model, model_url):
model.load_state_dict(state_dict)


def densenet121(pretrained=False, **kwargs):
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
**kwargs):
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
_load_state_dict(model, model_urls[arch], progress)
return model


def densenet121(pretrained=False, progress=True, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet121'])
return model


def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
def densenet161(pretrained=False, progress=True, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet169'])
return model


def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
def densenet169(pretrained=False, progress=True, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet201'])
return model


def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
def densenet201(pretrained=False, progress=True, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet161'])
return model
9 changes: 6 additions & 3 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from .utils import load_state_dict_from_url

__all__ = ['GoogLeNet', 'googlenet']

Expand All @@ -15,12 +15,13 @@
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])


def googlenet(pretrained=False, **kwargs):
def googlenet(pretrained=False, progress=True, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
Expand All @@ -38,7 +39,9 @@ def googlenet(pretrained=False, **kwargs):
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
state_dict = load_state_dict_from_url(model_urls['googlenet'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
Expand Down
9 changes: 6 additions & 3 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url


__all__ = ['Inception3', 'inception_v3']
Expand All @@ -16,7 +16,7 @@
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])


def inception_v3(pretrained=False, **kwargs):
def inception_v3(pretrained=False, progress=True, **kwargs):
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

Expand All @@ -26,6 +26,7 @@ def inception_v3(pretrained=False, **kwargs):

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
Expand All @@ -40,7 +41,9 @@ def inception_v3(pretrained=False, **kwargs):
else:
original_aux_logits = True
model = Inception3(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
Expand Down
76 changes: 40 additions & 36 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
Expand Down Expand Up @@ -185,75 +185,79 @@ def forward(self, x):
return x


def resnet18(pretrained=False, **kwargs):
def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
model = ResNet(inplanes, planes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model


def resnet18(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-18 model.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)


def resnet34(pretrained=False, **kwargs):
def resnet34(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-34 model.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet50(pretrained=False, **kwargs):
def resnet50(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet101(pretrained=False, **kwargs):
def resnet101(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-101 model.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)


def resnet152(pretrained=False, **kwargs):
def resnet152(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-152 model.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)


def resnext50_32x4d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
return model
def resnext50_32x4d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained=False, progress=True, **kwargs)


def resnext101_32x8d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
return model
def resnext101_32x8d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained=False, progress=True, **kwargs)
Loading