Skip to content

Commit ea64e2b

Browse files
author
Chris Elion
authored
[MLA-1909] Match3 and Camera/RenderTexture sensor GC improvements (#5233)
1 parent 145793d commit ea64e2b

File tree

13 files changed

+223
-86
lines changed

13 files changed

+223
-86
lines changed

DevProject/Packages/packages-lock.json

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,16 @@
5757
"dependencies": {
5858
"com.unity.barracuda": "1.3.2-preview",
5959
"com.unity.modules.imageconversion": "1.0.0",
60-
"com.unity.modules.jsonserialize": "1.0.0",
61-
"com.unity.modules.physics": "1.0.0",
62-
"com.unity.modules.physics2d": "1.0.0"
60+
"com.unity.modules.jsonserialize": "1.0.0"
6361
}
6462
},
6563
"com.unity.ml-agents.extensions": {
6664
"version": "file:../../com.unity.ml-agents.extensions",
6765
"depth": 0,
6866
"source": "local",
6967
"dependencies": {
70-
"com.unity.ml-agents": "2.0.0-exp.1"
68+
"com.unity.ml-agents": "2.0.0-exp.1",
69+
"com.unity.modules.physics": "1.0.0"
7170
}
7271
},
7372
"com.unity.nuget.mono-cecil": {

com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
using System;
12
using System.Collections.Generic;
23
using Unity.MLAgents.Sensors;
34
using UnityEngine;
4-
using Debug = UnityEngine.Debug;
55

66
namespace Unity.MLAgents.Extensions.Match3
77
{
8-
98
/// <summary>
109
/// Delegate that provides integer values at a given (x,y) coordinate.
1110
/// </summary>
@@ -43,7 +42,7 @@ public enum Match3ObservationType
4342
/// Sensor for Match3 games. Can generate either vector, compressed visual,
4443
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
4544
/// </summary>
46-
public class Match3Sensor : ISensor, IBuiltInSensor
45+
public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable
4746
{
4847
Match3ObservationType m_ObservationType;
4948
ObservationSpec m_ObservationSpec;
@@ -54,6 +53,9 @@ public class Match3Sensor : ISensor, IBuiltInSensor
5453
GridValueProvider m_GridValues;
5554
int m_OneHotSize;
5655

56+
Texture2D m_ObservationTexture;
57+
OneHotToTextureUtil m_TextureUtil;
58+
5759
/// <summary>
5860
/// Create a sensor for the GridValueProvider with the specified observation type.
5961
/// </summary>
@@ -164,7 +166,6 @@ public int Write(ObservationWriter writer)
164166

165167

166168
return offset;
167-
168169
}
169170

170171
/// <inheritdoc/>
@@ -173,8 +174,15 @@ public byte[] GetCompressedObservation()
173174
m_Board.CheckBoardSizes(m_MaxBoardSize);
174175
var height = m_MaxBoardSize.Rows;
175176
var width = m_MaxBoardSize.Columns;
176-
var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
177-
var converter = new OneHotToTextureUtil(height, width);
177+
if (ReferenceEquals(null, m_ObservationTexture))
178+
{
179+
m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
180+
}
181+
182+
if (ReferenceEquals(null, m_TextureUtil))
183+
{
184+
m_TextureUtil = new OneHotToTextureUtil(height, width);
185+
}
178186
var bytesOut = new List<byte>();
179187
var currentBoardSize = m_Board.GetCurrentBoardSize();
180188

@@ -185,17 +193,16 @@ public byte[] GetCompressedObservation()
185193
var numCellImages = (m_OneHotSize + 2) / 3;
186194
for (var i = 0; i < numCellImages; i++)
187195
{
188-
converter.EncodeToTexture(
196+
m_TextureUtil.EncodeToTexture(
189197
m_GridValues,
190-
tempTexture,
198+
m_ObservationTexture,
191199
3 * i,
192200
currentBoardSize.Rows,
193201
currentBoardSize.Columns
194202
);
195-
bytesOut.AddRange(tempTexture.EncodeToPNG());
203+
bytesOut.AddRange(m_ObservationTexture.EncodeToPNG());
196204
}
197205

198-
DestroyTexture(tempTexture);
199206
return bytesOut.ToArray();
200207
}
201208

@@ -234,16 +241,15 @@ public BuiltInSensorType GetBuiltInSensorType()
234241
return BuiltInSensorType.Match3Sensor;
235242
}
236243

237-
static void DestroyTexture(Texture2D texture)
244+
/// <summary>
245+
/// Clean up the owned Texture2D.
246+
/// </summary>
247+
public void Dispose()
238248
{
239-
if (Application.isEditor)
240-
{
241-
// Edit Mode tests complain if we use Destroy()
242-
Object.DestroyImmediate(texture);
243-
}
244-
else
249+
if (!ReferenceEquals(null, m_ObservationTexture))
245250
{
246-
Object.Destroy(texture);
251+
Utilities.DestroyTexture(m_ObservationTexture);
252+
m_ObservationTexture = null;
247253
}
248254
}
249255
}
@@ -274,7 +280,7 @@ public void EncodeToTexture(
274280
int channelOffset,
275281
int currentHeight,
276282
int currentWidth
277-
)
283+
)
278284
{
279285
var i = 0;
280286
// There's an implicit flip converting to PNG from texture, so make sure we

com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using Unity.MLAgents.Sensors;
23
using UnityEngine;
34

@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Extensions.Match3
78
/// Sensor component for a Match3 game.
89
/// </summary>
910
[AddComponentMenu("ML Agents/Match 3 Sensor", (int)MenuGroup.Sensors)]
10-
public class Match3SensorComponent : SensorComponent
11+
public class Match3SensorComponent : SensorComponent, IDisposable
1112
{
1213
/// <summary>
1314
/// Name of the generated Match3Sensor object.
@@ -20,15 +21,38 @@ public class Match3SensorComponent : SensorComponent
2021
/// </summary>
2122
public Match3ObservationType ObservationType = Match3ObservationType.Vector;
2223

24+
private ISensor[] m_Sensors;
25+
2326
/// <inheritdoc/>
2427
public override ISensor[] CreateSensors()
2528
{
29+
// Clean up any existing sensors
30+
Dispose();
31+
2632
var board = GetComponent<AbstractBoard>();
2733
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
2834
// This can be null if numSpecialTypes is 0
2935
var specialSensor = Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
30-
return specialSensor != null ? new ISensor[] { cellSensor, specialSensor } : new ISensor[] { cellSensor };
36+
m_Sensors = specialSensor != null
37+
? new ISensor[] { cellSensor, specialSensor }
38+
: new ISensor[] { cellSensor };
39+
return m_Sensors;
3140
}
3241

42+
/// <summary>
43+
/// Clean up the sensors created by CreateSensors().
44+
/// </summary>
45+
public void Dispose()
46+
{
47+
if (m_Sensors != null)
48+
{
49+
for (var i = 0; i < m_Sensors.Length; i++)
50+
{
51+
((Match3Sensor)m_Sensors[i]).Dispose();
52+
}
53+
54+
m_Sensors = null;
55+
}
56+
}
3357
}
3458
}

