@@ -253,8 +253,6 @@ def forward(
253
253
254
254
255
255
class MultiAgentNetworkBody (torch .nn .Module ):
256
- ATTENTION_EMBEDDING_SIZE = 128
257
-
258
256
"""
259
257
A network body that uses a self attention layer to handle state
260
258
and action input from a potentially variable number of agents that
@@ -293,17 +291,18 @@ def __init__(
293
291
+ self .action_spec .continuous_size
294
292
)
295
293
294
+ attention_embeding_size = self .h_size
296
295
self .obs_encoder = EntityEmbedding (
297
- obs_only_ent_size , None , self . ATTENTION_EMBEDDING_SIZE
296
+ obs_only_ent_size , None , attention_embeding_size
298
297
)
299
298
self .obs_action_encoder = EntityEmbedding (
300
- q_ent_size , None , self . ATTENTION_EMBEDDING_SIZE
299
+ q_ent_size , None , attention_embeding_size
301
300
)
302
301
303
- self .self_attn = ResidualSelfAttention (self . ATTENTION_EMBEDDING_SIZE )
302
+ self .self_attn = ResidualSelfAttention (attention_embeding_size )
304
303
305
304
self .linear_encoder = LinearEncoder (
306
- self . ATTENTION_EMBEDDING_SIZE ,
305
+ attention_embeding_size ,
307
306
network_settings .num_layers ,
308
307
self .h_size ,
309
308
kernel_gain = (0.125 / self .h_size ) ** 0.5 ,
0 commit comments