Skip to content

Conversation

@jeromeku
Copy link
Contributor

@jeromeku jeromeku commented Jun 13, 2024

@ebsmothers @RdoubleA

Opening a new PR in place of #1070 as commit history from prior PR was corrupted after re-basing on main.

Updates

  • Refactored setup_torch_profiler
    • this function now takes args that have been flattened, parsed from the top-level config
    • parsing of the top-level config is now handled by a _setup_profiler function defined in the recipe; see here for such an example
  • Updated tests to reflect the modified API
  • Updated the following configs / recipes to demonstrate use of the updated the profiler

Examples

Run single node LoRA finetune on Llama2

  tune run lora_finetune_single_device \
  --config llama2/7B_lora_single_device \
  max_steps_per_epoch=10 \
  gradient_accumulation_steps=1 \
  profiler.enabled=True \
  profiler.profile_memory=True \
  profiler.schedule.wait=2 \
  profiler.schedule.warmup=2 \
  profiler.schedule.active=1 \
  profiler.schedule.repeat=0 \
  profiler.output_dir=./profile-test

This will profile in a 5-step cycle for 10 total steps and export the trace results to profile-test.

Since there are 2 active steps, there will be two corresponding sub-directories within the export directory, labeled iteration-{5,10} respectively, each with: 1) gzipped tensorboard / chrome trace, 2) stack trace, 3) an html memory timeline.

Run distributed LoRA fine-tune on on Llama2

tune run --nnodes=1 --nproc_per_node=2 lora_finetune_distributed \
--config llama2/7B_lora \
max_steps_per_epoch=3 \
gradient_accumulation_steps=2 \
profiler.enabled=True \
profiler.profile_memory=True \
profiler.schedule.wait=1 \
profiler.schedule.warmup=1 \
profiler.schedule.active=1 \
profiler.output_dir=./profile-test

Should result in a single sub-directory within profile-test at iteration-3 containing the exported traces labeled by rank along with the memory timeline for rank-0 (exporting multiple memory timelines results in export errors on multiple devices).

Note that the chrome traces for the two examples will be different due to where the profiler is stepped within the recipe train loop. In the former, the profiler is stepped per batch whereas in the latter, the profiler is stepped per optimizer step, which will contain multiple batches when gradient accumulation > 1. This nuance is included in the docstring for setup_torch_profiler.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1089

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 21161ea with merge base 2fe9a70 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 13, 2024
@jeromeku jeromeku mentioned this pull request Jun 13, 2024
@jeromeku jeromeku marked this pull request as ready for review June 13, 2024 17:56
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None), log_cfg=True)

