-
Notifications
You must be signed in to change notification settings - Fork 4.3k
[MLA-1824] make SensorComponent return ISensor[] #5181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,14 @@ | |
|
||
namespace Unity.MLAgents.Extensions.Match3 | ||
{ | ||
|
||
/// <summary> | ||
/// Delegate that provides integer values at a given (x,y) coordinate. | ||
/// </summary> | ||
/// <param name="x"></param> | ||
/// <param name="y"></param> | ||
public delegate int GridValueProvider(int x, int y); | ||
|
||
/// <summary> | ||
/// Type of observations to generate. | ||
/// | ||
|
@@ -32,66 +40,68 @@ public enum Match3ObservationType | |
|
||
/// <summary> | ||
/// Sensor for Match3 games. Can generate either vector, compressed visual, | ||
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType() | ||
/// and AbstractBoard.GetSpecialType() to determine the observation values. | ||
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values. | ||
/// </summary> | ||
public class Match3Sensor : ISensor, IBuiltInSensor | ||
{ | ||
private Match3ObservationType m_ObservationType; | ||
private AbstractBoard m_Board; | ||
private ObservationSpec m_ObservationSpec; | ||
private int[] m_SparseChannelMapping; | ||
private string m_Name; | ||
|
||
private int m_Rows; | ||
private int m_Columns; | ||
private int m_NumCellTypes; | ||
private int m_NumSpecialTypes; | ||
|
||
private int SpecialTypeSize | ||
{ | ||
get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; } | ||
} | ||
private GridValueProvider m_GridValues; | ||
private int m_OneHotSize; | ||
|
||
/// <summary> | ||
/// Create a sensor for the board with the specified observation type. | ||
/// Create a sensor for the GridValueProvider with the specified observation type. | ||
/// </summary> | ||
/// <param name="board"></param> | ||
/// <param name="obsType"></param> | ||
/// <param name="name"></param> | ||
public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name) | ||
/// <remarks> | ||
/// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling | ||
/// the constructor directly. | ||
/// </remarks> | ||
/// <param name="board">The abstract board. This is only used to get the size.</param> | ||
/// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param> | ||
/// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param> | ||
/// <param name="obsType">Whether to produce vector or visual observations</param> | ||
/// <param name="name">Name of the sensor.</param> | ||
public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) | ||
{ | ||
m_Board = board; | ||
m_Name = name; | ||
m_Rows = board.Rows; | ||
m_Columns = board.Columns; | ||
m_NumCellTypes = board.NumCellTypes; | ||
m_NumSpecialTypes = board.NumSpecialTypes; | ||
m_GridValues = gvp; | ||
m_OneHotSize = oneHotSize; | ||
|
||
m_ObservationType = obsType; | ||
m_ObservationSpec = obsType == Match3ObservationType.Vector | ||
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize)) | ||
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize); | ||
|
||
// See comment in GetCompressedObservation() | ||
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3); | ||
m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize]; | ||
// If we have 4 cell types and 2 special types (3 special size), we'd have | ||
// [0, 1, 2, 3, -1, -1, 4, 5, 6] | ||
for (var i = 0; i < m_NumCellTypes; i++) | ||
{ | ||
m_SparseChannelMapping[i] = i; | ||
} | ||
? ObservationSpec.Vector(m_Rows * m_Columns * oneHotSize) | ||
: ObservationSpec.Visual(m_Rows, m_Columns, oneHotSize); | ||
} | ||
|
||
for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++) | ||
{ | ||
m_SparseChannelMapping[i] = -1; | ||
} | ||
/// <summary> | ||
/// Create a sensor that encodes the board cells as observations. | ||
/// </summary> | ||
/// <param name="board">The abstract board.</param> | ||
/// <param name="obsType">Whether to produce vector or visual observations</param> | ||
/// <param name="name">Name of the sensor.</param> | ||
/// <returns></returns> | ||
public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) | ||
{ | ||
return new Match3Sensor(board, board.GetCellType, board.NumCellTypes, obsType, name); | ||
} | ||
|
||
for (var i = 0; i < SpecialTypeSize; i++) | ||
{ | ||
m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes; | ||
} | ||
/// <summary> | ||
/// Create a sensor that encodes the cell special types as observations. | ||
/// </summary> | ||
/// <param name="board">The abstract board.</param> | ||
/// <param name="obsType">Whether to produce vector or visual observations</param> | ||
/// <param name="name">Name of the sensor.</param> | ||
/// <returns></returns> | ||
public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) | ||
{ | ||
var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1; | ||
return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); | ||
} | ||
|
||
/// <inheritdoc/> | ||
|
@@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec() | |
/// <inheritdoc/> | ||
public int Write(ObservationWriter writer) | ||
{ | ||
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes) | ||
{ | ||
Debug.LogWarning( | ||
$"Board shape changes since sensor initialization. This may cause unexpected results. " + | ||
$"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " + | ||
$"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}" | ||
); | ||
} | ||
// if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes) | ||
// { | ||
// Debug.LogWarning( | ||
// $"Board shape changes since sensor initialization. This may cause unexpected results. " + | ||
// $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " + | ||
// $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}" | ||
// ); | ||
// } | ||
|
||
if (m_ObservationType == Match3ObservationType.Vector) | ||
{ | ||
|
@@ -119,22 +129,13 @@ public int Write(ObservationWriter writer) | |
{ | ||
for (var c = 0; c < m_Columns; c++) | ||
{ | ||
var val = m_Board.GetCellType(r, c); | ||
for (var i = 0; i < m_NumCellTypes; i++) | ||
var val = m_GridValues(r, c); | ||
|
||
for (var i = 0; i < m_OneHotSize; i++) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this loop exist to 0 out the rest of one-hots? Could it be simplified to something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, as long as we know the memory is initialized to zero, we could do that (plus an extra check that |
||
{ | ||
writer[offset] = (i == val) ? 1.0f : 0.0f; | ||
offset++; | ||
} | ||
|
||
if (m_NumSpecialTypes > 0) | ||
{ | ||
var special = m_Board.GetSpecialType(r, c); | ||
for (var i = 0; i < SpecialTypeSize; i++) | ||
{ | ||
writer[offset] = (i == special) ? 1.0f : 0.0f; | ||
offset++; | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
@@ -148,22 +149,12 @@ public int Write(ObservationWriter writer) | |
{ | ||
for (var c = 0; c < m_Columns; c++) | ||
{ | ||
var val = m_Board.GetCellType(r, c); | ||
for (var i = 0; i < m_NumCellTypes; i++) | ||
var val = m_GridValues(r, c); | ||
for (var i = 0; i < m_OneHotSize; i++) | ||
{ | ||
writer[r, c, i] = (i == val) ? 1.0f : 0.0f; | ||
offset++; | ||
} | ||
|
||
if (m_NumSpecialTypes > 0) | ||
{ | ||
var special = m_Board.GetSpecialType(r, c); | ||
for (var i = 0; i < SpecialTypeSize; i++) | ||
{ | ||
writer[offset] = (i == special) ? 1.0f : 0.0f; | ||
offset++; | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
@@ -185,17 +176,10 @@ public byte[] GetCompressedObservation() | |
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for | ||
// the special types). Note that we have to also implement the sparse channel mapping. | ||
// Optimize this it later. | ||
var numCellImages = (m_NumCellTypes + 2) / 3; | ||
var numCellImages = (m_OneHotSize + 2) / 3; | ||
for (var i = 0; i < numCellImages; i++) | ||
{ | ||
converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i); | ||
bytesOut.AddRange(tempTexture.EncodeToPNG()); | ||
} | ||
|
||
var numSpecialImages = (SpecialTypeSize + 2) / 3; | ||
for (var i = 0; i < numSpecialImages; i++) | ||
{ | ||
converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i); | ||
converter.EncodeToTexture(m_GridValues, tempTexture, 3 * i); | ||
bytesOut.AddRange(tempTexture.EncodeToPNG()); | ||
} | ||
|
||
|
@@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType() | |
/// <inheritdoc/> | ||
public CompressionSpec GetCompressionSpec() | ||
{ | ||
return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping); | ||
return new CompressionSpec(GetCompressionType()); | ||
} | ||
|
||
/// <inheritdoc/> | ||
|
@@ -265,9 +249,6 @@ internal class OneHotToTextureUtil | |
int m_Width; | ||
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; | ||
|
||
public delegate int GridValueProvider(int x, int y); | ||
|
||
|
||
public OneHotToTextureUtil(int height, int width) | ||
{ | ||
m_Colors = new Color[height * width]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disabled for now, might return later (depending on how I handle multiple board sizes)