Skip to content

[WIP] Observation Types #4825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions .yamato/protobuf-generation-test.yml
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
test_mac_protobuf_generation:
test_linux_protobuf_generation:
name: Protobuf Generation Tests
agent:
type: Unity::VM::osx
image: package-ci/mac:stable
flavor: b1.small
type: Unity::VM
image: package-ci/ubuntu:stable
flavor: b1.large
variables:
GRPC_VERSION: "1.14.1"
CS_PROTO_PATH: "com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects"
HOMEBREW_NO_AUTO_UPDATE: "1"
commands:
- |
brew install nuget
sudo apt-get update && sudo apt-get install -y python3-venv nuget
python3 -m venv venv && source venv/bin/activate
nuget install Grpc.Tools -Version $GRPC_VERSION -OutputDirectory protobuf-definitions/
python3 -m venv venv
. venv/bin/activate
python3 -m pip install --upgrade pip --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
python3 -m pip install grpcio==1.28.1 grpcio-tools==1.13.0 protobuf==3.11.3 six==1.14.0 mypy-protobuf==1.16.0 --progress-bar=off --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
cd protobuf-definitions
chmod +x Grpc.Tools.$GRPC_VERSION/tools/macosx_x64/protoc
chmod +x Grpc.Tools.$GRPC_VERSION/tools/macosx_x64/grpc_csharp_plugin
COMPILER=Grpc.Tools.$GRPC_VERSION/tools/macosx_x64 ./make.sh
pushd protobuf-definitions
chmod +x Grpc.Tools.$GRPC_VERSION/tools/linux_x64/protoc
chmod +x Grpc.Tools.$GRPC_VERSION/tools/linux_x64/grpc_csharp_plugin
COMPILER=Grpc.Tools.$GRPC_VERSION/tools/linux_x64 ./make.sh
popd
mkdir -p artifacts
touch artifacts/proto.patch
git diff --exit-code -- :/ ":(exclude,top)$CS_PROTO_PATH/*.meta" \
Expand Down
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
}
}
observationProto.Shape.AddRange(shape);

// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
{
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
}
else
{
observationProto.ObservationType = ObservationTypeProto.Default;
}
return observationProto;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,23 @@ static ObservationReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK7AgoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKBAwoQT2JzZXJ2YXRp",
"b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
"ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
"dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",
"IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u",
"RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD",
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUaGQoJRmxvYXREYXRh",
"EgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVz",
"c2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1M",
"QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUSRAoQb2JzZXJ2YXRp",
"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvGhkKCUZsb2F0RGF0YRIMCgRkYXRhGAEgAygCQhIKEG9i",
"c2VydmF0aW9uX2RhdGEqKQoUQ29tcHJlc3Npb25UeXBlUHJvdG8SCAoETk9O",
"RRAAEgcKA1BORxABKkYKFE9ic2VydmF0aW9uVHlwZVByb3RvEgsKB0RFRkFV",
"TFQQABIICgRHT0FMEAESCgoGUkVXQVJEEAISCwoHTUVTU0FHRRADQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion
Expand All @@ -50,6 +53,13 @@ internal enum CompressionTypeProto {
[pbr::OriginalName("PNG")] Png = 1,
}

internal enum ObservationTypeProto {
[pbr::OriginalName("DEFAULT")] Default = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
[pbr::OriginalName("MESSAGE")] Message = 3,
}

#endregion

