Skip to content

Commit abfe387

Browse files
flatten results if single class dataset
1 parent 3817231 commit abfe387

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

boxmot/engine/evaluator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,14 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
510510

511511
trackeval_results = trackeval(opt, seq_paths, save_dir, gt_folder)
512512
parsed_results = parse_mot_results(trackeval_results)
513+
514+
# Load config to filter classes
515+
cfg = load_dataset_cfg(str(opt.source.parent.name))
516+
517+
# Filter parsed_results to only include classes from the benchmark
518+
if "benchmark" in cfg and "classes" in cfg["benchmark"]:
519+
bench_classes = cfg["benchmark"]["classes"].split()
520+
parsed_results = {k: v for k, v in parsed_results.items() if k in bench_classes}
513521

514522
# Print results summary
515523
if verbose:
@@ -540,10 +548,10 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
540548
LOGGER.opt(colors=True).info("<blue>" + "="*90 + "</blue>")
541549

542550
# Flatten results if only one class is present (backward compatibility)
543-
cfg = load_dataset_cfg(str(opt.source.parent.name))
544551
final_results = parsed_results
545-
if len(cfg["benchmark"]["classes"].split()) == 1:
546-
final_results = list(parsed_results.values())[0]
552+
if "benchmark" in cfg and "classes" in cfg["benchmark"]:
553+
if len(cfg["benchmark"]["classes"].split()) == 1 and len(parsed_results) > 0:
554+
final_results = list(parsed_results.values())[0]
547555

548556
if opt.ci:
549557
with open(opt.tracking_method + "_output.json", "w") as outfile:

0 commit comments

Comments
 (0)