1
1
using System . Collections . Generic ;
2
+ using System . Linq ;
2
3
using UnityEngine ;
3
4
4
5
namespace Unity . MLAgents . Sensors
@@ -11,7 +12,7 @@ public class GridSensorComponent : SensorComponent
11
12
{
12
13
// dummy sensor only used for debug gizmo
13
14
GridSensorBase m_DebugSensor ;
14
- List < ISensor > m_Sensors ;
15
+ List < GridSensorBase > m_Sensors ;
15
16
internal BoxOverlapChecker m_BoxOverlapChecker ;
16
17
17
18
[ HideInInspector , SerializeField ]
@@ -196,7 +197,6 @@ public int ObservationStacks
196
197
/// <inheritdoc/>
197
198
public override ISensor [ ] CreateSensors ( )
198
199
{
199
- m_Sensors = new List < ISensor > ( ) ;
200
200
m_BoxOverlapChecker = new BoxOverlapChecker (
201
201
m_CellScale ,
202
202
m_GridSize ,
@@ -213,29 +213,33 @@ public override ISensor[] CreateSensors()
213
213
m_DebugSensor = new GridSensorBase ( "DebugGridSensor" , m_CellScale , m_GridSize , m_DetectableTags , SensorCompressionType . None ) ;
214
214
m_BoxOverlapChecker . RegisterDebugSensor ( m_DebugSensor ) ;
215
215
216
- var gridSensors = GetGridSensors ( ) ;
217
- if ( gridSensors == null || gridSensors . Length < 1 )
216
+ m_Sensors = GetGridSensors ( ) . ToList ( ) ;
217
+ if ( m_Sensors == null || m_Sensors . Count < 1 )
218
218
{
219
219
throw new UnityAgentsException ( "GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." +
220
220
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor." ) ;
221
221
}
222
222
223
- foreach ( var sensor in gridSensors )
223
+ // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
224
+ m_Sensors [ 0 ] . m_BoxOverlapChecker = m_BoxOverlapChecker ;
225
+ foreach ( var sensor in m_Sensors )
224
226
{
225
- if ( ObservationStacks != 1 )
226
- {
227
- m_Sensors . Add ( new StackingSensor ( sensor , ObservationStacks ) ) ;
228
- }
229
- else
230
- {
231
- m_Sensors . Add ( sensor ) ;
232
- }
233
227
m_BoxOverlapChecker . RegisterSensor ( sensor ) ;
234
228
}
235
229
236
- // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
237
- ( ( GridSensorBase ) m_Sensors [ 0 ] ) . m_BoxOverlapChecker = m_BoxOverlapChecker ;
238
- return m_Sensors . ToArray ( ) ;
230
+ if ( ObservationStacks != 1 )
231
+ {
232
+ var sensors = new ISensor [ m_Sensors . Count ] ;
233
+ for ( var i = 0 ; i < m_Sensors . Count ; i ++ )
234
+ {
235
+ sensors [ i ] = new StackingSensor ( m_Sensors [ i ] , ObservationStacks ) ;
236
+ }
237
+ return sensors ;
238
+ }
239
+ else
240
+ {
241
+ return m_Sensors . ToArray ( ) ;
242
+ }
239
243
}
240
244
241
245
/// <summary>
@@ -262,7 +266,7 @@ internal void UpdateSensor()
262
266
m_BoxOverlapChecker . ColliderMask = m_ColliderMask ;
263
267
foreach ( var sensor in m_Sensors )
264
268
{
265
- ( ( GridSensorBase ) sensor ) . CompressionType = m_CompressionType ;
269
+ sensor . CompressionType = m_CompressionType ;
266
270
}
267
271
}
268
272
}
0 commit comments