Skip to content

Commit f813287

Browse files
authored
Added SUM as aggregation type for custom statistics (#4816)
1 parent eea679d commit f813287

File tree

12 files changed

+171
-38
lines changed

12 files changed

+171
-38
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
/summaries
44
# Output Artifacts
55
/results
6+
# Output Builds
7+
/Builds
68

79
# Training environments
810
/envs

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ public class HallwayAgent : Agent
1818
Renderer m_GroundRenderer;
1919
HallwaySettings m_HallwaySettings;
2020
int m_Selection;
21+
StatsRecorder m_statsRecorder;
2122

2223
public override void Initialize()
2324
{
2425
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
2526
m_AgentRb = GetComponent<Rigidbody>();
2627
m_GroundRenderer = ground.GetComponent<Renderer>();
2728
m_GroundMaterial = m_GroundRenderer.material;
29+
m_statsRecorder = Academy.Instance.StatsRecorder;
2830
}
2931

3032
public override void CollectObservations(VectorSensor sensor)
@@ -83,11 +85,13 @@ void OnCollisionEnter(Collision col)
8385
{
8486
SetReward(1f);
8587
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f));
88+
m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum);
8689
}
8790
else
8891
{
8992
SetReward(-0.1f);
9093
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
94+
m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum);
9195
}
9296
EndEpisode();
9397
}
@@ -156,5 +160,7 @@ public override void OnEpisodeBegin()
156160
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position;
157161
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position;
158162
}
163+
m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum);
164+
m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum);
159165
}
160166
}

com.unity.ml-agents/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ and this project adheres to
1515

1616
### Minor Changes
1717
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
18+
- `StatAggregationMethod.Sum` can now be passed to `StatsRecorder.Add()`. This
19+
will result in the values being summed (instead of averaged) when written to
20+
TensorBoard. Thanks to @brccabral for the contribution! (#4816)
1821

1922
#### ml-agents / ml-agents-envs / gym-unity (Python)
2023

com.unity.ml-agents/Runtime/StatsRecorder.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ public enum StatAggregationMethod
1717
/// To avoid conflicts when training with multiple concurrent environments, only
1818
/// stats from worker index 0 will be tracked.
1919
/// </summary>
20-
MostRecent = 1
20+
MostRecent = 1,
21+
22+
/// <summary>
23+
/// Values within the summary period are summed up before reporting.
24+
/// </summary>
25+
Sum = 2
2126
}
2227

2328
/// <summary>

ml-agents-envs/mlagents_envs/side_channel/stats_side_channel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class StatsAggregationMethod(Enum):
1414
# Only the most recent value is reported.
1515
MOST_RECENT = 1
1616

17+
# Values within the summary period are summed up before reporting.
18+
SUM = 2
19+
1720

