Skip to content

TP SP examples improvement #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def rank_log(_rank, logger, msg):

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus
24 changes: 18 additions & 6 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand All @@ -13,6 +14,7 @@

from log_utils import rank_log, get_logger, verify_min_gpu_count

from torch.distributed.tensor.debug import CommDebugMode

# ---- GPU check ------------
_min_gpu_count = 2
Expand Down Expand Up @@ -63,9 +65,10 @@ def forward(self, x):
"""
logger = get_logger()

device_type = torch.accelerator.current_accelerator().type
# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
device_type=device_type, mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()
Expand All @@ -75,7 +78,7 @@ def forward(self, x):
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")
model = ToyModel().to(device_type)

# Custom parallelization plan for the model
sp_model = parallelize_module(
Expand All @@ -87,6 +90,8 @@ def forward(self, x):
},
)

if torch.distributed.get_rank() == 0:
print (f"model {sp_model}")

# Create a optimizer for the parallelized module.
lr = 0.25
Expand All @@ -98,12 +103,19 @@ def forward(self, x):
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")


for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
#inp = torch.rand(20, 10, device=device_type)
inp = torch.rand(1, 10, device=device_type)
comm_mode = CommDebugMode()
with comm_mode:
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")

if i == 0:
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")

rank_log(_rank, logger, "Sequence Parallel training completed!")
25 changes: 17 additions & 8 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand All @@ -10,6 +11,7 @@
)

from log_utils import rank_log, get_logger, verify_min_gpu_count
from torch.distributed.tensor.debug import CommDebugMode

# ---- GPU check ------------
_min_gpu_count = 2
Expand Down Expand Up @@ -76,8 +78,8 @@ def forward(self, x):

# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
device_type = torch.accelerator.current_accelerator().type
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


Expand All @@ -88,8 +90,8 @@ def forward(self, x):

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
tp_model = ToyModel().to(device_type)


# Custom parallelization plan for the model
Expand All @@ -102,6 +104,9 @@ def forward(self, x):
},
)

if torch.distributed.get_rank() == 0:
print (f"model {tp_model}")

# Create an optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
Expand All @@ -116,10 +121,14 @@ def forward(self, x):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
output = tp_model(inp)
output.sum().backward()
optimizer.step()
inp = torch.rand(4, 10, device=device_type)
comm_mode = CommDebugMode()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on non cuda devices? Would be great to share some local logs of your tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gladly. Please see attached logs for H100.

Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 0.
06/16/2025 05:55:00 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:55:03 PM  Tensor Parallel training starting...
06/16/2025 05:55:03 PM  Tensor Parallel iter 0 completed
 rank3 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 1 completed
 rank0 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank2 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 2 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 3 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 4 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 5 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 6 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 7 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 8 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 9 completed
06/16/2025 05:55:04 PM  Tensor Parallel training completed!
[rank0]:[W616 17:55:04.791527408 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Starting PyTorch Sequence Parallel example on rank 0.
06/16/2025 05:53:21 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch Sequence Parallel example on rank 3.
Starting PyTorch Sequence Parallel example on rank 2.
Starting PyTorch Sequence Parallel example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:53:24 PM  Sequence Parallel training starting...
 rank2 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 0 completed
 rank0 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank3 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 1 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 2 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 3 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 4 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 5 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 6 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 7 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 8 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 9 completed
06/16/2025 05:53:25 PM  Sequence Parallel training completed!
[rank0]:[W616 17:53:25.948217933 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant on non CUDA devices, as does this API work if you use MPS or CPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch,accelerator works for cuda and non-cuda GPUs and accelerators. CommDebugMode is also a PyTorch feature, so should work for all devices. If not, that would be a bug.

with comm_mode:
output = tp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Tensor Parallel iter {i} completed")
if i == 1:
print (f" rank{torch.distributed.get_rank()} {i} get_comm_counts {comm_mode.get_comm_counts()} get_sharding_info() {comm_mode.get_sharding_info()} generate_comm_debug_tracing_table {comm_mode.generate_comm_debug_tracing_table(noise_level=1)} ")

rank_log(_rank, logger, "Tensor Parallel training completed!")