Skip to content

Commit 03f7e79

Browse files
author
Ervin T
authored
[refactor] Add --tensorflow, enable Torch as default setting (#4582)
* Add --tensorflow option * Switch framework to Pytorch default * Update changelog * Re-add --torch * Edit warning
1 parent fe0cfbf commit 03f7e79

File tree

5 files changed

+33
-4
lines changed

5 files changed

+33
-4
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ and this project adheres to
1212
### Major Changes
1313
#### com.unity.ml-agents (C#)
1414
#### ml-agents / ml-agents-envs / gym-unity (Python)
15+
- PyTorch trainers are now the default. See the
16+
[installation docs](https://github.com/Unity-Technologies/ml-agents/blob/mastere/docs/Installation.md) for
17+
more information on installing PyTorch. For the time being, TensorFlow is still available;
18+
you can use the TensorFlow backend by adding `--tensorflow` to the CLI, or
19+
adding `framework: tensorflow` in the configuration YAML. (#4517)
1520

1621
### Minor Changes
1722
#### com.unity.ml-agents (C#)

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,15 @@ def _create_parser() -> argparse.ArgumentParser:
172172
"--torch",
173173
default=False,
174174
action=DetectDefaultStoreTrue,
175-
help="(Experimental) Use the PyTorch framework instead of TensorFlow. Install PyTorch "
176-
"before using this option",
175+
help="Use the PyTorch framework. Note that this option is not required anymore as PyTorch is the"
176+
"default framework, and will be removed in the next release.",
177+
)
178+
argparser.add_argument(
179+
"--tensorflow",
180+
default=False,
181+
action=DetectDefaultStoreTrue,
182+
help="(Deprecated) Use the TensorFlow framework instead of PyTorch. Install TensorFlow "
183+
"before using this option.",
177184
)
178185

179186
eng_conf = argparser.add_argument_group(title="Engine Configuration")

ml-agents/mlagents/trainers/learn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
136136
init_path=maybe_init_path,
137137
multi_gpu=False,
138138
force_torch="torch" in DetectDefault.non_default_args,
139+
force_tensorflow="tensorflow" in DetectDefault.non_default_args,
139140
)
140141
# Create controller and begin training.
141142
tc = TrainerController(

ml-agents/mlagents/trainers/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def _set_default_hyperparameters(self):
620620
threaded: bool = True
621621
self_play: Optional[SelfPlaySettings] = None
622622
behavioral_cloning: Optional[BehavioralCloningSettings] = None
623-
framework: FrameworkType = FrameworkType.TENSORFLOW
623+
framework: FrameworkType = FrameworkType.PYTORCH
624624

625625
cattr.register_structure_hook(
626626
Dict[RewardSignalType, RewardSignalSettings], RewardSignalSettings.structure

ml-agents/mlagents/trainers/trainer/trainer_factory.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
init_path: str = None,
2828
multi_gpu: bool = False,
2929
force_torch: bool = False,
30+
force_tensorflow: bool = False,
3031
):
3132
"""
3233
The TrainerFactory generates the Trainers based on the configuration passed as
@@ -45,7 +46,9 @@ def __init__(
4546
:param init_path: Path from which to load model.
4647
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
4748
:param force_torch: If True, the Trainers will all use the PyTorch framework
48-
instead of the TensorFlow framework.
49+
instead of what is specified in the config YAML.
50+
:param force_tensorflow: If True, thee Trainers will all use the TensorFlow
51+
framework.
4952
"""
5053
self.trainer_config = trainer_config
5154
self.output_path = output_path
@@ -57,6 +60,7 @@ def __init__(
5760
self.multi_gpu = multi_gpu
5861
self.ghost_controller = GhostController()
5962
self._force_torch = force_torch
63+
self._force_tf = force_tensorflow
6064

6165
def generate(self, behavior_name: str) -> Trainer:
6266
if behavior_name not in self.trainer_config.keys():
@@ -67,6 +71,18 @@ def generate(self, behavior_name: str) -> Trainer:
6771
trainer_settings = self.trainer_config[behavior_name]
6872
if self._force_torch:
6973
trainer_settings.framework = FrameworkType.PYTORCH
74+
logger.warning(
75+
"Note that specifying --torch is not required anymore as PyTorch is the default framework."
76+
)
77+
if self._force_tf:
78+
trainer_settings.framework = FrameworkType.TENSORFLOW
79+
logger.warning(
80+
"Setting the framework to TensorFlow. TensorFlow trainers will be deprecated in the future."
81+
)
82+
if self._force_torch:
83+
logger.warning(
84+
"Both --torch and --tensorflow CLI options were specified. Using TensorFlow."
85+
)
7086
return TrainerFactory._initialize_trainer(
7187
trainer_settings,
7288
behavior_name,

0 commit comments

Comments
 (0)