Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector

from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, MoEAllReduce)
AllReduceParams, AllReduceStrategy,
MoEAllReduce)
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._utils import get_sm_version
Expand Down Expand Up @@ -646,7 +647,12 @@ def __init__(
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

self.all_reduce = AllReduce(mapping=model_config.mapping)
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
self.all_reduce = AllReduce(
strategy=model_config.allreduce_strategy
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
mapping=model_config.mapping,
)

self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None
Expand Down