Skip to content

Commit 7a62a54

Browse files
authored
Some fixes for crestereo (#6791)
1 parent 78fdaf3 commit 7a62a54

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

torchvision/prototype/models/depth/stereo/crestereo.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def _get_window_type(self, iteration: int) -> str:
763763
return "1d" if iteration % 2 == 0 else "2d"
764764

765765
def forward(
766-
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10
766+
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
767767
) -> List[Tensor]:
768768
features = torch.cat([left_image, right_image], dim=0)
769769
features = self.feature_encoder(features)
@@ -781,10 +781,10 @@ def forward(
781781
ctx_pyramid = self.downsampling_pyramid(ctx)
782782

783783
# we store in reversed order because we process the pyramid from top to bottom
784-
l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
785-
r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
786-
net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
787-
ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
784+
l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
785+
r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
786+
net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
787+
ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
788788

789789
# offsets for sampling pixel candidates in the correlation ops
790790
offsets: Dict[str, Tensor] = {}
@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
14251425
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
14261426
:members:
14271427
"""
1428+
1429+
weights = CREStereo_Base_Weights.verify(weights)
1430+
14281431
return _crestereo(
14291432
weights=weights,
14301433
progress=progress,

0 commit comments

Comments
 (0)