Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8ca6317
add updated profiler
jeromeku Jun 13, 2024
e7c9102
tests/test_profiler.py
jeromeku Jun 13, 2024
27434e6
update llama2 7B configs and recipes with profiler
jeromeku Jun 13, 2024
c7b57ee
linting
jeromeku Jun 13, 2024
04dc380
update configs and recipes
jeromeku Jun 14, 2024
b647ce7
remove deprecated profiler
jeromeku Jun 14, 2024
7671ab8
remove reference profiler setup
jeromeku Jun 14, 2024
ff5ef2e
add documentation for profiler trace_handler
jeromeku Jun 14, 2024
23c2147
always return profiler config from setup_torch_profiler
jeromeku Jun 14, 2024
f9356c5
remove should_profile
jeromeku Jun 14, 2024
0342a7e
update tests with profiler cfg
jeromeku Jun 14, 2024
2d3dfdc
lint
jeromeku Jun 14, 2024
d16e942
fix recipe _set_profiler signature
jeromeku Jun 14, 2024
8fe3138
re-run lint
jeromeku Jun 15, 2024
230e2d3
surface default profiler opts in recipe
jeromeku Jun 15, 2024
0ab1f76
update scheduler setup
jeromeku Jun 17, 2024
e072c7c
update tests
jeromeku Jun 17, 2024
fd3b49b
lint
jeromeku Jun 17, 2024
3d3c4a2
simplify profiler check
jeromeku Jun 17, 2024
9e840f4
fix live docs
RdoubleA Jun 17, 2024
31945de
refactor profiler setup
jeromeku Jun 18, 2024
8add75a
fixup lora config
jeromeku Jun 18, 2024
84bb1a0
resolve merge conflicts
jeromeku Jun 18, 2024
a0b07e5
fix docs
jeromeku Jun 18, 2024
5354f62
add component to config
jeromeku Jun 18, 2024
c371429
change | to Union
jeromeku Jun 20, 2024
598778a
update configs with new profiler API
jeromeku Jun 25, 2024
6602b37
run precommit hooks
jeromeku Jun 25, 2024
9841323
fix profiler docstring
jeromeku Jun 25, 2024
21161ea
fix profiler test
jeromeku Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ of your finetuning job.

get_memory_stats
log_memory_stats
profiler
setup_torch_profiler

.. _metric_logging_label:

Expand Down
25 changes: 25 additions & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,28 @@ log_peak_memory_stats: False
device: cuda
dtype: bf16
enable_activation_checkpointing: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
CPU: True
CUDA: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

#`torch.profiler.schedule` options
schedule:
wait: 1
warmup: 1
active: 1
repeat: 1
24 changes: 22 additions & 2 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,29 @@ device: cuda
dtype: bf16
enable_activation_checkpointing: True

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.utils.profiler
enabled: False
output_dir: ${output_dir}/torchtune_perf_tracing.json

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
CPU: True
CUDA: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

#`torch.profiler.schedule` options
schedule:
wait: 1
warmup: 1
active: 1
repeat: 1
127 changes: 126 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

from torchtune.utils import (
DEFAULT_TRACE_OPTS,
FakeProfiler,
PROFILER_KEY,
setup_torch_profiler,
)
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -270,6 +275,118 @@ def setup(self, cfg: DictConfig) -> None:
last_epoch=self.global_step - 1,
)

# Set up profiler, returns FakeProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None), log_cfg=True)

