Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion alonet/common/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
"https://storage.googleapis.com/visualbehavior-publicweights/trackformer-deformable-mot/trackformer-deformable-mot.pth"
],
"trackformer-crowdhuman-deformable-mot": [
"https://storage.googleapis.com/visualbehavior-publicweights/trackformer-crowdhuman-deformable-mot/trackformer-crowdhuman-deformable-mot.pth"
"https://storage.googleapis.com/visualbehavior-publicweights/trackformer-crowdhuman-deformable-mot/trackformer-crowdhuman-deformable-mot.pth"
],
"detr-r50-panoptic": [
"https://storage.googleapis.com/visualbehavior-publicweights/detr-r50-panoptic/detr-r50-panoptic.pth"
],
"detr-r50-things-stuffs": [
"https://storage.googleapis.com/visualbehavior-publicweights/detr-r50-things-stuffs/detr-r50-things-stuffs.pth"
]
}

Expand Down
5 changes: 2 additions & 3 deletions alonet/detr_panoptic/detr_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(

# Load weights
if weights is not None:
if ".pth" in weights or ".ckpt" in weights:
if ".pth" in weights or ".ckpt" in weights or weights in ["detr-r50-panoptic"]:
alonet.common.load_weights(self, weights, device, strict_load_weights=strict_load_weights)
else:
raise ValueError(f"Unknown weights: '{weights}'")
Expand Down Expand Up @@ -213,8 +213,7 @@ def main(image_path):
device = torch.device("cuda")

# Load model
weights = os.path.expanduser("~/.aloception/weights/detr-r50-panoptic/detr-r50-panoptic.pth")
model = PanopticHead(DetrR50(num_classes=250), weights=weights)
model = PanopticHead(DetrR50(num_classes=250), weights="detr-r50-panoptic")
model.to(device)

# Open and prepare a batch for the model
Expand Down
5 changes: 4 additions & 1 deletion aloscene/bounding_boxes_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from aloscene.labels import Labels
from torchvision.ops.boxes import nms

from aloscene.renderer import View, put_adapative_cv2_text


class BoundingBoxes2D(aloscene.tensors.AugmentedTensor):
"""BoundingBoxes2D Augmented Tensor. Used to represents 2D boxes in space encoded as `xcyc` (xc, yc, width, height
Expand Down Expand Up @@ -508,6 +510,7 @@ def get_view(self, frame: Tensor = None, size: tuple = None, labels_set: str = N
boxes_abs = self.xyxy().abs_pos(frame.HW)

# Get an imave with values between 0 and 1
frame_size = frame.HW
frame = frame.norm01().cpu().rename(None).permute([1, 2, 0]).detach().contiguous().numpy()
# Draw bouding boxes

Expand All @@ -531,12 +534,12 @@ def get_view(self, frame: Tensor = None, size: tuple = None, labels_set: str = N
box = box.round()
x1, y1, x2, y2 = box.as_tensor()
color = (0, 1, 0)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 3)
if label is not None:
color = self._GLOBAL_COLOR_SET[int(label) % len(self._GLOBAL_COLOR_SET)]
cv2.putText(
frame, str(int(label)), (int(x2), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA
)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 3)
# Return the view to display
return View(frame, **kwargs)

Expand Down
13 changes: 7 additions & 6 deletions aloscene/tensors/augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,15 +763,16 @@ def recursive_apply_on_children_(self, func):
"""
Recursively apply function on labels to modify tensor inplace
"""
# def __apply(l):
# if isinstance(l, torch.Tensor):
# return l
# else:
# return func(l).recursive_apply_on_children_(func)

def __apply(l):
if isinstance(l, torch.Tensor):
return l
else:
return func(l).recursive_apply_on_children_(func)

for name in self._children_list:
label = getattr(self, name)
modified_label = self.apply_on_child(label, lambda l: func(l).recursive_apply_on_children_(func))
modified_label = self.apply_on_child(label, __apply)
setattr(self, name, modified_label)
return self

Expand Down