1821
StatList = List[Tuple[float, StatsAggregationMethod]]
1922
EnvironmentStats = Mapping[str, StatList]
@@ -35,6 +38,7 @@ def __init__(self) -> None:
3538
def on_message_received(self, msg: IncomingMessage) -> None:
3639
"""
3740
Receive the message from the environment, and save it for later retrieval.
41+
3842
:param msg:
3943
:return:
4044
"""
@@ -47,6 +51,7 @@ def on_message_received(self, msg: IncomingMessage) -> None:
4751
def get_and_reset_stats(self) -> EnvironmentStats:
4852
"""
4953
Returns the current stats, and resets the internal storage of the stats.
54+
5055
:return:
5156
"""
5257
s = self.stats

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
):
4141
"""
4242
Create an AgentProcessor.
43+
4344
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory
4445
when it is finished.
4546
:param policy: Policy instance associated with this AgentProcessor.
@@ -112,7 +113,12 @@ def add_experiences(
112113
)
113114

114115
def _process_step(
115-
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
116+
self,
117+
step: Union[
118+
TerminalStep, DecisionStep
119+
], # pylint: disable=unsubscriptable-object
120+
global_id: str,
121+
index: int,
116122
) -> None:
117123
terminated = isinstance(step, TerminalStep)
118124
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
@@ -318,15 +324,18 @@ def record_environment_stats(
318324
"""
319325
Pass stats from the environment to the StatsReporter.
320326
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
321-
The worker_id is used to determin whether StatsReporter.set_stat should be used.
327+
The worker_id is used to determine whether StatsReporter.set_stat should be used.
328+
322329
:param env_stats:
323330
:param worker_id:
324331
:return:
325332
"""
326333
for stat_name, value_list in env_stats.items():
327334
for val, agg_type in value_list:
328335
if agg_type == StatsAggregationMethod.AVERAGE:
329-
self.stats_reporter.add_stat(stat_name, val)
336+
self.stats_reporter.add_stat(stat_name, val, agg_type)
337+
elif agg_type == StatsAggregationMethod.SUM:
338+
self.stats_reporter.add_stat(stat_name, val, agg_type)
330339
elif agg_type == StatsAggregationMethod.MOST_RECENT:
331340
# In order to prevent conflicts between multiple environments,
332341
# only stats from the first environment are recorded.

ml-agents/mlagents/trainers/stats.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import time
88
from threading import RLock
99

10+
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
11+
1012
from mlagents_envs.logging_util import get_logger
1113
from mlagents_envs.timers import set_gauge
1214
from torch.utils.tensorboard import SummaryWriter
@@ -20,8 +22,9 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
2022
"""
2123
Takes a parameter dictionary and converts it to a human-readable string.
2224
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
23-
param: param_dict: A Dictionary of key, value parameters.
24-
return: A string version of this dictionary.
25+
26+
:param param_dict: A Dictionary of key, value parameters.
27+
:return: A string version of this dictionary.
2528
"""
2629
if not isinstance(param_dict, dict):
2730
return str(param_dict)
@@ -37,14 +40,23 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
3740
)
3841

3942

40-
class StatsSummary(NamedTuple):
43+
class StatsSummary(NamedTuple): # pylint: disable=inherit-non-class
4144
mean: float
4245
std: float
4346
num: int
47+
sum: float
48+
aggregation_method: StatsAggregationMethod
4449

4550
@staticmethod
4651
def empty() -> "StatsSummary":
47-
return StatsSummary(0.0, 0.0, 0)
52+
return StatsSummary(0.0, 0.0, 0, 0.0, StatsAggregationMethod.AVERAGE)
53+
54+
@property
55+
def aggregated_value(self):
56+
if self.aggregation_method == StatsAggregationMethod.SUM:
57+
return self.sum
58+
else:
59+
return self.mean
4860

4961

5062
class StatsPropertyType(Enum):
@@ -71,8 +83,9 @@ def add_property(
7183
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
7284
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
7385
with all types of properties. For instance, a TB writer doesn't need a max step.
86+
7487
:param category: The category that the property belongs to.
75-
:param type: The type of property.
88+
:param property_type: The type of property.
7689
:param value: The property itself.
7790
"""
7891
pass
@@ -98,6 +111,10 @@ def write_stats(
98111
GaugeWriter.sanitize_string(f"{category}.{val}.mean"),
99112
float(stats_summary.mean),
100113
)
114+
set_gauge(
115+
GaugeWriter.sanitize_string(f"{category}.{val}.sum"),
116+
float(stats_summary.sum),
117+
)
101118

102119

