Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
b387190
feat: module fusion API for kernels
michaelbenayoun Apr 10, 2026
6bc9402
fix: improve __repr__ for fused modules
michaelbenayoun Apr 10, 2026
62d4454
wip: integration to KernelConfig
michaelbenayoun Apr 10, 2026
4082fe1
wip: add temporary example
michaelbenayoun Apr 10, 2026
ac4a699
wip: pattern matching in KernelConfig and actual kernel repo
michaelbenayoun Apr 13, 2026
e13111f
refactor: move relevant code to hub_kernels.py
michaelbenayoun Apr 13, 2026
d9d53f0
docs: reformat docstring
michaelbenayoun Apr 13, 2026
e1c7f3f
refactor: remove comment
michaelbenayoun Apr 13, 2026
db0b7f0
Merge branch 'main' into fused_kernels
michaelbenayoun Apr 13, 2026
e21d06e
Merge branch 'main' into fused_kernels
michaelbenayoun Apr 27, 2026
bd640ae
refactor: update example script for testing
michaelbenayoun Apr 27, 2026
323b000
wip: remove apply_fusions method
michaelbenayoun Apr 27, 2026
fe3002d
wip: add core feature for integration with the current fusing API
michaelbenayoun Apr 27, 2026
b541453
fix: move kernel mapping patching to kernelize
michaelbenayoun Apr 28, 2026
b3d73a7
wip: update example script
michaelbenayoun Apr 28, 2026
0845222
wip: add transform_model method for WeightTransform
michaelbenayoun May 1, 2026
ecfd97d
wip: conversion_mapping in Kernel
michaelbenayoun May 1, 2026
973a616
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 1, 2026
0f0a64b
wip: remove things from __all__
michaelbenayoun May 4, 2026
91177ae
wip: remove imports
michaelbenayoun May 4, 2026
2636c06
fix: remove register_fusion_pattern path
michaelbenayoun May 4, 2026
573d9f0
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 4, 2026
f7c15bd
fix: remove unused attribute
michaelbenayoun May 4, 2026
f9d4299
wip: update experimentation script
michaelbenayoun May 5, 2026
847dbd4
refactor: add convert as abstract method
michaelbenayoun May 5, 2026
4443c9a
style: reformat hub_kernels.py
michaelbenayoun May 5, 2026
51c59c9
wip: transform_model API
michaelbenayoun May 5, 2026
4c58503
wip: transform_model API, WeightTransform
michaelbenayoun May 5, 2026
a7f983f
wip: transform_model API, WeightConverter
michaelbenayoun May 5, 2026
b8d860f
wip: transform_model API, WeightConverter
michaelbenayoun May 5, 2026
3d5f353
wip: make transform_model idempotent
michaelbenayoun May 6, 2026
c35b513
refactor: infer_kernel_fusion_transforms
michaelbenayoun May 6, 2026
6e369c0
style: regexs -> regexes
michaelbenayoun May 6, 2026
731d0b7
refactor: register_kernel_fusions
michaelbenayoun May 6, 2026
4e56ad3
refactor: post transformation cleanup
michaelbenayoun May 6, 2026
96ac123
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 6, 2026
105a403
style: fix comment
michaelbenayoun May 7, 2026
b1c9645
test: add TestApplyTransformsToMetaModel tests
michaelbenayoun May 7, 2026
b10b864
test: add kernels test
michaelbenayoun May 7, 2026
d1bd06d
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 7, 2026
fb0a748
test: fix hub_kernels package reload
michaelbenayoun May 12, 2026
ded8b5f
style: ruff
michaelbenayoun May 12, 2026
6d6411b
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 12, 2026
318553e
refactor: do not create dynamic classes in test
michaelbenayoun May 12, 2026
2a24760
refactor: no dynamic class creation in tests
michaelbenayoun May 12, 2026
8136781
refactor: test
michaelbenayoun May 13, 2026
aa7743a
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 13, 2026
17f4a9e
fix: TYPE_CHECKING imports were broken
michaelbenayoun May 13, 2026
8ec88d1
Merge branch 'main' into extended_kernels_api
michaelbenayoun May 13, 2026
fabadca
wip: get rid of transform_model methods
michaelbenayoun May 19, 2026
e1f3a83
wip: move tests
michaelbenayoun May 19, 2026
68c3659
wip: make conversion happen before fused module instantiation
michaelbenayoun May 19, 2026
041182f
refactor
michaelbenayoun May 19, 2026
519e673
wip: move conversion_mapping inside the init
michaelbenayoun May 19, 2026
88e0aee
wip: without any transform_model
michaelbenayoun May 26, 2026
ad0e24e
wip: remove dead code
michaelbenayoun May 26, 2026
3924cf3
wip: api imrpovement
michaelbenayoun May 26, 2026
c948834
wip: refactor
michaelbenayoun May 26, 2026
e0c0366
wip: enable __init__ support in kernels
michaelbenayoun Jun 2, 2026
da0fdae
wip: fuse + init
michaelbenayoun Jun 2, 2026
06add71
clean: remove "dead" code
michaelbenayoun Jun 2, 2026
597bb8c
wip: use two classes in kernels
michaelbenayoun Jun 2, 2026
1489146
wip: remove docstring
michaelbenayoun Jun 2, 2026
024fd4c
Merge branch 'main' into extended_kernel_api_easy
michaelbenayoun Jun 2, 2026
b20cb71
test: add relevant tests
michaelbenayoun Jun 2, 2026
1fb1787
chore: remove experiment file
michaelbenayoun Jun 2, 2026
2e87c9f
cleanup: remove helper function
michaelbenayoun Jun 4, 2026
11ecc56
cleanup: remove helper function
michaelbenayoun Jun 4, 2026
d98489a
refactor: merge the two register kernel functions into one
michaelbenayoun Jun 4, 2026
89adf48
cleanup: use explicit regex patterns to match for monkey patching
michaelbenayoun Jun 4, 2026
c841433
test: cleanup and update tests
michaelbenayoun Jun 4, 2026
edf33ab
doc: add docstring to make_parent_class_for_kernel_fusion
michaelbenayoun Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import re
import traceback
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
Expand Down Expand Up @@ -80,7 +80,7 @@ def build_glob_alternation(
return alternation, src_group_to_glob, tgt_group_to_glob


class ConversionOps:
class ConversionOps(ABC):
"""Base class for weight conversion operations."""

def __repr__(self):
Expand Down
181 changes: 176 additions & 5 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,39 @@
import importlib.metadata
import os
import re
import sys
from collections.abc import Callable
from contextlib import contextmanager
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING

from packaging import version as pkg_version

from ..conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
from ..monkey_patching import register_patch_mapping
from ..utils import ENV_VARS_TRUE_VALUES, logging
from ..utils.import_utils import is_kernels_available
from ..utils.import_utils import is_kernels_available, is_torch_available
from .flash_attention import flash_attention_forward


if TYPE_CHECKING:
from ..configuration_utils import PretrainedConfig
from ..modeling_utils import PreTrainedModel
from ..utils.kernel_config import KernelConfig

if is_torch_available():
import torch
import torch.nn as nn


logger = logging.get_logger(__name__)

try:
from kernels import (
Device,
LayerRepository,
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
Expand Down Expand Up @@ -269,6 +285,10 @@ class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")

class LocalLayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LocalLayerRepository requires `kernels` to be installed. Run `pip install kernels`.")

def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
Expand Down Expand Up @@ -501,14 +521,165 @@ def allow_all_hub_kernels():
ALLOW_ALL_KERNELS = False


def make_parent_class_for_kernel_fusion(
parent_cls: type,
child_names: list[str],
kernel_cls: type,
) -> type:
"""
Create a new class that inherits from `parent_cls` and fuses the child modules specified in `child_names
with the provided `kernel_cls`.
The first child in `child_names` will be replaced with the `kernel_cls`, and the rest will be replaced with
`nn.Identity()` to keep the same interface.
"""
original_init = parent_cls.__init__

def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
children = [getattr(self, name) for name in child_names]
kernel_instance = kernel_cls(*children)
setattr(self, child_names[0], kernel_instance)
for name in child_names[1:]:
setattr(self, name, nn.Identity())

patched_cls = type(f"Fused{parent_cls.__name__}", (parent_cls,), {"__init__": patched_init})
patched_cls.__qualname__ = f"Fused{parent_cls.__qualname__}"
return patched_cls


def register_kernel_replacements_and_fusions(
cls: "type[PreTrainedModel]",
config: "PretrainedConfig",
kernel_config: "KernelConfig",
) -> None:
if not hasattr(cls, "config_class") or not hasattr(cls.config_class, "model_type"):
raise ValueError(f"Model {cls.__name__} has no config_class or model_type.")
model_type = cls.config_class.model_type

patch_mapping: dict[str, type] = {}
new_mapping: dict = {}

# We might need to instantiate the model on meta device.
# We do it lazily, only if we encounter a fused kernel.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

meta_model = None

for layer_name, hub_repo in kernel_config.kernel_mapping.items():
if isinstance(hub_repo, dict):
if len(hub_repo.values()) != 1:
raise ValueError(
f"Expected exactly one kernel repo regardless of device/mode specificity, got {hub_repo}"
)
repo_str = next(iter(hub_repo.values()))
elif isinstance(hub_repo, str):
repo_str = hub_repo
else:
raise ValueError(f"Invalid hub repo {hub_repo!r} for layer {layer_name!r}")

repo_id, _, layer_name_in_repo = repo_str.partition(":")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

if not repo_id or not layer_name_in_repo:
raise ValueError(f"Invalid kernel repo string {repo_str!r} for layer {layer_name!r}")

if kernel_config.use_local_kernel:
package_name = repo_id.rstrip("/").split("/")[-1]
repo = LocalLayerRepository(
repo_path=Path(repo_id),
package_name=package_name,
layer_name=layer_name_in_repo,
)
else:
repo = LayerRepository(repo_id=repo_id, layer_name=layer_name_in_repo)

kernel_cls = repo.load()

if kernel_cls is None:
raise ValueError(f"Could not load kernel class from hub_repo={hub_repo!r}")

kernel_mod = sys.modules.get(kernel_cls.__module__)
layout_cls = getattr(kernel_mod, f"{kernel_cls.__name__}Layout", None) if kernel_mod else None

# Case 1: no fusion.
if isinstance(layer_name, str):
# No layout class: stateless kernel, leave for kernels.kernelize.
if layout_cls is None:
new_mapping[layer_name] = repo_str
continue

# Register the layout class as a monkey patch for the parent module containing the target layer.
layout_cls.kernel_layer_name = kernel_cls.__name__
patch_mapping[layer_name] = layout_cls

# Keep the original repo string so kernelize can replace the layout's forward.
new_mapping[kernel_cls.__name__] = repo_str

# Case 2: fusion.
elif isinstance(layer_name, tuple):
if layout_cls is None:
raise ValueError(
f"Fused kernel {kernel_cls.__name__!r} requires a companion layout class "
f"named '{kernel_cls.__name__}Layout' in the same module."
)

layout_cls.kernel_layer_name = kernel_cls.__name__

glob_patterns = [item[1] for item in layer_name]
parent_patterns = [p.rsplit(".", 1)[0] for p in glob_patterns]

if len(set(parent_patterns)) != 1:
raise ValueError(
f"All patterns for a fused kernel must share the same parent module, got {glob_patterns}"
)

parent_pattern = parent_patterns[0].replace("*", r"\w+")
child_names = [p.rsplit(".", 1)[1] for p in glob_patterns]

if meta_model is None:
with torch.device("meta"):
meta_model = cls(config)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
meta_model = cls(config)
meta__modules = cls(config).named_modules()

we only need these

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah maybe it gets updated but that's good, you can'tupdate twice so its even better in a way no? (to not re-compute the named modules)


matched_any = False
for name, module in meta_model.named_modules():
if not re.fullmatch(parent_pattern, name):
continue
if not all(hasattr(module, child) for child in child_names):
raise ValueError(
f"Module {name!r} does not have the expected child modules {child_names} required for "
f"the fused kernel {kernel_cls.__name__!r}"
)
matched_any = True
module_cls = type(module)
patch_mapping[module_cls.__name__] = make_parent_class_for_kernel_fusion(
module_cls, child_names, layout_cls
)

if not matched_any:
raise ValueError(
f"No module matched pattern {parent_pattern!r} for fused kernel {kernel_cls.__name__!r}. "
f"Provide the full dotted path from the model root."
)

register_patch_mapping(patch_mapping, overwrite=True)

if hasattr(layout_cls, "conversion_mapping"):
existing = get_checkpoint_conversion_mapping(model_type)
transforms = list(layout_cls.conversion_mapping)
if existing is not None:
transforms = existing + transforms
register_checkpoint_conversion_mapping(model_type, transforms, overwrite=True)

new_mapping[kernel_cls.__name__] = repo_str

kernel_config.kernel_mapping = new_mapping


__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"get_kernel",
"lazy_load_kernel",
"register_kernel_mapping",
"register_kernel_mapping_transformers",
"register_kernel_replacements_and_fusions",
"replace_kernel_forward_from_hub",
"lazy_load_kernel",
"get_kernel",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"use_kernelized_func",
] # type: ignore
33 changes: 23 additions & 10 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,9 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path

# Make kernel_config an attribute that can be used by the model.
self.kernel_config = None

# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
# setting it recursively)
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
Expand Down Expand Up @@ -3782,30 +3785,27 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None
raise ValueError(
"`use_kernels=True` requires kernels>=0.9.0. Please install the latest version with `pip install -U kernels`"
)
from kernels import use_kernel_mapping

from .integrations.hub_kernels import register_kernel_mapping_transformers

register_kernel_mapping_transformers()

if kernel_config is not None and isinstance(kernel_config, KernelConfig):
# Since kernel_config is a correct value, set it as an attribute of the model so it can be used.
self.kernel_config = kernel_config

# This will make sure the mapping is valid, and the layers are registered in the model
kernel_config.sanitize_kernel_mapping(self)

# This will create a compatible mapping for the model with the kernels library
kernel_config.create_compatible_mapping(self)

# This is a context manager to override the default kernel mapping
# We are calling kernelize inside this context manager using the use_kernels setter
# Param inherit_mapping should be False to avoid still loading kernel from remote
inherit_mapping = not kernel_config.use_local_kernel
with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
self.use_kernels = True
self.use_kernels = True
# We use the default kernel mapping in .integrations.hub_kernels
else:
self.use_kernels = True
self.kernel_config = None
else:
self.use_kernels = False
self.kernel_config = None

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -4245,6 +4245,12 @@ def from_pretrained(

register_fusion_patches(cls, config, fusion_config)

# Kernel patches: single-layer replacement (stateful __init__) then fusions.
if kernel_config is not None and use_kernels:
from .integrations.hub_kernels import register_kernel_replacements_and_fusions

register_kernel_replacements_and_fusions(cls, config, kernel_config)

model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
Expand Down Expand Up @@ -4644,7 +4650,14 @@ def detach_hidden_kernels(module):
self.apply(attach_hidden_kernels)

mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode
kernelize(self, device=Device(type=self.device.type), mode=mode)
if self.kernel_config is not None:
from kernels import use_kernel_mapping

inherit_mapping = not self.kernel_config.use_local_kernel
with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
kernelize(self, device=Device(type=self.device.type), mode=mode)
else:
kernelize(self, device=Device(type=self.device.type), mode=mode)
Comment on lines +4653 to +4660

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.kernel_config is not None:
from kernels import use_kernel_mapping
inherit_mapping = not self.kernel_config.use_local_kernel
with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
kernelize(self, device=Device(type=self.device.type), mode=mode)
else:
kernelize(self, device=Device(type=self.device.type), mode=mode)
kernelize(self, device=Device(type=self.device.type), mode=mode, self.kernel_config)

let's reduce surface as much as possible

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelize is defined in kernels. I can make a PR there, but for now it cannot be changed here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay! we can also just create def kernelize to put in kernels utils!

self._use_kernels = True

finally:
Expand Down
Loading
Loading