Skip to content

Commit 604d7c1

Browse files
committed
Added more stable test.
1 parent 98da4b1 commit 604d7c1

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

ml-agents/mlagents/trainers/tests/torch/test_action_model.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def create_action_model(inp_size, act_size, deterministic=False):
15-
mask = torch.ones([1, act_size * 2])
15+
mask = torch.ones([1, act_size ** 2])
1616
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
1717
action_model = ActionModel(inp_size, action_spec, deterministic=deterministic)
1818
return action_model, mask
@@ -45,13 +45,14 @@ def test_sample_action():
4545

4646
def test_deterministic_sample_action():
4747
inp_size = 4
48-
act_size = 2
48+
act_size = 8
4949
action_model, masks = create_action_model(inp_size, act_size, deterministic=True)
5050
sample_inp = torch.ones((1, inp_size))
5151
dists = action_model._get_dists(sample_inp, masks=masks)
5252
agent_action1 = action_model._sample_action(dists)
5353
agent_action2 = action_model._sample_action(dists)
5454
agent_action3 = action_model._sample_action(dists)
55+
5556
assert torch.equal(agent_action1.continuous_tensor, agent_action2.continuous_tensor)
5657
assert torch.equal(agent_action1.continuous_tensor, agent_action3.continuous_tensor)
5758
assert torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor)
@@ -63,14 +64,26 @@ def test_deterministic_sample_action():
6364
agent_action1 = action_model._sample_action(dists)
6465
agent_action2 = action_model._sample_action(dists)
6566
agent_action3 = action_model._sample_action(dists)
66-
assert not torch.equal(
67+
68+
chance_counter = 0
69+
70+
if not torch.equal(
6771
agent_action1.continuous_tensor, agent_action2.continuous_tensor
68-
)
69-
assert not torch.equal(
72+
):
73+
chance_counter += 1
74+
75+
if not torch.equal(
7076
agent_action1.continuous_tensor, agent_action3.continuous_tensor
71-
)
72-
assert not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor)
73-
assert not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor)
77+
):
78+
chance_counter += 1
79+
80+
assert chance_counter > 1
81+
chance_counter = 0
82+
if not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor):
83+
chance_counter += 1
84+
if not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor):
85+
chance_counter += 1
86+
assert chance_counter > 1
7487

7588

7689
def test_get_probs_and_entropy():

ml-agents/mlagents/trainers/torch/distributions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def _mask_branch(
225225
# We do -1 * tensor + constant instead of constant - tensor because it seems
226226
# Barracuda might swap the inputs of a "Sub" operation
227227
logits = logits * allow_mask - 1e8 * block_mask
228+
228229
return logits
229230

230231
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:

0 commit comments

Comments
 (0)