@@ -763,7 +763,7 @@ def _get_window_type(self, iteration: int) -> str:
763
763
return "1d" if iteration % 2 == 0 else "2d"
764
764
765
765
def forward (
766
- self , left_image : Tensor , right_image : Tensor , flow_init : Optional [Tensor ], num_iters : int = 10
766
+ self , left_image : Tensor , right_image : Tensor , flow_init : Optional [Tensor ] = None , num_iters : int = 10
767
767
) -> List [Tensor ]:
768
768
features = torch .cat ([left_image , right_image ], dim = 0 )
769
769
features = self .feature_encoder (features )
@@ -781,10 +781,10 @@ def forward(
781
781
ctx_pyramid = self .downsampling_pyramid (ctx )
782
782
783
783
# we store in reversed order because we process the pyramid from top to bottom
784
- l_pyramid : Dict [ str , Tensor ] = {res : l_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
785
- r_pyramid : Dict [ str , Tensor ] = {res : r_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
786
- net_pyramid : Dict [ str , Tensor ] = {res : net_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
787
- ctx_pyramid : Dict [ str , Tensor ] = {res : ctx_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
784
+ l_pyramid = {res : l_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
785
+ r_pyramid = {res : r_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
786
+ net_pyramid = {res : net_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
787
+ ctx_pyramid = {res : ctx_pyramid [idx ] for idx , res in enumerate (self .resolutions )}
788
788
789
789
# offsets for sampling pixel candidates in the correlation ops
790
790
offsets : Dict [str , Tensor ] = {}
@@ -1425,6 +1425,9 @@ def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress
1425
1425
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
1426
1426
:members:
1427
1427
"""
1428
+
1429
+ weights = CREStereo_Base_Weights .verify (weights )
1430
+
1428
1431
return _crestereo (
1429
1432
weights = weights ,
1430
1433
progress = progress ,
0 commit comments