Skip to content

Commit 8c01b76

Browse files
POCA Attention will use h_size for embedding size and not 128 (#5281)
1 parent d06488d commit 8c01b76

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def forward(
253253

254254

255255
class MultiAgentNetworkBody(torch.nn.Module):
256-
ATTENTION_EMBEDDING_SIZE = 128
257-
258256
"""
259257
A network body that uses a self attention layer to handle state
260258
and action input from a potentially variable number of agents that
@@ -293,17 +291,18 @@ def __init__(
293291
+ self.action_spec.continuous_size
294292
)
295293

294+
attention_embeding_size = self.h_size
296295
self.obs_encoder = EntityEmbedding(
297-
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
296+
obs_only_ent_size, None, attention_embeding_size
298297
)
299298
self.obs_action_encoder = EntityEmbedding(
300-
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
299+
q_ent_size, None, attention_embeding_size
301300
)
302301

303-
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)
302+
self.self_attn = ResidualSelfAttention(attention_embeding_size)
304303

305304
self.linear_encoder = LinearEncoder(
306-
self.ATTENTION_EMBEDDING_SIZE,
305+
attention_embeding_size,
307306
network_settings.num_layers,
308307
self.h_size,
309308
kernel_gain=(0.125 / self.h_size) ** 0.5,

0 commit comments

Comments
 (0)