-
Notifications
You must be signed in to change notification settings - Fork 461
basic validator implementation #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass looks really good!
I left many detailed comments, please see if they make sense.
torchtitan/components/validate.py
Outdated
): | ||
self.job_config = job_config | ||
self.loss_fn = loss_fn | ||
self.model = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should pass model (model_parts
) as an arg to validate
, because it's changing
torchtitan/components/validate.py
Outdated
job_config: JobConfig, | ||
loss_fn: LossFunction, | ||
model: nn.Module, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
tokenizer: Tokenizer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make the order as close as how you used them below in build_hf_validation_dataloader
|
||
seq_len: int = 2048 | ||
"""Sequence length for validation""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set up a steps
config, controlling how many iterations we run, default to -1 which means consuming all the data in the validation dataset
torchtitan/datasets/hf_datasets.py
Outdated
# path="tests/assets/c4_test", | ||
# loader=lambda path: load_dataset(path, split="validation"), | ||
# text_processor=_process_c4_text, | ||
# ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should use path="allenai/c4",
and loader=lambda path: load_dataset(path, name="en", split="validation"),
torchtitan/train.py
Outdated
@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig): | |||
device_type, | |||
) | |||
|
|||
# Build validator if validation is configured | |||
self.validator = None | |||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if job_config.validation.enabled:
assert self.train_spec.build_validator_fn is not None
# build validator ...
torchtitan/train.py
Outdated
@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig): | |||
device_type, | |||
) | |||
|
|||
# Build validator if validation is configured | |||
self.validator = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this line, since it's already defined as instance variable
torchtitan/components/validate.py
Outdated
for k, v in input_dict.items(): | ||
if isinstance(v, torch.Tensor): | ||
input_dict[k] = v.to(device_type) | ||
if isinstance(labels, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this if
?
torchtitan/components/validate.py
Outdated
for batch_data, targets in self.validation_dataloader: | ||
input_dict, labels = batch_data, targets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for batch_data, targets in self.validation_dataloader: | |
input_dict, labels = batch_data, targets | |
for input_dict, labels in self.validation_dataloader: |
torchtitan/components/validate.py
Outdated
logger.warning("No validation batches processed") | ||
|
||
# Set model back to train mode | ||
self.model.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put this as the last line of this method
I've cleaned up the code according to your comments and added support for the validation frequency and steps. I also left streaming=True in the c4_validation dataset since otherwise it downloads the entire training dataset too. @tianyu-l |
torchtitan/config_manager.py
Outdated
seq_len: int = 2048 | ||
"""Sequence length for validation""" | ||
|
||
val_freq: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to have the val_
prefix as it's not ambiguous under Validation
val_freq: int = 1 | |
freq: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe default to 10
torchtitan/config_manager.py
Outdated
"""Frequency of validation""" | ||
|
||
val_steps: int = -1 | ||
"""Number of validation steps, -1 means all steps""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Number of validation steps, -1 means all steps""" | |
"""Number of validation steps, -1 means consuming all the data in the validation dataset""" |
torchtitan/datasets/hf_datasets.py
Outdated
dp_rank: int, | ||
tokenizer: Tokenizer, | ||
job_config: JobConfig, | ||
infinite: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this arg -- I don't think anyone wants to do multiple loops over the validation dataset
torchtitan/datasets/hf_datasets.py
Outdated
seq_len=seq_len, | ||
dp_rank=dp_rank, | ||
dp_world_size=dp_world_size, | ||
infinite=infinite, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you can always set to False
here
@@ -54,6 +54,7 @@ tensor_parallel_degree = 1 | |||
enable_async_tensor_parallel = false | |||
pipeline_parallel_degree = 1 | |||
context_parallel_degree = 1 | |||
disable_loss_parallel = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert this change?
torchtitan/train.py
Outdated
@@ -463,6 +477,12 @@ def train_step( | |||
else: | |||
global_avg_loss = global_max_loss = loss.detach().item() | |||
|
|||
# Run validation if validator is available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as this is not part of training step, let's put this outside train_step
and put it in train
before self.checkpointer.save(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As my comment above, please check if validation breaks checkpointing if FSDP is used.
tests/integration_tests.py
Outdated
"--validation.dataset c4_test", | ||
], | ||
], | ||
"Validation test no parallelism", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is not without parallelism -- you are doing data parallel for validation; however, you are not doing all-reduce on the loss, so the loss you print out would be different on each DP rank. Let's do that in this PR, following the code in model forward.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L451-L464
For that you'll need to pass in parallel_dims
world_mesh
ft_manager
when constructing Validator
I think then the code will support Tensor Parallel and Context Parallel but not Pipeline Parallel yet, which we can do in a followup PR.
model_parts: list[nn.Module], | ||
) -> dict[str, float]: | ||
# Set model to eval mode | ||
model = model_parts[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a TODO:
here claiming we only support data parallel for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to not support all parallelisms besides PP here?
torchtitan/components/validate.py
Outdated
num_val_steps = 0 | ||
|
||
with torch.no_grad(): | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you don't need this try-catch because StopIteration will be automatically captured by for loop safely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing this, this will be very useful!
You can take a look at these changes for some inspiration for addressing some of my comments.
torchtitan/train.py
Outdated
if self.job_config.validation.enabled and self.validator.should_validate( | ||
self.step | ||
): | ||
validation_metrics = self.validator.validate(self.model_parts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation metrics should be logged by self.metrics_processor.log()
(to the terminal output and Tensorboard/wandb).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For logging to TB/W&B I agree we should use self.metrics_processor.logger
.
For terminal, there are two options, one is to do it locally in this function, the other is creating a new metrics processor like what you did. I personally think the latter tries to make the style consistent (which I appreciate), but sounds a bit overkill.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a separate metrics processor seems like overkill for this implementation. A third option is to also use self.metrics_processor.log
to print the metrics in the terminal.
# Build validator if validation is configured | ||
if job_config.validation.enabled: | ||
assert self.train_spec.build_validator_fn is not None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you raise an error here if parallel_dims.pp_enabled
?
torchtitan/datasets/hf_datasets.py
Outdated
@@ -49,6 +49,13 @@ class DatasetConfig: | |||
loader=lambda path: load_dataset(path, split="train"), | |||
text_processor=_process_c4_text, | |||
), | |||
"c4_validation": DatasetConfig( | |||
path="allenai/c4", | |||
loader=lambda path: load_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: you can reuse _load_c4_dataset
together with functools.partial
here by adding split
as an argument to _load_c4_dataset
.
@@ -193,3 +200,34 @@ def build_hf_dataloader( | |||
dp_world_size=dp_world_size, | |||
batch_size=batch_size, | |||
) | |||
|
|||
|
|||
def build_hf_validation_dataloader( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think adding a new function for this is necessary; I would prefer replacing the job_config
argument with dataset_name
, dataset_path
, batch_size
, and seq_len
. The reasoning is that for validation the function is also just returning a data loader based on a HF dataset, just the underlying dataset will be different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we probably don't want to change this interface is:
people plug in their own data loader, and they want it to be general enough https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L133-L138
We used to have something closer to what you proposed, but changed due to their requests.
I think it's ok we make this compromise for that purpose.
@@ -657,6 +657,30 @@ class Experimental: | |||
""" | |||
|
|||
|
|||
@dataclass | |||
class Validation: | |||
enabled: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could remove this field and modify val_freq
to offer an option for disabling validation, e.g., val_freq: int | None = 10
, where validation is disabled if val_freq=None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep this enabled
to be consistent with other configs in torchtitan -- this sounds more like a style thing?
# Compute average loss | ||
if num_batches > 0: | ||
average_loss = total_loss / num_batches | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a __post_init__
to the Validation
dataclass that verifies that all values are valid, e.g., val_steps > 0
.
torchtitan/components/validate.py
Outdated
# Set model back to train mode | ||
model.train() | ||
|
||
return {"validation_loss": average_loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The average_loss
is the local loss for each rank, but should still be all-reduced across ranks.
torchtitan/components/validate.py
Outdated
# Set model back to train mode | ||
model.train() | ||
|
||
return {"validation_loss": average_loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change this to "validation/loss"
? This is important for how wandb represents the metrics and allows you to add more metrics to the same section via "validation/<you-new-metric>"
later on.
torchtitan/components/validate.py
Outdated
total_loss += loss.item() | ||
num_batches += 1 | ||
|
||
num_val_steps += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you use separate counters for num_batches
and num_val_steps
? Also, you could use this instead:
for step, (input_dict, labels) in enumerate(self.validation_dataloader):
Here, step
replaces num_batches
and num_val_steps
. You would also have to change num_val_steps >= self.job_config.validation.val_steps
to step > self.job_config.validation.val_steps
above.
torchtitan/components/validate.py
Outdated
device_type = utils.device_type | ||
num_val_steps = 0 | ||
|
||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: you can also use this as a decorator instead, so you don't have to indent your code as much.
@torch.no_grad()
def validate(
model_parts: list[nn.Module], | ||
) -> dict[str, float]: | ||
# Set model to eval mode | ||
model = model_parts[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this PR assume the model is not sharded and will handle model being sharded later? If so can we raise an exception if dp_shard > 1? If we do support FSDP, then you will need to be careful about reshard_after_forward
value as ensure the parameters are sharded before leaving validate()
. Otherwise, checkpointing will be broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point, I forgot this.
Please include & adapt the following code in Validator.validate()
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L185-L189
cc @wwwjn
torchtitan/train.py
Outdated
@@ -463,6 +477,12 @@ def train_step( | |||
else: | |||
global_avg_loss = global_max_loss = loss.detach().item() | |||
|
|||
# Run validation if validator is available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As my comment above, please check if validation breaks checkpointing if FSDP is used.
To show the validation parallelism is correct I followed https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#reproducibility-between-runs, to compare training the same model with different parallelisms and show they have the same loss. The validation is ran every 10 steps across 50 total training steps and local batch size 8. |
torchtitan/train.py
Outdated
@@ -216,7 +218,7 @@ def __init__(self, job_config: JobConfig): | |||
if parallel_dims.pp_enabled: | |||
if not self.train_spec.pipelining_fn: | |||
raise RuntimeError( | |||
f"Pipeline Parallel is enabled but {self.train_spec.name} " | |||
f"pipeline parallel is enabled but {self.train_spec.name} " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why making this change?
tests/integration_tests.py
Outdated
"validation_fsdp_checkpoint", | ||
ngpu=4, | ||
), | ||
OverrideDefinitions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI resource is valuable, so let's not add too many tests for validation alone.
I think we can keep this one and remove others.
tests/integration_tests.py
Outdated
[ | ||
[ | ||
"--validation.enabled", | ||
"--validation.dataset c4_validation", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use c4_test
for CI testing which is included in the repo so don't need internet access. This way we won't be blocked by HF website issue.
torchtitan/components/validate.py
Outdated
# Set model back to train mode | ||
model.train() | ||
|
||
return {"validation_loss": global_avg_loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not return this if it's not used. For logging to TB/WandB, we can pass in the logger so you could do additional logging within this method.
torchtitan/components/validate.py
Outdated
|
||
accumulated_losses = [] | ||
device_type = utils.device_type | ||
num_val_steps = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename to num_steps
sounds better as we are already in validate.py
torchtitan/components/validate.py
Outdated
inputs = input_dict["input"] | ||
predictions = model(inputs) | ||
|
||
if self.parallel_dims.loss_parallel_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the train_context
for validation as well, it will perform proper TP/CP forward pass for you
torchtitan/components/validate.py
Outdated
global_avg_loss = loss | ||
|
||
logger.info( | ||
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_val_steps} batches" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to scale the loss, either manually, or by reusing rescale_accumulated_loss
in torchtitan/components/loss.py
…e integration tests
d7163ee
to
a89f2c4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, could you edit the PR summary with latest test results (FSDP + CP + TP gives the same results as FSDP / 1GPU results)
torchtitan/train.py
Outdated
@@ -217,7 +219,7 @@ def __init__(self, job_config: JobConfig): | |||
if parallel_dims.pp_enabled: | |||
if not self.train_spec.pipelining_fn: | |||
raise RuntimeError( | |||
f"Pipeline Parallel is enabled but {self.train_spec.name} " | |||
f"Pipeline parallel is enabled but {self.train_spec.name} " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK / proper to use Tensor Parallel, Pipeline Parallel, Context Parallel as special terms
@@ -509,6 +509,31 @@ def build_test_list(): | |||
"gradient_accumulation", | |||
ngpu=2, | |||
), | |||
OverrideDefinitions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove these two, and only keep 1 8GPU test with FSDP2 CP2 TP2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome, thanks!
Update PR Summary: Implements a validator that can be easily plugged into the training loop and configured from the job specific config file. Changes: - Created validation section in job_config with enabled, dataset, freq, and steps fields - Created a builder function for validator in train_spec - Created a separate builder function for validation dataset in hf_dataset.py - Created validator class - Validator class initializes a build_validation_hf_loader but leaves this dataloader function unexposed to the train_spec - Validator class supports ddp, fsdp, cp, and tp (but not pp yet) - Integrated validation call into training loop - Creates an integration test to test parallelization Updated tests training the same base model weights from a seed checkpoint: | FSDP=2 | FSDP=2,TP=4 | | --- | --- | | <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM" src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9" /> | <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM" src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9" /> | | FSDP=2,CP=4 | FSDP=2,TP=2,CP=2 | | --- | --- | | <img width="972" alt="Screenshot 2025-07-09 at 4 39 35 PM" src="https://github.com/user-attachments/assets/56d62841-5841-4969-85b1-803705892465" /> | <img width="970" alt="Screenshot 2025-07-09 at 4 28 57 PM" src="https://github.com/user-attachments/assets/f7d33fa8-ca2c-48f1-931c-8d4c017a47ce" /> |
Update PR Summary:
Implements a validator that can be easily plugged into the training loop and configured from the job specific config file.
Changes:
Updated tests training the same base model weights from a seed checkpoint: