Skip to content

basic validator implementation #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,20 @@ def build_test_list():
"gradient_accumulation",
ngpu=2,
),
OverrideDefinitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove these two, and only keep 1 8GPU test with FSDP2 CP2 TP2

[
[
"--validation.enabled",
"--validation.dataset c4_test",
"--parallelism.data_parallel_replicate_degree=2",
"--parallelism.tensor_parallel_degree=2",
"--parallelism.context_parallel_degree=2",
],
],
"Validation test with fsdp, tp, cp",
"validation_fsdp_tp_cp",
ngpu=8,
),
]
return integration_tests_flavors

Expand Down
163 changes: 163 additions & 0 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator

import torch
import torch.nn as nn
from torch.distributed.fsdp import FSDPModule
from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.tools import utils
from torchtitan.tools.logging import logger


class BaseValidator:
def __init__(self, job_config: JobConfig):
self.job_config = job_config

def validate(self, model_parts: list[nn.Module]) -> dict[str, float]:
raise NotImplementedError("validate method not implemented")

def should_validate(self, step: int) -> bool:
return step % self.job_config.validation.freq == 0


class Validator(BaseValidator):
"""
Simple validator focused on correctness and integration.

Args:
job_config: Job configuration
validation_dataloader: The validation dataloader
loss_fn: Loss function to use for validation
model: The model to validate (single model, no parallelism)
"""

validation_dataloader: BaseDataLoader

def __init__(
self,
job_config: JobConfig,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
parallel_dims: ParallelDims,
world_mesh: torch.distributed.DeviceMesh,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
maybe_enable_amp: Generator[None, None, None],
):
self.job_config = job_config
self.parallel_dims = parallel_dims
self.world_mesh = world_mesh
self.loss_fn = loss_fn
self.validation_dataloader = build_hf_validation_dataloader(
job_config=job_config,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
)
self.validation_context = validation_context
self.maybe_enable_amp = maybe_enable_amp

@torch.no_grad()
def validate(
self,
model_parts: list[nn.Module],
) -> dict[str, float]:
# Set model to eval mode
# TODO: currently does not support pipeline parallelism
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO: here claiming we only support data parallel for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not support all parallelisms besides PP here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about reshard_after_forward value as ensure the parameters are sharded before leaving validate(). Otherwise, checkpointing will be broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, I forgot this.
Please include & adapt the following code in Validator.validate()
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L185-L189
cc @wwwjn

model.eval()

accumulated_losses = []
device_type = utils.device_type
num_steps = 0

for input_dict, labels in self.validation_dataloader:
if (
self.job_config.validation.steps != -1
and num_steps >= self.job_config.validation.steps
):
break

for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
inputs = input_dict["input"]
labels = labels.to(device_type)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=self.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if self.parallel_dims.cp_enabled
else None
)

with self.validation_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model(inputs)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())

num_steps += 1

# Compute average loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the number of tokens is the same for every batch. Maybe either manually keep track of the total number of tokens or at least add a NOTE that highlights this assumption.

loss = torch.sum(torch.stack(accumulated_losses))
loss /= num_steps
if self.parallel_dims.dp_cp_enabled:
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a __post_init__ to the Validation dataclass that verifies that all values are valid, e.g., val_steps > 0.

global_avg_loss = loss

logger.info(
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches"
)

# Reshard after run forward pass
# This is to ensure the model weights are sharded the same way for checkpoint saving.
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard()

# Set model back to train mode
model.train()


