Skip to content

Commit 9cf0e84

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 3e70806 Pull Request resolved: #834
1 parent ec82573 commit 9cf0e84

File tree

10 files changed

+474
-57
lines changed

10 files changed

+474
-57
lines changed

run_train.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
2326
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2427
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2528
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

scripts/estimate/estimation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
1616
from torch.testing._internal.distributed.fake_pg import FakeStore
1717

18+
from torchtitan.components.ft import init_ft_manager
1819
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
1920
from torchtitan.config_manager import JobConfig
2021
from torchtitan.distributed import ParallelDims, utils as dist_utils
@@ -102,7 +103,6 @@ def estimate_memory(job_config: JobConfig):
102103
if not job_config.memory_estimation.disable_fake_mode
103104
else contextlib.nullcontext()
104105
):
105-
106106
logger.info(
107107
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
108108
)
@@ -122,7 +122,8 @@ def estimate_memory(job_config: JobConfig):
122122
model.train()
123123

124124
# build optimizer after applying parallelisms to the model
125-
optimizers = build_optimizers([model], job_config)
125+
ft_manager = init_ft_manager(job_config)
126+
optimizers = build_optimizers([model], job_config, ft_manager)
126127
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
127128
# Post optimizer step model converters hook.
128129
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

tests/unit_tests/test_checkpoint.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class DummyJob:
6161
dump_folder: str = "dummy_folder"
6262

6363

64+
@dataclass
65+
class DummyFaultTolerance:
66+
replica_id = 0
67+
group_size = 1
68+
69+
6470
@dataclass
6571
class DummyJobConfig:
6672
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
6773
job: DummyJob = field(default_factory=DummyJob)
74+
fault_tolerance: DummyFaultTolerance = field(default_factory=DummyFaultTolerance)
75+
ft_manager = None
6876

6977

7078
# Dummy instances to supply as constructor arguments.
@@ -103,13 +111,16 @@ def tearDown(self):
103111
def test_save(self, *_):
104112
"""Test that calling save() writes a checkpoint file to disk."""
105113
job_config = DummyJobConfig(job=self.dummy_job)
114+
ft_manager = mock.Mock()
115+
ft_manager.enabled = False
106116
manager = CheckpointManager(
107117
dummy_dataloader,
108118
dummy_model_parts,
109119
dummy_optimizers,
110120
dummy_lr_schedulers,
111121
{"trainer": self.trainer_state},
112122
job_config,
123+
ft_manager,
113124
)
114125
step = 20
115126
manager.save(curr_step=step, force=True)
@@ -133,13 +144,16 @@ def test_save(self, *_):
133144
def test_load(self, *_):
134145
"""Test that load() properly reads the checkpoint file from disk and restores state."""
135146
job_config = DummyJobConfig(job=self.dummy_job)
147+
ft_manager = mock.Mock()
148+
ft_manager.enabled = False
136149
manager = CheckpointManager(
137150
dummy_dataloader,
138151
dummy_model_parts,
139152
dummy_optimizers,
140153
dummy_lr_schedulers,
141154
{"trainer": self.trainer_state},
142155
job_config,
156+
ft_manager,
143157
)
144158
step = 30
145159
manager.save(curr_step=step, force=True)
@@ -171,13 +185,16 @@ def test_purge_stale_checkpoints_rank_zero(self, *_):
171185
"""
172186
job_config = DummyJobConfig(job=self.dummy_job)
173187
job_config.checkpoint.keep_latest_k = 3
188+
ft_manager = mock.Mock()
189+
ft_manager.enabled = False
174190
manager = CheckpointManager(
175191
dummy_dataloader,
176192
dummy_model_parts,
177193
dummy_optimizers,
178194
dummy_lr_schedulers,
179195
{"trainer": self.trainer_state},
180196
job_config,
197+
ft_manager,
181198
)
182199
steps = [10, 20, 30, 40, 50]
183200
for s in steps:
@@ -215,13 +232,16 @@ def test_purge_stale_checkpoints_rank_nonzero(self, *_):
215232
"""
216233
job_config = DummyJobConfig(job=self.dummy_job)
217234
job_config.checkpoint.keep_latest_k = 3
235+
ft_manager = mock.Mock()
236+
ft_manager.enabled = False
218237
manager = CheckpointManager(
219238
dummy_dataloader,
220239
dummy_model_parts,
221240
dummy_optimizers,
222241
dummy_lr_schedulers,
223242
{"trainer": self.trainer_state},
224243
job_config,
244+
ft_manager,
225245
)
226246
steps = [10, 20, 30, 40, 50]
227247
for s in steps:
@@ -252,13 +272,16 @@ def test_async_save_calls_async_wait(self, *_):
252272
# Set async_mode to "async" in the job configuration.
253273
job_config = DummyJobConfig(job=self.dummy_job)
254274
job_config.checkpoint.async_mode = "async"
275+
ft_manager = mock.Mock()
276+
ft_manager.enabled = False
255277
manager = CheckpointManager(
256278
dummy_dataloader,
257279
dummy_model_parts,
258280
dummy_optimizers,
259281
dummy_lr_schedulers,
260282
{"trainer": self.trainer_state},
261283
job_config,
284+
ft_manager,
262285
)
263286
# First save: should schedule an async save.
264287
manager.save(curr_step=10, force=False)

