Skip to content

Commit 38ff8e4

Browse files
committed
Move get_world_size_and_rank to utils
1 parent 9cfa288 commit 38ff8e4

17 files changed

+69
-48
lines changed

docs/source/api_ref_training.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ Utilities for enabling and working with distributed training.
5252

5353
init_distributed
5454
is_distributed
55-
get_world_size_and_rank
5655
gather_cpu_state_dict
5756

5857
.. _ac_label:

docs/source/api_ref_utilities.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ Miscellaneous
1818
get_device
1919
get_logger
2020
torch_version_ge
21+
get_world_size_and_rank

recipes/dev/early_exit_finetune_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __init__(self, cfg: DictConfig) -> None:
183183

184184
# _is_rank_zero is used primarily for logging. In the future, the logger
185185
# should directly take care of this
186-
_, rank = training.get_world_size_and_rank()
186+
_, rank = utils.get_world_size_and_rank()
187187
self._is_rank_zero = rank == 0
188188

189189
# Training cfg
@@ -646,7 +646,7 @@ def _setup_data(
646646
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
647647
iterable datasets and streaming datasets are not supported.
648648
"""
649-
world_size, rank = training.get_world_size_and_rank()
649+
world_size, rank = utils.get_world_size_and_rank()
650650

651651
if isinstance(cfg_dataset, ListConfig):
652652
datasets = [
@@ -826,7 +826,7 @@ def train(self) -> None:
826826
# clean up before training begins
827827
training.cleanup_before_training()
828828

829-
world_size, rank = training.get_world_size_and_rank()
829+
world_size, rank = utils.get_world_size_and_rank()
830830

831831
# zero out the gradients before starting training
832832
if not self._optimizer_in_bwd:

recipes/full_finetune_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(self, cfg: DictConfig) -> None:
133133
)
134134
self._log_peak_memory_stats = False
135135

136-
_, rank = training.get_world_size_and_rank()
136+
_, rank = utils.get_world_size_and_rank()
137137
self._is_rank_zero = rank == 0
138138

139139
# Training cfg
@@ -619,7 +619,7 @@ def _setup_data(
619619
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
620620
iterable datasets and streaming datasets are not supported.
621621
"""
622-
world_size, rank = training.get_world_size_and_rank()
622+
world_size, rank = utils.get_world_size_and_rank()
623623

624624
if isinstance(cfg_dataset, ListConfig):
625625
datasets = [
@@ -757,7 +757,7 @@ def train(self) -> None:
757757
# clean up before training begins
758758
training.cleanup_before_training()
759759

760-
world_size, rank = training.get_world_size_and_rank()
760+
world_size, rank = utils.get_world_size_and_rank()
761761

762762
# zero out the gradients before starting training
763763
if not self._optimizer_in_bwd:

recipes/knowledge_distillation_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig) -> None:
116116
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
117117
)
118118

119-
_, rank = training.get_world_size_and_rank()
119+
_, rank = utils.get_world_size_and_rank()
120120

121121
self._is_rank_zero = rank == 0
122122

