diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 1a49bd9775d4..2465bb292002 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -19,7 +19,7 @@ import os import re import traceback -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor @@ -80,7 +80,7 @@ def build_glob_alternation( return alternation, src_group_to_glob, tgt_group_to_glob -class ConversionOps: +class ConversionOps(ABC): """Base class for weight conversion operations.""" def __repr__(self): diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 76344e5453e5..3717d7ea06c0 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -14,23 +14,39 @@ import importlib.metadata import os import re +import sys from collections.abc import Callable from contextlib import contextmanager +from pathlib import Path from types import ModuleType +from typing import TYPE_CHECKING from packaging import version as pkg_version +from ..conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping +from ..monkey_patching import register_patch_mapping from ..utils import ENV_VARS_TRUE_VALUES, logging -from ..utils.import_utils import is_kernels_available +from ..utils.import_utils import is_kernels_available, is_torch_available from .flash_attention import flash_attention_forward +if TYPE_CHECKING: + from ..configuration_utils import PretrainedConfig + from ..modeling_utils import PreTrainedModel + from ..utils.kernel_config import KernelConfig + +if is_torch_available(): + import torch + import torch.nn as nn + + logger = logging.get_logger(__name__) try: from kernels import ( Device, LayerRepository, + LocalLayerRepository, Mode, register_kernel_mapping, replace_kernel_forward_from_hub, @@ -269,6 +285,10 @@ class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") + class LocalLayerRepository: + def __init__(self, *args, **kwargs): + raise RuntimeError("LocalLayerRepository requires `kernels` to be installed. Run `pip install kernels`.") + def replace_kernel_forward_from_hub(*args, **kwargs): raise RuntimeError( "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." @@ -501,14 +521,165 @@ def allow_all_hub_kernels(): ALLOW_ALL_KERNELS = False +def make_parent_class_for_kernel_fusion( + parent_cls: type, + child_names: list[str], + kernel_cls: type, +) -> type: + """ + Create a new class that inherits from `parent_cls` and fuses the child modules specified in `child_names + with the provided `kernel_cls`. + The first child in `child_names` will be replaced with the `kernel_cls`, and the rest will be replaced with + `nn.Identity()` to keep the same interface. + """ + original_init = parent_cls.__init__ + + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + children = [getattr(self, name) for name in child_names] + kernel_instance = kernel_cls(*children) + setattr(self, child_names[0], kernel_instance) + for name in child_names[1:]: + setattr(self, name, nn.Identity()) + + patched_cls = type(f"Fused{parent_cls.__name__}", (parent_cls,), {"__init__": patched_init}) + patched_cls.__qualname__ = f"Fused{parent_cls.__qualname__}" + return patched_cls + + +def register_kernel_replacements_and_fusions( + cls: "type[PreTrainedModel]", + config: "PretrainedConfig", + kernel_config: "KernelConfig", +) -> None: + if not hasattr(cls, "config_class") or not hasattr(cls.config_class, "model_type"): + raise ValueError(f"Model {cls.__name__} has no config_class or model_type.") + model_type = cls.config_class.model_type + + patch_mapping: dict[str, type] = {} + new_mapping: dict = {} + + # We might need to instantiate the model on meta device. + # We do it lazily, only if we encounter a fused kernel. + meta_model = None + + for layer_name, hub_repo in kernel_config.kernel_mapping.items(): + if isinstance(hub_repo, dict): + if len(hub_repo.values()) != 1: + raise ValueError( + f"Expected exactly one kernel repo regardless of device/mode specificity, got {hub_repo}" + ) + repo_str = next(iter(hub_repo.values())) + elif isinstance(hub_repo, str): + repo_str = hub_repo + else: + raise ValueError(f"Invalid hub repo {hub_repo!r} for layer {layer_name!r}") + + repo_id, _, layer_name_in_repo = repo_str.partition(":") + if not repo_id or not layer_name_in_repo: + raise ValueError(f"Invalid kernel repo string {repo_str!r} for layer {layer_name!r}") + + if kernel_config.use_local_kernel: + package_name = repo_id.rstrip("/").split("/")[-1] + repo = LocalLayerRepository( + repo_path=Path(repo_id), + package_name=package_name, + layer_name=layer_name_in_repo, + ) + else: + repo = LayerRepository(repo_id=repo_id, layer_name=layer_name_in_repo) + + kernel_cls = repo.load() + + if kernel_cls is None: + raise ValueError(f"Could not load kernel class from hub_repo={hub_repo!r}") + + kernel_mod = sys.modules.get(kernel_cls.__module__) + layout_cls = getattr(kernel_mod, f"{kernel_cls.__name__}Layout", None) if kernel_mod else None + + # Case 1: no fusion. + if isinstance(layer_name, str): + # No layout class: stateless kernel, leave for kernels.kernelize. + if layout_cls is None: + new_mapping[layer_name] = repo_str + continue + + # Register the layout class as a monkey patch for the parent module containing the target layer. + layout_cls.kernel_layer_name = kernel_cls.__name__ + patch_mapping[layer_name] = layout_cls + + # Keep the original repo string so kernelize can replace the layout's forward. + new_mapping[kernel_cls.__name__] = repo_str + + # Case 2: fusion. + elif isinstance(layer_name, tuple): + if layout_cls is None: + raise ValueError( + f"Fused kernel {kernel_cls.__name__!r} requires a companion layout class " + f"named '{kernel_cls.__name__}Layout' in the same module." + ) + + layout_cls.kernel_layer_name = kernel_cls.__name__ + + glob_patterns = [item[1] for item in layer_name] + parent_patterns = [p.rsplit(".", 1)[0] for p in glob_patterns] + + if len(set(parent_patterns)) != 1: + raise ValueError( + f"All patterns for a fused kernel must share the same parent module, got {glob_patterns}" + ) + + parent_pattern = parent_patterns[0].replace("*", r"\w+") + child_names = [p.rsplit(".", 1)[1] for p in glob_patterns] + + if meta_model is None: + with torch.device("meta"): + meta_model = cls(config) + + matched_any = False + for name, module in meta_model.named_modules(): + if not re.fullmatch(parent_pattern, name): + continue + if not all(hasattr(module, child) for child in child_names): + raise ValueError( + f"Module {name!r} does not have the expected child modules {child_names} required for " + f"the fused kernel {kernel_cls.__name__!r}" + ) + matched_any = True + module_cls = type(module) + patch_mapping[module_cls.__name__] = make_parent_class_for_kernel_fusion( + module_cls, child_names, layout_cls + ) + + if not matched_any: + raise ValueError( + f"No module matched pattern {parent_pattern!r} for fused kernel {kernel_cls.__name__!r}. " + f"Provide the full dotted path from the model root." + ) + + register_patch_mapping(patch_mapping, overwrite=True) + + if hasattr(layout_cls, "conversion_mapping"): + existing = get_checkpoint_conversion_mapping(model_type) + transforms = list(layout_cls.conversion_mapping) + if existing is not None: + transforms = existing + transforms + register_checkpoint_conversion_mapping(model_type, transforms, overwrite=True) + + new_mapping[kernel_cls.__name__] = repo_str + + kernel_config.kernel_mapping = new_mapping + + __all__ = [ "LayerRepository", - "use_kernel_forward_from_hub", - "use_kernel_func_from_hub", + "get_kernel", + "lazy_load_kernel", "register_kernel_mapping", "register_kernel_mapping_transformers", + "register_kernel_replacements_and_fusions", "replace_kernel_forward_from_hub", - "lazy_load_kernel", - "get_kernel", + "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "use_kernelized_func", ] # type: ignore diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 13c06973b4a9..c45edc977e48 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1346,6 +1346,9 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path + # Make kernel_config an attribute that can be used by the model. + self.kernel_config = None + # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid # setting it recursively) self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation( @@ -3782,30 +3785,27 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None raise ValueError( "`use_kernels=True` requires kernels>=0.9.0. Please install the latest version with `pip install -U kernels`" ) - from kernels import use_kernel_mapping - from .integrations.hub_kernels import register_kernel_mapping_transformers register_kernel_mapping_transformers() if kernel_config is not None and isinstance(kernel_config, KernelConfig): + # Since kernel_config is a correct value, set it as an attribute of the model so it can be used. + self.kernel_config = kernel_config + # This will make sure the mapping is valid, and the layers are registered in the model kernel_config.sanitize_kernel_mapping(self) # This will create a compatible mapping for the model with the kernels library kernel_config.create_compatible_mapping(self) - - # This is a context manager to override the default kernel mapping - # We are calling kernelize inside this context manager using the use_kernels setter - # Param inherit_mapping should be False to avoid still loading kernel from remote - inherit_mapping = not kernel_config.use_local_kernel - with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): - self.use_kernels = True + self.use_kernels = True # We use the default kernel mapping in .integrations.hub_kernels else: self.use_kernels = True + self.kernel_config = None else: self.use_kernels = False + self.kernel_config = None @classmethod def from_pretrained( @@ -4245,6 +4245,12 @@ def from_pretrained( register_fusion_patches(cls, config, fusion_config) + # Kernel patches: single-layer replacement (stateful __init__) then fusions. + if kernel_config is not None and use_kernels: + from .integrations.hub_kernels import register_kernel_replacements_and_fusions + + register_kernel_replacements_and_fusions(cls, config, kernel_config) + model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. @@ -4644,7 +4650,14 @@ def detach_hidden_kernels(module): self.apply(attach_hidden_kernels) mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode - kernelize(self, device=Device(type=self.device.type), mode=mode) + if self.kernel_config is not None: + from kernels import use_kernel_mapping + + inherit_mapping = not self.kernel_config.use_local_kernel + with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): + kernelize(self, device=Device(type=self.device.type), mode=mode) + else: + kernelize(self, device=Device(type=self.device.type), mode=mode) self._use_kernels = True finally: diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 1bd9a7c79792..12bcd19b343c 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -21,6 +21,8 @@ import types from unittest.mock import MagicMock, patch +import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig from transformers.integrations.hub_kernels import ( _HUB_KERNEL_MAPPING, @@ -31,6 +33,7 @@ ) from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.monkey_patching import clear_patch_mapping, get_patch_mapping, register_patch_mapping from transformers.testing_utils import ( TestCasePlus, cleanup, @@ -87,7 +90,14 @@ def tearDownClass(cls): except Exception as e: print(f"Could not clear kernel module cache: {e}") + def setUp(self): + self._pre_test_patch_mapping = get_patch_mapping() + def tearDown(self): + # Restore monkey patch state to avoid leaking kernel patches across tests. + clear_patch_mapping() + if self._pre_test_patch_mapping: + register_patch_mapping(self._pre_test_patch_mapping) # Free accelerator memory/cache and trigger GC cleanup(torch_device, gc_collect=True) @@ -207,6 +217,107 @@ def test_kernels_mapping(self): del model + @require_torch_accelerator + def test_kernel_fusion(self): + model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" + kernel_config = KernelConfig( + { + ( + ("RMSNorm", "model.layers.*.post_attention_layernorm"), + ("MLP", "model.layers.*.mlp"), + ): "michaelbenayoun/dummy-rmsnorm-mlp-with-transformations-and-init:RMSNormMLP", + } + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer("Hello, how are you?", return_tensors="pt") + + baseline = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, device_map=torch_device) + baseline.eval() + inputs = {k: v.to(torch_device) for k, v in inputs.items()} + with torch.no_grad(): + baseline_out = baseline(**inputs).logits + del baseline + + fused = AutoModelForCausalLM.from_pretrained( + model_id, use_kernels=True, kernel_config=kernel_config, device_map=torch_device + ) + fused.eval() + with torch.no_grad(): + fused_out = fused(**inputs).logits + + torch.testing.assert_close(baseline_out, fused_out, atol=1e-4, rtol=1e-4) + + decoder_layers = [ + (name, m) + for name, m in fused.named_modules() + if hasattr(m, "post_attention_layernorm") and hasattr(m, "mlp") + ] + self.assertTrue(len(decoder_layers) > 0, "No decoder layers found") + for name, layer in decoder_layers: + self.assertIsInstance( + layer.mlp, + torch.nn.Identity, + f"{name}.mlp should be nn.Identity after fusion", + ) + self.assertTrue( + hasattr(layer.post_attention_layernorm, "kernel_layer_name") + or hasattr(type(layer.post_attention_layernorm), "kernel_layer_name"), + f"{name}.post_attention_layernorm should carry kernel_layer_name after fusion", + ) + + del fused + + @require_torch_accelerator + def test_kernel_replacement_with_layout(self): + model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" + kernel_config = KernelConfig({"RMSNorm": "michaelbenayoun/dummy-rmsnorm-kernel-with-init:CustomRMSNorm"}) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer("Hello, how are you?", return_tensors="pt") + + baseline = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, device_map=torch_device) + baseline.eval() + inputs = {k: v.to(torch_device) for k, v in inputs.items()} + original_rmsnorm_cls = type(next(m for m in baseline.modules() if "RMSNorm" in type(m).__name__)) + with torch.no_grad(): + baseline_out = baseline(**inputs).logits + del baseline + + model = AutoModelForCausalLM.from_pretrained( + model_id, use_kernels=True, kernel_config=kernel_config, device_map=torch_device + ) + model.eval() + with torch.no_grad(): + model_out = model(**inputs).logits + + torch.testing.assert_close(baseline_out, model_out, atol=1e-4, rtol=1e-4) + + replaced = [m for m in model.modules() if hasattr(type(m), "kernel_layer_name")] + self.assertTrue(len(replaced) > 0, "No replaced kernel layout modules found") + for m in replaced: + self.assertNotIsInstance(m, original_rmsnorm_cls) + + del model + + def test_faulty_fusion_incomplete_pattern(self): + model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" + # "layers.*.post_attention_layernorm" is missing the leading "model." segment. + # re.fullmatch("layers.\w+", "model.layers.0") returns None, so no module + # is ever matched and the function raises ValueError. + kernel_config = KernelConfig( + { + ( + ("RMSNorm", "layers.*.post_attention_layernorm"), + ("MLP", "layers.*.mlp"), + ): "michaelbenayoun/dummy-rmsnorm-mlp-with-transformations-and-init:RMSNormMLP", + } + ) + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained( + model_id, use_kernels=True, kernel_config=kernel_config, device_map=torch_device + ) + def test_faulty_kernel_mapping_layer_name(self): kernel_config = KernelConfig(kernel_mapping={"RMSNorm1": "kernels-community/layer_norm:LlamaRMSNorm"}) with self.assertRaises(ValueError): @@ -225,24 +336,24 @@ def test_faulty_kernel_mapping_type(self): @require_kernels class TestKernelsEnv(TestCasePlus): def test_disable_hub_kernels(self): - with patch.dict(os.environ, {"USE_HUB_KERNELS": "OFF"}): - import importlib + import importlib - from transformers.integrations import hub_kernels + from transformers.integrations import hub_kernels + with patch.dict(os.environ, {"USE_HUB_KERNELS": "OFF"}): importlib.reload(hub_kernels) - self.assertFalse(hub_kernels._kernels_enabled) + importlib.reload(hub_kernels) def test_enable_hub_kernels(self): - with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}): - import importlib + import importlib - from transformers.integrations import hub_kernels + from transformers.integrations import hub_kernels + with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}): importlib.reload(hub_kernels) - self.assertTrue(hub_kernels._kernels_enabled) + importlib.reload(hub_kernels) @require_kernels