diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 553a7187a..5b53eec44 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -9,23 +9,30 @@ import os import re import socket -from typing import Any, Dict, List, Optional, Tuple import shutil +from typing import Any, Dict, List, Optional, Tuple import torch from omegaconf import OmegaConf from metaseq.dataclass.configs import CheckpointConfig -from metaseq.dataclass.utils import overwrite_args_by_name +from metaseq.dataclass.utils import overwrite_args_by_name, CheckpointPath from metaseq.distributed import utils as distributed_utils from metaseq.file_io import PathManager, torch_load_cpu -from metaseq.launcher.opt_job_constants import ComputeEnvs logger = logging.getLogger(__name__) OPT_KEY = "last_optimizer_state" +try: + from metaseq_internal import azure_utils +except ImportError: + logger.warning( + "Proceeding without metaseq-internal installed! Please check if you need this!" + "It is required for loading from azure blob." + ) + def save_checkpoint( cfg: CheckpointConfig, @@ -73,7 +80,7 @@ def save_checkpoint( save_for_updates = not end_of_epoch and (save_to_NFS or save_locally) - checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch + checkpoint_conds[f"checkpoint_{updates}_epoch_{epoch}{suffix}.pt"] = save_for_epoch checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates checkpoint_conds[f"checkpoint_last{suffix}.pt"] = ( (training_finished and cfg.save_last_checkpoint) @@ -105,198 +112,214 @@ def save_checkpoint( ) -def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): - """ - Load a checkpoint and restore the training iterator. +def get_storage_type(path: str) -> str: + if path.startswith("nfs:"): + return "nfs" + elif "windows.net" in path: + return "azure_blob" + else: + return "local" - *passthrough_args* will be passed through to - ``trainer.get_train_iterator``. - """ - reset_optimizer = cfg.reset_optimizer - reset_lr_scheduler = cfg.reset_lr_scheduler - optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) - reset_meters = cfg.reset_meters - reset_dataloader = cfg.reset_dataloader +def get_checkpoint_steps(path: str) -> int: + match = re.search(r"checkpoint_(\d+)", path) + if match is None: + return 0 + return int(match[1]) - if cfg.finetune_from_model is not None and ( - reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader - ): - raise ValueError( - "--finetune-from-model can not be set together with either --reset-optimizer" - " or reset_lr_scheduler or reset_meters or reset_dataloader" - ) - suffix = trainer.checkpoint_suffix - default_restore_file = "checkpoint_last.pt" - # default to loading from restore file. - if cfg.restore_file == default_restore_file: - checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) - ) - first_launch = not PathManager.exists(checkpoint_path_to_load) - if cfg.finetune_from_model is not None and first_launch: - # if there is no last checkpoint to restore, start the finetune from pretrained model - # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - reset_optimizer = True - reset_lr_scheduler = True - reset_meters = True - reset_dataloader = True - checkpoint_path_to_load = None - if PathManager.exists(cfg.finetune_from_model): - checkpoint_path_to_load = cfg.finetune_from_model - elif suffix is not None: # check for sharded version - sharded_path = cfg.finetune_from_model.replace(".pt", suffix + ".pt") - if PathManager.exists(sharded_path): - checkpoint_path_to_load = sharded_path - if checkpoint_path_to_load is None: - raise ValueError( - f"--finetune-from-model {cfg.finetune_from_model} does not exist either as is or sharded" +def get_all_checkpoints_from_directory( + directory: str, suffix: str, add_priority: float, storage_type: str +) -> List[CheckpointPath]: + checkpoints = [] + for candidate in os.listdir(directory): + steps = get_checkpoint_steps(candidate) + if steps == 0: + continue + + # in scratch saved files are in this form: checkpoint_180-model_part-0-shard0.pt + if candidate.endswith(".pt"): + logger.info("is .pt file") + if suffix not in candidate: + continue + checkpoints.append( + CheckpointPath( + path=os.path.join(directory, candidate), + storage_type=storage_type, + priority=steps + add_priority, ) + ) + continue + # nfs and cached files look like this: checkpoint_180/checkpoint-model_part-0-shard0.pt + expected_file_count = distributed_utils.get_global_world_size() + present_files = len( + [ + f + for f in os.listdir(os.path.join(directory, candidate)) + if not f.startswith("_") + ] + ) + if present_files != expected_file_count: logger.info( - f"loading pretrained model from {checkpoint_path_to_load}: " - "optimizer, lr scheduler, meters, dataloader will be reset" + f"skipping checkpoint {candidate} in {directory} because it only has" + f" {present_files} files (expected {expected_file_count})" ) - elif suffix is not None: - checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") - else: - checkpoint_path_to_load = cfg.restore_file + continue - if cfg.restore_file != default_restore_file and cfg.finetune_from_model: - raise ValueError( - "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(cfg) + checkpoints.append( + CheckpointPath( + path=os.path.join(directory, candidate, f"checkpoint{suffix}.pt"), + storage_type=storage_type, + priority=steps + add_priority, + ) ) + return checkpoints + + +def get_recent_checkpoint_from_azure_blob( + blob_url: str, suffix: str, add_priority: float +) -> List[CheckpointPath]: + file_to_load = azure_utils.get_most_recent_ckpt(blob_url, suffix) + if file_to_load is None: + return [] + steps = get_checkpoint_steps(file_to_load) + return [ + CheckpointPath( + path=blob_url + "/" + file_to_load, + storage_type="azure_blob", + priority=steps + add_priority, + ) + ] - # Azure logic - try: - from metaseq_internal import azure_utils - has_metaseq_internal = True - except ImportError: - has_metaseq_internal = False - logger.warning( - "Proceeding without metaseq-internal installed! Please check if you need this!" +def get_checkpoint_to_finetune( + finetune_path: str, suffix: str, priority: float +) -> CheckpointPath: + if PathManager.exists(finetune_path): + validated_path = finetune_path + else: # check for sharded version + sharded_path = finetune_path.replace(".pt", suffix + ".pt") + if PathManager.exists(sharded_path): + validated_path = sharded_path + if validated_path is None: + raise ValueError( + f"--finetune-from-model {finetune_path} does not exist either as is or sharded" ) + return CheckpointPath( + path=validated_path, + storage_type=get_storage_type(validated_path), + priority=priority, + run_before_loading=[reset_for_finetuning], + ) - # TODO(susanz): fix all of this spagetti, split out logic by env - # Note that we compare by value since ComputeEnvs may be imported from metaseq_internal - if cfg.cloud_upload_path: - if cfg.cloud_upload_path.startswith("nfs:"): - checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) - ) - nfs_path = cfg.cloud_upload_path[4:] - filename = None - specific_restore_file_provided = cfg.restore_file != default_restore_file - slurm_was_restarted = int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - restart_from_latest = slurm_was_restarted or ( - cfg.finetune_from_model is None and not specific_restore_file_provided +def reset_for_finetuning(cfg, checkpoint): + cfg.reset_optimizer = True + cfg.reset_lr_scheduler = True + cfg.reset_meters = True + cfg.reset_dataloader = True + logger.info( + "Resetting optimizer, lr scheduler, meters, and dataloader for fine-tuning!" + ) + + +def prepare_local_checkpoint_path(cfg: CheckpointConfig, trainer) -> str: + suffix = trainer.checkpoint_suffix + + # collect all possible checkpoint paths + checkpoints = [] + if cfg.finetune_from_model: + checkpoints.append( + get_checkpoint_to_finetune(cfg.finetune_from_model, suffix, 0) + ) + + if cfg.restore_file: + checkpoints.append( + CheckpointPath( + path=cfg.restore_file.replace(".pt", suffix + ".pt"), + priority=get_checkpoint_steps(cfg.restore_file) + 0.1, + storage_type=get_storage_type(cfg.restore_file), ) - if restart_from_latest: - checkpoints = [] - expected_file_count = distributed_utils.get_global_world_size() - for candidate in os.listdir(nfs_path): - if candidate == "checkpoint_last": - raise RuntimeError( - "trying to restart a job that already wrote checkpoint_last" - ) - m = re.match(r"checkpoint_(\d+)", candidate) - if m: - checkpoints.append((int(m[1]), candidate)) - for _, candidate in sorted(checkpoints, reverse=True): - present_files = len( - [ - f - for f in os.listdir(os.path.join(nfs_path, candidate)) - if not f.startswith("_") - ] - ) - if present_files == expected_file_count: - filename = os.path.join( - nfs_path, candidate, f"checkpoint{suffix}.pt" - ) - break - logger.info( - f"skipping checkpoint {candidate} because it only has" - f" {present_files} files (expected {expected_file_count})" - ) - else: - filename = cfg.restore_file.replace(".pt", suffix + ".pt") - if filename is not None: - logger.info( - f"Copying checkpoint from nfs {filename} -> {checkpoint_path_to_load}" - ) - shutil.copyfile(filename, checkpoint_path_to_load) - else: - logger.info(f"No NFS checkpoints found") + ) - elif cfg.cluster_env == ComputeEnvs.AZURE.value and has_metaseq_internal: - if ( - # --restore-file was not passed, always download latest checkpoint - ( - cfg.restore_file == default_restore_file - and cfg.finetune_from_model is None - ) - # --restore-file was passed, but we requeued, so download latest checkpoint - or int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - ): - # download checkpoint into local save_dir - checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) - ) - azure_utils.download_recent_ckpt( - cfg.cloud_upload_path, checkpoint_path_to_load, suffix + ".pt" - ) - elif ( - # --restore-file was passed and is a blob URL, download that checkpoint - cfg.restore_file != default_restore_file - and "windows.net" in cfg.restore_file - ): - blob_url = cfg.restore_file.replace(".pt", suffix + ".pt") - # download checkpoint into local save_dir - checkpoint_path_to_load = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + if cfg.cloud_upload_path: + cloud_storage_type = get_storage_type(cfg.cloud_upload_path) + if cloud_storage_type == "nfs": + checkpoints.extend( + get_all_checkpoints_from_directory( + cfg.cloud_upload_path[4:], + suffix, + add_priority=0.2, + storage_type="nfs", ) - azure_utils.download_specific_ckpt(blob_url, checkpoint_path_to_load) - else: - logger.info( - f"Using checkpoint {checkpoint_path_to_load} even while on Azure" + ) + elif cloud_storage_type == "azure_blob": + checkpoints.extend( + get_recent_checkpoint_from_azure_blob( + cfg.cloud_upload_path, suffix, add_priority=0.2 ) + ) - # RSC logic: --restore-file was passed, and we requeued - elif ( - cfg.restore_file != default_restore_file - and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0 - ): - # point checkpoint_path to the current checkpoint directory for loading, if it exists. - save_dir_last = os.path.join( - cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + checkpoints.extend( + get_all_checkpoints_from_directory( + cfg.save_dir, suffix, add_priority=0.3, storage_type="local" ) - if PathManager.isfile(save_dir_last): - checkpoint_path_to_load = save_dir_last + ) - logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") + # get the most recent valid checkpoint + checkpoints.sort(key=lambda ckpt: ckpt.priority) + if len(checkpoints) == 0: + return "" + logger.info( + f"The following checkpoints were found to be ready to load: {str(checkpoints)}" + ) + checkpoint = checkpoints[-1] + _ = [hook(cfg, checkpoint) for hook in checkpoint.run_before_loading] + + if checkpoint.storage_type == "local": + return checkpoint.path + + # copy cloud checkpoints to a local cache file + local_cache_dir = os.path.join( + cfg.save_dir, f"cached_checkpoint_{int(checkpoint.priority)}" + ) + os.makedirs(local_cache_dir, exist_ok=True) + local_cache_file = os.path.join(local_cache_dir, f"checkpoint{suffix}.pt") + + logger.info(f"Copying checkpoint from {checkpoint.path} -> {local_cache_file}") + if checkpoint.storage_type == "nfs": + shutil.copyfile(checkpoint.path, local_cache_file) + elif checkpoint.storage_type == "azure_blob": + azure_utils.download_specific_ckpt(checkpoint.path, local_cache_file) + + return local_cache_file + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + checkpoint_path_to_load = prepare_local_checkpoint_path(cfg, trainer) + + logger.info(f"attempting to load checkpoint from: {checkpoint_path_to_load}") # make sure everyone is done downloading their checkpoints before we load distributed_utils.global_barrier() extra_state = trainer.load_checkpoint( checkpoint_path_to_load, - reset_optimizer, - reset_lr_scheduler, - optimizer_overrides, - reset_meters=reset_meters, + cfg.reset_optimizer, + cfg.reset_lr_scheduler, + ast.literal_eval(cfg.optimizer_overrides), + reset_meters=cfg.reset_meters, ) - if reset_dataloader and int(os.environ.get("SLURM_RESTART_COUNT", 0)) > 0: - logger.info( - f"Disregarding --reset-dataloader since we are continuing past a requeue" - ) - reset_dataloader = False - if extra_state is not None and not reset_dataloader: + if extra_state is not None and not cfg.reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index fc31bcfe7..de4e10db6 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -452,11 +452,11 @@ class CheckpointConfig(MetaseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) - restore_file: str = field( - default="checkpoint_last.pt", + restore_file: Optional[str] = field( + default=None, metadata={ "help": "filename from which to load checkpoint " - "(default: /checkpoint_last.pt" + "in the form nfs:path/to/dir/checkpoint.pt" }, ) finetune_from_model: Optional[str] = field( diff --git a/metaseq/dataclass/utils.py b/metaseq/dataclass/utils.py index 05f31df95..562dd8b50 100644 --- a/metaseq/dataclass/utils.py +++ b/metaseq/dataclass/utils.py @@ -9,7 +9,7 @@ import os import re from argparse import ArgumentError, ArgumentParser, Namespace -from dataclasses import _MISSING_TYPE, MISSING +from dataclasses import dataclass, field, _MISSING_TYPE, MISSING from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type @@ -473,3 +473,11 @@ def merge_with_parent(dc: MetaseqDataclass, cfg: MetaseqDataclass): merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg + + +@dataclass +class CheckpointPath: + path: str + storage_type: str + priority: float = 0 + run_before_loading: list = field(default_factory=list)