Skip to content

Extended & simplified n-to-1 kernel fusion via KernelConfig#46339

Merged
michaelbenayoun merged 72 commits into
huggingface:mainfrom
michaelbenayoun:extended_kernel_api_easy
Jun 9, 2026
Merged

Extended & simplified n-to-1 kernel fusion via KernelConfig#46339
michaelbenayoun merged 72 commits into
huggingface:mainfrom
michaelbenayoun:extended_kernel_api_easy

Conversation

@michaelbenayoun

@michaelbenayoun michaelbenayoun commented Jun 2, 2026

Copy link
Copy Markdown
Member

What does this PR do?

Extends the KernelConfig API with two orthogonal capabilities:

  • Module fusion: specify how Transformers modules should be fused together before a custom kernel is applied (n-to-1 replacement).

  • Parameter transformation: handle cases where a kernel expects weights in a different layout than the original modeling (e.g. fused linears).

Compared to previous PR, this approach is more explicit and way simpler, putting much of the burden to the kernel authors.

How it works

The kernel author needs to define two classes:

  • KernelName: defines the forward pass, used by the kernels library to kernelize the model
  • KernelNameLayout: defines the conversion_mapping as well as an __init__ method. This is used to monkey-patch the model

Having two classes because the kernels library prevents us from having stateful kernel classes.
While it might not be as pleasing as having one big class, it separates concerns.

Script for the examples

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig


model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer("Hello, how are you?", return_tensors="pt")

# --- baseline: plain model, no fusion ---
print("=" * 60)
print("Loading baseline model (no fusion)...")
baseline = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", use_kernels=True)
baseline.eval()
inputs = {k: v.to(baseline.device) for k, v in inputs.items()}

with torch.no_grad():
    baseline_out = baseline(**inputs).logits
print("Baseline output shape:", baseline_out.shape)
# del baseline

# --- fused model ---
print("=" * 60)
print("Loading fused model...")

# kernel_repo_id = "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP"
# kernel_repo_id = "michaelbenayoun/dummy-rmsnorm-mlp-with-transformations:RMSNormMLP"
kernel_repo_id = "michaelbenayoun/dummy-rmsnorm-mlp-with-transformations-and-init:RMSNormMLP"
kernel_repo_id = "michaelbenayoun/dummy-rmsnorm-kernel-with-init:CustomRMSNorm"
kernel_config = KernelConfig(
    {
        # (
        #     ("RMSNorm", "model.layers.*.post_attention_layernorm"),
        #     ("MLP",     "model.layers.*.mlp"),
        # ): kernel_repo_id,
        "RMSNorm": kernel_repo_id,
    },
)

fused_model = AutoModelForCausalLM.from_pretrained(
    model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda"
)
fused_model.eval()
print(fused_model)

with torch.no_grad():
    fused_out = fused_model(**inputs).logits
print("Fused output shape:", fused_out.shape)

# --- compare ---
print("=" * 60)
print("Max diff fused vs baseline:", (fused_out - baseline_out).abs().max().item())

Example 1: Parameter transformation, no fusion

In this case, the KernelNameLayout class's __init__ method has the same signature as the module being replaced.

import torch
import torch.nn as nn

from transformers.conversion_mapping import WeightRenaming

class CustomRMSNormLayout(nn.Module):
    conversion_mapping = [
        WeightRenaming(
            source_patterns=r"(.*(?:input_layernorm|post_attention_layernorm|norm)\.)weight",
            target_patterns=r"\1scale",
        ),
    ]

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pass  # replaced at runtime by kernelize


class CustomRMSNorm(nn.Module):
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        print("This dummy kernel is used")
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.scale * hidden_states.to(input_dtype)


class layers:
    CustomRMSNorm = CustomRMSNorm

Example 2: Fusion and parameter transformation

Compared to the first example, here we will fuse two modules in the original model into one module.
Because of this, the __init__ method does not have the same signature, but rather take the instantiated modules it's fusing.

import torch
import torch.nn as nn

from transformers import Concatenate, WeightConverter
from transformers.conversion_mapping import WeightRenaming


