Skip to content

Commit 642060f

Browse files
committed
Adding torch accelerator to ddp-tutorial-series example
Signed-off-by: dggaytan <[email protected]>
1 parent 3e2c3ae commit 642060f

File tree

3 files changed

+17
-24
lines changed

3 files changed

+17
-24
lines changed

distributed/ddp-tutorial-series/multigpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def ddp_setup(rank, world_size):
2525
torch.accelerator.set_device_index(rank)
2626
print(f"Running on rank {rank} on device {device}")
2727
else:
28-
device = torch.device("cpu")
29-
print(f"Running on device {device}")
28+
print(f"Multi-GPU environment not detected")
3029

3130
backend = torch.distributed.get_default_backend_for_device(device)
3231
init_process_group(backend=backend, rank=rank, world_size=world_size)

distributed/ddp-tutorial-series/multigpu_torchrun.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def ddp_setup():
1717
torch.accelerator.set_device_index(rank)
1818
print(f"Running on rank {rank} on device {device}")
1919
else:
20-
device = torch.device("cpu")
21-
print(f"Running on device {device}")
22-
23-
backend = torch.distributed.get_default_backend_for_device(device)
24-
torch.distributed.init_process_group(backend=backend, device_id=device)
25-
return device
20+
print(f"Multi-GPU environment not detected")
21+
22+
backend = torch.distributed.get_default_backend_for_device(rank)
23+
torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank)
24+
2625

2726

2827
class Trainer:
@@ -33,7 +32,6 @@ def __init__(
3332
optimizer: torch.optim.Optimizer,
3433
save_every: int,
3534
snapshot_path: str,
36-
device: torch.device,
3735
) -> None:
3836
self.gpu_id = int(os.environ["LOCAL_RANK"])
3937
self.model = model.to(self.gpu_id)
@@ -42,15 +40,14 @@ def __init__(
4240
self.save_every = save_every
4341
self.epochs_run = 0
4442
self.snapshot_path = snapshot_path
45-
self.device = device
4643
if os.path.exists(snapshot_path):
4744
print("Loading snapshot")
4845
self._load_snapshot(snapshot_path)
4946

5047
self.model = DDP(self.model, device_ids=[self.gpu_id])
5148

5249
def _load_snapshot(self, snapshot_path):
53-
loc = str(self.device)
50+
loc = str(torch.accelerator.current_accelerator())
5451
snapshot = torch.load(snapshot_path, map_location=loc)
5552
self.model.load_state_dict(snapshot["MODEL_STATE"])
5653
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -105,10 +102,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
105102

106103

107104
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
108-
device = ddp_setup()
105+
ddp_setup()
109106
dataset, model, optimizer = load_train_objs()
110107
train_data = prepare_dataloader(dataset, batch_size)
111-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
108+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
112109
trainer.train(total_epochs)
113110
destroy_process_group()
114111

distributed/ddp-tutorial-series/multinode.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def ddp_setup():
1717
torch.accelerator.set_device_index(rank)
1818
print(f"Running on rank {rank} on device {device}")
1919
else:
20-
device = torch.device("cpu")
21-
print(f"Running on device {device}")
22-
23-
backend = torch.distributed.get_default_backend_for_device(device)
24-
torch.distributed.init_process_group(backend=backend, device_id=device)
25-
return device
20+
print(f"Multi-GPU environment not detected")
21+
22+
backend = torch.distributed.get_default_backend_for_device(rank)
23+
torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank)
24+
2625

2726
class Trainer:
2827
def __init__(
@@ -32,7 +31,6 @@ def __init__(
3231
optimizer: torch.optim.Optimizer,
3332
save_every: int,
3433
snapshot_path: str,
35-
device: torch.device,
3634
) -> None:
3735
self.local_rank = int(os.environ["LOCAL_RANK"])
3836
self.global_rank = int(os.environ["RANK"])
@@ -42,15 +40,14 @@ def __init__(
4240
self.save_every = save_every
4341
self.epochs_run = 0
4442
self.snapshot_path = snapshot_path
45-
self.device = device
4643
if os.path.exists(snapshot_path):
4744
print("Loading snapshot")
4845
self._load_snapshot(snapshot_path)
4946

5047
self.model = DDP(self.model, device_ids=[self.local_rank])
5148

5249
def _load_snapshot(self, snapshot_path):
53-
loc = str(self.device)
50+
loc = str(torch.accelerator.current_accelerator())
5451
snapshot = torch.load(snapshot_path, map_location=loc)
5552
self.model.load_state_dict(snapshot["MODEL_STATE"])
5653
self.epochs_run = snapshot["EPOCHS_RUN"]
@@ -105,10 +102,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
105102

106103

107104
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
108-
device = ddp_setup()
105+
ddp_setup()
109106
dataset, model, optimizer = load_train_objs()
110107
train_data = prepare_dataloader(dataset, batch_size)
111-
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
108+
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
112109
trainer.train(total_epochs)
113110
destroy_process_group()
114111

0 commit comments

Comments
 (0)