diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 58d08b1bdc..696069a5ba 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Major Changes + - The `--load` and `--train` command-line flags have been deprecated. Training now happens by default, and + use `--resume` to resume training instead. (#3705) - The Jupyter notebooks have been removed from the repository. - Introduced the `SideChannelUtils` to register, unregister and access side channels. - `Academy.FloatProperties` was removed, please use `SideChannelUtils.GetSideChannel()` instead. @@ -23,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Environment subprocesses now close immediately on timeout or wrong API version. (#3679) - Fixed an issue in the gym wrapper that would raise an exception if an Agent called EndEpisode multiple times in the same step. (#3700) - Fixed an issue where exceptions from environments provided a returncode of 0. (#3680) + - Running `mlagents-learn` with the same `--run-id` twice will no longer overwrite the existing files. (#3705) - Fixed an issue where logging output was not visible; logging levels are now set consistently (#3703). ## [0.15.0-preview] - 2020-03-18 diff --git a/docs/Getting-Started.md b/docs/Getting-Started.md index 9372d7c970..1b796ca471 100644 --- a/docs/Getting-Started.md +++ b/docs/Getting-Started.md @@ -197,19 +197,17 @@ which accepts arguments used to configure both training and inference phases. 2. Navigate to the folder where you cloned the ML-Agents toolkit repository. **Note**: If you followed the default [installation](Installation.md), then you should be able to run `mlagents-learn` from any directory. -3. Run `mlagents-learn --run-id= --train` +3. Run `mlagents-learn --run-id=` where: - `` is the relative or absolute filepath of the trainer configuration. The defaults used by example environments included in `MLAgentsSDK` can be found in `config/trainer_config.yaml`. - `` is a string used to separate the results of different - training runs - - `--train` tells `mlagents-learn` to run a training session (rather - than inference) + training runs. Make sure to use one that hasn't been used already! 4. If you cloned the ML-Agents repo, then you can simply run ```sh - mlagents-learn config/trainer_config.yaml --run-id=firstRun --train + mlagents-learn config/trainer_config.yaml --run-id=firstRun ``` 5. When the message _"Start training by pressing the Play button in the Unity @@ -219,7 +217,6 @@ which accepts arguments used to configure both training and inference phases. **Note**: If you're using Anaconda, don't forget to activate the ml-agents environment first. -The `--train` flag tells the ML-Agents toolkit to run in training mode. The `--time-scale=100` sets the `Time.TimeScale` value in Unity. **Note**: You can train using an executable rather than the Editor. To do so, @@ -330,8 +327,14 @@ Either wait for the training process to close the window or press Ctrl+C at the command-line prompt. If you close the window manually, the `.nn` file containing the trained model is not exported into the ml-agents folder. -You can press Ctrl+C to stop the training, and your trained model will be at -`models//.nn` where +If you've quit the training early using Ctrl+C and want to resume training, run the +same command again, appending the `--resume` flag: + +```sh +mlagents-learn config/trainer_config.yaml --run-id=firstRun --resume +``` + +Your trained model will be at `models//.nn` where `` is the name of the `Behavior Name` of the agents corresponding to the model. (**Note:** There is a known bug on Windows that causes the saving of the model to fail when you early terminate the training, it's recommended to wait until Step diff --git a/docs/Learning-Environment-Create-New.md b/docs/Learning-Environment-Create-New.md index 15eee98a61..069507d974 100644 --- a/docs/Learning-Environment-Create-New.md +++ b/docs/Learning-Environment-Create-New.md @@ -418,7 +418,7 @@ in this simple environment, speeds up training. To train in the editor, run the following Python command from a Terminal or Console window before pressing play: - mlagents-learn config/config.yaml --run-id=RollerBall-1 --train + mlagents-learn config/config.yaml --run-id=RollerBall-1 (where `config.yaml` is a copy of `trainer_config.yaml` that you have edited to change the `batch_size` and `buffer_size` hyperparameters for your trainer.) diff --git a/docs/Learning-Environment-Executable.md b/docs/Learning-Environment-Executable.md index 5a2eec55d7..e9c560cdde 100644 --- a/docs/Learning-Environment-Executable.md +++ b/docs/Learning-Environment-Executable.md @@ -76,27 +76,25 @@ env = UnityEnvironment(file_name=) followed the default [installation](Installation.md), then navigate to the `ml-agents/` folder. 3. Run - `mlagents-learn --env= --run-id= --train` + `mlagents-learn --env= --run-id=` Where: * `` is the file path of the trainer configuration yaml * `` is the name and path to the executable you exported from Unity (without extension) * `` is a string used to separate the results of different training runs - * And the `--train` tells `mlagents-learn` to run a training session (rather - than inference) For example, if you are training with a 3DBall executable you exported to the the directory where you installed the ML-Agents Toolkit, run: ```sh -mlagents-learn ../config/trainer_config.yaml --env=3DBall --run-id=firstRun --train +mlagents-learn ../config/trainer_config.yaml --env=3DBall --run-id=firstRun ``` And you should see something like ```console -ml-agents$ mlagents-learn config/trainer_config.yaml --env=3DBall --run-id=first-run --train +ml-agents$ mlagents-learn config/trainer_config.yaml --env=3DBall --run-id=first-run ▄▄▄▓▓▓▓ diff --git a/docs/Migrating.md b/docs/Migrating.md index 663b4ca4b0..18f8fb2efe 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -10,6 +10,13 @@ The versions can be found in ## Migrating from 0.15 to latest ### Important changes +* The `--load` and `--train` command-line flags have been deprecated and replaced with `--resume` and `--inference`. +* Running with the same `--run-id` twice will now throw an error. + +### Steps to Migrate +* Replace the `--load` flag with `--resume` when calling `mlagents-learn`, and don't use the `--train` flag as training + will happen by default. To run without training, use `--inference`. +* To force-overwrite files from a pre-existing run, add the `--force` command-line flag. * The Jupyter notebooks have been removed from the repository. * `Academy.FloatProperties` was removed. * `Academy.RegisterSideChannel` and `Academy.UnregisterSideChannel` were removed. @@ -19,7 +26,6 @@ The versions can be found in * Replace `Academy.RegisterSideChannel` with `SideChannelUtils.RegisterSideChannel()`. * Replace `Academy.UnregisterSideChannel` with `SideChannelUtils.UnregisterSideChannel`. - ## Migrating from 0.14 to 0.15 ### Important changes diff --git a/docs/Training-Curriculum-Learning.md b/docs/Training-Curriculum-Learning.md index 126bfa6715..38885287ea 100644 --- a/docs/Training-Curriculum-Learning.md +++ b/docs/Training-Curriculum-Learning.md @@ -110,7 +110,7 @@ for our curricula and PPO will train using Curriculum Learning. For example, to train agents in the Wall Jump environment with curriculum learning, we can run: ```sh -mlagents-learn config/trainer_config.yaml --curriculum=config/curricula/wall_jump.yaml --run-id=wall-jump-curriculum --train +mlagents-learn config/trainer_config.yaml --curriculum=config/curricula/wall_jump.yaml --run-id=wall-jump-curriculum ``` We can then keep track of the current lessons and progresses via TensorBoard. diff --git a/docs/Training-Environment-Parameter-Randomization.md b/docs/Training-Environment-Parameter-Randomization.md index f98716dd0e..e812557b1d 100644 --- a/docs/Training-Environment-Parameter-Randomization.md +++ b/docs/Training-Environment-Parameter-Randomization.md @@ -165,7 +165,7 @@ sampling setup, we would run ```sh mlagents-learn config/trainer_config.yaml --sampler=config/3dball_randomize.yaml ---run-id=3D-Ball-randomize --train +--run-id=3D-Ball-randomize ``` We can observe progress and metrics via Tensorboard. diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md index 64ae7026fe..a8c7deafcb 100644 --- a/docs/Training-ML-Agents.md +++ b/docs/Training-ML-Agents.md @@ -43,7 +43,7 @@ training options. The basic command for training is: ```sh -mlagents-learn --env= --run-id= --train +mlagents-learn --env= --run-id= ``` where @@ -68,7 +68,7 @@ contains agents ready to train. To perform the training: environment you built in step 1: ```sh -mlagents-learn config/trainer_config.yaml --env=../../projects/Cats/CatsOnBicycles.app --run-id=cob_1 --train +mlagents-learn config/trainer_config.yaml --env=../../projects/Cats/CatsOnBicycles.app --run-id=cob_1 ``` During a training session, the training program prints out and saves updates at @@ -92,9 +92,27 @@ under the assigned run-id — in the cats example, the path to the model would b `models/cob_1/CatsOnBicycles_cob_1.nn`. While this example used the default training hyperparameters, you can edit the -[training_config.yaml file](#training-config-file) with a text editor to set +[trainer_config.yaml file](#training-config-file) with a text editor to set different values. +To interrupt training and save the current progress, hit Ctrl+C once and wait for the +model to be saved out. + +### Loading an Existing Model + +If you've quit training early using Ctrl+C, you can resume the training run by running +`mlagents-learn` again, specifying the same `` and appending the `--resume` flag +to the command. + +You can also use this mode to run inference of an already-trained model in Python. +Append both the `--resume` and `--inference` to do this. Note that if you want to run +inference in Unity, you should use the +[Unity Inference Engine](Getting-started#Running-a-pre-trained-model). + +If you've already trained a model using the specified `` and `--resume` is not +specified, you will not be able to continue with training. Use `--force` to force ML-Agents to +overwrite the existing data. + ### Command Line Training Options In addition to passing the path of the Unity executable containing your training @@ -115,7 +133,7 @@ environment, you can set the following command line options when invoking training. Defaults to 0. * `--num-envs=`: Specifies the number of concurrent Unity environment instances to collect experiences from when training. Defaults to 1. -* `--run-id=`: Specifies an identifier for each training run. This +* `--run-id=`: Specifies an identifier for each training run. This identifier is used to name the subdirectories in which the trained model and summary statistics are saved as well as the saved model itself. The default id is "ppo". If you use TensorBoard to view the training statistics, always set a @@ -137,13 +155,15 @@ environment, you can set the following command line options when invoking will use the port `(base_port + worker_id)`, where the `worker_id` is sequential IDs given to each instance from 0 to `num_envs - 1`. Default is 5005. __Note:__ When training using the Editor rather than an executable, the base port will be ignored. -* `--train`: Specifies whether to train model or only run in inference mode. - When training, **always** use the `--train` option. -* `--load`: If set, the training code loads an already trained model to +* `--inference`: Specifies whether to only run in inference mode. Omit to train the model. + To load an existing model, specify a run-id and combine with `--resume`. +* `--resume`: If set, the training code loads an already trained model to initialize the neural network before training. The learning code looks for the model in `models//` (which is also where it saves models at the end of - training). When not set (the default), the neural network weights are randomly - initialized and an existing model is not loaded. + training). This option only works when the models exist, and have the same behavior names + as the current agents in your scene. +* `--force`: Attempting to train a model with a run-id that has been used before will + throw an error. Use `--force` to force-overwrite this run-id's summary and model data. * `--no-graphics`: Specify this option to run the Unity executable in `-batchmode` and doesn't initialize the graphics driver. Use this only if your training doesn't involve visual observations (reading from Pixels). See diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index b1f93637ad..bf1cc720c7 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -12,7 +12,11 @@ from mlagents import tf_utils from mlagents.trainers.trainer_controller import TrainerController from mlagents.trainers.meta_curriculum import MetaCurriculum -from mlagents.trainers.trainer_util import load_config, TrainerFactory +from mlagents.trainers.trainer_util import ( + load_config, + TrainerFactory, + handle_existing_directories, +) from mlagents.trainers.stats import ( TensorboardWriter, CSVWriter, @@ -68,7 +72,22 @@ def _create_parser(): default=False, dest="load_model", action="store_true", - help="Whether to load the model or randomly initialize", + help=argparse.SUPPRESS, # Deprecated but still usable for now. + ) + argparser.add_argument( + "--resume", + default=False, + dest="resume", + action="store_true", + help="Resumes training from a checkpoint. Specify a --run-id to use this option.", + ) + argparser.add_argument( + "--force", + default=False, + dest="force", + action="store_true", + help="Force-overwrite existing models and summaries for a run-id that has been used " + "before.", ) argparser.add_argument( "--run-id", @@ -86,7 +105,15 @@ def _create_parser(): default=False, dest="train_model", action="store_true", - help="Whether to train model, or only run inference", + help=argparse.SUPPRESS, + ) + argparser.add_argument( + "--inference", + default=False, + dest="inference", + action="store_true", + help="Run in Python inference mode (don't train). Use with --resume to load a model trained with an " + "existing run-id.", ) argparser.add_argument( "--base-port", @@ -168,7 +195,10 @@ class RunOptions(NamedTuple): env_path: Optional[str] = parser.get_default("env_path") run_id: str = parser.get_default("run_id") load_model: bool = parser.get_default("load_model") + resume: bool = parser.get_default("resume") + force: bool = parser.get_default("force") train_model: bool = parser.get_default("train_model") + inference: bool = parser.get_default("inference") save_freq: int = parser.get_default("save_freq") keep_checkpoints: int = parser.get_default("keep_checkpoints") base_port: int = parser.get_default("base_port") @@ -205,7 +235,8 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": argparse_args["sampler_config"] = load_config( argparse_args["sampler_file_path"] ) - + # Keep deprecated --load working, TODO: remove + argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"] # Since argparse accepts file paths in the config options which don't exist in CommandLineOptions, # these keys will need to be deleted to use the **/splat operator below. argparse_args.pop("sampler_file_path") @@ -249,7 +280,10 @@ def run_training(run_seed: int, options: RunOptions) -> None: "Environment/Episode Length", ], ) - tb_writer = TensorboardWriter(summaries_dir) + handle_existing_directories( + model_path, summaries_dir, options.resume, options.force + ) + tb_writer = TensorboardWriter(summaries_dir, clear_past_data=not options.resume) gauge_write = GaugeWriter() console_writer = ConsoleWriter() StatsReporter.add_writer(tb_writer) @@ -282,8 +316,8 @@ def run_training(run_seed: int, options: RunOptions) -> None: options.run_id, model_path, options.keep_checkpoints, - options.train_model, - options.load_model, + not options.inference, + options.resume, run_seed, maybe_meta_curriculum, options.multi_gpu, @@ -296,7 +330,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: options.run_id, options.save_freq, maybe_meta_curriculum, - options.train_model, + not options.inference, run_seed, sampler_manager, resampling_interval, @@ -424,6 +458,17 @@ def run_cli(options: RunOptions) -> None: logger.debug("Configuration for this run:") logger.debug(json.dumps(options._asdict(), indent=4)) + # Options deprecation warnings + if options.load_model: + logger.warning( + "The --load option has been deprecated. Please use the --resume option instead." + ) + if options.train_model: + logger.warning( + "The --train option has been deprecated. Train mode is now the default. Use " + "--inference to run in inference mode." + ) + run_seed = options.seed if options.cpu: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index 9d0c02a57a..828b52ff83 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -115,10 +115,11 @@ def _load_graph(self): logger.info("Loading Model for brain {}".format(self.brain.brain_name)) ckpt = tf.train.get_checkpoint_state(self.model_path) if ckpt is None: - logger.info( - "The model {0} could not be found. Make " + raise UnityPolicyException( + "The model {0} could not be loaded. Make " "sure you specified the right " - "--run-id".format(self.model_path) + "--run-id. and that the previous run you are resuming from had the same " + "behavior names.".format(self.model_path) ) self.saver.restore(self.sess, ckpt.model_checkpoint_path) diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py index c13e6c5256..ad5bf4c3e1 100644 --- a/ml-agents/mlagents/trainers/stats.py +++ b/ml-agents/mlagents/trainers/stats.py @@ -174,14 +174,17 @@ def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str: class TensorboardWriter(StatsWriter): - def __init__(self, base_dir: str): + def __init__(self, base_dir: str, clear_past_data: bool = False): """ A StatsWriter that writes to a Tensorboard summary. :param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a {base_dir}/{category} directory. + :param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and + category. """ self.summary_writers: Dict[str, tf.summary.FileWriter] = {} self.base_dir: str = base_dir + self._clear_past_data = clear_past_data def write_stats( self, category: str, values: Dict[str, StatsSummary], step: int @@ -199,8 +202,25 @@ def _maybe_create_summary_writer(self, category: str) -> None: basedir=self.base_dir, category=category ) os.makedirs(filewriter_dir, exist_ok=True) + if self._clear_past_data: + self._delete_all_events_files(filewriter_dir) self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir) + def _delete_all_events_files(self, directory_name: str) -> None: + for file_name in os.listdir(directory_name): + if file_name.startswith("events.out"): + logger.warning( + "{} was left over from a previous run. Deleting.".format(file_name) + ) + full_fname = os.path.join(directory_name, file_name) + try: + os.remove(full_fname) + except OSError: + logger.warning( + "{} was left over from a previous run and " + "not deleted.".format(full_fname) + ) + def add_property( self, category: str, property_type: StatsPropertyType, value: Any ) -> None: diff --git a/ml-agents/mlagents/trainers/tests/test_learn.py b/ml-agents/mlagents/trainers/tests/test_learn.py index fd291de711..969d62ebd0 100644 --- a/ml-agents/mlagents/trainers/tests/test_learn.py +++ b/ml-agents/mlagents/trainers/tests/test_learn.py @@ -15,6 +15,7 @@ def basic_options(extra_args=None): return parse_command_line(args) +@patch("mlagents.trainers.learn.handle_existing_directories") @patch("mlagents.trainers.learn.TrainerFactory") @patch("mlagents.trainers.learn.SamplerManager") @patch("mlagents.trainers.learn.SubprocessEnvManager") @@ -26,6 +27,7 @@ def test_run_training( subproc_env_mock, sampler_manager_mock, trainer_factory_mock, + handle_dir_mock, ): mock_env = MagicMock() mock_env.external_brain_names = [] @@ -45,11 +47,14 @@ def test_run_training( "ppo", 50000, None, - False, + True, 0, sampler_manager_mock.return_value, None, ) + handle_dir_mock.assert_called_once_with( + "./models/ppo", "./summaries", False, False + ) StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py @@ -79,11 +84,11 @@ def test_commandline_args(mock_file): assert opt.sampler_config is None assert opt.keep_checkpoints == 5 assert opt.lesson == 0 - assert opt.load_model is False + assert opt.resume is False + assert opt.inference is False assert opt.run_id == "ppo" assert opt.save_freq == 50000 assert opt.seed == -1 - assert opt.train_model is False assert opt.base_port == 5005 assert opt.num_envs == 1 assert opt.no_graphics is False @@ -97,7 +102,8 @@ def test_commandline_args(mock_file): "--sampler=./mysample", "--keep-checkpoints=42", "--lesson=3", - "--load", + "--resume", + "--inference", "--run-id=myawesomerun", "--save-freq=123456", "--seed=7890", @@ -115,15 +121,15 @@ def test_commandline_args(mock_file): assert opt.sampler_config == {} assert opt.keep_checkpoints == 42 assert opt.lesson == 3 - assert opt.load_model is True assert opt.run_id == "myawesomerun" assert opt.save_freq == 123456 assert opt.seed == 7890 - assert opt.train_model is True assert opt.base_port == 4004 assert opt.num_envs == 2 assert opt.no_graphics is True assert opt.debug is True + assert opt.inference is True + assert opt.resume is True @patch("builtins.open", new_callable=mock_open, read_data="{}") diff --git a/ml-agents/mlagents/trainers/tests/test_stats.py b/ml-agents/mlagents/trainers/tests/test_stats.py index 20c5803a40..632c0abb9c 100644 --- a/ml-agents/mlagents/trainers/tests/test_stats.py +++ b/ml-agents/mlagents/trainers/tests/test_stats.py @@ -4,6 +4,7 @@ import tempfile import unittest import csv +import time from mlagents.trainers.stats import ( StatsReporter, @@ -75,7 +76,7 @@ def test_tensorboard_writer(mock_filewriter, mock_summary): # Test write_stats category = "category1" with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir: - tb_writer = TensorboardWriter(base_dir) + tb_writer = TensorboardWriter(base_dir, clear_past_data=False) statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1) tb_writer.write_stats("category1", {"key1": statssummary1}, 10) @@ -102,6 +103,26 @@ def test_tensorboard_writer(mock_filewriter, mock_summary): assert mock_filewriter.return_value.add_summary.call_count > 1 +def test_tensorboard_writer_clear(tmp_path): + tb_writer = TensorboardWriter(tmp_path, clear_past_data=False) + statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1) + tb_writer.write_stats("category1", {"key1": statssummary1}, 10) + # TB has some sort of timeout before making a new file + time.sleep(1.0) + assert len(os.listdir(os.path.join(tmp_path, "category1"))) > 0 + + # See if creating a new one doesn't delete it + tb_writer = TensorboardWriter(tmp_path, clear_past_data=False) + tb_writer.write_stats("category1", {"key1": statssummary1}, 10) + assert len(os.listdir(os.path.join(tmp_path, "category1"))) > 1 + time.sleep(1.0) + + # See if creating a new one deletes old ones + tb_writer = TensorboardWriter(tmp_path, clear_past_data=True) + tb_writer.write_stats("category1", {"key1": statssummary1}, 10) + assert len(os.listdir(os.path.join(tmp_path, "category1"))) == 1 + + def test_csv_writer(): # Test write_stats category = "category1" diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_util.py b/ml-agents/mlagents/trainers/tests/test_trainer_util.py index 2d90284cc8..0ab2bc1775 100644 --- a/ml-agents/mlagents/trainers/tests/test_trainer_util.py +++ b/ml-agents/mlagents/trainers/tests/test_trainer_util.py @@ -1,6 +1,7 @@ import pytest import yaml import io +import os from unittest.mock import patch from mlagents.trainers import trainer_util @@ -335,3 +336,24 @@ def test_load_config_invalid_yaml(): with pytest.raises(TrainerConfigError): fp = io.StringIO(file_contents) _load_config(fp) + + +def test_existing_directories(tmp_path): + model_path = os.path.join(tmp_path, "runid") + # Unused summary path + summary_path = os.path.join(tmp_path, "runid") + # Test fresh new unused path - should do nothing. + trainer_util.handle_existing_directories(model_path, summary_path, False, False) + # Test resume with fresh path - should throw an exception. + with pytest.raises(UnityTrainerException): + trainer_util.handle_existing_directories(model_path, summary_path, True, False) + + # make a directory + os.mkdir(model_path) + # Test try to train w.o. force, should complain + with pytest.raises(UnityTrainerException): + trainer_util.handle_existing_directories(model_path, summary_path, False, False) + # Test try to train w/ resume - should work + trainer_util.handle_existing_directories(model_path, summary_path, True, False) + # Test try to train w/ force - should work + trainer_util.handle_existing_directories(model_path, summary_path, False, True) diff --git a/ml-agents/mlagents/trainers/trainer_util.py b/ml-agents/mlagents/trainers/trainer_util.py index 87849cbdcb..b0eb397f70 100644 --- a/ml-agents/mlagents/trainers/trainer_util.py +++ b/ml-agents/mlagents/trainers/trainer_util.py @@ -191,3 +191,33 @@ def _load_config(fp: TextIO) -> Dict[str, Any]: "Error parsing yaml file. Please check for formatting errors. " "A tool such as http://www.yamllint.com/ can be helpful with this." ) from e + + +def handle_existing_directories( + model_path: str, summary_path: str, resume: bool, force: bool +) -> None: + """ + Validates that if the run_id model exists, we do not overwrite it unless --force is specified. + Throws an exception if resume isn't specified and run_id exists. Throws an exception + if --resume is specified and run-id was not found. + :param model_path: The model path specified. + :param summary_path: The summary path to be used. + :param resume: Whether or not the --resume flag was passed. + :param force: Whether or not the --force flag was passed. + """ + + model_path_exists = os.path.isdir(model_path) + + if model_path_exists: + if not resume and not force: + raise UnityTrainerException( + "Previous data from this run-id was found. " + "Either specify a new run-id, use --resume to resume this run, " + "or use the --force parameter to overwrite existing data." + ) + else: + if resume: + raise UnityTrainerException( + "Previous data from this run-id was not found. " + "Train a new run by removing the --resume flag." + )