Skip to content

[bug-fix] Fix save/restore critic, add test #5062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
return update_stats

def get_modules(self):
modules = {"Optimizer": self.optimizer}
modules = {
"Optimizer:value_optimizer": self.optimizer,
"Optimizer:critic": self._critic,
}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules
3 changes: 2 additions & 1 deletion ml-agents/mlagents/trainers/sac/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,8 @@ def update_reward_signals(

def get_modules(self):
modules = {
"Optimizer:value_network": self.q_network,
"Optimizer:q_network": self.q_network,
"Optimizer:value_network": self._critic,
"Optimizer:target_network": self.target_network,
"Optimizer:policy_optimizer": self.policy_optimizer,
"Optimizer:value_optimizer": self.value_optimizer,
Expand Down
61 changes: 59 additions & 2 deletions ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
import pytest
from unittest import mock
import os
Expand All @@ -6,8 +7,9 @@
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.settings import TrainerSettings, PPOSettings, SACSettings
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
from mlagents.trainers.torch.utils import ModelUtils
Expand All @@ -29,7 +31,7 @@ def test_register(tmp_path):
assert model_saver.policy is not None


def test_load_save(tmp_path):
def test_load_save_policy(tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
Expand Down Expand Up @@ -62,6 +64,42 @@ def test_load_save(tmp_path):
assert policy3.get_current_step() == 0


@pytest.mark.parametrize(
"optimizer",
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
ids=["ppo", "sac"],
)
def test_load_save_optimizer(tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer

trainer_settings = TrainerSettings()
trainer_settings.hyperparameters = HyperparametersClass()
policy = create_policy_mock(trainer_settings, use_discrete=False)
optimizer = OptimizerClass(policy, trainer_settings)

# save at path 1
path1 = os.path.join(tmp_path, "runid1")
model_saver = TorchModelSaver(trainer_settings, path1)
model_saver.register(policy)
model_saver.register(optimizer)
model_saver.initialize_or_load()
policy.set_step(2000)
model_saver.save_checkpoint("MockBrain", 2000)

# create a new optimizer and policy
policy2 = create_policy_mock(trainer_settings, use_discrete=False)
optimizer2 = OptimizerClass(policy2, trainer_settings)

# load weights
model_saver2 = TorchModelSaver(trainer_settings, path1, load=True)
model_saver2.register(policy2)
model_saver2.register(optimizer2)
model_saver2.initialize_or_load() # This is to load the optimizers

# Compare the two optimizers
_compare_two_optimizers(optimizer, optimizer2)


# TorchPolicy.evalute() returns log_probs instead of all_log_probs like tf does.
# resulting in indeterministic results for testing.
# So here use sample_actions instead.
Expand Down Expand Up @@ -95,6 +133,25 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
)


def _compare_two_optimizers(opt1: TorchOptimizer, opt2: TorchOptimizer) -> None:
trajectory = mb.make_fake_trajectory(
length=10,
observation_specs=opt1.policy.behavior_spec.observation_specs,
action_spec=opt1.policy.behavior_spec.action_spec,
max_step_complete=True,
)
with torch.no_grad():
_, opt1_val_out, _ = opt1.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
)
_, opt2_val_out, _ = opt2.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
)

for opt1_val, opt2_val in zip(opt1_val_out.values(), opt2_val_out.values()):
np.testing.assert_array_equal(opt1_val, opt2_val)


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
Expand Down