Skip to content

Commit a89f2c4

Browse files
committed
corrected tp and cp integration in validation, removed some repetitive integration tests
1 parent 5f3e434 commit a89f2c4

File tree

3 files changed

+40
-80
lines changed

3 files changed

+40
-80
lines changed

tests/integration_tests.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -513,45 +513,19 @@ def build_test_list():
513513
[
514514
[
515515
"--validation.enabled",
516-
"--validation.dataset c4_validation",
516+
"--validation.dataset c4_test",
517517
],
518518
],
519519
"Validation test no parallelism",
520520
"validation_no_parallel",
521521
ngpu=1,
522522
),
523-
OverrideDefinitions(
524-
[
525-
[
526-
"--validation.enabled",
527-
"--validation.dataset c4_validation",
528-
"--parallelism.data_parallel_shard_degree=1",
529-
"--parallelism.data_parallel_replicate_degree=4",
530-
]
531-
],
532-
"Validation test with DP",
533-
"validation_dp",
534-
ngpu=4,
535-
),
536-
OverrideDefinitions(
537-
[
538-
[
539-
"--validation.enabled",
540-
"--validation.dataset c4_validation",
541-
"--parallelism.data_parallel_shard_degree=2",
542-
"--parallelism.data_parallel_replicate_degree=2",
543-
]
544-
],
545-
"Validation test with FSDP",
546-
"validation_fsdp",
547-
ngpu=4,
548-
),
549523
OverrideDefinitions(
550524
[
551525
[
552526
"--checkpoint.enable_checkpoint",
553527
"--validation.enabled",
554-
"--validation.dataset c4_validation",
528+
"--validation.dataset c4_test",
555529
"--parallelism.data_parallel_shard_degree=2",
556530
"--parallelism.data_parallel_replicate_degree=2",
557531
]
@@ -560,37 +534,6 @@ def build_test_list():
560534
"validation_fsdp_checkpoint",
561535
ngpu=4,
562536
),
563-
OverrideDefinitions(
564-
[
565-
[
566-
"--validation.enabled",
567-
"--validation.dataset c4_validation",
568-
"--parallelism.data_parallel_shard_degree=2",
569-
"--parallelism.data_parallel_replicate_degree=1",
570-
"--parallelism.tensor_parallel_degree=2",
571-
"--parallelism.context_parallel_degree=2",
572-
]
573-
],
574-
"Validation test with FSDP, TP, CP",
575-
"validation_fsdp_tp_cp",
576-
ngpu=8,
577-
),
578-
OverrideDefinitions(
579-
[
580-
[
581-
"--checkpoint.enable_checkpoint",
582-
"--validation.enabled",
583-
"--validation.dataset c4_validation",
584-
"--parallelism.data_parallel_shard_degree=2",
585-
"--parallelism.data_parallel_replicate_degree=1",
586-
"--parallelism.tensor_parallel_degree=2",
587-
"--parallelism.context_parallel_degree=2",
588-
]
589-
],
590-
"Validation checkpoint test with FSDP, TP, CP",
591-
"validation_fsdp_tp_cp_checkpoint",
592-
ngpu=8,
593-
),
594537
]
595538
return integration_tests_flavors
596539

torchtitan/components/validate.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Generator
8+
79
import torch
810
import torch.nn as nn
911
from torch.distributed.fsdp import FSDPModule
10-
1112
from torch.distributed.tensor import DTensor
1213
from torchtitan.components.dataloader import BaseDataLoader
1314
from torchtitan.components.loss import LossFunction
@@ -52,6 +53,8 @@ def __init__(
5253
parallel_dims: ParallelDims,
5354
world_mesh: torch.distributed.DeviceMesh,
5455
loss_fn: LossFunction,
56+
validation_context: Generator[None, None, None],
57+
maybe_enable_amp: Generator[None, None, None],
5558
):
5659
self.job_config = job_config
5760
self.parallel_dims = parallel_dims
@@ -63,6 +66,8 @@ def __init__(
6366
dp_rank=dp_rank,
6467
tokenizer=tokenizer,
6568
)
69+
self.validation_context = validation_context
70+
self.maybe_enable_amp = maybe_enable_amp
6671

