Skip to content

Commit cb8e220

Browse files
authored
Merge pull request #4825 from Unity-Technologies/sensor-types
[WIP] Observation Types
2 parents f8bc88d + eb0e76c commit cb8e220

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+484
-218
lines changed

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
411411
}
412412
}
413413
observationProto.Shape.AddRange(shape);
414+
415+
// Add the observation type, if any, to the observationProto
416+
var typeSensor = sensor as ITypedSensor;
417+
if (typeSensor != null)
418+
{
419+
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
420+
}
421+
else
422+
{
423+
observationProto.ObservationType = ObservationTypeProto.Default;
424+
}
414425
return observationProto;
415426
}
416427

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,23 @@ static ObservationReflection() {
2525
byte[] descriptorData = global::System.Convert.FromBase64String(
2626
string.Concat(
2727
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
28-
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK7AgoQT2JzZXJ2YXRp",
28+
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKBAwoQT2JzZXJ2YXRp",
2929
"b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
3030
"ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
3131
"dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",
3232
"IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u",
3333
"RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD",
34-
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUaGQoJRmxvYXREYXRh",
35-
"EgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVz",
36-
"c2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1M",
37-
"QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
34+
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUSRAoQb2JzZXJ2YXRp",
35+
"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
36+
"aW9uVHlwZVByb3RvGhkKCUZsb2F0RGF0YRIMCgRkYXRhGAEgAygCQhIKEG9i",
37+
"c2VydmF0aW9uX2RhdGEqKQoUQ29tcHJlc3Npb25UeXBlUHJvdG8SCAoETk9O",
38+
"RRAAEgcKA1BORxABKkYKFE9ic2VydmF0aW9uVHlwZVByb3RvEgsKB0RFRkFV",
39+
"TFQQABIICgRHT0FMEAESCgoGUkVXQVJEEAISCwoHTUVTU0FHRRADQiWqAiJV",
40+
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
3841
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
3942
new pbr::FileDescriptor[] { },
40-
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
41-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
43+
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
44+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
4245
}));
4346
}
4447
#endregion
@@ -50,6 +53,13 @@ internal enum CompressionTypeProto {
5053
[pbr::OriginalName("PNG")] Png = 1,
5154
}
5255

56+
internal enum ObservationTypeProto {
57+
[pbr::OriginalName("DEFAULT")] Default = 0,
58+
[pbr::OriginalName("GOAL")] Goal = 1,
59+
[pbr::OriginalName("REWARD")] Reward = 2,
60+
[pbr::OriginalName("MESSAGE")] Message = 3,
61+
}
62+
5363
#endregion
5464

