Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
166 commits
Select commit Hold shift + click to select a range
4f462b0
init
IlyasMoutawwakil Feb 4, 2025
7b51103
style
IlyasMoutawwakil Feb 4, 2025
9d7376e
is_hpu_available
IlyasMoutawwakil Feb 4, 2025
069b88a
fix
IlyasMoutawwakil Feb 4, 2025
cd3cbb9
import habana_frameworks.torch.distributed.hccl
IlyasMoutawwakil Feb 4, 2025
2493abe
style
IlyasMoutawwakil Feb 4, 2025
32cbc88
test
IlyasMoutawwakil Feb 4, 2025
5fd4de2
initialize dist proc group
IlyasMoutawwakil Feb 4, 2025
7f72745
revert
IlyasMoutawwakil Feb 5, 2025
f66c5df
set backend to hccl only if hccl initialization sets a local rank
IlyasMoutawwakil Feb 5, 2025
2a4130d
force backend hccl and multi_hpu type when sure of distributed launch
IlyasMoutawwakil Feb 5, 2025
fa1bc44
style
IlyasMoutawwakil Feb 5, 2025
d3e24c5
pass accelerator tests
IlyasMoutawwakil Feb 6, 2025
00cc283
pas big modeling tests with bigger atol/rtol for accelerators
IlyasMoutawwakil Feb 6, 2025
97081da
fix hpu device count and skip tests requiring hpu:x
IlyasMoutawwakil Feb 6, 2025
ddcb3ca
hpu autocast
IlyasMoutawwakil Feb 6, 2025
6de389c
hpu rng_state
IlyasMoutawwakil Feb 7, 2025
ae9a76b
hpu launch
IlyasMoutawwakil Feb 7, 2025
5b8b0b2
hpu special device placement
IlyasMoutawwakil Feb 7, 2025
a2f8040
hpu launch
IlyasMoutawwakil Feb 7, 2025
6abecdd
rng state
IlyasMoutawwakil Feb 7, 2025
7bc37dc
distributed data loop tests
IlyasMoutawwakil Feb 7, 2025
ef1de61
enforce non contiguity after device memory allocation
IlyasMoutawwakil Feb 7, 2025
1b6905e
pass fsdp tests
IlyasMoutawwakil Feb 7, 2025
defe3fa
enforce pt_hpu_lazy_mode=0 when fsdp testing
IlyasMoutawwakil Feb 7, 2025
9551ce3
pass cli tests
IlyasMoutawwakil Feb 10, 2025
9c84fe7
pass and document grad sync tests
IlyasMoutawwakil Feb 10, 2025
6f00591
pass kwargs handler and autocast tests
IlyasMoutawwakil Feb 10, 2025
c94bfbd
memory utils
IlyasMoutawwakil Feb 10, 2025
61235d3
found source of int64 errors
IlyasMoutawwakil Feb 10, 2025
0896a50
skip some modeling utils tests
IlyasMoutawwakil Feb 10, 2025
e974758
enable int64
IlyasMoutawwakil Feb 10, 2025
ee08748
skip optimizer tests
IlyasMoutawwakil Feb 10, 2025
6f0fbe4
pass checkpointing tests
IlyasMoutawwakil Feb 10, 2025
c5c50c6
pass accelerator tests with safetensors main
IlyasMoutawwakil Feb 10, 2025
34010c9
more hpu stuff
IlyasMoutawwakil Feb 10, 2025
9f75a6e
Merge branch 'main' into hpu-support
IlyasMoutawwakil Feb 10, 2025
e80b484
style
IlyasMoutawwakil Feb 10, 2025
5cacc31
remove PT_HPU_LAZY_MODE and PT_ENABLE_INT64_SUPPORT as they should be…
IlyasMoutawwakil Feb 15, 2025
f006c4e
start testing on gaudi2
IlyasMoutawwakil Feb 17, 2025
19e652a
support fp16 on gaudi2
IlyasMoutawwakil Feb 17, 2025
40d22b1
add testing order
IlyasMoutawwakil Feb 17, 2025
eb37c43
custom hpu fsdp env dict
IlyasMoutawwakil Feb 17, 2025
dc4ca51
fix torch trace malloc
IlyasMoutawwakil Feb 17, 2025
74b307a
test ddp half precision comm hooks
IlyasMoutawwakil Feb 17, 2025
5a6d5ef
fix
IlyasMoutawwakil Feb 17, 2025
5a1c0c9
fix
IlyasMoutawwakil Feb 17, 2025
50d9e71
remove lower bound for hpu
IlyasMoutawwakil Feb 17, 2025
f0579e8
use 0.72 as lower bound
IlyasMoutawwakil Feb 17, 2025
dfc82ec
lower lower bound
IlyasMoutawwakil Feb 17, 2025
176e3d2
order deepspeed tests
IlyasMoutawwakil Feb 17, 2025
6c688d0
fix
IlyasMoutawwakil Feb 17, 2025
b078e90
deepspeed_use_hpu
IlyasMoutawwakil Feb 17, 2025
0dcb46a
assert non lazy mode with offloaded optimizer
IlyasMoutawwakil Feb 18, 2025
5abb1a4
make patching torch with habana frameworks the default
IlyasMoutawwakil Feb 18, 2025
b63a6fa
less of require_non_hpu
IlyasMoutawwakil Feb 18, 2025
36f8794
skip test_multi_device_merge_fsdp_weights for now as it halts
IlyasMoutawwakil Feb 18, 2025
ab5cbb0
skip another flaky test
IlyasMoutawwakil Feb 18, 2025
e318161
format
IlyasMoutawwakil Feb 18, 2025
0c040c3
use habana_visible_modules
IlyasMoutawwakil Feb 18, 2025
6f5977e
patch torch hpu device count
IlyasMoutawwakil Feb 18, 2025
f1e196f
avoid setting HABANA_VISIBLE_MODULES
IlyasMoutawwakil Feb 18, 2025
2772b68
don't play with habana visible devices/modules
IlyasMoutawwakil Feb 18, 2025
7d1ef62
only with hpu
IlyasMoutawwakil Feb 18, 2025
427c313
fixes and skips
IlyasMoutawwakil Feb 18, 2025
be91183
skip
IlyasMoutawwakil Feb 18, 2025
5c0cd84
fix device ids and add some todos
IlyasMoutawwakil Feb 19, 2025
ae1431a
skip offloading with generate()
IlyasMoutawwakil Feb 19, 2025
d383ea5
fix
IlyasMoutawwakil Feb 19, 2025
0b62d52
reduced atol/rtol for hpu
IlyasMoutawwakil Feb 19, 2025
f2504a5
fix
IlyasMoutawwakil Feb 19, 2025
f5cf0d5
tag deepspeed tests that should run first
IlyasMoutawwakil Feb 19, 2025
ac434c2
enable a test path that was skipped
IlyasMoutawwakil Feb 19, 2025
1501105
revert a test that was customized for gaudi1
IlyasMoutawwakil Feb 19, 2025
8b5708e
some patching to enable HABANA_VISIBLE_MODULES
IlyasMoutawwakil Feb 19, 2025
8935766
fix zero3 test
IlyasMoutawwakil Feb 19, 2025
d8301cd
misc
IlyasMoutawwakil Feb 19, 2025
6ce9e3a
test DTensor TP
IlyasMoutawwakil Feb 19, 2025
42775d2
remove gaudi1
IlyasMoutawwakil Feb 19, 2025
788e95f
test
IlyasMoutawwakil Feb 20, 2025
03b391e
style
IlyasMoutawwakil Feb 20, 2025
2247739
comment
IlyasMoutawwakil Feb 20, 2025
07ba582
pass pad_across_processes
IlyasMoutawwakil Feb 20, 2025
647dfab
require_fp16
IlyasMoutawwakil Feb 20, 2025
8e63b29
pass memory utils test
IlyasMoutawwakil Feb 20, 2025
6b1d131
test_ddp_comm_hook
IlyasMoutawwakil Feb 20, 2025
7803291
skip half precision comm hooks on hpu
IlyasMoutawwakil Feb 20, 2025
2883ca1
fix
IlyasMoutawwakil Feb 20, 2025
007d4a8
is_fp16_available
IlyasMoutawwakil Feb 20, 2025
9c12fae
fp16
IlyasMoutawwakil Feb 20, 2025
324d6df
tp as part of integration tests
IlyasMoutawwakil Feb 20, 2025
839c6be
fix
IlyasMoutawwakil Feb 20, 2025
3e548f4
write_basic_config
IlyasMoutawwakil Feb 20, 2025
f67a898
safetensors
IlyasMoutawwakil Feb 20, 2025
f449d3f
local sgd and masked_fill_fwd_i64
IlyasMoutawwakil Feb 20, 2025
79ef8a5
fix num_processes in test_load_states_by_steps
IlyasMoutawwakil Feb 20, 2025
f772b76
fp8 support
IlyasMoutawwakil Feb 24, 2025
6218cec
test
IlyasMoutawwakil Feb 24, 2025
31872f6
Merge branch 'main' into hpu-support
IlyasMoutawwakil Feb 24, 2025
610c68b
fix
IlyasMoutawwakil Feb 24, 2025
347db07
add a workflow
IlyasMoutawwakil Feb 25, 2025
5fc5a2a
Update src/accelerate/accelerator.py
IlyasMoutawwakil Feb 25, 2025
dc7a773
review comments
IlyasMoutawwakil Feb 25, 2025
9606f0d
ci
IlyasMoutawwakil Feb 25, 2025
6b77bc4
style
IlyasMoutawwakil Feb 25, 2025
d556021
comments
IlyasMoutawwakil Feb 26, 2025
e2fe2cc
test
IlyasMoutawwakil Feb 26, 2025
05e6861
habana_frameworks.torch
IlyasMoutawwakil Feb 26, 2025
ef6192c
patch device count
IlyasMoutawwakil Feb 26, 2025
59b51e5
fix
IlyasMoutawwakil Feb 26, 2025
c6731f5
fix
IlyasMoutawwakil Feb 26, 2025
66ec449
require_fp8
IlyasMoutawwakil Feb 26, 2025
28dae91
fix
IlyasMoutawwakil Feb 27, 2025
ec9c562
fix
IlyasMoutawwakil Feb 27, 2025
53f99c3
gaudi 1
IlyasMoutawwakil Feb 27, 2025
5f9928d
remove unnecessary
IlyasMoutawwakil Feb 27, 2025
ddbece5
fixed maskd fill error in transformers
IlyasMoutawwakil Feb 28, 2025
72bd312
style
IlyasMoutawwakil Feb 28, 2025
506d07e
balanced_memory pass on hpu
IlyasMoutawwakil Mar 3, 2025
ae67bcc
remove for now
IlyasMoutawwakil Mar 3, 2025
405b857
run first
IlyasMoutawwakil Mar 4, 2025
27be94c
Apply suggestions from code review
IlyasMoutawwakil Mar 5, 2025
4e0e966
Merge branch 'main' into hpu-support
IlyasMoutawwakil Mar 5, 2025
e2a8d85
style after merge
IlyasMoutawwakil Mar 5, 2025
03e2646
Update src/accelerate/accelerator.py
IlyasMoutawwakil Mar 6, 2025
3ed87c1
Update src/accelerate/utils/transformer_engine.py
IlyasMoutawwakil Mar 6, 2025
2dcab3e
Merge branch 'main' into hpu-support
IlyasMoutawwakil Mar 6, 2025
55b0d3c
empty cache review comments
IlyasMoutawwakil Mar 6, 2025
bd2afc3
test_scirpt.py error messages
IlyasMoutawwakil Mar 6, 2025
75e5b81
AccelerateTestCase for accelerator state cleanup
IlyasMoutawwakil Mar 6, 2025
e5dfad4
test
IlyasMoutawwakil Mar 7, 2025
ed84e7b
add gaudi1 workflow
IlyasMoutawwakil Mar 7, 2025
a05e54a
fp8 avilability
IlyasMoutawwakil Mar 7, 2025
eb0b3a3
fix
IlyasMoutawwakil Mar 7, 2025
7b2650a
reduce batch size
IlyasMoutawwakil Mar 7, 2025
9b227d8
concurrency
IlyasMoutawwakil Mar 7, 2025
8cf20cd
check cuda as well
IlyasMoutawwakil Mar 7, 2025
7c4897b
nits and comments
IlyasMoutawwakil Mar 7, 2025
d0485f1
mark fsdp tests that require_fp16
IlyasMoutawwakil Mar 7, 2025
c37aefd
style
IlyasMoutawwakil Mar 7, 2025
bdae68d
mark deepspeed fp16 tests
IlyasMoutawwakil Mar 7, 2025
d919931
update image
IlyasMoutawwakil Mar 7, 2025
efd2a27
fix
IlyasMoutawwakil Mar 9, 2025
394b687
updated
IlyasMoutawwakil Mar 9, 2025
4f76d2c
better msgs
IlyasMoutawwakil Mar 9, 2025
b3dd375
skip pippy
IlyasMoutawwakil Mar 9, 2025
17d43ab
test
IlyasMoutawwakil Mar 9, 2025
db16287
test on 2 device
IlyasMoutawwakil Mar 9, 2025
e359c01
support up to 1% relative error in test_accelerate
IlyasMoutawwakil Mar 9, 2025
e9cfca4
skip hpu fp16
IlyasMoutawwakil Mar 9, 2025
ac41600
allow for 1 byte differene
IlyasMoutawwakil Mar 9, 2025
8571ef4
revert torch_device change
IlyasMoutawwakil Mar 9, 2025
3115ee4
style
IlyasMoutawwakil Mar 9, 2025
7c6a44a
skip memory release since it's flaky
IlyasMoutawwakil Mar 9, 2025
e8f9a48
add accelerator state cleanup to fixture
IlyasMoutawwakil Mar 9, 2025
3face36
fix
IlyasMoutawwakil Mar 9, 2025
06c1f53
atol
IlyasMoutawwakil Mar 9, 2025
75aaabd
fix
IlyasMoutawwakil Mar 9, 2025
21fca86
more rtol
IlyasMoutawwakil Mar 10, 2025
a99c297
equal grad test
IlyasMoutawwakil Mar 10, 2025
81a37be
revert
IlyasMoutawwakil Mar 10, 2025
92775af
pass pippy on gaudi2 and skip on gaudi1
IlyasMoutawwakil Mar 10, 2025
ce13eeb
enable sd 1.5 test with require fp16
IlyasMoutawwakil Mar 10, 2025
04983cc
added warning on memory release
IlyasMoutawwakil Mar 10, 2025
5efbe8c
don't log warning in memory release as it requires PartialState to be…
IlyasMoutawwakil Mar 10, 2025
4847474
Apply suggestions from code review
IlyasMoutawwakil Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@