class RMSNormMLPLayout(nn.Module):
    conversion_mapping = [
        # norm.weight → scale (placed at post_attention_layernorm.scale)
        WeightRenaming(
            source_patterns=r"(.*post_attention_layernorm\.)weight",
            target_patterns=r"\1scale",
        ),
        # mlp.gate_proj + mlp.up_proj → post_attention_layernorm.gate_up_proj (concat)
        WeightConverter(
            ["mlp.gate_proj", "mlp.up_proj"],
            "post_attention_layernorm.gate_up_proj",
            [Concatenate(dim=0)],
        ),
        # mlp.down_proj.* → post_attention_layernorm.down_proj.*
        WeightRenaming(
            source_patterns=r"(.*\.)mlp\.(down_proj\..*)",
            target_patterns=r"\1post_attention_layernorm.\2",
        ),
    ]

    def __init__(self, norm, mlp):
        super().__init__()
        self.variance_epsilon = norm.variance_epsilon
        self.scale = nn.Parameter(torch.empty_like(norm.weight))
        self.gate_up_proj = nn.Linear(
            mlp.gate_proj.in_features,
            mlp.gate_proj.out_features + mlp.up_proj.out_features,
            bias=False,
            device=mlp.gate_proj.weight.device,
            dtype=mlp.gate_proj.weight.dtype,
        )
        self.down_proj = mlp.down_proj
        self.act_fn = mlp.act_fn

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pass


class RMSNormMLP(nn.Module):
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        hidden_states = self.scale * hidden_states.to(input_dtype)
        gate, up = self.gate_up_proj(hidden_states).chunk(2, dim=-1)
        return self.down_proj(self.act_fn(gate) * up)


class layers:
    RMSNormMLP = RMSNormMLP

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better!

Comment on lines +4654 to +4661
if self.kernel_config is not None:
from kernels import use_kernel_mapping

inherit_mapping = not self.kernel_config.use_local_kernel
with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
kernelize(self, device=Device(type=self.device.type), mode=mode)
else:
kernelize(self, device=Device(type=self.device.type), mode=mode)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.kernel_config is not None:
from kernels import use_kernel_mapping
inherit_mapping = not self.kernel_config.use_local_kernel
with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
kernelize(self, device=Device(type=self.device.type), mode=mode)
else:
kernelize(self, device=Device(type=self.device.type), mode=mode)
kernelize(self, device=Device(type=self.device.type), mode=mode, self.kernel_config)

let's reduce surface as much as possible

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelize is defined in kernels. I can make a PR there, but for now it cannot be changed here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay! we can also just create def kernelize to put in kernels utils!

Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment on lines +697 to +703
for module in meta_model.modules():
module_cls = type(module)
if module_cls in seen:
continue
if not all(hasattr(module, name) for name in child_names):
continue
seen.add(module_cls)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to iterate over all the modules!
We could register like we do for the tp plan with explicit path, we like explicitness in general!

{ "layers.*.self_attn.q_proj" : XXXX} 

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MOST important comment IMO if the contract is more like this we have a lot of simplifications no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have this contract.

kernel_config = KernelConfig(
    {
        (
            ("RMSNorm", "model.layers.*.post_attention_layernorm"),
            ("MLP",     "model.layers.*.mlp"),
        ): kernel_repo_id,
    },
)

I will update this loop

kernel_config.kernel_mapping = new_mapping