5565
#region Messages
@@ -82,6 +92,7 @@ public ObservationProto(ObservationProto other) : this() {
8292
compressionType_ = other.compressionType_;
8393
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
8494
dimensionProperties_ = other.dimensionProperties_.Clone();
95+
observationType_ = other.observationType_;
8596
switch (other.ObservationDataCase) {
8697
case ObservationDataOneofCase.CompressedData:
8798
CompressedData = other.CompressedData;
@@ -162,6 +173,17 @@ public ObservationProto Clone() {
162173
get { return dimensionProperties_; }
163174
}
164175

176+
/// <summary>Field number for the "observation_type" field.</summary>
177+
public const int ObservationTypeFieldNumber = 7;
178+
private global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto observationType_ = 0;
179+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
180+
public global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto ObservationType {
181+
get { return observationType_; }
182+
set {
183+
observationType_ = value;
184+
}
185+
}
186+
165187
private object observationData_;
166188
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
167189
public enum ObservationDataOneofCase {
@@ -200,6 +222,7 @@ public bool Equals(ObservationProto other) {
200222
if (!object.Equals(FloatData, other.FloatData)) return false;
201223
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
202224
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
225+
if (ObservationType != other.ObservationType) return false;
203226
if (ObservationDataCase != other.ObservationDataCase) return false;
204227
return Equals(_unknownFields, other._unknownFields);
205228
}
@@ -213,6 +236,7 @@ public override int GetHashCode() {
213236
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
214237
hash ^= compressedChannelMapping_.GetHashCode();
215238
hash ^= dimensionProperties_.GetHashCode();
239+
if (ObservationType != 0) hash ^= ObservationType.GetHashCode();
216240
hash ^= (int) observationDataCase_;
217241
if (_unknownFields != null) {
218242
hash ^= _unknownFields.GetHashCode();
@@ -242,6 +266,10 @@ public void WriteTo(pb::CodedOutputStream output) {
242266
}
243267
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
244268
dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec);
269+
if (ObservationType != 0) {
270+
output.WriteRawTag(56);
271+
output.WriteEnum((int) ObservationType);
272+
}
245273
if (_unknownFields != null) {
246274
_unknownFields.WriteTo(output);
247275
}
@@ -262,6 +290,9 @@ public int CalculateSize() {
262290
}
263291
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
264292
size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec);
293+
if (ObservationType != 0) {
294+
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ObservationType);
295+
}
265296
if (_unknownFields != null) {
266297
size += _unknownFields.CalculateSize();
267298
}
@@ -279,6 +310,9 @@ public void MergeFrom(ObservationProto other) {
279310
}
280311
compressedChannelMapping_.Add(other.compressedChannelMapping_);
281312
dimensionProperties_.Add(other.dimensionProperties_);
313+
if (other.ObservationType != 0) {
314+
ObservationType = other.ObservationType;
315+
}
282316
switch (other.ObservationDataCase) {
283317
case ObservationDataOneofCase.CompressedData:
284318
CompressedData = other.CompressedData;
@@ -334,6 +368,10 @@ public void MergeFrom(pb::CodedInputStream input) {
334368
dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec);
335369
break;
336370
}
371+
case 56: {
372+
observationType_ = (global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto) input.ReadEnum();
373+
break;
374+
}
337375
}
338376
}
339377
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
namespace Unity.MLAgents.Sensors
2+
{
3+
4+
/// <summary>
5+
/// The ObservationType enum of the Sensor.
6+
/// </summary>
7+
internal enum ObservationType
8+
{
9+
// Collected observations are generic.
10+
Default = 0,
11+
// Collected observations contain goal information.
12+
Goal = 1,
13+
// Collected observations contain reward information.
14+
Reward = 2,
15+
// Collected observations are messages from other agents.
16+
Message = 3,
17+
}
18+
19+
20+
/// <summary>
21+
/// Sensor interface for sensors with variable types.
22+
/// </summary>
23+
internal interface ITypedSensor
24+
{
25+
/// <summary>
26+
/// Returns the ObservationType enum corresponding to the type of the sensor.
27+
/// </summary>
28+
/// <returns>The ObservationType enum</returns>
29+
ObservationType GetObservationType();
30+
}
31+
}

com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/Python-API.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,14 @@ A `TerminalStep` has the following fields:
227227

228228
A `BehaviorSpec` has the following fields :
229229

230-
- `sensor_specs` is a List of `SensorSpec` objects : Each `SensorSpec`
230+
- `observation_specs` is a List of `ObservationSpec` objects : Each `ObservationSpec`
231231
corresponds to an observation's properties: `shape` is a tuple of ints that
232232
corresponds to the shape of the observation (without the number of agents dimension).
233233
`dimension_property` is a tuple of flags containing extra information about how the
234-
data should be processed in the corresponding dimension. Note that the `SensorSpec`
235-
have the same ordering as the ordering of observations in the DecisionSteps,
236-
DecisionStep, TerminalSteps and TerminalStep.
234+
data should be processed in the corresponding dimension. `observation_type` is an enum
235+
corresponding to what type of observation is generating the data (i.e., default, goal,
236+
etc). Note that the `ObservationSpec` have the same ordering as the ordering of observations
237+
in the DecisionSteps, DecisionStep, TerminalSteps and TerminalStep.
237238
- `action_spec` is an `ActionSpec` namedtuple that defines the number and types
238239
of actions for the Agent.
239240

gym-unity/gym_unity/envs/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,16 @@ def _preprocess_single(self, single_visual_obs: np.ndarray) -> np.ndarray:
229229

230230
def _get_n_vis_obs(self) -> int:
231231
result = 0
232-
for sen_spec in self.group_spec.sensor_specs:
233-
if len(sen_spec.shape) == 3:
232+
for obs_spec in self.group_spec.observation_specs:
233+
if len(obs_spec.shape) == 3:
234234
result += 1
235235
return result
236236

237237
def _get_vis_obs_shape(self) -> List[Tuple]:
238238
result: List[Tuple] = []
239-
for sen_spec in self.group_spec.sensor_specs:
240-
if len(sen_spec.shape) == 3:
241-
result.append(sen_spec.shape)
239+
for obs_spec in self.group_spec.observation_specs:
240+
if len(obs_spec.shape) == 3:
241+
result.append(obs_spec.shape)
242242
return result
243243

