Skip to content

Commit 38b65b0

Browse files
authored
Revert "Deterministic actions python training (#5619)" (#5622)
This reverts commit 0de327c.
1 parent 0de327c commit 38b65b0

29 files changed

+66
-469
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ and this project adheres to
3030
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3131
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
3232

33-
34-
- Deterministic action selection is now supported during training and inference(#5619)
35-
- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
36-
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5597)
37-
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5593)
38-
- Support inference with deterministic action selection in editor (#5599)
3933
### Bug Fixes
4034
- Fixed a bug where the critics were not being normalized during training. (#5595)
4135
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)

com.unity.ml-agents/Editor/BehaviorParametersEditor.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
2525
const string k_BrainParametersName = "m_BrainParameters";
2626
const string k_ModelName = "m_Model";
2727
const string k_InferenceDeviceName = "m_InferenceDevice";
28-
const string k_DeterministicInference = "m_DeterministicInference";
2928
const string k_BehaviorTypeName = "m_BehaviorType";
3029
const string k_TeamIdName = "TeamId";
3130
const string k_UseChildSensorsName = "m_UseChildSensors";
@@ -69,7 +68,6 @@ public override void OnInspectorGUI()
6968
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
7069
EditorGUI.indentLevel++;
7170
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
72-
EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true);
7371
EditorGUI.indentLevel--;
7472
}
7573
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
@@ -158,7 +156,7 @@ void DisplayFailedModelChecks()
158156
{
159157
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
160158
barracudaModel, brainParameters, sensors, actuatorComponents,
161-
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference
159+
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
162160
);
163161
foreach (var check in failedChecks)
164162
{

com.unity.ml-agents/Runtime/Academy.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,16 +616,14 @@ void EnvironmentReset()
616616
/// <param name="inferenceDevice">
617617
/// The inference device (CPU or GPU) the ModelRunner will use.
618618
/// </param>
619-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
620-
/// Deterministic. </param>
621619
/// <returns> The ModelRunner compatible with the input settings.</returns>
622620
internal ModelRunner GetOrCreateModelRunner(
623-
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false)
621+
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice)
624622
{
625623
var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice));
626624
if (modelRunner == null)
627625
{
628-
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference);
626+
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed);
629627
m_ModelRunners.Add(modelRunner);
630628
m_InferenceSeed++;
631629
}

com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs

Lines changed: 27 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ public static int GetNumVisualInputs(this Model model)
112112
/// <param name="model">
113113
/// The Barracuda engine model for loading static parameters.
114114
/// </param>
115-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
116-
/// deterministic. </param>
117115
/// <returns>Array of the output tensor names of the model</returns>
118-
public static string[] GetOutputNames(this Model model, bool deterministicInference = false)
116+
public static string[] GetOutputNames(this Model model)
119117
{
120118
var names = new List<string>();
121119

@@ -124,13 +122,13 @@ public static string[] GetOutputNames(this Model model, bool deterministicInfere
124122
return names.ToArray();
125123
}
126124

127-
if (model.HasContinuousOutputs(deterministicInference))
125+
if (model.HasContinuousOutputs())
128126
{
129-
names.Add(model.ContinuousOutputName(deterministicInference));
127+
names.Add(model.ContinuousOutputName());
130128
}
131-
if (model.HasDiscreteOutputs(deterministicInference))
129+
if (model.HasDiscreteOutputs())
132130
{
133-
names.Add(model.DiscreteOutputName(deterministicInference));
131+
names.Add(model.DiscreteOutputName());
134132
}
135133

136134
var modelVersion = model.GetVersion();
@@ -151,10 +149,8 @@ public static string[] GetOutputNames(this Model model, bool deterministicInfere
151149
/// <param name="model">
152150
/// The Barracuda engine model for loading static parameters.
153151
/// </param>
154-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
155-
/// deterministic. </param>
156152
/// <returns>True if the model has continuous action outputs.</returns>
157-
public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false)
153+
public static bool HasContinuousOutputs(this Model model)
158154
{
159155
if (model == null)
160156
return false;
@@ -164,13 +160,8 @@ public static bool HasContinuousOutputs(this Model model, bool deterministicInfe
164160
}
165161
else
166162
{
167-
bool hasStochasticOutput = !deterministicInference &&
168-
model.outputs.Contains(TensorNames.ContinuousActionOutput);
169-
bool hasDeterministicOutput = deterministicInference &&
170-
model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput);
171-
172-
return (hasStochasticOutput || hasDeterministicOutput) &&
173-
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
163+
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
164+
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
174165
}
175166
}
176167

