-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Convert checkpoints to .NN #4127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 59 commits
42fbf48
d619d0b
de9f908
8075648
6957381
fdecd97
91865e6
195267f
c5bda4b
63db228
3445556
66d4916
e6260b4
59b5d1d
543eb20
33f21ac
be07986
a773c32
7f52e2d
9cdeba9
31ecce2
12fa5a6
f3e4578
8b67f78
eb482fd
19e67d7
923974f
4768066
1699d04
ec62db9
5bdbf0f
e25ab31
3a77409
ce4351b
a2069cc
d3940c5
e5429f4
1f96c1e
44e84ae
a9ac345
4a1217b
e834fb9
3b36f65
23e058d
a24c16f
3d69cc8
4f77d86
377674d
6d7ec54
55b2428
a0a623c
dac0b44
82aaed4
88e81fe
f62546d
47fec44
6a4db12
5c4fd07
8baf21b
0317ea1
c1d507d
cee3379
15a41d0
5343359
efdcd38
c065f09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# # Unity ML-Agents Toolkit | ||
from typing import Dict, Any, Optional, List | ||
import os | ||
import attr | ||
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType | ||
from mlagents_envs.logging_util import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
@attr.s(auto_attribs=True) | ||
class Checkpoint: | ||
steps: int | ||
file_path: str | ||
vincentpierre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
reward: Optional[float] | ||
creation_time: float | ||
|
||
|
||
class CheckpointManager: | ||
ervteng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@staticmethod | ||
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: | ||
checkpoint_list = GlobalTrainingStatus.get_parameter_state( | ||
behavior_name, StatusType.CHECKPOINTS | ||
) | ||
if not checkpoint_list: | ||
checkpoint_list = [] | ||
GlobalTrainingStatus.set_parameter_state( | ||
behavior_name, StatusType.CHECKPOINTS, checkpoint_list | ||
) | ||
return checkpoint_list | ||
|
||
@staticmethod | ||
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: | ||
""" | ||
Removes a checkpoint stored in checkpoint_list. | ||
If checkpoint cannot be found, no action is done. | ||
:param checkpoint: A checkpoint stored in checkpoint_list | ||
""" | ||
file_path: str = checkpoint["file_path"] | ||
if os.path.exists(file_path): | ||
os.remove(file_path) | ||
logger.info(f"Removed checkpoint model {file_path}.") | ||
else: | ||
logger.info(f"Checkpoint at {file_path} could not be found.") | ||
return | ||
|
||
@classmethod | ||
def manage_checkpoint_list( | ||
harperj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cls, behavior_name: str, keep_checkpoints: int | ||
) -> List[Dict[str, Any]]: | ||
""" | ||
Ensures that the number of checkpoints stored are within the number | ||
of checkpoints the user defines. If the limit is hit, checkpoints are | ||
removed to create room for the next checkpoint to be inserted. | ||
:param category: The category (usually behavior name) of the parameter. | ||
:param keep_checkpoints: Number of checkpoints to record (user-defined). | ||
""" | ||
checkpoints = cls.get_checkpoints(behavior_name) | ||
while len(checkpoints) >= keep_checkpoints: | ||
if (keep_checkpoints <= 0) or (len(checkpoints) == 0): | ||
break | ||
CheckpointManager.remove_checkpoint(checkpoints.pop(0)) | ||
return checkpoints | ||
|
||
@classmethod | ||
def track_checkpoint_info( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I would prefer add_checkpoint to track_checkpoint_info. also, I would remove the keep_checkpoints argument and the call to manage_checkpoint_list. I don't think we should have |
||
cls, behavior_name: str, new_checkpoint: Checkpoint, keep_checkpoints: int | ||
) -> None: | ||
""" | ||
Make room for new checkpoint if needed and insert new checkpoint information. | ||
:param category: The category (usually behavior name) of the parameter. | ||
:param value: The new checkpoint to be recorded. | ||
:param keep_checkpoints: Number of checkpoints to record (user-defined). | ||
""" | ||
checkpoints = cls.manage_checkpoint_list(behavior_name, keep_checkpoints) | ||
new_checkpoint_dict = attr.asdict(new_checkpoint) | ||
checkpoints.append(new_checkpoint_dict) | ||
harperj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def track_final_model_info( | ||
cls, behavior_name: str, final_model: Checkpoint, keep_checkpoints: int | ||
) -> None: | ||
""" | ||
Ensures number of checkpoints stored is within the max number of checkpoints | ||
defined by the user and finally stores the information about the final | ||
model (or intermediate model if training is interrupted). | ||
:param category: The category (usually behavior name) of the parameter. | ||
:param final_model_path: The file path of the final model. | ||
:param keep_checkpoints: Number of checkpoints to record (user-defined). | ||
""" | ||
CheckpointManager.manage_checkpoint_list(behavior_name, keep_checkpoints) | ||
final_model_dict = attr.asdict(final_model) | ||
GlobalTrainingStatus.set_parameter_state( | ||
behavior_name, StatusType.FINAL_MODEL, final_model_dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think we need a FINAL_MODEL StatusType. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think from the user's perspective there is a difference between the resulting model and a checkpoint. This is actually referring to the model output at the end of training, when max_steps is reached. Maybe it would be clearer to rename the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checkpoint might be confused with a TensorFlow checkpoint.
I understand this, but I think what should happen at the end of training is that a new checkpoint/model info should be stored and then the last checkpoint/model info's |
||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,14 @@ | |
import abc | ||
import os | ||
import numpy as np | ||
import time | ||
from distutils.version import LooseVersion | ||
|
||
from mlagents.model_serialization import SerializationSettings, export_policy_model | ||
from mlagents.tf_utils import tf | ||
from mlagents import tf_utils | ||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers | ||
from mlagents.trainers.policy.checkpoint_manager import Checkpoint, CheckpointManager | ||
from mlagents_envs.exception import UnityException | ||
from mlagents_envs.logging_util import get_logger | ||
from mlagents.trainers.policy import Policy | ||
|
@@ -70,6 +74,9 @@ def __init__( | |
self.sequence_length = 1 | ||
self.seed = seed | ||
self.brain = brain | ||
self.behavior_id = BehaviorIdentifiers.from_name_behavior_id( | ||
self.brain.brain_name | ||
) | ||
|
||
self.act_size = brain.vector_action_space_size | ||
self.vec_obs_size = brain.vector_observation_space_size | ||
|
@@ -392,18 +399,52 @@ def get_update_vars(self): | |
""" | ||
return list(self.update_dict.keys()) | ||
|
||
def save_model(self, steps): | ||
def checkpoint(self, model_reward: Optional[float] = None) -> None: | ||
harperj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Saves the model | ||
:param steps: The number of steps the model was trained for | ||
:return: | ||
Writes an intermediate checkpoint model to memory | ||
model_reward: Mean reward of the reward buffer at the time of saving | ||
""" | ||
current_step = self.get_current_step() | ||
with self.graph.as_default(): | ||
last_checkpoint = os.path.join(self.model_path, f"model-{steps}.ckpt") | ||
self.saver.save(self.sess, last_checkpoint) | ||
last_checkpoint = os.path.join( | ||
self.model_path, f"model-{current_step}.ckpt" | ||
) | ||
if self.saver: | ||
self.saver.save(self.sess, last_checkpoint) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It kind of bothers me that saving the checkpoint is not handled by the CheckpointManager. Can we try to put this logic in the CheckpointManager to make this code nicer? The CheckpointManager is already responsible for deleting those files, and I think classes should delete only what they create. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think maybe we view Does that make sense? I think it's not a bad way to go moving forward, but maybe in a follow-up PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checkpoint manager should be renamed IMO.
|
||
tf.train.write_graph( | ||
self.graph, self.model_path, "raw_graph_def.pb", as_text=False | ||
) | ||
brain_name = self.behavior_id.brain_name | ||
checkpoint_path = f"{brain_name}-{current_step}" | ||
settings = SerializationSettings(self.model_path, brain_name, checkpoint_path) | ||
export_policy_model(settings, self.graph, self.sess) | ||
# Store steps and file_path | ||
new_checkpoint = Checkpoint( | ||
int(current_step), | ||
os.path.join(self.model_path, f"{settings.checkpoint_path}.nn"), | ||
model_reward, | ||
time.time(), | ||
) | ||
# Record checkpoint information | ||
CheckpointManager.track_checkpoint_info( | ||
brain_name, new_checkpoint, self.keep_checkpoints | ||
) | ||
|
||
def save(self, model_reward: Optional[float] = None) -> None: | ||
harperj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Saves the final model on completion or interruption | ||
model_reward: Mean reward of the reward buffer at the time of saving | ||
""" | ||
current_step = self.get_current_step() | ||
brain_name = self.behavior_id.brain_name | ||
settings = SerializationSettings(self.model_path, brain_name) | ||
final_model = Checkpoint( | ||
int(current_step), f"{settings.model_path}.nn", model_reward, time.time() | ||
) | ||
CheckpointManager.track_final_model_info( | ||
brain_name, final_model, self.keep_checkpoints | ||
) | ||
export_policy_model(settings, self.graph, self.sess) | ||
|
||
def update_normalization(self, vector_obs: np.ndarray) -> None: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -229,6 +229,7 @@ def add_policy( | |
if not isinstance(policy, NNPolicy): | ||
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()") | ||
self.policy = policy | ||
self.policies[parsed_behavior_id.behavior_id] = policy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am against adding self.policies to all of the trainers. Having a dictionary of policies only makes sense for the ghost trainer, having both self.policies and self.policy seems super redundant to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was already done. I agree we should change it, but wanted to avoid further expanding the scope of this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😲 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should just make it consistent - since |
||
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings) | ||
for _reward_signal in self.optimizer.reward_signals.keys(): | ||
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,12 +76,21 @@ def __init__( | |
|
||
self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer | ||
|
||
def save_model(self, name_behavior_id: str) -> None: | ||
def _checkpoint(self) -> None: | ||
""" | ||
Saves the model. Overrides the default save_model since we want to save | ||
the replay buffer as well. | ||
Writes a checkpoint model to memory | ||
Overrides the default to save the replay buffer. | ||
""" | ||
self.policy.save_model(self.get_step) | ||
super()._checkpoint() | ||
if self.checkpoint_replay_buffer: | ||
self.save_replay_buffer() | ||
|
||
def save_model(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save_model and _checkpoint have the same docstring. I doubt they do the same thing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please clarify what you meant here, I am not sure I understand the feedback. Is it that they are implementing the same thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sankalp04 just that the docstring describing the functions and their params are the same (this was my mistake), and should say what the respective functions do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think one is supposed to be called by the other. Maybe only one of them needs to be public ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you clarify @vincentpierre? I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, my bad. Sorry, I forgot to look at the base class. |
||
""" | ||
Saves the final training model to memory | ||
Overrides the default to save the replay buffer. | ||
""" | ||
super().save_model() | ||
if self.checkpoint_replay_buffer: | ||
self.save_replay_buffer() | ||
|
||
|
@@ -308,7 +317,6 @@ def add_policy( | |
) -> None: | ||
""" | ||
Adds policy to trainer. | ||
:param brain_parameters: specifications for policy construction | ||
""" | ||
if self.policy: | ||
logger.warning( | ||
|
@@ -320,6 +328,7 @@ def add_policy( | |
if not isinstance(policy, NNPolicy): | ||
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()") | ||
self.policy = policy | ||
self.policies[parsed_behavior_id.behavior_id] = policy | ||
self.optimizer = SACOptimizer(self.policy, self.trainer_settings) | ||
for _reward_signal in self.optimizer.reward_signals.keys(): | ||
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) | ||
|
Uh oh!
There was an error while loading. Please reload this page.