Skip to content

Change default weights of RAFT model builders #5381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_MODELS_URLS = {
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
}

Expand Down Expand Up @@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
pretrained (bool): Whether to use pretrained weights.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`
with two fine-tuning steps:

- one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D`
- one on :class:`~torchvsion.datasets.KittiFlow`.

This corresponds to the ``C+T+S/K`` strategy in the paper.

progress (bool): If True, displays a progress bar of the download to stderr.

Returns:
nn.Module: The model.
Expand Down Expand Up @@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
pretrained (bool): Whether to use pretrained weights.
pretrained (bool): Whether to use weights that have been pre-trained on
:class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum):
},
)

DEFAULT = C_T_V2
DEFAULT = C_T_SKHT_V2


class Raft_Small_Weights(WeightsEnum):
Expand Down Expand Up @@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum):
DEFAULT = C_T_V2


@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Expand Down