103120
class ConsoleWriter(StatsWriter):
@@ -114,7 +131,7 @@ def write_stats(
114131
is_training = "Not Training"
115132
if "Is Training" in values:
116133
stats_summary = values["Is Training"]
117-
if stats_summary.mean > 0.0:
134+
if stats_summary.aggregated_value > 0.0:
118135
is_training = "Training"
119136

120137
elapsed_time = time.time() - self.training_start_time
@@ -156,10 +173,11 @@ class TensorboardWriter(StatsWriter):
156173
def __init__(self, base_dir: str, clear_past_data: bool = False):
157174
"""
158175
A StatsWriter that writes to a Tensorboard summary.
176+
159177
:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
160178
{base_dir}/{category} directory.
161179
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and
162-
category.
180+
category.
163181
"""
164182
self.summary_writers: Dict[str, SummaryWriter] = {}
165183
self.base_dir: str = base_dir
@@ -170,7 +188,9 @@ def write_stats(
170188
) -> None:
171189
self._maybe_create_summary_writer(category)
172190
for key, value in values.items():
173-
self.summary_writers[category].add_scalar(f"{key}", value.mean, step)
191+
self.summary_writers[category].add_scalar(
192+
f"{key}", value.aggregated_value, step
193+
)
174194
self.summary_writers[category].flush()
175195

176196
def _maybe_create_summary_writer(self, category: str) -> None:
@@ -214,6 +234,9 @@ class StatsReporter:
214234
writers: List[StatsWriter] = []
215235
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
216236
lock = RLock()
237+
stats_aggregation: Dict[str, Dict[str, StatsAggregationMethod]] = defaultdict(
238+
lambda: defaultdict(lambda: StatsAggregationMethod.AVERAGE)
239+
)
217240

218241
def __init__(self, category: str):
219242
"""
@@ -234,37 +257,51 @@ def add_property(self, property_type: StatsPropertyType, value: Any) -> None:
234257
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
235258
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
236259
with all types of properties. For instance, a TB writer doesn't need a max step.
237-
:param key: The type of property.
260+
261+
:param property_type: The type of property.
238262
:param value: The property itself.
239263
"""
240264
with StatsReporter.lock:
241265
for writer in StatsReporter.writers:
242266
writer.add_property(self.category, property_type, value)
243267

244-
def add_stat(self, key: str, value: float) -> None:
268+
def add_stat(
269+
self,
270+
key: str,
271+
value: float,
272+
aggregation: StatsAggregationMethod = StatsAggregationMethod.AVERAGE,
273+
) -> None:
245274
"""
246275
Add a float value stat to the StatsReporter.
276+
247277
:param key: The type of statistic, e.g. Environment/Reward.
248278
:param value: the value of the statistic.
279+
:param aggregation: the aggregation method for the statistic, default StatsAggregationMethod.AVERAGE.
249280
"""
250281
with StatsReporter.lock:
251282
StatsReporter.stats_dict[self.category][key].append(value)
283+
StatsReporter.stats_aggregation[self.category][key] = aggregation
252284

253285
def set_stat(self, key: str, value: float) -> None:
254286
"""
255287
Sets a stat value to a float. This is for values that we don't want to average, and just
256288
want the latest.
289+
257290
:param key: The type of statistic, e.g. Environment/Reward.
258291
:param value: the value of the statistic.
259292
"""
260293
with StatsReporter.lock:
261294
StatsReporter.stats_dict[self.category][key] = [value]
295+
StatsReporter.stats_aggregation[self.category][
296+
key
297+
] = StatsAggregationMethod.MOST_RECENT
262298

263299
def write_stats(self, step: int) -> None:
264300
"""
265301
Write out all stored statistics that fall under the category specified.
266302
The currently stored values will be averaged, written out as a single value,
267303
and the buffer cleared.
304+
268305
:param step: Training step which to write these stats as.
269306
"""
270307
with StatsReporter.lock:
@@ -279,14 +316,19 @@ def write_stats(self, step: int) -> None:
279316

280317
def get_stats_summaries(self, key: str) -> StatsSummary:
281318
"""
282-
Get the mean, std, and count of a particular statistic, since last write.
319+
Get the mean, std, count, sum and aggregation method of a particular statistic, since last write.
320+
283321
:param key: The type of statistic, e.g. Environment/Reward.
284-
:returns: A StatsSummary NamedTuple containing (mean, std, count).
322+
:returns: A StatsSummary containing summary statistics.
285323
"""
286-
if len(StatsReporter.stats_dict[self.category][key]) > 0:
287-
return StatsSummary(
288-
mean=np.mean(StatsReporter.stats_dict[self.category][key]),
289-
std=np.std(StatsReporter.stats_dict[self.category][key]),
290-
num=len(StatsReporter.stats_dict[self.category][key]),
291-
)
292-
return StatsSummary.empty()
324+
stat_values = StatsReporter.stats_dict[self.category][key]
325+
if len(stat_values) == 0:
326+
return StatsSummary.empty()
327+
328+
return StatsSummary(
329+
mean=np.mean(stat_values),
330+
std=np.std(stat_values),
331+
num=len(stat_values),
332+
sum=np.sum(stat_values),
333+
aggregation_method=StatsReporter.stats_aggregation[self.category][key],
334+
)

ml-agents/mlagents/trainers/tests/check_env_trains.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def write_stats(
2828
) -> None:
2929
for val, stats_summary in values.items():
3030
if val == "Environment/Cumulative Reward":
31-
print(step, val, stats_summary.mean)
32-
self._last_reward_summary[category] = stats_summary.mean
31+
print(step, val, stats_summary.aggregated_value)
32+
self._last_reward_summary[category] = stats_summary.aggregated_value
3333

3434

3535
# The reward processor is passed as an argument to _check_environment_trains.

ml-agents/mlagents/trainers/tests/test_agent_processor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,18 +262,39 @@ def test_agent_manager_stats():
262262
{
263263
"averaged": [(1.0, StatsAggregationMethod.AVERAGE)],
264264
"most_recent": [(2.0, StatsAggregationMethod.MOST_RECENT)],
265+
"summed": [(3.1, StatsAggregationMethod.SUM)],
265266
},
266267
{
267268
"averaged": [(3.0, StatsAggregationMethod.AVERAGE)],
268269
"most_recent": [(4.0, StatsAggregationMethod.MOST_RECENT)],
270+
"summed": [(1.1, StatsAggregationMethod.SUM)],
269271
},
270272
]
271273
for env_stats in all_env_stats:
272274
manager.record_environment_stats(env_stats, worker_id=0)
273275

274276
expected_stats = {
275-
"averaged": StatsSummary(mean=2.0, std=mock.ANY, num=2),
276-
"most_recent": StatsSummary(mean=4.0, std=0.0, num=1),
277+
"averaged": StatsSummary(
278+
mean=2.0,
279+
std=mock.ANY,
280+
num=2,
281+
sum=4.0,
282+
aggregation_method=StatsAggregationMethod.AVERAGE,
283+
),
284+
"most_recent": StatsSummary(
285+
mean=4.0,
286+
std=0.0,
287+
num=1,
288+
sum=4.0,
289+
aggregation_method=StatsAggregationMethod.MOST_RECENT,
290+
),
291+
"summed": StatsSummary(
292+
mean=2.1,
293+
std=mock.ANY,
294+
num=2,
295+
sum=4.2,
296+
aggregation_method=StatsAggregationMethod.SUM,
297+
),
277298
}
278299
stats_reporter.write_stats(123)
279300
writer.write_stats.assert_any_call("FakeCategory", expected_stats, 123)

0 commit comments

Comments
 (0)