Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions alonet/common/pl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def vb_folder():
alofolder = os.path.join(home, ".aloception")
if not os.path.exists(alofolder):
raise Exception(
f"{alofolder} do not exists. Please, create the folder with the appropriate files. (Checkout documentation)"
f"{alofolder} do not exist. Please, create the folder with the appropriate files. (Checkout documentation)"
)
return alofolder

Expand Down Expand Up @@ -74,41 +74,42 @@ def add_argparse_args(parent_parser, add_pl_args=True, mode="training"):
parser.add_argument("--project_run_id", type=str, help="Project related with the run ID to load")
parser.add_argument("--expe_name", type=str, default=None, help="expe_name to be logged in wandb")
parser.add_argument("--no_suffix", action="store_true", help="do not add date suffix to expe_name")
parser.add_argument("--nostrict", action="store_true", help="load from checkpoint to run a model with different weights names (default False)")
parser.add_argument(
"--nostrict",
action="store_true",
help="load from checkpoint to run a model with different weights names (default False)",
)

return parent_parser


def load_training(
lit_model_class,
args=None,
no_run_id: bool = None,
run_id: str = None,
project_run_id: str = None,
no_exception=False,
**kwargs,
lit_model_class, args: Namespace = None, run_id: str = None, project_run_id: str = None, **kwargs,
):
"""Load training"""
run_id = args.run_id if run_id is None else run_id
project_run_id = args.project_run_id if project_run_id is None else project_run_id
no_run_id = args.no_run_id if no_run_id is None else no_run_id
run_id = args.run_id if run_id is None and "run_id" in args else run_id
project_run_id = args.project_run_id if project_run_id is None and "project_run_id" in args else project_run_id
weights_path = getattr(args, "weights", None)
if "weights" in kwargs and kwargs["weights"] is not None: # Highest priority
weights_path = kwargs["weights"]

strict = True if "nostrict" in args else not args.nostrict
if run_id is not None and project_run_id is not None:
strict = not args.nostrict
run_id_project_dir = os.path.join(vb_folder(), f"project_{project_run_id}")
ckpt_path = os.path.join(run_id_project_dir, run_id, "last.ckpt")
if not os.path.exists(ckpt_path):
raise Exception(f"Impossible to load the ckpt at the following destination:{ckpt_path}")
print(f"Loading ckpt from {run_id} at {ckpt_path}")
lit_model = lit_model_class.load_from_checkpoint(ckpt_path, strict=strict, args=args, **kwargs)
elif no_exception and getattr(args, "weights", None) is not None:
lit_model = lit_model_class(args=args, **kwargs)
elif no_run_id:
lit_model = lit_model_class(args=args, **kwargs)
elif weights_path is not None:
if ".pth" in weights_path:
lit_model = lit_model_class(args=args, **kwargs)
elif ".ckpt" in weights_path:
lit_model = lit_model_class.load_from_checkpoint(weights_path, strict=strict, args=args, **kwargs)
else:
raise Exception(f"Impossible to load the weights at the following destination:{weights_path}")
else:
raise Exception(
"--run_id (optionally --project_run_id) must be given to load the experiment. (--no_run_id to skip this warning) "
)
raise Exception("--run_id (optionally --project_run_id) must be given to load the experiment.")

return lit_model

Expand Down
7 changes: 6 additions & 1 deletion alonet/common/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def load_weights(model, weights, device, strict_load_weights=True):
if "model" in checkpoint:
checkpoint = checkpoint["model"]
model.load_state_dict(checkpoint)

print(f"Weights loaded from {weights}")
elif ".ckpt" in weights:
checkpoint = torch.load(weights, map_location=device)["state_dict"]
checkpoint = {k.replace("model.", "") if "model." in k else k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint)
print(f"Weights loaded from {weights}")
elif weights in WEIGHT_NAME_TO_FILES:
weights_dir = os.path.join(weights_dir, weights)
if not os.path.exists(weights_dir):
Expand Down
22 changes: 10 additions & 12 deletions alonet/deformable_detr/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def __init__(
in_channels = backbone.num_channels[i]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1),
nn.GroupNorm(32, self.hidden_dim),
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1), nn.GroupNorm(32, self.hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
Expand Down Expand Up @@ -168,9 +167,14 @@ def __init__(
if device is not None:
self.to(device)

if weights is not None and weights in ["deformable-detr-r50", "deformable-detr-r50-refinement"]:
if weights is not None and (
weights in ["deformable-detr-r50", "deformable-detr-r50-refinement"]
or ".pth" in weights
or ".ckpt" in weights
):
alonet.common.load_weights(self, weights, device, strict_load_weights=strict_load_weights)
print(f"Loaded: {weights}")
else:
raise ValueError(f"Unknown weights: '{weights}'")

@assert_and_export_onnx(check_mean_std=True, input_mean_std=INPUT_MEAN_STD)
def forward(self, frames: aloscene.Frame, **kwargs):
Expand Down Expand Up @@ -310,11 +314,7 @@ def _set_aux_loss(self, outputs_class, outputs_coord, **kwargs):
# as a dict having both a Tensor and a list.
return [{"pred_logits": a, "pred_boxes": b, **kwargs} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

def get_outs_labels(
self,
m_outputs: dict = None,
activation_fn: str = None,
) -> List[torch.Tensor]:
def get_outs_labels(self, m_outputs: dict = None, activation_fn: str = None,) -> List[torch.Tensor]:
"""Given the model outs_scores and the model outs_labels,
return the labels and the associated scores.

Expand Down Expand Up @@ -492,9 +492,7 @@ def build_decoder_layer(
)

def build_decoder(
self,
dec_layers: int = 6,
return_intermediate_dec=True,
self, dec_layers: int = 6, return_intermediate_dec=True,
):

decoder_layer = self.build_decoder_layer()
Expand Down
8 changes: 4 additions & 4 deletions alonet/detr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def __init__(
if device is not None:
self.to(device)

if weights is not None and weights == "detr-r50":
if weights is not None and (weights == "detr-r50" or ".pth" in weights or ".ckpt" in weights):
alonet.common.load_weights(self, "detr-r50", device, strict_load_weights=strict_load_weights)
else:
raise ValueError(f"Unknown weights: '{weights}'")

self.device = device
self.INPUT_MEAN_STD = INPUT_MEAN_STD
Expand Down Expand Up @@ -306,9 +308,7 @@ def build_decoder_layer(
)

def build_decoder(
self,
hidden_dim: int = 256,
num_decoder_layers: int = 6,
self, hidden_dim: int = 256, num_decoder_layers: int = 6,
):

decoder_layer = self.build_decoder_layer()
Expand Down
2 changes: 1 addition & 1 deletion alonet/raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(

if weights is not None:
weights_from_original_repo = ["raft-things", "raft-chairs", "raft-small", "raft-kitti", "raft-sintel"]
if weights in weights_from_original_repo or ".pth" in weights:
if weights in weights_from_original_repo or ".pth" in weights or ".ckpt" in weights:
alonet.common.load_weights(self, weights, device)
else:
raise ValueError(f"Unknown weights: '{weights}'")
Expand Down