Skip to content

Commit fb7849f

Browse files
committed
Renaming
1 parent 085e56e commit fb7849f

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

ml-agents/mlagents/trainers/torch/action_model.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,26 +162,33 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
162162
"""
163163
dists = self._get_dists(inputs, masks)
164164
continuous_out, discrete_out, action_out_deprecated = None, None, None
165-
deter_continuous_out, deter_discrete_out = None, None # deterministic actions
165+
deterministic_continuous_out, deterministic_discrete_out = (
166+
None,
167+
None,
168+
) # deterministic actions
166169
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
167170
continuous_out = dists.continuous.exported_model_output()
168171
action_out_deprecated = continuous_out
169-
deter_continuous_out = dists.continuous.deterministic_sample()
172+
deterministic_continuous_out = dists.continuous.deterministic_sample()
170173
if self._clip_action_on_export:
171174
continuous_out = torch.clamp(continuous_out, -3, 3) / 3
172175
action_out_deprecated = continuous_out
173-
deter_continuous_out = torch.clamp(deter_continuous_out, -3, 3) / 3
176+
deterministic_continuous_out = (
177+
torch.clamp(deterministic_continuous_out, -3, 3) / 3
178+
)
174179
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
175180
discrete_out_list = [
176181
discrete_dist.exported_model_output()
177182
for discrete_dist in dists.discrete
178183
]
179184
discrete_out = torch.cat(discrete_out_list, dim=1)
180185
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
181-
deter_discrete_out_list = [
186+
deterministic_discrete_out_list = [
182187
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
183188
]
184-
deter_discrete_out = torch.cat(deter_discrete_out_list, dim=1)
189+
deterministic_discrete_out = torch.cat(
190+
deterministic_discrete_out_list, dim=1
191+
)
185192

186193
# deprecated action field does not support hybrid action
187194
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
@@ -190,8 +197,8 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
190197
continuous_out,
191198
discrete_out,
192199
action_out_deprecated,
193-
deter_continuous_out,
194-
deter_discrete_out,
200+
deterministic_continuous_out,
201+
deterministic_discrete_out,
195202
)
196203

197204
def forward(

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -675,21 +675,21 @@ def forward(
675675
cont_action_out,
676676
disc_action_out,
677677
action_out_deprecated,
678-
deter_cont_action_out,
679-
deter_disc_action_out,
678+
deterministic_cont_action_out,
679+
deterministic_disc_action_out,
680680
) = self.action_model.get_action_out(encoding, masks)
681681
export_out = [self.version_number, self.memory_size_vector]
682682
if self.action_spec.continuous_size > 0:
683683
export_out += [
684684
cont_action_out,
685685
self.continuous_act_size_vector,
686-
deter_cont_action_out,
686+
deterministic_cont_action_out,
687687
]
688688
if self.action_spec.discrete_size > 0:
689689
export_out += [
690690
disc_action_out,
691691
self.discrete_act_size_vector,
692-
deter_disc_action_out,
692+
deterministic_disc_action_out,
693693
]
694694
if self.network_body.memory_size > 0:
695695
export_out += [memories_out]

0 commit comments

Comments
 (0)