6772
@torch.no_grad()
6873
def validate(
@@ -76,44 +81,52 @@ def validate(
7681

7782
accumulated_losses = []
7883
device_type = utils.device_type
79-
num_val_steps = 0
84+
num_steps = 0
8085

8186
for input_dict, labels in self.validation_dataloader:
8287
if (
8388
self.job_config.validation.steps != -1
84-
and num_val_steps >= self.job_config.validation.steps
89+
and num_steps >= self.job_config.validation.steps
8590
):
8691
break
8792

8893
for k, v in input_dict.items():
8994
input_dict[k] = v.to(device_type)
90-
labels = labels.to(device_type)
91-
9295
inputs = input_dict["input"]
93-
predictions = model(inputs)
96+
labels = labels.to(device_type)
9497

95-
if self.parallel_dims.loss_parallel_enabled:
96-
if isinstance(predictions, torch.Tensor) and not isinstance(
97-
predictions, DTensor
98-
):
99-
predictions = DTensor.from_local(predictions, self.world_mesh["tp"])
100-
if isinstance(labels, torch.Tensor) and not isinstance(labels, DTensor):
101-
labels = DTensor.from_local(labels, self.world_mesh["tp"])
102-
loss = self.loss_fn(predictions, labels)
98+
optional_context_parallel_ctx = (
99+
dist_utils.create_context_parallel_ctx(
100+
cp_mesh=self.world_mesh["cp"],
101+
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
102+
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
103+
cp_no_restore_buffers={inputs, labels},
104+
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
105+
)
106+
if self.parallel_dims.cp_enabled
107+
else None
108+
)
109+
110+
with self.validation_context(optional_context_parallel_ctx):
111+
assert len(model_parts) == 1
112+
with self.maybe_enable_amp:
113+
predictions = model(inputs)
114+
loss = self.loss_fn(predictions, labels)
103115

104116
accumulated_losses.append(loss.detach())
105117

106-
num_val_steps += 1
118+
num_steps += 1
107119

108120
# Compute average loss
109121
loss = torch.sum(torch.stack(accumulated_losses))
122+
loss /= num_steps
110123
if self.parallel_dims.dp_cp_enabled:
111124
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
112125
else:
113126
global_avg_loss = loss
114127

115128
logger.info(
116-
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_val_steps} batches"
129+
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches"
117130
)
118131

119132
# Reshard after run forward pass
@@ -125,8 +138,6 @@ def validate(
125138
# Set model back to train mode
126139
model.train()
127140

128-
return {"validation_loss": global_avg_loss}
129-
130141

131142
def build_validator(
132143
job_config: JobConfig,
@@ -136,6 +147,8 @@ def build_validator(
136147
parallel_dims: ParallelDims,
137148
world_mesh: torch.distributed.DeviceMesh,
138149
loss_fn: LossFunction,
150+
validation_context: Generator[None, None, None],
151+
maybe_enable_amp: Generator[None, None, None],
139152
) -> BaseValidator:
140153
"""Build a simple validator focused on correctness."""
141154
return Validator(
@@ -146,4 +159,6 @@ def build_validator(
146159
parallel_dims=parallel_dims,
147160
world_mesh=world_mesh,
148161
loss_fn=loss_fn,
162+
validation_context=validation_context,
163+
maybe_enable_amp=maybe_enable_amp,
149164
)

torchtitan/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(self, job_config: JobConfig):
219219
if parallel_dims.pp_enabled:
220220
if not self.train_spec.pipelining_fn:
221221
raise RuntimeError(
222-
f"pipeline parallel is enabled but {self.train_spec.name} "
222+
f"Pipeline parallel is enabled but {self.train_spec.name} "
223223
f"does not support pipelining"
224224
)
225225

@@ -336,7 +336,9 @@ def __init__(self, job_config: JobConfig):
336336
tokenizer=tokenizer,
337337
parallel_dims=parallel_dims,
338338
world_mesh=world_mesh,
339-
loss_fn=self.loss_fn,
339+
loss_fn=self.train_spec.build_loss_fn(job_config),
340+
validation_context=self.train_context,
341+
maybe_enable_amp=self.maybe_enable_amp,
340342
)
341343

342344
logger.info(
@@ -525,7 +527,7 @@ def train(self):
525527
self.job_config.validation.enabled
526528
and self.validator.should_validate(self.step)
527529
):
528-
validation_metrics = self.validator.validate(self.model_parts)
530+
self.validator.validate(self.model_parts)
529531

530532
self.checkpointer.save(
531533
self.step, last_step=(self.step == job_config.training.steps)

0 commit comments

Comments
 (0)