def _setup_profiler(
self, cfg: DictConfig, log_cfg: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this cfg_profiler, similar to the other setup methods, then it's very clear that it's the profiler field and not the full config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

output_dir: ${output_dir}/torchtune_perf_tracing.json

#Output directory of trace artifacts
output_dir: /tmp/profiling_outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to reuse the output_dir we are already creating above so we don't make too many stray folders

Suggested change
output_dir: /tmp/profiling_outputs
output_dir: ${output_dir}/profiling_outputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pbar.update(1)
pbar.set_description(
f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
with self._profiler as prof:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this used as a context manager and in the other recipe we use start and stop, is that purposeful to show different ways to use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes -- idea is to showcase use as object and as a context manager.

self._sampler.set_epoch(curr_epoch)

# Optionally profile the training loop
with self._profiler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to define profiler context twice, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah - this was a mistake. Removed second context manager.

pass


def should_profile(cfg: DictConfig) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just check cfg.profiler.enabled directly and remove this util? This would reduce one level of indirection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

log = get_logger("INFO")


def profiler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me the relationship between this method and setup_torch_profiler, is this now deprecated? If so, let's just remove it entirely

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

output_dir,
metric="self_cuda_time_total",
row_limit=25,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding a docstring for this? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- if no schedule is specified, profiler will default to wait 10, warmup 5, active 3, repeat 1
- if a schedule is specified, profiler will validate that the schedule is valid and can be passed to `instantiate`
- certain options will be overridden (`with_stack` and `record_shapes`) depending on requirements of other options
- e.g., `profile_memory` requires `with_stack` and `record_shapes`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for listing out potential overrides, this is awesome 👏

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dir = str(output_dir)
callback = partial(trace_handler, output_dir=output_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here for what the trace_handler is doing would be helpful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@jeromeku
Copy link
Contributor Author

@RdoubleA

Fixed all requests.

Copy link
Collaborator

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall this looks really clean to me, I'll go ahead and run the CI. Tagging @ebsmothers and @rohan-varma for any additional thoughts here

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really shaping up nicely! A few more comments on how we expose defaults and consistency, but really love how non-invasive this is in the training loop itself.

on_trace_ready=callback,
)

profiler_cfg = DictConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I actually disagree with @RdoubleA's previous comment here. I think it's better to keep the return type of setup_torch_profiler simple, I don't expect to get back a Tuple when calling setup_torch_profiler, I expect to just get back a profiler. Anyways, don't want to cause too much thrash on this so don't worry about changing it back for now; I think that's something we can sort out separately. But do lmk if you have any thoughts on this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about handle logging the config here instead of in the recipe? In fact you can just pretty print a regular duct instead of a DictConfig object. Then you don't need to return a Tuple.

Comment on lines 32 to 44
_DEFAULT_SCHEDULE: dict = {
"wait": 5,
"warmup": 5,
"active": 2,
"repeat": 1,
}

DEFAULT_TRACE_OPTS: dict = {
"profile_memory": False,
"with_stack": False,
"record_shapes": True,
"with_flops": False,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice you use DEFAULT_TRACE_OPTS in the recipe to set defaults but not _DEFAULT_SCHEDULE. Is this just because it's more natural to use bools for default values than ints? My question is kinda similar to my other comment below.. ideally we can handle this consistently across the default values (I don't want to have half defined in the recipe and half defined in here).

Comment on lines 353 to 364
schedule_cfg = cfg_profiler.get("schedule", None)

if schedule_cfg is None:
wait = None
warmup = None
active = None
repeat = None
else:
wait = schedule_cfg.get("wait", None)
warmup = schedule_cfg.get("warmup", None)
active = schedule_cfg.get("active", None)
repeat = schedule_cfg.get("repeat", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these fields have to be nested under schedule_cfg? (I get that it's natural for a Hydra-like instantiate API, but if we are just passing as flat args anyways idk how much it buys us)

running_loss = 0
num_tokens = 0

self._profiler.start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really nice how cleanly this fits into our training loop. This is exactly what I was hoping for

@jeromeku
Copy link
Contributor Author

@ebsmothers

Surfaced all default profiler settings to recipe per your suggestion.

warmup = DEFAULT_SCHEDULE["warmup"]
active = DEFAULT_SCHEDULE["active"]
repeat = DEFAULT_SCHEDULE["repeat"]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here, if the user does not specify the schedule the default will get used, but if they have a schedule field but the parameters are empty then wait, etc will be set to None. If the profiler is provided in the config and is enabled, don't we want to default schedule to always be the default unless overridden? Or am I misunderstanding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 67.79661% with 95 lines in your changes missing coverage. Please review.

Project coverage is 66.60%. Comparing base (74fb5e4) to head (3d3c4a2).
Report is 9 commits behind head on main.

Files Patch % Lines
recipes/lora_finetune_single_device.py 0.00% 34 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 33 Missing ⚠️
torchtune/utils/_profiler.py 71.87% 27 Missing ⚠️
tests/test_profiler.py 99.23% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1089       +/-   ##
===========================================
+ Coverage   26.66%   66.60%   +39.94%     
===========================================
  Files         183      184        +1     
  Lines        8326     8615      +289     
===========================================
+ Hits         2220     5738     +3518     
+ Misses       6106     2877     -3229     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this through, couple of nit comments but nothing blocking and can be done as follow-ups. There's a few checks that are failing on CI, lmk if you have issues fixing those. You can fix the lint by running pre-commit run --all-files locally

if not enabled:
if self._is_rank_zero:
log.info(" Profiling disabled.")
return FakeProfiler()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very clean 👌

if schedule_cfg is None:
schedule_cfg = DEFAULT_SCHEDULE
else:
schedule_cfg = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this can be simplified further:

for k in DEFAULT_SCHEDULE.keys():
    if k not in schedule_cfg:
        schedule_cfg[k] = DEFAULT_SCHEDULE[k]
        log.warning(f"Missing key {k} in schedule config, defaulting to {k} = {schedule_cfg[k]}")

# Check for schedule
# 1) If no schedule is provided, set to DEFAULT_SCHEDULE
# 2) else check for missing keys and warn if any are missing, setting these to defaults
if schedule_cfg is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just change the above line and remove this if-else logic

schedule_cfg = cfg_profiler.get("schedule", DEFAULT_SCHEDULE)

if you take my suggestion below as well, if schedule_cfg == DEFAULT_SCHEDULE then nothing will be replaced and it will be a no-op

@RdoubleA
Copy link
Collaborator

@jeromeku went ahead and fixed the doc build since sphinx is a pain to deal with, though some bullets are not rendering correctly. mind taking a look? @joecummings

log.info(" Profiling disabled.")
return FakeProfiler()

# Set up profiler activities
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe all of these configs could be moved into the setup_torch_profiler util as optional args and then setup_torch_profiler could be called directly with config.instantiate

record_shapes: True
with_flops: False

`torch.profiler.schedule` options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a #

with_stack: bool = False,
record_shapes: bool = False,
with_flops: bool = False,
wait: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we altered this to support config.initialize(setup_torch_profiler) from the recipe, then all of the schedule args could be renamed to "schedule_warmup", "schedule_active", etc allowing for a flat list of profiler configs. All of these config options that are None would be optional by default in the config and if they get a None value, they can be set to the default value in the init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you're describing is more in-line with my original PR, where most of the code in _setup_profiler was abstracted away from the user in setup_torch_profiler.

Per the comments from @RdoubleA and @ebsmothers, I refactored the original setup_torch_profiler such that most of the setup was extracted from setup_torch_profiler to the recipe, which admittedly results in more code duplication and config parsing on the part of the recipe writer / user.

Not sure what direction to take at this point.

Copy link
Contributor

@pbontrager pbontrager Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for the confusion. In the original version, there was a lot of config parsing happening inside of the util, which we don't want for a utility since our principle is we want each function to work as code and to not assume that a config will be involved. So I agree with all the changes to to remove the config parsing from the util. But I believe that we can still get that same simplicity that you originally proposed by defining all the arguments inside of the init instead of taking in a Dict. Something like this

def setup_torch_profiler(
        self,
        *,
        enabled: bool = True
        output_dir: str = "profiler_output",
        #`torch.profiler.ProfilerActivity` types to trace
        activity_types: List[str] = ACTIVITY_TYPES.keys(),
        #trace options passed to `torch.profiler.profile`
        profile_memory: bool =  False
  	with_stack: bool =  False
  	record_shapes: bool = True
  	with_flops: bool =  True
  	# `torch.profiler.schedule` options
	schedule_wait: int = 5
	schedule_warmup: int = 5
        schedule_active: int = 2
	schedule_repeat: int = 1
        log_cfg: bool = False
) -> ContextManager:

Then config.instantiate(setup_torch_profiler) could initialize it from the config automatically and setup_torch_profiler works as a stand alone function. Does this make sense? I'll tag @RdoubleA and @ebsmothers here to make sure I'm not contradicting them, because I don't want you to go back and forth on work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry for the thrash here @jeromeku. I think @pbontrager's last comment here generally makes sense to me. Basically we can define reasonable defaults for all params inside of setup_torch_profiler. That way we don't have to go through the work of parsing everything and setting defaults inside the recipe. Then the config would look like

profiler:
  _component_: torchtune.utils._profiler.setup_torch_profiler
  enabled: True
  cpu: True
  cuda: True
  ...

and in the recipe you can just call config.instantiate(cfg.profiler) in the recipe's _setup_profiler method. You then only have to check that the profiler field is defined, but by putting defaults in the setup_torch_profiler utility you don't really have to do a bunch of sanity checks on missing fields in the config.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late comments. This is a great addition and there has already been a lot of discussion on this PR, but I think it's important that we limit the amount of code this adds to the recipes as it adds a lot of boilerplate for the user. If we make our profiler util handle all of it's optional kwargs on init, then we should be able to directly utilize config.initialize. @RdoubleA does this match with your understanding too?

torch.distributed.barrier()


class FakeProfiler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we use Dummy throughout the library instead of Fake


enabled = cfg_profiler is not None and cfg_profiler.get("enabled", True)

if not enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs can be moved into the util as well. I think we'd want the util to take "enabled" as a kwarg as well so we can have it output the FakeProfiler if it's disabled or on the wrong rank.

@jeromeku
Copy link
Contributor Author

@RdoubleA @pbontrager @ebsmothers
Refactored _setup_profiler in recipes to directly instantiate torchtune.utils.setup_torch_profiler.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the quick turn around! I think this is looking good now, thank you for this great addition! (It should be good to go after the tests to pass)

@jeromeku
Copy link
Contributor Author

@RdoubleA @ebsmothers @pbontrager

Any tips on how to fix the docs and recipe check failures?

@ebsmothers
Copy link
Contributor

@jeromeku for the recipe failure I think it's because | is only available for typing from Python 3.10 onwards, and our CI jobs also run on 3.8 and 3.9. So for that maybe just switch to Union. Btw if you do have access to one or more GPUs and wanna run locally to get quicker signal you can run pytest tests -m integration_test to run the recipe tests (maybe you already did this and are just on python >= 3.10).

For the doc signal I am not sure but the triple backtick on L220 looks suspicious to me. Similarly there you can do a local build of the docs by following these instructions. (Admittedly some of the sphinx build errors can be quite annoying, if you are really banging your head against the wall here let me know and one of us can take a look at it.)

Also I guess linter signal is failing, just make sure to add newlines at the end of the config files for that one.

@jeromeku
Copy link
Contributor Author

jeromeku commented Jun 20, 2024

@ebsmothers

Thanks.

Linting is easy enough to fix (mistake on my part for not running pre-commit hooks on config files).

Regarding recipe tests, I'm getting failures on tests for recipes / configs that were not affected by changes in this PR:

FAILED tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama3/8B_lora_single_device-llama3-tune-True]
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama3/8B_lora_single_device-llama3-tune-False]
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-fp32]
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-bf16]
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[False-fp32]
FAILED tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[False-bf16]

