Skip to content

Commit f391b35

Browse files
author
Ervin Teng
committed
Merge branch 'develop-lin-enc-def' into develop-centralizedcritic-mm
2 parents c7c7d4c + 3f4b2b5 commit f391b35

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

ml-agents/mlagents/trainers/torch/attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,13 @@ def __init__(
117117
self.self_size = 0
118118
self.ent_encoders = torch.nn.ModuleList(
119119
[
120-
LinearEncoder(self.self_size + ent_size, 1, embedding_size)
120+
LinearEncoder(
121+
self.self_size + ent_size,
122+
1,
123+
embedding_size,
124+
kernel_init=Initialization.Normal,
125+
kernel_gain=(0.125 / embedding_size) ** 0.5,
126+
)
121127
for ent_size in self.entity_sizes
122128
]
123129
)

ml-agents/mlagents/trainers/torch/layers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,21 @@ class LinearEncoder(torch.nn.Module):
120120
Linear layers.
121121
"""
122122

123-
def __init__(self, input_size: int, num_layers: int, hidden_size: int):
123+
def __init__(
124+
self,
125+
input_size: int,
126+
num_layers: int,
127+
hidden_size: int,
128+
kernel_init: Initialization = Initialization.KaimingHeNormal,
129+
kernel_gain: float = 1.0,
130+
):
124131
super().__init__()
125132
self.layers = [
126133
linear_layer(
127134
input_size,
128135
hidden_size,
129-
kernel_init=Initialization.KaimingHeNormal,
130-
kernel_gain=1.0,
136+
kernel_init=kernel_init,
137+
kernel_gain=kernel_gain,
131138
)
132139
]
133140
self.layers.append(Swish())
@@ -136,8 +143,8 @@ def __init__(self, input_size: int, num_layers: int, hidden_size: int):
136143
linear_layer(
137144
hidden_size,
138145
hidden_size,
139-
kernel_init=Initialization.KaimingHeNormal,
140-
kernel_gain=1.0,
146+
kernel_init=kernel_init,
147+
kernel_gain=kernel_gain,
141148
)
142149
)
143150
self.layers.append(Swish())

0 commit comments

Comments
 (0)