Skip to content

[MLA-1824] make SensorComponent return ISensor[] #5181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ public class BasicSensorComponent : SensorComponent
/// Creates a BasicSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
return new BasicSensor(basicController);
return new ISensor[] { new BasicSensor(basicController) };
}
}

Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ public string SensorName


/// <inheritdoc/>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
m_Sensor = new TestTextureSensor(TestTexture, SensorName, CompressionType);
if (ObservationStacks != 1)
{
return new StackingSensor(m_Sensor, ObservationStacks);
return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) };
}
return m_Sensor;
return new ISensor[] { m_Sensor };
}
}

149 changes: 65 additions & 84 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

namespace Unity.MLAgents.Extensions.Match3
{

/// <summary>
/// Delegate that provides integer values at a given (x,y) coordinate.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
public delegate int GridValueProvider(int x, int y);

/// <summary>
/// Type of observations to generate.
///
Expand Down Expand Up @@ -32,66 +40,68 @@ public enum Match3ObservationType

/// <summary>
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
/// and AbstractBoard.GetSpecialType() to determine the observation values.
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
/// </summary>
public class Match3Sensor : ISensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

private int m_Rows;
private int m_Columns;
private int m_NumCellTypes;
private int m_NumSpecialTypes;

private int SpecialTypeSize
{
get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; }
}
private GridValueProvider m_GridValues;
private int m_OneHotSize;

/// <summary>
/// Create a sensor for the board with the specified observation type.
/// Create a sensor for the GridValueProvider with the specified observation type.
/// </summary>
/// <param name="board"></param>
/// <param name="obsType"></param>
/// <param name="name"></param>
public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name)
/// <remarks>
/// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
/// the constructor directly.
/// </remarks>
/// <param name="board">The abstract board. This is only used to get the size.</param>
/// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param>
/// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name)
{
m_Board = board;
m_Name = name;
m_Rows = board.Rows;
m_Columns = board.Columns;
m_NumCellTypes = board.NumCellTypes;
m_NumSpecialTypes = board.NumSpecialTypes;
m_GridValues = gvp;
m_OneHotSize = oneHotSize;

m_ObservationType = obsType;
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);

// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize];
// If we have 4 cell types and 2 special types (3 special size), we'd have
// [0, 1, 2, 3, -1, -1, 4, 5, 6]
for (var i = 0; i < m_NumCellTypes; i++)
{
m_SparseChannelMapping[i] = i;
}
? ObservationSpec.Vector(m_Rows * m_Columns * oneHotSize)
: ObservationSpec.Visual(m_Rows, m_Columns, oneHotSize);
}

for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++)
{
m_SparseChannelMapping[i] = -1;
}
/// <summary>
/// Create a sensor that encodes the board cells as observations.
/// </summary>
/// <param name="board">The abstract board.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
/// <returns></returns>
public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
{
return new Match3Sensor(board, board.GetCellType, board.NumCellTypes, obsType, name);
}

for (var i = 0; i < SpecialTypeSize; i++)
{
m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes;
}
/// <summary>
/// Create a sensor that encodes the cell special types as observations.
/// </summary>
/// <param name="board">The abstract board.</param>
/// <param name="obsType">Whether to produce vector or visual observations</param>
/// <param name="name">Name of the sensor.</param>
/// <returns></returns>
public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
{
var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name);
}

/// <inheritdoc/>
Expand All @@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec()
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
{
Debug.LogWarning(
$"Board shape changes since sensor initialization. This may cause unexpected results. " +
$"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
$"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
);
}
// if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disabled for now, might return later (depending on how I handle multiple board sizes)

// {
// Debug.LogWarning(
// $"Board shape changes since sensor initialization. This may cause unexpected results. " +
// $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
// $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
// );
// }

if (m_ObservationType == Match3ObservationType.Vector)
{
Expand All @@ -119,22 +129,13 @@ public int Write(ObservationWriter writer)
{
for (var c = 0; c < m_Columns; c++)
{
var val = m_Board.GetCellType(r, c);
for (var i = 0; i < m_NumCellTypes; i++)
var val = m_GridValues(r, c);

for (var i = 0; i < m_OneHotSize; i++)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this loop exist to 0 out the rest of one-hots? Could it be simplified to something like:

writer[offset + val] = 1.0f;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, as long as we know the memory is initialized to zero, we could do that (plus an extra check that val < m_OneHotSize)

{
writer[offset] = (i == val) ? 1.0f : 0.0f;
offset++;
}

if (m_NumSpecialTypes > 0)
{
var special = m_Board.GetSpecialType(r, c);
for (var i = 0; i < SpecialTypeSize; i++)
{
writer[offset] = (i == special) ? 1.0f : 0.0f;
offset++;
}
}
}
}

Expand All @@ -148,22 +149,12 @@ public int Write(ObservationWriter writer)
{
for (var c = 0; c < m_Columns; c++)
{
var val = m_Board.GetCellType(r, c);
for (var i = 0; i < m_NumCellTypes; i++)
var val = m_GridValues(r, c);
for (var i = 0; i < m_OneHotSize; i++)
{
writer[r, c, i] = (i == val) ? 1.0f : 0.0f;
offset++;
}

if (m_NumSpecialTypes > 0)
{
var special = m_Board.GetSpecialType(r, c);
for (var i = 0; i < SpecialTypeSize; i++)
{
writer[offset] = (i == special) ? 1.0f : 0.0f;
offset++;
}
}
}
}

Expand All @@ -185,17 +176,10 @@ public byte[] GetCompressedObservation()
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
// the special types). Note that we have to also implement the sparse channel mapping.
// Optimize this it later.
var numCellImages = (m_NumCellTypes + 2) / 3;
var numCellImages = (m_OneHotSize + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}

var numSpecialImages = (SpecialTypeSize + 2) / 3;
for (var i = 0; i < numSpecialImages; i++)
{
converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i);
converter.EncodeToTexture(m_GridValues, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}

Expand Down Expand Up @@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType()
/// <inheritdoc/>
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping);
return new CompressionSpec(GetCompressionType());
}

/// <inheritdoc/>
Expand Down Expand Up @@ -265,9 +249,6 @@ internal class OneHotToTextureUtil
int m_Width;
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };

public delegate int GridValueProvider(int x, int y);


public OneHotToTextureUtil(int height, int width)
{
m_Colors = new Color[height * width];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@ public class Match3SensorComponent : SensorComponent
public Match3ObservationType ObservationType = Match3ObservationType.Vector;

/// <inheritdoc/>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
var board = GetComponent<AbstractBoard>();
return new Match3Sensor(board, ObservationType, SensorName);
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
if (board.NumSpecialTypes > 0)
{
var specialSensor =
Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
return new ISensor[] { cellSensor, specialSensor };
}
else
{
return new ISensor[] { cellSensor };
}
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public class ArticulationBodySensorComponent : SensorComponent
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
return new PhysicsBodySensor(RootBody, Settings, sensorName);
return new ISensor[] {new PhysicsBodySensor(RootBody, Settings, sensorName)};
}

}
Expand Down
4 changes: 2 additions & 2 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ public enum GridDepthType { Channel, ChannelHot };
private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f);

/// <inheritdoc/>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
return this;
return new ISensor[] { this };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ public class RigidBodySensorComponent : SensorComponent
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
public override ISensor[] CreateSensors()
{
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
}

/// <summary>
Expand Down
Loading