-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Separate device stats monitoring into separate callbacks per-device, but sub-class from DeviceStatsMonitor
. This preserves the desired change from #9032 to consolidate the interface, but also allows for fine-grained control.
Motivation
With #9032, all the accelerators were combined under a single DeviceStatsMonitor
callback. This consolidated the API, but it also removed fine-grained control. For instance, the GPUStatsMonitor
that is now being deprecated used to provide fine-grained control over the nvidia-smi
stats that were tracked: https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/callbacks/gpu_stats_monitor.py#L87-L95
However, the new interface defaults to using torch memory stats (which provide less info than nvidia-smi): https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/accelerators/gpu.py#L73-L75
Regardless of whether GPU stats are changed to default to nvidia-smi
, the user no longer has control over what metrics are monitored. Additionally, if #11795 is merged, there will be additional CPU stats monitoring + whatever accelerator is used.
Pitch 1
If the user was allowed to specify specific stats to monitor, this would require the callback to look like:
DeviceStatsMonitor(
cpu_stats: Optional[Union[bool, Set[str]]] = None,
gpu_stats: Optional[Union[bool, Set[str]]] = None,
tpu_stats: Optional[Union[bool, Set[str]]] = None,
)
This builds on top of the suggestion in #11253 (comment) where the values allowed are:
None
: To know if the user passed a valuebool
: To easily enable/disable a default set of statsSet[str]
: To enable and show this specific set of stats.
# enable cpu stats + stats for the current accelerator
DeviceStatsMonitor(cpu_stats=True)
# enable these cpu stats + stats for the current accelerator
DeviceStatsMonitor(cpu_stats={"ram", "temp"})
This design provides no argument validation via type checking/auto-complete.
Pitch 2
Have a common interface via a base class:
class DeviceStatsMonitor(Callback):
For each device, sub-class DeviceStatsMonitor
and allow for configuration:
class GPUStatsMonitor(DeviceStatsMonitor):
def __init__(
self,
memory_utilization: bool = True,
gpu_utilization: bool = True,
intra_step_time: bool = False,
inter_step_time: bool = False,
fan_speed: bool = False,
temperature: bool = False,
)
Add a CPUStatsMonitor
.
If you want to track both CPU stats + another accelerator you can now pass:
trainer=Trainer(callbacks=[CPUStatsMonitor(), GPUStatsMonitor()])
Pitch 3
Use a single DeviceStatsMonitor
with the option to specify cpu_stats=True
and provide sensible default metrics. This will be a friendly generic interface for quickly tracking stats.
For other users, they should be able to access get_device_stats()
from the accelerator class and get_device_stats
should take optional arguments for configuration (i.e., get_device_stats()
with no arguments should be sufficient, but it should also allow additional optional arguments to be passed that change per-device). This allows for customization of the stats without needing to make each device callback unique and highly customizable.
Currently, (in my opinion) it is a pain to make a Callback since you have to override multiple hooks even if you want the same/similar behavior per-hook. I instead propose adding a new DecoratedCallback
class that derives from the Callback
class that allows you to specify decorators in order to specify which hooks should be called without needing to define a lot of one-line functions. I also think _prefix_metric_keys
should be made a public utility.
The user could now do:
class MyGPUStatsMonitor(DecoratedCallback):
@pl_hook.on_train_batch_start
@pl_hook.on_train_batch_end
@pl_hook.on_val_batch_start
@pl_hook.on_val_batch_end
def log_batch_stats(
key: str,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
prefixed_device_stats = prefix_metric_keys(device_stats, key)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
The current alternative to this would be:
class MyGPUStatsMonitor(Callback):
def on_train_batch_start(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
prefixed_device_stats = prefix_metric_keys(device_stats, "on_train_batch_start")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
def on_train_batch_end(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
prefixed_device_stats = prefix_metric_keys(device_stats, "on_train_batch_end")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
# ...
Or if using a shared function it would be:
class MyGPUStatsMonitor(Callback):
def _log_batch_stats(
key: str,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
prefixed_device_stats = prefix_metric_keys(device_stats, key)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
def on_train_batch_start(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
self._log_batch_stats("on_train_batch_start", trainer)
def on_train_batch_end(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0
):
self._log_batch_stats("on_train_batch_end", trainer)
# ...
By providing utilities to get device metrics easily and making it faster/less LOC to create a Callback,
it becomes less of a pain to migrate away from DeviceStatsMonitor
when you need to customize.
cc @justusschock @awaelchli @akihironitta @rohitgr7 @tchaton @Borda @kaushikb11 @ananthsub @daniellepintz @edward-io @mauvilsa