diff --git a/torchvision/prototype/models/depth/stereo/crestereo.py b/torchvision/prototype/models/depth/stereo/crestereo.py index 49643852285..29c0be93618 100644 --- a/torchvision/prototype/models/depth/stereo/crestereo.py +++ b/torchvision/prototype/models/depth/stereo/crestereo.py @@ -763,7 +763,7 @@ def _get_window_type(self, iteration: int) -> str: return "1d" if iteration % 2 == 0 else "2d" def forward( - self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10 + self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10 ) -> List[Tensor]: features = torch.cat([left_image, right_image], dim=0) features = self.feature_encoder(features) @@ -781,10 +781,10 @@ def forward( ctx_pyramid = self.downsampling_pyramid(ctx) # we store in reversed order because we process the pyramid from top to bottom - l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} - r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} - net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} - ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} + l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} + r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} + net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} + ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} # offsets for sampling pixel candidates in the correlation ops offsets: Dict[str, Tensor] = {} @@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress .. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights :members: """ + + weights = CREStereo_Base_Weights.verify(weights) + return _crestereo( weights=weights, progress=progress,