Skip to content

Improved control of device stats callbacks #11796

@EricWiener

Description

@EricWiener

Proposed refactor

Separate device stats monitoring into separate callbacks per-device, but sub-class from DeviceStatsMonitor. This preserves the desired change from #9032 to consolidate the interface, but also allows for fine-grained control.

Motivation

With #9032, all the accelerators were combined under a single DeviceStatsMonitor callback. This consolidated the API, but it also removed fine-grained control. For instance, the GPUStatsMonitor that is now being deprecated used to provide fine-grained control over the nvidia-smi stats that were tracked: https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/callbacks/gpu_stats_monitor.py#L87-L95

However, the new interface defaults to using torch memory stats (which provide less info than nvidia-smi): https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/accelerators/gpu.py#L73-L75

Regardless of whether GPU stats are changed to default to nvidia-smi, the user no longer has control over what metrics are monitored. Additionally, if #11795 is merged, there will be additional CPU stats monitoring + whatever accelerator is used.

Pitch 1

If the user was allowed to specify specific stats to monitor, this would require the callback to look like:

DeviceStatsMonitor(
    cpu_stats: Optional[Union[bool, Set[str]]] = None,
    gpu_stats: Optional[Union[bool, Set[str]]] = None,
    tpu_stats: Optional[Union[bool, Set[str]]] = None,
)

This builds on top of the suggestion in #11253 (comment) where the values allowed are:

  • None: To know if the user passed a value
  • bool: To easily enable/disable a default set of stats
  • Set[str]: To enable and show this specific set of stats.
# enable cpu stats + stats for the current accelerator
DeviceStatsMonitor(cpu_stats=True)

# enable these cpu stats + stats for the current accelerator
DeviceStatsMonitor(cpu_stats={"ram", "temp"})

This design provides no argument validation via type checking/auto-complete.

Pitch 2

Have a common interface via a base class:

class DeviceStatsMonitor(Callback):

For each device, sub-class DeviceStatsMonitor and allow for configuration:

class GPUStatsMonitor(DeviceStatsMonitor):
  def __init__(
        self,
        memory_utilization: bool = True,
        gpu_utilization: bool = True,
        intra_step_time: bool = False,
        inter_step_time: bool = False,
        fan_speed: bool = False,
        temperature: bool = False,
    )

Add a CPUStatsMonitor.

If you want to track both CPU stats + another accelerator you can now pass:

trainer=Trainer(callbacks=[CPUStatsMonitor(), GPUStatsMonitor()])

Pitch 3

Use a single DeviceStatsMonitor with the option to specify cpu_stats=True and provide sensible default metrics. This will be a friendly generic interface for quickly tracking stats.

For other users, they should be able to access get_device_stats() from the accelerator class and get_device_stats should take optional arguments for configuration (i.e., get_device_stats() with no arguments should be sufficient, but it should also allow additional optional arguments to be passed that change per-device). This allows for customization of the stats without needing to make each device callback unique and highly customizable.

Currently, (in my opinion) it is a pain to make a Callback since you have to override multiple hooks even if you want the same/similar behavior per-hook. I instead propose adding a new DecoratedCallback class that derives from the Callback class that allows you to specify decorators in order to specify which hooks should be called without needing to define a lot of one-line functions. I also think _prefix_metric_keys should be made a public utility.

The user could now do:

class MyGPUStatsMonitor(DecoratedCallback):
    @pl_hook.on_train_batch_start
    @pl_hook.on_train_batch_end
    @pl_hook.on_val_batch_start
    @pl_hook.on_val_batch_end
    def log_batch_stats(
        key: str, 
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
        prefixed_device_stats = prefix_metric_keys(device_stats, key)
        trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

The current alternative to this would be:

class MyGPUStatsMonitor(Callback):
    def on_train_batch_start(
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
        prefixed_device_stats = prefix_metric_keys(device_stats, "on_train_batch_start")
        trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
    
    def on_train_batch_end(
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
        prefixed_device_stats = prefix_metric_keys(device_stats, "on_train_batch_end")
        trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
    
    # ...

Or if using a shared function it would be:

class MyGPUStatsMonitor(Callback):
    def _log_batch_stats(
        key: str, 
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        stats = GPUAccelerator.get_device_stats(use_nvidia_smi=True, metrics=["gpu.utilization", ...])
        prefixed_device_stats = prefix_metric_keys(device_stats, key)
        trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

    def on_train_batch_start(
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        self._log_batch_stats("on_train_batch_start", trainer)
    
    def on_train_batch_end(
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        unused: Optional[int] = 0
    ):
        self._log_batch_stats("on_train_batch_end", trainer)
    
    # ...

By providing utilities to get device metrics easily and making it faster/less LOC to create a Callback,
it becomes less of a pain to migrate away from DeviceStatsMonitor when you need to customize.

cc @justusschock @awaelchli @akihironitta @rohitgr7 @tchaton @Borda @kaushikb11 @ananthsub @daniellepintz @edward-io @mauvilsa

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions