1212
1313
1414def create_action_model (inp_size , act_size , deterministic = False ):
15- mask = torch .ones ([1 , act_size * 2 ])
15+ mask = torch .ones ([1 , act_size ** 2 ])
1616 action_spec = ActionSpec (act_size , tuple (act_size for _ in range (act_size )))
1717 action_model = ActionModel (inp_size , action_spec , deterministic = deterministic )
1818 return action_model , mask
@@ -45,13 +45,14 @@ def test_sample_action():
4545
4646def test_deterministic_sample_action ():
4747 inp_size = 4
48- act_size = 2
48+ act_size = 8
4949 action_model , masks = create_action_model (inp_size , act_size , deterministic = True )
5050 sample_inp = torch .ones ((1 , inp_size ))
5151 dists = action_model ._get_dists (sample_inp , masks = masks )
5252 agent_action1 = action_model ._sample_action (dists )
5353 agent_action2 = action_model ._sample_action (dists )
5454 agent_action3 = action_model ._sample_action (dists )
55+
5556 assert torch .equal (agent_action1 .continuous_tensor , agent_action2 .continuous_tensor )
5657 assert torch .equal (agent_action1 .continuous_tensor , agent_action3 .continuous_tensor )
5758 assert torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor )
@@ -63,14 +64,26 @@ def test_deterministic_sample_action():
6364 agent_action1 = action_model ._sample_action (dists )
6465 agent_action2 = action_model ._sample_action (dists )
6566 agent_action3 = action_model ._sample_action (dists )
66- assert not torch .equal (
67+
68+ chance_counter = 0
69+
70+ if not torch .equal (
6771 agent_action1 .continuous_tensor , agent_action2 .continuous_tensor
68- )
69- assert not torch .equal (
72+ ):
73+ chance_counter += 1
74+
75+ if not torch .equal (
7076 agent_action1 .continuous_tensor , agent_action3 .continuous_tensor
71- )
72- assert not torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor )
73- assert not torch .equal (agent_action1 .discrete_tensor , agent_action3 .discrete_tensor )
77+ ):
78+ chance_counter += 1
79+
80+ assert chance_counter > 1
81+ chance_counter = 0
82+ if not torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor ):
83+ chance_counter += 1
84+ if not torch .equal (agent_action1 .discrete_tensor , agent_action3 .discrete_tensor ):
85+ chance_counter += 1
86+ assert chance_counter > 1
7487
7588
7689def test_get_probs_and_entropy ():
0 commit comments