Skip to content

Commit c3e74ee

Browse files
author
Ruo-Ping Dong
authored
Make OverlapChecker an interface (#5324)
1 parent 894bffa commit c3e74ee

File tree

7 files changed

+107
-41
lines changed

7 files changed

+107
-41
lines changed

com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
namespace Unity.MLAgents.Sensors
55
{
6-
internal class BoxOverlapChecker
6+
/// <summary>
7+
/// The grid perception strategy that uses box overlap to detect objects.
8+
/// </summary>
9+
internal class BoxOverlapChecker : IGridPerception
710
{
811
Vector3 m_CellScale;
912
Vector3Int m_GridSize;
@@ -84,17 +87,14 @@ void InitCellLocalPositions()
8487
}
8588
}
8689

87-
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
88-
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
89-
/// <param name="cellIndex">The index of the cell</param>
90-
Vector3 GetCellLocalPosition(int cellIndex)
90+
public Vector3 GetCellLocalPosition(int cellIndex)
9191
{
9292
float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x;
9393
float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z;
9494
return new Vector3(x, 0, z);
9595
}
9696

97-
internal Vector3 GetCellGlobalPosition(int cellIndex)
97+
public Vector3 GetCellGlobalPosition(int cellIndex)
9898
{
9999
if (m_RotateWithAgent)
100100
{
@@ -106,16 +106,12 @@ internal Vector3 GetCellGlobalPosition(int cellIndex)
106106
}
107107
}
108108

109-
internal Quaternion GetGridRotation()
109+
public Quaternion GetGridRotation()
110110
{
111111
return m_RotateWithAgent ? m_CenterObject.transform.rotation : Quaternion.identity;
112112
}
113113

114-
/// <summary>
115-
/// Perceive the latest grid status. Call OverlapBoxNonAlloc once to detect colliders.
116-
/// Then parse the collider arrays according to all available gridSensor delegates.
117-
/// </summary>
118-
internal void Update()
114+
public void Perceive()
119115
{
120116
#if MLA_UNITY_PHYSICS_MODULE
121117
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
@@ -135,10 +131,7 @@ internal void Update()
135131
#endif
136132
}
137133

138-
/// <summary>
139-
/// Same as Update(), but only load data for debug gizmo.
140-
/// </summary>
141-
internal void UpdateGizmo()
134+
public void UpdateGizmo()
142135
{
143136
#if MLA_UNITY_PHYSICS_MODULE
144137
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
@@ -246,7 +239,7 @@ void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, V
246239
}
247240
#endif
248241

249-
internal void RegisterSensor(GridSensorBase sensor)
242+
public void RegisterSensor(GridSensorBase sensor)
250243
{
251244
#if MLA_UNITY_PHYSICS_MODULE
252245
if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
@@ -260,7 +253,7 @@ internal void RegisterSensor(GridSensorBase sensor)
260253
#endif
261254
}
262255

263-
internal void RegisterDebugSensor(GridSensorBase debugSensor)
256+
public void RegisterDebugSensor(GridSensorBase debugSensor)
264257
{
265258
#if MLA_UNITY_PHYSICS_MODULE
266259
GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject;

com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable
3232
string[] m_DetectableTags;
3333
SensorCompressionType m_CompressionType;
3434
ObservationSpec m_ObservationSpec;
35-
internal BoxOverlapChecker m_BoxOverlapChecker;
35+
internal IGridPerception m_GridPerception;
3636

3737
// Buffers
3838
float[] m_PerceptionBuffer;
@@ -299,9 +299,9 @@ public void Update()
299299
ResetPerceptionBuffer();
300300
using (TimerStack.Instance.Scoped("GridSensor.Update"))
301301
{
302-
if (m_BoxOverlapChecker != null)
302+
if (m_GridPerception != null)
303303
{
304-
m_BoxOverlapChecker.Update();
304+
m_GridPerception.Perceive();
305305
}
306306
}
307307
}

com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class GridSensorComponent : SensorComponent
1313
// dummy sensor only used for debug gizmo
1414
GridSensorBase m_DebugSensor;
1515
List<GridSensorBase> m_Sensors;
16-
internal BoxOverlapChecker m_BoxOverlapChecker;
16+
internal IGridPerception m_GridPerception;
1717

1818
[HideInInspector, SerializeField]
1919
protected internal string m_SensorName = "GridSensor";
@@ -197,7 +197,7 @@ public int ObservationStacks
197197
/// <inheritdoc/>
198198
public override ISensor[] CreateSensors()
199199
{
200-
m_BoxOverlapChecker = new BoxOverlapChecker(
200+
m_GridPerception = new BoxOverlapChecker(
201201
m_CellScale,
202202
m_GridSize,
203203
m_RotateWithAgent,
@@ -211,7 +211,7 @@ public override ISensor[] CreateSensors()
211211

212212
// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
213213
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
214-
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
214+
m_GridPerception.RegisterDebugSensor(m_DebugSensor);
215215

216216
m_Sensors = GetGridSensors().ToList();
217217
if (m_Sensors == null || m_Sensors.Count < 1)
@@ -221,10 +221,10 @@ public override ISensor[] CreateSensors()
221221
}
222222

223223
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
224-
m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker;
224+
m_Sensors[0].m_GridPerception = m_GridPerception;
225225
foreach (var sensor in m_Sensors)
226226
{
227-
m_BoxOverlapChecker.RegisterSensor(sensor);
227+
m_GridPerception.RegisterSensor(sensor);
228228
}
229229

230230
if (ObservationStacks != 1)
@@ -262,8 +262,8 @@ internal void UpdateSensor()
262262
{
263263
if (m_Sensors != null)
264264
{
265-
m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent;
266-
m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
265+
m_GridPerception.RotateWithAgent = m_RotateWithAgent;
266+
m_GridPerception.ColliderMask = m_ColliderMask;
267267
foreach (var sensor in m_Sensors)
268268
{
269269
sensor.CompressionType = m_CompressionType;
@@ -275,22 +275,22 @@ void OnDrawGizmos()
275275
{
276276
if (m_ShowGizmos)
277277
{
278-
if (m_BoxOverlapChecker == null || m_DebugSensor == null)
278+
if (m_GridPerception == null || m_DebugSensor == null)
279279
{
280280
return;
281281
}
282282

283283
m_DebugSensor.ResetPerceptionBuffer();
284-
m_BoxOverlapChecker.UpdateGizmo();
284+
m_GridPerception.UpdateGizmo();
285285
var cellColors = m_DebugSensor.PerceptionBuffer;
286-
var rotation = m_BoxOverlapChecker.GetGridRotation();
286+
var rotation = m_GridPerception.GetGridRotation();
287287

288288
var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z);
289289
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0);
290290
var oldGizmoMatrix = Gizmos.matrix;
291291
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++)
292292
{
293-
var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i);
293+
var cellPosition = m_GridPerception.GetCellGlobalPosition(i);
294294
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale);
295295
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
296296
var colorIndex = cellColors[i] - 1;
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using System;
2+
using UnityEngine;
3+
4+
namespace Unity.MLAgents.Sensors
5+
{
6+
/// <summary>
7+
/// An interface for GridSensor perception that defines the grid cells and collider detecting strategies.
8+
/// </summary>
9+
internal interface IGridPerception
10+
{
11+
bool RotateWithAgent
12+
{
13+
get;
14+
set;
15+
}
16+
17+
LayerMask ColliderMask
18+
{
19+
get;
20+
set;
21+
}
22+
23+
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
24+
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
25+
/// <param name="cellIndex">The index of the cell</param>
26+
Vector3 GetCellLocalPosition(int cellIndex);
27+
28+
/// <summary>
29+
/// Converts the index of the cell to the 3D point (y is zero) in world space
30+
/// based on the result from GetCellLocalPosition()
31+
/// </summary>
32+
/// <returns>Vector3 of the position of the center of the cell in world space</returns>
33+
/// <param name="cellIndex">The index of the cell</param>
34+
Vector3 GetCellGlobalPosition(int cellIndex);
35+
36+
Quaternion GetGridRotation();
37+
38+
/// <summary>
39+
/// Perceive the latest grid status. Detect colliders for each cell, parse the collider arrays,
40+
/// then trigger registered sensors to encode and update with the new grid status.
41+
/// </summary>
42+
void Perceive();
43+
44+
/// <summary>
45+
/// Same as Perceive(), but only load data for debug gizmo.
46+
/// </summary>
47+
void UpdateGizmo();
48+
49+
/// <summary>
50+
/// Register a sensor to this GridPerception to receive the grid perception results.
51+
/// When the GridPerception perceive a new observation, registered sensors will be triggered
52+
/// to encode the new observation and update its data.
53+
/// </summary>
54+
void RegisterSensor(GridSensorBase sensor);
55+
56+
/// <summary>
57+
/// Register an internal debug sensor.
58+
/// Debug sensors will only be triggered when drawing debug gizmos.
59+
/// </summary>
60+
void RegisterDebugSensor(GridSensorBase debugSensor);
61+
}
62+
}

com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public void TestBufferResize()
155155
testGo.transform.position = Vector3.zero;
156156
testObjects.Add(testGo);
157157
var boxOverlap = TestBoxOverlapChecker.CreateChecker(agentGameObject: testGo, centerObject: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5);
158-
boxOverlap.Update();
158+
boxOverlap.Perceive();
159159
Assert.AreEqual(2, boxOverlap.ColliderBuffer.Length);
160160

161161
for (var i = 0; i < 3; i++)
@@ -165,7 +165,7 @@ public void TestBufferResize()
165165
boxGo.AddComponent<BoxCollider>();
166166
testObjects.Add(boxGo);
167167
}
168-
boxOverlap.Update();
168+
boxOverlap.Perceive();
169169
Assert.AreEqual(4, boxOverlap.ColliderBuffer.Length);
170170

171171
for (var i = 0; i < 2; i++)
@@ -175,7 +175,7 @@ public void TestBufferResize()
175175
boxGo.AddComponent<BoxCollider>();
176176
testObjects.Add(boxGo);
177177
}
178-
boxOverlap.Update();
178+
boxOverlap.Perceive();
179179
Assert.AreEqual(5, boxOverlap.ColliderBuffer.Length);
180180

181181
Object.DestroyImmediate(testGo);
@@ -212,7 +212,7 @@ public void TestParseCollidersClosest()
212212
testObjects.Add(boxGo);
213213
}
214214

215-
boxOverlap.Update();
215+
boxOverlap.Perceive();
216216
helper.Verify(1, new List<GameObject> { testObjects[0] });
217217

218218
Object.DestroyImmediate(testGo);
@@ -249,7 +249,7 @@ public void TestParseCollidersAll()
249249
testObjects.Add(boxGo);
250250
}
251251

252-
boxOverlap.Update();
252+
boxOverlap.Perceive();
253253
helper.Verify(3, testObjects);
254254

255255
Object.DestroyImmediate(testGo);
@@ -293,7 +293,7 @@ public void TestOnlyOneChecker()
293293
foreach (var sensor in sensors)
294294
{
295295
var gridsensor = (GridSensorBase)sensor;
296-
if (gridsensor.m_BoxOverlapChecker != null)
296+
if (gridsensor.m_GridPerception != null)
297297
{
298298
numChecker += 1;
299299
}

com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ public void TestMultipleSensors()
175175
string[] tags = { k_Tag1, k_Tag2 };
176176
gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true, useGridSensorBase: true, useTestingGridSensor: true);
177177
var gridSensors = gridSensorComponent.CreateSensors();
178-
Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_BoxOverlapChecker);
179-
Assert.IsNull(((GridSensorBase)gridSensors[1]).m_BoxOverlapChecker);
180-
Assert.IsNull(((GridSensorBase)gridSensors[2]).m_BoxOverlapChecker);
178+
Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_GridPerception);
179+
Assert.IsNull(((GridSensorBase)gridSensors[1]).m_GridPerception);
180+
Assert.IsNull(((GridSensorBase)gridSensors[2]).m_GridPerception);
181181
}
182182

183183
[Test]

0 commit comments

Comments
 (0)