4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Generator
8
+
7
9
import torch
8
10
import torch .nn as nn
9
11
from torch .distributed .fsdp import FSDPModule
10
-
11
12
from torch .distributed .tensor import DTensor
12
13
from torchtitan .components .dataloader import BaseDataLoader
13
14
from torchtitan .components .loss import LossFunction
@@ -52,6 +53,8 @@ def __init__(
52
53
parallel_dims : ParallelDims ,
53
54
world_mesh : torch .distributed .DeviceMesh ,
54
55
loss_fn : LossFunction ,
56
+ validation_context : Generator [None , None , None ],
57
+ maybe_enable_amp : Generator [None , None , None ],
55
58
):
56
59
self .job_config = job_config
57
60
self .parallel_dims = parallel_dims
@@ -63,6 +66,8 @@ def __init__(
63
66
dp_rank = dp_rank ,
64
67
tokenizer = tokenizer ,
65
68
)
69
+ self .validation_context = validation_context
70
+ self .maybe_enable_amp = maybe_enable_amp
66
71
67
72
@torch .no_grad ()
68
73
def validate (
@@ -76,44 +81,52 @@ def validate(
76
81
77
82
accumulated_losses = []
78
83
device_type = utils .device_type
79
- num_val_steps = 0
84
+ num_steps = 0
80
85
81
86
for input_dict , labels in self .validation_dataloader :
82
87
if (
83
88
self .job_config .validation .steps != - 1
84
- and num_val_steps >= self .job_config .validation .steps
89
+ and num_steps >= self .job_config .validation .steps
85
90
):
86
91
break
87
92
88
93
for k , v in input_dict .items ():
89
94
input_dict [k ] = v .to (device_type )
90
- labels = labels .to (device_type )
91
-
92
95
inputs = input_dict ["input" ]
93
- predictions = model ( inputs )
96
+ labels = labels . to ( device_type )
94
97
95
- if self .parallel_dims .loss_parallel_enabled :
96
- if isinstance (predictions , torch .Tensor ) and not isinstance (
97
- predictions , DTensor
98
- ):
99
- predictions = DTensor .from_local (predictions , self .world_mesh ["tp" ])
100
- if isinstance (labels , torch .Tensor ) and not isinstance (labels , DTensor ):
101
- labels = DTensor .from_local (labels , self .world_mesh ["tp" ])
102
- loss = self .loss_fn (predictions , labels )
98
+ optional_context_parallel_ctx = (
99
+ dist_utils .create_context_parallel_ctx (
100
+ cp_mesh = self .world_mesh ["cp" ],
101
+ cp_buffers = [inputs , labels ] + [m .freqs_cis for m in model_parts ],
102
+ cp_seq_dims = [1 , 1 ] + [0 for _ in model_parts ],
103
+ cp_no_restore_buffers = {inputs , labels },
104
+ cp_rotate_method = self .job_config .parallelism .context_parallel_rotate_method ,
105
+ )
106
+ if self .parallel_dims .cp_enabled
107
+ else None
108
+ )
109
+
110
+ with self .validation_context (optional_context_parallel_ctx ):
111
+ assert len (model_parts ) == 1
112
+ with self .maybe_enable_amp :
113
+ predictions = model (inputs )
114
+ loss = self .loss_fn (predictions , labels )
103
115
104
116
accumulated_losses .append (loss .detach ())
105
117
106
- num_val_steps += 1
118
+ num_steps += 1
107
119
108
120
# Compute average loss
109
121
loss = torch .sum (torch .stack (accumulated_losses ))
122
+ loss /= num_steps
110
123
if self .parallel_dims .dp_cp_enabled :
111
124
global_avg_loss = dist_utils .dist_mean (loss , self .world_mesh ["dp_cp" ])
112
125
else :
113
126
global_avg_loss = loss
114
127
115
128
logger .info (
116
- f"Validation completed. Average loss: { global_avg_loss :.4f} over { num_val_steps } batches"
129
+ f"Validation completed. Average loss: { global_avg_loss :.4f} over { num_steps } batches"
117
130
)
118
131
119
132
# Reshard after run forward pass
@@ -125,8 +138,6 @@ def validate(
125
138
# Set model back to train mode
126
139
model .train ()
127
140
128
- return {"validation_loss" : global_avg_loss }
129
-
130
141
131
142
def build_validator (
132
143
job_config : JobConfig ,
@@ -136,6 +147,8 @@ def build_validator(
136
147
parallel_dims : ParallelDims ,
137
148
world_mesh : torch .distributed .DeviceMesh ,
138
149
loss_fn : LossFunction ,
150
+ validation_context : Generator [None , None , None ],
151
+ maybe_enable_amp : Generator [None , None , None ],
139
152
) -> BaseValidator :
140
153
"""Build a simple validator focused on correctness."""
141
154
return Validator (
@@ -146,4 +159,6 @@ def build_validator(
146
159
parallel_dims = parallel_dims ,
147
160
world_mesh = world_mesh ,
148
161
loss_fn = loss_fn ,
162
+ validation_context = validation_context ,
163
+ maybe_enable_amp = maybe_enable_amp ,
149
164
)
0 commit comments