com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Collections.Generic;
22
using System.IO;
3+
using System.Reflection;
34
using NUnit.Framework;
45
using Unity.MLAgents.Extensions.Match3;
56
using UnityEngine;
@@ -244,6 +245,17 @@ public void TestVisualObservationsSpecial()
244245
};
245246
SensorTestHelper.CompareObservation(specialSensor, expectedObs3D);
246247
}
248+
249+
// Test that Dispose() cleans up the component and its sensors
250+
sensorComponent.Dispose();
251+
252+
var flags = BindingFlags.Instance | BindingFlags.NonPublic;
253+
var componentSensors = (ISensor[])typeof(Match3SensorComponent).GetField("m_Sensors", flags).GetValue(sensorComponent);
254+
Assert.IsNull(componentSensors);
255+
var cellTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
256+
Assert.IsNull(cellTexture);
257+
var specialTexture = (Texture2D)typeof(Match3Sensor).GetField("m_ObservationTexture", flags).GetValue(cellSensor);
258+
Assert.IsNull(specialTexture);
247259
}
248260

249261

com.unity.ml-agents/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are call
5252
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
5353
- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
5454
`IHeuristicProvider.Heuristic()`. (#5227)
55+
- `Agent` will now call `IDisposable.Dispose()` on all `ISensor`s that implement the `IDisposable` interface. (#5233)
56+
- `CameraSensor`, `RenderTextureSensor`, and `Match3Sensor` will now reuse their `Texture2D`s, reducing the
57+
amount of memory that needs to be allocated during runtime. (#5233)
58+
- Optimzed `ObservationWriter.WriteTexture()` so that it doesn't call `Texture2D.GetPixels32()` for `RGB24` textures.
59+
This results in much less memory being allocated during inference with `CameraSensor` and `RenderTextureSensor`. (#5233)
5560

5661
#### ml-agents / ml-agents-envs / gym-unity (Python)
5762
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ protected virtual void OnDisable()
542542
Academy.Instance.AgentForceReset -= _AgentReset;
543543
NotifyAgentDone(DoneReason.Disabled);
544544
}
545+
546+
CleanupSensors();
545547
m_Brain?.Dispose();
546548
OnAgentDisabled?.Invoke(this);
547549
m_Initialized = false;
@@ -1004,6 +1006,19 @@ internal void InitializeSensors()
10041006
#endif
10051007
}
10061008

1009+
void CleanupSensors()
1010+
{
1011+
// Dispose all attached sensor
1012+
for (var i = 0; i < sensors.Count; i++)
1013+
{
1014+
var sensor = sensors[i];
1015+
if (sensor is IDisposable disposableSensor)
1016+
{
1017+
disposableSensor.Dispose();
1018+
}
1019+
}
1020+
}
1021+
10071022
void InitializeActuators()
10081023
{
10091024
ActuatorComponent[] attachedActuators;

com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using UnityEngine;
23
using UnityEngine.Rendering;
34

@@ -6,7 +7,7 @@ namespace Unity.MLAgents.Sensors
67
/// <summary>
78
/// A sensor that wraps a Camera object to generate visual observations for an agent.
89
/// </summary>
9-
public class CameraSensor : ISensor, IBuiltInSensor
10+
public class CameraSensor : ISensor, IBuiltInSensor, IDisposable
1011
{
1112
Camera m_Camera;
1213
int m_Width;
@@ -15,6 +16,7 @@ public class CameraSensor : ISensor, IBuiltInSensor
1516
string m_Name;
1617
private ObservationSpec m_ObservationSpec;
1718
SensorCompressionType m_CompressionType;
19+
Texture2D m_Texture;
1820

1921
/// <summary>
2022
/// The Camera used for rendering the sensor observations.
@@ -34,7 +36,6 @@ public SensorCompressionType CompressionType
3436
set { m_CompressionType = value; }
3537
}
3638

37-
3839
/// <summary>
3940
/// Creates and returns the camera sensor.
4041
/// </summary>
@@ -56,6 +57,7 @@ public CameraSensor(
5657
var channels = grayscale ? 1 : 3;
5758
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
5859
m_CompressionType = compression;
60+
m_Texture = new Texture2D(width, height, TextureFormat.RGB24, false);
5961
}
6062

6163
/// <summary>
@@ -87,10 +89,9 @@ public byte[] GetCompressedObservation()
8789
{
8890
using (TimerStack.Instance.Scoped("CameraSensor.GetCompressedObservation"))
8991
{
90-
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
92+
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
9193
// TODO support more types here, e.g. JPG
92-
var compressed = texture.EncodeToPNG();
93-
DestroyTexture(texture);
94+
var compressed = m_Texture.EncodeToPNG();
9495
return compressed;
9596
}
9697
}
@@ -104,9 +105,8 @@ public int Write(ObservationWriter writer)
104105
{
105106
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
106107
{
107-
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
108-
var numWritten = writer.WriteTexture(texture, m_Grayscale);
109-
DestroyTexture(texture);
108+
ObservationToTexture(m_Camera, m_Texture, m_Width, m_Height);
109+
var numWritten = writer.WriteTexture(m_Texture, m_Grayscale);
110110
return numWritten;
111111
}
112112
}
@@ -126,19 +126,17 @@ public CompressionSpec GetCompressionSpec()
126126
/// <summary>
127127
/// Renders a Camera instance to a 2D texture at the corresponding resolution.
128128
/// </summary>
129-
/// <returns>The 2D texture.</returns>
130129
/// <param name="obsCamera">Camera.</param>
130+
/// <param name="texture2D">Texture2D to render to.</param>
131131
/// <param name="width">Width of resulting 2D texture.</param>
132132
/// <param name="height">Height of resulting 2D texture.</param>
133-
/// <returns name="texture2D">Texture2D to render to.</returns>
134-
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
133+
public static void ObservationToTexture(Camera obsCamera, Texture2D texture2D, int width, int height)
135134
{
136135
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
137136
{
138137
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
139138
}
140139

141-
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
142140
var oldRec = obsCamera.rect;
143141
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);
144142
var depth = 24;
@@ -163,40 +161,24 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he
163161
obsCamera.rect = oldRec;
164162
RenderTexture.active = prevActiveRt;
165163
RenderTexture.ReleaseTemporary(tempRt);
166-
return texture2D;
167164
}
168165

169-
/// <summary>
170-
/// Computes the observation shape for a camera sensor based on the height, width
171-
/// and grayscale flag.
172-
/// </summary>
173-
/// <param name="width">Width of the image captures from the camera.</param>
174-
/// <param name="height">Height of the image captures from the camera.</param>
175-
/// <param name="grayscale">Whether or not to convert the image to grayscale.</param>
176-
/// <returns>The observation shape.</returns>
177-
internal static int[] GenerateShape(int width, int height, bool grayscale)
166+
/// <inheritdoc/>
167+
public BuiltInSensorType GetBuiltInSensorType()
178168
{
179-
return new[] { height, width, grayscale ? 1 : 3 };
169+
return BuiltInSensorType.CameraSensor;
180170
}
181171

182-
static void DestroyTexture(Texture2D texture)
172+
/// <summary>
173+
/// Clean up the owned Texture2D.
174+
/// </summary>
175+
public void Dispose()
183176
{
184-
if (Application.isEditor)
177+
if (!ReferenceEquals(null, m_Texture))
185178
{
186-
// Edit Mode tests complain if we use Destroy()
187-
// TODO move to extension methods for UnityEngine.Object?
188-
Object.DestroyImmediate(texture);
179+
Utilities.DestroyTexture(m_Texture);
180+
m_Texture = null;
189181
}
190-
else
191-
{
192-
Object.Destroy(texture);
193-
}
194-
}
195-
196-
/// <inheritdoc/>
197-
public BuiltInSensorType GetBuiltInSensorType()
198-
{
199-
return BuiltInSensorType.CameraSensor;
200182
}
201183
}
202184
}

0 commit comments

Comments
 (0)