#region Messages
Expand Down Expand Up @@ -82,6 +92,7 @@ public ObservationProto(ObservationProto other) : this() {
compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
dimensionProperties_ = other.dimensionProperties_.Clone();
observationType_ = other.observationType_;
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
Expand Down Expand Up @@ -162,6 +173,17 @@ public ObservationProto Clone() {
get { return dimensionProperties_; }
}

/// <summary>Field number for the "observation_type" field.</summary>
public const int ObservationTypeFieldNumber = 7;
private global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto observationType_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto ObservationType {
get { return observationType_; }
set {
observationType_ = value;
}
}

private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {
Expand Down Expand Up @@ -200,6 +222,7 @@ public bool Equals(ObservationProto other) {
if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
if (ObservationType != other.ObservationType) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}
Expand All @@ -213,6 +236,7 @@ public override int GetHashCode() {
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
hash ^= dimensionProperties_.GetHashCode();
if (ObservationType != 0) hash ^= ObservationType.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
Expand Down Expand Up @@ -242,6 +266,10 @@ public void WriteTo(pb::CodedOutputStream output) {
}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec);
if (ObservationType != 0) {
output.WriteRawTag(56);
output.WriteEnum((int) ObservationType);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -262,6 +290,9 @@ public int CalculateSize() {
}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec);
if (ObservationType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ObservationType);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -279,6 +310,9 @@ public void MergeFrom(ObservationProto other) {
}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
dimensionProperties_.Add(other.dimensionProperties_);
if (other.ObservationType != 0) {
ObservationType = other.ObservationType;
}
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
Expand Down Expand Up @@ -334,6 +368,10 @@ public void MergeFrom(pb::CodedInputStream input) {
dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec);
break;
}
case 56: {
observationType_ = (global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto) input.ReadEnum();
break;
}
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
namespace Unity.MLAgents.Sensors
{

/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
internal enum ObservationType
{
// Collected observations are generic.
Default = 0,
// Collected observations contain goal information.
Goal = 1,
// Collected observations contain reward information.
Reward = 2,
// Collected observations are messages from other agents.
Message = 3,
}


/// <summary>
/// Sensor interface for sensors with variable types.
/// </summary>
internal interface ITypedSensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this interface be public? Are users expected to add ObservationType to their sensors or not for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to be, since any ISensor that isn't an ITypedSensor should be treated as Default type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intention is similar to making the DimensionProperty internal for now. To not change the public API until there are actual features to use it.

{
/// <summary>
/// Returns the ObservationType enum corresponding to the type of the sensor.
/// </summary>
/// <returns>The ObservationType enum</returns>
ObservationType GetObservationType();
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions docs/Python-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ A `TerminalStep` has the following fields:

A `BehaviorSpec` has the following fields :

- `sensor_specs` is a List of `SensorSpec` objects : Each `SensorSpec`
- `observation_specs` is a List of `ObservationSpec` objects : Each `ObservationSpec`
corresponds to an observation's properties: `shape` is a tuple of ints that
corresponds to the shape of the observation (without the number of agents dimension).
`dimension_property` is a tuple of flags containing extra information about how the
data should be processed in the corresponding dimension. Note that the `SensorSpec`
have the same ordering as the ordering of observations in the DecisionSteps,
DecisionStep, TerminalSteps and TerminalStep.
data should be processed in the corresponding dimension. `observation_type` is an enum
corresponding to what type of observation is generating the data (i.e., default, goal,
etc). Note that the `ObservationSpec` have the same ordering as the ordering of observations
in the DecisionSteps, DecisionStep, TerminalSteps and TerminalStep.
- `action_spec` is an `ActionSpec` namedtuple that defines the number and types
of actions for the Agent.

Expand Down
16 changes: 8 additions & 8 deletions gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,16 @@ def _preprocess_single(self, single_visual_obs: np.ndarray) -> np.ndarray:

def _get_n_vis_obs(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
result += 1
return result

def _get_vis_obs_shape(self) -> List[Tuple]:
result: List[Tuple] = []
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
result.append(sen_spec.shape)
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
result.append(obs_spec.shape)
return result

def _get_vis_obs_list(
Expand All @@ -261,9 +261,9 @@ def _get_vector_obs(

def _get_vec_obs_size(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 1:
result += sen_spec.shape[0]
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 1:
result += obs_spec.shape[0]
return result

def render(self, mode="rgb_array"):
Expand Down
6 changes: 3 additions & 3 deletions gym-unity/gym_unity/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TerminalSteps,
BehaviorMapping,
)
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes


def test_gym_wrapper():
Expand Down Expand Up @@ -227,8 +227,8 @@ def create_mock_group_spec(
obs_shapes = [(vector_observation_space_size,)]
for _ in range(number_visual_observations):
obs_shapes += [(8, 8, 3)]
sen_spec = create_sensor_specs_with_shapes(obs_shapes)
return BehaviorSpec(sen_spec, action_spec)
obs_spec = create_observation_specs_with_shapes(obs_shapes)
return BehaviorSpec(obs_spec, action_spec)


def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):
Expand Down
30 changes: 24 additions & 6 deletions ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Any,
Mapping as MappingType,
)
from enum import IntFlag
from enum import IntFlag, Enum
import numpy as np

from mlagents_envs.exception import UnityActionException
Expand Down Expand Up @@ -137,7 +137,7 @@ def empty(spec: "BehaviorSpec") -> "DecisionSteps":
:param spec: The BehaviorSpec for the DecisionSteps
"""
obs: List[np.ndarray] = []
for sen_spec in spec.sensor_specs:
for sen_spec in spec.observation_specs:
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
return DecisionSteps(
obs=obs,
Expand Down Expand Up @@ -235,7 +235,7 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps":
:param spec: The BehaviorSpec for the TerminalSteps
"""
obs: List[np.ndarray] = []
for sen_spec in spec.sensor_specs:
for sen_spec in spec.observation_specs:
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
return TerminalSteps(
obs=obs,
Expand Down Expand Up @@ -458,31 +458,49 @@ class DimensionProperty(IntFlag):
VARIABLE_SIZE = 4


class SensorSpec(NamedTuple):
class ObservationType(Enum):
"""
An Enum which defines the type of information carried in the observation
of the agent.
"""

# Observation information is generic.
DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3


class ObservationSpec(NamedTuple):
"""
A NamedTuple containing information about the observation of Agents.
- shape is a Tuple of int : It corresponds to the shape of
an observation's dimensions.
- dimension_property is a Tuple of DimensionProperties flag, one flag for each
dimension.
- observation_type is an enum of ObservationType.
"""

shape: Tuple[int, ...]
dimension_property: Tuple[DimensionProperty, ...]
observation_type: ObservationType


class BehaviorSpec(NamedTuple):
"""
A NamedTuple containing information about the observation and action
spaces for a group of Agents under the same behavior.
- sensor_specs is a List of SensorSpec NamedTuple containing
- observation_specs is a List of ObservationSpec NamedTuple containing
information about the information of the Agent's observations such as their shapes.
The order of the SensorSpec is the same as the order of the observations of an
agent.
- action_spec is an ActionSpec NamedTuple.
"""

sensor_specs: List[SensorSpec]
observation_specs: List[ObservationSpec]
action_spec: ActionSpec


Expand Down
Loading