Skip to content

[refactor] Use PyTorch TensorBoard utils #4518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 30 additions & 56 deletions ml-agents/mlagents/trainers/stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from enum import Enum
from typing import List, Dict, NamedTuple, Any, Optional
from typing import List, Dict, NamedTuple, Any
import numpy as np
import abc
import os
Expand All @@ -9,13 +9,34 @@

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from mlagents.tf_utils import tf, generate_session_config
from torch.utils.tensorboard import SummaryWriter
from mlagents.tf_utils.globals import get_rank


logger = get_logger(__name__)


def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return str(param_dict)
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(
[
"\t"
+ " " * num_tabs
+ "{}:\t{}".format(x, _dict_to_str(param_dict[x], num_tabs + 1))
for x in param_dict
]
)


class StatsSummary(NamedTuple):
mean: float
std: float
Expand Down Expand Up @@ -123,35 +144,13 @@ def add_property(
if property_type == StatsPropertyType.HYPERPARAMETERS:
logger.info(
"""Hyperparameters for behavior name {}: \n{}""".format(
category, self._dict_to_str(value, 0)
category, _dict_to_str(value, 0)
)
)
elif property_type == StatsPropertyType.SELF_PLAY:
assert isinstance(value, bool)
self.self_play = value

def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return str(param_dict)
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(
[
"\t"
+ " " * num_tabs
+ "{}:\t{}".format(
x, self._dict_to_str(param_dict[x], num_tabs + 1)
)
for x in param_dict
]
)


class TensorboardWriter(StatsWriter):
def __init__(self, base_dir: str, clear_past_data: bool = False):
Expand All @@ -162,7 +161,7 @@ def __init__(self, base_dir: str, clear_past_data: bool = False):
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and
category.
"""
self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
self.summary_writers: Dict[str, SummaryWriter] = {}
self.base_dir: str = base_dir
self._clear_past_data = clear_past_data

Expand All @@ -171,9 +170,7 @@ def write_stats(
) -> None:
self._maybe_create_summary_writer(category)
for key, value in values.items():
summary = tf.Summary()
summary.value.add(tag=f"{key}", simple_value=value.mean)
self.summary_writers[category].add_summary(summary, step)
self.summary_writers[category].add_scalar(f"{key}", value.mean, step)
self.summary_writers[category].flush()

def _maybe_create_summary_writer(self, category: str) -> None:
Expand All @@ -184,7 +181,7 @@ def _maybe_create_summary_writer(self, category: str) -> None:
os.makedirs(filewriter_dir, exist_ok=True)
if self._clear_past_data:
self._delete_all_events_files(filewriter_dir)
self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir)
self.summary_writers[category] = SummaryWriter(filewriter_dir)

def _delete_all_events_files(self, directory_name: str) -> None:
for file_name in os.listdir(directory_name):
Expand All @@ -206,34 +203,11 @@ def add_property(
) -> None:
if property_type == StatsPropertyType.HYPERPARAMETERS:
assert isinstance(value, dict)
summary = self._dict_to_tensorboard("Hyperparameters", value)
summary = _dict_to_str(value, 0)
self._maybe_create_summary_writer(category)
if summary is not None:
self.summary_writers[category].add_summary(summary, 0)

def _dict_to_tensorboard(
self, name: str, input_dict: Dict[str, Any]
) -> Optional[bytes]:
"""
Convert a dict to a Tensorboard-encoded string.
:param name: The name of the text.
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session(config=generate_session_config()) as sess:
s_op = tf.summary.text(
name,
tf.convert_to_tensor(
[[str(x), str(input_dict[x])] for x in input_dict]
),
)
s = sess.run(s_op)
return s
except Exception:
logger.warning(
f"Could not write {name} summary for Tensorboard: {input_dict}"
)
return None
self.summary_writers[category].add_text("Hyperparameters", summary)
self.summary_writers[category].flush()


class StatsReporter:
Expand Down
20 changes: 7 additions & 13 deletions ml-agents/mlagents/trainers/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def test_stat_reporter_property():
)


@mock.patch("mlagents.tf_utils.tf.Summary")
@mock.patch("mlagents.tf_utils.tf.summary.FileWriter")
def test_tensorboard_writer(mock_filewriter, mock_summary):
@mock.patch("mlagents.trainers.stats.SummaryWriter")
def test_tensorboard_writer(mock_summary):
# Test write_stats
category = "category1"
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
Expand All @@ -83,22 +82,17 @@ def test_tensorboard_writer(mock_filewriter, mock_summary):
basedir=base_dir, category=category
)
assert os.path.exists(filewriter_dir)
mock_filewriter.assert_called_once_with(filewriter_dir)
mock_summary.assert_called_once_with(filewriter_dir)

# Test that the filewriter was written to and the summary was added.
mock_summary.return_value.value.add.assert_called_once_with(
tag="key1", simple_value=1.0
)
mock_filewriter.return_value.add_summary.assert_called_once_with(
mock_summary.return_value, 10
)
mock_filewriter.return_value.flush.assert_called_once()
mock_summary.return_value.add_scalar.assert_called_once_with("key1", 1.0, 10)
mock_summary.return_value.flush.assert_called_once()

# Test hyperparameter writing - no good way to parse the TB string though.
tb_writer.add_property(
"category1", StatsPropertyType.HYPERPARAMETERS, {"example": 1.0}
)
assert mock_filewriter.return_value.add_summary.call_count > 1
assert mock_summary.return_value.add_text.call_count >= 1


def test_tensorboard_writer_clear(tmp_path):
Expand Down Expand Up @@ -153,7 +147,7 @@ def test_console_writer(self):
},
10,
)
# Test hyperparameter writing - no good way to parse the TB string though.
# Test hyperparameter writing
console_writer.add_property(
"category1", StatsPropertyType.HYPERPARAMETERS, {"example": 1.0}
)
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run(self):
"pyyaml>=3.1.0",
# Windows ver. of PyTorch doesn't work from PyPi
'torch>=1.6.0;platform_system!="Windows"',
"tensorboard>=1.14",
"tensorboard>=1.15",
"cattrs>=1.0.0",
"attrs>=19.3.0",
'pypiwin32==223;platform_system=="Windows"',
Expand Down
1 change: 1 addition & 0 deletions test_constraints_min_version.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Pillow==4.2.1
protobuf==3.6
tensorflow==1.14.0
h5py==2.9.0
tensorboard==1.15.0