Skip to content

Commit d0d9d14

Browse files
andrewcohErvin Teng
andauthored
Move the Critic into the Optimizer (#4939)
Co-authored-by: Ervin Teng <[email protected]>
1 parent 2e50682 commit d0d9d14

20 files changed

+421
-331
lines changed

ml-agents/mlagents/trainers/action_info.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@
66

77

88
class ActionInfo(NamedTuple):
9+
"""
10+
A NamedTuple containing actions and related quantities to the policy forward
11+
pass. Additionally contains the agent ids in the corresponding DecisionStep
12+
:param action: The action output of the policy
13+
:param env_action: The possibly clipped action to be executed in the environment
14+
:param outputs: Dict of all quantities associated with the policy forward pass
15+
:param agent_ids: List of int agent ids in DecisionStep
16+
"""
17+
918
action: Any
1019
env_action: Any
11-
value: Any
1220
outputs: ActionInfoOutputs
1321
agent_ids: List[AgentId]
1422

1523
@staticmethod
1624
def empty() -> "ActionInfo":
17-
return ActionInfo([], [], [], {}, [])
25+
return ActionInfo([], [], {}, [])

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _process_step(
126126
if stored_decision_step is not None and stored_take_action_outputs is not None:
127127
obs = stored_decision_step.obs
128128
if self.policy.use_recurrent:
129-
memory = self.policy.retrieve_memories([global_id])[0, :]
129+
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
130130
else:
131131
memory = None
132132
done = terminated # Since this is an ongoing step

ml-agents/mlagents/trainers/buffer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class BufferKey(enum.Enum):
2828
ENVIRONMENT_REWARDS = "environment_rewards"
2929
MASKS = "masks"
3030
MEMORY = "memory"
31+
CRITIC_MEMORY = "critic_memory"
3132
PREV_ACTION = "prev_action"
3233

3334
ADVANTAGES = "advantages"

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Dict, Optional, Tuple, List
22
from mlagents.torch_utils import torch
33
import numpy as np
4+
import math
45

5-
from mlagents.trainers.buffer import AgentBuffer
6+
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField
67
from mlagents.trainers.trajectory import ObsUtil
78
from mlagents.trainers.torch.components.bc.module import BCModule
89
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
@@ -26,6 +27,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
2627
self.global_step = torch.tensor(0)
2728
self.bc_module: Optional[BCModule] = None
2829
self.create_reward_signals(trainer_settings.reward_signals)
30+
self.critic_memory_dict: Dict[str, torch.Tensor] = {}
2931
if trainer_settings.behavioral_cloning is not None:
3032
self.bc_module = BCModule(
3133
self.policy,
@@ -35,6 +37,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
3537
default_num_epoch=3,
3638
)
3739

40+
@property
41+
def critic(self):
42+
raise NotImplementedError
43+
3844
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
3945
pass
4046

@@ -49,25 +55,132 @@ def create_reward_signals(self, reward_signal_configs):
4955
reward_signal, self.policy.behavior_spec, settings
5056
)
5157

58+
def _evaluate_by_sequence(
59+
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
60+
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
61+
"""
62+
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the
63+
intermediate memories for the critic.
64+
:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's
65+
observations for this trajectory.
66+
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e.
67+
what is returned as the output of a MemoryModules.
68+
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial
69+
memories to be used during value function update, and the final memory at the end of the trajectory.
70+
"""
71+
num_experiences = tensor_obs[0].shape[0]
72+
all_next_memories = AgentBufferField()
73+
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
74+
# trajectory is of length 10, the 1st sequence is [pad,pad,obs].
75+
# Compute the number of elements in this padded seq.
76+
leftover = num_experiences % self.policy.sequence_length
77+
78+
# Compute values for the potentially truncated initial sequence
79+
seq_obs = []
80+
81+
first_seq_len = self.policy.sequence_length
82+
for _obs in tensor_obs:
83+
if leftover > 0:
84+
first_seq_len = leftover
85+
first_seq_obs = _obs[0:first_seq_len]
86+
seq_obs.append(first_seq_obs)
87+
88+
# For the first sequence, the initial memory should be the one at the
89+
# beginning of this trajectory.
90+
for _ in range(first_seq_len):
91+
all_next_memories.append(initial_memory.squeeze().detach().numpy())
92+
93+
init_values, _mem = self.critic.critic_pass(
94+
seq_obs, initial_memory, sequence_length=first_seq_len
95+
)
96+
all_values = {
97+
signal_name: [init_values[signal_name]]
98+
for signal_name in init_values.keys()
99+
}
100+
101+
# Evaluate other trajectories, carrying over _mem after each
102+
# trajectory
103+
for seq_num in range(
104+
1, math.ceil((num_experiences) / (self.policy.sequence_length))
105+
):
106+
seq_obs = []
107+
for _ in range(self.policy.sequence_length):
108+
all_next_memories.append(_mem.squeeze().detach().numpy())
109+
for _obs in tensor_obs:
110+
start = seq_num * self.policy.sequence_length - (
111+
self.policy.sequence_length - leftover
112+
)
113+
end = (seq_num + 1) * self.policy.sequence_length - (
114+
self.policy.sequence_length - leftover
115+
)
116+
seq_obs.append(_obs[start:end])
117+
values, _mem = self.critic.critic_pass(
118+
seq_obs, _mem, sequence_length=self.policy.sequence_length
119+
)
120+
for signal_name, _val in values.items():
121+
all_values[signal_name].append(_val)
122+
# Create one tensor per reward signal
123+
all_value_tensors = {
124+
signal_name: torch.cat(value_list, dim=0)
125+
for signal_name, value_list in all_values.items()
126+
}
127+
next_mem = _mem
128+
return all_value_tensors, all_next_memories, next_mem
129+
52130
def get_trajectory_value_estimates(
53-
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
54-
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
131+
self,
132+
batch: AgentBuffer,
133+
next_obs: List[np.ndarray],
134+
done: bool,
135+
agent_id: str = "",
136+
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]:
137+
"""
138+
Get value estimates and memories for a trajectory, in batch form.
139+
:param batch: An AgentBuffer that consists of a trajectory.
140+
:param next_obs: the next observation (after the trajectory). Used for boostrapping
141+
if this is not a termiinal trajectory.
142+
:param done: Set true if this is a terminal trajectory.
143+
:param agent_id: Agent ID of the agent that this trajectory belongs to.
144+
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
145+
the final value estimate as a Dict of [name, float], and optionally (if using memories)
146+
an AgentBufferField of initial critic memories to be used during update.
147+
"""
55148
n_obs = len(self.policy.behavior_spec.observation_specs)
56-
current_obs = ObsUtil.from_buffer(batch, n_obs)
149+
150+
if agent_id in self.critic_memory_dict:
151+
memory = self.critic_memory_dict[agent_id]
152+
else:
153+
memory = (
154+
torch.zeros((1, 1, self.critic.memory_size))
155+
if self.policy.use_recurrent
156+
else None
157+
)
57158

58159
# Convert to tensors
59-
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
160+
current_obs = [
161+
ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs)
162+
]
60163
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
61164

62-
memory = torch.zeros([1, 1, self.policy.m_size])
63-
64165
next_obs = [obs.unsqueeze(0) for obs in next_obs]
65166

66-
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
67-
current_obs, memory, sequence_length=batch.num_experiences
68-
)
167+
# If we're using LSTM, we want to get all the intermediate memories.
168+
all_next_memories: Optional[AgentBufferField] = None
169+
if self.policy.use_recurrent:
170+
(
171+
value_estimates,
172+
all_next_memories,
173+
next_memory,
174+
) = self._evaluate_by_sequence(current_obs, memory)
175+
else:
176+
value_estimates, next_memory = self.critic.critic_pass(
177+
current_obs, memory, sequence_length=batch.num_experiences
178+
)
69179

70-
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
180+
# Store the memory for the next trajectory
181+
self.critic_memory_dict[agent_id] = next_memory
182+
183+
next_value_estimate, _ = self.critic.critic_pass(
71184
next_obs, next_memory, sequence_length=1
72185
)
73186

@@ -79,5 +192,6 @@ def get_trajectory_value_estimates(
79192
for k in next_value_estimate:
80193
if not self.reward_signals[k].ignore_done:
81194
next_value_estimate[k] = 0.0
82-
83-
return value_estimates, next_value_estimate
195+
if agent_id in self.critic_memory_dict:
196+
self.critic_memory_dict.pop(agent_id)
197+
return value_estimates, next_value_estimate, all_next_memories

ml-agents/mlagents/trainers/policy/policy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
self.network_settings: NetworkSettings = trainer_settings.network_settings
3434
self.seed = seed
3535
self.previous_action_dict: Dict[str, np.ndarray] = {}
36+
self.previous_memory_dict: Dict[str, np.ndarray] = {}
3637
self.memory_dict: Dict[str, np.ndarray] = {}
3738
self.normalize = trainer_settings.network_settings.normalize
3839
self.use_recurrent = self.network_settings.memory is not None
@@ -72,6 +73,11 @@ def save_memories(
7273
if memory_matrix is None:
7374
return
7475

76+
# Pass old memories into previous_memory_dict
77+
for agent_id in agent_ids:
78+
if agent_id in self.memory_dict:
79+
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id]
80+
7581
for index, agent_id in enumerate(agent_ids):
7682
self.memory_dict[agent_id] = memory_matrix[index, :]
7783

@@ -82,10 +88,19 @@ def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
8288
memory_matrix[index, :] = self.memory_dict[agent_id]
8389
return memory_matrix
8490

91+
def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
92+
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
93+
for index, agent_id in enumerate(agent_ids):
94+
if agent_id in self.previous_memory_dict:
95+
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
96+
return memory_matrix
97+
8598
def remove_memories(self, agent_ids):
8699
for agent_id in agent_ids:
87100
if agent_id in self.memory_dict:
88101
self.memory_dict.pop(agent_id)
102+
if agent_id in self.previous_memory_dict:
103+
self.previous_memory_dict.pop(agent_id)
89104

90105
def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
91106
"""

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
from mlagents_envs.timers import timed
1111

1212
from mlagents.trainers.settings import TrainerSettings
13-
from mlagents.trainers.torch.networks import (
14-
SharedActorCritic,
15-
SeparateActorCritic,
16-
GlobalSteps,
17-
)
13+
from mlagents.trainers.torch.networks import SimpleActor, SharedActorCritic, GlobalSteps
1814

1915
from mlagents.trainers.torch.utils import ModelUtils
2016
from mlagents.trainers.buffer import AgentBuffer
@@ -61,31 +57,40 @@ def __init__(
6157
) # could be much simpler if TorchPolicy is nn.Module
6258
self.grads = None
6359

64-
reward_signal_configs = trainer_settings.reward_signals
65-
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
66-
6760
self.stats_name_to_update_name = {
6861
"Losses/Value Loss": "value_loss",
6962
"Losses/Policy Loss": "policy_loss",
7063
}
7164
if separate_critic:
72-
ac_class = SeparateActorCritic
65+
self.actor = SimpleActor(
66+
observation_specs=self.behavior_spec.observation_specs,
67+
network_settings=trainer_settings.network_settings,
68+
action_spec=behavior_spec.action_spec,
69+
conditional_sigma=self.condition_sigma_on_obs,
70+
tanh_squash=tanh_squash,
71+
)
72+
self.shared_critic = False
7373
else:
74-
ac_class = SharedActorCritic
75-
self.actor_critic = ac_class(
76-
observation_specs=self.behavior_spec.observation_specs,
77-
network_settings=trainer_settings.network_settings,
78-
action_spec=behavior_spec.action_spec,
79-
stream_names=reward_signal_names,
80-
conditional_sigma=self.condition_sigma_on_obs,
81-
tanh_squash=tanh_squash,
82-
)
74+
reward_signal_configs = trainer_settings.reward_signals
75+
reward_signal_names = [
76+
key.value for key, _ in reward_signal_configs.items()
77+
]
78+
self.actor = SharedActorCritic(
79+
observation_specs=self.behavior_spec.observation_specs,
80+
network_settings=trainer_settings.network_settings,
81+
action_spec=behavior_spec.action_spec,
82+
stream_names=reward_signal_names,
83+
conditional_sigma=self.condition_sigma_on_obs,
84+
tanh_squash=tanh_squash,
85+
)
86+
self.shared_critic = True
87+
8388
# Save the m_size needed for export
8489
self._export_m_size = self.m_size
8590
# m_size needed for training is determined by network, not trainer settings
86-
self.m_size = self.actor_critic.memory_size
91+
self.m_size = self.actor.memory_size
8792

88-
self.actor_critic.to(default_device())
93+
self.actor.to(default_device())
8994
self._clip_action = not tanh_squash
9095

9196
@property
@@ -115,7 +120,7 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
115120
"""
116121

117122
if self.normalize:
118-
self.actor_critic.update_normalization(buffer)
123+
self.actor.update_normalization(buffer)
119124

120125
@timed
121126
def sample_actions(
@@ -132,7 +137,7 @@ def sample_actions(
132137
:param seq_len: Sequence length when using RNN.
133138
:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories.
134139
"""
135-
actions, log_probs, entropies, memories = self.actor_critic.get_action_stats(
140+
actions, log_probs, entropies, memories = self.actor.get_action_and_stats(
136141
obs, masks, memories, seq_len
137142
)
138143
return (actions, log_probs, entropies, memories)
@@ -144,11 +149,11 @@ def evaluate_actions(
144149
masks: Optional[torch.Tensor] = None,
145150
memories: Optional[torch.Tensor] = None,
146151
seq_len: int = 1,
147-
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
148-
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value(
152+
) -> Tuple[ActionLogProbs, torch.Tensor]:
153+
log_probs, entropies = self.actor.get_stats(
149154
obs, actions, masks, memories, seq_len
150155
)
151-
return log_probs, entropies, value_heads
156+
return log_probs, entropies
152157

153158
@timed
154159
def evaluate(
@@ -210,7 +215,6 @@ def get_action(
210215
return ActionInfo(
211216
action=run_out.get("action"),
212217
env_action=run_out.get("env_action"),
213-
value=run_out.get("value"),
214218
outputs=run_out,
215219
agent_ids=list(decision_requests.agent_id),
216220
)
@@ -239,13 +243,13 @@ def increment_step(self, n_steps):
239243
return self.get_current_step()
240244

241245
def load_weights(self, values: List[np.ndarray]) -> None:
242-
self.actor_critic.load_state_dict(values)
246+
self.actor.load_state_dict(values)
243247

244248
def init_load_weights(self) -> None:
245249
pass
246250

247251
def get_weights(self) -> List[np.ndarray]:
248-
return copy.deepcopy(self.actor_critic.state_dict())
252+
return copy.deepcopy(self.actor.state_dict())
249253

250254
def get_modules(self):
251-
return {"Policy": self.actor_critic, "global_step": self.global_step}
255+
return {"Policy": self.actor, "global_step": self.global_step}

0 commit comments

Comments
 (0)