Skip to content

Commit b79d44e

Browse files
Exporting all the branches size instead of omly the sum (#5092)
1 parent fb459a7 commit b79d44e

File tree

3 files changed

+73
-7
lines changed

3 files changed

+73
-7
lines changed

com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ public static bool HasDiscreteOutputs(this Model model)
239239
else
240240
{
241241
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
242-
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
242+
(int)model.DiscreteOutputSize() > 0;
243243
}
244244
}
245245

@@ -262,7 +262,19 @@ public static int DiscreteOutputSize(this Model model)
262262
else
263263
{
264264
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
265-
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
265+
if (discreteOutputShape == null)
266+
{
267+
return 0;
268+
}
269+
else
270+
{
271+
int result = 0;
272+
for (int i = 0; i < discreteOutputShape.length; i++)
273+
{
274+
result += (int)discreteOutputShape[i];
275+
}
276+
return result;
277+
}
266278
}
267279
}
268280

com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,19 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
716716
{
717717
failedModelChecks.Add(continuousError);
718718
}
719-
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
720-
var discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
719+
FailedCheck discreteError = null;
720+
var modelApiVersion = model.GetVersion();
721+
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
722+
{
723+
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
724+
discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
725+
}
726+
if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
727+
{
728+
var modeDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
729+
discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modeDiscreteBranches);
730+
}
731+
721732
if (discreteError != null)
722733
{
723734
failedModelChecks.Add(discreteError);
@@ -733,14 +744,58 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
733744
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
734745
/// </param>
735746
/// <param name="actuatorComponents">Array of attached actuator components.</param>
747+
/// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
748+
/// </param>
749+
/// <returns>
750+
/// If the Check failed, returns a string containing information about why the
751+
/// check failed. If the check passed, returns null.
752+
/// </returns>
753+
static FailedCheck CheckDiscreteActionOutputShape(
754+
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches)
755+
{
756+
757+
var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList();
758+
foreach (var actuatorComponent in actuatorComponents)
759+
{
760+
var actionSpec = actuatorComponent.ActionSpec;
761+
discreteActionBranches.AddRange(actionSpec.BranchSizes);
762+
}
763+
764+
if (modelDiscreteBranches.length != discreteActionBranches.Count)
765+
{
766+
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
767+
$"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranches.length}."
768+
);
769+
}
770+
771+
for (int i = 0; i < modelDiscreteBranches.length; i++)
772+
{
773+
if (modelDiscreteBranches[i] != discreteActionBranches[i])
774+
{
775+
return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " +
776+
$"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} "
777+
);
778+
}
779+
}
780+
return null;
781+
}
782+
783+
/// <summary>
784+
/// Checks that the shape of the discrete action output is the same in the
785+
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
786+
/// </summary>
787+
/// <param name="brainParameters">
788+
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
789+
/// </param>
790+
/// <param name="actuatorComponents">Array of attached actuator components.</param>
736791
/// <param name="modelSumDiscreteBranchSizes">
737792
/// The size of the discrete action output that is expected by the model.
738793
/// </param>
739794
/// <returns>
740795
/// If the Check failed, returns a string containing information about why the
741796
/// check failed. If the check passed, returns null.
742797
/// </returns>
743-
static FailedCheck CheckDiscreteActionOutputShape(
798+
static FailedCheck CheckDiscreteActionOutputShapeLegacy(
744799
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes)
745800
{
746801
// TODO: check each branch size instead of sum of branch sizes

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,8 @@ def __init__(
320320
self.continuous_act_size_vector = torch.nn.Parameter(
321321
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False
322322
)
323-
# TODO: export list of branch sizes instead of sum
324323
self.discrete_act_size_vector = torch.nn.Parameter(
325-
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False
324+
torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False
326325
)
327326
self.act_size_vector_deprecated = torch.nn.Parameter(
328327
torch.Tensor(

0 commit comments

Comments
 (0)