Skip to content

Commit a7a22f4

Browse files
author
Ruo-Ping Dong
authored
Fix stacked grid sensor (#5335)
1 parent 3a75c3c commit a7a22f4

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using System.Linq;
23
using UnityEngine;
34

45
namespace Unity.MLAgents.Sensors
@@ -11,7 +12,7 @@ public class GridSensorComponent : SensorComponent
1112
{
1213
// dummy sensor only used for debug gizmo
1314
GridSensorBase m_DebugSensor;
14-
List<ISensor> m_Sensors;
15+
List<GridSensorBase> m_Sensors;
1516
internal BoxOverlapChecker m_BoxOverlapChecker;
1617

1718
[HideInInspector, SerializeField]
@@ -196,7 +197,6 @@ public int ObservationStacks
196197
/// <inheritdoc/>
197198
public override ISensor[] CreateSensors()
198199
{
199-
m_Sensors = new List<ISensor>();
200200
m_BoxOverlapChecker = new BoxOverlapChecker(
201201
m_CellScale,
202202
m_GridSize,
@@ -213,29 +213,33 @@ public override ISensor[] CreateSensors()
213213
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
214214
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
215215

216-
var gridSensors = GetGridSensors();
217-
if (gridSensors == null || gridSensors.Length < 1)
216+
m_Sensors = GetGridSensors().ToList();
217+
if (m_Sensors == null || m_Sensors.Count < 1)
218218
{
219219
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." +
220220
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor.");
221221
}
222222

223-
foreach (var sensor in gridSensors)
223+
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
224+
m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker;
225+
foreach (var sensor in m_Sensors)
224226
{
225-
if (ObservationStacks != 1)
226-
{
227-
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks));
228-
}
229-
else
230-
{
231-
m_Sensors.Add(sensor);
232-
}
233227
m_BoxOverlapChecker.RegisterSensor(sensor);
234228
}
235229

236-
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
237-
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker;
238-
return m_Sensors.ToArray();
230+
if (ObservationStacks != 1)
231+
{
232+
var sensors = new ISensor[m_Sensors.Count];
233+
for (var i = 0; i < m_Sensors.Count; i++)
234+
{
235+
sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks);
236+
}
237+
return sensors;
238+
}
239+
else
240+
{
241+
return m_Sensors.ToArray();
242+
}
239243
}
240244

241245
/// <summary>
@@ -262,7 +266,7 @@ internal void UpdateSensor()
262266
m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
263267
foreach (var sensor in m_Sensors)
264268
{
265-
((GridSensorBase)sensor).CompressionType = m_CompressionType;
269+
sensor.CompressionType = m_CompressionType;
266270
}
267271
}
268272
}

com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public void TestCreateSensor()
8989
gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
9090

9191
gridSensorComponent.CreateSensors();
92-
var componentSensor = (List<ISensor>)typeof(GridSensorComponent).GetField("m_Sensors",
92+
var componentSensor = (List<GridSensorBase>)typeof(GridSensorComponent).GetField("m_Sensors",
9393
BindingFlags.Instance | BindingFlags.NonPublic).GetValue(gridSensorComponent);
9494
Assert.AreEqual(componentSensor.Count, 1);
9595
}
@@ -191,6 +191,17 @@ public void TestNoSensors()
191191
gridSensorComponent.CreateSensors();
192192
});
193193
}
194+
195+
[Test]
196+
public void TestStackedSensors()
197+
{
198+
testGo.tag = k_Tag2;
199+
string[] tags = { k_Tag1, k_Tag2 };
200+
gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
201+
gridSensorComponent.ObservationStacks = 3;
202+
var sensors = gridSensorComponent.CreateSensors();
203+
Assert.IsInstanceOf(typeof(StackingSensor), sensors[0]);
204+
}
194205
}
195206
}
196207
#endif

0 commit comments

Comments
 (0)