Skip to content

Commit c75eacb

Browse files
author
Chris Elion
authored
[MLA-1141] Rigidbody and ArticulationBody sensors (#4192)
1 parent 5be56c4 commit c75eacb

23 files changed

+799
-48
lines changed

com.unity.ml-agents.extensions/LICENSE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
com.unity.ml-agents.extensions copyright © 2020 Unity Technologies
1+
com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS
22

33
Licensed under the Unity Companion License for Unity-dependent projects -- see
44
[Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@
55

66
namespace Unity.MLAgents.Extensions.Sensors
77
{
8-
8+
/// <summary>
9+
/// Utility class to track a hierarchy of ArticulationBodies.
10+
/// </summary>
911
public class ArticulationBodyPoseExtractor : PoseExtractor
1012
{
1113
ArticulationBody[] m_Bodies;
1214

1315
public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
1416
{
17+
if (rootBody == null)
18+
{
19+
return;
20+
}
21+
1522
if (!rootBody.isRoot)
1623
{
1724
Debug.Log("Must pass ArticulationBody.isRoot");
@@ -38,23 +45,32 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
3845

3946
for (var i = 1; i < numBodies; i++)
4047
{
41-
var body = m_Bodies[i];
42-
var parent = body.GetComponentInParent<ArticulationBody>();
43-
parentIndices[i] = bodyToIndex[parent];
48+
var currentArticBody = m_Bodies[i];
49+
// Component.GetComponentInParent will consider the provided object as well.
50+
// So start looking from the parent.
51+
var currentGameObject = currentArticBody.gameObject;
52+
var parentGameObject = currentGameObject.transform.parent;
53+
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>();
54+
parentIndices[i] = bodyToIndex[parentArticBody];
4455
}
4556

4657
SetParentIndices(parentIndices);
4758
}
4859

60+
/// <inheritdoc/>
61+
protected override Vector3 GetLinearVelocityAt(int index)
62+
{
63+
return m_Bodies[index].velocity;
64+
}
65+
66+
/// <inheritdoc/>
4967
protected override Pose GetPoseAt(int index)
5068
{
5169
var body = m_Bodies[index];
5270
var go = body.gameObject;
5371
var t = go.transform;
5472
return new Pose { rotation = t.rotation, position = t.position };
5573
}
56-
57-
5874
}
5975
}
6076
#endif // UNITY_2020_1_OR_NEWER
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#if UNITY_2020_1_OR_NEWER
2+
using UnityEngine;
3+
using Unity.MLAgents.Sensors;
4+
5+
namespace Unity.MLAgents.Extensions.Sensors
6+
{
7+
public class ArticulationBodySensorComponent : SensorComponent
8+
{
9+
public ArticulationBody RootBody;
10+
11+
[SerializeField]
12+
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
13+
public string sensorName;
14+
15+
/// <summary>
16+
/// Creates a PhysicsBodySensor.
17+
/// </summary>
18+
/// <returns></returns>
19+
public override ISensor CreateSensor()
20+
{
21+
return new PhysicsBodySensor(RootBody, Settings, sensorName);
22+
}
23+
24+
/// <inheritdoc/>
25+
public override int[] GetObservationShape()
26+
{
27+
if (RootBody == null)
28+
{
29+
return new[] { 0 };
30+
}
31+
32+
// TODO static method in PhysicsBodySensor?
33+
// TODO only update PoseExtractor when body changes?
34+
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
35+
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
36+
return new[] { numTransformObservations };
37+
}
38+
}
39+
40+
}
41+
#endif // UNITY_2020_1_OR_NEWER

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using UnityEngine;
2+
using Unity.MLAgents.Sensors;
3+
4+
namespace Unity.MLAgents.Extensions.Sensors
5+
{
6+
/// <summary>
7+
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
8+
/// </summary>
9+
public class PhysicsBodySensor : ISensor
10+
{
11+
int[] m_Shape;
12+
string m_SensorName;
13+
14+
PoseExtractor m_PoseExtractor;
15+
PhysicsSensorSettings m_Settings;
16+
17+
/// <summary>
18+
/// Construct a new PhysicsBodySensor
19+
/// </summary>
20+
/// <param name="rootBody"></param>
21+
/// <param name="settings"></param>
22+
/// <param name="sensorName"></param>
23+
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
24+
{
25+
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
26+
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
27+
m_Settings = settings;
28+
29+
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
30+
m_Shape = new[] { numTransformObservations };
31+
}
32+
33+
#if UNITY_2020_1_OR_NEWER
34+
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
35+
{
36+
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
37+
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
38+
m_Settings = settings;
39+
40+
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
41+
m_Shape = new[] { numTransformObservations };
42+
}
43+
#endif
44+
45+
/// <inheritdoc/>
46+
public int[] GetObservationShape()
47+
{
48+
return m_Shape;
49+
}
50+
51+
/// <inheritdoc/>
52+
public int Write(ObservationWriter writer)
53+
{
54+
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
55+
return numWritten;
56+
}
57+
58+
/// <inheritdoc/>
59+
public byte[] GetCompressedObservation()
60+
{
61+
return null;
62+
}
63+
64+
/// <inheritdoc/>
65+
public void Update()
66+
{
67+
if (m_Settings.UseModelSpace)
68+
{
69+
m_PoseExtractor.UpdateModelSpacePoses();
70+
}
71+
72+
if (m_Settings.UseLocalSpace)
73+
{
74+
m_PoseExtractor.UpdateLocalSpacePoses();
75+
}
76+
}
77+
78+
/// <inheritdoc/>
79+
public void Reset() {}
80+
81+
/// <inheritdoc/>
82+
public SensorCompressionType GetCompressionType()
83+
{
84+
return SensorCompressionType.None;
85+
}
86+
87+
/// <inheritdoc/>
88+
public string GetName()
89+
{
90+
return m_SensorName;
91+
}
92+
}
93+
}

com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
namespace Unity.MLAgents.Extensions.Sensors
66
{
7+
/// <summary>
8+
/// Settings that define the observations generated for physics-based sensors.
9+
/// </summary>
710
[Serializable]
811
public struct PhysicsSensorSettings
912
{
@@ -13,7 +16,7 @@ public struct PhysicsSensorSettings
1316
public bool UseModelSpaceTranslations;
1417

1518
/// <summary>
16-
/// Whether to use model space (relative to the root body) rotatoins as observations.
19+
/// Whether to use model space (relative to the root body) rotations as observations.
1720
/// </summary>
1821
public bool UseModelSpaceRotations;
1922

@@ -27,6 +30,16 @@ public struct PhysicsSensorSettings
2730
/// </summary>
2831
public bool UseLocalSpaceRotations;
2932

33+
/// <summary>
34+
/// Whether to use model space (relative to the root body) linear velocities as observations.
35+
/// </summary>
36+
public bool UseModelSpaceLinearVelocity;
37+
38+
/// <summary>
39+
/// Whether to use local space (relative to the parent body) linear velocities as observations.
40+
/// </summary>
41+
public bool UseLocalSpaceLinearVelocity;
42+
3043
/// <summary>
3144
/// Creates a PhysicsSensorSettings with reasonable default values.
3245
/// </summary>
@@ -45,15 +58,15 @@ public static PhysicsSensorSettings Default()
4558
/// </summary>
4659
public bool UseModelSpace
4760
{
48-
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
61+
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
4962
}
5063

5164
/// <summary>
5265
/// Whether any local space observations are being used.
5366
/// </summary>
5467
public bool UseLocalSpace
5568
{
56-
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
69+
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
5770
}
5871

5972

@@ -70,6 +83,9 @@ public int TransformSize(int numTransforms)
7083
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
7184
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
7285

86+
obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
87+
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
88+
7389
return numTransforms * obsPerTransform;
7490
}
7591
}
@@ -89,8 +105,12 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
89105
var offset = baseOffset;
90106
if (settings.UseModelSpace)
91107
{
92-
foreach (var pose in poseExtractor.ModelSpacePoses)
108+
var poses = poseExtractor.ModelSpacePoses;
109+
var vels = poseExtractor.ModelSpaceVelocities;
110+
111+
for(var i=0; i<poseExtractor.NumPoses; i++)
93112
{
113+
var pose = poses[i];
94114
if(settings.UseModelSpaceTranslations)
95115
{
96116
writer.Add(pose.position, offset);
@@ -101,13 +121,22 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
101121
writer.Add(pose.rotation, offset);
102122
offset += 4;
103123
}
124+
if (settings.UseModelSpaceLinearVelocity)
125+
{
126+
writer.Add(vels[i], offset);
127+
offset += 3;
128+
}
104129
}
105130
}
106131

107132
if (settings.UseLocalSpace)
108133
{
109-
foreach (var pose in poseExtractor.LocalSpacePoses)
134+
var poses = poseExtractor.LocalSpacePoses;
135+
var vels = poseExtractor.LocalSpaceVelocities;
136+
137+
for(var i=0; i<poseExtractor.NumPoses; i++)
110138
{
139+
var pose = poses[i];
111140
if(settings.UseLocalSpaceTranslations)
112141
{
113142
writer.Add(pose.position, offset);
@@ -118,6 +147,11 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
118147
writer.Add(pose.rotation, offset);
119148
offset += 4;
120149
}
150+
if (settings.UseLocalSpaceLinearVelocity)
151+
{
152+
writer.Add(vels[i], offset);
153+
offset += 3;
154+
}
121155
}
122156
}
123157

0 commit comments

Comments
 (0)