def build_validator(
job_config: JobConfig,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
parallel_dims: ParallelDims,
world_mesh: torch.distributed.DeviceMesh,
loss_fn: LossFunction,
validation_context: Generator[None, None, None],
maybe_enable_amp: Generator[None, None, None],
) -> BaseValidator:
"""Build a simple validator focused on correctness."""
return Validator(
job_config=job_config,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
tokenizer=tokenizer,
parallel_dims=parallel_dims,
world_mesh=world_mesh,
loss_fn=loss_fn,
validation_context=validation_context,
maybe_enable_amp=maybe_enable_amp,
)
30 changes: 30 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,35 @@ class Experimental:
"""


@dataclass
class Validation:
enabled: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove this field and modify val_freq to offer an option for disabling validation, e.g., val_freq: int | None = 10, where validation is disabled if val_freq=None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this enabled to be consistent with other configs in torchtitan -- this sounds more like a style thing?

"""Enable validation to default run validation after each training loop"""

dataset: str = "c4_validation"
"""Dataset to use for validation"""

dataset_path: str | None = None
"""Path to dataset to use for validation"""

local_batch_size: int = 8
"""Batch size for validation"""

seq_len: int = 2048
"""Sequence length for validation"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set up a steps config, controlling how many iterations we run, default to -1 which means consuming all the data in the validation dataset

freq: int = 10
"""Frequency of validation"""

steps: int = -1
"""Number of steps to take in the validation set, -1 means consuming all the data in the validation dataset"""

def __post_init__(self):
assert (
self.steps > 0 or self.steps == -1
), "validation steps must be positive or -1"


@dataclass
class JobConfig:
"""
Expand All @@ -689,6 +718,7 @@ class JobConfig:
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
experimental: Experimental = field(default_factory=Experimental)
validation: Validation = field(default_factory=Validation)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
Expand Down
43 changes: 40 additions & 3 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

from functools import partial
from typing import Any, Callable

import torch
Expand All @@ -20,9 +22,9 @@
from torchtitan.tools.logging import logger


def _load_c4_dataset(dataset_path: str):
def _load_c4_dataset(dataset_path: str, split: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)
return load_dataset(dataset_path, name="en", split=split, streaming=True)


def _process_c4_text(sample: dict[str, Any]) -> str:
Expand All @@ -41,14 +43,19 @@ class DatasetConfig:
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=_load_c4_dataset,
loader=partial(_load_c4_dataset, split="train"),
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="tests/assets/c4_test",
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
"c4_validation": DatasetConfig(
path="allenai/c4",
loader=partial(_load_c4_dataset, split="validation"),
text_processor=_process_c4_text,
),
}


Expand Down Expand Up @@ -193,3 +200,33 @@ def build_hf_dataloader(
dp_world_size=dp_world_size,
batch_size=batch_size,
)


def build_hf_validation_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think adding a new function for this is necessary; I would prefer replacing the job_config argument with dataset_name, dataset_path, batch_size, and seq_len. The reasoning is that for validation the function is also just returning a data loader based on a HF dataset, just the underlying dataset will be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we probably don't want to change this interface is:
people plug in their own data loader, and they want it to be general enough https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L133-L138
We used to have something closer to what you proposed, but changed due to their requests.

I think it's ok we make this compromise for that purpose.

dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
) -> ParallelAwareDataloader:
"""Build a validation data loader for HuggingFace datasets."""
dataset_name = job_config.validation.dataset
dataset_path = job_config.validation.dataset_path
batch_size = job_config.validation.local_batch_size
seq_len = job_config.validation.seq_len

hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=False,
)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
2 changes: 2 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.validate import build_validator
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -81,5 +82,6 @@
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
)
)
6 changes: 6 additions & 0 deletions torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,9 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[validation]
enabled = false
dataset = "c4_validation"
freq = 5
steps = 10
3 changes: 3 additions & 0 deletions torchtitan/protocols/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtitan.components.metrics import MetricsProcessor
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.components.validate import BaseValidator
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

Expand Down Expand Up @@ -80,6 +81,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
[OptimizersContainer, JobConfig], LRSchedulersContainer
]
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
ValidatorBuilder: TypeAlias = Callable[..., BaseValidator]


@dataclass
Expand All @@ -94,6 +96,7 @@ class TrainSpec:
build_dataloader_fn: DataLoaderBuilder
build_tokenizer_fn: TokenizerBuilder | None
build_loss_fn: LossFunctionBuilder
build_validator_fn: ValidatorBuilder | None = None
build_metrics_processor_fn: MetricsProcessorBuilder | None = None


Expand Down
Loading
Loading