diff --git a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs index 84fa7d2c5b..f6049ba529 100644 --- a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs +++ b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs @@ -8,12 +8,12 @@ public class Ball3DAgent : Agent [Header("Specific to Ball3D")] public GameObject ball; Rigidbody m_BallRb; - FloatPropertiesChannel m_ResetParams; + EnvironmentParameters m_ResetParams; public override void Initialize() { m_BallRb = ball.GetComponent(); - m_ResetParams = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } @@ -75,8 +75,8 @@ public override void Heuristic(float[] actionsOut) public void SetBall() { //Set the attributes of the ball by fetching the information from the academy - m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f); - var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f); + m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f); + var scale = m_ResetParams.GetWithDefault("scale", 1.0f); ball.transform.localScale = new Vector3(scale, scale, scale); } diff --git a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs index c9c8b0b512..674d9498d5 100644 --- a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs +++ b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs @@ -8,12 +8,12 @@ public class Ball3DHardAgent : Agent [Header("Specific to Ball3DHard")] public GameObject ball; Rigidbody m_BallRb; - FloatPropertiesChannel m_ResetParams; + EnvironmentParameters m_ResetParams; public override void Initialize() { m_BallRb = ball.GetComponent(); - m_ResetParams = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } @@ -66,8 +66,8 @@ public override void OnEpisodeBegin() public void SetBall() { //Set the attributes of the ball by fetching the information from the academy - m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f); - var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f); + m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f); + var scale = m_ResetParams.GetWithDefault("scale", 1.0f); ball.transform.localScale = new Vector3(scale, scale, scale); } diff --git a/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs b/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs index 6e09fb86dc..6036a08299 100644 --- a/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs @@ -15,14 +15,14 @@ public class BouncerAgent : Agent int m_NumberJumps = 20; int m_JumpLeft = 20; - FloatPropertiesChannel m_ResetParams; + EnvironmentParameters m_ResetParams; public override void Initialize() { m_Rb = gameObject.GetComponent(); m_LookDir = Vector3.zero; - m_ResetParams = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } @@ -121,7 +121,7 @@ void Update() public void SetTargetScale() { - var targetScale = m_ResetParams.GetPropertyWithDefault("target_scale", 1.0f); + var targetScale = m_ResetParams.GetWithDefault("target_scale", 1.0f); target.transform.localScale = new Vector3(targetScale, targetScale, targetScale); } diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 83f38bbef5..de718fab82 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -29,13 +29,14 @@ public class FoodCollectorAgent : Agent public bool contribute; public bool useVectorObs; + EnvironmentParameters m_ResetParams; public override void Initialize() { m_AgentRb = GetComponent(); m_MyArea = area.GetComponent(); m_FoodCollecterSettings = FindObjectOfType(); - + m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } @@ -271,12 +272,12 @@ void OnCollisionEnter(Collision collision) public void SetLaserLengths() { - m_LaserLength = SideChannelUtils.GetSideChannel().GetPropertyWithDefault("laser_length", 1.0f); + m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f); } public void SetAgentScale() { - float agentScale = SideChannelUtils.GetSideChannel().GetPropertyWithDefault("agent_scale", 1.0f); + float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f); gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale); } diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs index f4789b003d..0953ddd747 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs @@ -1,8 +1,6 @@ -using System; using UnityEngine; using UnityEngine.UI; using MLAgents; -using MLAgents.SideChannels; public class FoodCollectorSettings : MonoBehaviour { @@ -14,15 +12,15 @@ public class FoodCollectorSettings : MonoBehaviour public int totalScore; public Text scoreText; - StatsSideChannel m_statsSideChannel; + StatsRecorder m_Recorder; public void Awake() { Academy.Instance.OnEnvironmentReset += EnvironmentReset; - m_statsSideChannel = SideChannelUtils.GetSideChannel(); + m_Recorder = Academy.Instance.StatsRecorder; } - public void EnvironmentReset() + private void EnvironmentReset() { ClearObjects(GameObject.FindGameObjectsWithTag("food")); ClearObjects(GameObject.FindGameObjectsWithTag("badFood")); @@ -54,7 +52,7 @@ public void Update() // need to send every Update() call. if ((Time.frameCount % 100)== 0) { - m_statsSideChannel?.AddStat("TotalScore", totalScore); + m_Recorder.Add("TotalScore", totalScore); } } } diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index e2dd63cd1a..52fbf9fde1 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -29,6 +29,13 @@ public class GridAgent : Agent const int k_Left = 3; const int k_Right = 4; + EnvironmentParameters m_ResetParams; + + public override void Initialize() + { + m_ResetParams = Academy.Instance.EnvironmentParameters; + } + public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker) { // Mask the necessary actions if selected by the user. @@ -37,7 +44,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske // Prevents the agent from picking an action that would make it collide with a wall var positionX = (int)transform.position.x; var positionZ = (int)transform.position.z; - var maxPosition = (int)SideChannelUtils.GetSideChannel().GetPropertyWithDefault("gridSize", 5f) - 1; + var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1; if (positionX == 0) { diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs index 126cb00b0d..22669bea37 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs @@ -14,8 +14,6 @@ public class GridArea : MonoBehaviour public GameObject trueAgent; - FloatPropertiesChannel m_ResetParameters; - Camera m_AgentCam; public GameObject goalPref; @@ -30,9 +28,11 @@ public class GridArea : MonoBehaviour Vector3 m_InitialPosition; + EnvironmentParameters m_ResetParams; + public void Start() { - m_ResetParameters = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; m_Objects = new[] { goalPref, pitPref }; @@ -50,23 +50,23 @@ public void Start() m_InitialPosition = transform.position; } - public void SetEnvironment() + private void SetEnvironment() { - transform.position = m_InitialPosition * (m_ResetParameters.GetPropertyWithDefault("gridSize", 5f) + 1); + transform.position = m_InitialPosition * (m_ResetParams.GetWithDefault("gridSize", 5f) + 1); var playersList = new List(); - for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numObstacles", 1); i++) + for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numObstacles", 1); i++) { playersList.Add(1); } - for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numGoals", 1f); i++) + for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numGoals", 1f); i++) { playersList.Add(0); } players = playersList.ToArray(); - var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f); + var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f); m_Plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f); m_Plane.transform.localPosition = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f); m_Sn.transform.localScale = new Vector3(1, 1, gridSize + 2); @@ -84,7 +84,7 @@ public void SetEnvironment() public void AreaReset() { - var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f); + var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f); foreach (var actor in actorObjs) { DestroyImmediate(actor); @@ -98,7 +98,7 @@ public void AreaReset() { numbers.Add(Random.Range(0, gridSize * gridSize)); } - var numbersA = Enumerable.ToArray(numbers); + var numbersA = numbers.ToArray(); for (var i = 0; i < players.Length; i++) { diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs index d2ae6f6dda..2baee93b3a 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs @@ -8,7 +8,7 @@ public class GridSettings : MonoBehaviour public void Awake() { - SideChannelUtils.GetSideChannel().RegisterCallback("gridSize", f => + Academy.Instance.EnvironmentParameters.RegisterCallback("gridSize", f => { MainCamera.transform.position = new Vector3(-(f - 1) / 2f, f * 1.25f, -(f - 1) / 2f); MainCamera.orthographicSize = (f + 5f) / 2f; diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs index 54ad57b84f..53d06e4a50 100644 --- a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs +++ b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs @@ -49,6 +49,8 @@ public class PushAgentBasic : Agent /// Renderer m_GroundRenderer; + private EnvironmentParameters m_ResetParams; + void Awake() { m_PushBlockSettings = FindObjectOfType(); @@ -70,6 +72,8 @@ public override void Initialize() // Starting material m_GroundMaterial = m_GroundRenderer.material; + m_ResetParams = Academy.Instance.EnvironmentParameters; + SetResetParameters(); } @@ -226,27 +230,23 @@ public override void OnEpisodeBegin() public void SetGroundMaterialFriction() { - var resetParams = SideChannelUtils.GetSideChannel(); - var groundCollider = ground.GetComponent(); - groundCollider.material.dynamicFriction = resetParams.GetPropertyWithDefault("dynamic_friction", 0); - groundCollider.material.staticFriction = resetParams.GetPropertyWithDefault("static_friction", 0); + groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0); + groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0); } public void SetBlockProperties() { - var resetParams = SideChannelUtils.GetSideChannel(); - - var scale = resetParams.GetPropertyWithDefault("block_scale", 2); + var scale = m_ResetParams.GetWithDefault("block_scale", 2); //Set the scale of the block m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale); // Set the drag of the block - m_BlockRb.drag = resetParams.GetPropertyWithDefault("block_drag", 0.5f); + m_BlockRb.drag = m_ResetParams.GetWithDefault("block_drag", 0.5f); } - public void SetResetParameters() + private void SetResetParameters() { SetGroundMaterialFriction(); SetBlockProperties(); diff --git a/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs b/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs index 2f67598829..da369334a6 100644 --- a/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs @@ -21,6 +21,8 @@ public class ReacherAgent : Agent // Frequency of the cosine deviation of the goal along the vertical dimension float m_DeviationFreq; + private EnvironmentParameters m_ResetParams; + /// /// Collect the rigidbodies of the reacher in order to resue them for /// observations and actions. @@ -30,6 +32,8 @@ public override void Initialize() m_RbA = pendulumA.GetComponent(); m_RbB = pendulumB.GetComponent(); + m_ResetParams = Academy.Instance.EnvironmentParameters; + SetResetParameters(); } @@ -110,10 +114,9 @@ public override void OnEpisodeBegin() public void SetResetParameters() { - var fp = SideChannelUtils.GetSideChannel(); - m_GoalSize = fp.GetPropertyWithDefault("goal_size", 5); - m_GoalSpeed = Random.Range(-1f, 1f) * fp.GetPropertyWithDefault("goal_speed", 1); - m_Deviation = fp.GetPropertyWithDefault("deviation", 0); - m_DeviationFreq = fp.GetPropertyWithDefault("deviation_freq", 0); + m_GoalSize = m_ResetParams.GetWithDefault("goal_size", 5); + m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetWithDefault("goal_speed", 1); + m_Deviation = m_ResetParams.GetWithDefault("deviation", 0); + m_DeviationFreq = m_ResetParams.GetWithDefault("deviation_freq", 0); } } diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs index 4b1ff4279c..ab6eb10264 100644 --- a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs +++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs @@ -44,8 +44,7 @@ public void Awake() Physics.defaultSolverVelocityIterations = solverVelocityIterations; // Make sure the Academy singleton is initialized first, since it will create the SideChannels. - var academy = Academy.Instance; - SideChannelUtils.GetSideChannel().RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); }); + Academy.Instance.EnvironmentParameters.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); }); } public void OnDestroy() diff --git a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs index bca66d8158..8ec3202a5f 100644 --- a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs +++ b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs @@ -49,6 +49,8 @@ public enum Position BehaviorParameters m_BehaviorParameters; Vector3 m_Transform; + private EnvironmentParameters m_ResetParams; + public override void Initialize() { m_Existential = 1f / MaxStep; @@ -73,7 +75,7 @@ public override void Initialize() m_LateralSpeed = 0.3f; m_ForwardSpeed = 1.3f; } - else + else { m_LateralSpeed = 0.3f; m_ForwardSpeed = 1.0f; @@ -91,6 +93,8 @@ public override void Initialize() area.playerStates.Add(playerState); m_PlayerIndex = area.playerStates.IndexOf(playerState); playerState.playerIndex = m_PlayerIndex; + + m_ResetParams = Academy.Instance.EnvironmentParameters; } public void MoveAgent(float[] act) @@ -214,7 +218,7 @@ public override void OnEpisodeBegin() { timePenalty = 0; - m_BallTouch = SideChannelUtils.GetSideChannel().GetPropertyWithDefault("ball_touch", 0); + m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0); if (team == Team.Purple) { transform.rotation = Quaternion.Euler(0f, -90f, 0f); diff --git a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs index a435f4eb9f..fcafa5f50b 100644 --- a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs +++ b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/SoccerFieldArea.cs @@ -32,6 +32,8 @@ public class SoccerFieldArea : MonoBehaviour [HideInInspector] public bool canResetBall; + private EnvironmentParameters m_ResetParams; + void Awake() { canResetBall = true; @@ -40,6 +42,8 @@ void Awake() m_BallController = ball.GetComponent(); m_BallController.area = this; ballStartingPos = ball.transform.position; + + m_ResetParams = Academy.Instance.EnvironmentParameters; } IEnumerator ShowGoalUI() @@ -76,7 +80,7 @@ public void ResetBall() ballRb.velocity = Vector3.zero; ballRb.angularVelocity = Vector3.zero; - var ballScale = SideChannelUtils.GetSideChannel().GetPropertyWithDefault("ball_scale", 0.015f); + var ballScale = m_ResetParams.GetWithDefault("ball_scale", 0.015f); ballRb.transform.localScale = new Vector3(ballScale, ballScale, ballScale); } } diff --git a/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs b/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs index 9d99f75c43..49d1c4c84c 100644 --- a/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs @@ -18,7 +18,7 @@ public class TennisAgent : Agent Rigidbody m_AgentRb; Rigidbody m_BallRb; float m_InvertMult; - FloatPropertiesChannel m_ResetParams; + EnvironmentParameters m_ResetParams; // Looks for the scoreboard based on the name of the gameObjects. // Do not modify the names of the Score GameObjects @@ -32,7 +32,7 @@ public override void Initialize() m_BallRb = ball.GetComponent(); var canvas = GameObject.Find(k_CanvasName); GameObject scoreBoard; - m_ResetParams = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; if (invertX) { scoreBoard = canvas.transform.Find(k_ScoreBoardBName).gameObject; @@ -105,7 +105,7 @@ public override void OnEpisodeBegin() public void SetRacket() { - angle = m_ResetParams.GetPropertyWithDefault("angle", 55); + angle = m_ResetParams.GetWithDefault("angle", 55); gameObject.transform.eulerAngles = new Vector3( gameObject.transform.eulerAngles.x, gameObject.transform.eulerAngles.y, @@ -115,7 +115,7 @@ public void SetRacket() public void SetBall() { - scale = m_ResetParams.GetPropertyWithDefault("scale", .5f); + scale = m_ResetParams.GetWithDefault("scale", .5f); ball.transform.localScale = new Vector3(scale, scale, scale); } diff --git a/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs b/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs index 3593dc12ec..2d481028cc 100644 --- a/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs @@ -34,7 +34,7 @@ public class WalkerAgent : Agent Rigidbody m_ChestRb; Rigidbody m_SpineRb; - FloatPropertiesChannel m_ResetParams; + EnvironmentParameters m_ResetParams; public override void Initialize() { @@ -60,7 +60,7 @@ public override void Initialize() m_ChestRb = chest.GetComponent(); m_SpineRb = spine.GetComponent(); - m_ResetParams = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } @@ -179,9 +179,9 @@ public override void OnEpisodeBegin() public void SetTorsoMass() { - m_ChestRb.mass = m_ResetParams.GetPropertyWithDefault("chest_mass", 8); - m_SpineRb.mass = m_ResetParams.GetPropertyWithDefault("spine_mass", 10); - m_HipsRb.mass = m_ResetParams.GetPropertyWithDefault("hip_mass", 15); + m_ChestRb.mass = m_ResetParams.GetWithDefault("chest_mass", 8); + m_SpineRb.mass = m_ResetParams.GetWithDefault("spine_mass", 10); + m_HipsRb.mass = m_ResetParams.GetWithDefault("hip_mass", 15); } public void SetResetParameters() diff --git a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs index db522b6746..69a099a322 100644 --- a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs +++ b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs @@ -42,7 +42,7 @@ public class WallJumpAgent : Agent Vector3 m_JumpTargetPos; Vector3 m_JumpStartingPos; - FloatPropertiesChannel m_FloatProperties; + EnvironmentParameters m_ResetParams; public override void Initialize() { @@ -57,7 +57,7 @@ public override void Initialize() spawnArea.SetActive(false); - m_FloatProperties = SideChannelUtils.GetSideChannel(); + m_ResetParams = Academy.Instance.EnvironmentParameters; } // Begin the jump sequence @@ -316,7 +316,7 @@ void ConfigureAgent(int config) { localScale = new Vector3( localScale.x, - m_FloatProperties.GetPropertyWithDefault("no_wall_height", 0), + m_ResetParams.GetWithDefault("no_wall_height", 0), localScale.z); wall.transform.localScale = localScale; SetModel("SmallWallJump", noWallBrain); @@ -325,15 +325,15 @@ void ConfigureAgent(int config) { localScale = new Vector3( localScale.x, - m_FloatProperties.GetPropertyWithDefault("small_wall_height", 4), + m_ResetParams.GetWithDefault("small_wall_height", 4), localScale.z); wall.transform.localScale = localScale; SetModel("SmallWallJump", smallWallBrain); } else { - var min = m_FloatProperties.GetPropertyWithDefault("big_wall_min_height", 8); - var max = m_FloatProperties.GetPropertyWithDefault("big_wall_max_height", 8); + var min = m_ResetParams.GetWithDefault("big_wall_min_height", 8); + var max = m_ResetParams.GetWithDefault("big_wall_max_height", 8); var height = min + Random.value * (max - min); localScale = new Vector3( localScale.x, diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 388cffa62a..18baf2bab0 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -13,10 +13,6 @@ and this project adheres to - The `--load` and `--train` command-line flags have been deprecated. Training now happens by default, and use `--resume` to resume training instead. (#3705) - The Jupyter notebooks have been removed from the repository. -- Introduced the `SideChannelUtils` to register, unregister and access side - channels. -- `Academy.FloatProperties` was removed, please use - `SideChannelUtils.GetSideChannel()` instead. - Removed the multi-agent gym option from the gym wrapper. For multi-agent scenarios, use the [Low Level Python API](../docs/Python-API.md). - The low level Python API has changed. You can look at the document @@ -38,6 +34,19 @@ and this project adheres to `AgentAction` and `AgentReset` have been removed. - The GhostTrainer has been extended to support asymmetric games and the asymmetric example environment Strikers Vs. Goalie has been added. +- The SideChannel API has changed (#3833, #3660) : + - Introduced the `SideChannelManager` to register, unregister and access side + channels. + - `EnvironmentParameters` replaces the default `FloatProperties`. + You can access the `EnvironmentParameters` with + `Academy.Instance.EnvironmentParameters` on C# and create an + `EnvironmentParametersChannel` on Python + - `SideChannel.OnMessageReceived` is now a protected method (was public) + - SideChannel IncomingMessages methods now take an optional default argument, + which is used when trying to read more data than the message contains. + - Added a feature to allow sending stats from C# environments to TensorBoard + (and other python StatsWriters). To do this from your code, use + `Academy.Instance.StatsRecorder.Add(key, value)`(#3660) - CameraSensorComponent.m_Grayscale and RenderTextureSensorComponent.m_Grayscale were changed from `public` to `private` (#3808). - The `UnityEnv` class from the `gym-unity` package was renamed @@ -53,15 +62,9 @@ and this project adheres to - Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616) -- Added a feature to allow sending stats from C# environments to TensorBoard - (and other python StatsWriters). To do this from your code, use - `SideChannelUtils.GetSideChannel().AddStat(key, value)` - (#3660) - Renamed 'Generalization' feature to 'Environment Parameter Randomization'. - Timer files now contain a dictionary of metadata, including things like the package version numbers. -- SideChannel IncomingMessages methods now take an optional default argument, - which is used when trying to read more data than the message contains. - The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index cb2319a6f1..73dbeec163 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -42,7 +42,7 @@ void FixedUpdate() /// Access the Academy singleton through the /// property. The Academy instance is initialized the first time it is accessed (which will /// typically be by the first initialized in a scene). - /// + /// /// At initialization, the Academy attempts to connect to the Python training process through /// the external communicator. If successful, the training process can train /// instances. When you set an agent's setting @@ -62,7 +62,7 @@ public class Academy : IDisposable /// on each side, although we may allow some flexibility in the future. /// This should be incremented whenever a change is made to the communication protocol. /// - const string k_ApiVersion = "0.16.0"; + const string k_ApiVersion = "0.17.0"; /// /// Unity package version of com.unity.ml-agents. @@ -311,6 +311,31 @@ static int ReadPortFromArgs() } } + private EnvironmentParameters m_EnvironmentParameters; + private StatsRecorder m_StatsRecorder; + + /// + /// Returns the instance. If training + /// features such as Curriculum Learning or Environment Parameter Randomization are used, + /// then the values of the parameters generated from the training process can be + /// retrieved here. + /// + /// + public EnvironmentParameters EnvironmentParameters + { + get { return m_EnvironmentParameters; } + } + + /// + /// Returns the instance. This instance can be used + /// to record any statistics from the Unity environment. + /// + /// + public StatsRecorder StatsRecorder + { + get { return m_StatsRecorder; } + } + /// /// Initializes the environment, configures it and initializes the Academy. /// @@ -321,9 +346,9 @@ void InitializeEnvironment() EnableAutomaticStepping(); - SideChannelUtils.RegisterSideChannel(new EngineConfigurationChannel()); - SideChannelUtils.RegisterSideChannel(new FloatPropertiesChannel()); - SideChannelUtils.RegisterSideChannel(new StatsSideChannel()); + SideChannelsManager.RegisterSideChannel(new EngineConfigurationChannel()); + m_EnvironmentParameters = new EnvironmentParameters(); + m_StatsRecorder = new StatsRecorder(); // Try to launch the communicator by using the arguments passed at launch var port = ReadPortFromArgs(); @@ -477,7 +502,7 @@ public void EnvironmentStep() // If the communicator is not on, we need to clear the SideChannel sending queue if (!IsCommunicatorOn) { - SideChannelUtils.GetSideChannelMessage(); + SideChannelsManager.GetSideChannelMessage(); } using (TimerStack.Instance.Scoped("AgentAct")) @@ -531,7 +556,10 @@ public void Dispose() Communicator?.Dispose(); Communicator = null; - SideChannelUtils.UnregisterAllSideChannels(); + + m_EnvironmentParameters.Dispose(); + m_StatsRecorder.Dispose(); + SideChannelsManager.UnregisterAllSideChannels(); // unregister custom side channels if (m_ModelRunners != null) { diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs index 090ae08f18..bcf8b940ac 100644 --- a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs +++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs @@ -198,7 +198,7 @@ public void SubscribeBrain(string brainKey, BrainParameters brainParameters) void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) { - SideChannelUtils.ProcessSideChannelData(rlInput.SideChannel.ToArray()); + SideChannelsManager.ProcessSideChannelData(rlInput.SideChannel.ToArray()); SendCommandEvent(rlInput.Command); } @@ -365,7 +365,7 @@ void SendBatchedMessageHelper() message.RlInitializationOutput = tempUnityRlInitializationOutput; } - byte[] messageAggregated = SideChannelUtils.GetSideChannelMessage(); + byte[] messageAggregated = SideChannelsManager.GetSideChannelMessage(); message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); var input = Exchange(message); diff --git a/com.unity.ml-agents/Runtime/EnvironmentParameters.cs b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs new file mode 100644 index 0000000000..55cac9176d --- /dev/null +++ b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using MLAgents.SideChannels; + +namespace MLAgents +{ + /// + /// A container for the Environment Parameters that may be modified during training. + /// The keys for those parameters are defined in the trainer configurations and the + /// the values are generated from the training process in features such as Curriculum Learning + /// and Environment Parameter Randomization. + /// + /// One current assumption for all the environment parameters is that they are of type float. + /// + public sealed class EnvironmentParameters + { + /// + /// The side channel that is used to receive the new parameter values. + /// + readonly EnvironmentParametersChannel m_Channel; + + /// + /// Constructor. + /// + internal EnvironmentParameters() + { + m_Channel = new EnvironmentParametersChannel(); + SideChannelsManager.RegisterSideChannel(m_Channel); + } + + /// + /// Returns the parameter value for the specified key. Returns the default value provided + /// if this parameter key does not have a value. Only returns a parameter value if it is + /// of type float. + /// + /// The parameter key + /// Default value for this parameter. + /// + public float GetWithDefault(string key, float defaultValue) + { + return m_Channel.GetWithDefault(key, defaultValue); + } + + /// + /// Registers a callback action for the provided parameter key. Will overwrite any + /// existing action for that parameter. The callback will be called whenever the parameter + /// receives a value from the training process. + /// + /// The parameter key + /// The callback action + public void RegisterCallback(string key, Action action) + { + m_Channel.RegisterCallback(key, action); + } + + /// + /// Returns a list of all the parameter keys that have received values. + /// + /// List of parameter keys. + public IList Keys() + { + return m_Channel.ListParameters(); + } + + internal void Dispose() + { + SideChannelsManager.UnregisterSideChannel(m_Channel); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs.meta b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta similarity index 83% rename from com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs.meta rename to com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta index 3b50458ea1..9e7a85f810 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs.meta +++ b/com.unity.ml-agents/Runtime/EnvironmentParameters.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 2506dff31271f49298fbff21e13fa8b6 +guid: 90ce0b26bef35484890eac0633b85eed MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs index 140010867a..530656e711 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs @@ -3,11 +3,21 @@ namespace MLAgents.SideChannels { + /// /// Side channel that supports modifying attributes specific to the Unity Engine. /// - public class EngineConfigurationChannel : SideChannel + internal class EngineConfigurationChannel : SideChannel { + private enum ConfigurationType : int + { + ScreenResolution = 0, + QualityLevel = 1, + TimeScale = 2, + TargetFrameRate = 3, + CaptureFrameRate = 4 + } + const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7"; /// @@ -20,21 +30,39 @@ internal EngineConfigurationChannel() } /// - public override void OnMessageReceived(IncomingMessage msg) + protected override void OnMessageReceived(IncomingMessage msg) { - var width = msg.ReadInt32(); - var height = msg.ReadInt32(); - var qualityLevel = msg.ReadInt32(); - var timeScale = msg.ReadFloat32(); - var targetFrameRate = msg.ReadInt32(); - - timeScale = Mathf.Clamp(timeScale, 1, 100); - - Screen.SetResolution(width, height, false); - QualitySettings.SetQualityLevel(qualityLevel, true); - Time.timeScale = timeScale; - Time.captureFramerate = 60; - Application.targetFrameRate = targetFrameRate; + var messageType = (ConfigurationType)msg.ReadInt32(); + switch (messageType) + { + case ConfigurationType.ScreenResolution: + var width = msg.ReadInt32(); + var height = msg.ReadInt32(); + Screen.SetResolution(width, height, false); + break; + case ConfigurationType.QualityLevel: + var qualityLevel = msg.ReadInt32(); + QualitySettings.SetQualityLevel(qualityLevel, true); + break; + case ConfigurationType.TimeScale: + var timeScale = msg.ReadFloat32(); + timeScale = Mathf.Clamp(timeScale, 1, 100); + Time.timeScale = timeScale; + break; + case ConfigurationType.TargetFrameRate: + var targetFrameRate = msg.ReadInt32(); + Application.targetFrameRate = targetFrameRate; + break; + case ConfigurationType.CaptureFrameRate: + var captureFrameRate = msg.ReadInt32(); + Time.captureFramerate = captureFrameRate; + break; + default: + Debug.LogWarning( + "Unknown engine configuration received from Python. Make sure" + + " your Unity and Python versions are compatible."); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs new file mode 100644 index 0000000000..55242b329c --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs @@ -0,0 +1,91 @@ +using System.Collections.Generic; +using System; +using UnityEngine; + +namespace MLAgents.SideChannels +{ + /// + /// Lists the different data types supported. + /// + internal enum EnvironmentDataTypes + { + Float = 0 + } + + /// + /// A side channel that manages the environment parameter values from Python. Currently + /// limited to parameters of type float. + /// + internal class EnvironmentParametersChannel : SideChannel + { + Dictionary m_Parameters = new Dictionary(); + Dictionary> m_RegisteredActions = + new Dictionary>(); + + const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time, and is created by the Academy. + /// + internal EnvironmentParametersChannel() + { + ChannelId = new Guid(k_EnvParamsId); + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + var key = msg.ReadString(); + var type = msg.ReadInt32(); + if ((int)EnvironmentDataTypes.Float == type) + { + var value = msg.ReadFloat32(); + + m_Parameters[key] = value; + + Action action; + m_RegisteredActions.TryGetValue(key, out action); + action?.Invoke(value); + } + else + { + Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); + } + } + + /// + /// Returns the parameter value associated with the provided key. Returns the default + /// value if one doesn't exist. + /// + /// Parameter key. + /// Default value to return. + /// + public float GetWithDefault(string key, float defaultValue) + { + float valueOut; + bool hasKey = m_Parameters.TryGetValue(key, out valueOut); + return hasKey ? valueOut : defaultValue; + } + + /// + /// Registers a callback for the associated parameter key. Will overwrite any existing + /// actions for this parameter key. + /// + /// The parameter key. + /// The callback. + public void RegisterCallback(string key, Action action) + { + m_RegisteredActions[key] = action; + } + + /// + /// Returns all parameter keys that have a registered value. + /// + /// + public IList ListParameters() + { + return new List(m_Parameters.Keys); + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta new file mode 100644 index 0000000000..f118b1f99f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a849760d5bec946b884984e35c66fcfa +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs index a61db071e4..2bcaf3c919 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/FloatPropertiesChannel.cs @@ -29,7 +29,7 @@ public class FloatPropertiesChannel : SideChannel } /// - public override void OnMessageReceived(IncomingMessage msg) + protected override void OnMessageReceived(IncomingMessage msg) { var key = msg.ReadString(); var value = msg.ReadFloat32(); @@ -41,8 +41,12 @@ public override void OnMessageReceived(IncomingMessage msg) action?.Invoke(value); } - /// - public void SetProperty(string key, float value) + /// + /// Sets one of the float properties of the environment. This data will be sent to Python. + /// + /// The string identifier of the property. + /// The float value of the property. + public void Set(string key, float value) { m_FloatProperties[key] = value; using (var msgOut = new OutgoingMessage()) @@ -57,19 +61,35 @@ public void SetProperty(string key, float value) action?.Invoke(value); } - public float GetPropertyWithDefault(string key, float defaultValue) + /// + /// Get an Environment property with a default value. If there is a value for this property, + /// it will be returned, otherwise, the default value will be returned. + /// + /// The string identifier of the property. + /// The default value of the property. + /// + public float GetWithDefault(string key, float defaultValue) { float valueOut; bool hasKey = m_FloatProperties.TryGetValue(key, out valueOut); return hasKey ? valueOut : defaultValue; } + /// + /// Registers an action to be performed everytime the property is changed. + /// + /// The string identifier of the property. + /// The action that ill be performed. Takes a float as input. public void RegisterCallback(string key, Action action) { m_RegisteredActions[key] = action; } - public IList ListProperties() + /// + /// Returns a list of all the string identifiers of the properties currently present. + /// + /// The list of string identifiers + public IList Keys() { return new List(m_FloatProperties.Keys); } diff --git a/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs index 298e4d31da..4720c4c7d6 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/RawBytesChannel.cs @@ -22,7 +22,7 @@ public RawBytesChannel(Guid channelId) } /// - public override void OnMessageReceived(IncomingMessage msg) + protected override void OnMessageReceived(IncomingMessage msg) { m_MessagesReceived.Add(msg.GetRawBytes()); } diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs index 4811796abc..8edaeb4655 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannel.cs @@ -7,6 +7,13 @@ namespace MLAgents.SideChannels /// Side channels provide an alternative mechanism of sending/receiving data from Unity /// to Python that is outside of the traditional machine learning loop. ML-Agents provides /// some specific implementations of side channels, but users can create their own. + /// + /// To create your own, you'll need to create two, new mirrored classes, one in Unity (by + /// extending ) and another in Python by extending a Python class + /// also called SideChannel. Then, within your project, use + /// and + /// to register and unregister your + /// custom side channel. /// public abstract class SideChannel { @@ -25,12 +32,20 @@ public Guid ChannelId protected set; } + internal void ProcessMessage(byte[] msg) + { + using (var incomingMsg = new IncomingMessage(msg)) + { + OnMessageReceived(incomingMsg); + } + } + /// /// Is called by the communicator every time a message is received from Python by the SideChannel. /// Can be called multiple times per simulation step if multiple messages were sent. /// /// The incoming message. - public abstract void OnMessageReceived(IncomingMessage msg); + protected abstract void OnMessageReceived(IncomingMessage msg); /// /// Queues a message to be sent to Python during the next simulation step. diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs b/com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs similarity index 76% rename from com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs rename to com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs index 3a90c5d299..09edd63ec9 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/SideChannelUtils.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs @@ -9,9 +9,8 @@ namespace MLAgents.SideChannels /// Collection of static utilities for managing the registering/unregistering of /// and the sending/receiving of messages for all the channels. /// - public static class SideChannelUtils + public static class SideChannelsManager { - private static Dictionary RegisteredChannels = new Dictionary(); private struct CachedSideChannelMessage @@ -20,34 +19,34 @@ private struct CachedSideChannelMessage public byte[] Message; } - private static Queue m_CachedMessages = new Queue(); + private static readonly Queue m_CachedMessages = + new Queue(); /// - /// Registers a side channel to the communicator. The side channel will exchange - /// messages with its Python equivalent. + /// Register a side channel to begin sending and receiving messages. This method is + /// available for environments that have custom side channels. All built-in side + /// channels within the ML-Agents Toolkit are managed internally and do not need to + /// be explicitly registered/unregistered. A side channel may only be registered once. /// - /// The side channel to be registered. + /// The side channel to register. public static void RegisterSideChannel(SideChannel sideChannel) { var channelId = sideChannel.ChannelId; if (RegisteredChannels.ContainsKey(channelId)) { - throw new UnityAgentsException(string.Format( - "A side channel with type index {0} is already registered. You cannot register multiple " + - "side channels of the same id.", channelId)); + throw new UnityAgentsException( + $"A side channel with id {channelId} is already registered. " + + "You cannot register multiple side channels of the same id."); } // Process any messages that we've already received for this channel ID. var numMessages = m_CachedMessages.Count; - for (int i = 0; i < numMessages; i++) + for (var i = 0; i < numMessages; i++) { var cachedMessage = m_CachedMessages.Dequeue(); if (channelId == cachedMessage.ChannelId) { - using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) - { - sideChannel.OnMessageReceived(incomingMsg); - } + sideChannel.ProcessMessage(cachedMessage.Message); } else { @@ -58,9 +57,17 @@ public static void RegisterSideChannel(SideChannel sideChannel) } /// - /// Unregisters a side channel from the communicator. + /// Unregister a side channel to stop sending and receiving messages. This method is + /// available for environments that have custom side channels. All built-in side + /// channels within the ML-Agents Toolkit are managed internally and do not need to + /// be explicitly registered/unregistered. Unregistering a side channel that has already + /// been unregistered (or never registered in the first place) has no negative side effects. + /// Note that unregistering a side channel may not stop the Python side + /// from sending messages, but it does mean that sent messages with not result in a call + /// to . Furthermore, + /// those messages will not be buffered and will, in essence, be lost. /// - /// The side channel to be unregistered. + /// The side channel to unregister. public static void UnregisterSideChannel(SideChannel sideChannel) { if (RegisteredChannels.ContainsKey(sideChannel.ChannelId)) @@ -72,7 +79,7 @@ public static void UnregisterSideChannel(SideChannel sideChannel) /// /// Unregisters all the side channels from the communicator. /// - public static void UnregisterAllSideChannels() + internal static void UnregisterAllSideChannels() { RegisteredChannels = new Dictionary(); } @@ -83,7 +90,7 @@ public static void UnregisterAllSideChannels() /// /// /// - public static T GetSideChannel() where T: SideChannel + internal static T GetSideChannel() where T: SideChannel { foreach (var sc in RegisteredChannels.Values) { @@ -95,26 +102,6 @@ public static T GetSideChannel() where T: SideChannel return null; } - /// - /// Returns all SideChannels of Type T that are registered. Use if possible, - /// as that does not make any memory allocations. - /// - /// - /// - public static List GetSideChannels() where T: SideChannel - { - var output = new List(); - - foreach (var sc in RegisteredChannels.Values) - { - if (sc.GetType() == typeof(T)) - { - output.Add((T) sc); - } - } - return output; - } - /// /// Grabs the messages that the registered side channels will send to Python at the current step /// into a singe byte array. @@ -174,10 +161,7 @@ internal static void ProcessSideChannelData(Dictionary sideCh var cachedMessage = m_CachedMessages.Dequeue(); if (sideChannels.ContainsKey(cachedMessage.ChannelId)) { - using (var incomingMsg = new IncomingMessage(cachedMessage.Message)) - { - sideChannels[cachedMessage.ChannelId].OnMessageReceived(incomingMsg); - } + sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message); } else { @@ -214,10 +198,7 @@ internal static void ProcessSideChannelData(Dictionary sideCh } if (sideChannels.ContainsKey(channelId)) { - using (var incomingMsg = new IncomingMessage(message)) - { - sideChannels[channelId].OnMessageReceived(incomingMsg); - } + sideChannels[channelId].ProcessMessage(message); } else { @@ -233,6 +214,5 @@ internal static void ProcessSideChannelData(Dictionary sideCh } } } - } } diff --git a/com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs.meta new file mode 100644 index 0000000000..251cc14632 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/SideChannelsManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ccc0d134445f947349c68a6d07e3cdc2 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs index 52a46499e4..5ab776d2bb 100644 --- a/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs +++ b/com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs @@ -2,38 +2,15 @@ namespace MLAgents.SideChannels { /// - /// Determines the behavior of how multiple stats within the same summary period are combined. + /// A Side Channel for sending data. /// - public enum StatAggregationMethod - { - /// - /// Values within the summary period are averaged before reporting. - /// Note that values from the same C# environment in the same step may replace each other. - /// - Average = 0, - - /// - /// Only the most recent value is reported. - /// To avoid conflicts between multiple environments, the ML Agents environment will only - /// keep stats from worker index 0. - /// - MostRecent = 1 - } - - /// - /// Add stats (key-value pairs) for reporting. The ML Agents environment will send these to a StatsReporter - /// instance, which means the values will appear in the Tensorboard summary, as well as trainer gauges. - /// Note that stats are only written every summary_frequency steps; See - /// for options on how multiple values are handled. - /// - public class StatsSideChannel : SideChannel + internal class StatsSideChannel : SideChannel { const string k_StatsSideChannelDefaultId = "a1d8f7b7-cec8-50f9-b78b-d3e165a78520"; /// - /// Initializes the side channel with the provided channel ID. - /// The constructor is internal because only one instance is - /// supported at a time, and is created by the Academy. + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time. /// internal StatsSideChannel() { @@ -41,18 +18,12 @@ internal StatsSideChannel() } /// - /// Add a stat value for reporting. This will appear in the Tensorboard summary and trainer gauges. - /// You can nest stats in Tensorboard with "/". - /// Note that stats are only written to Tensorboard each summary_frequency steps; if a stat is - /// received multiple times, only the most recent version is used. - /// To avoid conflicts between multiple environments, only stats from worker index 0 are used. + /// Add a stat value for reporting. /// /// The stat name. - /// The stat value. You can nest stats in Tensorboard by using "/". + /// The stat value. /// How multiple values should be treated. - public void AddStat( - string key, float value, StatAggregationMethod aggregationMethod = StatAggregationMethod.Average - ) + public void AddStat(string key, float value, StatAggregationMethod aggregationMethod) { using (var msg = new OutgoingMessage()) { @@ -64,7 +35,7 @@ public void AddStat( } /// - public override void OnMessageReceived(IncomingMessage msg) + protected override void OnMessageReceived(IncomingMessage msg) { throw new UnityAgentsException("StatsSideChannel should never receive messages."); } diff --git a/com.unity.ml-agents/Runtime/StatsRecorder.cs b/com.unity.ml-agents/Runtime/StatsRecorder.cs new file mode 100644 index 0000000000..a4c6af352e --- /dev/null +++ b/com.unity.ml-agents/Runtime/StatsRecorder.cs @@ -0,0 +1,71 @@ +using MLAgents.SideChannels; + +namespace MLAgents +{ + /// + /// Determines the behavior of how multiple stats within the same summary period are combined. + /// + public enum StatAggregationMethod + { + /// + /// Values within the summary period are averaged before reporting. + /// Note that values from the same C# environment in the same step may replace each other. + /// + Average = 0, + + /// + /// Only the most recent value is reported. + /// To avoid conflicts when training with multiple concurrent environments, only + /// stats from worker index 0 will be tracked. + /// + MostRecent = 1 + } + + /// + /// Add stats (key-value pairs) for reporting. These values will sent these to a StatsReporter + /// instance, which means the values will appear in the TensorBoard summary, as well as trainer + /// gauges. You can nest stats in TensorBoard by adding "/" in the name (e.g. "Agent/Health" + /// and "Agent/Wallet"). Note that stats are only written to TensorBoard each summary_frequency + /// steps (a trainer configuration). If a stat is received multiple times, within that period + /// then the values will be aggregated using the provided. + /// + public sealed class StatsRecorder + { + /// + /// The side channel that is used to receive the new parameter values. + /// + readonly StatsSideChannel m_Channel; + + /// + /// Constructor. + /// + internal StatsRecorder() + { + m_Channel = new StatsSideChannel(); + SideChannelsManager.RegisterSideChannel(m_Channel); + } + + /// + /// Add a stat value for reporting. + /// + /// The stat name. + /// + /// The stat value. You can nest stats in TensorBoard by using "/". + /// + /// + /// How multiple values sent in the same summary window should be treated. + /// + public void Add( + string key, + float value, + StatAggregationMethod aggregationMethod = StatAggregationMethod.Average) + { + m_Channel.AddStat(key, value, aggregationMethod); + } + + internal void Dispose() + { + SideChannelsManager.UnregisterSideChannel(m_Channel); + } + } +} diff --git a/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta b/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta new file mode 100644 index 0000000000..bfc4addbb1 --- /dev/null +++ b/com.unity.ml-agents/Runtime/StatsRecorder.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d9add8900e8a746e6a4cb410cb27d664 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index fdf67b4f0d..6a0f8cd87c 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -210,7 +210,9 @@ public void TestAcademy() Assert.AreEqual(0, aca.EpisodeCount); Assert.AreEqual(0, aca.StepCount); Assert.AreEqual(0, aca.TotalStepCount); - Assert.AreNotEqual(null, SideChannelUtils.GetSideChannel()); + Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel()); + Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel()); + Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel()); // Check that Dispose is idempotent aca.Dispose(); @@ -221,14 +223,20 @@ public void TestAcademy() [Test] public void TestAcademyDispose() { - var floatProperties1 = SideChannelUtils.GetSideChannel(); + var envParams1 = SideChannelsManager.GetSideChannel(); + var engineParams1 = SideChannelsManager.GetSideChannel(); + var statsParams1 = SideChannelsManager.GetSideChannel(); Academy.Instance.Dispose(); Academy.Instance.LazyInitialize(); - var floatProperties2 = SideChannelUtils.GetSideChannel(); + var envParams2 = SideChannelsManager.GetSideChannel(); + var engineParams2 = SideChannelsManager.GetSideChannel(); + var statsParams2 = SideChannelsManager.GetSideChannel(); Academy.Instance.Dispose(); - Assert.AreNotEqual(floatProperties1, floatProperties2); + Assert.AreNotEqual(envParams1, envParams2); + Assert.AreNotEqual(engineParams1, engineParams2); + Assert.AreNotEqual(statsParams1, statsParams2); } [Test] diff --git a/com.unity.ml-agents/Tests/Editor/SideChannelTests.cs b/com.unity.ml-agents/Tests/Editor/SideChannelTests.cs index 3a236bfcfe..0cfdd0794f 100644 --- a/com.unity.ml-agents/Tests/Editor/SideChannelTests.cs +++ b/com.unity.ml-agents/Tests/Editor/SideChannelTests.cs @@ -18,7 +18,7 @@ public TestSideChannel() ChannelId = new Guid("6afa2c06-4f82-11ea-b238-784f4387d1f7"); } - public override void OnMessageReceived(IncomingMessage msg) + protected override void OnMessageReceived(IncomingMessage msg) { messagesReceived.Add(msg.ReadInt32()); } @@ -45,8 +45,8 @@ public void TestIntegerSideChannel() intSender.SendInt(5); intSender.SendInt(6); - byte[] fakeData = SideChannelUtils.GetSideChannelMessage(dictSender); - SideChannelUtils.ProcessSideChannelData(dictReceiver, fakeData); + byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender); + SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData); Assert.AreEqual(intReceiver.messagesReceived[0], 4); Assert.AreEqual(intReceiver.messagesReceived[1], 5); @@ -67,8 +67,8 @@ public void TestRawBytesSideChannel() strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1)); strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2)); - byte[] fakeData = SideChannelUtils.GetSideChannelMessage(dictSender); - SideChannelUtils.ProcessSideChannelData(dictReceiver, fakeData); + byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender); + SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData); var messages = strReceiver.GetAndClearReceivedMessages(); @@ -90,31 +90,31 @@ public void TestFloatPropertiesSideChannel() var dictSender = new Dictionary { { propB.ChannelId, propB } }; propA.RegisterCallback(k1, f => { wasCalled++; }); - var tmp = propB.GetPropertyWithDefault(k2, 3.0f); + var tmp = propB.GetWithDefault(k2, 3.0f); Assert.AreEqual(tmp, 3.0f); - propB.SetProperty(k2, 1.0f); - tmp = propB.GetPropertyWithDefault(k2, 3.0f); + propB.Set(k2, 1.0f); + tmp = propB.GetWithDefault(k2, 3.0f); Assert.AreEqual(tmp, 1.0f); - byte[] fakeData = SideChannelUtils.GetSideChannelMessage(dictSender); - SideChannelUtils.ProcessSideChannelData(dictReceiver, fakeData); + byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender); + SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData); - tmp = propA.GetPropertyWithDefault(k2, 3.0f); + tmp = propA.GetWithDefault(k2, 3.0f); Assert.AreEqual(tmp, 1.0f); Assert.AreEqual(wasCalled, 0); - propB.SetProperty(k1, 1.0f); + propB.Set(k1, 1.0f); Assert.AreEqual(wasCalled, 0); - fakeData = SideChannelUtils.GetSideChannelMessage(dictSender); - SideChannelUtils.ProcessSideChannelData(dictReceiver, fakeData); + fakeData = SideChannelsManager.GetSideChannelMessage(dictSender); + SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData); Assert.AreEqual(wasCalled, 1); - var keysA = propA.ListProperties(); + var keysA = propA.Keys(); Assert.AreEqual(2, keysA.Count); Assert.IsTrue(keysA.Contains(k1)); Assert.IsTrue(keysA.Contains(k2)); - var keysB = propA.ListProperties(); + var keysB = propA.Keys(); Assert.AreEqual(2, keysB.Count); Assert.IsTrue(keysB.Contains(k1)); Assert.IsTrue(keysB.Contains(k2)); diff --git a/docs/Custom-SideChannels.md b/docs/Custom-SideChannels.md index e0a6568f2c..b93e4fe30e 100644 --- a/docs/Custom-SideChannels.md +++ b/docs/Custom-SideChannels.md @@ -3,7 +3,7 @@ You can create your own side channel in C# and Python and use it to communicate custom data structures between the two. This can be useful for situations in which the data to be sent is too complex or structured for the built-in -`FloatPropertiesChannel`, or is not related to any specific agent, and therefore +`EnvironmentParameters`, or is not related to any specific agent, and therefore inappropriate as an agent observation. ## Overview @@ -24,7 +24,7 @@ To send data from C# to Python, create an `OutgoingMessage` instance, add data t `base.QueueMessageToSend(msg)` method inside the side channel, and call the `OutgoingMessage.Dispose()` method. -To register a side channel on the Unity side, call `SideChannelUtils.RegisterSideChannel` with the side channel +To register a side channel on the Unity side, call `SideChannelManager.RegisterSideChannel` with the side channel as only argument. ### Python side @@ -122,8 +122,8 @@ public class RegisterStringLogSideChannel : MonoBehaviour // When a Debug.Log message is created, we send it to the stringChannel Application.logMessageReceived += stringChannel.SendDebugStatementToPython; - // The channel must be registered with the SideChannelUtils class - SideChannelUtils.RegisterSideChannel(stringChannel); + // The channel must be registered with the SideChannelManager class + SideChannelManager.RegisterSideChannel(stringChannel); } public void OnDestroy() @@ -131,7 +131,7 @@ public class RegisterStringLogSideChannel : MonoBehaviour // De-register the Debug.Log callback Application.logMessageReceived -= stringChannel.SendDebugStatementToPython; if (Academy.IsInitialized){ - SideChannelUtils.UnregisterSideChannel(stringChannel); + SideChannelManager.UnregisterSideChannel(stringChannel); } } diff --git a/docs/Migrating.md b/docs/Migrating.md index 3d9508ed77..94633e00df 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -33,6 +33,19 @@ double-check that the versions are in the same. The versions can be found in - The signature of `Agent.Heuristic()` was changed to take a `float[]` as a parameter, instead of returning the array. This was done to prevent a common source of error where users would return arrays of the wrong size. +- The SideChannel API has changed (#3833, #3660) : + - Introduced the `SideChannelManager` to register, unregister and access side + channels. + - `EnvironmentParameters` replaces the default `FloatProperties`. + You can access the `EnvironmentParameters` with + `Academy.Instance.EnvironmentParameters` on C# and create an + `EnvironmentParametersChannel` on Python + - `SideChannel.OnMessageReceived` is now a protected method (was public) + - SideChannel IncomingMessages methods now take an optional default argument, + which is used when trying to read more data than the message contains. + - Added a feature to allow sending stats from C# environments to TensorBoard + (and other python StatsWriters). To do this from your code, use + `Academy.Instance.StatsRecorder.Add(key, value)`(#3660) - `num_updates` and `train_interval` for SAC have been replaced with `steps_per_update`. - The `UnityEnv` class from the `gym-unity` package was renamed `UnityToGymWrapper` and no longer creates the `UnityEnvironment`. Instead, @@ -51,18 +64,14 @@ double-check that the versions are in the same. The versions can be found in - To force-overwrite files from a pre-existing run, add the `--force` command-line flag. - The Jupyter notebooks have been removed from the repository. -- `Academy.FloatProperties` was removed. -- `Academy.RegisterSideChannel` and `Academy.UnregisterSideChannel` were - removed. -- Replace `Academy.FloatProperties` with - `SideChannelUtils.GetSideChannel()`. -- Replace `Academy.RegisterSideChannel` with - `SideChannelUtils.RegisterSideChannel()`. -- Replace `Academy.UnregisterSideChannel` with - `SideChannelUtils.UnregisterSideChannel`. - If your Agent class overrides `Heuristic()`, change the signature to `public override void Heuristic(float[] actionsOut)` and assign values to `actionsOut` instead of returning an array. +- If you used `SideChannels` you must: + - Replace `Academy.FloatProperties` with `Academy.Instance.EnvironmentParameters`. + - `Academy.RegisterSideChannel` and `Academy.UnregisterSideChannel` were + removed. Use `SideChannelManager.RegisterSideChannel` and + `SideChannelManager.UnregisterSideChannel` instead. - Set `steps_per_update` to be around equal to the number of agents in your environment, times `num_updates` and divided by `train_interval`. - Replace `UnityEnv` with `UnityToGymWrapper` in your code. The constructor diff --git a/docs/Python-API.md b/docs/Python-API.md index 751d298ac0..8a32ccae7b 100644 --- a/docs/Python-API.md +++ b/docs/Python-API.md @@ -273,11 +273,12 @@ The `EngineConfiguration` side channel allows you to modify the time-scale, reso `EngineConfigurationChannel` has two methods : * `set_configuration_parameters` which takes the following arguments: - * `width`: Defines the width of the display. Default 80. - * `height`: Defines the height of the display. Default 80. - * `quality_level`: Defines the quality level of the simulation. Default 1. - * `time_scale`: Defines the multiplier for the deltatime in the simulation. If set to a higher value, time will pass faster in the simulation but the physics may perform unpredictably. Default 20. - * `target_frame_rate`: Instructs simulation to try to render at a specified frame rate. Default -1. + * `width`: Defines the width of the display. (Must be set alongside height) + * `height`: Defines the height of the display. (Must be set alongside width) + * `quality_level`: Defines the quality level of the simulation. + * `time_scale`: Defines the multiplier for the deltatime in the simulation. If set to a higher value, time will pass faster in the simulation but the physics may perform unpredictably. + * `target_frame_rate`: Instructs simulation to try to render at a specified frame rate. + * `capture_frame_rate` Instructs the simulation to consider time between updates to always be constant, regardless of the actual frame rate. * `set_configuration` with argument config which is an `EngineConfig` NamedTuple object. @@ -297,41 +298,34 @@ i = env.reset() ... ``` -#### FloatPropertiesChannel -The `FloatPropertiesChannel` will allow you to get and set pre-defined numerical values in the environment. This can be useful for adjusting environment-specific settings, or for reading non-agent related information from the environment. You can call `get_property` and `set_property` on the side channel to read and write properties. +#### EnvironmentParameters +The `EnvironmentParameters` will allow you to get and set pre-defined numerical values in the environment. This can be useful for adjusting environment-specific settings, or for reading non-agent related information from the environment. You can call `get_property` and `set_property` on the side channel to read and write properties. -`FloatPropertiesChannel` has three methods: +`EnvironmentParametersChannel` has one methods: - * `set_property` Sets a property in the Unity Environment. + * `set_float_parameter` Sets a float parameter in the Unity Environment. * key: The string identifier of the property. * value: The float value of the property. - * `get_property` Gets a property in the Unity Environment. If the property was not found, will return None. - * key: The string identifier of the property. - - * `list_properties` Returns a list of all the string identifiers of the properties - ```python from mlagents_envs.environment import UnityEnvironment -from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel +from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel -channel = FloatPropertiesChannel() +channel = EnvironmentParametersChannel() env = UnityEnvironment(side_channels=[channel]) -channel.set_property("parameter_1", 2.0) +channel.set_float_parameter("parameter_1", 2.0) i = env.reset() - -readout_value = channel.get_property("parameter_2") ... ``` Once a property has been modified in Python, you can access it in C# after the next call to `step` as follows: ```csharp -var sharedProperties = SideChannelUtils.GetSideChannel(); -float property1 = sharedProperties.GetPropertyWithDefault("parameter_1", 0.0f); +var envParameters = Academy.Instance.EnvironmentParameters; +float property1 = envParameters.GetWithDefault("parameter_1", 0.0f); ``` #### Custom side channels diff --git a/docs/Training-Curriculum-Learning.md b/docs/Training-Curriculum-Learning.md index 38885287ea..cb0203c8d2 100644 --- a/docs/Training-Curriculum-Learning.md +++ b/docs/Training-Curriculum-Learning.md @@ -40,8 +40,8 @@ the same environment. In order to define the curricula, the first step is to decide which parameters of the environment will vary. In the case of the Wall Jump environment, -the height of the wall is what varies. We define this as a `Shared Float Property` -that can be accessed in `SideChannelUtils.GetSideChannel()`, and by doing +the height of the wall is what varies. We define this as a `Environment Parameters` +that can be accessed in `Academy.Instance.EnvironmentParameters`, and by doing so it becomes adjustable via the Python API. Rather than adjusting it by hand, we will create a YAML file which describes the structure of the curricula. Within it, we can specify which diff --git a/docs/Using-Tensorboard.md b/docs/Using-Tensorboard.md index dd424c20a9..9243969429 100644 --- a/docs/Using-Tensorboard.md +++ b/docs/Using-Tensorboard.md @@ -95,6 +95,6 @@ To get custom metrics from a C# environment into Tensorboard, you can use the StatsSideChannel: ```csharp -var statsSideChannel = SideChannelUtils.GetSideChannel(); -statsSideChannel.AddStat("MyMetric", 1.0); +var statsRecorder = Academy.Instance.StatsRecorder; +statsSideChannel.Add("MyMetric", 1.0); ``` diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 178ebe1eb6..b0f05aa113 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -58,7 +58,7 @@ class UnityEnvironment(BaseEnv): # Currently we require strict equality between the communication protocol # on each side, although we may allow some flexibility in the future. # This should be incremented whenever a change is made to the communication protocol. - API_VERSION = "0.16.0" + API_VERSION = "0.17.0" # Default port that the editor listens on. If an environment executable # isn't specified, this port will be used. diff --git a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py index a8eef624c2..c9315433f9 100644 --- a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py +++ b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py @@ -1,7 +1,11 @@ from mlagents_envs.side_channel import SideChannel, OutgoingMessage, IncomingMessage -from mlagents_envs.exception import UnityCommunicationException +from mlagents_envs.exception import ( + UnityCommunicationException, + UnitySideChannelException, +) import uuid -from typing import NamedTuple +from typing import NamedTuple, Optional +from enum import IntEnum class EngineConfig(NamedTuple): @@ -10,10 +14,11 @@ class EngineConfig(NamedTuple): quality_level: int time_scale: float target_frame_rate: int + capture_frame_rate: int @staticmethod def default_config(): - return EngineConfig(80, 80, 1, 20.0, -1) + return EngineConfig(80, 80, 1, 20.0, -1, 60) class EngineConfigurationChannel(SideChannel): @@ -25,8 +30,16 @@ class EngineConfigurationChannel(SideChannel): - int qualityLevel; - float timeScale; - int targetFrameRate; + - int captureFrameRate; """ + class ConfigurationType(IntEnum): + SCREEN_RESOLUTION = 0 + QUALITY_LEVEL = 1 + TIME_SCALE = 2 + TARGET_FRAME_RATE = 3 + CAPTURE_FRAME_RATE = 4 + def __init__(self) -> None: super().__init__(uuid.UUID("e951342c-4f7e-11ea-b238-784f4387d1f7")) @@ -45,32 +58,67 @@ def on_message_received(self, msg: IncomingMessage) -> None: def set_configuration_parameters( self, - width: int = 80, - height: int = 80, - quality_level: int = 1, - time_scale: float = 20.0, - target_frame_rate: int = -1, + width: Optional[int] = None, + height: Optional[int] = None, + quality_level: Optional[int] = None, + time_scale: Optional[float] = None, + target_frame_rate: Optional[int] = None, + capture_frame_rate: Optional[int] = None, ) -> None: """ Sets the engine configuration. Takes as input the configurations of the engine. - :param width: Defines the width of the display. Default 80. - :param height: Defines the height of the display. Default 80. + :param width: Defines the width of the display. (Must be set alongside height) + :param height: Defines the height of the display. (Must be set alongside width) :param quality_level: Defines the quality level of the simulation. - Default 1. :param time_scale: Defines the multiplier for the deltatime in the simulation. If set to a higher value, time will pass faster in the - simulation but the physics might break. Default 20. + simulation but the physics might break. :param target_frame_rate: Instructs simulation to try to render at a - specified frame rate. Default -1. + specified frame rate. + :param capture_frame_rate: Instructs the simulation to consider time between + updates to always be constant, regardless of the actual frame rate. """ - msg = OutgoingMessage() - msg.write_int32(width) - msg.write_int32(height) - msg.write_int32(quality_level) - msg.write_float32(time_scale) - msg.write_int32(target_frame_rate) - super().queue_message_to_send(msg) + + if (width is None and height is not None) or ( + width is not None and height is None + ): + raise UnitySideChannelException( + "You cannot set the width/height of the screen resolution without also setting the height/width" + ) + + if width is not None and height is not None: + screen_msg = OutgoingMessage() + screen_msg.write_int32(self.ConfigurationType.SCREEN_RESOLUTION) + screen_msg.write_int32(width) + screen_msg.write_int32(height) + super().queue_message_to_send(screen_msg) + + if quality_level is not None: + quality_level_msg = OutgoingMessage() + quality_level_msg.write_int32(self.ConfigurationType.QUALITY_LEVEL) + quality_level_msg.write_int32(quality_level) + super().queue_message_to_send(quality_level_msg) + + if time_scale is not None: + time_scale_msg = OutgoingMessage() + time_scale_msg.write_int32(self.ConfigurationType.TIME_SCALE) + time_scale_msg.write_float32(time_scale) + super().queue_message_to_send(time_scale_msg) + + if target_frame_rate is not None: + target_frame_rate_msg = OutgoingMessage() + target_frame_rate_msg.write_int32(self.ConfigurationType.TARGET_FRAME_RATE) + target_frame_rate_msg.write_int32(target_frame_rate) + super().queue_message_to_send(target_frame_rate_msg) + + if capture_frame_rate is not None: + capture_frame_rate_msg = OutgoingMessage() + capture_frame_rate_msg.write_int32( + self.ConfigurationType.CAPTURE_FRAME_RATE + ) + capture_frame_rate_msg.write_int32(capture_frame_rate) + super().queue_message_to_send(capture_frame_rate_msg) def set_configuration(self, config: EngineConfig) -> None: """ diff --git a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py new file mode 100644 index 0000000000..958364b675 --- /dev/null +++ b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py @@ -0,0 +1,37 @@ +from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage +from mlagents_envs.exception import UnityCommunicationException +import uuid +from enum import IntEnum + + +class EnvironmentParametersChannel(SideChannel): + """ + This is the SideChannel for sending environment parameters to Unity. + You can send parameters to an environment with the command + set_float_parameter. + """ + + class EnvironmentDataTypes(IntEnum): + FLOAT = 0 + + def __init__(self) -> None: + channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400")) + super().__init__(channel_id) + + def on_message_received(self, msg: IncomingMessage) -> None: + raise UnityCommunicationException( + "The EnvironmentParametersChannel received a message from Unity, " + + "this should not have happend." + ) + + def set_float_parameter(self, key: str, value: float) -> None: + """ + Sets a float environment parameter in the Unity Environment. + :param key: The string identifier of the parameter. + :param value: The float value of the parameter. + """ + msg = OutgoingMessage() + msg.write_string(key) + msg.write_int32(self.EnvironmentDataTypes.FLOAT) + msg.write_float32(value) + super().queue_message_to_send(msg) diff --git a/ml-agents/mlagents/trainers/env_manager.py b/ml-agents/mlagents/trainers/env_manager.py index bd0958279c..ded18ba3d6 100644 --- a/ml-agents/mlagents/trainers/env_manager.py +++ b/ml-agents/mlagents/trainers/env_manager.py @@ -72,11 +72,6 @@ def reset(self, config: Dict = None) -> int: def external_brains(self) -> Dict[BehaviorName, BrainParameters]: pass - @property - @abstractmethod - def get_properties(self) -> Dict[BehaviorName, float]: - pass - @abstractmethod def close(self): pass diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index a5cb4677af..93c90490f1 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -234,6 +234,13 @@ def _create_parser(): help="The target frame rate of the Unity environment(s). Equivalent to setting " "Application.targetFrameRate in Unity.", ) + eng_conf.add_argument( + "--capture-frame-rate", + default=60, + type=int, + help="The capture frame rate of the Unity environment(s). Equivalent to setting " + "Time.captureFramerate in Unity.", + ) return argparser @@ -268,6 +275,7 @@ class RunOptions(NamedTuple): quality_level: int = parser.get_default("quality_level") time_scale: float = parser.get_default("time_scale") target_frame_rate: int = parser.get_default("target_frame_rate") + capture_frame_rate: int = parser.get_default("capture_frame_rate") @staticmethod def from_argparse(args: argparse.Namespace) -> "RunOptions": @@ -353,11 +361,12 @@ def run_training(run_seed: int, options: RunOptions) -> None: options.env_path, options.no_graphics, run_seed, port, options.env_args ) engine_config = EngineConfig( - options.width, - options.height, - options.quality_level, - options.time_scale, - options.target_frame_rate, + width=options.width, + height=options.height, + quality_level=options.quality_level, + time_scale=options.time_scale, + target_frame_rate=options.target_frame_rate, + capture_frame_rate=options.capture_frame_rate, ) env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs) maybe_meta_curriculum = try_create_meta_curriculum( diff --git a/ml-agents/mlagents/trainers/simple_env_manager.py b/ml-agents/mlagents/trainers/simple_env_manager.py index e3ea815553..eca037783e 100644 --- a/ml-agents/mlagents/trainers/simple_env_manager.py +++ b/ml-agents/mlagents/trainers/simple_env_manager.py @@ -5,7 +5,9 @@ from mlagents_envs.timers import timed from mlagents.trainers.action_info import ActionInfo from mlagents.trainers.brain import BrainParameters -from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel +from mlagents_envs.side_channel.environment_parameters_channel import ( + EnvironmentParametersChannel, +) from mlagents.trainers.brain_conversion_utils import behavior_spec_to_brain_parameters @@ -15,9 +17,9 @@ class SimpleEnvManager(EnvManager): This is generally only useful for testing; see SubprocessEnvManager for a production-quality implementation. """ - def __init__(self, env: BaseEnv, float_prop_channel: FloatPropertiesChannel): + def __init__(self, env: BaseEnv, env_params: EnvironmentParametersChannel): super().__init__() - self.shared_float_properties = float_prop_channel + self.env_params = env_params self.env = env self.previous_step: EnvironmentStep = EnvironmentStep.empty(0) self.previous_all_action_info: Dict[str, ActionInfo] = {} @@ -42,7 +44,7 @@ def _reset_env( ) -> List[EnvironmentStep]: # type: ignore if config is not None: for k, v in config.items(): - self.shared_float_properties.set_property(k, v) + self.env_params.set_float_parameter(k, v) self.env.reset() all_step_result = self._generate_all_results() self.previous_step = EnvironmentStep(all_step_result, 0, {}, {}) @@ -57,10 +59,6 @@ def external_brains(self) -> Dict[BehaviorName, BrainParameters]: ) return result - @property - def get_properties(self) -> Dict[BehaviorName, float]: - return self.shared_float_properties.get_property_dict_copy() - def close(self): self.env.close() diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py index dbb042ae36..35817b150a 100644 --- a/ml-agents/mlagents/trainers/subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py @@ -23,7 +23,9 @@ ) from mlagents.trainers.brain import BrainParameters from mlagents.trainers.action_info import ActionInfo -from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel +from mlagents_envs.side_channel.environment_parameters_channel import ( + EnvironmentParametersChannel, +) from mlagents_envs.side_channel.engine_configuration_channel import ( EngineConfigurationChannel, EngineConfig, @@ -113,7 +115,7 @@ def worker( env_factory: Callable[ [int, List[SideChannel]], UnityEnvironment ] = cloudpickle.loads(pickled_env_factory) - shared_float_properties = FloatPropertiesChannel() + env_parameters = EnvironmentParametersChannel() engine_configuration_channel = EngineConfigurationChannel() engine_configuration_channel.set_configuration(engine_configuration) stats_channel = StatsSideChannel() @@ -138,8 +140,7 @@ def external_brains(): try: env = env_factory( - worker_id, - [shared_float_properties, engine_configuration_channel, stats_channel], + worker_id, [env_parameters, engine_configuration_channel, stats_channel] ) while True: req: EnvironmentRequest = parent_conn.recv() @@ -167,12 +168,9 @@ def external_brains(): reset_timers() elif req.cmd == EnvironmentCommand.EXTERNAL_BRAINS: _send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains()) - elif req.cmd == EnvironmentCommand.GET_PROPERTIES: - reset_params = shared_float_properties.get_property_dict_copy() - _send_response(EnvironmentCommand.GET_PROPERTIES, reset_params) elif req.cmd == EnvironmentCommand.RESET: for k, v in req.payload.items(): - shared_float_properties.set_property(k, v) + env_parameters.set_float_parameter(k, v) env.reset() all_step_result = _generate_all_results() _send_response(EnvironmentCommand.RESET, all_step_result) @@ -295,11 +293,6 @@ def external_brains(self) -> Dict[BehaviorName, BrainParameters]: self.env_workers[0].send(EnvironmentCommand.EXTERNAL_BRAINS) return self.env_workers[0].recv().payload - @property - def get_properties(self) -> Dict[BehaviorName, float]: - self.env_workers[0].send(EnvironmentCommand.GET_PROPERTIES) - return self.env_workers[0].recv().payload - def close(self) -> None: logger.debug(f"SubprocessEnvManager closing.") self.step_queue.close() diff --git a/ml-agents/mlagents/trainers/tests/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/test_simple_rl.py index 61d658640d..3754dd5ca6 100644 --- a/ml-agents/mlagents/trainers/tests/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/test_simple_rl.py @@ -16,7 +16,9 @@ from mlagents.trainers.sampler_class import SamplerManager from mlagents.trainers.demo_loader import write_demo from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary -from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel +from mlagents_envs.side_channel.environment_parameters_channel import ( + EnvironmentParametersChannel, +) from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( DemonstrationMetaProto, ) @@ -142,7 +144,7 @@ def _check_environment_trains( # Make sure threading is turned off for determinism trainer_config["threading"] = False if env_manager is None: - env_manager = SimpleEnvManager(env, FloatPropertiesChannel()) + env_manager = SimpleEnvManager(env, EnvironmentParametersChannel()) trainer_factory = TrainerFactory( trainer_config=trainer_config, summaries_dir=dir,