Skip to content

[proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases #3331

Open
@vadimkantorov

Description

@vadimkantorov

Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

If it instead used x = self.flatten(x), then it would simplify model surgery: del model.avgpool, model.flatten, model.fc. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The method forward could then be removed

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions