Skip to content

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Feb 26, 2025

What does this PR do?

This PR introduces upstream support for HPU torch device/backend:

  • HPU is the device name for Intel Gaudi Accelerators, a very powerful, energy efficient ASIC for AI workloads.
  • Gaudi1 is available on AWS since 2021, Gaudi2/Gaudi3 on Intel Dev Cloud and soon on IBM Cloud.
  • The documentation of the torch device is available here.

This PR focuses on enabling out of the box support in eager mode (PT_HPU_LAZY_MODE=0), while optimum-habana will continue to enable optimized paths making use of the lazy mode and advanced features of the SynapseAI software stack.

This is part of three PRs:

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as draft February 27, 2025 16:03
@IlyasMoutawwakil
Copy link
Member Author

@ArthurZucker @muellerzr PR is ready for review, I made sure (trainer, fsdp, deepspeed) tests ran successfully on both gaudi1 and gaudi2 in single and multi device settings.

Comment on lines 306 to 308
# the file doesn't exist in the repo
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
raise unittest.SkipTest("FSDP CPU offloading script not found!")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't find this file, is this test still relevant ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea cc @muellerzr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's meant to be:

from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

Was referenced in #31161 but never actually added? 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I leave it for another PR ? the file path utils/testing_scripts/fsdp_cpu_offloading.py doesn't make sense in transformers repo.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIce! Missing for me a bit of doc on:

  • what is HPU
  • how could anyone run on HPU?
    But that's it!

Comment on lines 306 to 308
# the file doesn't exist in the repo
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
raise unittest.SkipTest("FSDP CPU offloading script not found!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea cc @muellerzr

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Added a note for our apparent missing test file 👀

Comment on lines 306 to 308
# the file doesn't exist in the repo
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
raise unittest.SkipTest("FSDP CPU offloading script not found!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's meant to be:

from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

Was referenced in #31161 but never actually added? 😅

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go!

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good from the Trainer side in my eyes, only thing we may want is to add an accelerate import check to flag as a requirement (release will go live tonight)

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Mar 11, 2025

only thing we may want is to add an accelerate import check to flag

Added ! target version is 1.50 right ? @muellerzr

@ArthurZucker ArthurZucker merged commit 89f6956 into main Mar 12, 2025
20 of 24 checks passed
@ArthurZucker ArthurZucker deleted the hpu-support branch March 12, 2025 08:08
@regisss regisss mentioned this pull request Apr 29, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants