diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py index fe6ca6a924..2eec1993c7 100644 --- a/ml-agents/mlagents/trainers/stats.py +++ b/ml-agents/mlagents/trainers/stats.py @@ -13,6 +13,10 @@ class StatsSummary(NamedTuple): std: float num: int + @staticmethod + def empty() -> "StatsSummary": + return StatsSummary(0.0, 0.0, 0) + class StatsWriter(abc.ABC): """ @@ -184,8 +188,10 @@ def get_stats_summaries(self, key: str) -> StatsSummary: :param key: The type of statistic, e.g. Environment/Reward. :returns: A StatsSummary NamedTuple containing (mean, std, count). """ - return StatsSummary( - mean=np.mean(StatsReporter.stats_dict[self.category][key]), - std=np.std(StatsReporter.stats_dict[self.category][key]), - num=len(StatsReporter.stats_dict[self.category][key]), - ) + if len(StatsReporter.stats_dict[self.category][key]) > 0: + return StatsSummary( + mean=np.mean(StatsReporter.stats_dict[self.category][key]), + std=np.std(StatsReporter.stats_dict[self.category][key]), + num=len(StatsReporter.stats_dict[self.category][key]), + ) + return StatsSummary.empty()