Skip to content

Commit cb83655

Browse files
authored
Renamed parallelize_plan to tensor_parallel_plan (#2387)
1 parent f67ccda commit cb83655

File tree

9 files changed

+15
-13
lines changed

9 files changed

+15
-13
lines changed

recipes/configs/llama3/70B_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system.
2121

2222
# Parallelism
2323
tensor_parallel_dim: 1
24-
parallelize_plan:
24+
tensor_parallel_plan:
2525
_component_: torchtune.models.llama3.base_llama_tp_plan
2626

2727
# Tokenizer

recipes/configs/llama3/70B_generation_distributed.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ output_dir: ./
1313
model:
1414
_component_: torchtune.models.llama3.llama3_70b
1515

16-
parallelize_plan:
16+
tensor_parallel_plan:
1717
_component_: torchtune.models.llama3.base_llama_tp_plan
1818

1919
# Transform arguments

recipes/configs/llama3_1/70B_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your syste
2020

2121
# Parallelism
2222
tensor_parallel_dim: 1
23-
parallelize_plan:
23+
tensor_parallel_plan:
2424
_component_: torchtune.models.llama3.base_llama_tp_plan
2525

2626
# Tokenizer

recipes/configs/llama3_1/70B_generation_distributed.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ output_dir: ./
1313
model:
1414
_component_: torchtune.models.llama3_1.llama3_1_70b
1515

16-
parallelize_plan:
16+
tensor_parallel_plan:
1717
_component_: torchtune.models.llama3.base_llama_tp_plan
1818

1919
# Transform arguments

recipes/configs/llama3_3/70B_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your syste
2020

2121
# Parallelism
2222
tensor_parallel_dim: 1
23-
parallelize_plan:
23+
tensor_parallel_plan:
2424
_component_: torchtune.models.llama3.base_llama_tp_plan
2525

2626
# Tokenizer

recipes/configs/llama3_3/70B_generation_distributed.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ output_dir: ./
1313
model:
1414
_component_: torchtune.models.llama3_3.llama3_3_70b
1515

16-
parallelize_plan:
16+
tensor_parallel_plan:
1717
_component_: torchtune.models.llama3.base_llama_tp_plan
1818

1919
# Transform arguments

recipes/dev/generate_v2_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def setup(self, cfg: DictConfig) -> None:
111111
parallelize_module(
112112
model,
113113
tp_device_mesh,
114-
parallelize_plan=config.instantiate(cfg.parallelize_plan),
114+
parallelize_plan=config.instantiate(cfg.tensor_parallel_plan),
115115
)
116116

117117
with training.set_default_dtype(self._dtype), self._device:

recipes/full_finetune_distributed.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,13 @@ def __init__(self, cfg: DictConfig) -> None:
145145
# Initialize distributed variables
146146
self.world_size, self.rank = utils.get_world_size_and_rank()
147147
self._is_rank_zero = self.rank == 0
148-
self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None))
148+
self.tensor_parallel_plan = config.instantiate(
149+
cfg.get("tensor_parallel_plan", None)
150+
)
149151
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
150-
if self.tensor_parallel_dim > 1 and self.parallelize_plan is None:
152+
if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
151153
raise ValueError(
152-
"Parallelism plan need to be provided when tensor parallel is enabled."
154+
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
153155
)
154156
if self.world_size % self.tensor_parallel_dim != 0:
155157
raise ValueError(
@@ -549,7 +551,7 @@ def _setup_model(
549551
parallelize_module(
550552
model,
551553
device_mesh["tp"],
552-
parallelize_plan=self.parallelize_plan,
554+
parallelize_plan=self.tensor_parallel_plan,
553555
)
554556

555557
# We currently have two versions of activation checkpointing in this recipe

tests/recipes/test_full_finetune_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_loss_2d_parallel(
156156
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
157157
ckpt_dir = ckpt_path.parent
158158
log_file = gen_log_file_name(tmpdir)
159-
parallelize_plan = "torchtune.models.llama3.base_llama_tp_plan"
159+
tp_plan = "torchtune.models.llama3.base_llama_tp_plan"
160160

161161
# Config file needed for model conversion.
162162
write_hf_ckpt_config(ckpt_dir)
@@ -175,7 +175,7 @@ def test_loss_2d_parallel(
175175
tokenizer.path='{tokenizer_path}' \
176176
tokenizer.prompt_template=null \
177177
tensor_parallel_dim={tensor_parallel_dim} \
178-
parallelize_plan._component_={parallelize_plan} \
178+
tensor_parallel_plan._component_={tp_plan} \
179179
metric_logger.filename={log_file} \
180180
""".split()
181181
model_config = MODEL_TEST_CONFIGS[model_type]

0 commit comments

Comments
 (0)