-
Notifications
You must be signed in to change notification settings - Fork 0
Adding torch accelerator to ddp-tutorial-series example #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,8 +18,14 @@ def ddp_setup(rank, world_size): | |
| """ | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12355" | ||
| torch.cuda.set_device(rank) | ||
| init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
|
|
||
| if torch.accelerator.is_available(): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you also need to remove the |
||
| device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") | ||
| torch.accelerator.set_device_index(rank) | ||
| print(f"Running on rank {rank} on device {device}") | ||
|
|
||
| backend = torch.distributed.get_default_backend_for_device(device) | ||
| init_process_group(backend=backend, rank=rank, world_size=world_size) | ||
|
|
||
| class Trainer: | ||
| def __init__( | ||
|
|
@@ -95,10 +101,10 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s | |
| if __name__ == "__main__": | ||
| import argparse | ||
| parser = argparse.ArgumentParser(description='simple distributed training job') | ||
| parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') | ||
| parser.add_argument('save_every', type=int, help='How often to save a snapshot') | ||
| parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model') | ||
| parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot') | ||
| parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | ||
| args = parser.parse_args() | ||
|
|
||
| world_size = torch.cuda.device_count() | ||
| world_size = torch.accelerator.device_count() | ||
| mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,8 +11,18 @@ | |
|
|
||
|
|
||
| def ddp_setup(): | ||
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
| init_process_group(backend="nccl") | ||
| rank = int(os.environ["LOCAL_RANK"]) | ||
| if torch.accelerator.is_available(): | ||
| device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") | ||
| torch.accelerator.set_device_index(rank) | ||
| print(f"Running on rank {rank} on device {device}") | ||
| else: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line is not neccesary. |
||
| print(f"Multi-GPU environment not detected") | ||
|
|
||
| backend = torch.distributed.get_default_backend_for_device(rank) | ||
| torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank) | ||
|
|
||
|
|
||
|
|
||
| class Trainer: | ||
| def __init__( | ||
|
|
@@ -37,7 +47,7 @@ def __init__( | |
| self.model = DDP(self.model, device_ids=[self.gpu_id]) | ||
|
|
||
| def _load_snapshot(self, snapshot_path): | ||
| loc = f"cuda:{self.gpu_id}" | ||
| loc = str(torch.accelerator.current_accelerator()) | ||
| snapshot = torch.load(snapshot_path, map_location=loc) | ||
| self.model.load_state_dict(snapshot["MODEL_STATE"]) | ||
| self.epochs_run = snapshot["EPOCHS_RUN"] | ||
|
|
@@ -103,8 +113,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str | |
| if __name__ == "__main__": | ||
| import argparse | ||
| parser = argparse.ArgumentParser(description='simple distributed training job') | ||
| parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') | ||
| parser.add_argument('save_every', type=int, help='How often to save a snapshot') | ||
| parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model') | ||
| parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot') | ||
| parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | ||
| args = parser.parse_args() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| torch>=1.11.0 | ||
| torch>=2.7 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # /bin/bash | ||
| # bash run_example.sh {file_to_run.py} {num_gpus} | ||
| # where file_to_run = example to run. Default = 'example.py' | ||
| # num_gpus = num local gpus to use (must be at least 2). Default = 2 | ||
|
|
||
| # samples to run include: | ||
| # multigpu_torchrun.py | ||
| # multinode.py | ||
|
|
||
| echo "Launching ${1:-example.py} with ${2:-2} gpus" | ||
| torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} 10 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,8 +50,16 @@ function distributed_tensor_parallelism() { | |
| uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed" | ||
| } | ||
|
|
||
|
|
||
| function distributed_ddp-tutorial-series() { | ||
| uv python multigpu.py 10 1 || error "ddp tutorial series multigpu example failed" | ||
| uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed" | ||
| uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed" | ||
| uv python single_gpu.py 10 1 || error "ddp tutorial series single gpu example failed" | ||
|
|
||
| function distributed_FSDP2() { | ||
| uv run bash run_example.sh example.py || error "FSDP2 example failed" | ||
|
|
||
| } | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the README with the instructions tu run the examples |
||
| function distributed_ddp() { | ||
|
|
@@ -72,6 +80,7 @@ function distributed_rpc_rnn() { | |
|
|
||
| function run_all() { | ||
| run distributed/tensor_parallelism | ||
| run distributed/ddp-tutorial-series | ||
| run distributed/ddp | ||
| run distributed/minGPT-ddp | ||
| run distributed/rpc/ddp_rpc | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
minor change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think it would be better if all
argshas a default valueThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the default values for epochs and save every 👍