Skip to content

Commit aac2ee6

Browse files
[πŸ› πŸ”¨ ] set_action_for_agent expects a ActionTuple with batch size 1. (#5208)
* [Bug Fix] set_action_for_agent expects a ActionTuple with batch size 1. * moving a line around
1 parent 21548e0 commit aac2ee6

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

β€Žml-agents-envs/mlagents_envs/base_env.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,28 +410,20 @@ def random_action(self, n_agents: int) -> ActionTuple:
410410
return ActionTuple(continuous=_continuous, discrete=_discrete)
411411

412412
def _validate_action(
413-
self, actions: ActionTuple, n_agents: Optional[int], name: str
413+
self, actions: ActionTuple, n_agents: int, name: str
414414
) -> ActionTuple:
415415
"""
416416
Validates that action has the correct action dim
417417
for the correct number of agents and ensures the type.
418418
"""
419-
_expected_shape = (
420-
(n_agents, self.continuous_size)
421-
if n_agents is not None
422-
else (self.continuous_size,)
423-
)
419+
_expected_shape = (n_agents, self.continuous_size)
424420
if actions.continuous.shape != _expected_shape:
425421
raise UnityActionException(
426422
f"The behavior {name} needs a continuous input of dimension "
427423
f"{_expected_shape} for (<number of agents>, <action size>) but "
428424
f"received input of dimension {actions.continuous.shape}"
429425
)
430-
_expected_shape = (
431-
(n_agents, self.discrete_size)
432-
if n_agents is not None
433-
else (self.discrete_size,)
434-
)
426+
_expected_shape = (n_agents, self.discrete_size)
435427
if actions.discrete.shape != _expected_shape:
436428
raise UnityActionException(
437429
f"The behavior {name} needs a discrete input of dimension "

β€Žml-agents-envs/mlagents_envs/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ def set_action_for_agent(
365365
if behavior_name not in self._env_state:
366366
return
367367
action_spec = self._env_specs[behavior_name].action_spec
368-
num_agents = len(self._env_state[behavior_name][0])
369-
action = action_spec._validate_action(action, None, behavior_name)
368+
action = action_spec._validate_action(action, 1, behavior_name)
370369
if behavior_name not in self._env_actions:
370+
num_agents = len(self._env_state[behavior_name][0])
371371
self._env_actions[behavior_name] = action_spec.empty_action(num_agents)
372372
try:
373373
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from mlagents_envs.registry import default_registry
2+
from mlagents_envs.side_channel.engine_configuration_channel import (
3+
EngineConfigurationChannel,
4+
)
5+
from mlagents_envs.base_env import ActionTuple
6+
import numpy as np
7+
8+
BALL_ID = "3DBall"
9+
10+
11+
def test_set_action_single_agent():
12+
engine_config_channel = EngineConfigurationChannel()
13+
env = default_registry[BALL_ID].make(
14+
base_port=6000,
15+
worker_id=0,
16+
no_graphics=True,
17+
side_channels=[engine_config_channel],
18+
)
19+
engine_config_channel.set_configuration_parameters(time_scale=100)
20+
for _ in range(3):
21+
env.reset()
22+
behavior_name = list(env.behavior_specs.keys())[0]
23+
d, t = env.get_steps(behavior_name)
24+
for _ in range(50):
25+
for agent_id in d.agent_id:
26+
action = np.ones((1, 2))
27+
action_tuple = ActionTuple()
28+
action_tuple.add_continuous(action)
29+
env.set_action_for_agent(behavior_name, agent_id, action_tuple)
30+
env.step()
31+
d, t = env.get_steps(behavior_name)
32+
env.close()
33+
34+
35+
def test_set_action_multi_agent():
36+
engine_config_channel = EngineConfigurationChannel()
37+
env = default_registry[BALL_ID].make(
38+
base_port=6001,
39+
worker_id=0,
40+
no_graphics=True,
41+
side_channels=[engine_config_channel],
42+
)
43+
engine_config_channel.set_configuration_parameters(time_scale=100)
44+
for _ in range(3):
45+
env.reset()
46+
behavior_name = list(env.behavior_specs.keys())[0]
47+
d, t = env.get_steps(behavior_name)
48+
for _ in range(50):
49+
action = np.ones((len(d), 2))
50+
action_tuple = ActionTuple()
51+
action_tuple.add_continuous(action)
52+
env.set_actions(behavior_name, action_tuple)
53+
env.step()
54+
d, t = env.get_steps(behavior_name)
55+
env.close()

0 commit comments

Comments
Β (0)