@@ -203,10 +194,8 @@ public static int ContinuousOutputSize(this Model model)
203194
/// <param name="model">
204195
/// The Barracuda engine model for loading static parameters.
205196
/// </param>
206-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
207-
/// deterministic. </param>
208197
/// <returns>Tensor name of continuous action output.</returns>
209-
public static string ContinuousOutputName(this Model model, bool deterministicInference = false)
198+
public static string ContinuousOutputName(this Model model)
210199
{
211200
if (model == null)
212201
return null;
@@ -216,7 +205,7 @@ public static string ContinuousOutputName(this Model model, bool deterministicIn
216205
}
217206
else
218207
{
219-
return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput;
208+
return TensorNames.ContinuousActionOutput;
220209
}
221210
}
222211

@@ -226,10 +215,8 @@ public static string ContinuousOutputName(this Model model, bool deterministicIn
226215
/// <param name="model">
227216
/// The Barracuda engine model for loading static parameters.
228217
/// </param>
229-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
230-
/// deterministic. </param>
231218
/// <returns>True if the model has discrete action outputs.</returns>
232-
public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false)
219+
public static bool HasDiscreteOutputs(this Model model)
233220
{
234221
if (model == null)
235222
return false;
@@ -239,12 +226,7 @@ public static bool HasDiscreteOutputs(this Model model, bool deterministicInfere
239226
}
240227
else
241228
{
242-
bool hasStochasticOutput = !deterministicInference &&
243-
model.outputs.Contains(TensorNames.DiscreteActionOutput);
244-
bool hasDeterministicOutput = deterministicInference &&
245-
model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput);
246-
return (hasStochasticOutput || hasDeterministicOutput) &&
247-
model.DiscreteOutputSize() > 0;
229+
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0;
248230
}
249231
}
250232

@@ -297,10 +279,8 @@ public static int DiscreteOutputSize(this Model model)
297279
/// <param name="model">
298280
/// The Barracuda engine model for loading static parameters.
299281
/// </param>
300-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
301-
/// deterministic. </param>
302282
/// <returns>Tensor name of discrete action output.</returns>
303-
public static string DiscreteOutputName(this Model model, bool deterministicInference = false)
283+
public static string DiscreteOutputName(this Model model)
304284
{
305285
if (model == null)
306286
return null;
@@ -310,7 +290,7 @@ public static string DiscreteOutputName(this Model model, bool deterministicInfe
310290
}
311291
else
312292
{
313-
return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput;
293+
return TensorNames.DiscreteActionOutput;
314294
}
315295
}
316296

