From 762ba54385ba3c0e6e297762f6de45273d2bb0d6 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 18 Oct 2022 16:42:06 +0100 Subject: [PATCH] Some fixes for crestereo --- .../prototype/models/depth/stereo/crestereo.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/models/depth/stereo/crestereo.py b/torchvision/prototype/models/depth/stereo/crestereo.py index 49643852285..29c0be93618 100644 --- a/torchvision/prototype/models/depth/stereo/crestereo.py +++ b/torchvision/prototype/models/depth/stereo/crestereo.py @@ -763,7 +763,7 @@ def _get_window_type(self, iteration: int) -> str: return "1d" if iteration % 2 == 0 else "2d" def forward( - self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10 + self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10 ) -> List[Tensor]: features = torch.cat([left_image, right_image], dim=0) features = self.feature_encoder(features) @@ -781,10 +781,10 @@ def forward( ctx_pyramid = self.downsampling_pyramid(ctx) # we store in reversed order because we process the pyramid from top to bottom - l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} - r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} - net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} - ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} + l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} + r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} + net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} + ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} # offsets for sampling pixel candidates in the correlation ops offsets: Dict[str, Tensor] = {} @@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress .. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights :members: """ + + weights = CREStereo_Base_Weights.verify(weights) + return _crestereo( weights=weights, progress=progress,