diff --git a/supervision/dataset/utils.py b/supervision/dataset/utils.py index 6c30eeab0..2264df3ae 100644 --- a/supervision/dataset/utils.py +++ b/supervision/dataset/utils.py @@ -50,6 +50,18 @@ def approximate_mask_with_polygons( ] +def get_class_distribution(dataset: "DetectionDataset") -> Dict[str, int]: + """ + Returns a dictionary with class names as keys and sample counts as values. + """ + from collections import Counter + + all_classes = [] + for det in dataset.annotations: + all_classes.extend(det["class_names"]) + return dict(Counter(all_classes)) + + def merge_class_lists(class_lists: List[List[str]]) -> List[str]: unique_classes = set()