Skip to content

basic validator implementation #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 10, 2025
Merged

basic validator implementation #1362

merged 9 commits into from
Jul 10, 2025

Conversation

wesleytruong
Copy link
Contributor

@wesleytruong wesleytruong commented Jul 2, 2025

Update PR Summary:
Implements a validator that can be easily plugged into the training loop and configured from the job specific config file.
Changes:

  • Created validation section in job_config with enabled, dataset, freq, and steps fields
  • Created a builder function for validator in train_spec
  • Created a separate builder function for validation dataset in hf_dataset.py
  • Created validator class
    • Validator class initializes a build_validation_hf_loader but leaves this dataloader function unexposed to the train_spec
    • Validator class supports ddp, fsdp, cp, and tp (but not pp yet)
  • Integrated validation call into training loop
  • Creates an integration test to test parallelization

Updated tests training the same base model weights from a seed checkpoint:

FSDP=2 FSDP=2,TP=4
Screenshot 2025-07-09 at 4 33 53 PM Screenshot 2025-07-09 at 4 33 53 PM
FSDP=2,CP=4 FSDP=2,TP=2,CP=2
Screenshot 2025-07-09 at 4 39 35 PM Screenshot 2025-07-09 at 4 28 57 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 2, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass looks really good!
I left many detailed comments, please see if they make sense.

):
self.job_config = job_config
self.loss_fn = loss_fn
self.model = model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass model (model_parts) as an arg to validate, because it's changing

Comment on lines 42 to 52
job_config: JobConfig,
loss_fn: LossFunction,
model: nn.Module,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make the order as close as how you used them below in build_hf_validation_dataloader


seq_len: int = 2048
"""Sequence length for validation"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set up a steps config, controlling how many iterations we run, default to -1 which means consuming all the data in the validation dataset

Comment on lines 53 to 56
# path="tests/assets/c4_test",
# loader=lambda path: load_dataset(path, split="validation"),
# text_processor=_process_c4_text,
# ),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use path="allenai/c4", and loader=lambda path: load_dataset(path, name="en", split="validation"),

@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig):
device_type,
)

# Build validator if validation is configured
self.validator = None
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if job_config.validation.enabled:
  assert self.train_spec.build_validator_fn is not None
  # build validator ...

@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig):
device_type,
)

# Build validator if validation is configured
self.validator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this line, since it's already defined as instance variable

for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
input_dict[k] = v.to(device_type)
if isinstance(labels, torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this if?

Comment on lines 70 to 71
for batch_data, targets in self.validation_dataloader:
input_dict, labels = batch_data, targets
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for batch_data, targets in self.validation_dataloader:
input_dict, labels = batch_data, targets
for input_dict, labels in self.validation_dataloader:

logger.warning("No validation batches processed")

# Set model back to train mode
self.model.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this as the last line of this method

@wesleytruong
Copy link
Contributor Author

I've cleaned up the code according to your comments and added support for the validation frequency and steps. I also left streaming=True in the c4_validation dataset since otherwise it downloads the entire training dataset too. @tianyu-l

seq_len: int = 2048
"""Sequence length for validation"""

val_freq: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to have the val_ prefix as it's not ambiguous under Validation

Suggested change
val_freq: int = 1
freq: int = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe default to 10

"""Frequency of validation"""

val_steps: int = -1
"""Number of validation steps, -1 means all steps"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Number of validation steps, -1 means all steps"""
"""Number of validation steps, -1 means consuming all the data in the validation dataset"""

dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
infinite: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this arg -- I don't think anyone wants to do multiple loops over the validation dataset

seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you can always set to False here

@@ -54,6 +54,7 @@ tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
disable_loss_parallel = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change?

@@ -463,6 +477,12 @@ def train_step(
else:
global_avg_loss = global_max_loss = loss.detach().item()

# Run validation if validator is available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is not part of training step, let's put this outside train_step and put it in train before self.checkpointer.save(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As my comment above, please check if validation breaks checkpointing if FSDP is used.

"--validation.dataset c4_test",
],
],
"Validation test no parallelism",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is not without parallelism -- you are doing data parallel for validation; however, you are not doing all-reduce on the loss, so the loss you print out would be different on each DP rank. Let's do that in this PR, following the code in model forward.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L451-L464

For that you'll need to pass in parallel_dims world_mesh ft_manager when constructing Validator

I think then the code will support Tensor Parallel and Context Parallel but not Pipeline Parallel yet, which we can do in a followup PR.

