7
7
import time
8
8
from threading import RLock
9
9
10
+ from mlagents_envs .side_channel .stats_side_channel import StatsAggregationMethod
11
+
10
12
from mlagents_envs .logging_util import get_logger
11
13
from mlagents_envs .timers import set_gauge
12
14
from torch .utils .tensorboard import SummaryWriter
@@ -20,8 +22,9 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
20
22
"""
21
23
Takes a parameter dictionary and converts it to a human-readable string.
22
24
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
23
- param: param_dict: A Dictionary of key, value parameters.
24
- return: A string version of this dictionary.
25
+
26
+ :param param_dict: A Dictionary of key, value parameters.
27
+ :return: A string version of this dictionary.
25
28
"""
26
29
if not isinstance (param_dict , dict ):
27
30
return str (param_dict )
@@ -37,14 +40,23 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
37
40
)
38
41
39
42
40
- class StatsSummary (NamedTuple ):
43
+ class StatsSummary (NamedTuple ): # pylint: disable=inherit-non-class
41
44
mean : float
42
45
std : float
43
46
num : int
47
+ sum : float
48
+ aggregation_method : StatsAggregationMethod
44
49
45
50
@staticmethod
46
51
def empty () -> "StatsSummary" :
47
- return StatsSummary (0.0 , 0.0 , 0 )
52
+ return StatsSummary (0.0 , 0.0 , 0 , 0.0 , StatsAggregationMethod .AVERAGE )
53
+
54
+ @property
55
+ def aggregated_value (self ):
56
+ if self .aggregation_method == StatsAggregationMethod .SUM :
57
+ return self .sum
58
+ else :
59
+ return self .mean
48
60
49
61
50
62
class StatsPropertyType (Enum ):
@@ -71,8 +83,9 @@ def add_property(
71
83
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
72
84
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
73
85
with all types of properties. For instance, a TB writer doesn't need a max step.
86
+
74
87
:param category: The category that the property belongs to.
75
- :param type : The type of property.
88
+ :param property_type : The type of property.
76
89
:param value: The property itself.
77
90
"""
78
91
pass
@@ -98,6 +111,10 @@ def write_stats(
98
111
GaugeWriter .sanitize_string (f"{ category } .{ val } .mean" ),
99
112
float (stats_summary .mean ),
100
113
)
114
+ set_gauge (
115
+ GaugeWriter .sanitize_string (f"{ category } .{ val } .sum" ),
116
+ float (stats_summary .sum ),
117
+ )
101
118
102
119
103
120
class ConsoleWriter (StatsWriter ):
@@ -114,7 +131,7 @@ def write_stats(
114
131
is_training = "Not Training"
115
132
if "Is Training" in values :
116
133
stats_summary = values ["Is Training" ]
117
- if stats_summary .mean > 0.0 :
134
+ if stats_summary .aggregated_value > 0.0 :
118
135
is_training = "Training"
119
136
120
137
elapsed_time = time .time () - self .training_start_time
@@ -156,10 +173,11 @@ class TensorboardWriter(StatsWriter):
156
173
def __init__ (self , base_dir : str , clear_past_data : bool = False ):
157
174
"""
158
175
A StatsWriter that writes to a Tensorboard summary.
176
+
159
177
:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
160
178
{base_dir}/{category} directory.
161
179
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and
162
- category.
180
+ category.
163
181
"""
164
182
self .summary_writers : Dict [str , SummaryWriter ] = {}
165
183
self .base_dir : str = base_dir
@@ -170,7 +188,9 @@ def write_stats(
170
188
) -> None :
171
189
self ._maybe_create_summary_writer (category )
172
190
for key , value in values .items ():
173
- self .summary_writers [category ].add_scalar (f"{ key } " , value .mean , step )
191
+ self .summary_writers [category ].add_scalar (
192
+ f"{ key } " , value .aggregated_value , step
193
+ )
174
194
self .summary_writers [category ].flush ()
175
195
176
196
def _maybe_create_summary_writer (self , category : str ) -> None :
@@ -214,6 +234,9 @@ class StatsReporter:
214
234
writers : List [StatsWriter ] = []
215
235
stats_dict : Dict [str , Dict [str , List ]] = defaultdict (lambda : defaultdict (list ))
216
236
lock = RLock ()
237
+ stats_aggregation : Dict [str , Dict [str , StatsAggregationMethod ]] = defaultdict (
238
+ lambda : defaultdict (lambda : StatsAggregationMethod .AVERAGE )
239
+ )
217
240
218
241
def __init__ (self , category : str ):
219
242
"""
@@ -234,37 +257,51 @@ def add_property(self, property_type: StatsPropertyType, value: Any) -> None:
234
257
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
235
258
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
236
259
with all types of properties. For instance, a TB writer doesn't need a max step.
237
- :param key: The type of property.
260
+
261
+ :param property_type: The type of property.
238
262
:param value: The property itself.
239
263
"""
240
264
with StatsReporter .lock :
241
265
for writer in StatsReporter .writers :
242
266
writer .add_property (self .category , property_type , value )
243
267
244
- def add_stat (self , key : str , value : float ) -> None :
268
+ def add_stat (
269
+ self ,
270
+ key : str ,
271
+ value : float ,
272
+ aggregation : StatsAggregationMethod = StatsAggregationMethod .AVERAGE ,
273
+ ) -> None :
245
274
"""
246
275
Add a float value stat to the StatsReporter.
276
+
247
277
:param key: The type of statistic, e.g. Environment/Reward.
248
278
:param value: the value of the statistic.
279
+ :param aggregation: the aggregation method for the statistic, default StatsAggregationMethod.AVERAGE.
249
280
"""
250
281
with StatsReporter .lock :
251
282
StatsReporter .stats_dict [self .category ][key ].append (value )
283
+ StatsReporter .stats_aggregation [self .category ][key ] = aggregation
252
284
253
285
def set_stat (self , key : str , value : float ) -> None :
254
286
"""
255
287
Sets a stat value to a float. This is for values that we don't want to average, and just
256
288
want the latest.
289
+
257
290
:param key: The type of statistic, e.g. Environment/Reward.
258
291
:param value: the value of the statistic.
259
292
"""
260
293
with StatsReporter .lock :
261
294
StatsReporter .stats_dict [self .category ][key ] = [value ]
295
+ StatsReporter .stats_aggregation [self .category ][
296
+ key
297
+ ] = StatsAggregationMethod .MOST_RECENT
262
298
263
299
def write_stats (self , step : int ) -> None :
264
300
"""
265
301
Write out all stored statistics that fall under the category specified.
266
302
The currently stored values will be averaged, written out as a single value,
267
303
and the buffer cleared.
304
+
268
305
:param step: Training step which to write these stats as.
269
306
"""
270
307
with StatsReporter .lock :
@@ -279,14 +316,19 @@ def write_stats(self, step: int) -> None:
279
316
280
317
def get_stats_summaries (self , key : str ) -> StatsSummary :
281
318
"""
282
- Get the mean, std, and count of a particular statistic, since last write.
319
+ Get the mean, std, count, sum and aggregation method of a particular statistic, since last write.
320
+
283
321
:param key: The type of statistic, e.g. Environment/Reward.
284
- :returns: A StatsSummary NamedTuple containing (mean, std, count) .
322
+ :returns: A StatsSummary containing summary statistics .
285
323
"""
286
- if len (StatsReporter .stats_dict [self .category ][key ]) > 0 :
287
- return StatsSummary (
288
- mean = np .mean (StatsReporter .stats_dict [self .category ][key ]),
289
- std = np .std (StatsReporter .stats_dict [self .category ][key ]),
290
- num = len (StatsReporter .stats_dict [self .category ][key ]),
291
- )
292
- return StatsSummary .empty ()
324
+ stat_values = StatsReporter .stats_dict [self .category ][key ]
325
+ if len (stat_values ) == 0 :
326
+ return StatsSummary .empty ()
327
+
328
+ return StatsSummary (
329
+ mean = np .mean (stat_values ),
330
+ std = np .std (stat_values ),
331
+ num = len (stat_values ),
332
+ sum = np .sum (stat_values ),
333
+ aggregation_method = StatsReporter .stats_aggregation [self .category ][key ],
334
+ )
0 commit comments