-
Notifications
You must be signed in to change notification settings - Fork 692
Profiler v2 #1089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Profiler v2 #1089
Changes from 14 commits
8ca6317
e7c9102
27434e6
c7b57ee
04dc380
b647ce7
7671ab8
ff5ef2e
23c2147
f9356c5
0342a7e
2d3dfdc
d16e942
8fe3138
230e2d3
0ab1f76
e072c7c
fd3b49b
3d3c4a2
9e840f4
31945de
8add75a
84bb1a0
a0b07e5
5354f62
c371429
598778a
6602b37
9841323
21161ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,12 @@ | |
| validate_state_dict_for_lora, | ||
| ) | ||
| from torchtune.recipe_interfaces import FTRecipeInterface | ||
|
|
||
| from torchtune.utils import ( | ||
| DEFAULT_TRACE_OPTS, | ||
| FakeProfiler, | ||
| PROFILER_KEY, | ||
| setup_torch_profiler, | ||
| ) | ||
| from tqdm import tqdm | ||
|
|
||
| log = utils.get_logger("DEBUG") | ||
|
|
@@ -270,6 +275,118 @@ def setup(self, cfg: DictConfig) -> None: | |
| last_epoch=self.global_step - 1, | ||
| ) | ||
|
|
||
| # Set up profiler, returns FakeProfiler (nullcontext object with no-op `step` method) | ||
| # if cfg is missing profiler key or if `cfg.profiler.enabled = False` | ||
| self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None), log_cfg=True) | ||
|
|
||
| def _setup_profiler( | ||
| self, cfg_profiler: DictConfig, log_cfg: bool = False | ||
| ) -> torch.profiler.profile | FakeProfiler: | ||
| """ | ||
| Parses the `profiler` section of top-level `cfg` and sets up profiler | ||
|
|
||
| Args: | ||
| cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) | ||
| log_cfg: bool - whether to return the profiler config after profiler setup, which sets defaults and possibly | ||
| overrides certain profiling options. | ||
|
|
||
| NOTE: Since not all settings of the profiler can be parsed from the returned profiler object, | ||
| such as the `schedule`, `log_cfg` can be used for easy logging / debugging of all profiler options post setup. | ||
|
|
||
| Returns: | ||
| profiler: torch.profiler.profile | FakeProfiler - FakeProfiler is a nullcontext with no-op methods | ||
| for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such | ||
| that the instrumented training loop does not need to be changed profiling is disabled. | ||
| profiler_cfg: Optional[DictConfig] | ||
|
|
||
| The profiler config can be provided in configs under the `profiler` key with the following layout: | ||
| ``` | ||
| profiler: | ||
| enabled: bool | ||
|
|
||
| #Output directory of trace artifacts | ||
| output_dir: str | ||
|
|
||
| #`torch.profiler.ProfilerActivity` types to trace | ||
| CPU: bool | ||
| CUDA: bool | ||
|
|
||
| #Trace options | ||
| profile_memory: bool | ||
| with_stack: bool | ||
| record_shapes: bool | ||
| with_flops: bool | ||
|
|
||
| #`torch.profiler.schedule` args | ||
| schedule: | ||
| wait: int | ||
| warmup: int | ||
| active: int | ||
| repeat: int | ||
| ``` | ||
| """ | ||
| # Check whether `profiler` key is present in the config and that it is not empty; | ||
| # if it is present check that `enabled = True` | ||
| if (cfg_profiler is not None and len(cfg_profiler) > 0) and cfg_profiler.get( | ||
jeromeku marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "enabled", True | ||
| ): | ||
| enabled = True | ||
| else: | ||
| if self._is_rank_zero: | ||
| log.info(" Profiling disabled.") | ||
| return FakeProfiler() | ||
|
||
|
|
||
| # Set up profiler activities | ||
|
||
| cpu = cfg_profiler.get("CPU", False) | ||
| cuda = cfg_profiler.get("CUDA", False) | ||
| profile_memory = cfg_profiler.get( | ||
| "profile_memory", DEFAULT_TRACE_OPTS["profile_memory"] | ||
| ) | ||
| with_stack = cfg_profiler.get("with_stack", DEFAULT_TRACE_OPTS["with_stack"]) | ||
| record_shapes = cfg_profiler.get( | ||
| "record_shapes", DEFAULT_TRACE_OPTS["record_shapes"] | ||
| ) | ||
| with_flops = cfg_profiler.get("with_flops", DEFAULT_TRACE_OPTS["with_flops"]) | ||
| output_dir = cfg_profiler.get("output_dir", None) | ||
|
|
||
| # Parse schedule specific args | ||
| schedule_cfg = cfg_profiler.get("schedule", None) | ||
|
|
||
| if schedule_cfg is None: | ||
|
||
| wait = None | ||
| warmup = None | ||
| active = None | ||
| repeat = None | ||
| else: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused here, if the user does not specify the schedule the default will get used, but if they have a schedule field but the parameters are empty then wait, etc will be set to None. If the profiler is provided in the config and is enabled, don't we want to default schedule to always be the default unless overridden? Or am I misunderstanding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| wait = schedule_cfg.get("wait", None) | ||
| warmup = schedule_cfg.get("warmup", None) | ||
| active = schedule_cfg.get("active", None) | ||
| repeat = schedule_cfg.get("repeat", None) | ||
|
||
|
|
||
| # Delegate setup of actual profiler and optionally return updated profiler config | ||
| profiler, profiler_cfg = setup_torch_profiler( | ||
| enabled=enabled, | ||
| cpu=cpu, | ||
| cuda=cuda, | ||
| profile_memory=profile_memory, | ||
| with_stack=with_stack, | ||
| record_shapes=record_shapes, | ||
| with_flops=with_flops, | ||
| wait=wait, | ||
| warmup=warmup, | ||
| active=active, | ||
| repeat=repeat, | ||
| output_dir=output_dir, | ||
| ) | ||
|
|
||
| if self._is_rank_zero: | ||
| if log_cfg: | ||
| log.info(f" Profiler config after instantiation: {profiler_cfg}") | ||
| else: | ||
| log.info(" Profiler instantiated.") | ||
|
|
||
| return profiler | ||
|
|
||
| def _setup_model( | ||
| self, | ||
| cfg_model: DictConfig, | ||
|
|
@@ -563,6 +680,7 @@ def train(self) -> None: | |
| running_loss = 0 | ||
| num_tokens = 0 | ||
|
|
||
| self._profiler.start() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's really nice how cleanly this fits into our training loop. This is exactly what I was hoping for |
||
| # self.epochs_run should be non-zero when we're resuming from a checkpoint | ||
| for curr_epoch in range(self.epochs_run, self.total_epochs): | ||
|
|
||
|
|
@@ -646,9 +764,16 @@ def train(self) -> None: | |
| num_tokens = 0 | ||
| t0 = time.perf_counter() | ||
|
|
||
| # Step profiler | ||
| # Note that this is called within gradient accumulation block, hence | ||
| # will include multiple forward / backward passes if gradient accumulation > 1 | ||
| self._profiler.step() | ||
|
|
||
| self.epochs_run += 1 | ||
| self.save_checkpoint(epoch=curr_epoch) | ||
|
|
||
| self._profiler.stop() | ||
|
|
||
| def cleanup(self) -> None: | ||
| if self._is_rank_zero: | ||
| self._metric_logger.close() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.