Skip to content

Commit 6dbba73

Browse files
author
Ervin T
authored
[change] Remove concatenate in discrete action probabilities to improve inference performance (#3598)
1 parent e91a8cc commit 6dbba73

File tree

5 files changed

+50
-38
lines changed

5 files changed

+50
-38
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
4949
- `DecisionRequester` has been made internal (you can still use the DecisionRequesterComponent from the inspector). `RepeatAction` was renamed `TakeActionsBetweenDecisions` for clarity. (#3555)
5050
- The `IFloatProperties` interface has been removed.
5151
- Fix #3579.
52+
- Improved inference performance for models with multiple action branches. (#3598)
5253
- Fixed an issue when using GAIL with less than `batch_size` number of demonstrations. (#3591)
5354
- The interfaces to the `SideChannel` classes (on C# and python) have changed to use new `IncomingMessage` and `OutgoingMessage` classes. These should make reading and writing data to the channel easier. (#3596)
5455

ml-agents/mlagents/trainers/distributions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,11 @@ def _create_policy_branches(
214214
kernel_initializer=ModelUtils.scaled_init(0.01),
215215
)
216216
)
217-
unmasked_log_probs = tf.concat(policy_branches, axis=1)
218-
return unmasked_log_probs
217+
return policy_branches
219218

220219
def _get_masked_actions_probs(
221220
self,
222-
unmasked_log_probs: tf.Tensor,
221+
unmasked_log_probs: List[tf.Tensor],
223222
act_size: List[int],
224223
action_masks: tf.Tensor,
225224
) -> Tuple[tf.Tensor, tf.Tensor, np.ndarray]:

ml-agents/mlagents/trainers/models.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -456,25 +456,39 @@ def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction:
456456
)
457457

458458
@staticmethod
459-
def create_discrete_action_masking_layer(all_logits, action_masks, action_size):
459+
def break_into_branches(
460+
concatenated_logits: tf.Tensor, action_size: List[int]
461+
) -> List[tf.Tensor]:
462+
"""
463+
Takes a concatenated set of logits that represent multiple discrete action branches
464+
and breaks it up into one Tensor per branch.
465+
:param concatenated_logits: Tensor that represents the concatenated action branches
466+
:param action_size: List of ints containing the number of possible actions for each branch.
467+
:return: A List of Tensors containing one tensor per branch.
468+
"""
469+
action_idx = [0] + list(np.cumsum(action_size))
470+
branched_logits = [
471+
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
472+
for i in range(len(action_size))
473+
]
474+
return branched_logits
475+
476+
@staticmethod
477+
def create_discrete_action_masking_layer(
478+
branches_logits: List[tf.Tensor],
479+
action_masks: tf.Tensor,
480+
action_size: List[int],
481+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
460482
"""
461483
Creates a masking layer for the discrete actions
462-
:param all_logits: The concatenated unnormalized action probabilities for all branches
484+
:param branches_logits: A List of the unnormalized action probabilities for each branch
463485
:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
464486
:param action_size: A list containing the number of possible actions for each branch
465487
:return: The action output dimension [batch_size, num_branches], the concatenated
466488
normalized probs (after softmax)
467489
and the concatenated normalized log probs
468490
"""
469-
action_idx = [0] + list(np.cumsum(action_size))
470-
branches_logits = [
471-
all_logits[:, action_idx[i] : action_idx[i + 1]]
472-
for i in range(len(action_size))
473-
]
474-
branch_masks = [
475-
action_masks[:, action_idx[i] : action_idx[i + 1]]
476-
for i in range(len(action_size))
477-
]
491+
branch_masks = ModelUtils.break_into_branches(action_masks, action_size)
478492
raw_probs = [
479493
tf.multiply(tf.nn.softmax(branches_logits[k]) + EPSILON, branch_masks[k])
480494
for k in range(len(action_size))

ml-agents/mlagents/trainers/ppo/optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,14 @@ def _create_dc_critic(
169169
dtype=tf.float32,
170170
name="old_probabilities",
171171
)
172+
173+
# Break old log probs into separate branches
174+
old_log_prob_branches = ModelUtils.break_into_branches(
175+
self.all_old_log_probs, self.policy.act_size
176+
)
177+
172178
_, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer(
173-
self.all_old_log_probs, self.policy.action_masks, self.policy.act_size
179+
old_log_prob_branches, self.policy.action_masks, self.policy.act_size
174180
)
175181

176182
action_idx = [0] + list(np.cumsum(self.policy.act_size))

ml-agents/mlagents/trainers/sac/optimizer.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def _create_losses(
232232

233233
for name in stream_names:
234234
if discrete:
235-
_branched_mpq1 = self._apply_as_branches(
236-
self.policy_network.q1_pheads[name] * discrete_action_probs
235+
_branched_mpq1 = ModelUtils.break_into_branches(
236+
self.policy_network.q1_pheads[name] * discrete_action_probs,
237+
self.act_size,
237238
)
238239
branched_mpq1 = tf.stack(
239240
[
@@ -243,8 +244,9 @@ def _create_losses(
243244
)
244245
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0)
245246

246-
_branched_mpq2 = self._apply_as_branches(
247-
self.policy_network.q2_pheads[name] * discrete_action_probs
247+
_branched_mpq2 = ModelUtils.break_into_branches(
248+
self.policy_network.q2_pheads[name] * discrete_action_probs,
249+
self.act_size,
248250
)
249251
branched_mpq2 = tf.stack(
250252
[
@@ -282,11 +284,11 @@ def _create_losses(
282284

283285
if discrete:
284286
# We need to break up the Q functions by branch, and update them individually.
285-
branched_q1_stream = self._apply_as_branches(
286-
self.policy.selected_actions * q1_streams[name]
287+
branched_q1_stream = ModelUtils.break_into_branches(
288+
self.policy.selected_actions * q1_streams[name], self.act_size
287289
)
288-
branched_q2_stream = self._apply_as_branches(
289-
self.policy.selected_actions * q2_streams[name]
290+
branched_q2_stream = ModelUtils.break_into_branches(
291+
self.policy.selected_actions * q2_streams[name], self.act_size
290292
)
291293

292294
# Reduce each branch into scalar
@@ -344,7 +346,9 @@ def _create_losses(
344346
self.ent_coef = tf.exp(self.log_ent_coef)
345347
if discrete:
346348
# We also have to do a different entropy and target_entropy per branch.
347-
branched_per_action_ent = self._apply_as_branches(per_action_entropy)
349+
branched_per_action_ent = ModelUtils.break_into_branches(
350+
per_action_entropy, self.act_size
351+
)
348352
branched_ent_sums = tf.stack(
349353
[
350354
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te
@@ -364,8 +368,8 @@ def _create_losses(
364368
# Same with policy loss, we have to do the loss per branch and average them,
365369
# so that larger branches don't get more weight.
366370
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q
367-
branched_q_term = self._apply_as_branches(
368-
discrete_action_probs * self.policy_network.q1_p
371+
branched_q_term = ModelUtils.break_into_branches(
372+
discrete_action_probs * self.policy_network.q1_p, self.act_size
369373
)
370374

371375
branched_policy_loss = tf.stack(
@@ -444,18 +448,6 @@ def _create_losses(
444448

445449
self.entropy = self.policy_network.entropy
446450

447-
def _apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]:
448-
"""
449-
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per
450-
action branch
451-
"""
452-
action_idx = [0] + list(np.cumsum(self.act_size))
453-
branches_logits = [
454-
concat_logits[:, action_idx[i] : action_idx[i + 1]]
455-
for i in range(len(self.act_size))
456-
]
457-
return branches_logits
458-
459451
def _create_sac_optimizer_ops(self) -> None:
460452
"""
461453
Creates the Adam optimizers and update ops for SAC, including

0 commit comments

Comments
 (0)