@@ -336,11 +316,9 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
336316
/// The Barracuda engine model for loading static parameters.
337317
/// </param>
338318
/// <param name="failedModelChecks">Output list of failure messages</param>
339-
///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
340-
/// deterministic. </param>
319+
///
341320
/// <returns>True if the model contains all the expected tensors.</returns>
342-
/// TODO: add checks for deterministic actions
343-
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false)
321+
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
344322
{
345323
// Check the presence of model version
346324
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
@@ -365,9 +343,7 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
365343
// Check the presence of action output tensor
366344
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
367345
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
368-
!model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
369-
!model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
370-
!model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
346+
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
371347
{
372348
failedModelChecks.Add(
373349
FailedCheck.Warning("The model does not contain any Action Output Node.")
@@ -397,51 +373,22 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
397373
}
398374
else
399375
{
400-
if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
376+
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
377+
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
401378
{
402-
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
403-
{
404-
failedModelChecks.Add(
405-
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
406-
);
407-
return false;
408-
}
409-
410-
else if (!model.HasContinuousOutputs(deterministicInference))
411-
{
412-
var actionType = deterministicInference ? "deterministic" : "stochastic";
413-
var actionName = deterministicInference ? "Deterministic" : "";
414-
failedModelChecks.Add(
415-
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
379+
failedModelChecks.Add(
380+
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
416381
);
417-
return false;
418-
}
382+
return false;
419383
}
420-
421-
if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
384+
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
385+
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
422386
{
423-
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
424-
{
425-
failedModelChecks.Add(
426-
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
427-
);
428-
return false;
429-
}
430-
else if (!model.HasDiscreteOutputs(deterministicInference))
431-
{
432-
var actionType = deterministicInference ? "deterministic" : "stochastic";
433-
var actionName = deterministicInference ? "Deterministic" : "";
434-
failedModelChecks.Add(
435-
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
387+
failedModelChecks.Add(
388+
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
436389
);
437-
return false;
438-
}
439-
390+
return false;
440391
}
441-
442-
443-
444-
445392
}
446393
return true;
447394
}

com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,14 @@ public static FailedCheck CheckModelVersion(Model model)
122122
/// <param name="actuatorComponents">Attached actuator components</param>
123123
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
124124
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
125-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
126-
/// deterministic. </param>
127125
/// <returns>A IEnumerable of the checks that failed</returns>
128126
public static IEnumerable<FailedCheck> CheckModel(
129127
Model model,
130128
BrainParameters brainParameters,
131129
ISensor[] sensors,
132130
ActuatorComponent[] actuatorComponents,
133131
int observableAttributeTotalSize = 0,
134-
BehaviorType behaviorType = BehaviorType.Default,
135-
bool deterministicInference = false
132+
BehaviorType behaviorType = BehaviorType.Default
136133
)
137134
{
138135
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
@@ -151,7 +148,7 @@ public static IEnumerable<FailedCheck> CheckModel(
151148
return failedModelChecks;
152149
}
153150

154-
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference);
151+
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
155152
if (!hasExpectedTensors)
156153
{
157154
return failedModelChecks;
@@ -184,7 +181,7 @@ public static IEnumerable<FailedCheck> CheckModel(
184181
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
185182
{
186183
failedModelChecks.AddRange(
187-
CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference)
184+
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
188185
);
189186
failedModelChecks.AddRange(
190187
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
@@ -198,7 +195,7 @@ public static IEnumerable<FailedCheck> CheckModel(
198195
);
199196

200197
failedModelChecks.AddRange(
201-
CheckOutputTensorPresence(model, memorySize, deterministicInference)
198+
CheckOutputTensorPresence(model, memorySize)
202199
);
203200
return failedModelChecks;
204201
}
@@ -321,17 +318,14 @@ ISensor[] sensors
321318
/// The memory size that the model is expecting.
322319
/// </param>
323320
/// <param name="sensors">Array of attached sensor components</param>
324-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
325-
/// Deterministic. </param>
326321
/// <returns>
327322
/// A IEnumerable of the checks that failed
328323
/// </returns>
329324
static IEnumerable<FailedCheck> CheckInputTensorPresence(
330325
Model model,
331326
BrainParameters brainParameters,
332327
int memory,
333-
ISensor[] sensors,
334-
bool deterministicInference = false
328+
ISensor[] sensors
335329
)
336330
{
337331
var failedModelChecks = new List<FailedCheck>();
@@ -362,7 +356,7 @@ static IEnumerable<FailedCheck> CheckInputTensorPresence(
362356
}
363357

364358
// If the model uses discrete control but does not have an input for action masks
365-
if (model.HasDiscreteOutputs(deterministicInference))
359+
if (model.HasDiscreteOutputs())
366360
{
367361
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
368362
{
@@ -382,19 +376,17 @@ static IEnumerable<FailedCheck> CheckInputTensorPresence(
382376
/// The Barracuda engine model for loading static parameters
383377
/// </param>
384378
/// <param name="memory">The memory size that the model is expecting/</param>
385-
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
386-
/// deterministic. </param>
387379
/// <returns>
388380
/// A IEnumerable of the checks that failed
389381
/// </returns>
390-
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false)
382+
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
391383
{
392384
var failedModelChecks = new List<FailedCheck>();
393385

394386
// If there is no Recurrent Output but the model is Recurrent.
395387
if (memory > 0)
396388
{
397-
var allOutputs = model.GetOutputNames(deterministicInference).ToList();
389+
var allOutputs = model.GetOutputNames().ToList();
398390
if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput))
399391
{
400392
failedModelChecks.Add(

0 commit comments

Comments
 (0)