torchtitan/components/checkpoint.py

Lines changed: 115 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.distributed.checkpoint.stateful import Stateful
3232
from torch.utils.data import DataLoader
3333

34+
from torchtitan.components.ft import FTManager
3435
from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer
3536
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3637
from torchtitan.tools.logging import init_logger, logger
@@ -214,6 +215,19 @@ class CheckpointManager:
214215
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215216
with the assumption that all lr_schedulers have the same state_dict.
216217
218+
Note: TorchFT checkpointing flow
219+
220+
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
221+
checkpoint, 2) the per-replica checkpoint.
222+
223+
The full perisistent checkpoint is saved by the replica with
224+
``ft_manager.participating_rank() == 0``. It contains everything including the model,
225+
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
226+
checkpoint is loaded by all replicas. However, we can optimize it to only load if
227+
there are no other alive replicas.
228+
229+
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
230+
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217231
218232
Args:
219233
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +237,7 @@ class CheckpointManager:
223237
states (Dict[str, Any]): The states that need to be saved, other than the
224238
previous 4 components.
225239
job_config (JobConfig): The job config used to configure the checkpointing.
240+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226241
"""
227242

228243
def __init__(
@@ -233,16 +248,41 @@ def __init__(
233248
lr_schedulers: LRSchedulersContainer,
234249
states: Dict[str, Any],
235250
job_config: JobConfig,
251+
ft_manager: FTManager,
236252
) -> None:
237253
ckpt_config = job_config.checkpoint
238254
self.enable_checkpoint = ckpt_config.enable_checkpoint
255+
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
256+
257+
if self.ft_manager:
258+
optimizers.init_cache_state_dict()
259+
260+
def state_dict():
261+
ret = {}
262+
for k, v in self.states.items():
263+
if k in {
264+
MODEL,
265+
OPTIMIZER,
266+
LR_SCHEDULER,
267+
TRAIN_STATE,
268+
}:
269+
ret[k] = v.state_dict()
270+
return ret
271+
272+
def load_state_dict(state_dict):
273+
assert state_dict is not None
274+
for k, v in state_dict.items():
275+
self.states[k].load_state_dict(v)
276+
277+
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
278+
self.ft_replica_id = job_config.fault_tolerance.replica_id
239279

240280
async_mode = ckpt_config.async_mode.lower()
241281
self.enable_staging = (
242282
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
243-
)
283+
) or self.ft_manager
244284

245-
if not self.enable_checkpoint:
285+
if not self.enable_checkpoint and self.ft_manager is None:
246286
return
247287

248288
self.states = states
@@ -254,6 +294,13 @@ def __init__(
254294
LR_SCHEDULER: lr_schedulers,
255295
}
256296
)
297+
self.ft_states = {DATALOADER: dataloader}
298+
299+
self.staging = False
300+
self.sending_to_checkpoint_mp = False
301+
self.staging_id = None
302+
self.cpu_offload_state_dict = None
303+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
257304

258305
self.staging = False
259306
self.sending_to_checkpoint_mp = False
@@ -264,7 +311,7 @@ def __init__(
264311
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
265312
self.interval = ckpt_config.interval
266313
async_mode = ckpt_config.async_mode.lower()
267-
if async_mode == AsyncMode.ASYNC:
314+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
268315
self.pg = dist.new_group(backend="gloo")
269316

270317
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -339,35 +386,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339386
None
340387
"""
341388

389+
if self.ft_manager:
390+
self._ft_save(curr_step)
391+
342392
if not self._should_save(curr_step, force):
343393
return
344394

345395
begin = time.monotonic()
346-
logger.info("Saving the checkpoint (or staging if async is enabled).")
347-
checkpoint_id = self._create_checkpoint_id(curr_step)
348-
self._async_wait()
349-
# This GC is called for async checkpoint as it is useless to do
350-
# GC right after async_save -- the CPU memory is not able to be
351-
# freed until _async_wait()
352-
if force:
353-
self._save_last_step(curr_step)
354-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
355-
GarbageCollection.collect("GC collection invoked by checkpointer.")
356-
self._async_with_pinned_memory(checkpoint_id)
357-
elif self.async_mode == AsyncMode.ASYNC:
358-
GarbageCollection.collect("GC collection invoked by checkpointer.")
359-
self.async_future = dcp.async_save(
360-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
361-
)
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
else:
364-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
365-
self._purge_stale_checkpoints()
396+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
397+
logger.info("Saving the checkpoint (or staging if async is enabled).")
398+
checkpoint_id = self._create_checkpoint_id(curr_step)
399+
self._async_wait()
400+
# This GC is called for async checkpoint as it is useless to do
401+
# GC right after async_save -- the CPU memory is not able to be
402+
# freed until _async_wait()
403+
if force:
404+
self._save_last_step(curr_step)
405+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
406+
GarbageCollection.collect("GC collection invoked by checkpointer.")
407+
self._async_with_pinned_memory(checkpoint_id)
408+
elif self.async_mode == AsyncMode.ASYNC:
409+
GarbageCollection.collect("GC collection invoked by checkpointer.")
410+
self.async_future = dcp.async_save(
411+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
412+
)
413+
GarbageCollection.collect("GC collection invoked by checkpointer.")
414+
else:
415+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
416+
self._purge_stale_checkpoints()
366417

367-
logger.info(
368-
"Finished saving the checkpoint (or staging if async is enabled)"
369-
f"in {time.monotonic() - begin:.2f} seconds."
370-
)
418+
logger.info(
419+
"Finished saving the checkpoint (or staging if async is enabled)"
420+
f"in {time.monotonic() - begin:.2f} seconds."
421+
)
422+
elif self.ft_manager:
423+
logger.info(
424+
"Replica %d doesn't save checkpoint.",
425+
self.ft_manager.participating_rank(),
426+
)
371427

372428
@torch.no_grad()
373429
def load(self, step: int = -1) -> bool:
@@ -384,6 +440,9 @@ def load(self, step: int = -1) -> bool:
384440
bool: Whether the checkpoint was loaded successfully.
385441
"""
386442

443+
if self.ft_manager:
444+
self._ft_load()
445+
387446
if not self.enable_checkpoint or not os.path.isdir(self.folder):
388447
return False
389448

@@ -467,10 +526,36 @@ def _find_load_step(self, folder: str = "") -> int:
467526
return -1
468527
return max(step_counts)
469528

529+
def _ft_folder(self) -> str:
530+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
531+
470532
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
471533
folder = folder if folder else self.folder
472534
return os.path.join(folder, f"step-{step}")
473535

536+
def _ft_save(self, step: int) -> None:
537+
begin = time.monotonic()
538+
self._async_wait()
539+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
540+
self.async_future = dcp.async_save(
541+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
542+
)
543+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
544+
545+
def _ft_load(self) -> None:
546+
step = self._find_load_step(folder=self._ft_folder())
547+
if step == -1:
548+
return
549+
550+
begin = time.monotonic()
551+
logger.info(f"Loading the FT checkpoint at step {step}.")
552+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
553+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
554+
GarbageCollection.collect("GC collection for checkpoint loading.")
555+
logger.info(
556+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
557+
)
558+
474559
def _states_to_load(self, step: int) -> Dict[str, Any]:
475560
"""Determines which states to load for the given step.
476561
@@ -491,6 +576,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491576
for exclude_key in self.exclude_from_loading:
492577
if exclude_key not in states:
493578
raise ValueError(f"{exclude_key} not found in state_dict.")
579+
if self.ft_manager:
580+
states_to_load.pop(DATALOADER)
494581
return states_to_load
495582

496583
def _save_last_step(self, curr_step: int) -> None:
@@ -577,6 +664,7 @@ def _purge_stale_checkpoints(self):
577664
self.keep_latest_k > 0
578665
and dist.get_rank() == 0
579666
and os.path.isdir(self.folder)
667+
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
580668
):
581669
discovered_checkpoints = []
582670
for filename in os.listdir(self.folder):

0 commit comments

Comments
 (0)