12
12
13
13
14
14
def create_action_model (inp_size , act_size , deterministic = False ):
15
- mask = torch .ones ([1 , act_size * 2 ])
15
+ mask = torch .ones ([1 , act_size ** 2 ])
16
16
action_spec = ActionSpec (act_size , tuple (act_size for _ in range (act_size )))
17
17
action_model = ActionModel (inp_size , action_spec , deterministic = deterministic )
18
18
return action_model , mask
@@ -45,13 +45,14 @@ def test_sample_action():
45
45
46
46
def test_deterministic_sample_action ():
47
47
inp_size = 4
48
- act_size = 2
48
+ act_size = 8
49
49
action_model , masks = create_action_model (inp_size , act_size , deterministic = True )
50
50
sample_inp = torch .ones ((1 , inp_size ))
51
51
dists = action_model ._get_dists (sample_inp , masks = masks )
52
52
agent_action1 = action_model ._sample_action (dists )
53
53
agent_action2 = action_model ._sample_action (dists )
54
54
agent_action3 = action_model ._sample_action (dists )
55
+
55
56
assert torch .equal (agent_action1 .continuous_tensor , agent_action2 .continuous_tensor )
56
57
assert torch .equal (agent_action1 .continuous_tensor , agent_action3 .continuous_tensor )
57
58
assert torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor )
@@ -63,14 +64,26 @@ def test_deterministic_sample_action():
63
64
agent_action1 = action_model ._sample_action (dists )
64
65
agent_action2 = action_model ._sample_action (dists )
65
66
agent_action3 = action_model ._sample_action (dists )
66
- assert not torch .equal (
67
+
68
+ chance_counter = 0
69
+
70
+ if not torch .equal (
67
71
agent_action1 .continuous_tensor , agent_action2 .continuous_tensor
68
- )
69
- assert not torch .equal (
72
+ ):
73
+ chance_counter += 1
74
+
75
+ if not torch .equal (
70
76
agent_action1 .continuous_tensor , agent_action3 .continuous_tensor
71
- )
72
- assert not torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor )
73
- assert not torch .equal (agent_action1 .discrete_tensor , agent_action3 .discrete_tensor )
77
+ ):
78
+ chance_counter += 1
79
+
80
+ assert chance_counter > 1
81
+ chance_counter = 0
82
+ if not torch .equal (agent_action1 .discrete_tensor , agent_action2 .discrete_tensor ):
83
+ chance_counter += 1
84
+ if not torch .equal (agent_action1 .discrete_tensor , agent_action3 .discrete_tensor ):
85
+ chance_counter += 1
86
+ assert chance_counter > 1
74
87
75
88
76
89
def test_get_probs_and_entropy ():
0 commit comments