Skip to content

Python Dataflow for Group Manager #4926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 212 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
212 commits
Select commit Hold shift + click to select a range
7282b82
Make buffer type-agnostic
Nov 14, 2020
24e95f1
Edit types of Apped method
Nov 14, 2020
e73897e
Change comment
Nov 14, 2020
3ac6ff7
Collaborative walljump
Nov 17, 2020
59183b2
Make collab env harder
Nov 17, 2020
faf3308
Merge branch 'develop-stack-walljump' into develop-centralizedcritic
Nov 20, 2020
6b463a7
Add group ID
Nov 20, 2020
b7a8afc
Add collab obs to trajectory
Nov 20, 2020
84362c8
Merge branch 'develop-multitype-buffer' into develop-centralizedcritic
Nov 21, 2020
5fdf354
Fix bug; add critic_obs to buffer
Nov 21, 2020
d41ad79
Set group ids for some envs
Nov 21, 2020
26cc813
Merge branch 'develop-multitype-buffer' into develop-unified-obs
Nov 21, 2020
b71f2d7
Pretty broken
Nov 21, 2020
d94b700
Less broken PPO
Dec 1, 2020
5f523bf
Update SAC, fix PPO batching
Dec 1, 2020
72a165b
Fix SAC interrupted condition and typing
Dec 1, 2020
8287b97
Fix SAC interrupted again
Dec 1, 2020
f66e481
Remove erroneous file
Dec 1, 2020
e9f2ccc
Fix multiple obs
Dec 2, 2020
33e50cb
Update curiosity reward provider
Dec 2, 2020
d8db247
Update GAIL and BC
Dec 3, 2020
77ab2e9
Merge branch 'develop-unified-obs' into develop-centralizedcritic
Dec 3, 2020
fc3fbea
Multi-input network
Dec 3, 2020
45d2454
Some minor tweaks but still broken
Dec 3, 2020
f990a3a
Get next critic observations into value estimate
Dec 9, 2020
4f3132e
Temporarily disable exporting
Dec 10, 2020
be8b7ea
Use Vince's ONNX export code
Dec 11, 2020
ccff057
Cleanup
Dec 11, 2020
438ba4c
Add walljump collab YAML
Dec 14, 2020
4b8a396
Lower max height
Dec 14, 2020
6370c0c
Merge branch 'master' into develop-centralizedcritic
Dec 14, 2020
131599f
Update prefab
Dec 14, 2020
1a8ad76
Update prefab
Dec 14, 2020
e848147
Collaborative Hallway
Dec 14, 2020
788360e
Set num teammates to 2
Dec 14, 2020
c5536a6
Add config and group ids to HallwayCollab
Dec 14, 2020
32473bf
Fix bug with hallway collab
Dec 15, 2020
ae83e62
Edits to HallwayCollab
Dec 15, 2020
d72eee8
Update onnx file meta
Dec 15, 2020
62e9b45
Make the env easier
Dec 15, 2020
1ebacc1
Remove prints
Dec 15, 2020
cb57bf0
Make Collab env harder
Dec 17, 2020
95b3522
Fix group ID
Dec 18, 2020
afd7476
Add cc to ghost trainer
Dec 18, 2020
292b6ce
Add comment to ghost trainer
Dec 18, 2020
112a9dc
Revert "Add comment to ghost trainer"
Dec 18, 2020
783db4c
Actually add comment to ghosttrainer
Dec 18, 2020
6c4ba1e
Scale size of CC network
Dec 21, 2020
d314478
Scale value network based on num agents
Dec 21, 2020
c7adb93
Add 3rd symbol to hallway collab
Dec 21, 2020
d2e315d
Make comms one-hot
Dec 21, 2020
5cf76e3
Fix S tag
Dec 23, 2020
8708f70
Merge branch 'master' into develop-centralizedcritic-mm
Jan 4, 2021
44fb8b5
Additional changes
Jan 4, 2021
56f9dbf
Some more fixes
Jan 4, 2021
a468075
Self-attention Centralized Critic
Jan 6, 2021
db184d9
separate entity encoder and RSA
andrewcoh Jan 11, 2021
32cbdee
clean up args in mha
andrewcoh Jan 11, 2021
c90472c
more cleanups
andrewcoh Jan 11, 2021
d429b53
fixed tests
andrewcoh Jan 11, 2021
44093f2
Merge branch 'develop-attention-refactor' into develop-centralizedcri…
Jan 11, 2021
1dc0059
Merge branch 'develop-attention-refactor' into develop-centralizedcri…
Jan 11, 2021
2b5b994
entity embeddings work with no max
Jan 11, 2021
cd84fe3
remove group id
Jan 11, 2021
eed2fce
very rough sketch for TeamManager interface
Jan 8, 2021
fe41094
One layer for entity embed
Jan 12, 2021
3822b18
Use 4 heads
Jan 12, 2021
3f4b2b5
add defaults to linear encoder, initialize ent encoders
andrewcoh Jan 12, 2021
c7c7d4c
Merge branch 'master' into develop-centralizedcritic-mm
Jan 12, 2021
f391b35
Merge branch 'develop-lin-enc-def' into develop-centralizedcritic-mm
Jan 12, 2021
f706a91
add team manager id to proto
Jan 12, 2021
cee5466
team manager for hallway
Jan 12, 2021
195978c
add manager to hallway
Jan 12, 2021
10f336e
send and process team manager id
Jan 12, 2021
f0bf657
remove print
Jan 12, 2021
e03c79e
Merge branch 'develop-centralizedcritic-mm' into develop-cc-teammanager
Jan 12, 2021
1118089
small cleanup
Jan 13, 2021
13a90b1
default behavior for baseTeamManager
Jan 13, 2021
36d1b5b
add back statsrecorder
Jan 13, 2021
376d500
update
Jan 13, 2021
dd8b5fb
Team manager prototype (#4850)
Jan 13, 2021
8673820
Remove statsrecorder
Jan 13, 2021
fb86a57
Fix AgentProcessor for TeamManager
Jan 13, 2021
1beea7d
Merge branch 'develop-centralizedcritic-mm' into develop-cc-teammanager
Jan 13, 2021
9e69790
team manager
Jan 13, 2021
3c2b9d1
New buffer layout, TeamObsUtil, pad dead agents
Jan 14, 2021
b4b9d72
Use NaNs to get masks for attention
Jan 14, 2021
7d5f3e3
Add team reward to buffer
Jan 15, 2021
b7c5533
Try subtract marginalized value
Jan 15, 2021
53e1277
Add Q function with attention
Jan 20, 2021
2134004
Some more progress - still broken
Jan 20, 2021
60c6071
use singular entity embedding (#4873)
andrewcoh Jan 20, 2021
47cfae4
I think it's running
Jan 20, 2021
d31da21
Actions added but untested
Jan 21, 2021
541d062
Fix issue with team_actions
Jan 22, 2021
d3c4372
Add next action and next team obs
Jan 22, 2021
3407478
separate forward into q_net and baseline
andrewcoh Jan 22, 2021
f84ca50
Merge branch 'develop-centralizedcritic-counterfact' into develop-coma2
andrewcoh Jan 22, 2021
287c1b9
might be right
andrewcoh Jan 22, 2021
f73ef80
forcing this to work
andrewcoh Jan 22, 2021
10a416a
buffer error
andrewcoh Jan 22, 2021
e716199
COMAA runs
andrewcoh Jan 23, 2021
45349b8
add lambda return and target network
andrewcoh Jan 23, 2021
9a6474e
no target net
andrewcoh Jan 24, 2021
04d9617
remove normalize advantages
andrewcoh Jan 24, 2021
5bbb222
add target network back
andrewcoh Jan 24, 2021
2868694
value estimator
andrewcoh Jan 24, 2021
c9b4e71
update coma config
andrewcoh Jan 24, 2021
a10caaf
add target net
andrewcoh Jan 24, 2021
44c616d
no target, increase lambda
andrewcoh Jan 24, 2021
ef01af4
remove prints
andrewcoh Jan 24, 2021
f329e1d
cloud config
andrewcoh Jan 24, 2021
fbd1749
use v return
andrewcoh Jan 25, 2021
908b1df
use target net
andrewcoh Jan 25, 2021
d4073ce
adding zombie to coma2 brnch
andrewcoh Jan 25, 2021
7d8f2b5
add callbacks
andrewcoh Jan 25, 2021
9452239
cloud run with coma2 of held out zombie test env
andrewcoh Jan 25, 2021
39adec6
target of baseline is returns_v
andrewcoh Jan 26, 2021
14bb6fd
remove target update
andrewcoh Jan 26, 2021
7cb5dbc
Add team dones
Jan 26, 2021
761a206
ntegrate teammate dones
andrewcoh Jan 26, 2021
3afae60
add value clipping
andrewcoh Jan 26, 2021
f0dfada
try again on cloud
andrewcoh Jan 26, 2021
c3d8d8e
clipping values and updated zombie
andrewcoh Jan 27, 2021
c3d84c5
update configs
andrewcoh Jan 27, 2021
f5419aa
remove value head clipping
andrewcoh Jan 27, 2021
d7a2386
update zombie config
andrewcoh Jan 27, 2021
cdc6dde
Add trust region to COMA updates
Jan 29, 2021
4f35048
Remove Q-net for perf
Jan 29, 2021
05c8ea1
Weight decay, regularizaton loss
Jan 29, 2021
a7f2fc2
Use same network
Jan 29, 2021
6d2be2c
add base team manager
Feb 1, 2021
b812da4
Remove reg loss, still stable
Feb 4, 2021
0c3dbff
Black format
Feb 4, 2021
09590ad
add team reward field to agent and proto
Feb 5, 2021
c982c06
set team reward
Feb 5, 2021
7e3d976
add maxstep to teammanager and hook to academy
Feb 5, 2021
c40fec0
check agent by agent.enabled
Feb 8, 2021
ffb3f0b
remove manager from academy when dispose
Feb 9, 2021
f87cfbd
move manager
Feb 9, 2021
8b8e916
put team reward in decision steps
Feb 9, 2021
6b71f5a
use 0 as default manager id
Feb 9, 2021
87e97dd
fix setTeamReward
Feb 9, 2021
d3d1dc1
change method name to GetRegisteredAgents
Feb 9, 2021
2ba09ca
address comments
Feb 9, 2021
5587e48
Merge branch 'develop-base-teammanager' into develop-agentprocessor-t…
Feb 9, 2021
7e51ad1
Merge branch 'develop-base-teammanager' into develop-agentprocessor-t…
Feb 9, 2021
f25b171
Revert C# env changes
Feb 9, 2021
128b09b
Remove a bunch of stuff from envs
Feb 9, 2021
4690c4e
Remove a bunch of extra files
Feb 9, 2021
dbdd045
Remove changes from base-teammanager
Feb 9, 2021
30c846f
Remove remaining files
Feb 9, 2021
dd7f867
Remove some unneeded changes
Feb 9, 2021
f36f696
Make buffer typing neater
Feb 9, 2021
a1b7e75
AgentProcessor fixes
Feb 9, 2021
236f398
Back out trainer changes
Feb 9, 2021
a22c621
use delegate to avoid agent-manager cyclic reference
Feb 9, 2021
2dc90a9
put team reward in decision steps
Feb 9, 2021
70207a3
fix unregister agents
Feb 10, 2021
49282f6
add teamreward to decision step
Feb 10, 2021
204b45b
typo
Feb 10, 2021
7eacfba
unregister on disabled
Feb 10, 2021
016ffd8
remove OnTeamEpisodeBegin
Feb 10, 2021
8b9d662
change name TeamManager to MultiAgentGroup
Feb 11, 2021
3fb14b9
more team -> group
Feb 11, 2021
4e4ecad
fix tests
Feb 11, 2021
492fd17
fix tests
Feb 11, 2021
7292672
Merge remote-tracking branch 'origin/develop-base-teammanager' into d…
Feb 11, 2021
78e052b
Use attention tests from master
Feb 11, 2021
81d8389
Revert "Use attention tests from master"
Feb 11, 2021
39f92c3
Use attention from master
Feb 11, 2021
1d500d6
Renaming fest
Feb 11, 2021
6418e05
Use NamedTuples instead of attrs classes
Feb 11, 2021
6da8dd3
Bug fixes
Feb 11, 2021
ad4a821
remove GroupMaxStep
Feb 12, 2021
9725aa5
add some doc
Feb 12, 2021
f5190fe
Fix mock brain
Feb 12, 2021
664ae89
np float32 fixes
Feb 12, 2021
8f696f4
more renaming
Feb 12, 2021
77557ca
Test for team obs in agentprocessor
Feb 12, 2021
6464cb6
Test for group and add team reward
Feb 12, 2021
cbfdfb3
doc improve
Feb 12, 2021
6badfb5
Merge branch 'master' into develop-base-teammanager
Feb 13, 2021
ef67f53
Merge branch 'master' into develop-base-teammanager
Feb 13, 2021
8e78dbd
Merge branch 'develop-base-teammanager' of https://github.com/Unity-T…
Feb 13, 2021
31ee1c4
store registered agents in set
Feb 16, 2021
1e4c837
remove unused step counts
Feb 17, 2021
cba26b2
Merge branch 'develop-base-teammanager' into develop-agentprocessor-t…
Feb 17, 2021
2113a43
Global group ids
Feb 17, 2021
0e28c07
Fix Trajectory test
Feb 19, 2021
6936004
Merge branch 'master' into develop-agentprocessor-teammanager
Feb 23, 2021
97d1b80
Remove duplicated files
Feb 23, 2021
9a00053
Add team methods to AgentAction
Feb 23, 2021
f879b61
Buffer fixes
Feb 23, 2021
6d7a604
Add test for GroupObs
Feb 24, 2021
587e3da
Change AgentAction back to 0 pad and add tests
Feb 24, 2021
fd4aa53
Addressed some comments
Feb 24, 2021
8dbea77
Address some comments
Feb 25, 2021
ec9e5ad
Add more comments
Feb 25, 2021
e1f48db
Rename internal function
Feb 25, 2021
d42896a
Move padding method to AgentBufferField
Feb 25, 2021
b3f2689
Merge branch 'main' into develop-agentprocessor-teammanager
Feb 25, 2021
b2100c1
Fix slicing typing and string printing in AgentBufferField
Feb 25, 2021
7b1f805
Fix to-flat and add tests
Feb 25, 2021
1ce50ef
Rename GroupmateStatus to AgentStatus
Mar 1, 2021
aea066d
Update comments
Mar 4, 2021
0860098
Added GroupId, GlobalGroupId, GlobalAgentId types
Mar 4, 2021
580683c
Update comment
Mar 4, 2021
5e73905
Make some agent processor properties internal
Mar 4, 2021
f1c8d45
Rename add_group_status
Mar 4, 2021
af0d353
Rename store_group_status, fix test
Mar 4, 2021
f3ef9ef
Rename clear_group_obs
Mar 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mlagents_envs.exception import UnityActionException

AgentId = int
GroupId = int
BehaviorName = str


Expand Down Expand Up @@ -172,7 +173,7 @@ class TerminalStep(NamedTuple):
reward: float
interrupted: bool
agent_id: AgentId
group_id: int
group_id: GroupId
group_reward: float


Expand Down
211 changes: 159 additions & 52 deletions ml-agents/mlagents/trainers/agent_processor.py

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions ml-agents/mlagents/trainers/behavior_id_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import NamedTuple
from urllib.parse import urlparse, parse_qs
from mlagents_envs.base_env import AgentId, GroupId

GlobalGroupId = str
GlobalAgentId = str


class BehaviorIdentifiers(NamedTuple):
Expand Down Expand Up @@ -46,8 +50,15 @@ def create_name_behavior_id(name: str, team_id: int) -> str:
return name + "?team=" + str(team_id)


def get_global_agent_id(worker_id: int, agent_id: int) -> str:
def get_global_agent_id(worker_id: int, agent_id: AgentId) -> GlobalAgentId:
"""
Create an agent id that is unique across environment workers using the worker_id.
"""
return f"${worker_id}-{agent_id}"
return f"agent_{worker_id}-{agent_id}"


def get_global_group_id(worker_id: int, group_id: GroupId) -> GlobalGroupId:
"""
Create a group id that is unique across environment workers when using the worker_id.
"""
return f"group_{worker_id}-{group_id}"
106 changes: 77 additions & 29 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from mlagents_envs.exception import UnityException

# Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards,
# a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references.
BufferEntry = Union[np.ndarray, List[np.ndarray]]


class BufferException(UnityException):
"""
Expand All @@ -21,8 +25,10 @@ class BufferException(UnityException):
class BufferKey(enum.Enum):
ACTION_MASK = "action_mask"
CONTINUOUS_ACTION = "continuous_action"
NEXT_CONT_ACTION = "next_continuous_action"
CONTINUOUS_LOG_PROBS = "continuous_log_probs"
DISCRETE_ACTION = "discrete_action"
NEXT_DISC_ACTION = "next_discrete_action"
DISCRETE_LOG_PROBS = "discrete_log_probs"
DONE = "done"
ENVIRONMENT_REWARDS = "environment_rewards"
Expand All @@ -34,11 +40,22 @@ class BufferKey(enum.Enum):
ADVANTAGES = "advantages"
DISCOUNTED_RETURNS = "discounted_returns"

GROUP_DONES = "group_dones"
GROUPMATE_REWARDS = "groupmate_reward"
GROUP_REWARD = "group_reward"
GROUP_CONTINUOUS_ACTION = "group_continuous_action"
GROUP_DISCRETE_ACTION = "group_discrete_aaction"
GROUP_NEXT_CONT_ACTION = "group_next_cont_action"
GROUP_NEXT_DISC_ACTION = "group_next_disc_action"


class ObservationKeyPrefix(enum.Enum):
OBSERVATION = "obs"
NEXT_OBSERVATION = "next_obs"

GROUP_OBSERVATION = "group_obs"
NEXT_GROUP_OBSERVATION = "next_group_obs"


class RewardSignalKeyPrefix(enum.Enum):
# Reward signals
Expand Down Expand Up @@ -73,16 +90,23 @@ def advantage_key(name: str) -> AgentBufferKey:

class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to its
AgentBufferField with the append method.
AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries.
When an agent collects a field, you can add it to its AgentBufferField with the append method.
"""

def __init__(self):
def __init__(self, *args, **kwargs):
self.padding_value = 0
super().__init__()
super().__init__(*args, **kwargs)

def __str__(self):
return str(np.array(self).shape)
def __str__(self) -> str:
return f"AgentBufferField: {super().__str__()}"

def __getitem__(self, index):
return_data = super().__getitem__(index)
if isinstance(return_data, list):
return AgentBufferField(return_data)
else:
return return_data

def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
"""
Expand All @@ -95,31 +119,20 @@ def append(self, element: np.ndarray, padding_value: float = 0.0) -> None:
super().append(element)
self.padding_value = padding_value

def extend(self, data: np.ndarray) -> None:
"""
Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
self += list(np.array(data, dtype=np.float32))

def set(self, data):
def set(self, data: List[BufferEntry]) -> None:
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.
Sets the list of BufferEntry to the input data
:param data: The BufferEntry list to be set.
"""
# Make sure we convert incoming data to float32 if it's a float
dtype = None
if data is not None and len(data) and isinstance(data[0], float):
dtype = np.float32
self[:] = []
self[:] = list(np.array(data, dtype=dtype))
self[:] = data

def get_batch(
self,
batch_size: int = None,
training_length: Optional[int] = 1,
sequential: bool = True,
) -> np.ndarray:
) -> List[BufferEntry]:
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array
Expand Down Expand Up @@ -150,13 +163,10 @@ def get_batch(
)
if batch_size * training_length > len(self):
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:], dtype=np.float32
)
return [padding] * (training_length - leftover) + self[:]

else:
return np.array(
self[len(self) - batch_size * training_length :], dtype=np.float32
)
return self[len(self) - batch_size * training_length :]
else:
# The sequences will have overlapping elements
if batch_size is None:
Expand All @@ -172,14 +182,52 @@ def get_batch(
tmp_list: List[np.ndarray] = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list, dtype=np.float32)
return tmp_list

def reset_field(self) -> None:
"""
Resets the AgentBufferField
"""
self[:] = []

def padded_to_batch(
self, pad_value: np.float = 0, dtype: np.dtype = np.float32
) -> Union[np.ndarray, List[np.ndarray]]:
"""
Converts this AgentBufferField (which is a List[BufferEntry]) into a numpy array
with first dimension equal to the length of this AgentBufferField. If this AgentBufferField
contains a List[List[BufferEntry]] (i.e., in the case of group observations), return a List
containing numpy arrays or tensors, of length equal to the maximum length of an entry. Missing
For entries with less than that length, the array will be padded with pad_value.
:param pad_value: Value to pad List AgentBufferFields, when there are less than the maximum
number of agents present.
:param dtype: Dtype of output numpy array.
:return: Numpy array or List of numpy arrays representing this AgentBufferField, where the first
dimension is equal to the length of the AgentBufferField.
"""
if len(self) > 0 and not isinstance(self[0], list):
return np.asanyarray(self, dytpe=dtype)

shape = None
for _entry in self:
# _entry could be an empty list if there are no group agents in this
# step. Find the first non-empty list and use that shape.
if _entry:
shape = _entry[0].shape
break
# If there were no groupmate agents in the entire batch, return an empty List.
if shape is None:
return []

# Convert to numpy array while padding with 0's
new_list = list(
map(
lambda x: np.asanyarray(x, dtype=dtype),
itertools.zip_longest(*self, fillvalue=np.full(shape, pad_value)),
)
)
return new_list


class AgentBuffer(MutableMapping):
"""
Expand Down
15 changes: 8 additions & 7 deletions ml-agents/mlagents/trainers/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.behavior_id_utils import GlobalAgentId


class UnityPolicyException(UnityException):
Expand Down Expand Up @@ -68,7 +69,7 @@ def make_empty_memory(self, num_agents):
return np.zeros((num_agents, self.m_size), dtype=np.float32)

def save_memories(
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
self, agent_ids: List[GlobalAgentId], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
return
Expand All @@ -81,21 +82,21 @@ def save_memories(
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
def retrieve_memories(self, agent_ids: List[GlobalAgentId]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix

def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
def retrieve_previous_memories(self, agent_ids: List[GlobalAgentId]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_memory_dict:
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
return memory_matrix

def remove_memories(self, agent_ids):
def remove_memories(self, agent_ids: List[GlobalAgentId]) -> None:
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)
Expand All @@ -113,19 +114,19 @@ def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
)

def save_previous_action(
self, agent_ids: List[str], action_tuple: ActionTuple
self, agent_ids: List[GlobalAgentId], action_tuple: ActionTuple
) -> None:
for index, agent_id in enumerate(agent_ids):
self.previous_action_dict[agent_id] = action_tuple.discrete[index, :]

def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
def retrieve_previous_action(self, agent_ids: List[GlobalAgentId]) -> np.ndarray:
action_matrix = self.make_empty_previous_action(len(agent_ids))
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix

def remove_previous_action(self, agent_ids):
def remove_previous_action(self, agent_ids: List[GlobalAgentId]) -> None:
for agent_id in agent_ids:
if agent_id in self.previous_action_dict:
self.previous_action_dict.pop(agent_id)
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _update_policy(self):
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)

advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
advantages = np.array(self.update_buffer[BufferKey.ADVANTAGES].get_batch())
self.update_buffer[BufferKey.ADVANTAGES].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)
Expand Down
20 changes: 17 additions & 3 deletions ml-agents/mlagents/trainers/tests/mock_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mlagents.trainers.buffer import AgentBuffer, AgentBufferKey
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience
from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
Expand All @@ -20,6 +20,7 @@ def create_mock_steps(
observation_specs: List[ObservationSpec],
action_spec: ActionSpec,
done: bool = False,
grouped: bool = False,
) -> Tuple[DecisionSteps, TerminalSteps]:
"""
Creates a mock Tuple[DecisionSteps, TerminalSteps] with observations.
Expand All @@ -43,7 +44,8 @@ def create_mock_steps(
reward = np.array(num_agents * [1.0], dtype=np.float32)
interrupted = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)
group_id = np.array(num_agents * [0], dtype=np.int32)
_gid = 1 if grouped else 0
group_id = np.array(num_agents * [_gid], dtype=np.int32)
group_reward = np.array(num_agents * [0.0], dtype=np.float32)
behavior_spec = BehaviorSpec(observation_specs, action_spec)
if done:
Expand Down Expand Up @@ -78,6 +80,7 @@ def make_fake_trajectory(
action_spec: ActionSpec,
max_step_complete: bool = False,
memory_size: int = 10,
num_other_agents_in_group: int = 0,
) -> Trajectory:
"""
Makes a fake trajectory of length length. If max_step_complete,
Expand Down Expand Up @@ -117,6 +120,9 @@ def make_fake_trajectory(
memory = np.ones(memory_size, dtype=np.float32)
agent_id = "test_agent"
behavior_id = "test_brain"
group_status = []
for _ in range(num_other_agents_in_group):
group_status.append(AgentStatus(obs, reward, action, done))
experience = AgentExperience(
obs=obs,
reward=reward,
Expand All @@ -127,6 +133,8 @@ def make_fake_trajectory(
prev_action=prev_action,
interrupted=max_step,
memory=memory,
group_status=group_status,
group_reward=0,
)
steps_list.append(experience)
obs = []
Expand All @@ -142,10 +150,16 @@ def make_fake_trajectory(
prev_action=prev_action,
interrupted=max_step_complete,
memory=memory,
group_status=group_status,
group_reward=0,
)
steps_list.append(last_experience)
return Trajectory(
steps=steps_list, agent_id=agent_id, behavior_id=behavior_id, next_obs=obs
steps=steps_list,
agent_id=agent_id,
behavior_id=behavior_id,
next_obs=obs,
next_group_obs=[obs] * num_other_agents_in_group,
)


Expand Down
Loading