Skip to content

Commit 590d883

Browse files
felipemello1Felipe Mello
authored andcommitted
output_dir not in ckpt dir (meta-pytorch#2181)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 608fbee commit 590d883

File tree

5 files changed

+126
-28
lines changed

5 files changed

+126
-28
lines changed

tests/recipes/test_ppo_full_finetune_single_device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc
358358
--config mistral/7B_full_ppo_low_memory \
359359
output_dir={tmpdir} \
360360
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
361-
checkpointer.checkpoint_dir='{policy_tmpdir}' \
361+
checkpointer.checkpoint_dir='{ckpt_dir}' \
362362
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
363363
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
364364
checkpointer.output_dir={policy_tmpdir} \
@@ -367,7 +367,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc
367367
ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \
368368
ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\
369369
370-
value_checkpointer.checkpoint_dir='{value_tmpdir}' \
370+
value_checkpointer.checkpoint_dir='{ckpt_dir}' \
371371
value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, epoch_folder_minus_one, model_ckpt_fname)}]\
372372
value_checkpointer.output_dir={value_tmpdir} \
373373

tests/torchtune/training/checkpointing/test_checkpointer.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
152152
* embed_dim: 64
153153
* max_seq_len: 128
154154
"""
155-
checkpoint_file_1 = tmp_path / "llama2_hf_checkpoint_01.pt"
156-
checkpoint_file_2 = tmp_path / "llama2_hf_checkpoint_02.pt"
155+
checkpoint_dir = Path.joinpath(tmp_path, "checkpoint_dir")
156+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
157+
158+
checkpoint_file_1 = checkpoint_dir / "llama2_hf_checkpoint_01.pt"
159+
checkpoint_file_2 = checkpoint_dir / "llama2_hf_checkpoint_02.pt"
157160

158161
torch.save(state_dict_1, checkpoint_file_1)
159162
torch.save(state_dict_2, checkpoint_file_2)
@@ -163,7 +166,7 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
163166
"num_attention_heads": 4,
164167
"num_key_value_heads": 4,
165168
}
166-
config_file = Path.joinpath(tmp_path, "config.json")
169+
config_file = Path.joinpath(checkpoint_dir, "config.json")
167170
with config_file.open("w") as f:
168171
json.dump(config, f)
169172

@@ -174,23 +177,27 @@ def single_file_checkpointer(
174177
self, llama2_hf_checkpoints, tmp_path
175178
) -> FullModelHFCheckpointer:
176179
checkpoint_file, _ = llama2_hf_checkpoints
180+
checkpoint_dir = str(Path.joinpath(tmp_path, "checkpoint_dir"))
181+
output_dir = str(Path.joinpath(tmp_path, "output_dir"))
177182
return FullModelHFCheckpointer(
178-
checkpoint_dir=tmp_path,
183+
checkpoint_dir=checkpoint_dir,
179184
checkpoint_files=[checkpoint_file],
180185
model_type="LLAMA2",
181-
output_dir=tmp_path,
186+
output_dir=output_dir,
182187
)
183188

184189
@pytest.fixture
185190
def multi_file_checkpointer(
186191
self, llama2_hf_checkpoints, tmp_path
187192
) -> FullModelHFCheckpointer:
188193
checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints
194+
checkpoint_dir = str(Path.joinpath(tmp_path, "checkpoint_dir"))
195+
output_dir = str(Path.joinpath(tmp_path, "output_dir"))
189196
return FullModelHFCheckpointer(
190-
checkpoint_dir=tmp_path,
197+
checkpoint_dir=checkpoint_dir,
191198
checkpoint_files=[checkpoint_file_1, checkpoint_file_2],
192199
model_type="LLAMA2",
193-
output_dir=tmp_path,
200+
output_dir=output_dir,
194201
)
195202

196203
def test_load_save_checkpoint_single_file(
@@ -242,7 +249,7 @@ def test_load_save_checkpoint_single_file(
242249
# assumes we know what the name of the file is. This is fine, breaking this logic
243250
# should be something we capture through this test
244251
output_file = Path.joinpath(
245-
checkpoint_file.parent,
252+
checkpoint_file.parent.parent / "output_dir",
246253
"epoch_1",
247254
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
248255
).with_suffix(".safetensors")
@@ -306,12 +313,12 @@ def test_save_load_checkpoint_multiple_file(
306313
# assumes we know what the name of the file is. This is fine, breaking this logic
307314
# should be something we capture through this test
308315
output_file_1 = Path.joinpath(
309-
checkpoint_file_1.parent,
316+
checkpoint_file_1.parent.parent / "output_dir",
310317
"epoch_1",
311318
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="2".zfill(5)),
312319
).with_suffix(".safetensors")
313320
output_file_2 = Path.joinpath(
314-
checkpoint_file_2.parent,
321+
checkpoint_file_2.parent.parent / "output_dir",
315322
"epoch_1",
316323
SHARD_FNAME.format(cpt_idx="2".zfill(5), num_shards="2".zfill(5)),
317324
).with_suffix(".safetensors")
@@ -338,12 +345,14 @@ def test_load_save_adapter_only(
338345
single_file_checkpointer.save_checkpoint(state_dict, epoch=2, adapter_only=True)
339346

340347
output_file_1 = Path.joinpath(
341-
tmp_path,
348+
tmp_path / "output_dir",
342349
"epoch_2",
343350
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
344351
)
345352
output_file_2 = Path.joinpath(
346-
tmp_path, "epoch_2", f"{ADAPTER_MODEL_FNAME}.safetensors"
353+
tmp_path / "output_dir",
354+
"epoch_2",
355+
f"{ADAPTER_MODEL_FNAME}.safetensors",
347356
)
348357

349358
with pytest.raises(ValueError, match="Unable to load checkpoint from"):
@@ -437,12 +446,16 @@ def test_save_checkpoint_in_peft_format(
437446

438447
# Load saved adapter weights and config from file for comparison
439448
adapter_weights_file = Path.joinpath(
440-
checkpoint_file.parent, "epoch_1", f"{ADAPTER_MODEL_FNAME}.safetensors"
449+
checkpoint_file.parent.parent / "output_dir",
450+
"epoch_1",
451+
f"{ADAPTER_MODEL_FNAME}.safetensors",
441452
)
442453
actual_adapter_state_dict = safe_torch_load(adapter_weights_file)
443454

444455
adapter_config_file = Path.joinpath(
445-
checkpoint_file.parent, "epoch_1", f"{ADAPTER_CONFIG_FNAME}.json"
456+
checkpoint_file.parent.parent / "output_dir",
457+
"epoch_1",
458+
f"{ADAPTER_CONFIG_FNAME}.json",
446459
)
447460
with open(adapter_config_file, "r") as f:
448461
adapter_config = json.load(f)
@@ -558,7 +571,10 @@ def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
558571
* intermediate_dim: 256
559572
560573
"""
561-
checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt"
574+
checkpoint_dir = Path.joinpath(tmp_path, "checkpoint_dir")
575+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
576+
577+
checkpoint_file = checkpoint_dir / "mistral_reward_model_hf_checkpoint.pt"
562578

563579
torch.save(state_dict, checkpoint_file)
564580

@@ -568,7 +584,7 @@ def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
568584
"num_key_value_heads": 4,
569585
"num_classes": 1,
570586
}
571-
config_file = Path.joinpath(tmp_path, "config.json")
587+
config_file = Path.joinpath(checkpoint_dir, "config.json")
572588
with config_file.open("w") as f:
573589
json.dump(config, f)
574590

@@ -579,11 +595,13 @@ def single_file_checkpointer(
579595
self, mistral_reward_model_hf_checkpoint, tmp_path
580596
) -> FullModelHFCheckpointer:
581597
checkpoint_file = mistral_reward_model_hf_checkpoint
598+
checkpoint_dir = str(Path.joinpath(tmp_path, "checkpoint_dir"))
599+
output_dir = str(Path.joinpath(tmp_path, "output_dir"))
582600
return FullModelHFCheckpointer(
583-
checkpoint_dir=tmp_path,
601+
checkpoint_dir=checkpoint_dir,
584602
checkpoint_files=[checkpoint_file],
585603
model_type="REWARD",
586-
output_dir=tmp_path,
604+
output_dir=output_dir,
587605
)
588606

589607
def test_load_save_checkpoint_single_file(
@@ -636,7 +654,7 @@ def test_load_save_checkpoint_single_file(
636654
# assumes we know what the name of the file is. This is fine, breaking this logic
637655
# should be something we capture through this test
638656
output_file = Path.joinpath(
639-
checkpoint_file.parent,
657+
checkpoint_file.parent.parent / "output_dir",
640658
"epoch_1",
641659
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
642660
).with_suffix(".safetensors")
@@ -708,7 +726,10 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
708726
* head_dim : 16
709727
710728
"""
711-
checkpoint_file = tmp_path / "gemma_hf_checkpoint.pt"
729+
checkpoint_dir = Path.joinpath(tmp_path, "checkpoint_dir")
730+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
731+
732+
checkpoint_file = checkpoint_dir / "gemma_hf_checkpoint.pt"
712733

713734
torch.save(state_dict, checkpoint_file)
714735

@@ -719,7 +740,7 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
719740
"head_dim": _HEAD_DIM,
720741
"intermediate_size": _HIDDEN_DIM,
721742
}
722-
config_file = Path.joinpath(tmp_path, "config.json")
743+
config_file = Path.joinpath(checkpoint_dir, "config.json")
723744
with config_file.open("w") as f:
724745
json.dump(config, f)
725746

@@ -730,11 +751,13 @@ def single_file_checkpointer(
730751
self, gemma_hf_checkpoint, tmp_path
731752
) -> FullModelHFCheckpointer:
732753
checkpoint_file = gemma_hf_checkpoint
754+
checkpoint_dir = str(Path.joinpath(tmp_path, "checkpoint_dir"))
755+
output_dir = str(Path.joinpath(tmp_path, "output_dir"))
733756
return FullModelHFCheckpointer(
734-
checkpoint_dir=tmp_path,
757+
checkpoint_dir=checkpoint_dir,
735758
checkpoint_files=[checkpoint_file],
736759
model_type="GEMMA",
737-
output_dir=tmp_path,
760+
output_dir=output_dir,
738761
)
739762

740763
def test_load_save_checkpoint_single_file(
@@ -788,7 +811,7 @@ def test_load_save_checkpoint_single_file(
788811
# assumes we know what the name of the file is. This is fine, breaking this logic
789812
# should be something we capture through this test
790813
output_file = Path.joinpath(
791-
checkpoint_file.parent,
814+
checkpoint_file.parent.parent / "output_dir",
792815
"epoch_1",
793816
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
794817
).with_suffix(".safetensors")

tests/torchtune/training/checkpointing/test_checkpointer_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torchtune.models.llama2 import llama2, llama2_classifier
1313
from torchtune.training.checkpointing._utils import (
14+
check_outdir_not_in_ckptdir,
1415
FormattedCheckpointFiles,
1516
safe_torch_load,
1617
update_state_dict_for_classifier,
@@ -226,3 +227,47 @@ def test_build_checkpoint_filenames(self, expected_filenames):
226227
formatted_files = FormattedCheckpointFiles.from_dict(formatted_file_dict)
227228
actual_filenames = formatted_files.build_checkpoint_filenames()
228229
assert actual_filenames == expected_filenames
230+
231+
232+
class TestCheckOutdirNotInCkptdir:
233+
def test_sibling_directories(self):
234+
# Sibling directories should pass without raising an error
235+
ckpt_dir = Path("/path/to/ckpt")
236+
out_dir = Path("/path/to/output")
237+
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)
238+
239+
def test_ckpt_dir_in_output_dir(self):
240+
# out_dir is a parent of ckpt_dir, should pass without raising an error
241+
ckpt_dir = Path("/path/to/output/ckpt_dir")
242+
out_dir = Path("/path/to/output")
243+
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)
244+
245+
def test_equal_directories(self):
246+
# Equal directories should raise a ValueError
247+
ckpt_dir = Path("/path/to/ckpt")
248+
out_dir = Path("/path/to/ckpt")
249+
with pytest.raises(
250+
ValueError,
251+
match="The output directory cannot be the same as or a subdirectory of the checkpoint directory.",
252+
):
253+
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)
254+
255+
def test_output_dir_in_ckpt_dir(self):
256+
# out_dir is a subdirectory of ckpt_dir, should raise a ValueError
257+
ckpt_dir = Path("/path/to/ckpt")
258+
out_dir = Path("/path/to/ckpt/subdir")
259+
with pytest.raises(
260+
ValueError,
261+
match="The output directory cannot be the same as or a subdirectory of the checkpoint directory.",
262+
):
263+
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)
264+
265+
def test_output_dir_ckpt_dir_few_levels_down(self):
266+
# out_dir is a few levels down in ckpt_dir, should raise a ValueError
267+
ckpt_dir = Path("/path/to/ckpt")
268+
out_dir = Path("/path/to/ckpt/subdir/another_subdir")
269+
with pytest.raises(
270+
ValueError,
271+
match="The output directory cannot be the same as or a subdirectory of the checkpoint directory.",
272+
):
273+
check_outdir_not_in_ckptdir(ckpt_dir, out_dir)

torchtune/training/checkpointing/_checkpointer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torchtune.training.checkpointing._utils import (
3131
ADAPTER_CONFIG_FNAME,
3232
ADAPTER_MODEL_FNAME,
33+
check_outdir_not_in_ckptdir,
3334
copy_files,
3435
get_adapter_checkpoint_path,
3536
get_model_checkpoint_path,
@@ -162,7 +163,7 @@ def __init__(
162163
# TODO: support loading more than one file
163164
if len(checkpoint_files) != 1:
164165
raise ValueError(
165-
"Currently we only support reading from a single torchtune checkpoint file. "
166+
"Currently we only support reading from a single checkpoint file. "
166167
f"Got {len(checkpoint_files)} files instead."
167168
)
168169

@@ -177,6 +178,9 @@ def __init__(
177178

178179
self._model_type = ModelType[model_type]
179180
self._output_dir = Path(output_dir)
181+
check_outdir_not_in_ckptdir(
182+
ckpt_dir=self._checkpoint_dir, out_dir=self._output_dir
183+
)
180184
self._output_dir.mkdir(parents=True, exist_ok=True)
181185

182186
# resume from adapter_model ckpt
@@ -422,6 +426,9 @@ def __init__(
422426
self._checkpoint_dir = Path(checkpoint_dir)
423427
self._model_type = ModelType[model_type]
424428
self._output_dir = Path(output_dir)
429+
check_outdir_not_in_ckptdir(
430+
ckpt_dir=self._checkpoint_dir, out_dir=self._output_dir
431+
)
425432
self._output_dir.mkdir(parents=True, exist_ok=True)
426433

427434
# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
@@ -950,7 +957,7 @@ def __init__(
950957
# TODO: support loading more than one file
951958
if len(checkpoint_files) != 1:
952959
raise ValueError(
953-
"Currently we only support reading from a single torchtune checkpoint file. "
960+
"Currently we only support reading from a single checkpoint file. "
954961
f"Got {len(checkpoint_files)} files instead."
955962
)
956963

@@ -963,6 +970,9 @@ def __init__(
963970
)
964971
self._model_type = ModelType[model_type]
965972
self._output_dir = Path(output_dir)
973+
check_outdir_not_in_ckptdir(
974+
ckpt_dir=self._checkpoint_dir, out_dir=self._output_dir
975+
)
966976
self._output_dir.mkdir(parents=True, exist_ok=True)
967977

968978
# resume from adapter_model ckpt

torchtune/training/checkpointing/_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,23 @@ def validate_checkpoint_files(
572572
)
573573

574574
return checkpoint_paths
575+
576+
577+
def check_outdir_not_in_ckptdir(ckpt_dir: Path, out_dir: Path) -> bool:
578+
"""
579+
Checks that the output directory is not equal to or a subdirectory of the checkpoint directory.
580+
This is necessary to avoid making copies of copies when geting config files from ckpt_dir.
581+
"""
582+
583+
# Resolve the absolute paths to avoid issues with relative paths
584+
_ckpt_dir = ckpt_dir.resolve()
585+
_out_dir = out_dir.resolve()
586+
587+
# Check if out_dir is the same as ckpt_dir or a subdirectory of it
588+
if _out_dir == _ckpt_dir or _ckpt_dir in _out_dir.parents:
589+
raise ValueError(
590+
"The output directory cannot be the same as or a subdirectory of the checkpoint directory. "
591+
f"Found {ckpt_dir=} and {out_dir=}."
592+
)
593+
594+
return True

0 commit comments

Comments
 (0)