Skip to content

Commit e78bdac

Browse files
authored
Merge pull request #348 from Visual-Behavior/augmentations_warning
Add warning when augmentations fail
2 parents 7b98304 + d9b3878 commit e78bdac

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

aloscene/tensors/augmented_tensor.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def _torch_function_get_self(cls, func, types, args, kwargs):
10-
""" Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
10+
"""Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
1111
1212
"A simple solution would be to scan the args for the first subclass of this class.
1313
My question is more: will forcing this to be a subclass actually be a problem for some use case?
@@ -36,7 +36,7 @@ class AugmentedTensor(torch.Tensor):
3636

3737
# Ignore named tansors userwarning.
3838
ERROR_MSG = "Named tensors and all their associated APIs are an experimental feature and subject to change"
39-
warnings.filterwarnings(action='ignore', message=ERROR_MSG)
39+
warnings.filterwarnings(action="ignore", message=ERROR_MSG)
4040

4141
@staticmethod
4242
def __new__(cls, x, names=None, device=None, *args, **kwargs):
@@ -335,7 +335,6 @@ def __getitem__(self, idx):
335335

336336
tensor = tensor.reset_names() if len(self.shape) == len(tensor.shape) else tensor.as_tensor()
337337

338-
339338
# if not idx.dtype == torch.bool:
340339
# if not torch.equal(idx ** 3, idx):
341340
# raise IndexError(f"Unvalid mask. Expected mask elements to be in [0, 1, True, False]")
@@ -475,6 +474,7 @@ def _fillup_dict(dm, sub_label, dim, target_dim):
475474
else:
476475
for s in range(len(sub_label)):
477476
_fillup_dict(dm[s], sub_label[s], dim + 1, target_dim)
477+
478478
_fillup_dict(dict_merge[key], label, 0, target_dim)
479479

480480
return dict_merge
@@ -499,7 +499,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
499499
setattr(n_tensor, prop, None)
500500
else:
501501
values = set([prop_name_to_value[prop], getattr(tensor, prop)])
502-
raise RuntimeError(f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}")
502+
raise RuntimeError(
503+
f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}"
504+
)
503505
else:
504506
prop_name_to_value[prop] = getattr(tensor, prop)
505507

@@ -545,15 +547,19 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
545547
if intersection:
546548
del labels_dict2list[label_name][key]
547549
else:
548-
raise RuntimeError(f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't")
550+
raise RuntimeError(
551+
f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't"
552+
)
549553
else:
550554
args = list(args)
551555
args[0] = labels_dict2list[label_name][key]
552556
labels_dict2list[label_name][key] = func(*tuple(args), **kwargs)
553557
# if we removed all keys, set this child to None
554558
if intersection and not labels_dict2list[label_name]:
555559
labels_dict2list[label_name] = None
556-
elif intersection and (len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name])):
560+
elif intersection and (
561+
len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name])
562+
):
557563
labels_dict2list[label_name] = None
558564
else:
559565
args = list(args)
@@ -595,7 +601,6 @@ def __iter__(self):
595601
for t in range(len(self)):
596602
yield self[t]
597603

598-
599604
@classmethod
600605
def __torch_function__(cls, func, types, args=(), kwargs=None):
601606
self = _torch_function_get_self(cls, func, types, args, kwargs)
@@ -614,10 +619,10 @@ def _merging_frame(args):
614619
if func.__name__ == "__reduce_ex__":
615620
self.rename_(None, auto_restore_names=True)
616621
tensor = super().__torch_function__(func, types, args, kwargs)
617-
#tensor = super().torch_func_method(func, types, args, kwargs)
622+
# tensor = super().torch_func_method(func, types, args, kwargs)
618623
else:
619624
tensor = super().__torch_function__(func, types, args, kwargs)
620-
#tensor = super().torch_func_method(func, types, args, kwargs)
625+
# tensor = super().torch_func_method(func, types, args, kwargs)
621626

622627
if isinstance(tensor, type(self)):
623628
tensor._property_list = self._property_list
@@ -853,6 +858,9 @@ def _hflip_label(self, label, **kwargs):
853858
try:
854859
label_flipped = label._hflip(**kwargs)
855860
except AttributeError:
861+
print(
862+
f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
863+
)
856864
return label
857865
else:
858866
return label_flipped
@@ -914,6 +922,7 @@ def resize_func(label):
914922
label_resized = label._resize(size01, **kwargs)
915923
return label_resized
916924
except AttributeError:
925+
print(f"[WARNING] resize returned AttributeError on {type(label).__name__}, returning initial tensor.")
917926
return label
918927

919928
resized = self._resize(size01, **kwargs)
@@ -924,7 +933,7 @@ def resize_func(label):
924933
def _resize(self, *args, **kwargs):
925934
raise Exception("This Augmented tensor should implement this method")
926935

