diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
index cddbc27439..f956d747ad 100644
--- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
+++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
@@ -15,9 +15,9 @@ public class BasicSensorComponent : SensorComponent
/// Creates a BasicSensor.
///
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
- return new BasicSensor(basicController);
+ return new ISensor[] { new BasicSensor(basicController) };
}
}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx
index 7d32504590..011c232473 100644
Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx differ
diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx
index d8ec297d15..c20ba1c999 100644
Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.onnx differ
diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs
index 504e0c3d19..4a12657505 100644
--- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs
+++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs
@@ -23,14 +23,14 @@ public string SensorName
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
m_Sensor = new TestTextureSensor(TestTexture, SensorName, CompressionType);
if (ObservationStacks != 1)
{
- return new StackingSensor(m_Sensor, ObservationStacks);
+ return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) };
}
- return m_Sensor;
+ return new ISensor[] { m_Sensor };
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
index 647e56fddb..8e5e81557a 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
@@ -4,6 +4,14 @@
namespace Unity.MLAgents.Extensions.Match3
{
+
+ ///
+ /// Delegate that provides integer values at a given (x,y) coordinate.
+ ///
+ ///
+ ///
+ public delegate int GridValueProvider(int x, int y);
+
///
/// Type of observations to generate.
///
@@ -32,66 +40,68 @@ public enum Match3ObservationType
///
/// Sensor for Match3 games. Can generate either vector, compressed visual,
- /// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
- /// and AbstractBoard.GetSpecialType() to determine the observation values.
+ /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
///
public class Match3Sensor : ISensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
- private AbstractBoard m_Board;
private ObservationSpec m_ObservationSpec;
- private int[] m_SparseChannelMapping;
private string m_Name;
private int m_Rows;
private int m_Columns;
- private int m_NumCellTypes;
- private int m_NumSpecialTypes;
-
- private int SpecialTypeSize
- {
- get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; }
- }
+ private GridValueProvider m_GridValues;
+ private int m_OneHotSize;
///
- /// Create a sensor for the board with the specified observation type.
+ /// Create a sensor for the GridValueProvider with the specified observation type.
///
- ///
- ///
- ///
- public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name)
+ ///
+ /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
+ /// the constructor directly.
+ ///
+ /// The abstract board. This is only used to get the size.
+ /// The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.
+ /// The number of possible values that the GridValueProvider can return.
+ /// Whether to produce vector or visual observations
+ /// Name of the sensor.
+ public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name)
{
- m_Board = board;
m_Name = name;
m_Rows = board.Rows;
m_Columns = board.Columns;
- m_NumCellTypes = board.NumCellTypes;
- m_NumSpecialTypes = board.NumSpecialTypes;
+ m_GridValues = gvp;
+ m_OneHotSize = oneHotSize;
m_ObservationType = obsType;
m_ObservationSpec = obsType == Match3ObservationType.Vector
- ? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
- : ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
-
- // See comment in GetCompressedObservation()
- var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
- m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize];
- // If we have 4 cell types and 2 special types (3 special size), we'd have
- // [0, 1, 2, 3, -1, -1, 4, 5, 6]
- for (var i = 0; i < m_NumCellTypes; i++)
- {
- m_SparseChannelMapping[i] = i;
- }
+ ? ObservationSpec.Vector(m_Rows * m_Columns * oneHotSize)
+ : ObservationSpec.Visual(m_Rows, m_Columns, oneHotSize);
+ }
- for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++)
- {
- m_SparseChannelMapping[i] = -1;
- }
+ ///
+ /// Create a sensor that encodes the board cells as observations.
+ ///
+ /// The abstract board.
+ /// Whether to produce vector or visual observations
+ /// Name of the sensor.
+ ///
+ public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
+ {
+ return new Match3Sensor(board, board.GetCellType, board.NumCellTypes, obsType, name);
+ }
- for (var i = 0; i < SpecialTypeSize; i++)
- {
- m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes;
- }
+ ///
+ /// Create a sensor that encodes the cell special types as observations.
+ ///
+ /// The abstract board.
+ /// Whether to produce vector or visual observations
+ /// Name of the sensor.
+ ///
+ public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
+ {
+ var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
+ return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name);
}
///
@@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec()
///
public int Write(ObservationWriter writer)
{
- if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
- {
- Debug.LogWarning(
- $"Board shape changes since sensor initialization. This may cause unexpected results. " +
- $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
- $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
- );
- }
+ // if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
+ // {
+ // Debug.LogWarning(
+ // $"Board shape changes since sensor initialization. This may cause unexpected results. " +
+ // $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
+ // $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
+ // );
+ // }
if (m_ObservationType == Match3ObservationType.Vector)
{
@@ -119,22 +129,13 @@ public int Write(ObservationWriter writer)
{
for (var c = 0; c < m_Columns; c++)
{
- var val = m_Board.GetCellType(r, c);
- for (var i = 0; i < m_NumCellTypes; i++)
+ var val = m_GridValues(r, c);
+
+ for (var i = 0; i < m_OneHotSize; i++)
{
writer[offset] = (i == val) ? 1.0f : 0.0f;
offset++;
}
-
- if (m_NumSpecialTypes > 0)
- {
- var special = m_Board.GetSpecialType(r, c);
- for (var i = 0; i < SpecialTypeSize; i++)
- {
- writer[offset] = (i == special) ? 1.0f : 0.0f;
- offset++;
- }
- }
}
}
@@ -148,22 +149,12 @@ public int Write(ObservationWriter writer)
{
for (var c = 0; c < m_Columns; c++)
{
- var val = m_Board.GetCellType(r, c);
- for (var i = 0; i < m_NumCellTypes; i++)
+ var val = m_GridValues(r, c);
+ for (var i = 0; i < m_OneHotSize; i++)
{
writer[r, c, i] = (i == val) ? 1.0f : 0.0f;
offset++;
}
-
- if (m_NumSpecialTypes > 0)
- {
- var special = m_Board.GetSpecialType(r, c);
- for (var i = 0; i < SpecialTypeSize; i++)
- {
- writer[offset] = (i == special) ? 1.0f : 0.0f;
- offset++;
- }
- }
}
}
@@ -185,17 +176,10 @@ public byte[] GetCompressedObservation()
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
// the special types). Note that we have to also implement the sparse channel mapping.
// Optimize this it later.
- var numCellImages = (m_NumCellTypes + 2) / 3;
+ var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
- converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i);
- bytesOut.AddRange(tempTexture.EncodeToPNG());
- }
-
- var numSpecialImages = (SpecialTypeSize + 2) / 3;
- for (var i = 0; i < numSpecialImages; i++)
- {
- converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i);
+ converter.EncodeToTexture(m_GridValues, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}
@@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType()
///
public CompressionSpec GetCompressionSpec()
{
- return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping);
+ return new CompressionSpec(GetCompressionType());
}
///
@@ -265,9 +249,6 @@ internal class OneHotToTextureUtil
int m_Width;
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };
- public delegate int GridValueProvider(int x, int y);
-
-
public OneHotToTextureUtil(int height, int width)
{
m_Colors = new Color[height * width];
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
index 4dbc1303c2..9007872e23 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
@@ -21,10 +21,20 @@ public class Match3SensorComponent : SensorComponent
public Match3ObservationType ObservationType = Match3ObservationType.Vector;
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
var board = GetComponent();
- return new Match3Sensor(board, ObservationType, SensorName);
+ var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
+ if (board.NumSpecialTypes > 0)
+ {
+ var specialSensor =
+ Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
+ return new ISensor[] { cellSensor, specialSensor };
+ }
+ else
+ {
+ return new ISensor[] { cellSensor };
+ }
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
index b8b2ac8017..c62ea60b70 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
@@ -16,9 +16,9 @@ public class ArticulationBodySensorComponent : SensorComponent
/// Creates a PhysicsBodySensor.
///
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
- return new PhysicsBodySensor(RootBody, Settings, sensorName);
+ return new ISensor[] {new PhysicsBodySensor(RootBody, Settings, sensorName)};
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
index 8224139d5e..43e66f0cac 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
@@ -267,9 +267,9 @@ public enum GridDepthType { Channel, ChannelHot };
private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f);
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
- return this;
+ return new ISensor[] { this };
}
///
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
index 03da767b8b..e201125bff 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
@@ -39,10 +39,10 @@ public class RigidBodySensorComponent : SensorComponent
/// Creates a PhysicsBodySensor.
///
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
- return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
+ return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
}
///
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
index 80438f857e..72e20a85ae 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
@@ -13,6 +13,8 @@ public class Match3SensorTests
// Whether the expected PNG data should be written to a file.
// Only set this to true if the compressed observation format changes.
private bool WritePNGDataToFile = false;
+ private const string k_CellObservationPng = "match3obs";
+ private const string k_SpecialObservationPng = "match3obs_special";
[Test]
public void TestVectorObservations()
@@ -27,7 +29,7 @@ public void TestVectorObservations()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.Vector;
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
var expectedShape = new InplaceArray(3 * 3 * 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
@@ -60,18 +62,35 @@ public void TestVectorObservationsSpecial()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.Vector;
- var sensor = sensorComponent.CreateSensor();
+ var sensors = sensorComponent.CreateSensors();
+ var cellSensor = sensors[0];
+ var specialSensor = sensors[1];
- var expectedShape = new InplaceArray(3 * 3 * (2 + 3));
- Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
- var expectedObs = new float[]
{
- 1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
- 1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */
- 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
- };
- SensorTestHelper.CompareObservation(sensor, expectedObs);
+ var expectedShape = new InplaceArray(3 * 3 * 2);
+ Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape);
+
+ var expectedObs = new float[]
+ {
+ 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */
+ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */
+ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */
+ };
+ SensorTestHelper.CompareObservation(cellSensor, expectedObs);
+ }
+ {
+ var expectedShape = new InplaceArray(3 * 3 * 3);
+ Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape);
+
+ var expectedObs = new float[]
+ {
+ 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */
+ 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */
+ 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */
+ };
+ SensorTestHelper.CompareObservation(specialSensor, expectedObs);
+ }
}
[Test]
@@ -87,7 +106,7 @@ public void TestVisualObservations()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
var expectedShape = new InplaceArray(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
@@ -130,28 +149,54 @@ public void TestVisualObservationsSpecial()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
- var sensor = sensorComponent.CreateSensor();
+ var sensors = sensorComponent.CreateSensors();
+ var cellSensor = sensors[0];
+ var specialSensor = sensors[1];
- var expectedShape = new InplaceArray(3, 3, 2 + 3);
- Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
+ {
+ var expectedShape = new InplaceArray(3, 3, 2);
+ Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape);
- Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);
+ Assert.AreEqual(SensorCompressionType.None, cellSensor.GetCompressionSpec().SensorCompressionType);
- var expectedObs = new float[]
- {
- 1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
- 1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */
- 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
- };
- SensorTestHelper.CompareObservation(sensor, expectedObs);
+ var expectedObs = new float[]
+ {
+ 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */
+ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */
+ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */
+ };
+ SensorTestHelper.CompareObservation(cellSensor, expectedObs);
- var expectedObs3D = new float[,,]
+ var expectedObs3D = new float[,,]
+ {
+ {{1, 0}, {0, 1}, {1, 0}},
+ {{1, 0}, {1, 0}, {1, 0}},
+ {{1, 0}, {1, 0}, {1, 0}},
+ };
+ SensorTestHelper.CompareObservation(cellSensor, expectedObs3D);
+ }
{
- {{1, 0, 1, 0, 0}, {0, 1, 1, 0, 0}, {1, 0, 1, 0, 0}},
- {{1, 0, 0, 0, 1}, {1, 0, 1, 0, 0}, {1, 0, 1, 0, 0}},
- {{1, 0, 1, 0, 0}, {1, 0, 0, 1, 0}, {1, 0, 1, 0, 0}},
- };
- SensorTestHelper.CompareObservation(sensor, expectedObs3D);
+ var expectedShape = new InplaceArray(3, 3, 3);
+ Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape);
+
+ Assert.AreEqual(SensorCompressionType.None, specialSensor.GetCompressionSpec().SensorCompressionType);
+
+ var expectedObs = new float[]
+ {
+ 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */
+ 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */
+ 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */
+ };
+ SensorTestHelper.CompareObservation(specialSensor, expectedObs);
+
+ var expectedObs3D = new float[,,]
+ {
+ {{1, 0, 0}, {1, 0, 0}, {1, 0, 0}},
+ {{0, 0, 1}, {1, 0, 0}, {1, 0, 0}},
+ {{1, 0, 0}, {0, 1, 0}, {1, 0, 0}},
+ };
+ SensorTestHelper.CompareObservation(specialSensor, expectedObs3D);
+ }
}
[Test]
@@ -167,7 +212,7 @@ public void TestCompressedVisualObservations()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
var expectedShape = new InplaceArray(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
@@ -178,10 +223,10 @@ public void TestCompressedVisualObservations()
if (WritePNGDataToFile)
{
// Enable this if the format of the observation changes
- SavePNGs(pngData, "match3obs");
+ SavePNGs(pngData, k_CellObservationPng);
}
- var expectedPng = LoadPNGs("match3obs", 1);
+ var expectedPng = LoadPNGs(k_CellObservationPng, 1);
Assert.AreEqual(expectedPng, pngData);
}
@@ -204,22 +249,30 @@ public void TestCompressedVisualObservationsSpecial()
var sensorComponent = gameObj.AddComponent();
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
- var sensor = sensorComponent.CreateSensor();
-
- var expectedShape = new InplaceArray(3, 3, 2 + 3);
- Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
+ var sensors = sensorComponent.CreateSensors();
- Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);
+ var paths = new[] { k_CellObservationPng, k_SpecialObservationPng };
+ var expectedChannels = new[] { 2, 3 };
- var concatenatedPngData = sensor.GetCompressedObservation();
- var pathPrefix = "match3obs_special";
- if (WritePNGDataToFile)
+ for (var i = 0; i < 2; i++)
{
- // Enable this if the format of the observation changes
- SavePNGs(concatenatedPngData, pathPrefix);
+ var sensor = sensors[i];
+ var expectedShape = new InplaceArray(3, 3, expectedChannels[i]);
+ Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
+
+ Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);
+
+ var pngData = sensor.GetCompressedObservation();
+ if (WritePNGDataToFile)
+ {
+ // Enable this if the format of the observation changes
+ SavePNGs(pngData, paths[i]);
+ }
+
+ var expectedPng = LoadPNGs(paths[i], 1);
+ Assert.AreEqual(expectedPng, pngData);
}
- var expectedPng = LoadPNGs(pathPrefix, 2);
- Assert.AreEqual(expectedPng, concatenatedPngData);
+
}
///
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta
index 16bf958b40..227ce7050f 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png.meta
@@ -1,9 +1,9 @@
fileFormatVersion: 2
guid: 3e1767bf6c63e46b1a16404dc1afe508
TextureImporter:
- fileIDToRecycleName: {}
+ internalIDToNameTable: []
externalObjects: {}
- serializedVersion: 9
+ serializedVersion: 11
mipmaps:
mipMapMode: 0
enableMipMap: 1
@@ -57,8 +57,9 @@ TextureImporter:
maxTextureSizeSet: 0
compressionQualitySet: 0
textureFormatSet: 0
+ applyGammaDecoding: 0
platformSettings:
- - serializedVersion: 2
+ - serializedVersion: 3
buildTarget: DefaultTexturePlatform
maxTextureSize: 2048
resizeAlgorithm: 0
@@ -69,6 +70,7 @@ TextureImporter:
allowsAlphaSplitting: 0
overridden: 0
androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
spriteSheet:
serializedVersion: 2
sprites: []
@@ -76,10 +78,12 @@ TextureImporter:
physicsShape: []
bones: []
spriteID:
+ internalID: 0
vertices: []
indices:
edges: []
weights: []
+ secondaryTextures: []
spritePackingTag:
pSDRemoveMatte: 0
pSDShowRemoveMatteOption: 0
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png
index 3495ebb395..cb3869d6b8 100644
Binary files a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png and b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png differ
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta
index 92f0f9a1cc..302444ad28 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special0.png.meta
@@ -1,9 +1,9 @@
fileFormatVersion: 2
guid: 2e4ca31cf9cff4505acbefe44b621d6f
TextureImporter:
- fileIDToRecycleName: {}
+ internalIDToNameTable: []
externalObjects: {}
- serializedVersion: 9
+ serializedVersion: 11
mipmaps:
mipMapMode: 0
enableMipMap: 1
@@ -57,8 +57,9 @@ TextureImporter:
maxTextureSizeSet: 0
compressionQualitySet: 0
textureFormatSet: 0
+ applyGammaDecoding: 0
platformSettings:
- - serializedVersion: 2
+ - serializedVersion: 3
buildTarget: DefaultTexturePlatform
maxTextureSize: 2048
resizeAlgorithm: 0
@@ -69,6 +70,7 @@ TextureImporter:
allowsAlphaSplitting: 0
overridden: 0
androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
spriteSheet:
serializedVersion: 2
sprites: []
@@ -76,10 +78,12 @@ TextureImporter:
physicsShape: []
bones: []
spriteID:
+ internalID: 0
vertices: []
indices:
edges: []
weights: []
+ secondaryTextures: []
spritePackingTag:
pSDRemoveMatte: 0
pSDShowRemoveMatteOption: 0
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png
deleted file mode 100644
index cb3869d6b8..0000000000
Binary files a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png and /dev/null differ
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta
deleted file mode 100644
index 4f4bc4de91..0000000000
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs_special1.png.meta
+++ /dev/null
@@ -1,88 +0,0 @@
-fileFormatVersion: 2
-guid: fceb584222ca149cd984dad123c8ae25
-TextureImporter:
- fileIDToRecycleName: {}
- externalObjects: {}
- serializedVersion: 9
- mipmaps:
- mipMapMode: 0
- enableMipMap: 1
- sRGBTexture: 1
- linearTexture: 0
- fadeOut: 0
- borderMipMap: 0
- mipMapsPreserveCoverage: 0
- alphaTestReferenceValue: 0.5
- mipMapFadeDistanceStart: 1
- mipMapFadeDistanceEnd: 3
- bumpmap:
- convertToNormalMap: 0
- externalNormalMap: 0
- heightScale: 0.25
- normalMapFilter: 0
- isReadable: 0
- streamingMipmaps: 0
- streamingMipmapsPriority: 0
- grayScaleToAlpha: 0
- generateCubemap: 6
- cubemapConvolution: 0
- seamlessCubemap: 0
- textureFormat: 1
- maxTextureSize: 2048
- textureSettings:
- serializedVersion: 2
- filterMode: -1
- aniso: -1
- mipBias: -100
- wrapU: -1
- wrapV: -1
- wrapW: -1
- nPOTScale: 1
- lightmap: 0
- compressionQuality: 50
- spriteMode: 0
- spriteExtrude: 1
- spriteMeshType: 1
- alignment: 0
- spritePivot: {x: 0.5, y: 0.5}
- spritePixelsToUnits: 100
- spriteBorder: {x: 0, y: 0, z: 0, w: 0}
- spriteGenerateFallbackPhysicsShape: 1
- alphaUsage: 1
- alphaIsTransparency: 0
- spriteTessellationDetail: -1
- textureType: 0
- textureShape: 1
- singleChannelComponent: 0
- maxTextureSizeSet: 0
- compressionQualitySet: 0
- textureFormatSet: 0
- platformSettings:
- - serializedVersion: 2
- buildTarget: DefaultTexturePlatform
- maxTextureSize: 2048
- resizeAlgorithm: 0
- textureFormat: -1
- textureCompression: 1
- compressionQuality: 50
- crunchedCompression: 0
- allowsAlphaSplitting: 0
- overridden: 0
- androidETC2FallbackOverride: 0
- spriteSheet:
- serializedVersion: 2
- sprites: []
- outline: []
- physicsShape: []
- bones: []
- spriteID:
- vertices: []
- indices:
- edges: []
- weights: []
- spritePackingTag:
- pSDRemoveMatte: 0
- pSDShowRemoveMatteOption: 0
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs
index c88c494b0a..03b429595b 100644
--- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs
@@ -15,7 +15,7 @@ public void TestNullRootBody()
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent();
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
@@ -33,7 +33,7 @@ public void TestSingleBody()
UseLocalSpaceRotations = true
};
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
var expected = new[]
{
@@ -42,7 +42,7 @@ public void TestSingleBody()
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
}
[Test]
@@ -89,7 +89,7 @@ public void TestBodiesWithJoint()
#endif
};
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
var expected = new[]
{
@@ -110,7 +110,7 @@ public void TestBodiesWithJoint()
#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
@@ -119,7 +119,7 @@ public void TestBodiesWithJoint()
UseJointPositionsAndAngles = true,
};
- sensor = sensorComponent.CreateSensor();
+ sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
expected = new[]
@@ -133,7 +133,7 @@ public void TestBodiesWithJoint()
0f, // joint2.force
};
SensorTestHelper.CompareObservation(sensor, expected);
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
}
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs
index 3bf956c210..68300be98f 100644
--- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs
@@ -32,7 +32,7 @@ public void TestNullRootBody()
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent();
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
@@ -50,13 +50,13 @@ public void TestSingleRigidbody()
UseLocalSpaceRotations = true
};
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
// The root body is ignored since it always generates identity values
// and there are no other bodies to generate observations.
var expected = new float[0];
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
}
@@ -95,7 +95,7 @@ public void TestBodiesWithJoint()
};
sensorComponent.VirtualRoot = virtualRoot;
- var sensor = sensorComponent.CreateSensor();
+ var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
// Note that the VirtualRoot is ignored from the observations
@@ -115,7 +115,7 @@ public void TestBodiesWithJoint()
-1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
};
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
// Update the settings to only process joint observations
@@ -125,7 +125,7 @@ public void TestBodiesWithJoint()
UseJointForces = true,
};
- sensor = sensorComponent.CreateSensor();
+ sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
expected = new[]
@@ -136,7 +136,7 @@ public void TestBodiesWithJoint()
0f, 0f, 0f, // joint2.torque
};
SensorTestHelper.CompareObservation(sensor, expected);
- Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
+ Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
}
}
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 7a22e50c63..3a8cdaac99 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -8,7 +8,7 @@ and this project adheres to
## [Unreleased]
### Major Changes
-#### com.unity.ml-agents (C#)
+#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- The minimum supported Unity version was updated to 2019.4. (#5166)
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
@@ -23,6 +23,10 @@ and `IDimensionPropertiesSensor` interfaces were removed. (#5127)
- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor`
interface was removed. (#5164)
- The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172)
+- `SensorComponent.CreateSensor()` was replaced with `SensorComponent.CreateSensor()`, which returns an `ISensor[]`. (#5181)
+- `Match3Sensor` was refactored to produce cell and special type observations separately, and `Match3SensorComponent` now
+produces two `Match3Sensor`s (unless there are no special types). Previously trained models will have different observation
+sizes and will need to be retrained. (#5181)
#### ml-agents / ml-agents-envs / gym-unity (Python)
@@ -39,7 +43,7 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Bug Fixes
-#### com.unity.ml-agents (C#)
+#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index 580e1571c3..985725d0d5 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -968,7 +968,7 @@ internal void InitializeSensors()
sensors.Capacity += attachedSensorComponents.Length;
foreach (var component in attachedSensorComponents)
{
- sensors.Add(component.CreateSensor());
+ sensors.AddRange(component.CreateSensors());
}
// Support legacy CollectObservations
diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
index 2bf357b47e..b19903a973 100644
--- a/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
@@ -49,10 +49,10 @@ public int MaxNumObservables
private BufferSensor m_Sensor;
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
m_Sensor = new BufferSensor(MaxNumObservables, ObservableSize, m_SensorName);
- return m_Sensor;
+ return new ISensor[] { m_Sensor };
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
index 0c677ee585..41582d35c6 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
@@ -106,15 +106,15 @@ public int ObservationStacks
/// Creates the
///
/// The created object for this component.
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
if (ObservationStacks != 1)
{
- return new StackingSensor(m_Sensor, ObservationStacks);
+ return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) };
}
- return m_Sensor;
+ return new ISensor[] { m_Sensor };
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs
index 47cfb5cfb7..50b9fd61a2 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs
@@ -181,7 +181,7 @@ public virtual float GetEndVerticalOffset()
/// Returns an initialized raycast sensor.
///
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
var rayPerceptionInput = GetRayPerceptionInput();
@@ -190,10 +190,10 @@ public override ISensor CreateSensor()
if (ObservationStacks != 1)
{
var stackingSensor = new StackingSensor(m_RaySensor, ObservationStacks);
- return stackingSensor;
+ return new ISensor[] { stackingSensor };
}
- return m_RaySensor;
+ return new ISensor[] { m_RaySensor };
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
index 32c6bec07d..edcacf4ee4 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
@@ -82,14 +82,14 @@ public int ObservationStacks
}
///
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression);
if (ObservationStacks != 1)
{
- return new StackingSensor(m_Sensor, ObservationStacks);
+ return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) };
}
- return m_Sensor;
+ return new ISensor[] { m_Sensor };
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
index b270cdd646..38781eec7b 100644
--- a/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
@@ -10,9 +10,9 @@ namespace Unity.MLAgents.Sensors
public abstract class SensorComponent : MonoBehaviour
{
///
- /// Create the ISensor. This is called by the Agent when it is initialized.
+ /// Create the ISensors. This is called by the Agent when it is initialized.
///
- /// Created ISensor object.
- public abstract ISensor CreateSensor();
+ /// Created ISensor objects.
+ public abstract ISensor[] CreateSensors();
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
index b3f4801c56..1c65793ecd 100644
--- a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
+++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
@@ -100,10 +100,10 @@ public void TestRunModel()
var modelRunner = new ModelRunner(discreteONNXModel, actionSpec, InferenceDevice.Burst);
var info1 = new AgentInfo();
info1.episodeId = 1;
- modelRunner.PutObservations(info1, new[] { sensor_21_20_3.CreateSensor() }.ToList());
+ modelRunner.PutObservations(info1, new[] { sensor_21_20_3.CreateSensors()[0] }.ToList());
var info2 = new AgentInfo();
info2.episodeId = 2;
- modelRunner.PutObservations(info2, new[] { sensor_21_20_3.CreateSensor() }.ToList());
+ modelRunner.PutObservations(info2, new[] { sensor_21_20_3.CreateSensors()[0] }.ToList());
modelRunner.DecideBatch();
diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
index 059f1a0a13..f6c8f3124f 100644
--- a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
+++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
@@ -14,9 +14,9 @@ public class Test3DSensorComponent : SensorComponent
{
public ISensor Sensor;
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
- return Sensor;
+ return new ISensor[] { Sensor };
}
}
@@ -286,7 +286,13 @@ public void TestCheckModelValidContinuous(bool useDeprecatedNNModel)
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
- new ISensor[] { new VectorSensor(8), sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
+ new ISensor[]
+ {
+ new VectorSensor(8),
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
@@ -300,7 +306,7 @@ public void TestCheckModelValidDiscrete(bool useDeprecatedNNModel)
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
- new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]
+ new ISensor[] { sensor_21_20_3.CreateSensors()[0] }, new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
@@ -313,7 +319,10 @@ public void TestCheckModelValidHybrid()
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
- new ISensor[] { new VectorSensor(validBrainParameters.VectorObservationSize) }, new ActuatorComponent[0]
+ new ISensor[]
+ {
+ new VectorSensor(validBrainParameters.VectorObservationSize)
+ }, new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
@@ -328,7 +337,12 @@ public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNN
brainParameters.VectorObservationSize = 9; // Invalid observation
var errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
- new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
+ new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
@@ -336,7 +350,12 @@ public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNN
brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
- new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
+ new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
}
@@ -349,7 +368,13 @@ public void TestCheckModelThrowsVectorObservationDiscrete(bool useDeprecatedNNMo
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.VectorObservationSize = 1; // Invalid observation
- var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
+ var errors = BarracudaModelParamLoader.CheckModel(
+ model, brainParameters, new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
}
@@ -383,12 +408,26 @@ public void TestCheckModelThrowsActionContinuous(bool useDeprecatedNNModel)
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeContinuous(3); // Invalid action
- var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
+ var errors = BarracudaModelParamLoader.CheckModel(
+ model, brainParameters, new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3); // Invalid SpaceType
- errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
+ errors = BarracudaModelParamLoader.CheckModel(
+ model, brainParameters, new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
}
@@ -400,12 +439,21 @@ public void TestCheckModelThrowsActionDiscrete(bool useDeprecatedNNModel)
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3, 3); // Invalid action
- var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
+ var errors = BarracudaModelParamLoader.CheckModel(
+ model, brainParameters,
+ new ISensor[] { sensor_21_20_3.CreateSensors()[0] },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeContinuous(2); // Invalid SpaceType
- errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
+ errors = BarracudaModelParamLoader.CheckModel(
+ model,
+ brainParameters,
+ new ISensor[] { sensor_21_20_3.CreateSensors()[0] },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
}
@@ -416,12 +464,30 @@ public void TestCheckModelThrowsActionHybrid()
var brainParameters = GetHybridBrainParameters();
brainParameters.ActionSpec = new ActionSpec(3, new[] { 3 }); // Invalid discrete action size
- var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
+ var errors = BarracudaModelParamLoader.CheckModel(
+ model,
+ brainParameters,
+ new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(2); // Missing continuous action
- errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
+ errors = BarracudaModelParamLoader.CheckModel(
+ model,
+ brainParameters,
+ new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
}
@@ -429,7 +495,16 @@ public void TestCheckModelThrowsActionHybrid()
public void TestCheckModelThrowsNoModel()
{
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
- var errors = BarracudaModelParamLoader.CheckModel(null, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
+ var errors = BarracudaModelParamLoader.CheckModel(
+ null,
+ brainParameters,
+ new ISensor[]
+ {
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ },
+ new ActuatorComponent[0]
+ );
Assert.Greater(errors.Count(), 0);
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
index df4a3cc961..563cac341d 100644
--- a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
+++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
@@ -68,7 +68,7 @@ public void CheckSetupRayPerceptionSensorComponent()
sensorComponent.RayLayerMask = 0;
sensorComponent.ObservationStacks = 2;
- sensorComponent.CreateSensor();
+ sensorComponent.CreateSensors();
}
#endif
}
diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
index edd16ee04a..89c790c126 100644
--- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
@@ -31,10 +31,16 @@ public class StackingComponent : SensorComponent
public SensorComponent wrappedComponent;
public int numStacks;
- public override ISensor CreateSensor()
+ public override ISensor[] CreateSensors()
{
- var wrappedSensor = wrappedComponent.CreateSensor();
- return new StackingSensor(wrappedSensor, numStacks);
+ var wrappedSensors = wrappedComponent.CreateSensors();
+ var sensorsOut = new ISensor[wrappedSensors.Length];
+ for (var i = 0; i < wrappedSensors.Length; i++)
+ {
+ sensorsOut[i] = new StackingSensor(wrappedSensors[i], numStacks);
+ }
+
+ return sensorsOut;
}
}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs
index c17f6cecb5..622c238695 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs
@@ -55,7 +55,7 @@ public void TestBufferSensorComponent()
bufferComponent.ObservableSize = 4;
bufferComponent.SensorName = "TestName";
- var sensor = bufferComponent.CreateSensor();
+ var sensor = bufferComponent.CreateSensors()[0];
var shape = sensor.GetObservationSpec().Shape;
Assert.AreEqual(shape[0], 20);
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
index 5bb2c74fe1..1ed056fd21 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
@@ -29,7 +29,7 @@ public void TestCameraSensorComponent()
cameraComponent.Grayscale = grayscale;
cameraComponent.CompressionType = compression;
- var sensor = cameraComponent.CreateSensor();
+ var sensor = cameraComponent.CreateSensors()[0];
var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs
index 751ed433ac..4cfa65a1e4 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs
@@ -106,7 +106,7 @@ public void TestRaycasts()
foreach (var castRadius in radii)
{
perception.SphereCastRadius = castRadius;
- var sensor = perception.CreateSensor();
+ var sensor = perception.CreateSensors()[0];
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
@@ -165,7 +165,7 @@ public void TestRaycastMiss()
perception.DetectableTags.Add(k_CubeTag);
perception.DetectableTags.Add(k_SphereTag);
- var sensor = perception.CreateSensor();
+ var sensor = perception.CreateSensors()[0];
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
@@ -213,7 +213,7 @@ public void TestRayFilter()
}
perception.RayLayerMask = layerMask;
- var sensor = perception.CreateSensor();
+ var sensor = perception.CreateSensors()[0];
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
@@ -259,7 +259,7 @@ public void TestRaycastsScaled()
foreach (var castRadius in radii)
{
perception.SphereCastRadius = castRadius;
- var sensor = perception.CreateSensor();
+ var sensor = perception.CreateSensors()[0];
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
@@ -308,7 +308,7 @@ public void TestRayZeroLength()
{
// Set the layer mask to either the default, or one that ignores the close cube's layer
- var sensor = perception.CreateSensor();
+ var sensor = perception.CreateSensors()[0];
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs
index d21e544dd7..c4dcc93ef5 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs
@@ -28,7 +28,7 @@ public void TestRenderTextureSensorComponent()
var expectedShape = new InplaceArray(height, width, grayscale ? 1 : 3);
- var sensor = renderTexComponent.CreateSensor();
+ var sensor = renderTexComponent.CreateSensors()[0];
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
diff --git a/docs/Migrating.md b/docs/Migrating.md
index 9971e96d6b..c386dbc723 100644
--- a/docs/Migrating.md
+++ b/docs/Migrating.md
@@ -91,6 +91,7 @@ public CompressionSpec GetCompressionSpec()
```
- The abstract method `SensorComponent.GetObservationShape()` was removed.
+- The abstract method `SensorComponent.CreateSensor()` was replaced with `CreateSensors()`, which returns an `ISensor[]`.
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations