@@ -162,26 +162,33 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
162
162
"""
163
163
dists = self ._get_dists (inputs , masks )
164
164
continuous_out , discrete_out , action_out_deprecated = None , None , None
165
- deter_continuous_out , deter_discrete_out = None , None # deterministic actions
165
+ deterministic_continuous_out , deterministic_discrete_out = (
166
+ None ,
167
+ None ,
168
+ ) # deterministic actions
166
169
if self .action_spec .continuous_size > 0 and dists .continuous is not None :
167
170
continuous_out = dists .continuous .exported_model_output ()
168
171
action_out_deprecated = continuous_out
169
- deter_continuous_out = dists .continuous .deterministic_sample ()
172
+ deterministic_continuous_out = dists .continuous .deterministic_sample ()
170
173
if self ._clip_action_on_export :
171
174
continuous_out = torch .clamp (continuous_out , - 3 , 3 ) / 3
172
175
action_out_deprecated = continuous_out
173
- deter_continuous_out = torch .clamp (deter_continuous_out , - 3 , 3 ) / 3
176
+ deterministic_continuous_out = (
177
+ torch .clamp (deterministic_continuous_out , - 3 , 3 ) / 3
178
+ )
174
179
if self .action_spec .discrete_size > 0 and dists .discrete is not None :
175
180
discrete_out_list = [
176
181
discrete_dist .exported_model_output ()
177
182
for discrete_dist in dists .discrete
178
183
]
179
184
discrete_out = torch .cat (discrete_out_list , dim = 1 )
180
185
action_out_deprecated = torch .cat (discrete_out_list , dim = 1 )
181
- deter_discrete_out_list = [
186
+ deterministic_discrete_out_list = [
182
187
discrete_dist .deterministic_sample () for discrete_dist in dists .discrete
183
188
]
184
- deter_discrete_out = torch .cat (deter_discrete_out_list , dim = 1 )
189
+ deterministic_discrete_out = torch .cat (
190
+ deterministic_discrete_out_list , dim = 1
191
+ )
185
192
186
193
# deprecated action field does not support hybrid action
187
194
if self .action_spec .continuous_size > 0 and self .action_spec .discrete_size > 0 :
@@ -190,8 +197,8 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
190
197
continuous_out ,
191
198
discrete_out ,
192
199
action_out_deprecated ,
193
- deter_continuous_out ,
194
- deter_discrete_out ,
200
+ deterministic_continuous_out ,
201
+ deterministic_discrete_out ,
195
202
)
196
203
197
204
def forward (
0 commit comments