7
7
8
8
9
9
def _torch_function_get_self (cls , func , types , args , kwargs ):
10
- """ Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
10
+ """Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
11
11
12
12
"A simple solution would be to scan the args for the first subclass of this class.
13
13
My question is more: will forcing this to be a subclass actually be a problem for some use case?
@@ -36,7 +36,7 @@ class AugmentedTensor(torch.Tensor):
36
36
37
37
# Ignore named tansors userwarning.
38
38
ERROR_MSG = "Named tensors and all their associated APIs are an experimental feature and subject to change"
39
- warnings .filterwarnings (action = ' ignore' , message = ERROR_MSG )
39
+ warnings .filterwarnings (action = " ignore" , message = ERROR_MSG )
40
40
41
41
@staticmethod
42
42
def __new__ (cls , x , names = None , device = None , * args , ** kwargs ):
@@ -335,7 +335,6 @@ def __getitem__(self, idx):
335
335
336
336
tensor = tensor .reset_names () if len (self .shape ) == len (tensor .shape ) else tensor .as_tensor ()
337
337
338
-
339
338
# if not idx.dtype == torch.bool:
340
339
# if not torch.equal(idx ** 3, idx):
341
340
# raise IndexError(f"Unvalid mask. Expected mask elements to be in [0, 1, True, False]")
@@ -475,6 +474,7 @@ def _fillup_dict(dm, sub_label, dim, target_dim):
475
474
else :
476
475
for s in range (len (sub_label )):
477
476
_fillup_dict (dm [s ], sub_label [s ], dim + 1 , target_dim )
477
+
478
478
_fillup_dict (dict_merge [key ], label , 0 , target_dim )
479
479
480
480
return dict_merge
@@ -499,7 +499,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
499
499
setattr (n_tensor , prop , None )
500
500
else :
501
501
values = set ([prop_name_to_value [prop ], getattr (tensor , prop )])
502
- raise RuntimeError (f"Encountered different values for property '{ prop } ' while merging AugmentedTensor: { values } " )
502
+ raise RuntimeError (
503
+ f"Encountered different values for property '{ prop } ' while merging AugmentedTensor: { values } "
504
+ )
503
505
else :
504
506
prop_name_to_value [prop ] = getattr (tensor , prop )
505
507
@@ -545,15 +547,19 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
545
547
if intersection :
546
548
del labels_dict2list [label_name ][key ]
547
549
else :
548
- raise RuntimeError (f"Error during merging. Some tensors have label '{ label_name } ' with key '{ key } ' and some don't" )
550
+ raise RuntimeError (
551
+ f"Error during merging. Some tensors have label '{ label_name } ' with key '{ key } ' and some don't"
552
+ )
549
553
else :
550
554
args = list (args )
551
555
args [0 ] = labels_dict2list [label_name ][key ]
552
556
labels_dict2list [label_name ][key ] = func (* tuple (args ), ** kwargs )
553
557
# if we removed all keys, set this child to None
554
558
if intersection and not labels_dict2list [label_name ]:
555
559
labels_dict2list [label_name ] = None
556
- elif intersection and (len (labels_dict2list [label_name ]) != dim_size or (None in labels_dict2list [label_name ])):
560
+ elif intersection and (
561
+ len (labels_dict2list [label_name ]) != dim_size or (None in labels_dict2list [label_name ])
562
+ ):
557
563
labels_dict2list [label_name ] = None
558
564
else :
559
565
args = list (args )
@@ -595,7 +601,6 @@ def __iter__(self):
595
601
for t in range (len (self )):
596
602
yield self [t ]
597
603
598
-
599
604
@classmethod
600
605
def __torch_function__ (cls , func , types , args = (), kwargs = None ):
601
606
self = _torch_function_get_self (cls , func , types , args , kwargs )
@@ -614,10 +619,10 @@ def _merging_frame(args):
614
619
if func .__name__ == "__reduce_ex__" :
615
620
self .rename_ (None , auto_restore_names = True )
616
621
tensor = super ().__torch_function__ (func , types , args , kwargs )
617
- #tensor = super().torch_func_method(func, types, args, kwargs)
622
+ # tensor = super().torch_func_method(func, types, args, kwargs)
618
623
else :
619
624
tensor = super ().__torch_function__ (func , types , args , kwargs )
620
- #tensor = super().torch_func_method(func, types, args, kwargs)
625
+ # tensor = super().torch_func_method(func, types, args, kwargs)
621
626
622
627
if isinstance (tensor , type (self )):
623
628
tensor ._property_list = self ._property_list
@@ -853,6 +858,9 @@ def _hflip_label(self, label, **kwargs):
853
858
try :
854
859
label_flipped = label ._hflip (** kwargs )
855
860
except AttributeError :
861
+ print (
862
+ f"[WARNING] Horizontal flip returned AttributeError on { type (label ).__name__ } , returning unflipped tensor."
863
+ )
856
864
return label
857
865
else :
858
866
return label_flipped
@@ -914,6 +922,7 @@ def resize_func(label):
914
922
label_resized = label ._resize (size01 , ** kwargs )
915
923
return label_resized
916
924
except AttributeError :
925
+ print (f"[WARNING] resize returned AttributeError on { type (label ).__name__ } , returning initial tensor." )
917
926
return label
918
927
919
928
resized = self ._resize (size01 , ** kwargs )
@@ -924,7 +933,7 @@ def resize_func(label):
924
933
def _resize (self , * args , ** kwargs ):
925
934
raise Exception ("This Augmented tensor should implement this method" )
926
935
927
- def rotate (self , angle , center = None ,** kwargs ):
936
+ def rotate (self , angle , center = None , ** kwargs ):
928
937
"""
929
938
Rotate AugmentedTensor, and its labels recursively
930
939
@@ -941,12 +950,15 @@ def rotate(self, angle, center=None,**kwargs):
941
950
942
951
def rotate_func (label ):
943
952
try :
944
- label_rotated = label ._rotate (angle , center ,** kwargs )
953
+ label_rotated = label ._rotate (angle , center , ** kwargs )
945
954
return label_rotated
946
955
except AttributeError :
956
+ print (
957
+ f"[WARNING] Rotate returned AttributeError on { type (label ).__name__ } , returning unrotated tensor."
958
+ )
947
959
return label
948
960
949
- rotated = self ._rotate (angle , center ,** kwargs )
961
+ rotated = self ._rotate (angle , center , ** kwargs )
950
962
rotated .recursive_apply_on_children_ (rotate_func )
951
963
952
964
return rotated
@@ -956,6 +968,7 @@ def _crop_label(self, label, H_crop, W_crop, **kwargs):
956
968
label_resized = label ._crop (H_crop , W_crop , ** kwargs )
957
969
return label_resized
958
970
except AttributeError :
971
+ print (f"[WARNING] Crop returned AttributeError on { type (label ).__name__ } , returning uncropped tensor." )
959
972
return label
960
973
961
974
def crop (self , H_crop : tuple , W_crop : tuple , ** kwargs ):
@@ -991,6 +1004,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs):
991
1004
label_pad = label ._pad (offset_y , offset_x , ** kwargs )
992
1005
return label_pad
993
1006
except AttributeError :
1007
+ print (f"[WARNING] Padding returned AttributeError on { type (label ).__name__ } , returning unpadded tensor." )
994
1008
return label
995
1009
996
1010
def pad (self , offset_y : tuple = None , offset_x : tuple = None , multiple : int = None , ** kwargs ):
@@ -1043,6 +1057,9 @@ def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):
1043
1057
label_shift = label ._spatial_shift (shift_y , shift_x , ** kwargs )
1044
1058
return label_shift
1045
1059
except AttributeError :
1060
+ print (
1061
+ f"[WARNING] Spatial shift returned AttributeError on { type (label ).__name__ } , returning unshifted tensor."
1062
+ )
1046
1063
return label
1047
1064
1048
1065
def spatial_shift (self , shift_y : float , shift_x : float , ** kwargs ):
0 commit comments