244244
def _get_vis_obs_list(
@@ -261,9 +261,9 @@ def _get_vector_obs(
261261

262262
def _get_vec_obs_size(self) -> int:
263263
result = 0
264-
for sen_spec in self.group_spec.sensor_specs:
265-
if len(sen_spec.shape) == 1:
266-
result += sen_spec.shape[0]
264+
for obs_spec in self.group_spec.observation_specs:
265+
if len(obs_spec.shape) == 1:
266+
result += obs_spec.shape[0]
267267
return result
268268

269269
def render(self, mode="rgb_array"):

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TerminalSteps,
1212
BehaviorMapping,
1313
)
14-
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
14+
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
1515

1616

1717
def test_gym_wrapper():
@@ -227,8 +227,8 @@ def create_mock_group_spec(
227227
obs_shapes = [(vector_observation_space_size,)]
228228
for _ in range(number_visual_observations):
229229
obs_shapes += [(8, 8, 3)]
230-
sen_spec = create_sensor_specs_with_shapes(obs_shapes)
231-
return BehaviorSpec(sen_spec, action_spec)
230+
obs_spec = create_observation_specs_with_shapes(obs_shapes)
231+
return BehaviorSpec(obs_spec, action_spec)
232232

233233

234234
def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):

ml-agents-envs/mlagents_envs/base_env.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Any,
2929
Mapping as MappingType,
3030
)
31-
from enum import IntFlag
31+
from enum import IntFlag, Enum
3232
import numpy as np
3333

3434
from mlagents_envs.exception import UnityActionException
@@ -137,7 +137,7 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps":
137137
:param spec: The BehaviorSpec for the DecisionSteps
138138
"""
139139
obs: List[np.ndarray] = []
140-
for sen_spec in spec.sensor_specs:
140+
for sen_spec in spec.observation_specs:
141141
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
142142
return DecisionSteps(
143143
obs=obs,
@@ -235,7 +235,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps":
235235
:param spec: The BehaviorSpec for the TerminalSteps
236236
"""
237237
obs: List[np.ndarray] = []
238-
for sen_spec in spec.sensor_specs:
238+
for sen_spec in spec.observation_specs:
239239
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
240240
return TerminalSteps(
241241
obs=obs,
@@ -458,31 +458,49 @@ class DimensionProperty(IntFlag):
458458
VARIABLE_SIZE = 4
459459

460460

461-
class SensorSpec(NamedTuple):
461+
class ObservationType(Enum):
462+
"""
463+
An Enum which defines the type of information carried in the observation
464+
of the agent.
465+
"""
466+
467+
# Observation information is generic.
468+
DEFAULT = 0
469+
# Observation contains goal information for current task.
470+
GOAL = 1
471+
# Observation contains reward information for current task.
472+
REWARD = 2
473+
# Observation contains a message from another agent.
474+
MESSAGE = 3
475+
476+
477+
class ObservationSpec(NamedTuple):
462478
"""
463479
A NamedTuple containing information about the observation of Agents.
464480
- shape is a Tuple of int : It corresponds to the shape of
465481
an observation's dimensions.
466482
- dimension_property is a Tuple of DimensionProperties flag, one flag for each
467483
dimension.
484+
- observation_type is an enum of ObservationType.
468485
"""
469486

470487
shape: Tuple[int, ...]
471488
dimension_property: Tuple[DimensionProperty, ...]
489+
observation_type: ObservationType
472490

473491

474492
class BehaviorSpec(NamedTuple):
475493
"""
476494
A NamedTuple containing information about the observation and action
477495
spaces for a group of Agents under the same behavior.
478-
- sensor_specs is a List of SensorSpec NamedTuple containing
496+
- observation_specs is a List of ObservationSpec NamedTuple containing
479497
information about the information of the Agent's observations such as their shapes.
480498
The order of the SensorSpec is the same as the order of the observations of an
481499
agent.
482500
- action_spec is an ActionSpec NamedTuple.
483501
"""
484502

485-
sensor_specs: List[SensorSpec]
503+
observation_specs: List[ObservationSpec]
486504
action_spec: ActionSpec
487505

488506

0 commit comments

Comments
 (0)