Skip to content

Commit 0001f22

Browse files
authored
Merge branch 'pytorch:main' into patch-1
2 parents 9c53135 + afd5bc4 commit 0001f22

File tree

16 files changed

+847
-44
lines changed

16 files changed

+847
-44
lines changed

.github/workflows/build_docs.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
python-version: ['3.11']
2929
steps:
3030
- name: Check out repo
31-
uses: actions/checkout@v3
31+
uses: actions/checkout@v4
3232
- name: Setup conda env
3333
uses: conda-incubator/setup-miniconda@v2
3434
with:
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
cd docs
5252
make html
53-
- uses: actions/upload-artifact@v3
53+
- uses: actions/upload-artifact@v4
5454
with:
5555
name: Built-Docs
5656
path: docs/build/html/
@@ -61,9 +61,9 @@ jobs:
6161
if: ${{ github.repository_owner == 'pytorch' && github.event_name == 'pull_request' }}
6262
steps:
6363
- name: Checkout
64-
uses: actions/checkout@v3
64+
uses: actions/checkout@v4
6565
- name: Download artifact
66-
uses: actions/download-artifact@v3
66+
uses: actions/download-artifact@v4
6767
with:
6868
name: Built-Docs
6969
path: docs
@@ -87,12 +87,12 @@ jobs:
8787
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docs-push' || '' }}
8888
steps:
8989
- name: Checkout
90-
uses: actions/checkout@v3
90+
uses: actions/checkout@v4
9191
with:
9292
ref: gh-pages
9393
persist-credentials: false
9494
- name: Download artifact
95-
uses: actions/download-artifact@v3
95+
uses: actions/download-artifact@v4
9696
with:
9797
name: Built-Docs
9898
path: docs

docs/source/tune_cli.rst

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ with a short description of each.
1717
.. code-block:: bash
1818
1919
$ tune --help
20-
usage: tune [-h] {download,ls,cp,run,validate} ...
20+
usage: tune [-h] {download,ls,cp,run,validate,cat} ...
2121
2222
Welcome to the torchtune CLI!
2323
2424
options:
2525
-h, --help show this help message and exit
2626
2727
subcommands:
28-
{download,ls,cp,run,validate}
28+
{download,ls,cp,run,validate,cat}
2929
download Download a model from the Hugging Face Hub.
3030
ls List all built-in recipes and configs
3131
...
@@ -233,3 +233,72 @@ The ``tune validate <config>`` command will validate that your config is formatt
233233
# If you've copied over a built-in config and want to validate custom changes
234234
$ tune validate my_configs/llama3/8B_full.yaml
235235
Config is well-formed!
236+
237+
.. _tune_cat_cli_label:
238+
239+
Inspect a config
240+
---------------------
241+
242+
The ``tune cat <config>`` command pretty prints a configuration file, making it easy to use ``tune run`` with confidence. This command is useful for inspecting the structure and contents of a config file before running a recipe, ensuring that all parameters are correctly set.
243+
244+
You can also use the ``--sort`` option to print the config in sorted order, which can help in quickly locating specific keys.
245+
246+
.. list-table::
247+
:widths: 30 60
248+
249+
* - \--sort
250+
- Print the config in sorted order.
251+
252+
**Workflow Example**
253+
254+
1. **List all available configs:**
255+
256+
Use the ``tune ls`` command to list all the built-in recipes and configs within torchtune.
257+
258+
.. code-block:: bash
259+
260+
$ tune ls
261+
RECIPE CONFIG
262+
full_finetune_single_device llama2/7B_full_low_memory
263+
code_llama2/7B_full_low_memory
264+
llama3/8B_full_single_device
265+
mistral/7B_full_low_memory
266+
phi3/mini_full_low_memory
267+
full_finetune_distributed llama2/7B_full
268+
llama2/13B_full
269+
llama3/8B_full
270+
llama3/70B_full
271+
...
272+
273+
2. **Inspect the contents of a config:**
274+
275+
Use the ``tune cat`` command to pretty print the contents of a specific config. This helps you understand the structure and parameters of the config.
276+
277+
.. code-block:: bash
278+
279+
$ tune cat llama2/7B_full
280+
output_dir: /tmp/torchtune/llama2_7B/full
281+
tokenizer:
282+
_component_: torchtune.models.llama2.llama2_tokenizer
283+
path: /tmp/Llama-2-7b-hf/tokenizer.model
284+
max_seq_len: null
285+
...
286+
287+
You can also print the config in sorted order:
288+
289+
.. code-block:: bash
290+
291+
$ tune cat llama2/7B_full --sort
292+
293+
3. **Run a recipe with parameter override:**
294+
295+
After inspecting the config, you can use the ``tune run`` command to run a recipe with the config. You can also override specific parameters directly from the command line. For example, to override the `output_dir` parameter:
296+
297+
.. code-block:: bash
298+
299+
$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./
300+
301+
Learn more about config overrides :ref:`here <cli_override>`.
302+
303+
.. note::
304+
You can find all the cat-able configs via the ``tune ls`` command.

recipes/configs/code_llama2/evaluation.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# To launch, run the following command:
44
# tune run eleuther_eval --config code_llama2/evaluation
55

