@@ -112,10 +112,8 @@ public static int GetNumVisualInputs(this Model model)
112
112
/// <param name="model">
113
113
/// The Barracuda engine model for loading static parameters.
114
114
/// </param>
115
- /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
116
- /// deterministic. </param>
117
115
/// <returns>Array of the output tensor names of the model</returns>
118
- public static string [ ] GetOutputNames ( this Model model , bool deterministicInference = false )
116
+ public static string [ ] GetOutputNames ( this Model model )
119
117
{
120
118
var names = new List < string > ( ) ;
121
119
@@ -124,13 +122,13 @@ public static string[] GetOutputNames(this Model model, bool deterministicInfere
124
122
return names . ToArray ( ) ;
125
123
}
126
124
127
- if ( model . HasContinuousOutputs ( deterministicInference ) )
125
+ if ( model . HasContinuousOutputs ( ) )
128
126
{
129
- names . Add ( model . ContinuousOutputName ( deterministicInference ) ) ;
127
+ names . Add ( model . ContinuousOutputName ( ) ) ;
130
128
}
131
- if ( model . HasDiscreteOutputs ( deterministicInference ) )
129
+ if ( model . HasDiscreteOutputs ( ) )
132
130
{
133
- names . Add ( model . DiscreteOutputName ( deterministicInference ) ) ;
131
+ names . Add ( model . DiscreteOutputName ( ) ) ;
134
132
}
135
133
136
134
var modelVersion = model . GetVersion ( ) ;
@@ -151,10 +149,8 @@ public static string[] GetOutputNames(this Model model, bool deterministicInfere
151
149
/// <param name="model">
152
150
/// The Barracuda engine model for loading static parameters.
153
151
/// </param>
154
- /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
155
- /// deterministic. </param>
156
152
/// <returns>True if the model has continuous action outputs.</returns>
157
- public static bool HasContinuousOutputs ( this Model model , bool deterministicInference = false )
153
+ public static bool HasContinuousOutputs ( this Model model )
158
154
{
159
155
if ( model == null )
160
156
return false ;
@@ -164,13 +160,8 @@ public static bool HasContinuousOutputs(this Model model, bool deterministicInfe
164
160
}
165
161
else
166
162
{
167
- bool hasStochasticOutput = ! deterministicInference &&
168
- model . outputs . Contains ( TensorNames . ContinuousActionOutput ) ;
169
- bool hasDeterministicOutput = deterministicInference &&
170
- model . outputs . Contains ( TensorNames . DeterministicContinuousActionOutput ) ;
171
-
172
- return ( hasStochasticOutput || hasDeterministicOutput ) &&
173
- ( int ) model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) [ 0 ] > 0 ;
163
+ return model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
164
+ ( int ) model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) [ 0 ] > 0 ;
174
165
}
175
166
}
176
167
@@ -203,10 +194,8 @@ public static int ContinuousOutputSize(this Model model)
203
194
/// <param name="model">
204
195
/// The Barracuda engine model for loading static parameters.
205
196
/// </param>
206
- /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
207
- /// deterministic. </param>
208
197
/// <returns>Tensor name of continuous action output.</returns>
209
- public static string ContinuousOutputName ( this Model model , bool deterministicInference = false )
198
+ public static string ContinuousOutputName ( this Model model )
210
199
{
211
200
if ( model == null )
212
201
return null ;
@@ -216,7 +205,7 @@ public static string ContinuousOutputName(this Model model, bool deterministicIn
216
205
}
217
206
else
218
207
{
219
- return deterministicInference ? TensorNames . DeterministicContinuousActionOutput : TensorNames . ContinuousActionOutput ;
208
+ return TensorNames . ContinuousActionOutput ;
220
209
}
221
210
}
222
211
@@ -226,10 +215,8 @@ public static string ContinuousOutputName(this Model model, bool deterministicIn
226
215
/// <param name="model">
227
216
/// The Barracuda engine model for loading static parameters.
228
217
/// </param>
229
- /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
230
- /// deterministic. </param>
231
218
/// <returns>True if the model has discrete action outputs.</returns>
232
- public static bool HasDiscreteOutputs ( this Model model , bool deterministicInference = false )
219
+ public static bool HasDiscreteOutputs ( this Model model )
233
220
{
234
221
if ( model == null )
235
222
return false ;
@@ -239,12 +226,7 @@ public static bool HasDiscreteOutputs(this Model model, bool deterministicInfere
239
226
}
240
227
else
241
228
{
242
- bool hasStochasticOutput = ! deterministicInference &&
243
- model . outputs . Contains ( TensorNames . DiscreteActionOutput ) ;
244
- bool hasDeterministicOutput = deterministicInference &&
245
- model . outputs . Contains ( TensorNames . DeterministicDiscreteActionOutput ) ;
246
- return ( hasStochasticOutput || hasDeterministicOutput ) &&
247
- model . DiscreteOutputSize ( ) > 0 ;
229
+ return model . outputs . Contains ( TensorNames . DiscreteActionOutput ) && model . DiscreteOutputSize ( ) > 0 ;
248
230
}
249
231
}
250
232
@@ -297,10 +279,8 @@ public static int DiscreteOutputSize(this Model model)
297
279
/// <param name="model">
298
280
/// The Barracuda engine model for loading static parameters.
299
281
/// </param>
300
- /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
301
- /// deterministic. </param>
302
282
/// <returns>Tensor name of discrete action output.</returns>
303
- public static string DiscreteOutputName ( this Model model , bool deterministicInference = false )
283
+ public static string DiscreteOutputName ( this Model model )
304
284
{
305
285
if ( model == null )
306
286
return null ;
@@ -310,7 +290,7 @@ public static string DiscreteOutputName(this Model model, bool deterministicInfe
310
290
}
311
291
else
312
292
{
313
- return deterministicInference ? TensorNames . DeterministicDiscreteActionOutput : TensorNames . DiscreteActionOutput ;
293
+ return TensorNames . DiscreteActionOutput ;
314
294
}
315
295
}
316
296
@@ -336,11 +316,9 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
336
316
/// The Barracuda engine model for loading static parameters.
337
317
/// </param>
338
318
/// <param name="failedModelChecks">Output list of failure messages</param>
339
- ///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
340
- /// deterministic. </param>
319
+ ///
341
320
/// <returns>True if the model contains all the expected tensors.</returns>
342
- /// TODO: add checks for deterministic actions
343
- public static bool CheckExpectedTensors ( this Model model , List < FailedCheck > failedModelChecks , bool deterministicInference = false )
321
+ public static bool CheckExpectedTensors ( this Model model , List < FailedCheck > failedModelChecks )
344
322
{
345
323
// Check the presence of model version
346
324
var modelApiVersionTensor = model . GetTensorByName ( TensorNames . VersionNumber ) ;
@@ -365,9 +343,7 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
365
343
// Check the presence of action output tensor
366
344
if ( ! model . outputs . Contains ( TensorNames . ActionOutputDeprecated ) &&
367
345
! model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
368
- ! model . outputs . Contains ( TensorNames . DiscreteActionOutput ) &&
369
- ! model . outputs . Contains ( TensorNames . DeterministicContinuousActionOutput ) &&
370
- ! model . outputs . Contains ( TensorNames . DeterministicDiscreteActionOutput ) )
346
+ ! model . outputs . Contains ( TensorNames . DiscreteActionOutput ) )
371
347
{
372
348
failedModelChecks . Add (
373
349
FailedCheck . Warning ( "The model does not contain any Action Output Node." )
@@ -397,51 +373,22 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
397
373
}
398
374
else
399
375
{
400
- if ( model . outputs . Contains ( TensorNames . ContinuousActionOutput ) )
376
+ if ( model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
377
+ model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) == null )
401
378
{
402
- if ( model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) == null )
403
- {
404
- failedModelChecks . Add (
405
- FailedCheck . Warning ( "The model uses continuous action but does not contain Continuous Action Output Shape Node." )
406
- ) ;
407
- return false ;
408
- }
409
-
410
- else if ( ! model . HasContinuousOutputs ( deterministicInference ) )
411
- {
412
- var actionType = deterministicInference ? "deterministic" : "stochastic" ;
413
- var actionName = deterministicInference ? "Deterministic" : "" ;
414
- failedModelChecks . Add (
415
- FailedCheck . Warning ( $ "The model uses { actionType } inference but does not contain { actionName } Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
379
+ failedModelChecks . Add (
380
+ FailedCheck . Warning ( "The model uses continuous action but does not contain Continuous Action Output Shape Node." )
416
381
) ;
417
- return false ;
418
- }
382
+ return false ;
419
383
}
420
-
421
- if ( model . outputs . Contains ( TensorNames . DiscreteActionOutput ) )
384
+ if ( model . outputs . Contains ( TensorNames . DiscreteActionOutput ) &&
385
+ model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) == null )
422
386
{
423
- if ( model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) == null )
424
- {
425
- failedModelChecks . Add (
426
- FailedCheck . Warning ( "The model uses discrete action but does not contain Discrete Action Output Shape Node." )
427
- ) ;
428
- return false ;
429
- }
430
- else if ( ! model . HasDiscreteOutputs ( deterministicInference ) )
431
- {
432
- var actionType = deterministicInference ? "deterministic" : "stochastic" ;
433
- var actionName = deterministicInference ? "Deterministic" : "" ;
434
- failedModelChecks . Add (
435
- FailedCheck . Warning ( $ "The model uses { actionType } inference but does not contain { actionName } Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
387
+ failedModelChecks . Add (
388
+ FailedCheck . Warning ( "The model uses discrete action but does not contain Discrete Action Output Shape Node." )
436
389
) ;
437
- return false ;
438
- }
439
-
390
+ return false ;
440
391
}
441
-
442
-
443
-
444
-
445
392
}
446
393
return true ;
447
394
}
0 commit comments