927-
def rotate(self, angle, center=None,**kwargs):
936+
def rotate(self, angle, center=None, **kwargs):
928937
"""
929938
Rotate AugmentedTensor, and its labels recursively
930939
@@ -941,12 +950,15 @@ def rotate(self, angle, center=None,**kwargs):
941950

942951
def rotate_func(label):
943952
try:
944-
label_rotated = label._rotate(angle, center,**kwargs)
953+
label_rotated = label._rotate(angle, center, **kwargs)
945954
return label_rotated
946955
except AttributeError:
956+
print(
957+
f"[WARNING] Rotate returned AttributeError on {type(label).__name__}, returning unrotated tensor."
958+
)
947959
return label
948960

949-
rotated = self._rotate(angle, center,**kwargs)
961+
rotated = self._rotate(angle, center, **kwargs)
950962
rotated.recursive_apply_on_children_(rotate_func)
951963

952964
return rotated
@@ -956,6 +968,7 @@ def _crop_label(self, label, H_crop, W_crop, **kwargs):
956968
label_resized = label._crop(H_crop, W_crop, **kwargs)
957969
return label_resized
958970
except AttributeError:
971+
print(f"[WARNING] Crop returned AttributeError on {type(label).__name__}, returning uncropped tensor.")
959972
return label
960973

961974
def crop(self, H_crop: tuple, W_crop: tuple, **kwargs):
@@ -991,6 +1004,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs):
9911004
label_pad = label._pad(offset_y, offset_x, **kwargs)
9921005
return label_pad
9931006
except AttributeError:
1007+
print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.")
9941008
return label
9951009

9961010
def pad(self, offset_y: tuple = None, offset_x: tuple = None, multiple: int = None, **kwargs):
@@ -1043,6 +1057,9 @@ def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):
10431057
label_shift = label._spatial_shift(shift_y, shift_x, **kwargs)
10441058
return label_shift
10451059
except AttributeError:
1060+
print(
1061+
f"[WARNING] Spatial shift returned AttributeError on {type(label).__name__}, returning unshifted tensor."
1062+
)
10461063
return label
10471064

10481065
def spatial_shift(self, shift_y: float, shift_x: float, **kwargs):

aloscene/tensors/spatial_augmented_tensor.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import warnings
1616

17+
1718
class SpatialAugmentedTensor(AugmentedTensor):
1819
"""Spatial Augmented Tensor. Used to represets any 2D data. The spatial augmented tensor can be used as a
1920
basis for images, depth or and spatially related data. Moreover, for stereo setup, the augmented tensor
@@ -131,7 +132,10 @@ def get_view(self, views: list = [], exclude=[], size=None, grid_size=None, titl
131132
"""
132133
_views = [v for v in views if isinstance(v, View)]
133134
if len(_views) > 0:
134-
return View(Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs), title=title)
135+
return View(
136+
Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs),
137+
title=title,
138+
)
135139

136140
# Include type
137141
include_type = [
@@ -427,9 +431,15 @@ def _relative_to_absolute_hs_ws(self, hs=None, ws=None, assert_integer=True, war
427431
assert hs is None or isinstance(hs, (list, tuple)), "hs should be a list or a tuple of floats"
428432
assert ws is None or isinstance(ws, (list, tuple)), "ws should be a list or a tuple of floats"
429433
if hs is not None:
430-
hs = [self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for h in hs]
434+
hs = [
435+
self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer)
436+
for h in hs
437+
]
431438
if ws is not None:
432-
ws = [self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for w in ws]
439+
ws = [
440+
self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer)
441+
for w in ws
442+
]
433443
return hs, ws
434444

435445
def _hflip_label(self, label, **kwargs):
@@ -441,6 +451,9 @@ def _hflip_label(self, label, **kwargs):
441451
frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs
442452
)
443453
except AttributeError:
454+
print(
455+
f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
456+
)
444457
return label
445458
else:
446459
return label_flipped
@@ -454,6 +467,9 @@ def _vflip_label(self, label, **kwargs):
454467
frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs
455468
)
456469
except AttributeError:
470+
print(
471+
f"[WARNING] Vertical flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
472+
)
457473
return label
458474
else:
459475
return label_flipped
@@ -528,7 +544,7 @@ def _resize(self, size, interpolation=InterpolationMode.BILINEAR, **kwargs):
528544
return self.rename(None).view(shapes).reset_names()
529545
return F.resize(self.rename(None), (h, w), interpolation=interpolation).reset_names()
530546

531-
def _rotate(self, angle, center=None,**kwargs):
547+
def _rotate(self, angle, center=None, **kwargs):
532548
"""Rotate SpatialAugmentedTensor, but not its labels
533549
534550
Parameters
@@ -546,7 +562,7 @@ def _rotate(self, angle, center=None,**kwargs):
546562
assert not (
547563
("N" in self.names and self.size("N") == 0) or ("C" in self.names and self.size("C") == 0)
548564
), "rotation is not possible on an empty tensor"
549-
return F.rotate(self.rename(None), angle,center=center).reset_names()
565+
return F.rotate(self.rename(None), angle, center=center).reset_names()
550566

551567
def _crop(self, H_crop: tuple, W_crop: tuple, **kwargs):
552568
"""Crop the SpatialAugmentedTensor
@@ -576,6 +592,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs):
576592
label_pad = label._pad(offset_y, offset_x, **kwargs)
577593
return label_pad
578594
except AttributeError:
595+
print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.")
579596
return label
580597

581598
def _pad(self, offset_y: tuple, offset_x: tuple, **kwargs):

0 commit comments

Comments
 (0)