Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions docs/source/tune_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ with a short description of each.
.. code-block:: bash

$ tune --help
usage: tune [-h] {download,ls,cp,run,validate} ...
usage: tune [-h] {download,ls,cp,run,validate,cat} ...

Welcome to the torchtune CLI!

options:
-h, --help show this help message and exit

subcommands:
{download,ls,cp,run,validate}
{download,ls,cp,run,validate,cat}
download Download a model from the Hugging Face Hub.
ls List all built-in recipes and configs
...
Expand Down Expand Up @@ -233,3 +233,72 @@ The ``tune validate <config>`` command will validate that your config is formatt
# If you've copied over a built-in config and want to validate custom changes
$ tune validate my_configs/llama3/8B_full.yaml
Config is well-formed!

.. _tune_cat_cli_label:

Inspect a config
---------------------

The ``tune cat <config>`` command pretty prints a configuration file, making it easy to use ``tune run`` with confidence. This command is useful for inspecting the structure and contents of a config file before running a recipe, ensuring that all parameters are correctly set.

You can also use the ``--sort`` option to print the config in sorted order, which can help in quickly locating specific keys.

.. list-table::
:widths: 30 60

* - \--sort
- Print the config in sorted order.

**Workflow Example**

1. **List all available configs:**

Use the ``tune ls`` command to list all the built-in recipes and configs within torchtune.

.. code-block:: bash

$ tune ls
RECIPE CONFIG
full_finetune_single_device llama2/7B_full_low_memory
code_llama2/7B_full_low_memory
llama3/8B_full_single_device
mistral/7B_full_low_memory
phi3/mini_full_low_memory
full_finetune_distributed llama2/7B_full
llama2/13B_full
llama3/8B_full
llama3/70B_full
...

2. **Inspect the contents of a config:**

Use the ``tune cat`` command to pretty print the contents of a specific config. This helps you understand the structure and parameters of the config.

.. code-block:: bash

$ tune cat llama2/7B_full
output_dir: /tmp/torchtune/llama2_7B/full
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null
...

You can also print the config in sorted order:

.. code-block:: bash

$ tune cat llama2/7B_full --sort

3. **Run a recipe with parameter override:**

After inspecting the config, you can use the ``tune run`` command to run a recipe with the config. You can also override specific parameters directly from the command line. For example, to override the `output_dir` parameter:

.. code-block:: bash

$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./

Learn more about config overrides :ref:`here <cli_override>`.

.. note::
You can find all the cat-able configs via the ``tune ls`` command.
81 changes: 81 additions & 0 deletions tests/torchtune/_cli/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import runpy
import sys

import pytest
from tests.common import TUNE_PATH


class TestTuneCatCommand:
"""This class tests the `tune cat` command."""

def test_cat_valid_config(self, capsys, monkeypatch):
testargs = "tune cat llama2/7B_full".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

# Check for key sections that should be in the YAML output
assert "output_dir:" in output
assert "tokenizer:" in output
assert "model:" in output

def test_cat_recipe_name_shows_error(self, capsys, monkeypatch):
testargs = "tune cat full_finetune_single_device".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

assert "is a recipe, not a config" in output

def test_cat_non_existent_config(self, capsys, monkeypatch):
testargs = "tune cat non_existent_config".split()
monkeypatch.setattr(sys, "argv", testargs)

with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
err = captured.err.rstrip("\n")

assert (
"Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)"
in err
)

def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir):
invalid_yaml = tmpdir / "invalid.yaml"
invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8")

testargs = f"tune cat {invalid_yaml}".split()
monkeypatch.setattr(sys, "argv", testargs)

with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
err = captured.err.rstrip("\n")

assert "Error parsing YAML file" in err

def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir):
valid_yaml = tmpdir / "external.yaml"
valid_yaml.write_text("key: value", encoding="utf-8")

testargs = f"tune cat {valid_yaml}".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

assert "key: value" in output
128 changes: 128 additions & 0 deletions torchtune/_cli/cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import textwrap

from pathlib import Path
from typing import List, Optional

import yaml
from torchtune._cli.subcommand import Subcommand
from torchtune._recipe_registry import Config, get_all_recipes

ROOT = Path(__file__).parent.parent.parent


class Cat(Subcommand):
"""Holds all the logic for the `tune cat` subcommand."""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self._parser = subparsers.add_parser(
"cat",
prog="tune cat",
help="Pretty print a config, making it easy to know which parameters you can override with `tune run`.",
description="Pretty print a config, making it easy to know which parameters you can override with `tune run`.",
epilog=textwrap.dedent(
"""\
examples:
$ tune cat llama2/7B_full
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an example of tune cat LOCALFILE.yaml

output_dir: /tmp/torchtune/llama2_7B/full
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null
...

# Pretty print the config in sorted order
$ tune cat llama2/7B_full --sort

# Pretty print the contents of LOCALFILE.yaml
$ tune cat LOCALFILE.yaml

You can now easily override a key based on your findings from `tune cat`:
$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./

Need to find all the "cat"-able configs? Try `tune ls`!
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._parser.add_argument(
"config_name", type=str, help="Name of the config to print"
)
self._parser.set_defaults(func=self._cat_cmd)
self._parser.add_argument(
"--sort", action="store_true", help="Print the config in sorted order"
)

def _get_all_recipes(self) -> List[str]:
return [recipe.name for recipe in get_all_recipes()]

def _get_config(self, config_str: str) -> Optional[Config]:
# Search through all recipes
for recipe in get_all_recipes():
for config in recipe.configs:
if config.name == config_str:
return config

def _print_yaml_file(self, file: str, sort_keys: bool) -> None:
try:
with open(file, "r") as f:
data = yaml.safe_load(f)
if data:
print(
yaml.dump(
data,
default_flow_style=False,
sort_keys=sort_keys,
indent=4,
width=80,
allow_unicode=True,
),
end="",
)
except yaml.YAMLError as e:
self._parser.error(f"Error parsing YAML file: {e}")

def _cat_cmd(self, args: argparse.Namespace) -> None:
"""Display the contents of a configuration file.

Handles both predefined configurations and direct file paths, ensuring:
- Input is not a recipe name
- File exists
- File is YAML format

Args:
args (argparse.Namespace): Command-line arguments containing 'config_name' attribute
"""
config_str = args.config_name

# Immediately handle recipe name case
if config_str in self._get_all_recipes():
print(
f"'{config_str}' is a recipe, not a config. Please use a config name."
)
return

# Resolve config path
config = self._get_config(config_str)
if config:
config_path = ROOT / "recipes" / "configs" / config.file_path
else:
config_path = Path(config_str)
if config_path.suffix.lower() not in {".yaml", ".yml"}:
self._parser.error(
f"Invalid config format: '{config_path}'. Must be YAML (.yaml/.yml)"
)
return

if not config_path.exists():
self._parser.error(f"Config '{config_str}' not found.")
return

self._print_yaml_file(str(config_path), args.sort)
3 changes: 3 additions & 0 deletions torchtune/_cli/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import argparse

from torchtune._cli.cat import Cat

from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(self):
Copy.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)
Cat.create(subparsers)

def parse_args(self) -> argparse.Namespace:
"""Parse CLI arguments"""
Expand Down