diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 031beaeec8b..be3b3c349ce 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -21,7 +21,7 @@ _MODELS_URLS = { - "raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", } @@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - pretrained (bool): Whether to use pretrained weights. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. + pretrained (bool): Whether to use weights that have been pre-trained on + :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` + with two fine-tuning steps: + + - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` + - one on :class:`~torchvsion.datasets.KittiFlow`. + + This corresponds to the ``C+T+S/K`` strategy in the paper. + + progress (bool): If True, displays a progress bar of the download to stderr. Returns: nn.Module: The model. @@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - pretrained (bool): Whether to use pretrained weights. + pretrained (bool): Whether to use weights that have been pre-trained on + :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. Returns: nn.Module: The model. diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 44ea84deed0..bf8634efd4f 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum): }, ) - DEFAULT = C_T_V2 + DEFAULT = C_T_SKHT_V2 class Raft_Small_Weights(WeightsEnum): @@ -151,7 +151,7 @@ class Raft_Small_Weights(WeightsEnum): DEFAULT = C_T_V2 -@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_.