6+
output_dir: ./ # Not needed
7+
68
# Model arguments
79
model:
810
_component_: torchtune.models.code_llama2.code_llama2_7b

recipes/configs/llama3_2/evaluation.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# To launch, run the following command:
44
# tune run eleuther_eval --config llama3_2/evaluation
55

6+
output_dir: ./ # Not needed
7+
68
# Model Arguments
79
model:
810
_component_: torchtune.models.llama3_2.llama3_2_3b

recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1-
# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py
1+
# Config for multi-device LoRA finetuning in lora_finetune_distributed_multi_dataset.py
22
# using a Llama3.2 11B Vision Instruct model
33
#
44
# This config assumes that you've run the following command before launching:
55
# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct --ignore-patterns "original/consolidated*"
66
#
77
# To launch on 2 devices, run the following command from root:
8-
# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td
8+
# tune run --nproc_per_node 2 lora_finetune_distributed_multi_dataset --config llama3_2_vision/11B_lora_multi_dataset
99
#
1010
# You can add specific overrides through the command line. For example
1111
# to override the checkpointer directory while launching training:
12-
# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
12+
# tune run --nproc_per_node 2 lora_finetune_distributed_multi_dataset --config llama3_2_vision/11B_lora_multi_dataset checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
1313
#
1414
# This config works best when the model is being fine-tuned on 2+ GPUs.
1515
# For single device LoRA finetuning please use 11B_lora_single_device.yaml
1616
# or 11B_qlora_single_device.yaml
1717

18+
output_dir: /tmp/torchtune/llama3_2_vision_11B/lora_multi_dataset # /tmp may be deleted by your system. Change it to your preference.
19+
1820
# Model arguments
1921
model:
2022
_component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b
@@ -44,7 +46,7 @@ checkpointer:
4446
filename_format: model-{}-of-{}.safetensors
4547
max_filename: "00005"
4648
recipe_checkpoint: null
47-
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
49+
output_dir: ${output_dir}
4850
model_type: LLAMA3_VISION
4951
resume_from_checkpoint: False
5052
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -117,6 +119,6 @@ dtype: bf16
117119
output_dir: /tmp/lora-llama3.2-vision-finetune
118120
metric_logger:
119121
_component_: torchtune.training.metric_logging.DiskLogger
120-
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
122+
log_dir: ${output_dir}/logs
121123
log_every_n_steps: 1
122124
log_peak_memory_stats: True

tests/torchtune/_cli/test_cat.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import runpy
8+
import sys
9+
10+
import pytest
11+
from tests.common import TUNE_PATH
12+
13+
14+
class TestTuneCatCommand:
15+
"""This class tests the `tune cat` command."""
16+
17+
def test_cat_valid_config(self, capsys, monkeypatch):
18+
testargs = "tune cat llama2/7B_full".split()
19+
monkeypatch.setattr(sys, "argv", testargs)
20+
runpy.run_path(TUNE_PATH, run_name="__main__")
21+
22+
captured = capsys.readouterr()
23+
output = captured.out.rstrip("\n")
24+
25+
# Check for key sections that should be in the YAML output
26+
assert "output_dir:" in output
27+
assert "tokenizer:" in output
28+
assert "model:" in output
29+
30+
def test_cat_recipe_name_shows_error(self, capsys, monkeypatch):
31+
testargs = "tune cat full_finetune_single_device".split()
32+
monkeypatch.setattr(sys, "argv", testargs)
33+
runpy.run_path(TUNE_PATH, run_name="__main__")
34+
35+
captured = capsys.readouterr()
36+
output = captured.out.rstrip("\n")
37+
38+
assert "is a recipe, not a config" in output
39+
40+
def test_cat_non_existent_config(self, capsys, monkeypatch):
41+
testargs = "tune cat non_existent_config".split()
42+
monkeypatch.setattr(sys, "argv", testargs)
43+
44+
with pytest.raises(SystemExit):
45+
runpy.run_path(TUNE_PATH, run_name="__main__")
46+
47+
captured = capsys.readouterr()
48+
err = captured.err.rstrip("\n")
49+
50+
assert (
51+
"Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)"
52+
in err
53+
)
54+
55+
def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir):
56+
invalid_yaml = tmpdir / "invalid.yaml"
57+
invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8")
58+
59+
testargs = f"tune cat {invalid_yaml}".split()
60+
monkeypatch.setattr(sys, "argv", testargs)
61+
62+
with pytest.raises(SystemExit):
63+
runpy.run_path(TUNE_PATH, run_name="__main__")
64+
65+
captured = capsys.readouterr()
66+
err = captured.err.rstrip("\n")
67+
68+
assert "Error parsing YAML file" in err
69+
70+
def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir):
71+
valid_yaml = tmpdir / "external.yaml"
72+
valid_yaml.write_text("key: value", encoding="utf-8")
73+
74+
testargs = f"tune cat {valid_yaml}".split()
75+
monkeypatch.setattr(sys, "argv", testargs)
76+
runpy.run_path(TUNE_PATH, run_name="__main__")
77+
78+
captured = capsys.readouterr()
79+
output = captured.out.rstrip("\n")
80+
81+
assert "key: value" in output

0 commit comments

Comments
 (0)