Also, any config that uses the previous profiler config will no longer be valid, since the _component_ should now be torchtune.utils.setup_torch_profiler instead of torchtune.utils.profiler (shouldn't affect the tests, I don't think).

Should I go ahead and change these in all config files with a profiler section (I only changed the ones for llama2/7B_lora and llama2/7B_lora_single_device as a demo). I suppose any of the recipes which used the previous API will need to be changed as well.

@ebsmothers
Copy link
Contributor

ebsmothers commented Jun 21, 2024

Also, any config that uses the previous profiler config will no longer be valid, since the component should now be torchtune.utils.setup_torch_profiler instead of torchtune.utils.profiler (shouldn't affect the tests, I don't think).
Should I go ahead and change these in all config files with a profiler section (I only changed the ones for llama2/7B_lora and llama2/7B_lora_single_device as a demo). I suppose any of the recipes which used the previous API will need to be changed as well.

@jeromeku good point, this will be BC-breaking for our LoRA recipes. Not the end of the world as we are actually planning to do a release soon (so now is actually the perfect time to get this in). I think updating any other configs containing profiler usage to point to the new API would be a good idea. I don't think any other recipes in the core repo use the profiler API though (but user recipes will need to change when they pull latest changes).

Re the failing tests, do you have a specific error message for any of them? I agree it's not caused by your changes so not a huge concern for you, but we do wanna get CI green to land. I can also take a pass on my end to repro (fwiw some tests run only locally and not in CI depending on which versions you have available, so you may see some test failures on our more bleeding-edge features locally)

Edit: To follow up on the test failures, I created a fresh conda env and I don't see any of these (though the eval failure is probably because lm_eval needs to be installed separately and is not part of our core install). Do you happen to have a more bleeding-edge version of ao installed? If so I suspect that's why some of the other ones are failing (our current install is pinned to 0.1 for the time being)

@jeromeku
Copy link
Contributor Author

@ebsmothers

Updated all configs with a profiler entry with the new profiler API. Also, all integrated tests are now passing (except for lm_eval since I don't have that installed).

@ebsmothers
Copy link
Contributor

@jeromeku thanks, this is looking good! I am gonna run a couple sanity checks on my end and while waiting for GPU tests to run. After that I think this should be good to merge

@ebsmothers ebsmothers merged commit 52e3283 into meta-pytorch:main Jun 26, 2024
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
FlamingoPg pushed a commit to FlamingoPg/sgl-tune-eagle that referenced this pull request May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants