@@ -51,7 +51,7 @@ public enum tensorType
51
51
public string [ ] ObservationPlaceholderName ;
52
52
/// Modify only in inspector : Name of the action node
53
53
public string ActionPlaceholderName = "action" ;
54
- #if ENABLE_TENSORFLOW
54
+ #if ENABLE_TENSORFLOW
55
55
TFGraph graph ;
56
56
TFSession session ;
57
57
bool hasRecurrent ;
@@ -62,7 +62,7 @@ public enum tensorType
62
62
float [ , ] inputState ;
63
63
List < float [ , , , ] > observationMatrixList ;
64
64
float [ , ] inputOldMemories ;
65
- #endif
65
+ #endif
66
66
67
67
/// Reference to the brain that uses this CoreBrainInternal
68
68
public Brain brain ;
@@ -190,13 +190,22 @@ public void DecideAction()
190
190
191
191
foreach ( TensorFlowAgentPlaceholder placeholder in graphPlaceholders )
192
192
{
193
- if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . FloatingPoint )
193
+ try
194
194
{
195
- runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new float [ ] { Random . Range ( placeholder . minValue , placeholder . maxValue ) } ) ;
195
+ if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . FloatingPoint )
196
+ {
197
+ runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new float [ ] { Random . Range ( placeholder . minValue , placeholder . maxValue ) } ) ;
198
+ }
199
+ else if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . Integer )
200
+ {
201
+ runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new int [ ] { Random . Range ( ( int ) placeholder . minValue , ( int ) placeholder . maxValue + 1 ) } ) ;
202
+ }
196
203
}
197
- else if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . Integer )
204
+ catch
198
205
{
199
- runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new int [ ] { Random . Range ( ( int ) placeholder . minValue , ( int ) placeholder . maxValue + 1 ) } ) ;
206
+ throw new UnityAgentsException ( string . Format ( @"One of the Tensorflow placeholder cound nout be found.
207
+ In brain {0}, there are no {1} placeholder named {2}." ,
208
+ brain . gameObject . name , placeholder . valueType . ToString ( ) , graphScope + placeholder . name ) ) ;
200
209
}
201
210
}
202
211
@@ -212,6 +221,26 @@ public void DecideAction()
212
221
runner . AddInput ( graph [ graphScope + ObservationPlaceholderName [ obs_number ] ] [ 0 ] , observationMatrixList [ obs_number ] ) ;
213
222
}
214
223
224
+ TFTensor [ ] networkOutput ;
225
+ try
226
+ {
227
+ networkOutput = runner . Run ( ) ;
228
+ }
229
+ catch ( TFException e )
230
+ {
231
+ string errorMessage = e . Message ;
232
+ try
233
+ {
234
+ errorMessage = string . Format ( @"The tensorflow graph needs an input for {0} of type {1}" ,
235
+ e . Message . Split ( new string [ ] { "Node: " } , 0 ) [ 1 ] . Split ( '=' ) [ 0 ] ,
236
+ e . Message . Split ( new string [ ] { "dtype=" } , 0 ) [ 1 ] . Split ( ',' ) [ 0 ] ) ;
237
+ }
238
+ finally
239
+ {
240
+ throw new UnityAgentsException ( errorMessage ) ;
241
+ }
242
+
243
+ }
215
244
216
245
// Create the recurrent tensor
217
246
if ( hasRecurrent )
@@ -220,7 +249,7 @@ public void DecideAction()
220
249
221
250
runner . AddInput ( graph [ graphScope + RecurrentInPlaceholderName ] [ 0 ] , inputOldMemories ) ;
222
251
runner . Fetch ( graph [ graphScope + RecurrentOutPlaceholderName ] [ 0 ] ) ;
223
- float [ , ] recurrent_tensor = runner . Run ( ) [ 1 ] . GetValue ( ) as float [ , ] ;
252
+ float [ , ] recurrent_tensor = networkOutput [ 1 ] . GetValue ( ) as float [ , ] ;
224
253
225
254
int i = 0 ;
226
255
foreach ( int k in agentKeys )
@@ -241,7 +270,7 @@ public void DecideAction()
241
270
242
271
if ( brain . brainParameters . actionSpaceType == StateType . continuous )
243
272
{
244
- float [ , ] output = runner . Run ( ) [ 0 ] . GetValue ( ) as float [ , ] ;
273
+ float [ , ] output = networkOutput [ 0 ] . GetValue ( ) as float [ , ] ;
245
274
int i = 0 ;
246
275
foreach ( int k in agentKeys )
247
276
{
@@ -256,7 +285,7 @@ public void DecideAction()
256
285
}
257
286
else if ( brain . brainParameters . actionSpaceType == StateType . discrete )
258
287
{
259
- long [ , ] output = runner . Run ( ) [ 0 ] . GetValue ( ) as long [ , ] ;
288
+ long [ , ] output = networkOutput [ 0 ] . GetValue ( ) as long [ , ] ;
260
289
int i = 0 ;
261
290
foreach ( int k in agentKeys )
262
291
{
0 commit comments