class Accelerator:
"""
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU, HPU) or mixed precision
training.

Args:
device_placement (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -529,7 +530,7 @@ def __init__(
DistributedType.DEEPSPEED,
DistributedType.MEGATRON_LM,
):
if self.device.type in ["cpu", "xpu"]:
if self.device.type in ["cpu", "xpu", "hpu"]:
self.native_amp = True
else:
self.native_amp = is_bf16_available(True)
Expand Down
16 changes: 15 additions & 1 deletion src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_datasets_available,
is_deepspeed_available,
is_fp8_available,
is_hpu_available,
is_ipex_available,
is_mlu_available,
is_mps_available,
Expand All @@ -64,6 +65,9 @@
if is_npu_available(check_device=False):
import torch_npu # noqa: F401

if is_hpu_available(check_device=False):
import habana_frameworks.torch # noqa: F401

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -210,6 +214,7 @@ def __init__(self, cpu: bool = False, **kwargs):
and not torch.distributed.is_initialized()
):
torch.distributed.init_process_group(backend=self.backend, **kwargs)

# XPU and CPU require special env configs to be set
if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
dist_information = get_cpu_distributed_information()
Expand Down Expand Up @@ -290,6 +295,7 @@ def __init__(self, cpu: bool = False, **kwargs):
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)

# Important: This should be the *only* code outside of `self.initialized!`
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

Expand Down Expand Up @@ -369,6 +375,7 @@ def wait_for_everyone(self):
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_CPU,
DistributedType.MULTI_HPU,
DistributedType.DEEPSPEED,
DistributedType.FSDP,
):
Expand Down Expand Up @@ -704,6 +711,8 @@ def default_device(self) -> torch.device:
return torch.device("cuda")
elif is_xpu_available():
return torch.device("xpu")
elif is_hpu_available():
return torch.device("hpu")
else:
return torch.device("cpu")

Expand All @@ -720,6 +729,7 @@ def _prepare_backend(
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA

elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
if is_mlu_available():
backend = "cncl"
Expand All @@ -732,6 +742,9 @@ def _prepare_backend(
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
elif is_hpu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_HPU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
Expand Down Expand Up @@ -776,7 +789,7 @@ def set_device(self):
self.device = torch.device("cpu") if self._cpu else self.default_device
return
device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla"):
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu"):
raise ValueError(
f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
)
Expand Down Expand Up @@ -911,6 +924,7 @@ def __init__(
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
]:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
self.distributed_type = DistributedType.FSDP
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
is_deepspeed_available,
is_dvclive_available,
is_fp8_available,
is_hpu_available,
is_import_timer_available,
is_ipex_available,
is_lomo_available,
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@
"MULTI_MUSA",
"MULTI_XPU",
"MULTI_CPU",
"MULTI_HPU",
]
6 changes: 6 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .environment import parse_flag_from_env, str_to_bool
from .imports import (
is_cuda_available,
is_hpu_available,
is_mlu_available,
is_msamp_available,
is_musa_available,
Expand Down Expand Up @@ -528,6 +529,7 @@ class DistributedType(str, enum.Enum):
- **MULTI_MUSA** -- Distributed on multiple MUSAs.
- **MULTI_NPU** -- Distributed on multiple NPUs.
- **MULTI_XPU** -- Distributed on multiple XPUs.
- **MULTI_HPU** -- Distributed on multiple HPUs.
- **DEEPSPEED** -- Using DeepSpeed.
- **XLA** -- Using TorchXLA.
"""
Expand All @@ -545,6 +547,7 @@ class DistributedType(str, enum.Enum):
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"
MULTI_HPU = "MULTI_HPU"


