Skip to content

Commit d831aaf

Browse files
committed
Update base for Update on "[PP] Support OVERLAP_F_B computation type"
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see pytorch/torchtitan#1447) cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
1 parent 3fa8905 commit d831aaf

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

test/distributed/pipelining/test_schedule_multiproc.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,21 @@ def test_eval_inference_mode(self, ScheduleClass):
238238
stages, stage_modules, _ = self._create_multi_stage_pipeline(
239239
mod, stages_per_rank, n_stages
240240
)
241-
schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn, scale_grads=False)
241+
schedule = ScheduleClass(
242+
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
243+
)
242244
else:
243245
# Single-stage schedules
244246
mod, _, x, target, loss_fn = self._setup_models_and_data()
245247

246248
# Create single-stage pipeline
247-
stage, stage_module, _ = self._create_single_stage_pipeline(mod, x, num_microbatches)
249+
stage, stage_module, _ = self._create_single_stage_pipeline(
250+
mod, x, num_microbatches
251+
)
248252
stage_modules = [stage_module]
249-
schedule = ScheduleClass(stage, num_microbatches, loss_fn=loss_fn, scale_grads=False)
253+
schedule = ScheduleClass(
254+
stage, num_microbatches, loss_fn=loss_fn, scale_grads=False
255+
)
250256

251257
# Clear gradients and run eval
252258
self._zero_gradients(stage_modules)

0 commit comments

Comments
 (0)