Skip to content

Add the SWAG pre-trained weights in TorchVision #5708

Closed
@datumbox

Description

@datumbox

🚀 The feature

It would be great to add support of the pre-trained weights from the Supervised Weakly from hashtAGs (SWAG) to TorchVision. We will focus on porting the weights developed by @lauragustafson @mannatsingh and @aadcock.

There are two sets of weights to add for each model variant:

  1. The original frozen trunk SWAG weights with a linear classifier learnt on ImageNet1K
  2. The end-to-end fine-tuned weights on ImageNet1K

We will focus on the variants that are currently supported by TorchVision (regnet_y_16gf, regnet_y_32gf, vit_b_16 and vit_l_16). We should also investigate if the larger variants can be added (regnet_y_128gf and vit_h_14) or if they cause issues on our CI (memory, increased execution times etc).

This task includes the following subtasks:

  • Convert the weights to be compatible with TorchVision's implementation
  • Add the weight entries with the right transform configuration and meta-data
  • Add the necessary licensing info (name, URL etc) in the meta-data; update the README to clarify they are offered under CC-BY-NC 4.0
  • Verify that the accuracies reported by our reference scripts match the ones reports on the SWAG repo
  • Confirm that our CI works well and the additions don't bring significant slowdowns or breakages. If there are such effects, take actions to mitigate

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions