diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs index cddbc27439..f956d747ad 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs @@ -15,9 +15,9 @@ public class BasicSensorComponent : SensorComponent /// Creates a BasicSensor. /// /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { - return new BasicSensor(basicController); + return new ISensor[] { new BasicSensor(basicController) }; } } diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx index 7d32504590..011c232473 100644 Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx differ diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx index d8ec297d15..c20ba1c999 100644 Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx differ diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs index 504e0c3d19..4a12657505 100644 --- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs +++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs @@ -23,14 +23,14 @@ public string SensorName /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { m_Sensor = new TestTextureSensor(TestTexture, SensorName, CompressionType); if (ObservationStacks != 1) { - return new StackingSensor(m_Sensor, ObservationStacks); + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; } - return m_Sensor; + return new ISensor[] { m_Sensor }; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs index 647e56fddb..8e5e81557a 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs @@ -4,6 +4,14 @@ namespace Unity.MLAgents.Extensions.Match3 { + + /// + /// Delegate that provides integer values at a given (x,y) coordinate. + /// + /// + /// + public delegate int GridValueProvider(int x, int y); + /// /// Type of observations to generate. /// @@ -32,66 +40,68 @@ public enum Match3ObservationType /// /// Sensor for Match3 games. Can generate either vector, compressed visual, - /// or uncompressed visual observations. Uses AbstractBoard.GetCellType() - /// and AbstractBoard.GetSpecialType() to determine the observation values. + /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values. /// public class Match3Sensor : ISensor, IBuiltInSensor { private Match3ObservationType m_ObservationType; - private AbstractBoard m_Board; private ObservationSpec m_ObservationSpec; - private int[] m_SparseChannelMapping; private string m_Name; private int m_Rows; private int m_Columns; - private int m_NumCellTypes; - private int m_NumSpecialTypes; - - private int SpecialTypeSize - { - get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; } - } + private GridValueProvider m_GridValues; + private int m_OneHotSize; /// - /// Create a sensor for the board with the specified observation type. + /// Create a sensor for the GridValueProvider with the specified observation type. /// - /// - /// - /// - public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name) + /// + /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling + /// the constructor directly. + /// + /// The abstract board. This is only used to get the size. + /// The GridValueProvider, should be either board.GetCellType or board.GetSpecialType. + /// The number of possible values that the GridValueProvider can return. + /// Whether to produce vector or visual observations + /// Name of the sensor. + public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) { - m_Board = board; m_Name = name; m_Rows = board.Rows; m_Columns = board.Columns; - m_NumCellTypes = board.NumCellTypes; - m_NumSpecialTypes = board.NumSpecialTypes; + m_GridValues = gvp; + m_OneHotSize = oneHotSize; m_ObservationType = obsType; m_ObservationSpec = obsType == Match3ObservationType.Vector - ? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize)) - : ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize); - - // See comment in GetCompressedObservation() - var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3); - m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize]; - // If we have 4 cell types and 2 special types (3 special size), we'd have - // [0, 1, 2, 3, -1, -1, 4, 5, 6] - for (var i = 0; i < m_NumCellTypes; i++) - { - m_SparseChannelMapping[i] = i; - } + ? ObservationSpec.Vector(m_Rows * m_Columns * oneHotSize) + : ObservationSpec.Visual(m_Rows, m_Columns, oneHotSize); + } - for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++) - { - m_SparseChannelMapping[i] = -1; - } + /// + /// Create a sensor that encodes the board cells as observations. + /// + /// The abstract board. + /// Whether to produce vector or visual observations + /// Name of the sensor. + /// + public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) + { + return new Match3Sensor(board, board.GetCellType, board.NumCellTypes, obsType, name); + } - for (var i = 0; i < SpecialTypeSize; i++) - { - m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes; - } + /// + /// Create a sensor that encodes the cell special types as observations. + /// + /// The abstract board. + /// Whether to produce vector or visual observations + /// Name of the sensor. + /// + public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) + { + var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1; + return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); } /// @@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec() /// public int Write(ObservationWriter writer) { - if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes) - { - Debug.LogWarning( - $"Board shape changes since sensor initialization. This may cause unexpected results. " + - $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " + - $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}" - ); - } + // if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes) + // { + // Debug.LogWarning( + // $"Board shape changes since sensor initialization. This may cause unexpected results. " + + // $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " + + // $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}" + // ); + // } if (m_ObservationType == Match3ObservationType.Vector) { @@ -119,22 +129,13 @@ public int Write(ObservationWriter writer) { for (var c = 0; c < m_Columns; c++) { - var val = m_Board.GetCellType(r, c); - for (var i = 0; i < m_NumCellTypes; i++) + var val = m_GridValues(r, c); + + for (var i = 0; i < m_OneHotSize; i++) { writer[offset] = (i == val) ? 1.0f : 0.0f; offset++; } - - if (m_NumSpecialTypes > 0) - { - var special = m_Board.GetSpecialType(r, c); - for (var i = 0; i < SpecialTypeSize; i++) - { - writer[offset] = (i == special) ? 1.0f : 0.0f; - offset++; - } - } } } @@ -148,22 +149,12 @@ public int Write(ObservationWriter writer) { for (var c = 0; c < m_Columns; c++) { - var val = m_Board.GetCellType(r, c); - for (var i = 0; i < m_NumCellTypes; i++) + var val = m_GridValues(r, c); + for (var i = 0; i < m_OneHotSize; i++) { writer[r, c, i] = (i == val) ? 1.0f : 0.0f; offset++; } - - if (m_NumSpecialTypes > 0) - { - var special = m_Board.GetSpecialType(r, c); - for (var i = 0; i < SpecialTypeSize; i++) - { - writer[offset] = (i == special) ? 1.0f : 0.0f; - offset++; - } - } } } @@ -185,17 +176,10 @@ public byte[] GetCompressedObservation() // fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for // the special types). Note that we have to also implement the sparse channel mapping. // Optimize this it later. - var numCellImages = (m_NumCellTypes + 2) / 3; + var numCellImages = (m_OneHotSize + 2) / 3; for (var i = 0; i < numCellImages; i++) { - converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i); - bytesOut.AddRange(tempTexture.EncodeToPNG()); - } - - var numSpecialImages = (SpecialTypeSize + 2) / 3; - for (var i = 0; i < numSpecialImages; i++) - { - converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i); + converter.EncodeToTexture(m_GridValues, tempTexture, 3 * i); bytesOut.AddRange(tempTexture.EncodeToPNG()); } @@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType() /// public CompressionSpec GetCompressionSpec() { - return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping); + return new CompressionSpec(GetCompressionType()); } /// @@ -265,9 +249,6 @@ internal class OneHotToTextureUtil int m_Width; private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; - public delegate int GridValueProvider(int x, int y); - - public OneHotToTextureUtil(int height, int width) { m_Colors = new Color[height * width]; diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs index 4dbc1303c2..9007872e23 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs @@ -21,10 +21,20 @@ public class Match3SensorComponent : SensorComponent public Match3ObservationType ObservationType = Match3ObservationType.Vector; /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { var board = GetComponent(); - return new Match3Sensor(board, ObservationType, SensorName); + var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)"); + if (board.NumSpecialTypes > 0) + { + var specialSensor = + Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)"); + return new ISensor[] { cellSensor, specialSensor }; + } + else + { + return new ISensor[] { cellSensor }; + } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs index b8b2ac8017..c62ea60b70 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs @@ -16,9 +16,9 @@ public class ArticulationBodySensorComponent : SensorComponent /// Creates a PhysicsBodySensor. /// /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { - return new PhysicsBodySensor(RootBody, Settings, sensorName); + return new ISensor[] {new PhysicsBodySensor(RootBody, Settings, sensorName)}; } } diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs index 8224139d5e..43e66f0cac 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs @@ -267,9 +267,9 @@ public enum GridDepthType { Channel, ChannelHot }; private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f); /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { - return this; + return new ISensor[] { this }; } /// diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs index 03da767b8b..e201125bff 100644 --- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs @@ -39,10 +39,10 @@ public class RigidBodySensorComponent : SensorComponent /// Creates a PhysicsBodySensor. /// /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName; - return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName); + return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) }; } /// diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs index 80438f857e..72e20a85ae 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs @@ -13,6 +13,8 @@ public class Match3SensorTests // Whether the expected PNG data should be written to a file. // Only set this to true if the compressed observation format changes. private bool WritePNGDataToFile = false; + private const string k_CellObservationPng = "match3obs"; + private const string k_SpecialObservationPng = "match3obs_special"; [Test] public void TestVectorObservations() @@ -27,7 +29,7 @@ public void TestVectorObservations() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.Vector; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3 * 3 * 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); @@ -60,18 +62,35 @@ public void TestVectorObservationsSpecial() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.Vector; - var sensor = sensorComponent.CreateSensor(); + var sensors = sensorComponent.CreateSensors(); + var cellSensor = sensors[0]; + var specialSensor = sensors[1]; - var expectedShape = new InplaceArray(3 * 3 * (2 + 3)); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); - var expectedObs = new float[] { - 1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */ - 1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */ - 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */ - }; - SensorTestHelper.CompareObservation(sensor, expectedObs); + var expectedShape = new InplaceArray(3 * 3 * 2); + Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); + + var expectedObs = new float[] + { + 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs); + } + { + var expectedShape = new InplaceArray(3 * 3 * 3); + Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); + + var expectedObs = new float[] + { + 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ + 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ + 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs); + } } [Test] @@ -87,7 +106,7 @@ public void TestVisualObservations() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3, 3, 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); @@ -130,28 +149,54 @@ public void TestVisualObservationsSpecial() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; - var sensor = sensorComponent.CreateSensor(); + var sensors = sensorComponent.CreateSensors(); + var cellSensor = sensors[0]; + var specialSensor = sensors[1]; - var expectedShape = new InplaceArray(3, 3, 2 + 3); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + { + var expectedShape = new InplaceArray(3, 3, 2); + Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); - Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); + Assert.AreEqual(SensorCompressionType.None, cellSensor.GetCompressionSpec().SensorCompressionType); - var expectedObs = new float[] - { - 1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */ - 1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */ - 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */ - }; - SensorTestHelper.CompareObservation(sensor, expectedObs); + var expectedObs = new float[] + { + 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs); - var expectedObs3D = new float[,,] + var expectedObs3D = new float[,,] + { + {{1, 0}, {0, 1}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + {{1, 0}, {1, 0}, {1, 0}}, + }; + SensorTestHelper.CompareObservation(cellSensor, expectedObs3D); + } { - {{1, 0, 1, 0, 0}, {0, 1, 1, 0, 0}, {1, 0, 1, 0, 0}}, - {{1, 0, 0, 0, 1}, {1, 0, 1, 0, 0}, {1, 0, 1, 0, 0}}, - {{1, 0, 1, 0, 0}, {1, 0, 0, 1, 0}, {1, 0, 1, 0, 0}}, - }; - SensorTestHelper.CompareObservation(sensor, expectedObs3D); + var expectedShape = new InplaceArray(3, 3, 3); + Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.None, specialSensor.GetCompressionSpec().SensorCompressionType); + + var expectedObs = new float[] + { + 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ + 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ + 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs); + + var expectedObs3D = new float[,,] + { + {{1, 0, 0}, {1, 0, 0}, {1, 0, 0}}, + {{0, 0, 1}, {1, 0, 0}, {1, 0, 0}}, + {{1, 0, 0}, {0, 1, 0}, {1, 0, 0}}, + }; + SensorTestHelper.CompareObservation(specialSensor, expectedObs3D); + } } [Test] @@ -167,7 +212,7 @@ public void TestCompressedVisualObservations() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3, 3, 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); @@ -178,10 +223,10 @@ public void TestCompressedVisualObservations() if (WritePNGDataToFile) { // Enable this if the format of the observation changes - SavePNGs(pngData, "match3obs"); + SavePNGs(pngData, k_CellObservationPng); } - var expectedPng = LoadPNGs("match3obs", 1); + var expectedPng = LoadPNGs(k_CellObservationPng, 1); Assert.AreEqual(expectedPng, pngData); } @@ -204,22 +249,30 @@ public void TestCompressedVisualObservationsSpecial() var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; - var sensor = sensorComponent.CreateSensor(); - - var expectedShape = new InplaceArray(3, 3, 2 + 3); - Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + var sensors = sensorComponent.CreateSensors(); - Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); + var paths = new[] { k_CellObservationPng, k_SpecialObservationPng }; + var expectedChannels = new[] { 2, 3 }; - var concatenatedPngData = sensor.GetCompressedObservation(); - var pathPrefix = "match3obs_special"; - if (WritePNGDataToFile) + for (var i = 0; i < 2; i++) { - // Enable this if the format of the observation changes - SavePNGs(concatenatedPngData, pathPrefix); + var sensor = sensors[i]; + var expectedShape = new InplaceArray(3, 3, expectedChannels[i]); + Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); + + Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); + + var pngData = sensor.GetCompressedObservation(); + if (WritePNGDataToFile) + { + // Enable this if the format of the observation changes + SavePNGs(pngData, paths[i]); + } + + var expectedPng = LoadPNGs(paths[i], 1); + Assert.AreEqual(expectedPng, pngData); } - var expectedPng = LoadPNGs(pathPrefix, 2); - Assert.AreEqual(expectedPng, concatenatedPngData); + } /// diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta index 16bf958b40..227ce7050f 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta @@ -1,9 +1,9 @@ fileFormatVersion: 2 guid: 3e1767bf6c63e46b1a16404dc1afe508 TextureImporter: - fileIDToRecycleName: {} + internalIDToNameTable: [] externalObjects: {} - serializedVersion: 9 + serializedVersion: 11 mipmaps: mipMapMode: 0 enableMipMap: 1 @@ -57,8 +57,9 @@ TextureImporter: maxTextureSizeSet: 0 compressionQualitySet: 0 textureFormatSet: 0 + applyGammaDecoding: 0 platformSettings: - - serializedVersion: 2 + - serializedVersion: 3 buildTarget: DefaultTexturePlatform maxTextureSize: 2048 resizeAlgorithm: 0 @@ -69,6 +70,7 @@ TextureImporter: allowsAlphaSplitting: 0 overridden: 0 androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 spriteSheet: serializedVersion: 2 sprites: [] @@ -76,10 +78,12 @@ TextureImporter: physicsShape: [] bones: [] spriteID: + internalID: 0 vertices: [] indices: edges: [] weights: [] + secondaryTextures: [] spritePackingTag: pSDRemoveMatte: 0 pSDShowRemoveMatteOption: 0 diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png index 3495ebb395..cb3869d6b8 100644 Binary files a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png and b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png differ diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta index 92f0f9a1cc..302444ad28 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta @@ -1,9 +1,9 @@ fileFormatVersion: 2 guid: 2e4ca31cf9cff4505acbefe44b621d6f TextureImporter: - fileIDToRecycleName: {} + internalIDToNameTable: [] externalObjects: {} - serializedVersion: 9 + serializedVersion: 11 mipmaps: mipMapMode: 0 enableMipMap: 1 @@ -57,8 +57,9 @@ TextureImporter: maxTextureSizeSet: 0 compressionQualitySet: 0 textureFormatSet: 0 + applyGammaDecoding: 0 platformSettings: - - serializedVersion: 2 + - serializedVersion: 3 buildTarget: DefaultTexturePlatform maxTextureSize: 2048 resizeAlgorithm: 0 @@ -69,6 +70,7 @@ TextureImporter: allowsAlphaSplitting: 0 overridden: 0 androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 spriteSheet: serializedVersion: 2 sprites: [] @@ -76,10 +78,12 @@ TextureImporter: physicsShape: [] bones: [] spriteID: + internalID: 0 vertices: [] indices: edges: [] weights: [] + secondaryTextures: [] spritePackingTag: pSDRemoveMatte: 0 pSDShowRemoveMatteOption: 0 diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png deleted file mode 100644 index cb3869d6b8..0000000000 Binary files a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png and /dev/null differ diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta deleted file mode 100644 index 4f4bc4de91..0000000000 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta +++ /dev/null @@ -1,88 +0,0 @@ -fileFormatVersion: 2 -guid: fceb584222ca149cd984dad123c8ae25 -TextureImporter: - fileIDToRecycleName: {} - externalObjects: {} - serializedVersion: 9 - mipmaps: - mipMapMode: 0 - enableMipMap: 1 - sRGBTexture: 1 - linearTexture: 0 - fadeOut: 0 - borderMipMap: 0 - mipMapsPreserveCoverage: 0 - alphaTestReferenceValue: 0.5 - mipMapFadeDistanceStart: 1 - mipMapFadeDistanceEnd: 3 - bumpmap: - convertToNormalMap: 0 - externalNormalMap: 0 - heightScale: 0.25 - normalMapFilter: 0 - isReadable: 0 - streamingMipmaps: 0 - streamingMipmapsPriority: 0 - grayScaleToAlpha: 0 - generateCubemap: 6 - cubemapConvolution: 0 - seamlessCubemap: 0 - textureFormat: 1 - maxTextureSize: 2048 - textureSettings: - serializedVersion: 2 - filterMode: -1 - aniso: -1 - mipBias: -100 - wrapU: -1 - wrapV: -1 - wrapW: -1 - nPOTScale: 1 - lightmap: 0 - compressionQuality: 50 - spriteMode: 0 - spriteExtrude: 1 - spriteMeshType: 1 - alignment: 0 - spritePivot: {x: 0.5, y: 0.5} - spritePixelsToUnits: 100 - spriteBorder: {x: 0, y: 0, z: 0, w: 0} - spriteGenerateFallbackPhysicsShape: 1 - alphaUsage: 1 - alphaIsTransparency: 0 - spriteTessellationDetail: -1 - textureType: 0 - textureShape: 1 - singleChannelComponent: 0 - maxTextureSizeSet: 0 - compressionQualitySet: 0 - textureFormatSet: 0 - platformSettings: - - serializedVersion: 2 - buildTarget: DefaultTexturePlatform - maxTextureSize: 2048 - resizeAlgorithm: 0 - textureFormat: -1 - textureCompression: 1 - compressionQuality: 50 - crunchedCompression: 0 - allowsAlphaSplitting: 0 - overridden: 0 - androidETC2FallbackOverride: 0 - spriteSheet: - serializedVersion: 2 - sprites: [] - outline: [] - physicsShape: [] - bones: [] - spriteID: - vertices: [] - indices: - edges: [] - weights: [] - spritePackingTag: - pSDRemoveMatte: 0 - pSDShowRemoveMatteOption: 0 - userData: - assetBundleName: - assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs index c88c494b0a..03b429595b 100644 --- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs @@ -15,7 +15,7 @@ public void TestNullRootBody() var gameObj = new GameObject(); var sensorComponent = gameObj.AddComponent(); - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; SensorTestHelper.CompareObservation(sensor, new float[0]); } @@ -33,7 +33,7 @@ public void TestSingleBody() UseLocalSpaceRotations = true }; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); var expected = new[] { @@ -42,7 +42,7 @@ public void TestSingleBody() 0f, 0f, 0f, 1f // LocalSpaceRotations }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); } [Test] @@ -89,7 +89,7 @@ public void TestBodiesWithJoint() #endif }; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); var expected = new[] { @@ -110,7 +110,7 @@ public void TestBodiesWithJoint() #endif }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); // Update the settings to only process joint observations sensorComponent.Settings = new PhysicsSensorSettings @@ -119,7 +119,7 @@ public void TestBodiesWithJoint() UseJointPositionsAndAngles = true, }; - sensor = sensorComponent.CreateSensor(); + sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); expected = new[] @@ -133,7 +133,7 @@ public void TestBodiesWithJoint() 0f, // joint2.force }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); } } } diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs index 3bf956c210..68300be98f 100644 --- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs @@ -32,7 +32,7 @@ public void TestNullRootBody() var gameObj = new GameObject(); var sensorComponent = gameObj.AddComponent(); - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; SensorTestHelper.CompareObservation(sensor, new float[0]); } @@ -50,13 +50,13 @@ public void TestSingleRigidbody() UseLocalSpaceRotations = true }; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); // The root body is ignored since it always generates identity values // and there are no other bodies to generate observations. var expected = new float[0]; - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); SensorTestHelper.CompareObservation(sensor, expected); } @@ -95,7 +95,7 @@ public void TestBodiesWithJoint() }; sensorComponent.VirtualRoot = virtualRoot; - var sensor = sensorComponent.CreateSensor(); + var sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); // Note that the VirtualRoot is ignored from the observations @@ -115,7 +115,7 @@ public void TestBodiesWithJoint() -1f, 1f, 0f, // Attached vel 0f, -1f, 1f // Leaf vel }; - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); SensorTestHelper.CompareObservation(sensor, expected); // Update the settings to only process joint observations @@ -125,7 +125,7 @@ public void TestBodiesWithJoint() UseJointForces = true, }; - sensor = sensorComponent.CreateSensor(); + sensor = sensorComponent.CreateSensors()[0]; sensor.Update(); expected = new[] @@ -136,7 +136,7 @@ public void TestBodiesWithJoint() 0f, 0f, 0f, // joint2.torque }; SensorTestHelper.CompareObservation(sensor, expected); - Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]); + Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); } } diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 7a22e50c63..3a8cdaac99 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to ## [Unreleased] ### Major Changes -#### com.unity.ml-agents (C#) +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) - The minimum supported Unity version was updated to 2019.4. (#5166) - Several breaking interface changes were made. See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more @@ -23,6 +23,10 @@ and `IDimensionPropertiesSensor` interfaces were removed. (#5127) - `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor` interface was removed. (#5164) - The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172) +- `SensorComponent.CreateSensor()` was replaced with `SensorComponent.CreateSensor()`, which returns an `ISensor[]`. (#5181) +- `Match3Sensor` was refactored to produce cell and special type observations separately, and `Match3SensorComponent` now +produces two `Match3Sensor`s (unless there are no special types). Previously trained models will have different observation +sizes and will need to be retrained. (#5181) #### ml-agents / ml-agents-envs / gym-unity (Python) @@ -39,7 +43,7 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe #### ml-agents / ml-agents-envs / gym-unity (Python) ### Bug Fixes -#### com.unity.ml-agents (C#) +#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 580e1571c3..985725d0d5 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -968,7 +968,7 @@ internal void InitializeSensors() sensors.Capacity += attachedSensorComponents.Length; foreach (var component in attachedSensorComponents) { - sensors.Add(component.CreateSensor()); + sensors.AddRange(component.CreateSensors()); } // Support legacy CollectObservations diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs index 2bf357b47e..b19903a973 100644 --- a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs @@ -49,10 +49,10 @@ public int MaxNumObservables private BufferSensor m_Sensor; /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { m_Sensor = new BufferSensor(MaxNumObservables, ObservableSize, m_SensorName); - return m_Sensor; + return new ISensor[] { m_Sensor }; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs index 0c677ee585..41582d35c6 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -106,15 +106,15 @@ public int ObservationStacks /// Creates the /// /// The created object for this component. - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression); if (ObservationStacks != 1) { - return new StackingSensor(m_Sensor, ObservationStacks); + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; } - return m_Sensor; + return new ISensor[] { m_Sensor }; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs index 47cfb5cfb7..50b9fd61a2 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs @@ -181,7 +181,7 @@ public virtual float GetEndVerticalOffset() /// Returns an initialized raycast sensor. /// /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { var rayPerceptionInput = GetRayPerceptionInput(); @@ -190,10 +190,10 @@ public override ISensor CreateSensor() if (ObservationStacks != 1) { var stackingSensor = new StackingSensor(m_RaySensor, ObservationStacks); - return stackingSensor; + return new ISensor[] { stackingSensor }; } - return m_RaySensor; + return new ISensor[] { m_RaySensor }; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs index 32c6bec07d..edcacf4ee4 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs @@ -82,14 +82,14 @@ public int ObservationStacks } /// - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression); if (ObservationStacks != 1) { - return new StackingSensor(m_Sensor, ObservationStacks); + return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; } - return m_Sensor; + return new ISensor[] { m_Sensor }; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs index b270cdd646..38781eec7b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs @@ -10,9 +10,9 @@ namespace Unity.MLAgents.Sensors public abstract class SensorComponent : MonoBehaviour { /// - /// Create the ISensor. This is called by the Agent when it is initialized. + /// Create the ISensors. This is called by the Agent when it is initialized. /// - /// Created ISensor object. - public abstract ISensor CreateSensor(); + /// Created ISensor objects. + public abstract ISensor[] CreateSensors(); } } diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs index b3f4801c56..1c65793ecd 100644 --- a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs @@ -100,10 +100,10 @@ public void TestRunModel() var modelRunner = new ModelRunner(discreteONNXModel, actionSpec, InferenceDevice.Burst); var info1 = new AgentInfo(); info1.episodeId = 1; - modelRunner.PutObservations(info1, new[] { sensor_21_20_3.CreateSensor() }.ToList()); + modelRunner.PutObservations(info1, new[] { sensor_21_20_3.CreateSensors()[0] }.ToList()); var info2 = new AgentInfo(); info2.episodeId = 2; - modelRunner.PutObservations(info2, new[] { sensor_21_20_3.CreateSensor() }.ToList()); + modelRunner.PutObservations(info2, new[] { sensor_21_20_3.CreateSensors()[0] }.ToList()); modelRunner.DecideBatch(); diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs index 059f1a0a13..f6c8f3124f 100644 --- a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs @@ -14,9 +14,9 @@ public class Test3DSensorComponent : SensorComponent { public ISensor Sensor; - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { - return Sensor; + return new ISensor[] { Sensor }; } } @@ -286,7 +286,13 @@ public void TestCheckModelValidContinuous(bool useDeprecatedNNModel) var errors = BarracudaModelParamLoader.CheckModel( model, validBrainParameters, - new ISensor[] { new VectorSensor(8), sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0] + new ISensor[] + { + new VectorSensor(8), + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] ); Assert.AreEqual(0, errors.Count()); // There should not be any errors } @@ -300,7 +306,7 @@ public void TestCheckModelValidDiscrete(bool useDeprecatedNNModel) var errors = BarracudaModelParamLoader.CheckModel( model, validBrainParameters, - new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0] + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, new ActuatorComponent[0] ); Assert.AreEqual(0, errors.Count()); // There should not be any errors } @@ -313,7 +319,10 @@ public void TestCheckModelValidHybrid() var errors = BarracudaModelParamLoader.CheckModel( model, validBrainParameters, - new ISensor[] { new VectorSensor(validBrainParameters.VectorObservationSize) }, new ActuatorComponent[0] + new ISensor[] + { + new VectorSensor(validBrainParameters.VectorObservationSize) + }, new ActuatorComponent[0] ); Assert.AreEqual(0, errors.Count()); // There should not be any errors } @@ -328,7 +337,12 @@ public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNN brainParameters.VectorObservationSize = 9; // Invalid observation var errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, - new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0] + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); @@ -336,7 +350,12 @@ public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNN brainParameters.NumStackedVectorObservations = 2;// Invalid stacking errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, - new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0] + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); } @@ -349,7 +368,13 @@ public void TestCheckModelThrowsVectorObservationDiscrete(bool useDeprecatedNNMo var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters(); brainParameters.VectorObservationSize = 1; // Invalid observation - var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]); + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); } @@ -383,12 +408,26 @@ public void TestCheckModelThrowsActionContinuous(bool useDeprecatedNNModel) var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.ActionSpec = ActionSpec.MakeContinuous(3); // Invalid action - var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]); + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3); // Invalid SpaceType - errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]); + errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); } @@ -400,12 +439,21 @@ public void TestCheckModelThrowsActionDiscrete(bool useDeprecatedNNModel) var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters(); brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3, 3); // Invalid action - var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]); + var errors = BarracudaModelParamLoader.CheckModel( + model, brainParameters, + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.ActionSpec = ActionSpec.MakeContinuous(2); // Invalid SpaceType - errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]); + errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); } @@ -416,12 +464,30 @@ public void TestCheckModelThrowsActionHybrid() var brainParameters = GetHybridBrainParameters(); brainParameters.ActionSpec = new ActionSpec(3, new[] { 3 }); // Invalid discrete action size - var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]); + var errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.ActionSpec = ActionSpec.MakeDiscrete(2); // Missing continuous action - errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]); + errors = BarracudaModelParamLoader.CheckModel( + model, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); } @@ -429,7 +495,16 @@ public void TestCheckModelThrowsActionHybrid() public void TestCheckModelThrowsNoModel() { var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); - var errors = BarracudaModelParamLoader.CheckModel(null, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]); + var errors = BarracudaModelParamLoader.CheckModel( + null, + brainParameters, + new ISensor[] + { + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }, + new ActuatorComponent[0] + ); Assert.Greater(errors.Count(), 0); } } diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs index df4a3cc961..563cac341d 100644 --- a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs @@ -68,7 +68,7 @@ public void CheckSetupRayPerceptionSensorComponent() sensorComponent.RayLayerMask = 0; sensorComponent.ObservationStacks = 2; - sensorComponent.CreateSensor(); + sensorComponent.CreateSensors(); } #endif } diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index edd16ee04a..89c790c126 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -31,10 +31,16 @@ public class StackingComponent : SensorComponent public SensorComponent wrappedComponent; public int numStacks; - public override ISensor CreateSensor() + public override ISensor[] CreateSensors() { - var wrappedSensor = wrappedComponent.CreateSensor(); - return new StackingSensor(wrappedSensor, numStacks); + var wrappedSensors = wrappedComponent.CreateSensors(); + var sensorsOut = new ISensor[wrappedSensors.Length]; + for (var i = 0; i < wrappedSensors.Length; i++) + { + sensorsOut[i] = new StackingSensor(wrappedSensors[i], numStacks); + } + + return sensorsOut; } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs index c17f6cecb5..622c238695 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs @@ -55,7 +55,7 @@ public void TestBufferSensorComponent() bufferComponent.ObservableSize = 4; bufferComponent.SensorName = "TestName"; - var sensor = bufferComponent.CreateSensor(); + var sensor = bufferComponent.CreateSensors()[0]; var shape = sensor.GetObservationSpec().Shape; Assert.AreEqual(shape[0], 20); diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs index 5bb2c74fe1..1ed056fd21 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs @@ -29,7 +29,7 @@ public void TestCameraSensorComponent() cameraComponent.Grayscale = grayscale; cameraComponent.CompressionType = compression; - var sensor = cameraComponent.CreateSensor(); + var sensor = cameraComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(CameraSensor), sensor.GetType()); diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs index 751ed433ac..4cfa65a1e4 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs @@ -106,7 +106,7 @@ public void TestRaycasts() foreach (var castRadius in radii) { perception.SphereCastRadius = castRadius; - var sensor = perception.CreateSensor(); + var sensor = perception.CreateSensors()[0]; var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); @@ -165,7 +165,7 @@ public void TestRaycastMiss() perception.DetectableTags.Add(k_CubeTag); perception.DetectableTags.Add(k_SphereTag); - var sensor = perception.CreateSensor(); + var sensor = perception.CreateSensors()[0]; var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; @@ -213,7 +213,7 @@ public void TestRayFilter() } perception.RayLayerMask = layerMask; - var sensor = perception.CreateSensor(); + var sensor = perception.CreateSensors()[0]; var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; @@ -259,7 +259,7 @@ public void TestRaycastsScaled() foreach (var castRadius in radii) { perception.SphereCastRadius = castRadius; - var sensor = perception.CreateSensor(); + var sensor = perception.CreateSensors()[0]; var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); @@ -308,7 +308,7 @@ public void TestRayZeroLength() { // Set the layer mask to either the default, or one that ignores the close cube's layer - var sensor = perception.CreateSensor(); + var sensor = perception.CreateSensors()[0]; var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2); Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs); var outputBuffer = new float[expectedObs]; diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs index d21e544dd7..c4dcc93ef5 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs @@ -28,7 +28,7 @@ public void TestRenderTextureSensorComponent() var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3); - var sensor = renderTexComponent.CreateSensor(); + var sensor = renderTexComponent.CreateSensors()[0]; Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType()); } diff --git a/docs/Migrating.md b/docs/Migrating.md index 9971e96d6b..c386dbc723 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -91,6 +91,7 @@ public CompressionSpec GetCompressionSpec() ``` - The abstract method `SensorComponent.GetObservationShape()` was removed. +- The abstract method `SensorComponent.CreateSensor()` was replaced with `CreateSensors()`, which returns an `ISensor[]`. ## Migrating to Release 13 ### Implementing IHeuristic in your IActuator implementations