-
Notifications
You must be signed in to change notification settings - Fork 462
basic validator implementation #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
03792df
32d1c7e
df43c9f
4d712c6
384c579
5f3e434
a89f2c4
3f33b10
5146fa7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Generator | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.fsdp import FSDPModule | ||
from torchtitan.components.dataloader import BaseDataLoader | ||
from torchtitan.components.loss import LossFunction | ||
from torchtitan.components.tokenizer import Tokenizer | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader | ||
from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
from torchtitan.tools import utils | ||
from torchtitan.tools.logging import logger | ||
|
||
|
||
class BaseValidator: | ||
def __init__(self, job_config: JobConfig): | ||
self.job_config = job_config | ||
|
||
def validate(self, model_parts: list[nn.Module]) -> dict[str, float]: | ||
raise NotImplementedError("validate method not implemented") | ||
|
||
def should_validate(self, step: int) -> bool: | ||
return step % self.job_config.validation.freq == 0 | ||
|
||
|
||
class Validator(BaseValidator): | ||
""" | ||
Simple validator focused on correctness and integration. | ||
|
||
Args: | ||
job_config: Job configuration | ||
validation_dataloader: The validation dataloader | ||
loss_fn: Loss function to use for validation | ||
model: The model to validate (single model, no parallelism) | ||
""" | ||
|
||
validation_dataloader: BaseDataLoader | ||
|
||
def __init__( | ||
self, | ||
job_config: JobConfig, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
parallel_dims: ParallelDims, | ||
world_mesh: torch.distributed.DeviceMesh, | ||
loss_fn: LossFunction, | ||
validation_context: Generator[None, None, None], | ||
maybe_enable_amp: Generator[None, None, None], | ||
): | ||
self.job_config = job_config | ||
self.parallel_dims = parallel_dims | ||
self.world_mesh = world_mesh | ||
self.loss_fn = loss_fn | ||
self.validation_dataloader = build_hf_validation_dataloader( | ||
job_config=job_config, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
tokenizer=tokenizer, | ||
) | ||
self.validation_context = validation_context | ||
self.maybe_enable_amp = maybe_enable_amp | ||
|
||
@torch.no_grad() | ||
def validate( | ||
self, | ||
model_parts: list[nn.Module], | ||
) -> dict[str, float]: | ||
# Set model to eval mode | ||
# TODO: currently does not support pipeline parallelism | ||
model = model_parts[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to not support all parallelisms besides PP here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great point, I forgot this. |
||
model.eval() | ||
|
||
accumulated_losses = [] | ||
device_type = utils.device_type | ||
num_steps = 0 | ||
|
||
for input_dict, labels in self.validation_dataloader: | ||
if ( | ||
self.job_config.validation.steps != -1 | ||
and num_steps >= self.job_config.validation.steps | ||
): | ||
break | ||
|
||
for k, v in input_dict.items(): | ||
input_dict[k] = v.to(device_type) | ||
inputs = input_dict["input"] | ||
labels = labels.to(device_type) | ||
|
||
optional_context_parallel_ctx = ( | ||
dist_utils.create_context_parallel_ctx( | ||
cp_mesh=self.world_mesh["cp"], | ||
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
cp_no_restore_buffers={inputs, labels}, | ||
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
) | ||
if self.parallel_dims.cp_enabled | ||
else None | ||
) | ||
|
||
with self.validation_context(optional_context_parallel_ctx): | ||
assert len(model_parts) == 1 | ||
with self.maybe_enable_amp: | ||
predictions = model(inputs) | ||
loss = self.loss_fn(predictions, labels) | ||
|
||
accumulated_losses.append(loss.detach()) | ||
|
||
num_steps += 1 | ||
|
||
# Compute average loss | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that the number of tokens is the same for every batch. Maybe either manually keep track of the total number of tokens or at least add a NOTE that highlights this assumption. |
||
loss = torch.sum(torch.stack(accumulated_losses)) | ||
loss /= num_steps | ||
if self.parallel_dims.dp_cp_enabled: | ||
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"]) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a |
||
global_avg_loss = loss | ||
|
||
logger.info( | ||
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches" | ||
) | ||
|
||
# Reshard after run forward pass | ||
# This is to ensure the model weights are sharded the same way for checkpoint saving. | ||
for module in model.modules(): | ||
if isinstance(module, FSDPModule): | ||
module.reshard() | ||
|
||
# Set model back to train mode | ||
model.train() | ||
|
||
|
||
def build_validator( | ||
job_config: JobConfig, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
parallel_dims: ParallelDims, | ||
world_mesh: torch.distributed.DeviceMesh, | ||
loss_fn: LossFunction, | ||
validation_context: Generator[None, None, None], | ||
maybe_enable_amp: Generator[None, None, None], | ||
) -> BaseValidator: | ||
"""Build a simple validator focused on correctness.""" | ||
return Validator( | ||
job_config=job_config, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
tokenizer=tokenizer, | ||
parallel_dims=parallel_dims, | ||
world_mesh=world_mesh, | ||
loss_fn=loss_fn, | ||
validation_context=validation_context, | ||
maybe_enable_amp=maybe_enable_amp, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -665,6 +665,35 @@ class Experimental: | |
""" | ||
|
||
|
||
@dataclass | ||
class Validation: | ||
enabled: bool = False | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could remove this field and modify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep this |
||
"""Enable validation to default run validation after each training loop""" | ||
|
||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dataset: str = "c4_validation" | ||
"""Dataset to use for validation""" | ||
|
||
dataset_path: str | None = None | ||
"""Path to dataset to use for validation""" | ||
|
||
local_batch_size: int = 8 | ||
"""Batch size for validation""" | ||
|
||
seq_len: int = 2048 | ||
"""Sequence length for validation""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set up a |
||
freq: int = 10 | ||
"""Frequency of validation""" | ||
|
||
steps: int = -1 | ||
"""Number of steps to take in the validation set, -1 means consuming all the data in the validation dataset""" | ||
|
||
def __post_init__(self): | ||
assert ( | ||
self.steps > 0 or self.steps == -1 | ||
), "validation steps must be positive or -1" | ||
|
||
|
||
@dataclass | ||
class JobConfig: | ||
""" | ||
|
@@ -689,6 +718,7 @@ class JobConfig: | |
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation) | ||
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) | ||
experimental: Experimental = field(default_factory=Experimental) | ||
validation: Validation = field(default_factory=Validation) | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
return asdict(self) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
from functools import partial | ||
from typing import Any, Callable | ||
|
||
import torch | ||
|
@@ -20,9 +22,9 @@ | |
from torchtitan.tools.logging import logger | ||
|
||
|
||
def _load_c4_dataset(dataset_path: str): | ||
def _load_c4_dataset(dataset_path: str, split: str): | ||
"""Load C4 dataset with default configuration.""" | ||
return load_dataset(dataset_path, name="en", split="train", streaming=True) | ||
return load_dataset(dataset_path, name="en", split=split, streaming=True) | ||
|
||
|
||
def _process_c4_text(sample: dict[str, Any]) -> str: | ||
|
@@ -41,14 +43,19 @@ class DatasetConfig: | |
DATASETS = { | ||
"c4": DatasetConfig( | ||
path="allenai/c4", | ||
loader=_load_c4_dataset, | ||
loader=partial(_load_c4_dataset, split="train"), | ||
text_processor=_process_c4_text, | ||
), | ||
"c4_test": DatasetConfig( | ||
path="tests/assets/c4_test", | ||
loader=lambda path: load_dataset(path, split="train"), | ||
text_processor=_process_c4_text, | ||
), | ||
"c4_validation": DatasetConfig( | ||
path="allenai/c4", | ||
loader=partial(_load_c4_dataset, split="validation"), | ||
text_processor=_process_c4_text, | ||
), | ||
} | ||
|
||
|
||
|
@@ -193,3 +200,33 @@ def build_hf_dataloader( | |
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) | ||
|
||
|
||
def build_hf_validation_dataloader( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think adding a new function for this is necessary; I would prefer replacing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we probably don't want to change this interface is: I think it's ok we make this compromise for that purpose. |
||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
job_config: JobConfig, | ||
) -> ParallelAwareDataloader: | ||
"""Build a validation data loader for HuggingFace datasets.""" | ||
dataset_name = job_config.validation.dataset | ||
dataset_path = job_config.validation.dataset_path | ||
batch_size = job_config.validation.local_batch_size | ||
seq_len = job_config.validation.seq_len | ||
|
||
hf_ds = HuggingFaceDataset( | ||
dataset_name=dataset_name, | ||
dataset_path=dataset_path, | ||
tokenizer=tokenizer, | ||
seq_len=seq_len, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
infinite=False, | ||
) | ||
|
||
return ParallelAwareDataloader( | ||
dataset=hf_ds, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
batch_size=batch_size, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove these two, and only keep 1 8GPU test with FSDP2 CP2 TP2