Skip to content

Commit 2b90698

Browse files
support gated delta net
1 parent 876a046 commit 2b90698

17 files changed

+1502
-245
lines changed

gpt_builders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
get_gpt_layer_with_inference_spec,
99
get_gpt_mtp_block_spec,
1010
)
11+
from megatron.core.models.gpt.experimental_attention_variant_module_specs import (
12+
is_linear_attention_variant,
13+
)
1114
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
1215
get_gpt_heterogeneous_layer_spec,
1316
)
@@ -42,7 +45,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
4245
else:
4346
use_te = args.transformer_impl == "transformer_engine"
4447

45-
if args.num_experts:
48+
if args.num_experts or is_linear_attention_variant(args.experimental_attention_variant):
4649
assert not (config.transformer_impl == "inference_optimized")
4750
# Define the decoder block spec
4851
transformer_layer_spec = get_gpt_decoder_block_spec(
@@ -117,6 +120,7 @@ def _get_transformer_layer_spec(use_te, config):
117120
args.moe_grouped_gemm,
118121
args.qk_layernorm,
119122
args.multi_latent_attention,
123+
args.experimental_attention_variant,
120124
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
121125
qk_l2_norm=args.qk_l2_norm,
122126
use_kitchen=config.use_kitchen,
@@ -135,6 +139,7 @@ def _get_transformer_layer_spec(use_te, config):
135139
args.moe_grouped_gemm,
136140
args.qk_layernorm,
137141
args.multi_latent_attention,
142+
args.experimental_attention_variant,
138143
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
139144
normalization=args.normalization,
140145
use_kitchen=config.use_kitchen,

megatron/core/jit.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,27 @@
77
jit_fuser = torch.jit.script
88
# nvFuser is deprecated in PyTorch JIT starting from 2.2
99

10-
try:
11-
if is_torch_min_version("2.2.0a0"):
12-
jit_fuser = torch.compile
13-
except ImportError:
1410

15-
def noop_decorator(func):
16-
return func
11+
def noop_decorator(func):
12+
'''No-op decorator'''
13+
return func
1714

15+
16+
def enable_jit_fuser():
17+
'''Enable the JIT fuser'''
18+
global jit_fuser
19+
try:
20+
if is_torch_min_version("2.2.0a0"):
21+
jit_fuser = torch.compile
22+
except ImportError:
23+
24+
jit_fuser = noop_decorator
25+
26+
27+
def disable_jit_fuser():
28+
'''Disable the JIT fuser'''
29+
global jit_fuser
1830
jit_fuser = noop_decorator
31+
32+
33+
enable_jit_fuser()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
3+
from typing import Optional
4+
5+
from megatron.core.models.backends import BackendSpecProvider
6+
from megatron.core.ssm.gated_delta_net import GatedDeltaNet, GatedDeltaNetSubmodules
7+
from megatron.core.transformer.spec_utils import ModuleSpec
8+
9+
10+
def is_linear_attention_variant(experimental_attention_variant: str) -> bool:
11+
"""Check if the experimental attention variant is a linear attention variant."""
12+
linear_attention_variants = ["gated_delta_net"]
13+
return experimental_attention_variant in linear_attention_variants
14+
15+
16+
def get_gated_delta_net_module_spec_for_backend(
17+
backend: BackendSpecProvider, normalization: Optional[str] = None
18+
) -> ModuleSpec:
19+
"""Helper function to get module spec for Linear Attention"""
20+
rms_norm = normalization == "RMSNorm"
21+
attention = ModuleSpec(
22+
module=GatedDeltaNet,
23+
submodules=GatedDeltaNetSubmodules(
24+
in_proj=backend.column_parallel_layer_norm_linear(),
25+
out_norm=backend.layer_norm(rms_norm=rms_norm, for_qk=False),
26+
out_proj=backend.row_parallel_linear(),
27+
),
28+
metainfo={"fuse_input_layernorm": True},
29+
)
30+
return attention
31+
32+
33+
def get_experimental_attention_variant_module_spec_for_backend(
34+
backend: BackendSpecProvider,
35+
sharded_state_dict_keys_map: dict,
36+
experimental_attention_variant: Optional[str] = None,
37+
qk_layernorm: Optional[bool] = False,
38+
qk_l2_norm: Optional[bool] = False,
39+
multi_latent_attention: Optional[bool] = False,
40+
mla_down_proj_use_column_parallel: Optional[bool] = False,
41+
normalization: Optional[str] = None,
42+
fallback_to_eager_attn: Optional[bool] = False,
43+
) -> ModuleSpec:
44+
"""Helper function to get module spec for Attention"""
45+
if experimental_attention_variant == "gated_delta_net":
46+
return get_gated_delta_net_module_spec_for_backend(
47+
backend=backend, normalization=normalization
48+
)
49+
else:
50+
raise ValueError(
51+
f"Invalid experimental attention variant: {experimental_attention_variant}"
52+
)

0 commit comments

Comments
 (0)