diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py index 3aa2f8b8f8..de420c42a4 100644 --- a/ml-agents/mlagents/trainers/cli_utils.py +++ b/ml-agents/mlagents/trainers/cli_utils.py @@ -96,7 +96,8 @@ def _create_parser() -> argparse.ArgumentParser: default=False, dest="deterministic", action=DetectDefaultStoreTrue, - help="Whether to select actions deterministically in policy. `dist.mean` for continuous action space, and `dist.argmax` for deterministic action space ", + help="Whether to select actions deterministically in policy. `dist.mean` for continuous action " + "space, and `dist.argmax` for deterministic action space ", ) argparser.add_argument( "--force", diff --git a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py index 2f81aaf034..a33f793927 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py @@ -65,24 +65,22 @@ def test_deterministic_sample_action(): agent_action2 = action_model._sample_action(dists) agent_action3 = action_model._sample_action(dists) - chance_counter = 0 - - if not torch.equal( + assert not torch.equal( agent_action1.continuous_tensor, agent_action2.continuous_tensor - ): - chance_counter += 1 + ) - if not torch.equal( + assert not torch.equal( agent_action1.continuous_tensor, agent_action3.continuous_tensor - ): - chance_counter += 1 + ) - assert chance_counter > 1 chance_counter = 0 if not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor): chance_counter += 1 if not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor): chance_counter += 1 + if not torch.equal(agent_action2.discrete_tensor, agent_action3.discrete_tensor): + chance_counter += 1 + assert chance_counter > 1