model_parts: list[nn.Module],
) -> dict[str, float]:
# Set model to eval mode
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO: here claiming we only support data parallel for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not support all parallelisms besides PP here?

num_val_steps = 0

with torch.no_grad():
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you don't need this try-catch because StopIteration will be automatically captured by for loop safely.

Copy link
Contributor

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this, this will be very useful!

You can take a look at these changes for some inspiration for addressing some of my comments.

if self.job_config.validation.enabled and self.validator.should_validate(
self.step
):
validation_metrics = self.validator.validate(self.model_parts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation metrics should be logged by self.metrics_processor.log() (to the terminal output and Tensorboard/wandb).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For logging to TB/W&B I agree we should use self.metrics_processor.logger.
For terminal, there are two options, one is to do it locally in this function, the other is creating a new metrics processor like what you did. I personally think the latter tries to make the style consistent (which I appreciate), but sounds a bit overkill.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a separate metrics processor seems like overkill for this implementation. A third option is to also use self.metrics_processor.log to print the metrics in the terminal.

# Build validator if validation is configured
if job_config.validation.enabled:
assert self.train_spec.build_validator_fn is not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you raise an error here if parallel_dims.pp_enabled?

@@ -49,6 +49,13 @@ class DatasetConfig:
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
"c4_validation": DatasetConfig(
path="allenai/c4",
loader=lambda path: load_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can reuse _load_c4_dataset together with functools.partial here by adding split as an argument to _load_c4_dataset.

@@ -193,3 +200,34 @@ def build_hf_dataloader(
dp_world_size=dp_world_size,
batch_size=batch_size,
)


def build_hf_validation_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think adding a new function for this is necessary; I would prefer replacing the job_config argument with dataset_name, dataset_path, batch_size, and seq_len. The reasoning is that for validation the function is also just returning a data loader based on a HF dataset, just the underlying dataset will be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we probably don't want to change this interface is:
people plug in their own data loader, and they want it to be general enough https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L133-L138
We used to have something closer to what you proposed, but changed due to their requests.

I think it's ok we make this compromise for that purpose.

@@ -657,6 +657,30 @@ class Experimental:
"""


@dataclass
class Validation:
enabled: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove this field and modify val_freq to offer an option for disabling validation, e.g., val_freq: int | None = 10, where validation is disabled if val_freq=None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this enabled to be consistent with other configs in torchtitan -- this sounds more like a style thing?

# Compute average loss
if num_batches > 0:
average_loss = total_loss / num_batches
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a __post_init__ to the Validation dataclass that verifies that all values are valid, e.g., val_steps > 0.

# Set model back to train mode
model.train()

return {"validation_loss": average_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average_loss is the local loss for each rank, but should still be all-reduced across ranks.

# Set model back to train mode
model.train()

return {"validation_loss": average_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change this to "validation/loss"? This is important for how wandb represents the metrics and allows you to add more metrics to the same section via "validation/<you-new-metric>" later on.

total_loss += loss.item()
num_batches += 1

num_val_steps += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you use separate counters for num_batches and num_val_steps? Also, you could use this instead:

for step, (input_dict, labels) in enumerate(self.validation_dataloader):

Here, step replaces num_batches and num_val_steps. You would also have to change num_val_steps >= self.job_config.validation.val_steps to step > self.job_config.validation.val_steps above.

device_type = utils.device_type
num_val_steps = 0

with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can also use this as a decorator instead, so you don't have to indent your code as much.

@torch.no_grad()
def validate(

model_parts: list[nn.Module],
) -> dict[str, float]:
# Set model to eval mode
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about reshard_after_forward value as ensure the parameters are sharded before leaving validate(). Otherwise, checkpointing will be broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, I forgot this.
Please include & adapt the following code in Validator.validate()
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L185-L189
cc @wwwjn

@@ -463,6 +477,12 @@ def train_step(
else:
global_avg_loss = global_max_loss = loss.detach().item()

# Run validation if validator is available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As my comment above, please check if validation breaks checkpointing if FSDP is used.

@wesleytruong
Copy link
Contributor Author

@tianyu-l @runame
I made some changes to address the concerns, mainly adding support for the parallelisms besides pp in validate, and some organizational/stylistic changes. The metrics is still handled locally but will be implemented in more detail in a future pr.

@wesleytruong
Copy link
Contributor Author

wesleytruong commented Jul 8, 2025

To show the validation parallelism is correct I followed https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#reproducibility-between-runs, to compare training the same model with different parallelisms and show they have the same loss. The validation is ran every 10 steps across 50 total training steps and local batch size 8.
FSDP=2:
Screenshot 2025-07-08 at 3 28 17 PM
FSDP=2, TP=2:
Screenshot 2025-07-08 at 3 30 17 PM
FSDP=2, TP=2, CP=2:
Screenshot 2025-07-08 at 3 34 06 PM

@@ -216,7 +218,7 @@ def __init__(self, job_config: JobConfig):
if parallel_dims.pp_enabled:
if not self.train_spec.pipelining_fn:
raise RuntimeError(
f"Pipeline Parallel is enabled but {self.train_spec.name} "
f"pipeline parallel is enabled but {self.train_spec.name} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why making this change?

"validation_fsdp_checkpoint",
ngpu=4,
),
OverrideDefinitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI resource is valuable, so let's not add too many tests for validation alone.
I think we can keep this one and remove others.

[
[
"--validation.enabled",
"--validation.dataset c4_validation",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use c4_test for CI testing which is included in the repo so don't need internet access. This way we won't be blocked by HF website issue.

# Set model back to train mode
model.train()

return {"validation_loss": global_avg_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not return this if it's not used. For logging to TB/WandB, we can pass in the logger so you could do additional logging within this method.


accumulated_losses = []
device_type = utils.device_type
num_val_steps = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to num_steps sounds better as we are already in validate.py

inputs = input_dict["input"]
predictions = model(inputs)

if self.parallel_dims.loss_parallel_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the train_context for validation as well, it will perform proper TP/CP forward pass for you

global_avg_loss = loss

logger.info(
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_val_steps} batches"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to scale the loss, either manually, or by reusing rescale_accumulated_loss in torchtitan/components/loss.py

@wesleytruong wesleytruong force-pushed the basic_validator_interface branch from d7163ee to a89f2c4 Compare July 9, 2025 17:26
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, could you edit the PR summary with latest test results (FSDP + CP + TP gives the same results as FSDP / 1GPU results)

@@ -217,7 +219,7 @@ def __init__(self, job_config: JobConfig):
if parallel_dims.pp_enabled:
if not self.train_spec.pipelining_fn:
raise RuntimeError(
f"Pipeline Parallel is enabled but {self.train_spec.name} "
f"Pipeline parallel is enabled but {self.train_spec.name} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK / proper to use Tensor Parallel, Pipeline Parallel, Context Parallel as special terms

@@ -509,6 +509,31 @@ def build_test_list():
"gradient_accumulation",
ngpu=2,
),
OverrideDefinitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove these two, and only keep 1 8GPU test with FSDP2 CP2 TP2

@wesleytruong wesleytruong changed the title non parallelized basic validator implementation [WIP] basic validator implementation Jul 9, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome, thanks!

@tianyu-l tianyu-l merged commit acd5ba8 into main Jul 10, 2025
7 checks passed
@tianyu-l tianyu-l deleted the basic_validator_interface branch July 10, 2025 03:46
@tianyu-l tianyu-l linked an issue Jul 10, 2025 that may be closed by this pull request
idoh pushed a commit to idoh/torchtitan that referenced this pull request Jul 28, 2025
Update PR Summary:
Implements a validator that can be easily plugged into the training loop
and configured from the job specific config file.
Changes:
- Created validation section in job_config with enabled, dataset, freq,
and steps fields
- Created a builder function for validator in train_spec
- Created a separate builder function for validation dataset in
hf_dataset.py
- Created validator class
- Validator class initializes a build_validation_hf_loader but leaves
this dataloader function unexposed to the train_spec
  - Validator class supports ddp, fsdp, cp, and tp (but not pp yet)
- Integrated validation call into training loop
- Creates an integration test to test parallelization

Updated tests training the same base model weights from a seed
checkpoint:

| FSDP=2 | FSDP=2,TP=4 |
| --- | --- |
| <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM"
src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9"
/> | <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM"
src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9"
/> |

| FSDP=2,CP=4 | FSDP=2,TP=2,CP=2 |
| --- | --- |
| <img width="972" alt="Screenshot 2025-07-09 at 4 39 35 PM"
src="https://github.com/user-attachments/assets/56d62841-5841-4969-85b1-803705892465"
/> | <img width="970" alt="Screenshot 2025-07-09 at 4 28 57 PM"
src="https://github.com/user-attachments/assets/f7d33fa8-ca2c-48f1-931c-8d4c017a47ce"
/> |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Evaluation] Minimal support for downstream tasks
6 participants