Skip to content
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
42fbf48
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
d619d0b
Fixed style
sankalp04 Jun 9, 2020
de9f908
Fixed more style
sankalp04 Jun 9, 2020
8075648
Nit changes
sankalp04 Jun 9, 2020
6957381
Fixed signature
sankalp04 Jun 9, 2020
fdecd97
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
91865e6
Fixed style
sankalp04 Jun 9, 2020
195267f
Fixed more style
sankalp04 Jun 9, 2020
c5bda4b
Nit changes
sankalp04 Jun 9, 2020
63db228
Fixed signature
sankalp04 Jun 9, 2020
3445556
Merge branch 'develop-checkpoint-conversion' of https://github.com/Un…
sankalp04 Jun 9, 2020
66d4916
Fixed tests, checkpoint management and style
sankalp04 Jun 15, 2020
e6260b4
Check checkpoint management
sankalp04 Jun 15, 2020
59b5d1d
Modify statement on artifacts
sankalp04 Jun 15, 2020
543eb20
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
33f21ac
Fixed style
sankalp04 Jun 9, 2020
be07986
Fixed more style
sankalp04 Jun 9, 2020
a773c32
Nit changes
sankalp04 Jun 9, 2020
7f52e2d
Fixed signature
sankalp04 Jun 9, 2020
9cdeba9
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
31ecce2
Fixed style
sankalp04 Jun 9, 2020
12fa5a6
Nit changes
sankalp04 Jun 9, 2020
f3e4578
Fixed tests, checkpoint management and style
sankalp04 Jun 15, 2020
8b67f78
Check checkpoint management
sankalp04 Jun 15, 2020
eb482fd
Modify statement on artifacts
sankalp04 Jun 15, 2020
19e67d7
Resolved conflicts
sankalp04 Jun 15, 2020
923974f
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
4768066
Fixed style
sankalp04 Jun 9, 2020
1699d04
Fixed more style
sankalp04 Jun 9, 2020
ec62db9
Nit changes
sankalp04 Jun 9, 2020
5bdbf0f
Fixed signature
sankalp04 Jun 9, 2020
e25ab31
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
3a77409
Fixed style
sankalp04 Jun 9, 2020
ce4351b
Nit changes
sankalp04 Jun 9, 2020
a2069cc
Fixed tests, checkpoint management and style
sankalp04 Jun 15, 2020
d3940c5
Check checkpoint management
sankalp04 Jun 15, 2020
e5429f4
Modify statement on artifacts
sankalp04 Jun 15, 2020
1f96c1e
Convert checkpoints to .nn format
sankalp04 Jun 5, 2020
44e84ae
Fixed more style
sankalp04 Jun 9, 2020
a9ac345
Nit changes
sankalp04 Jun 9, 2020
4a1217b
Fixed more style
sankalp04 Jun 9, 2020
e834fb9
Nit changes
sankalp04 Jun 9, 2020
3b36f65
Fixed tests, checkpoint management and style
sankalp04 Jun 15, 2020
23e058d
Check checkpoint management
sankalp04 Jun 15, 2020
a24c16f
Merge branch 'develop-checkpoint-conversion' of https://github.com/Un…
sankalp04 Jun 15, 2020
3d69cc8
refactor checkpoint management in trainer
sankalp04 Jun 18, 2020
4f77d86
Fix imports & test
sankalp04 Jun 18, 2020
377674d
Fix tests
sankalp04 Jun 18, 2020
6d7ec54
Group checkpoint save function for clearer API
sankalp04 Jun 18, 2020
55b2428
Merge branch 'master' into develop-checkpoint-conversion
sankalp04 Jun 19, 2020
a0a623c
Moved checkpoint logic to a different file
sankalp04 Jun 22, 2020
dac0b44
Fix variable names
sankalp04 Jun 22, 2020
82aaed4
Fixed file tracking
sankalp04 Jun 22, 2020
88e81fe
Move checkpointing into Policy (#4139)
Jun 30, 2020
f62546d
Track model creation times
sankalp04 Jul 2, 2020
47fec44
Refactor checkpoint management state to use GlobalTrainingStatus
sankalp04 Jul 8, 2020
6a4db12
Remove is_checkpoint argument
sankalp04 Jul 8, 2020
5c4fd07
Fixed bug with not initializing checkpoint_list
sankalp04 Jul 8, 2020
8baf21b
Fix docstrings
sankalp04 Jul 8, 2020
0317ea1
Add steps back to checkpoint/save
Jul 8, 2020
c1d507d
Remove update_parameter_state
Jul 8, 2020
cee3379
Add docstring for export_policy_model
Jul 8, 2020
15a41d0
Fix test call to export_policy_model
Jul 8, 2020
5343359
Rename to NNCheckpointManager
Jul 9, 2020
efdcd38
Refactor to avoid steps, brain_name in policy
Jul 9, 2020
c065f09
Merge remote-tracking branch 'origin/master' into develop-checkpoint-…
Jul 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ in the `results/<run-identifier>` folder:
blocks. See [Profiling in Python](Profiling-Python.md) for more information
on the timers generated.

These artifacts (except the `.nn` file) are updated throughout the training
process and finalized when training completes or is interrupted.
These artifacts are updated throughout the training
process and finalized when training is completed or is interrupted.

#### Stopping and Resuming Training

Expand Down
14 changes: 12 additions & 2 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
class SerializationSettings(NamedTuple):
model_path: str
brain_name: str
checkpoint_path: str = ""
convert_to_barracuda: bool = True
convert_to_onnx: bool = True
onnx_opset: int = 9
Expand All @@ -72,15 +73,24 @@ def export_policy_model(
Exports latest saved model to .nn format for Unity embedding.
"""
frozen_graph_def = _make_frozen_graph(settings, graph, sess)
if not os.path.exists(settings.model_path):
os.makedirs(settings.model_path)
# Save frozen graph
frozen_graph_def_path = settings.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())

# Convert to barracuda
if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
if settings.checkpoint_path:
tf2bc.convert(
frozen_graph_def_path,
os.path.join(settings.model_path, f"{settings.checkpoint_path}.nn"),
)
logger.info(f"Exported {settings.checkpoint_path}.nn file")
else:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")

# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:
Expand Down
17 changes: 3 additions & 14 deletions ml-agents/mlagents/trainers/ghost/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
self.current_policy_snapshot: Dict[str, List[float]] = {}

self.snapshot_counter: int = 0
self.policies: Dict[str, TFPolicy] = {}

# wrapped_training_team and learning team need to be separate
# in the situation where new agents are created destroyed
Expand Down Expand Up @@ -298,21 +297,11 @@ def end_episode(self):
"""
self.trainer.end_episode()

def save_model(self, name_behavior_id: str) -> None:
def save_model(self) -> None:
"""
Forwarding call to wrapped trainers save_model
Forwarding call to wrapped trainers save_model.
"""
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
self.trainer.save_model(brain_name)

def export_model(self, name_behavior_id: str) -> None:
"""
Forwarding call to wrapped trainers export_model.
"""
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
self.trainer.export_model(brain_name)
self.trainer.save_model()

def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)

# Configure CSV, Tensorboard Writers and StatsReporter
# We assume reward and episode length are needed in the CSV.
csv_writer = CSVWriter(
Expand Down
95 changes: 95 additions & 0 deletions ml-agents/mlagents/trainers/policy/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# # Unity ML-Agents Toolkit
from typing import Dict, Any, Optional, List
import os
import attr
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger

logger = get_logger(__name__)


@attr.s(auto_attribs=True)
class Checkpoint:
steps: int
file_path: str
reward: Optional[float]
creation_time: float


class CheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
behavior_name, StatusType.CHECKPOINTS
)
if not checkpoint_list:
checkpoint_list = []
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoint_list
)
return checkpoint_list