def _setup_profiler(
self, cfg_profiler: DictConfig, log_cfg: bool = False
) -> torch.profiler.profile | FakeProfiler:
"""
Parses the `profiler` section of top-level `cfg` and sets up profiler

Args:
cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`)
log_cfg: bool - whether to return the profiler config after profiler setup, which sets defaults and possibly
overrides certain profiling options.

NOTE: Since not all settings of the profiler can be parsed from the returned profiler object,
such as the `schedule`, `log_cfg` can be used for easy logging / debugging of all profiler options post setup.

Returns:
profiler: torch.profiler.profile | FakeProfiler - FakeProfiler is a nullcontext with no-op methods
for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
that the instrumented training loop does not need to be changed profiling is disabled.
profiler_cfg: Optional[DictConfig]

The profiler config can be provided in configs under the `profiler` key with the following layout:
```
profiler:
enabled: bool

#Output directory of trace artifacts
output_dir: str

#`torch.profiler.ProfilerActivity` types to trace
CPU: bool
CUDA: bool

#Trace options
profile_memory: bool
with_stack: bool
record_shapes: bool
with_flops: bool

#`torch.profiler.schedule` args
schedule:
wait: int
warmup: int
active: int
repeat: int
```
"""
# Check whether `profiler` key is present in the config and that it is not empty;
# if it is present check that `enabled = True`
if (cfg_profiler is not None and len(cfg_profiler) > 0) and cfg_profiler.get(
"enabled", True
):
enabled = True
else:
if self._is_rank_zero:
log.info(" Profiling disabled.")
return FakeProfiler()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very clean 👌


# Set up profiler activities
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe all of these configs could be moved into the setup_torch_profiler util as optional args and then setup_torch_profiler could be called directly with config.instantiate

cpu = cfg_profiler.get("CPU", False)
cuda = cfg_profiler.get("CUDA", False)
profile_memory = cfg_profiler.get(
"profile_memory", DEFAULT_TRACE_OPTS["profile_memory"]
)
with_stack = cfg_profiler.get("with_stack", DEFAULT_TRACE_OPTS["with_stack"])
record_shapes = cfg_profiler.get(
"record_shapes", DEFAULT_TRACE_OPTS["record_shapes"]
)
with_flops = cfg_profiler.get("with_flops", DEFAULT_TRACE_OPTS["with_flops"])
output_dir = cfg_profiler.get("output_dir", None)

# Parse schedule specific args
schedule_cfg = cfg_profiler.get("schedule", None)

if schedule_cfg is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just change the above line and remove this if-else logic

schedule_cfg = cfg_profiler.get("schedule", DEFAULT_SCHEDULE)

if you take my suggestion below as well, if schedule_cfg == DEFAULT_SCHEDULE then nothing will be replaced and it will be a no-op

wait = None
warmup = None
active = None
repeat = None
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here, if the user does not specify the schedule the default will get used, but if they have a schedule field but the parameters are empty then wait, etc will be set to None. If the profiler is provided in the config and is enabled, don't we want to default schedule to always be the default unless overridden? Or am I misunderstanding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

wait = schedule_cfg.get("wait", None)
warmup = schedule_cfg.get("warmup", None)
active = schedule_cfg.get("active", None)
repeat = schedule_cfg.get("repeat", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these fields have to be nested under schedule_cfg? (I get that it's natural for a Hydra-like instantiate API, but if we are just passing as flat args anyways idk how much it buys us)


# Delegate setup of actual profiler and optionally return updated profiler config
profiler, profiler_cfg = setup_torch_profiler(
enabled=enabled,
cpu=cpu,
cuda=cuda,
profile_memory=profile_memory,
with_stack=with_stack,
record_shapes=record_shapes,
with_flops=with_flops,
wait=wait,
warmup=warmup,
active=active,
repeat=repeat,
output_dir=output_dir,
)

if self._is_rank_zero:
if log_cfg:
log.info(f" Profiler config after instantiation: {profiler_cfg}")
else:
log.info(" Profiler instantiated.")

return profiler

def _setup_model(
self,
cfg_model: DictConfig,
Expand Down Expand Up @@ -563,6 +680,7 @@ def train(self) -> None:
running_loss = 0
num_tokens = 0

self._profiler.start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really nice how cleanly this fits into our training loop. This is exactly what I was hoping for

# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):

Expand Down Expand Up @@ -646,9 +764,16 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

self._profiler.stop()

def cleanup(self) -> None:
if self._is_rank_zero:
self._metric_logger.close()
Expand Down
Loading