Skip to content

Commit b047e5d

Browse files
committed
clean up
1 parent 07c11d8 commit b047e5d

File tree

3 files changed

+0
-13
lines changed

3 files changed

+0
-13
lines changed

ml-agents/mlagents/trainers/torch/action_model.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
from mlagents.trainers.torch.agent_action import AgentAction
1010
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
1111
from mlagents_envs.base_env import ActionSpec
12-
from mlagents_envs import logging_util
13-
14-
logger = logging_util.get_logger(__name__)
1512

1613

1714
EPSILON = 1e-7 # Small value to avoid divide by zero
@@ -175,20 +172,15 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
175172
action_out_deprecated = continuous_out
176173
deter_continuous_out = torch.clamp(deter_continuous_out, -3, 3) / 3
177174
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
178-
logger.info(
179-
f"dist: {[discrete_dist.probs for discrete_dist in dists.discrete]}"
180-
) # TODO: remove
181175
discrete_out_list = [
182176
discrete_dist.exported_model_output()
183177
for discrete_dist in dists.discrete
184178
]
185-
logger.info(f"discretelist {discrete_out_list}") # TODO: remove
186179
discrete_out = torch.cat(discrete_out_list, dim=1)
187180
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
188181
deter_discrete_out_list = [
189182
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
190183
]
191-
logger.info(f"deterlist {deter_discrete_out_list}") # TODO: remove
192184
deter_discrete_out = torch.cat(deter_discrete_out_list, dim=1)
193185

194186
# deprecated action field does not support hybrid action

ml-agents/mlagents/trainers/torch/model_serialization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,5 @@ def export_policy_model(self, output_filepath: str) -> None:
169169
input_names=self.input_names,
170170
output_names=self.output_names,
171171
dynamic_axes=self.dynamic_axes,
172-
verbose=True, # TODO: remove
173172
)
174173
logger.info(f"Exported {onnx_output_path}")

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,6 @@ def forward(
571571

572572
class SimpleActor(nn.Module, Actor):
573573
MODEL_EXPORT_VERSION = 3 # Corresponds to ModelApiVersion.MLAgents2_0
574-
is_stochastic_action_sampling = (
575-
True
576-
) # TODO: this should be a user input both for training and inference
577574

578575
def __init__(
579576
self,
@@ -585,7 +582,6 @@ def __init__(
585582
):
586583
super().__init__()
587584
self.action_spec = action_spec
588-
# self.is_continuous_int_deprecated = is_stochastic_action_sampling # TODO:
589585
self.version_number = torch.nn.Parameter(
590586
torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False
591587
)

0 commit comments

Comments
 (0)