Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 118 additions & 76 deletions tests/microbenchmarks/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
# isort: off
import torch
# isort: on
from cuda import cuda, cudart
from cuda import cudart

import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
from tensorrt_llm import Mapping
from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._utils import local_mpi_rank, local_mpi_size
from tensorrt_llm.functional import (AllReduceParams, AllReduceStrategy,
allreduce)
from tensorrt_llm.plugin.plugin import (current_all_reduce_helper,
init_all_reduce_helper)
from tensorrt_llm.runtime import Session
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy


def allreduce_benchmark(dtype: str,
test_range: str = "10,10000000,10",
no_header: bool = False):
test_range: str = "1,10000000,10",
no_header: bool = False,
enable_cudagraph: bool = False):
tllm.logger.set_level('error')
world_size = tllm.mpi_world_size()
rank = tllm.mpi_rank()
Expand All @@ -49,80 +49,120 @@ def allreduce_benchmark(dtype: str,

torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
min_size, max_size, ratio = [int(i) for i in test_range.split(",")]
inner_loop = 1000
inner_loop = 1200
outer_loop = 10

size = min_size
dtype_size = torch.finfo(torch_dtype).bits // 8
hidden_size = size
bs = 1
if mapping.rank == 0 and not no_header:
print(
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<15}, {'duration (ms)':<10}"
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<10}, {'fusion':<20}, {'version':<10}, {'duration (ms)':<10}"
)
while size < max_size:
input = torch.ones(size, dtype=torch_dtype, device="cuda")

for strategy in [
AllReduceStrategy.AUTO,
AllReduceStrategy.NCCL,
AllReduceStrategy.ONESHOT,
AllReduceStrategy.TWOSHOT,
]:
builder = tllm.Builder()
net = builder.create_network()
net.plugin_config.set_nccl_plugin(dtype)
init_all_reduce_helper()
_buffers, workspace = current_all_reduce_helper(
).allocate_workspace(mapping, size * dtype_size)

with tllm.net_guard(net):
tllm.default_trtnet()

x = Tensor(name='x',
shape=input.shape,
dtype=tllm.str_dtype_to_trt(dtype))

current_all_reduce_helper().set_workspace_tensor(mapping)

current = x
for _ in range(inner_loop):
current = allreduce(
current,
mapping.tp_group,
all_reduce_params=AllReduceParams(strategy=strategy))
current.mark_output('output', dtype)
feed_dict = {'x': input, 'all_reduce_workspace': workspace}
builder_config = builder.create_builder_config(precision=dtype)
engine = builder.build_engine(net, builder_config)
assert engine is not None, "Failed to build engine"
session = Session.from_serialized_engine(engine)

_, start = cuda.cuEventCreate(0)
_, stop = cuda.cuEventCreate(0)
runtimes = []

tllm.mpi_barrier()
output = torch.empty(input.shape, dtype=torch_dtype, device='cuda')
stream = torch.cuda.current_stream()
for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)

median_ms = sorted(runtimes)[len(runtimes) // 2]

allreduce_ref = (input * world_size)**inner_loop
assert torch.allclose(output, allreduce_ref)

if mapping.rank == 0:
print(
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<15}, {median_ms:<10.2f}"
)
input = torch.ones((bs, hidden_size), dtype=torch_dtype, device="cuda")

for version in ["v1"]:
for fusion in [
AllReduceFusionOp.RESIDUAL_RMS_NORM, AllReduceFusionOp.NONE
]:
for strategy in [
AllReduceStrategy.NCCL,
AllReduceStrategy.ONESHOT,
AllReduceStrategy.TWOSHOT,
]:
if size >= 25600000 and fusion != AllReduceFusionOp.NONE:
continue
allreduce = AllReduce(mapping=mapping, strategy=strategy)
if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM:
norm_weight = torch.randn((hidden_size, ),
dtype=torch_dtype,
device="cuda")
norm = RMSNorm(hidden_size=hidden_size,
dtype=torch_dtype,
eps=1e-5).cuda()
norm.weight.data.copy_(norm_weight)
if version == "v1":
params = {
"all_reduce_params":
AllReduceParams(fusion_op=fusion,
residual=input,
norm_weight=norm.weight,
eps=norm.variance_epsilon)
}
else:
params = {
"reduce_fusion_inputs": [input, norm.weight],
"eps": norm.variance_epsilon,
"fusion_op": fusion
}
else:
if version == "v1":
params = {
"all_reduce_params":
AllReduceParams(fusion_op=fusion)
}
else:
continue

def func(input):
for _ in range(inner_loop):
input = allreduce(input, **params)
if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM:
input = input[0]
return input

start = [
torch.cuda.Event(enable_timing=True)
for _ in range(outer_loop)
]
stop = [
torch.cuda.Event(enable_timing=True)
for _ in range(outer_loop)
]
graph = torch.cuda.CUDAGraph()

stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
if enable_cudagraph:
for _ in range(2):
func(input)
with torch.cuda.graph(graph, stream=stream):
output = func(input)
tllm.mpi_barrier()
delay_kernel(2000000, stream)
torch.cuda.profiler.start()
for i in range(outer_loop):
start[i].record(stream)
if enable_cudagraph:
graph.replay()
else:
output = func(input)
stop[i].record(stream)

torch.cuda.synchronize()
torch.cuda.profiler.stop()
runtimes = [
start[i].elapsed_time(stop[i])
for i in range(outer_loop)
]
median_ms = sorted(runtimes)[len(runtimes) // 2]

if fusion == AllReduceFusionOp.NONE:
allreduce_ref = (input * world_size)**inner_loop
torch.testing.assert_close(output, allreduce_ref)

if mapping.rank == 0:
print(
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<10}, {fusion.name:<20}, {version:<10}, {median_ms:<10.2f}"
)

size *= ratio
if hidden_size * ratio > 4096:
bs *= ratio
else:
hidden_size *= ratio
assert size == bs * hidden_size


if __name__ == "__main__":
Expand All @@ -134,6 +174,8 @@ def allreduce_benchmark(dtype: str,
default="256,256000000,10", # 256 to 256M
help="min_size,max_size,multiplicative_ratio")
parser.add_argument("--no-header", action="store_true")
parser.add_argument("--enable-cudagraph", action="store_true")
args = parser.parse_args()

allreduce_benchmark(args.dtype, args.range, args.no_header)
allreduce_benchmark(args.dtype, args.range, args.no_header,
args.enable_cudagraph)