@@ -646,7 +646,7 @@ def _setup_data(
646646
Map-style Datasets which fit into memory and an option for random shuffling.
647647
Samplers, iterable datasets, and streaming datasets are not supported.
648648
"""
649-
world_size, rank = training.get_world_size_and_rank()
649+
world_size, rank = utils.get_world_size_and_rank()
650650

651651
if isinstance(cfg_dataset, ListConfig):
652652
datasets = [
@@ -815,7 +815,7 @@ def train(self) -> None:
815815
# clean up before training begins
816816
training.cleanup_before_training()
817817

818-
world_size, rank = training.get_world_size_and_rank()
818+
world_size, rank = utils.get_world_size_and_rank()
819819

820820
# zero out the gradients before starting training
821821
self._optimizer.zero_grad()

recipes/lora_dpo_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(self, cfg: DictConfig) -> None:
131131
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
132132
)
133133

134-
_, rank = training.get_world_size_and_rank()
134+
_, rank = utils.get_world_size_and_rank()
135135

136136
self._is_rank_zero = rank == 0
137137

@@ -492,7 +492,7 @@ def _setup_data(
492492
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
493493
iterable datasets and streaming datasets are not supported.
494494
"""
495-
world_size, rank = training.get_world_size_and_rank()
495+
world_size, rank = utils.get_world_size_and_rank()
496496

497497
if isinstance(cfg_dataset, ListConfig):
498498
datasets = [
@@ -642,7 +642,7 @@ def train(self) -> None:
642642
# clean up before training begins
643643
training.cleanup_before_training()
644644

645-
_, rank = training.get_world_size_and_rank()
645+
_, rank = utils.get_world_size_and_rank()
646646

647647
# zero out the gradients before starting training
648648
self._optimizer.zero_grad()

recipes/lora_finetune_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(self, cfg: DictConfig) -> None:
135135
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
136136
)
137137

138-
_, rank = training.get_world_size_and_rank()
138+
_, rank = utils.get_world_size_and_rank()
139139

140140
self._is_rank_zero = rank == 0
141141

@@ -584,7 +584,7 @@ def _setup_data(
584584
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
585585
iterable datasets and streaming datasets are not supported.
586586
"""
587-
world_size, rank = training.get_world_size_and_rank()
587+
world_size, rank = utils.get_world_size_and_rank()
588588

589589
if isinstance(cfg_dataset, ListConfig):
590590
datasets = [
@@ -746,7 +746,7 @@ def train(self) -> None:
746746
# clean up before training begins
747747
training.cleanup_before_training()
748748

749-
world_size, rank = training.get_world_size_and_rank()
749+
world_size, rank = utils.get_world_size_and_rank()
750750

751751
# zero out the gradients before starting training
752752
self._optimizer.zero_grad()

recipes/qat_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self, cfg: DictConfig) -> None:
144144
)
145145
self._log_peak_memory_stats = False
146146

147-
_, rank = training.get_world_size_and_rank()
147+
_, rank = utils.get_world_size_and_rank()
148148
self._is_rank_zero = rank == 0
149149

150150
# Training cfg
@@ -591,7 +591,7 @@ def _setup_data(
591591
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
592592
iterable datasets and streaming datasets are not supported.
593593
"""
594-
world_size, rank = training.get_world_size_and_rank()
594+
world_size, rank = utils.get_world_size_and_rank()
595595

596596
if isinstance(cfg_dataset, ListConfig):
597597
datasets = [
@@ -729,7 +729,7 @@ def train(self) -> None:
729729
# clean up before training begins
730730
training.cleanup_before_training()
731731

732-
world_size, rank = training.get_world_size_and_rank()
732+
world_size, rank = utils.get_world_size_and_rank()
733733

734734
# zero out the gradients before starting training
735735
if not self._optimizer_in_bwd:

recipes/qat_lora_finetune_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(self, cfg: DictConfig) -> None:
149149
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
150150
)
151151

152-
_, rank = training.get_world_size_and_rank()
152+
_, rank = utils.get_world_size_and_rank()
153153

154154
# _is_rank_zero is used primarily for logging. In the future, the logger
155155
# should directly take care of this
@@ -620,7 +620,7 @@ def _setup_data(
620620
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
621621
iterable datasets and streaming datasets are not supported.
622622
"""
623-
world_size, rank = training.get_world_size_and_rank()
623+
world_size, rank = utils.get_world_size_and_rank()
624624

625625
if isinstance(cfg_dataset, ListConfig):
626626
datasets = [
@@ -784,7 +784,7 @@ def train(self) -> None:
784784
# clean up before training begins
785785
training.cleanup_before_training()
786786

787-
world_size, rank = training.get_world_size_and_rank()
787+
world_size, rank = utils.get_world_size_and_rank()
788788

789789
# zero out the gradients before starting training
790790
self._optimizer.zero_grad()

tests/torchtune/training/test_distributed.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,6 @@ def _test_worker_fn(init_pg_explicit: bool) -> None:
5656
pg_backend == "gloo"
5757
), f"Expected 'gloo' backend, but received {pg_backend}"
5858

59-
@staticmethod
60-
def _test_world_size_with_cpu_device(expected_world_size: int) -> None:
61-
training.init_distributed(backend="gloo")
62-
world_size, _ = training.get_world_size_and_rank()
63-
if world_size != expected_world_size:
64-
raise AssertionError(
65-
f"Expected different world size: received {world_size}, expected {expected_world_size}"
66-
)
67-
6859
def _test_launch_worker(
6960
self,
7061
get_pet_launch_config,
@@ -84,13 +75,6 @@ def test_init_from_env_dup(self, get_pet_launch_config) -> None:
8475
# trivial test case to ensure test passes with no exceptions
8576
assert True
8677

87-
def test_world_size_with_cpu(self, get_pet_launch_config) -> None:
88-
desired_world_size = 4
89-
lc = get_pet_launch_config(desired_world_size)
90-
launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)(
91-
desired_world_size
92-
)
93-
9478
def test_validate_no_params_on_meta_device(self) -> None:
9579
with torch.device("meta"):
9680
model = torch.nn.Linear(3, 3)

0 commit comments

Comments
 (0)