@staticmethod
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
"""
Removes a checkpoint stored in checkpoint_list.
If checkpoint cannot be found, no action is done.
:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Removed checkpoint model {file_path}.")
else:
logger.info(f"Checkpoint at {file_path} could not be found.")
return

@classmethod
def manage_checkpoint_list(
cls, behavior_name: str, keep_checkpoints: int
) -> List[Dict[str, Any]]:
"""
Ensures that the number of checkpoints stored are within the number
of checkpoints the user defines. If the limit is hit, checkpoints are
removed to create room for the next checkpoint to be inserted.
:param category: The category (usually behavior name) of the parameter.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
checkpoints = cls.get_checkpoints(behavior_name)
while len(checkpoints) >= keep_checkpoints:
if (keep_checkpoints <= 0) or (len(checkpoints) == 0):
break
CheckpointManager.remove_checkpoint(checkpoints.pop(0))
return checkpoints

@classmethod
def track_checkpoint_info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer add_checkpoint to track_checkpoint_info. also, I would remove the keep_checkpoints argument and the call to manage_checkpoint_list. I don't think we should have track_checkpoint_info call manage_checkpoint_list under the hood, managing checkpoints should be separate from adding them.

cls, behavior_name: str, new_checkpoint: Checkpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.
:param category: The category (usually behavior name) of the parameter.
:param value: The new checkpoint to be recorded.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
checkpoints = cls.manage_checkpoint_list(behavior_name, keep_checkpoints)
new_checkpoint_dict = attr.asdict(new_checkpoint)
checkpoints.append(new_checkpoint_dict)

@classmethod
def track_final_model_info(
cls, behavior_name: str, final_model: Checkpoint, keep_checkpoints: int
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints
defined by the user and finally stores the information about the final
model (or intermediate model if training is interrupted).
:param category: The category (usually behavior name) of the parameter.
:param final_model_path: The file path of the final model.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
CheckpointManager.manage_checkpoint_list(behavior_name, keep_checkpoints)
final_model_dict = attr.asdict(final_model)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.FINAL_MODEL, final_model_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need a FINAL_MODEL StatusType.
The final model should simply be the last in the checkpoint in StatusType.CHECKPOINTS.
If you really want to keep it rename FINAL_MODEL to FINAL_CHECKPOINT since it is not a model.
You could also rename to LATEST_CHECKPOINT and modify it in track_checkpoint_info every time a model is added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from the user's perspective there is a difference between the resulting model and a checkpoint. This is actually referring to the model output at the end of training, when max_steps is reached. Maybe it would be clearer to rename the Checkpoint class to ModelInfo or similar?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint might be confused with a TensorFlow checkpoint.

This is actually referring to the model output at the end of training, when max_steps is reached.

I understand this, but I think what should happen at the end of training is that a new checkpoint/model info should be stored and then the last checkpoint/model info's .nn file is copied and given to the user at the end of training.

)
53 changes: 47 additions & 6 deletions ml-agents/mlagents/trainers/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import abc
import os
import numpy as np
import time
from distutils.version import LooseVersion

