Skip to content

Commit 4b7d0c9

Browse files
Merge pull request #27 from Unity-Technologies/fix-internal-placeholder
made a better error if a placeholder is missing or if a placeholder is …
2 parents 53bf659 + caa5ba9 commit 4b7d0c9

File tree

1 file changed

+38
-9
lines changed

1 file changed

+38
-9
lines changed

unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public enum tensorType
5151
public string[] ObservationPlaceholderName;
5252
/// Modify only in inspector : Name of the action node
5353
public string ActionPlaceholderName = "action";
54-
#if ENABLE_TENSORFLOW
54+
#if ENABLE_TENSORFLOW
5555
TFGraph graph;
5656
TFSession session;
5757
bool hasRecurrent;
@@ -62,7 +62,7 @@ public enum tensorType
6262
float[,] inputState;
6363
List<float[,,,]> observationMatrixList;
6464
float[,] inputOldMemories;
65-
#endif
65+
#endif
6666

6767
/// Reference to the brain that uses this CoreBrainInternal
6868
public Brain brain;
@@ -190,13 +190,22 @@ public void DecideAction()
190190

191191
foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
192192
{
193-
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
193+
try
194194
{
195-
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
195+
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
196+
{
197+
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
198+
}
199+
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
200+
{
201+
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
202+
}
196203
}
197-
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
204+
catch
198205
{
199-
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
206+
throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found.
207+
In brain {0}, there are no {1} placeholder named {2}.",
208+
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
200209
}
201210
}
202211

@@ -212,6 +221,26 @@ public void DecideAction()
212221
runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]);
213222
}
214223

224+
TFTensor[] networkOutput;
225+
try
226+
{
227+
networkOutput = runner.Run();
228+
}
229+
catch (TFException e)
230+
{
231+
string errorMessage = e.Message;
232+
try
233+
{
234+
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}",
235+
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0],
236+
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]);
237+
}
238+
finally
239+
{
240+
throw new UnityAgentsException(errorMessage);
241+
}
242+
243+
}
215244

216245
// Create the recurrent tensor
217246
if (hasRecurrent)
@@ -220,7 +249,7 @@ public void DecideAction()
220249

221250
runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
222251
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
223-
float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,];
252+
float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,];
224253

225254
int i = 0;
226255
foreach (int k in agentKeys)
@@ -241,7 +270,7 @@ public void DecideAction()
241270

242271
if (brain.brainParameters.actionSpaceType == StateType.continuous)
243272
{
244-
float[,] output = runner.Run()[0].GetValue() as float[,];
273+
float[,] output = networkOutput[0].GetValue() as float[,];
245274
int i = 0;
246275
foreach (int k in agentKeys)
247276
{
@@ -256,7 +285,7 @@ public void DecideAction()
256285
}
257286
else if (brain.brainParameters.actionSpaceType == StateType.discrete)
258287
{
259-
long[,] output = runner.Run()[0].GetValue() as long[,];
288+
long[,] output = networkOutput[0].GetValue() as long[,];
260289
int i = 0;
261290
foreach (int k in agentKeys)
262291
{

0 commit comments

Comments
 (0)