class SageMakerDistributedType(str, enum.Enum):
Expand Down Expand Up @@ -646,6 +649,7 @@ class DynamoBackend(str, BaseEnum):
TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE"
IPEX = "IPEX"
TVM = "TVM"
HPU_BACKEND = "HPU_BACKEND"


class LoggerType(BaseEnum):
Expand Down Expand Up @@ -1695,6 +1699,8 @@ def __post_init__(self):
device = torch.cuda.current_device()
elif is_xpu_available():
device = torch.xpu.current_device()
elif is_hpu_available():
device = torch.hpu.current_device()
else:
raise RuntimeError(
"There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'."
Expand Down
22 changes: 22 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,28 @@ def is_npu_available(check_device=False):
return hasattr(torch, "npu") and torch.npu.is_available()


@lru_cache
def is_hpu_available(check_device=False):
"Checks if `torch_hpu` is installed and potentially if a HPU is in the environment"
if importlib.util.find_spec("habana_frameworks") is None:
return False

import habana_frameworks.torch # noqa: F401
import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401

if check_device:
try:
import habana_frameworks.torch.utils.experimental as htexp

if htexp.hpu.is_available():
_ = htexp.hpu.device_count()
return True
return False
except RuntimeError:
return False
return True


@lru_cache
def is_xpu_available(check_device=False):
"""
Expand Down
Loading