Skip to content

Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py #13617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
970e5fa
Remove callbacks.model_checkpoint in pyproject.toml
Jungwon-Lee Jul 12, 2022
fd6f244
fix mypy errors in model_checkpoint.py
Jungwon-Lee Jul 12, 2022
3f9a068
Merge branch 'master' into typing_model_checkpoint
Jungwon-Lee Jul 12, 2022
efd5f63
fix pep-error line for ckecking Union[int, bool] variable is True
Jungwon-Lee Jul 12, 2022
3bc6308
Merge branch 'typing_model_checkpoint' of https://github.com/BongYang…
Jungwon-Lee Jul 12, 2022
776574f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
27fdfbd
Self review
carmocca Jul 12, 2022
0639746
Self review
carmocca Jul 12, 2022
a6cfc2c
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 12, 2022
673f2f1
Support tensors to reduce_bool_decision
carmocca Jul 12, 2022
e3da8ae
Fix dtype
carmocca Jul 12, 2022
75fbd3d
Standardize tensor annotation
carmocca Jul 12, 2022
3963931
Keep int dtype
carmocca Jul 13, 2022
9ba9c78
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 13, 2022
5f3e8e6
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 14, 2022
3cf4302
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 19, 2022
7c3c8a7
assert dirpath not None
carmocca Jul 19, 2022
c2311bf
Undo reduce_bool_decision change. Cast to bool
carmocca Jul 19, 2022
5c0074b
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 19, 2022
bf9e4cb
Unused import
carmocca Jul 19, 2022
77c46c0
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 19, 2022
a2bd6a0
Merge branch 'master' into typing_model_checkpoint
carmocca Jul 20, 2022
9e4626a
Merge branch 'master' into typing_model_checkpoint
otaj Jul 20, 2022
953a087
empty
Jul 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
# validation, then we run after validation instead of on train epoch end
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1

def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
monitor_val = logs.get(self.monitor)

error_msg = (
Expand Down
42 changes: 24 additions & 18 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -231,13 +231,14 @@ def __init__(
self._save_on_train_epoch_end = save_on_train_epoch_end
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score = None
self.best_k_models = {}
self.current_score: Optional[Tensor] = None
self.best_k_models: Dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score = None
self.best_model_score: Optional[Tensor] = None
self.best_model_path = ""
self.last_model_path = ""

self.kth_value: Tensor
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
Expand All @@ -256,6 +257,7 @@ def state_key(self) -> str:

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self.__resolve_ckpt_dir(trainer)
assert self.dirpath is not None
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

Expand Down Expand Up @@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if self.save_top_k == 0:
return

Expand Down Expand Up @@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn

return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
Expand Down Expand Up @@ -493,15 +495,15 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

return should_update_best_and_save

@classmethod
def _format_checkpoint_name(
cls,
filename: Optional[str],
metrics: Dict[str, _METRIC],
metrics: Dict[str, Tensor],
prefix: str = "",
auto_insert_metric_name: bool = True,
) -> str:
Expand All @@ -522,7 +524,7 @@ def _format_checkpoint_name(
filename = filename.replace(group, f"{{0[{name}]")

if name not in metrics:
metrics[name] = 0
metrics[name] = torch.tensor(0)
filename = filename.format(metrics)

if prefix:
Expand All @@ -531,7 +533,7 @@ def _format_checkpoint_name(
return filename

def format_checkpoint_name(
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.

Expand Down Expand Up @@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
elif trainer.loggers:
if len(trainer.loggers) == 1:
assert trainer.logger is not None
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir
Expand All @@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def _get_metric_interpolated_filepath_name(
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(monitor_candidates)

Expand All @@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name(

return filepath

def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
monitor_candidates = deepcopy(trainer.callback_metrics)
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
# or does not exist we overwrite it as it's likely an error
Expand All @@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
return monitor_candidates

def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if not self.save_last:
return

Expand All @@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)

def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
assert self.monitor
current = monitor_candidates.get(self.monitor)
if self.check_monitor_top_k(trainer, current):
assert current is not None
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")

def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
Expand All @@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
trainer.strategy.remove_checkpoint(previous)

def _update_best_and_save(
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

Expand All @@ -691,11 +696,11 @@ def _update_best_and_save(
if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.kth_value = self.best_k_models[self.kth_best_model_path]

_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
Expand All @@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
if filepath is None:
assert self.dirpath
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __to_tensor(self, value: numbers.Number) -> Tensor:
return torch.tensor(value, device=self.device)

@staticmethod
def __check_numel_1(value: torch.Tensor, name: str) -> None:
def __check_numel_1(value: Tensor, name: str) -> None:
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
"""Reduce a boolean decision across all processes."""
return decision

def pre_backward(self, closure_loss: Tensor) -> None:
Expand Down
12 changes: 3 additions & 9 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,13 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, Tensor):
output = torch.tensor(output, device=self.root_device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
result_metric.meta.sync.should = should
cache = result_metric._computed
if cache is not None:
if not isinstance(cache, torch.Tensor):
if not isinstance(cache, Tensor):
raise ValueError(
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
f" Found {cache}"
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,9 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None:
self._loggers = loggers if loggers else []

@property
def callback_metrics(self) -> dict:
def callback_metrics(self) -> Dict[str, Tensor]:
# TODO: the true typing return can include dictionaries as defined in
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
return self._logger_connector.callback_metrics

@property
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
return gathered_result


def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result
Expand Down