def register_kernel_fusions(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do both in a single func!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment on lines +545 to +554
def _first_str_leaf(obj) -> str | None:
"""Recursively extract the first string leaf from a potentially nested dict (device → mode → str)."""
if isinstance(obj, str):
return obj
if isinstance(obj, dict):
for v in obj.values():
result = _first_str_leaf(v)
if result is not None:
return result
return None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ALLOW_ALL_KERNELS = False


def make_kernel_init_parent_class(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super important needs to be documented well:

  • we replace the fused cls by identity
  • thus we have to patch some inits, etc etc c
    also do we even have to patch inits when the proper class replaces the one that holds them?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much much better! Its just missing a piece of doc / update the doc for monkey patching, maybe some bench if you have but that's fine for another PR !

Ty for iterating its quite nice now!

new_mapping: dict = {}

# We might need to instantiate the model on meta device.
# We do it lazily, only if we encounter a fused kernel.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

else:
raise ValueError(f"Invalid hub repo {hub_repo!r} for layer {layer_name!r}")

repo_id, _, layer_name_in_repo = repo_str.partition(":")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


if meta_model is None:
with torch.device("meta"):
meta_model = cls(config)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
meta_model = cls(config)
meta__modules = cls(config).named_modules()

we only need these

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah maybe it gets updated but that's good, you can'tupdate twice so its even better in a way no? (to not re-compute the named modules)

Comment on lines +4654 to +4661
if self.kernel_config is not None:
from kernels import use_kernel_mapping

inherit_mapping = not self.kernel_config.use_local_kernel
with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
kernelize(self, device=Device(type=self.device.type), mode=mode)
else:
kernelize(self, device=Device(type=self.device.type), mode=mode)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay! we can also just create def kernelize to put in kernels utils!

@michaelbenayoun michaelbenayoun added this pull request to the merge queue Jun 9, 2026
Merged via the queue into huggingface:main with commit 5047f08 Jun 9, 2026
117 of 118 checks passed
@michaelbenayoun michaelbenayoun deleted the extended_kernel_api_easy branch June 9, 2026 12:21
@michaelbenayoun

Copy link
Copy Markdown
Member Author

For the kernelize refactor, it is done here: #46520.

louzongzhi pushed a commit to louzongzhi/transformers that referenced this pull request Jun 10, 2026
…ace#46339)

* feat: module fusion API for kernels

* fix: improve __repr__ for fused modules

* wip: integration to KernelConfig

* wip: add temporary example

* wip: pattern matching in KernelConfig and actual kernel repo

* refactor: move relevant code to hub_kernels.py

* docs: reformat docstring

* refactor: remove comment

* refactor: update example script for testing

* wip: remove apply_fusions method

* wip: add core feature for integration with the current fusing API

* fix: move kernel mapping patching to kernelize

* wip: update example script

* wip: add transform_model method for WeightTransform

* wip: conversion_mapping in Kernel

* wip: remove things from __all__

* wip: remove imports

* fix: remove register_fusion_pattern path

* fix: remove unused attribute

* wip: update experimentation script

* refactor: add convert as abstract method

* style: reformat hub_kernels.py

* wip: transform_model API

* wip: transform_model API, WeightTransform

* wip: transform_model API, WeightConverter

* wip: transform_model API, WeightConverter

* wip: make transform_model idempotent

* refactor: infer_kernel_fusion_transforms

* style: regexs -> regexes

* refactor: register_kernel_fusions

* refactor: post transformation cleanup

* style: fix comment

* test: add TestApplyTransformsToMetaModel tests

* test: add kernels test

* test: fix hub_kernels package reload

* style: ruff

* refactor: do not create dynamic classes in test

* refactor: no dynamic class creation in tests

* refactor: test

* fix: TYPE_CHECKING imports were broken

* wip: get rid of transform_model methods

* wip: move tests

* wip: make conversion happen before fused module instantiation

* refactor

* wip: move conversion_mapping inside the init

* wip: without any transform_model

* wip: remove dead code

* wip: api imrpovement

* wip: refactor

* wip: enable __init__ support in kernels

* wip: fuse + init

* clean: remove "dead" code

* wip: use two classes in kernels

* wip: remove docstring

* test: add relevant tests

* chore: remove experiment file

* cleanup: remove helper function

* cleanup: remove helper function

* refactor: merge the two register kernel functions into one

* cleanup: use explicit regex patterns to match for monkey patching

* test: cleanup and update tests

* doc: add docstring to make_parent_class_for_kernel_fusion
louzongzhi pushed a commit to louzongzhi/transformers that referenced this pull request Jun 10, 2026
…ace#46339)

* feat: module fusion API for kernels

* fix: improve __repr__ for fused modules

* wip: integration to KernelConfig

* wip: add temporary example

* wip: pattern matching in KernelConfig and actual kernel repo

* refactor: move relevant code to hub_kernels.py

* docs: reformat docstring

* refactor: remove comment

* refactor: update example script for testing

* wip: remove apply_fusions method

* wip: add core feature for integration with the current fusing API

* fix: move kernel mapping patching to kernelize

* wip: update example script

* wip: add transform_model method for WeightTransform

* wip: conversion_mapping in Kernel

* wip: remove things from __all__

* wip: remove imports

* fix: remove register_fusion_pattern path

* fix: remove unused attribute

* wip: update experimentation script

* refactor: add convert as abstract method

* style: reformat hub_kernels.py

* wip: transform_model API

* wip: transform_model API, WeightTransform

* wip: transform_model API, WeightConverter

* wip: transform_model API, WeightConverter

* wip: make transform_model idempotent

* refactor: infer_kernel_fusion_transforms

* style: regexs -> regexes

* refactor: register_kernel_fusions

* refactor: post transformation cleanup

* style: fix comment

* test: add TestApplyTransformsToMetaModel tests

* test: add kernels test

* test: fix hub_kernels package reload

* style: ruff

* refactor: do not create dynamic classes in test

* refactor: no dynamic class creation in tests

* refactor: test

* fix: TYPE_CHECKING imports were broken

* wip: get rid of transform_model methods

* wip: move tests

* wip: make conversion happen before fused module instantiation

* refactor

* wip: move conversion_mapping inside the init

* wip: without any transform_model

* wip: remove dead code

* wip: api imrpovement

* wip: refactor

* wip: enable __init__ support in kernels

* wip: fuse + init

* clean: remove "dead" code

* wip: use two classes in kernels

* wip: remove docstring

* test: add relevant tests

* chore: remove experiment file

* cleanup: remove helper function

* cleanup: remove helper function

* refactor: merge the two register kernel functions into one

* cleanup: use explicit regex patterns to match for monkey patching

* test: cleanup and update tests

* doc: add docstring to make_parent_class_for_kernel_fusion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants