1
1
from typing import Dict , Optional , Tuple , List
2
2
from mlagents .torch_utils import torch
3
3
import numpy as np
4
+ import math
4
5
5
- from mlagents .trainers .buffer import AgentBuffer
6
+ from mlagents .trainers .buffer import AgentBuffer , AgentBufferField
6
7
from mlagents .trainers .trajectory import ObsUtil
7
8
from mlagents .trainers .torch .components .bc .module import BCModule
8
9
from mlagents .trainers .torch .components .reward_providers import create_reward_provider
@@ -26,6 +27,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
26
27
self .global_step = torch .tensor (0 )
27
28
self .bc_module : Optional [BCModule ] = None
28
29
self .create_reward_signals (trainer_settings .reward_signals )
30
+ self .critic_memory_dict : Dict [str , torch .Tensor ] = {}
29
31
if trainer_settings .behavioral_cloning is not None :
30
32
self .bc_module = BCModule (
31
33
self .policy ,
@@ -35,6 +37,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
35
37
default_num_epoch = 3 ,
36
38
)
37
39
40
+ @property
41
+ def critic (self ):
42
+ raise NotImplementedError
43
+
38
44
def update (self , batch : AgentBuffer , num_sequences : int ) -> Dict [str , float ]:
39
45
pass
40
46
@@ -49,25 +55,132 @@ def create_reward_signals(self, reward_signal_configs):
49
55
reward_signal , self .policy .behavior_spec , settings
50
56
)
51
57
58
+ def _evaluate_by_sequence (
59
+ self , tensor_obs : List [torch .Tensor ], initial_memory : np .ndarray
60
+ ) -> Tuple [Dict [str , torch .Tensor ], AgentBufferField , torch .Tensor ]:
61
+ """
62
+ Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the
63
+ intermediate memories for the critic.
64
+ :param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's
65
+ observations for this trajectory.
66
+ :param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e.
67
+ what is returned as the output of a MemoryModules.
68
+ :return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial
69
+ memories to be used during value function update, and the final memory at the end of the trajectory.
70
+ """
71
+ num_experiences = tensor_obs [0 ].shape [0 ]
72
+ all_next_memories = AgentBufferField ()
73
+ # In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
74
+ # trajectory is of length 10, the 1st sequence is [pad,pad,obs].
75
+ # Compute the number of elements in this padded seq.
76
+ leftover = num_experiences % self .policy .sequence_length
77
+
78
+ # Compute values for the potentially truncated initial sequence
79
+ seq_obs = []
80
+
81
+ first_seq_len = self .policy .sequence_length
82
+ for _obs in tensor_obs :
83
+ if leftover > 0 :
84
+ first_seq_len = leftover
85
+ first_seq_obs = _obs [0 :first_seq_len ]
86
+ seq_obs .append (first_seq_obs )
87
+
88
+ # For the first sequence, the initial memory should be the one at the
89
+ # beginning of this trajectory.
90
+ for _ in range (first_seq_len ):
91
+ all_next_memories .append (initial_memory .squeeze ().detach ().numpy ())
92
+
93
+ init_values , _mem = self .critic .critic_pass (
94
+ seq_obs , initial_memory , sequence_length = first_seq_len
95
+ )
96
+ all_values = {
97
+ signal_name : [init_values [signal_name ]]
98
+ for signal_name in init_values .keys ()
99
+ }
100
+
101
+ # Evaluate other trajectories, carrying over _mem after each
102
+ # trajectory
103
+ for seq_num in range (
104
+ 1 , math .ceil ((num_experiences ) / (self .policy .sequence_length ))
105
+ ):
106
+ seq_obs = []
107
+ for _ in range (self .policy .sequence_length ):
108
+ all_next_memories .append (_mem .squeeze ().detach ().numpy ())
109
+ for _obs in tensor_obs :
110
+ start = seq_num * self .policy .sequence_length - (
111
+ self .policy .sequence_length - leftover
112
+ )
113
+ end = (seq_num + 1 ) * self .policy .sequence_length - (
114
+ self .policy .sequence_length - leftover
115
+ )
116
+ seq_obs .append (_obs [start :end ])
117
+ values , _mem = self .critic .critic_pass (
118
+ seq_obs , _mem , sequence_length = self .policy .sequence_length
119
+ )
120
+ for signal_name , _val in values .items ():
121
+ all_values [signal_name ].append (_val )
122
+ # Create one tensor per reward signal
123
+ all_value_tensors = {
124
+ signal_name : torch .cat (value_list , dim = 0 )
125
+ for signal_name , value_list in all_values .items ()
126
+ }
127
+ next_mem = _mem
128
+ return all_value_tensors , all_next_memories , next_mem
129
+
52
130
def get_trajectory_value_estimates (
53
- self , batch : AgentBuffer , next_obs : List [np .ndarray ], done : bool
54
- ) -> Tuple [Dict [str , np .ndarray ], Dict [str , float ]]:
131
+ self ,
132
+ batch : AgentBuffer ,
133
+ next_obs : List [np .ndarray ],
134
+ done : bool ,
135
+ agent_id : str = "" ,
136
+ ) -> Tuple [Dict [str , np .ndarray ], Dict [str , float ], Optional [AgentBufferField ]]:
137
+ """
138
+ Get value estimates and memories for a trajectory, in batch form.
139
+ :param batch: An AgentBuffer that consists of a trajectory.
140
+ :param next_obs: the next observation (after the trajectory). Used for boostrapping
141
+ if this is not a termiinal trajectory.
142
+ :param done: Set true if this is a terminal trajectory.
143
+ :param agent_id: Agent ID of the agent that this trajectory belongs to.
144
+ :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
145
+ the final value estimate as a Dict of [name, float], and optionally (if using memories)
146
+ an AgentBufferField of initial critic memories to be used during update.
147
+ """
55
148
n_obs = len (self .policy .behavior_spec .observation_specs )
56
- current_obs = ObsUtil .from_buffer (batch , n_obs )
149
+
150
+ if agent_id in self .critic_memory_dict :
151
+ memory = self .critic_memory_dict [agent_id ]
152
+ else :
153
+ memory = (
154
+ torch .zeros ((1 , 1 , self .critic .memory_size ))
155
+ if self .policy .use_recurrent
156
+ else None
157
+ )
57
158
58
159
# Convert to tensors
59
- current_obs = [ModelUtils .list_to_tensor (obs ) for obs in current_obs ]
160
+ current_obs = [
161
+ ModelUtils .list_to_tensor (obs ) for obs in ObsUtil .from_buffer (batch , n_obs )
162
+ ]
60
163
next_obs = [ModelUtils .list_to_tensor (obs ) for obs in next_obs ]
61
164
62
- memory = torch .zeros ([1 , 1 , self .policy .m_size ])
63
-
64
165
next_obs = [obs .unsqueeze (0 ) for obs in next_obs ]
65
166
66
- value_estimates , next_memory = self .policy .actor_critic .critic_pass (
67
- current_obs , memory , sequence_length = batch .num_experiences
68
- )
167
+ # If we're using LSTM, we want to get all the intermediate memories.
168
+ all_next_memories : Optional [AgentBufferField ] = None
169
+ if self .policy .use_recurrent :
170
+ (
171
+ value_estimates ,
172
+ all_next_memories ,
173
+ next_memory ,
174
+ ) = self ._evaluate_by_sequence (current_obs , memory )
175
+ else :
176
+ value_estimates , next_memory = self .critic .critic_pass (
177
+ current_obs , memory , sequence_length = batch .num_experiences
178
+ )
69
179
70
- next_value_estimate , _ = self .policy .actor_critic .critic_pass (
180
+ # Store the memory for the next trajectory
181
+ self .critic_memory_dict [agent_id ] = next_memory
182
+
183
+ next_value_estimate , _ = self .critic .critic_pass (
71
184
next_obs , next_memory , sequence_length = 1
72
185
)
73
186
@@ -79,5 +192,6 @@ def get_trajectory_value_estimates(
79
192
for k in next_value_estimate :
80
193
if not self .reward_signals [k ].ignore_done :
81
194
next_value_estimate [k ] = 0.0
82
-
83
- return value_estimates , next_value_estimate
195
+ if agent_id in self .critic_memory_dict :
196
+ self .critic_memory_dict .pop (agent_id )
197
+ return value_estimates , next_value_estimate , all_next_memories
0 commit comments