@@ -716,8 +716,19 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
716
716
{
717
717
failedModelChecks . Add ( continuousError ) ;
718
718
}
719
- var modelSumDiscreteBranchSizes = model . DiscreteOutputSize ( ) ;
720
- var discreteError = CheckDiscreteActionOutputShape ( brainParameters , actuatorComponents , modelSumDiscreteBranchSizes ) ;
719
+ FailedCheck discreteError = null ;
720
+ var modelApiVersion = model . GetVersion ( ) ;
721
+ if ( modelApiVersion == ( int ) ModelApiVersion . MLAgents1_0 )
722
+ {
723
+ var modelSumDiscreteBranchSizes = model . DiscreteOutputSize ( ) ;
724
+ discreteError = CheckDiscreteActionOutputShapeLegacy ( brainParameters , actuatorComponents , modelSumDiscreteBranchSizes ) ;
725
+ }
726
+ if ( modelApiVersion == ( int ) ModelApiVersion . MLAgents2_0 )
727
+ {
728
+ var modeDiscreteBranches = model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) ;
729
+ discreteError = CheckDiscreteActionOutputShape ( brainParameters , actuatorComponents , modeDiscreteBranches ) ;
730
+ }
731
+
721
732
if ( discreteError != null )
722
733
{
723
734
failedModelChecks . Add ( discreteError ) ;
@@ -733,14 +744,58 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
733
744
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
734
745
/// </param>
735
746
/// <param name="actuatorComponents">Array of attached actuator components.</param>
747
+ /// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
748
+ /// </param>
749
+ /// <returns>
750
+ /// If the Check failed, returns a string containing information about why the
751
+ /// check failed. If the check passed, returns null.
752
+ /// </returns>
753
+ static FailedCheck CheckDiscreteActionOutputShape (
754
+ BrainParameters brainParameters , ActuatorComponent [ ] actuatorComponents , Tensor modelDiscreteBranches )
755
+ {
756
+
757
+ var discreteActionBranches = brainParameters . ActionSpec . BranchSizes . ToList ( ) ;
758
+ foreach ( var actuatorComponent in actuatorComponents )
759
+ {
760
+ var actionSpec = actuatorComponent . ActionSpec ;
761
+ discreteActionBranches . AddRange ( actionSpec . BranchSizes ) ;
762
+ }
763
+
764
+ if ( modelDiscreteBranches . length != discreteActionBranches . Count )
765
+ {
766
+ return FailedCheck . Warning ( "Discrete Action Size of the model does not match. The BrainParameters expect " +
767
+ $ "{ discreteActionBranches . Count } branches but the model contains { modelDiscreteBranches . length } ."
768
+ ) ;
769
+ }
770
+
771
+ for ( int i = 0 ; i < modelDiscreteBranches . length ; i ++ )
772
+ {
773
+ if ( modelDiscreteBranches [ i ] != discreteActionBranches [ i ] )
774
+ {
775
+ return FailedCheck . Warning ( $ "The number of Discrete Actions of branch { i } does not match. " +
776
+ $ "Was expecting { discreteActionBranches [ i ] } but the model contains { modelDiscreteBranches [ i ] } "
777
+ ) ;
778
+ }
779
+ }
780
+ return null ;
781
+ }
782
+
783
+ /// <summary>
784
+ /// Checks that the shape of the discrete action output is the same in the
785
+ /// model and in the Brain Parameters. Tests the models created with the API of version 1.X
786
+ /// </summary>
787
+ /// <param name="brainParameters">
788
+ /// The BrainParameters that are used verify the compatibility with the InferenceEngine
789
+ /// </param>
790
+ /// <param name="actuatorComponents">Array of attached actuator components.</param>
736
791
/// <param name="modelSumDiscreteBranchSizes">
737
792
/// The size of the discrete action output that is expected by the model.
738
793
/// </param>
739
794
/// <returns>
740
795
/// If the Check failed, returns a string containing information about why the
741
796
/// check failed. If the check passed, returns null.
742
797
/// </returns>
743
- static FailedCheck CheckDiscreteActionOutputShape (
798
+ static FailedCheck CheckDiscreteActionOutputShapeLegacy (
744
799
BrainParameters brainParameters , ActuatorComponent [ ] actuatorComponents , int modelSumDiscreteBranchSizes )
745
800
{
746
801
// TODO: check each branch size instead of sum of branch sizes
0 commit comments