55# LICENSE file in the root directory of this source tree.
66
77from itertools import zip_longest
8- from typing import Sequence , Union
8+ from typing import List , Optional , Sequence , Tuple , Union
99
1010import numpy as np
1111import torch
@@ -240,7 +240,9 @@ def __init__(self, points, normals=None, features=None) -> None:
240240 if features_C is not None :
241241 self ._C = features_C
242242
243- def _parse_auxiliary_input (self , aux_input ):
243+ def _parse_auxiliary_input (
244+ self , aux_input
245+ ) -> Tuple [Optional [List [torch .Tensor ]], Optional [torch .Tensor ], Optional [int ]]:
244246 """
245247 Interpret the auxiliary inputs (normals, features) given to __init__.
246248
@@ -323,24 +325,26 @@ def __getitem__(self, index) -> "Pointclouds":
323325 Pointclouds object with selected clouds. The tensors are not cloned.
324326 """
325327 normals , features = None , None
328+ normals_list = self .normals_list ()
329+ features_list = self .features_list ()
326330 if isinstance (index , int ):
327331 points = [self .points_list ()[index ]]
328- if self . normals_list () is not None :
329- normals = [self . normals_list () [index ]]
330- if self . features_list () is not None :
331- features = [self . features_list () [index ]]
332+ if normals_list is not None :
333+ normals = [normals_list [index ]]
334+ if features_list is not None :
335+ features = [features_list [index ]]
332336 elif isinstance (index , slice ):
333337 points = self .points_list ()[index ]
334- if self . normals_list () is not None :
335- normals = self . normals_list () [index ]
336- if self . features_list () is not None :
337- features = self . features_list () [index ]
338+ if normals_list is not None :
339+ normals = normals_list [index ]
340+ if features_list is not None :
341+ features = features_list [index ]
338342 elif isinstance (index , list ):
339343 points = [self .points_list ()[i ] for i in index ]
340- if self . normals_list () is not None :
341- normals = [self . normals_list () [i ] for i in index ]
342- if self . features_list () is not None :
343- features = [self . features_list () [i ] for i in index ]
344+ if normals_list is not None :
345+ normals = [normals_list [i ] for i in index ]
346+ if features_list is not None :
347+ features = [features_list [i ] for i in index ]
344348 elif isinstance (index , torch .Tensor ):
345349 if index .dim () != 1 or index .dtype .is_floating_point :
346350 raise IndexError (index )
@@ -351,10 +355,10 @@ def __getitem__(self, index) -> "Pointclouds":
351355 index = index .squeeze (1 ) if index .numel () > 0 else index
352356 index = index .tolist ()
353357 points = [self .points_list ()[i ] for i in index ]
354- if self . normals_list () is not None :
355- normals = [self . normals_list () [i ] for i in index ]
356- if self . features_list () is not None :
357- features = [self . features_list () [i ] for i in index ]
358+ if normals_list is not None :
359+ normals = [normals_list [i ] for i in index ]
360+ if features_list is not None :
361+ features = [features_list [i ] for i in index ]
358362 else :
359363 raise IndexError (index )
360364
@@ -369,7 +373,7 @@ def isempty(self) -> bool:
369373 """
370374 return self ._N == 0 or self .valid .eq (False ).all ()
371375
372- def points_list (self ):
376+ def points_list (self ) -> List [ torch . Tensor ] :
373377 """
374378 Get the list representation of the points.
375379
@@ -388,9 +392,10 @@ def points_list(self):
388392 self ._points_list = points_list
389393 return self ._points_list
390394
391- def normals_list (self ):
395+ def normals_list (self ) -> Optional [ List [ torch . Tensor ]] :
392396 """
393- Get the list representation of the normals.
397+ Get the list representation of the normals,
398+ or None if there are no normals.
394399
395400 Returns:
396401 list of tensors of normals of shape (P_n, 3).
@@ -404,9 +409,10 @@ def normals_list(self):
404409 )
405410 return self ._normals_list
406411
407- def features_list (self ):
412+ def features_list (self ) -> Optional [ List [ torch . Tensor ]] :
408413 """
409- Get the list representation of the features.
414+ Get the list representation of the features,
415+ or None if there are no features.
410416
411417 Returns:
412418 list of tensors of features of shape (P_n, C).
@@ -420,7 +426,7 @@ def features_list(self):
420426 )
421427 return self ._features_list
422428
423- def points_packed (self ):
429+ def points_packed (self ) -> torch . Tensor :
424430 """
425431 Get the packed representation of the points.
426432
@@ -430,22 +436,24 @@ def points_packed(self):
430436 self ._compute_packed ()
431437 return self ._points_packed
432438
433- def normals_packed (self ):
439+ def normals_packed (self ) -> Optional [ torch . Tensor ] :
434440 """
435441 Get the packed representation of the normals.
436442
437443 Returns:
438- tensor of normals of shape (sum(P_n), 3).
444+ tensor of normals of shape (sum(P_n), 3),
445+ or None if there are no normals.
439446 """
440447 self ._compute_packed ()
441448 return self ._normals_packed
442449
443- def features_packed (self ):
450+ def features_packed (self ) -> Optional [ torch . Tensor ] :
444451 """
445452 Get the packed representation of the features.
446453
447454 Returns:
448- tensor of features of shape (sum(P_n), C).
455+ tensor of features of shape (sum(P_n), C),
456+ or None if there are no features
449457 """
450458 self ._compute_packed ()
451459 return self ._features_packed
@@ -483,7 +491,7 @@ def num_points_per_cloud(self):
483491 """
484492 return self ._num_points_per_cloud
485493
486- def points_padded (self ):
494+ def points_padded (self ) -> torch . Tensor :
487495 """
488496 Get the padded representation of the points.
489497
@@ -493,19 +501,21 @@ def points_padded(self):
493501 self ._compute_padded ()
494502 return self ._points_padded
495503
496- def normals_padded (self ):
504+ def normals_padded (self ) -> Optional [ torch . Tensor ] :
497505 """
498- Get the padded representation of the normals.
506+ Get the padded representation of the normals,
507+ or None if there are no normals.
499508
500509 Returns:
501510 tensor of normals of shape (N, max(P_n), 3).
502511 """
503512 self ._compute_padded ()
504513 return self ._normals_padded
505514
506- def features_padded (self ):
515+ def features_padded (self ) -> Optional [ torch . Tensor ] :
507516 """
508- Get the padded representation of the features.
517+ Get the padded representation of the features,
518+ or None if there are no features.
509519
510520 Returns:
511521 tensor of features of shape (N, max(P_n), 3).
@@ -562,16 +572,18 @@ def _compute_padded(self, refresh: bool = False):
562572 pad_value = 0.0 ,
563573 equisized = self .equisized ,
564574 )
565- if self .normals_list () is not None :
575+ normals_list = self .normals_list ()
576+ if normals_list is not None :
566577 self ._normals_padded = struct_utils .list_to_padded (
567- self . normals_list () ,
578+ normals_list ,
568579 (self ._P , 3 ),
569580 pad_value = 0.0 ,
570581 equisized = self .equisized ,
571582 )
572- if self .features_list () is not None :
583+ features_list = self .features_list ()
584+ if features_list is not None :
573585 self ._features_padded = struct_utils .list_to_padded (
574- self . features_list () ,
586+ features_list ,
575587 (self ._P , self ._C ),
576588 pad_value = 0.0 ,
577589 equisized = self .equisized ,
@@ -772,10 +784,12 @@ def get_cloud(self, index: int):
772784 )
773785 points = self .points_list ()[index ]
774786 normals , features = None , None
775- if self .normals_list () is not None :
776- normals = self .normals_list ()[index ]
777- if self .features_list () is not None :
778- features = self .features_list ()[index ]
787+ normals_list = self .normals_list ()
788+ if normals_list is not None :
789+ normals = normals_list [index ]
790+ features_list = self .features_list ()
791+ if features_list is not None :
792+ features = features_list [index ]
779793 return points , normals , features
780794
781795 # TODO(nikhilar) Move function to a utils file.
@@ -1022,13 +1036,15 @@ def extend(self, N: int):
10221036 new_points_list , new_normals_list , new_features_list = [], None , None
10231037 for points in self .points_list ():
10241038 new_points_list .extend (points .clone () for _ in range (N ))
1025- if self .normals_list () is not None :
1039+ normals_list = self .normals_list ()
1040+ if normals_list is not None :
10261041 new_normals_list = []
1027- for normals in self . normals_list () :
1042+ for normals in normals_list :
10281043 new_normals_list .extend (normals .clone () for _ in range (N ))
1029- if self .features_list () is not None :
1044+ features_list = self .features_list ()
1045+ if features_list is not None :
10301046 new_features_list = []
1031- for features in self . features_list () :
1047+ for features in features_list :
10321048 new_features_list .extend (features .clone () for _ in range (N ))
10331049 return self .__class__ (
10341050 points = new_points_list , normals = new_normals_list , features = new_features_list
0 commit comments