-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Python Dataflow for Group Manager #4926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 196 commits
7282b82
24e95f1
e73897e
3ac6ff7
59183b2
faf3308
6b463a7
b7a8afc
84362c8
5fdf354
d41ad79
26cc813
b71f2d7
d94b700
5f523bf
72a165b
8287b97
f66e481
e9f2ccc
33e50cb
d8db247
77ab2e9
fc3fbea
45d2454
f990a3a
4f3132e
be8b7ea
ccff057
438ba4c
4b8a396
6370c0c
131599f
1a8ad76
e848147
788360e
c5536a6
32473bf
ae83e62
d72eee8
62e9b45
1ebacc1
cb57bf0
95b3522
afd7476
292b6ce
112a9dc
783db4c
6c4ba1e
d314478
c7adb93
d2e315d
5cf76e3
8708f70
44fb8b5
56f9dbf
a468075
db184d9
32cbdee
c90472c
d429b53
44093f2
1dc0059
2b5b994
cd84fe3
eed2fce
fe41094
3822b18
3f4b2b5
c7c7d4c
f391b35
f706a91
cee5466
195978c
10f336e
f0bf657
e03c79e
1118089
13a90b1
36d1b5b
376d500
dd8b5fb
8673820
fb86a57
1beea7d
9e69790
3c2b9d1
b4b9d72
7d5f3e3
b7c5533
53e1277
2134004
60c6071
47cfae4
d31da21
541d062
d3c4372
3407478
f84ca50
287c1b9
f73ef80
10a416a
e716199
45349b8
9a6474e
04d9617
5bbb222
2868694
c9b4e71
a10caaf
44c616d
ef01af4
f329e1d
fbd1749
908b1df
d4073ce
7d8f2b5
9452239
39adec6
14bb6fd
7cb5dbc
761a206
3afae60
f0dfada
c3d8d8e
c3d84c5
f5419aa
d7a2386
cdc6dde
4f35048
05c8ea1
a7f2fc2
6d2be2c
b812da4
0c3dbff
09590ad
c982c06
7e3d976
c40fec0
ffb3f0b
f87cfbd
8b8e916
6b71f5a
87e97dd
d3d1dc1
2ba09ca
5587e48
7e51ad1
f25b171
128b09b
4690c4e
dbdd045
30c846f
dd7f867
f36f696
a1b7e75
236f398
a22c621
2dc90a9
70207a3
49282f6
204b45b
7eacfba
016ffd8
8b9d662
3fb14b9
4e4ecad
492fd17
7292672
78e052b
81d8389
39f92c3
1d500d6
6418e05
6da8dd3
ad4a821
9725aa5
f5190fe
664ae89
8f696f4
77557ca
6464cb6
cbfdfb3
6badfb5
ef67f53
8e78dbd
31ee1c4
1e4c837
cba26b2
2113a43
0e28c07
6936004
97d1b80
9a00053
f879b61
6d7a604
587e3da
fd4aa53
8dbea77
ec9e5ad
e1f48db
d42896a
b3f2689
b2100c1
7b1f805
1ce50ef
aea066d
0860098
580683c
5e73905
f1c8d45
af0d353
f3ef9ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,4 +1,5 @@ | ||||||
import sys | ||||||
import numpy as np | ||||||
from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union | ||||||
from collections import defaultdict, Counter | ||||||
import queue | ||||||
|
@@ -14,12 +15,12 @@ | |||||
StatsAggregationMethod, | ||||||
EnvironmentStats, | ||||||
) | ||||||
from mlagents.trainers.trajectory import Trajectory, AgentExperience | ||||||
from mlagents.trainers.trajectory import GroupmateStatus, Trajectory, AgentExperience | ||||||
from mlagents.trainers.policy import Policy | ||||||
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs | ||||||
from mlagents.trainers.torch.action_log_probs import LogProbsTuple | ||||||
from mlagents.trainers.stats import StatsReporter | ||||||
from mlagents.trainers.behavior_id_utils import get_global_agent_id | ||||||
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id | ||||||
|
||||||
T = TypeVar("T") | ||||||
|
||||||
|
@@ -49,6 +50,16 @@ def __init__( | |||||
""" | ||||||
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) | ||||||
self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {} | ||||||
# current_group_obs is used to collect the last seen obs of all the agents in the same group, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# and assemble the group obs. | ||||||
self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( | ||||||
lambda: defaultdict(list) | ||||||
) | ||||||
# last_group_obs is used to collect the last seen obs of all the agents in the same group, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# and assemble the group obs. | ||||||
self.group_status: Dict[str, Dict[str, GroupmateStatus]] = defaultdict( | ||||||
lambda: defaultdict(None) | ||||||
) | ||||||
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while | ||||||
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). | ||||||
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} | ||||||
|
@@ -88,19 +99,32 @@ def add_experiences( | |||||
if global_id in self.last_step_result: # Don't store if agent just reset | ||||||
self.last_take_action_outputs[global_id] = take_action_outputs | ||||||
|
||||||
# Iterate over all the terminal steps | ||||||
# Iterate over all the terminal steps, first gather all the teammate obs | ||||||
# and then create the AgentExperiences/Trajectories | ||||||
for terminal_step in terminal_steps.values(): | ||||||
self._gather_group_obs(terminal_step, worker_id) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment that |
||||||
for terminal_step in terminal_steps.values(): | ||||||
local_id = terminal_step.agent_id | ||||||
global_id = get_global_agent_id(worker_id, local_id) | ||||||
self._process_step( | ||||||
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] | ||||||
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id] | ||||||
) | ||||||
# Iterate over all the decision steps | ||||||
# Clear the last seen group obs when agents die. | ||||||
self._clear_group_obs(global_id) | ||||||
|
||||||
# Clean the last experience dictionary for terminal steps | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. legacy code? |
||||||
for terminal_step in terminal_steps.values(): | ||||||
local_id = terminal_step.agent_id | ||||||
global_id = get_global_agent_id(worker_id, local_id) | ||||||
|
||||||
# Iterate over all the decision steps, first gather all the teammate obs | ||||||
# and then create the trajectories | ||||||
for ongoing_step in decision_steps.values(): | ||||||
self._gather_group_obs(ongoing_step, worker_id) | ||||||
for ongoing_step in decision_steps.values(): | ||||||
local_id = ongoing_step.agent_id | ||||||
global_id = get_global_agent_id(worker_id, local_id) | ||||||
self._process_step( | ||||||
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] | ||||||
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id] | ||||||
) | ||||||
|
||||||
for _gid in action_global_agent_ids: | ||||||
|
@@ -112,21 +136,65 @@ def add_experiences( | |||||
[_gid], take_action_outputs["action"] | ||||||
) | ||||||
|
||||||
def _gather_group_obs( | ||||||
self, step: Union[TerminalStep, DecisionStep], worker_id: int | ||||||
) -> None: | ||||||
global_agent_id = get_global_agent_id(worker_id, step.agent_id) | ||||||
stored_decision_step, idx = self.last_step_result.get( | ||||||
global_agent_id, (None, None) | ||||||
) | ||||||
stored_take_action_outputs = self.last_take_action_outputs.get( | ||||||
global_agent_id, None | ||||||
) | ||||||
if stored_decision_step is not None and stored_take_action_outputs is not None: | ||||||
if step.group_id > 0: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is using implicit assumption that 0 means no group. I'd add some comments or put it into a separate function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment - we can add a utility method too, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I was thinking more like a utility method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah that's not a bad idea. Let's do a separate PR to add this, and add the group ID to the documentation for the Python API |
||||||
global_group_id = get_global_group_id(worker_id, step.group_id) | ||||||
stored_actions = stored_take_action_outputs["action"] | ||||||
action_tuple = ActionTuple( | ||||||
continuous=stored_actions.continuous[idx], | ||||||
discrete=stored_actions.discrete[idx], | ||||||
) | ||||||
group_status = GroupmateStatus( | ||||||
obs=stored_decision_step.obs, | ||||||
reward=step.reward, | ||||||
action=action_tuple, | ||||||
done=isinstance(step, TerminalStep), | ||||||
) | ||||||
self.group_status[global_group_id][global_agent_id] = group_status | ||||||
self.current_group_obs[global_group_id][global_agent_id] = step.obs | ||||||
|
||||||
def _clear_group_obs(self, global_id: str) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not only clear the obs but also the status the name and comment should reflect that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||||||
self._delete_in_nested_dict(self.current_group_obs, global_id) | ||||||
self._delete_in_nested_dict(self.group_status, global_id) | ||||||
|
||||||
def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this a static or utils method. _safe_delete should also be static. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally don't think we should make the method static unless it's used elsewhere. Should we make a new place for utils like this? |
||||||
for _manager_id in list(nested_dict.keys()): | ||||||
_team_group = nested_dict[_manager_id] | ||||||
self._safe_delete(_team_group, key) | ||||||
if not _team_group: # if dict is empty | ||||||
self._safe_delete(nested_dict, _manager_id) | ||||||
|
||||||
def _process_step( | ||||||
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int | ||||||
self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int | ||||||
) -> None: | ||||||
terminated = isinstance(step, TerminalStep) | ||||||
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) | ||||||
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None) | ||||||
global_agent_id = get_global_agent_id(worker_id, step.agent_id) | ||||||
global_group_id = get_global_group_id(worker_id, step.group_id) | ||||||
stored_decision_step, idx = self.last_step_result.get( | ||||||
global_agent_id, (None, None) | ||||||
) | ||||||
stored_take_action_outputs = self.last_take_action_outputs.get( | ||||||
global_agent_id, None | ||||||
) | ||||||
if not terminated: | ||||||
# Index is needed to grab from last_take_action_outputs | ||||||
self.last_step_result[global_id] = (step, index) | ||||||
self.last_step_result[global_agent_id] = (step, index) | ||||||
|
||||||
# This state is the consequence of a past action | ||||||
if stored_decision_step is not None and stored_take_action_outputs is not None: | ||||||
obs = stored_decision_step.obs | ||||||
if self.policy.use_recurrent: | ||||||
memory = self.policy.retrieve_memories([global_id])[0, :] | ||||||
memory = self.policy.retrieve_memories([global_agent_id])[0, :] | ||||||
else: | ||||||
memory = None | ||||||
done = terminated # Since this is an ongoing step | ||||||
|
@@ -143,7 +211,14 @@ def _process_step( | |||||
discrete=stored_action_probs.discrete[idx], | ||||||
) | ||||||
action_mask = stored_decision_step.action_mask | ||||||
prev_action = self.policy.retrieve_previous_action([global_id])[0, :] | ||||||
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :] | ||||||
|
||||||
# Assemble teammate_obs. If none saved, then it will be an empty list. | ||||||
group_statuses = [] | ||||||
for _id, _obs in self.group_status[global_group_id].items(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename variable _obs to _mate_status ? |
||||||
if _id != global_agent_id: | ||||||
group_statuses.append(_obs) | ||||||
|
||||||
experience = AgentExperience( | ||||||
obs=obs, | ||||||
reward=step.reward, | ||||||
|
@@ -154,35 +229,44 @@ def _process_step( | |||||
prev_action=prev_action, | ||||||
interrupted=interrupted, | ||||||
memory=memory, | ||||||
group_status=group_statuses, | ||||||
group_reward=step.group_reward, | ||||||
) | ||||||
# Add the value outputs if needed | ||||||
self.experience_buffers[global_id].append(experience) | ||||||
self.episode_rewards[global_id] += step.reward | ||||||
self.experience_buffers[global_agent_id].append(experience) | ||||||
self.episode_rewards[global_agent_id] += step.reward | ||||||
if not terminated: | ||||||
self.episode_steps[global_id] += 1 | ||||||
self.episode_steps[global_agent_id] += 1 | ||||||
|
||||||
# Add a trajectory segment to the buffer if terminal or the length has reached the time horizon | ||||||
if ( | ||||||
len(self.experience_buffers[global_id]) >= self.max_trajectory_length | ||||||
len(self.experience_buffers[global_agent_id]) | ||||||
>= self.max_trajectory_length | ||||||
or terminated | ||||||
): | ||||||
# Make next AgentExperience | ||||||
next_obs = step.obs | ||||||
next_group_obs = [] | ||||||
for _id, _exp in self.current_group_obs[global_group_id].items(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _exp -> _obs |
||||||
if _id != global_agent_id: | ||||||
next_group_obs.append(_exp) | ||||||
|
||||||
trajectory = Trajectory( | ||||||
steps=self.experience_buffers[global_id], | ||||||
agent_id=global_id, | ||||||
steps=self.experience_buffers[global_agent_id], | ||||||
agent_id=global_agent_id, | ||||||
next_obs=next_obs, | ||||||
next_group_obs=next_group_obs, | ||||||
behavior_id=self.behavior_id, | ||||||
) | ||||||
for traj_queue in self.trajectory_queues: | ||||||
traj_queue.put(trajectory) | ||||||
self.experience_buffers[global_id] = [] | ||||||
self.experience_buffers[global_agent_id] = [] | ||||||
if terminated: | ||||||
# Record episode length. | ||||||
self.stats_reporter.add_stat( | ||||||
"Environment/Episode Length", self.episode_steps.get(global_id, 0) | ||||||
"Environment/Episode Length", | ||||||
self.episode_steps.get(global_agent_id, 0), | ||||||
) | ||||||
self._clean_agent_data(global_id) | ||||||
self._clean_agent_data(global_agent_id) | ||||||
|
||||||
def _clean_agent_data(self, global_id: str) -> None: | ||||||
""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a bunch of these new fields (and some of the old fields as well) should be made private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, made the ones that weren't used outside private