-
Notifications
You must be signed in to change notification settings - Fork 693
[WIP] 'tune cat' command for pretty printing configuration files #2298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
4a5f014
Add 'tune cat' command for pretty printing configuration files
Ankur-singh 8853bde
Enhance 'tune cat' command to support sorting and improve help descri…
Ankur-singh 20801e3
Update documentation to include 'tune cat' command and its usage
Ankur-singh 8c981a4
Update description and epilog
Ankur-singh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import runpy | ||
| import sys | ||
|
|
||
| import pytest | ||
| from tests.common import TUNE_PATH | ||
|
|
||
|
|
||
| class TestTuneCatCommand: | ||
| """This class tests the `tune cat` command.""" | ||
|
|
||
| def test_cat_valid_config(self, capsys, monkeypatch): | ||
| testargs = "tune cat llama2/7B_full".split() | ||
| monkeypatch.setattr(sys, "argv", testargs) | ||
| runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
|
||
| captured = capsys.readouterr() | ||
| output = captured.out.rstrip("\n") | ||
|
|
||
| # Check for key sections that should be in the YAML output | ||
| assert "output_dir:" in output | ||
| assert "tokenizer:" in output | ||
| assert "model:" in output | ||
|
|
||
| def test_cat_recipe_name_shows_error(self, capsys, monkeypatch): | ||
| testargs = "tune cat full_finetune_single_device".split() | ||
| monkeypatch.setattr(sys, "argv", testargs) | ||
| runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
|
||
| captured = capsys.readouterr() | ||
| output = captured.out.rstrip("\n") | ||
|
|
||
| assert "is a recipe, not a config" in output | ||
|
|
||
| def test_cat_non_existent_config(self, capsys, monkeypatch): | ||
| testargs = "tune cat non_existent_config".split() | ||
| monkeypatch.setattr(sys, "argv", testargs) | ||
|
|
||
| with pytest.raises(SystemExit): | ||
| runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
|
||
| captured = capsys.readouterr() | ||
| err = captured.err.rstrip("\n") | ||
|
|
||
| assert ( | ||
| "Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)" | ||
| in err | ||
| ) | ||
|
|
||
| def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir): | ||
| invalid_yaml = tmpdir / "invalid.yaml" | ||
| invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8") | ||
|
|
||
| testargs = f"tune cat {invalid_yaml}".split() | ||
| monkeypatch.setattr(sys, "argv", testargs) | ||
|
|
||
| with pytest.raises(SystemExit): | ||
| runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
|
||
| captured = capsys.readouterr() | ||
| err = captured.err.rstrip("\n") | ||
|
|
||
| assert "Error parsing YAML file" in err | ||
|
|
||
| def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir): | ||
| valid_yaml = tmpdir / "external.yaml" | ||
| valid_yaml.write_text("key: value", encoding="utf-8") | ||
|
|
||
| testargs = f"tune cat {valid_yaml}".split() | ||
| monkeypatch.setattr(sys, "argv", testargs) | ||
| runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
|
||
| captured = capsys.readouterr() | ||
| output = captured.out.rstrip("\n") | ||
|
|
||
| assert "key: value" in output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import argparse | ||
| import textwrap | ||
|
|
||
| from pathlib import Path | ||
| from typing import List, Optional | ||
|
|
||
| import yaml | ||
| from torchtune._cli.subcommand import Subcommand | ||
| from torchtune._recipe_registry import Config, get_all_recipes | ||
|
|
||
| ROOT = Path(__file__).parent.parent.parent | ||
|
|
||
|
|
||
| class Cat(Subcommand): | ||
| """Holds all the logic for the `tune cat` subcommand.""" | ||
|
|
||
| def __init__(self, subparsers: argparse._SubParsersAction): | ||
| super().__init__() | ||
| self._parser = subparsers.add_parser( | ||
| "cat", | ||
| prog="tune cat", | ||
| help="Pretty print a config, making it easy to know which parameters you can override with `tune run`.", | ||
| description="Pretty print a config, making it easy to know which parameters you can override with `tune run`.", | ||
| epilog=textwrap.dedent( | ||
| """\ | ||
| examples: | ||
| $ tune cat llama2/7B_full | ||
| output_dir: /tmp/torchtune/llama2_7B/full | ||
| tokenizer: | ||
| _component_: torchtune.models.llama2.llama2_tokenizer | ||
| path: /tmp/Llama-2-7b-hf/tokenizer.model | ||
| max_seq_len: null | ||
| ... | ||
|
|
||
| # Pretty print the config in sorted order | ||
| $ tune cat llama2/7B_full --sort | ||
|
|
||
| # Pretty print the contents of LOCALFILE.yaml | ||
| $ tune cat LOCALFILE.yaml | ||
|
|
||
| You can now easily override a key based on your findings from `tune cat`: | ||
| $ tune run full_finetune_distributed --config llama2/7B_full output_dir=./ | ||
|
|
||
| Need to find all the "cat"-able configs? Try `tune ls`! | ||
| """ | ||
| ), | ||
| formatter_class=argparse.RawTextHelpFormatter, | ||
| ) | ||
| self._parser.add_argument( | ||
| "config_name", type=str, help="Name of the config to print" | ||
| ) | ||
| self._parser.set_defaults(func=self._cat_cmd) | ||
| self._parser.add_argument( | ||
| "--sort", action="store_true", help="Print the config in sorted order" | ||
| ) | ||
|
|
||
| def _get_all_recipes(self) -> List[str]: | ||
| return [recipe.name for recipe in get_all_recipes()] | ||
|
|
||
| def _get_config(self, config_str: str) -> Optional[Config]: | ||
| # Search through all recipes | ||
| for recipe in get_all_recipes(): | ||
| for config in recipe.configs: | ||
| if config.name == config_str: | ||
| return config | ||
|
|
||
| def _print_yaml_file(self, file: str, sort_keys: bool) -> None: | ||
| try: | ||
| with open(file, "r") as f: | ||
| data = yaml.safe_load(f) | ||
| if data: | ||
| print( | ||
| yaml.dump( | ||
| data, | ||
| default_flow_style=False, | ||
| sort_keys=sort_keys, | ||
| indent=4, | ||
| width=80, | ||
| allow_unicode=True, | ||
| ), | ||
| end="", | ||
| ) | ||
| except yaml.YAMLError as e: | ||
| self._parser.error(f"Error parsing YAML file: {e}") | ||
|
|
||
| def _cat_cmd(self, args: argparse.Namespace) -> None: | ||
| """Display the contents of a configuration file. | ||
|
|
||
| Handles both predefined configurations and direct file paths, ensuring: | ||
| - Input is not a recipe name | ||
| - File exists | ||
| - File is YAML format | ||
|
|
||
| Args: | ||
| args (argparse.Namespace): Command-line arguments containing 'config_name' attribute | ||
| """ | ||
| config_str = args.config_name | ||
|
|
||
| # Immediately handle recipe name case | ||
| if config_str in self._get_all_recipes(): | ||
| print( | ||
| f"'{config_str}' is a recipe, not a config. Please use a config name." | ||
| ) | ||
| return | ||
|
|
||
| # Resolve config path | ||
| config = self._get_config(config_str) | ||
| if config: | ||
| config_path = ROOT / "recipes" / "configs" / config.file_path | ||
| else: | ||
| config_path = Path(config_str) | ||
| if config_path.suffix.lower() not in {".yaml", ".yml"}: | ||
| self._parser.error( | ||
| f"Invalid config format: '{config_path}'. Must be YAML (.yaml/.yml)" | ||
| ) | ||
| return | ||
|
|
||
| if not config_path.exists(): | ||
| self._parser.error(f"Config '{config_str}' not found.") | ||
| return | ||
|
|
||
| self._print_yaml_file(str(config_path), args.sort) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example of
tune cat LOCALFILE.yaml