from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.policy.checkpoint_manager import Checkpoint, CheckpointManager
from mlagents_envs.exception import UnityException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.policy import Policy
Expand Down Expand Up @@ -70,6 +74,9 @@ def __init__(
self.sequence_length = 1
self.seed = seed
self.brain = brain
self.behavior_id = BehaviorIdentifiers.from_name_behavior_id(
self.brain.brain_name
)

self.act_size = brain.vector_action_space_size
self.vec_obs_size = brain.vector_observation_space_size
Expand Down Expand Up @@ -392,18 +399,52 @@ def get_update_vars(self):
"""
return list(self.update_dict.keys())

def save_model(self, steps):
def checkpoint(self, model_reward: Optional[float] = None) -> None:
"""
Saves the model
:param steps: The number of steps the model was trained for
:return:
Writes an intermediate checkpoint model to memory
model_reward: Mean reward of the reward buffer at the time of saving
"""
current_step = self.get_current_step()
with self.graph.as_default():
last_checkpoint = os.path.join(self.model_path, f"model-{steps}.ckpt")
self.saver.save(self.sess, last_checkpoint)
last_checkpoint = os.path.join(
self.model_path, f"model-{current_step}.ckpt"
)
if self.saver:
self.saver.save(self.sess, last_checkpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It kind of bothers me that saving the checkpoint is not handled by the CheckpointManager. Can we try to put this logic in the CheckpointManager to make this code nicer? The CheckpointManager is already responsible for deleting those files, and I think classes should delete only what they create.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe we view CheckpointManager differently. I view that as managing our own .nn checkpoints rather than implementation specific (e.g. TF/Torch) policy checkpoints. If we want to combine those we could make CheckpointManager non-static, so we can have a TFCheckpointManager that uses the saver. This aligns with @ervteng's suggestion that the Policy not know much about steps/rewards.

Does that make sense? I think it's not a bad way to go moving forward, but maybe in a follow-up PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint manager should be renamed IMO.

  • NNFileManager
  • ModelMetaData
  • NNInventory

tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
brain_name = self.behavior_id.brain_name
checkpoint_path = f"{brain_name}-{current_step}"
settings = SerializationSettings(self.model_path, brain_name, checkpoint_path)
export_policy_model(settings, self.graph, self.sess)
# Store steps and file_path
new_checkpoint = Checkpoint(
int(current_step),
os.path.join(self.model_path, f"{settings.checkpoint_path}.nn"),
model_reward,
time.time(),
)
# Record checkpoint information
CheckpointManager.track_checkpoint_info(
brain_name, new_checkpoint, self.keep_checkpoints
)

def save(self, model_reward: Optional[float] = None) -> None:
"""
Saves the final model on completion or interruption
model_reward: Mean reward of the reward buffer at the time of saving
"""
current_step = self.get_current_step()
brain_name = self.behavior_id.brain_name
settings = SerializationSettings(self.model_path, brain_name)
final_model = Checkpoint(
int(current_step), f"{settings.model_path}.nn", model_reward, time.time()
)
CheckpointManager.track_final_model_info(
brain_name, final_model, self.keep_checkpoints
)
export_policy_model(settings, self.graph, self.sess)

def update_normalization(self, vector_obs: np.ndarray) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def add_policy(
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am against adding self.policies to all of the trainers. Having a dictionary of policies only makes sense for the ghost trainer, having both self.policies and self.policy seems super redundant to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already done. self.policies is defined in the base Trainer class. This change is remedying the fact that the policies weren't actually assigned in PPOTrainer and SACTrainer.

I agree we should change it, but wanted to avoid further expanding the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😲

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should just make it consistent - since self.policies is more general, we should just have it as the standard across trainers. That way all trainers can share the get_policy method.

self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
Expand Down
19 changes: 14 additions & 5 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,21 @@ def __init__(

self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer

def save_model(self, name_behavior_id: str) -> None:
def _checkpoint(self) -> None:
"""
Saves the model. Overrides the default save_model since we want to save
the replay buffer as well.
Writes a checkpoint model to memory
Overrides the default to save the replay buffer.
"""
self.policy.save_model(self.get_step)
super()._checkpoint()
if self.checkpoint_replay_buffer:
self.save_replay_buffer()

def save_model(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_model and _checkpoint have the same docstring. I doubt they do the same thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify what you meant here, I am not sure I understand the feedback. Is it that they are implementing the same thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sankalp04 just that the docstring describing the functions and their params are the same (this was my mistake), and should say what the respective functions do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one is supposed to be called by the other. Maybe only one of them needs to be public ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you clarify @vincentpierre? I believe super().save_model() should result in the _checkpoint method being called

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad. Sorry, I forgot to look at the base class.

"""
Saves the final training model to memory
Overrides the default to save the replay buffer.
"""
super().save_model()
if self.checkpoint_replay_buffer:
self.save_replay_buffer()

Expand Down Expand Up @@ -308,7 +317,6 @@ def add_policy(
) -> None:
"""
Adds policy to trainer.
:param brain_parameters: specifications for policy construction
"""
if self.policy:
logger.warning(
Expand All @@ -320,6 +328,7 @@ def add_policy(
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_policy_conversion(tmpdir, rnn, visual, discrete):
use_discrete=discrete,
use_visual=visual,
)
policy.save_model(1000)
policy.checkpoint()
settings = SerializationSettings(
policy.model_path, os.path.join(tmpdir, policy.brain.brain_name)
)
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_config_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def test_convert_behaviors(trainer_type, use_recurrent):
if trainer_type == TrainerType.PPO:
trainer_config = PPO_CONFIG
trainer_settings_type = PPOSettings
elif trainer_type == TrainerType.SAC:
else:
trainer_config = SAC_CONFIG
trainer_settings_type = SACSettings

old_config = yaml.load(trainer_config)
old_config = yaml.safe_load(trainer_config)
old_config[BRAIN_NAME]["use_recurrent"] = use_recurrent
new_config = convert_behaviors(old_config)

Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/test_nn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_load_save(tmp_path):
policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._set_step(2000)
policy.save_model(2000)
policy.checkpoint()

assert len(os.listdir(tmp_path)) > 0

Expand Down
Loading