diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 79a53fc517..19243db936 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -25,6 +25,7 @@ TensorBoard. Thanks to @brccabral for the contribution! (#4816) #### com.unity.ml-agents (C#) - Fix a compile warning about using an obsolete enum in `GrpcExtensions.cs`. (#4812) #### ml-agents / ml-agents-envs / gym-unity (Python) +- Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842) ## [1.7.2-preview] - 2020-12-22 diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 3607e30caa..9f47bd567b 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -681,7 +681,13 @@ def structure(d: Mapping, t: type) -> Any: class DefaultTrainerDict(collections.defaultdict): def __init__(self, *args): - super().__init__(TrainerSettings, *args) + # Depending on how this is called, args may have the defaultdict + # callable at the start of the list or not. In particular, unpickling + # will pass [TrainerSettings]. + if args and args[0] == TrainerSettings: + super().__init__(*args) + else: + super().__init__(TrainerSettings, *args) def __missing__(self, key: Any) -> "TrainerSettings": if TrainerSettings.default_override is not None: diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index ba06c3881d..52a0460718 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -1,5 +1,6 @@ import attr import cattr +import pickle import pytest import yaml @@ -516,3 +517,10 @@ def test_default_settings(): test1_settings.max_steps = 1 test1_settings.network_settings.hidden_units == default_settings_cls.network_settings.hidden_units check_if_different(test1_settings, default_settings_cls) + + +def test_pickle(): + # Make sure RunOptions is pickle-able. + run_options = RunOptions() + p = pickle.dumps(run_options) + pickle.loads(p)