Skip to content

Commit 1c4ceaf

Browse files
[🐛🔨 ] Fix sac target for continuous actions (#5372)
* Fix of the target entropy for continuous SAC * Lowering required steps of test and remove unecessary unsqueeze * Changing the target from -dim(a)^2 to -dim(a) by removing implicit broadcasting
1 parent d3f5ca7 commit 1c4ceaf

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,8 @@ def sac_policy_loss(
385385
all_mean_q1 = mean_q1
386386
if self._action_spec.continuous_size > 0:
387387
cont_log_probs = log_probs.continuous_tensor
388-
batch_policy_loss += torch.mean(
389-
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
388+
batch_policy_loss += (
389+
_cont_ent_coef * torch.sum(cont_log_probs, dim=1) - all_mean_q1
390390
)
391391
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
392392

@@ -426,8 +426,8 @@ def sac_entropy_loss(
426426
if self._action_spec.continuous_size > 0:
427427
with torch.no_grad():
428428
cont_log_probs = log_probs.continuous_tensor
429-
target_current_diff = torch.sum(
430-
cont_log_probs + self.target_entropy.continuous, dim=1
429+
target_current_diff = (
430+
torch.sum(cont_log_probs, dim=1) + self.target_entropy.continuous
431431
)
432432
# We update all the _cont_ent_coef as one block
433433
entropy_loss += -1 * ModelUtils.masked_mean(

ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_2d_sac(action_sizes):
256256
SAC_TORCH_CONFIG.hyperparameters, buffer_init_steps=2000
257257
)
258258
config = attr.evolve(
259-
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=6000
259+
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=3000
260260
)
261261
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.8)
262262

0 commit comments

Comments
 (0)