Skip to content

Commit 608fbee

Browse files
felipemello1Felipe Mello
authored andcommitted
change saving logic (meta-pytorch#2182)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 37bf22e commit 608fbee

File tree

2 files changed

+34
-26
lines changed

2 files changed

+34
-26
lines changed

torchtune/training/checkpointing/_checkpointer.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from torchtune.training.checkpointing._utils import (
3131
ADAPTER_CONFIG_FNAME,
3232
ADAPTER_MODEL_FNAME,
33-
BASE_MODEL_DIRNAME,
3433
copy_files,
3534
get_adapter_checkpoint_path,
3635
get_model_checkpoint_path,
@@ -180,14 +179,6 @@ def __init__(
180179
self._output_dir = Path(output_dir)
181180
self._output_dir.mkdir(parents=True, exist_ok=True)
182181

183-
# save all files in input_dir, except model weights and mapping, to output_dir
184-
# this is useful to preserve the tokenizer, configs, license, etc.
185-
copy_files(
186-
self._checkpoint_dir,
187-
Path.joinpath(self._output_dir, BASE_MODEL_DIRNAME),
188-
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
189-
)
190-
191182
# resume from adapter_model ckpt
192183
self._adapter_checkpoint = get_adapter_checkpoint_path(
193184
output_dir=self._output_dir,
@@ -331,6 +322,14 @@ def save_checkpoint(
331322
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
332323
)
333324

325+
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
326+
# So its easy to run inference with the model using this epoch's checkpoint
327+
copy_files(
328+
self._checkpoint_dir,
329+
Path.joinpath(self._output_dir, f"epoch_{epoch}"),
330+
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
331+
)
332+
334333
# If the recipe state needs to be output, first remove the model state dict
335334
if intermediate_checkpoint:
336335
_ = state_dict.pop(training.MODEL_KEY, None)
@@ -435,14 +434,6 @@ def __init__(
435434
Path.joinpath(self._checkpoint_dir, "config.json").read_text()
436435
)
437436

438-
# save all files in input_dir, except model weights and mapping, to output_dir
439-
# this is useful to preserve the tokenizer, configs, license, etc.
440-
copy_files(
441-
self._checkpoint_dir,
442-
Path.joinpath(self._output_dir, BASE_MODEL_DIRNAME),
443-
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
444-
)
445-
446437
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
447438
# This json file is produced and saved in the download step.
448439
# contents are {"repo_id": "some_model/some_model_version"}
@@ -873,6 +864,14 @@ def save_checkpoint(
873864
f"saved to {output_path}"
874865
)
875866

867+
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
868+
# So its easy to run inference with the model using this epoch's checkpoint
869+
copy_files(
870+
self._checkpoint_dir,
871+
Path.joinpath(self._output_dir, f"epoch_{epoch}"),
872+
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
873+
)
874+
876875
# If the recipe state needs to be output, first remove the model state dict
877876
# and if it exists, remove the adapter state dict as well
878877
if intermediate_checkpoint:
@@ -966,14 +965,6 @@ def __init__(
966965
self._output_dir = Path(output_dir)
967966
self._output_dir.mkdir(parents=True, exist_ok=True)
968967

969-
# save all files in input_dir, except model weights and mapping, to output_dir
970-
# this is useful to preserve the tokenizer, configs, license, etc.
971-
copy_files(
972-
self._checkpoint_dir,
973-
Path.joinpath(self._output_dir, BASE_MODEL_DIRNAME),
974-
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
975-
)
976-
977968
# resume from adapter_model ckpt
978969
self._adapter_checkpoint = get_adapter_checkpoint_path(
979970
output_dir=self._output_dir,
@@ -1126,6 +1117,14 @@ def save_checkpoint(
11261117
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
11271118
)
11281119

1120+
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
1121+
# So its easy to run inference with the model using this epoch's checkpoint
1122+
copy_files(
1123+
self._checkpoint_dir,
1124+
Path.joinpath(self._output_dir, f"epoch_{epoch}"),
1125+
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
1126+
)
1127+
11291128
# If the recipe state needs to be output, first remove the model state dict
11301129
# and if it exists, remove the adapter state dict as well
11311130
if intermediate_checkpoint:

torchtune/training/checkpointing/_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
# standardize checkpointing
3939
SHARD_FNAME = "ft-model-{cpt_idx}-of-{num_shards}"
4040
RECIPE_STATE_DIRNAME = "recipe_state"
41-
BASE_MODEL_DIRNAME = "base_model"
4241

4342
# Needed when setting up output dir in checkpointing
4443
REPO_ID_FNAME = "original_repo_id"
@@ -334,6 +333,7 @@ def copy_files(
334333
output_dir: Union[str, Path],
335334
*,
336335
ignore_suffixes: Optional[List[str]] = None,
336+
max_file_size_mb: int = 100,
337337
) -> None:
338338
"""
339339
Copies files from the input directory to the output directory, preserving the directory structure.
@@ -346,6 +346,7 @@ def copy_files(
346346
output_dir (Union[str, Path]): The path to the output directory where files should be copied.
347347
ignore_suffixes (Optional[List[str]]): A list of file suffixes to exclude from copying.
348348
Defaults to ['.pt', '.bin', '.safetensors'] if not provided.
349+
max_file_size_mb (int): The maximum file size in megabytes to copy. Defaults to 100 MB.
349350
Returns:
350351
None
351352
Example:
@@ -355,6 +356,7 @@ def copy_files(
355356
already exist in the destination or have the specified suffixes.
356357
"""
357358

359+
max_file_size = max_file_size_mb * 1024 * 1024
358360
for root, dirs, files in os.walk(input_dir):
359361

360362
# Filter out directories that start with '.'. E.g. ".cache/"
@@ -381,6 +383,13 @@ def copy_files(
381383
src_file = os.path.join(root, file)
382384
dest_file = os.path.join(dest_dir, file)
383385

386+
# Check the file size
387+
if os.path.getsize(src_file) > max_file_size:
388+
print(
389+
f"Skipping copying {src_file} to {output_dir} as it exceeds the size limit of {max_file_size_mb} MiB."
390+
)
391+
continue
392+
384393
# Copy the file if it doesn't already exist in the destination
385394
if not os.path.exists(dest_file):
386395
shutil.copy2(src_file, dest_file)

0 commit comments

Comments
 (0)