-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Extended & simplified n-to-1 kernel fusion via KernelConfig #46339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b387190
6bc9402
62d4454
4082fe1
ac4a699
e13111f
d9d53f0
e1c7f3f
db0b7f0
e21d06e
bd640ae
323b000
fe3002d
b541453
b3d73a7
0845222
ecfd97d
973a616
0f0a64b
91177ae
2636c06
573d9f0
f7c15bd
f9d4299
847dbd4
4443c9a
51c59c9
4c58503
a7f983f
b8d860f
3d5f353
c35b513
6e369c0
731d0b7
4e56ad3
96ac123
105a403
b1c9645
b10b864
d1bd06d
fb0a748
ded8b5f
6d6411b
318553e
2a24760
8136781
aa7743a
17f4a9e
8ec88d1
fabadca
e1f3a83
68c3659
041182f
519e673
88e0aee
ad0e24e
3924cf3
c948834
e0c0366
da0fdae
06add71
597bb8c
1489146
024fd4c
b20cb71
1fb1787
2e87c9f
11ecc56
d98489a
89adf48
c841433
edf33ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,23 +14,39 @@ | |||||
| import importlib.metadata | ||||||
| import os | ||||||
| import re | ||||||
| import sys | ||||||
| from collections.abc import Callable | ||||||
| from contextlib import contextmanager | ||||||
| from pathlib import Path | ||||||
| from types import ModuleType | ||||||
| from typing import TYPE_CHECKING | ||||||
|
|
||||||
| from packaging import version as pkg_version | ||||||
|
|
||||||
| from ..conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping | ||||||
| from ..monkey_patching import register_patch_mapping | ||||||
| from ..utils import ENV_VARS_TRUE_VALUES, logging | ||||||
| from ..utils.import_utils import is_kernels_available | ||||||
| from ..utils.import_utils import is_kernels_available, is_torch_available | ||||||
| from .flash_attention import flash_attention_forward | ||||||
|
|
||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from ..configuration_utils import PretrainedConfig | ||||||
| from ..modeling_utils import PreTrainedModel | ||||||
| from ..utils.kernel_config import KernelConfig | ||||||
|
|
||||||
| if is_torch_available(): | ||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
|
|
||||||
|
|
||||||
| logger = logging.get_logger(__name__) | ||||||
|
|
||||||
| try: | ||||||
| from kernels import ( | ||||||
| Device, | ||||||
| LayerRepository, | ||||||
| LocalLayerRepository, | ||||||
| Mode, | ||||||
| register_kernel_mapping, | ||||||
| replace_kernel_forward_from_hub, | ||||||
|
|
@@ -269,6 +285,10 @@ class LayerRepository: | |||||
| def __init__(self, *args, **kwargs): | ||||||
| raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") | ||||||
|
|
||||||
| class LocalLayerRepository: | ||||||
| def __init__(self, *args, **kwargs): | ||||||
| raise RuntimeError("LocalLayerRepository requires `kernels` to be installed. Run `pip install kernels`.") | ||||||
|
|
||||||
| def replace_kernel_forward_from_hub(*args, **kwargs): | ||||||
| raise RuntimeError( | ||||||
| "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." | ||||||
|
|
@@ -501,14 +521,165 @@ def allow_all_hub_kernels(): | |||||
| ALLOW_ALL_KERNELS = False | ||||||
|
|
||||||
|
|
||||||
| def make_parent_class_for_kernel_fusion( | ||||||
| parent_cls: type, | ||||||
| child_names: list[str], | ||||||
| kernel_cls: type, | ||||||
| ) -> type: | ||||||
| """ | ||||||
| Create a new class that inherits from `parent_cls` and fuses the child modules specified in `child_names | ||||||
| with the provided `kernel_cls`. | ||||||
| The first child in `child_names` will be replaced with the `kernel_cls`, and the rest will be replaced with | ||||||
| `nn.Identity()` to keep the same interface. | ||||||
| """ | ||||||
| original_init = parent_cls.__init__ | ||||||
|
|
||||||
| def patched_init(self, *args, **kwargs): | ||||||
| original_init(self, *args, **kwargs) | ||||||
| children = [getattr(self, name) for name in child_names] | ||||||
| kernel_instance = kernel_cls(*children) | ||||||
| setattr(self, child_names[0], kernel_instance) | ||||||
| for name in child_names[1:]: | ||||||
| setattr(self, name, nn.Identity()) | ||||||
|
|
||||||
| patched_cls = type(f"Fused{parent_cls.__name__}", (parent_cls,), {"__init__": patched_init}) | ||||||
| patched_cls.__qualname__ = f"Fused{parent_cls.__qualname__}" | ||||||
| return patched_cls | ||||||
|
|
||||||
|
|
||||||
| def register_kernel_replacements_and_fusions( | ||||||
| cls: "type[PreTrainedModel]", | ||||||
| config: "PretrainedConfig", | ||||||
| kernel_config: "KernelConfig", | ||||||
| ) -> None: | ||||||
| if not hasattr(cls, "config_class") or not hasattr(cls.config_class, "model_type"): | ||||||
| raise ValueError(f"Model {cls.__name__} has no config_class or model_type.") | ||||||
| model_type = cls.config_class.model_type | ||||||
|
|
||||||
| patch_mapping: dict[str, type] = {} | ||||||
| new_mapping: dict = {} | ||||||
|
|
||||||
| # We might need to instantiate the model on meta device. | ||||||
| # We do it lazily, only if we encounter a fused kernel. | ||||||
| meta_model = None | ||||||
|
|
||||||
| for layer_name, hub_repo in kernel_config.kernel_mapping.items(): | ||||||
| if isinstance(hub_repo, dict): | ||||||
| if len(hub_repo.values()) != 1: | ||||||
| raise ValueError( | ||||||
| f"Expected exactly one kernel repo regardless of device/mode specificity, got {hub_repo}" | ||||||
| ) | ||||||
| repo_str = next(iter(hub_repo.values())) | ||||||
| elif isinstance(hub_repo, str): | ||||||
| repo_str = hub_repo | ||||||
| else: | ||||||
| raise ValueError(f"Invalid hub repo {hub_repo!r} for layer {layer_name!r}") | ||||||
|
|
||||||
| repo_id, _, layer_name_in_repo = repo_str.partition(":") | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||||||
| if not repo_id or not layer_name_in_repo: | ||||||
| raise ValueError(f"Invalid kernel repo string {repo_str!r} for layer {layer_name!r}") | ||||||
|
|
||||||
| if kernel_config.use_local_kernel: | ||||||
| package_name = repo_id.rstrip("/").split("/")[-1] | ||||||
| repo = LocalLayerRepository( | ||||||
| repo_path=Path(repo_id), | ||||||
| package_name=package_name, | ||||||
| layer_name=layer_name_in_repo, | ||||||
| ) | ||||||
| else: | ||||||
| repo = LayerRepository(repo_id=repo_id, layer_name=layer_name_in_repo) | ||||||
|
|
||||||
| kernel_cls = repo.load() | ||||||
|
|
||||||
| if kernel_cls is None: | ||||||
| raise ValueError(f"Could not load kernel class from hub_repo={hub_repo!r}") | ||||||
|
|
||||||
| kernel_mod = sys.modules.get(kernel_cls.__module__) | ||||||
| layout_cls = getattr(kernel_mod, f"{kernel_cls.__name__}Layout", None) if kernel_mod else None | ||||||
|
|
||||||
| # Case 1: no fusion. | ||||||
| if isinstance(layer_name, str): | ||||||
| # No layout class: stateless kernel, leave for kernels.kernelize. | ||||||
| if layout_cls is None: | ||||||
| new_mapping[layer_name] = repo_str | ||||||
| continue | ||||||
|
|
||||||
| # Register the layout class as a monkey patch for the parent module containing the target layer. | ||||||
| layout_cls.kernel_layer_name = kernel_cls.__name__ | ||||||
| patch_mapping[layer_name] = layout_cls | ||||||
|
|
||||||
| # Keep the original repo string so kernelize can replace the layout's forward. | ||||||
| new_mapping[kernel_cls.__name__] = repo_str | ||||||
|
|
||||||
| # Case 2: fusion. | ||||||
| elif isinstance(layer_name, tuple): | ||||||
| if layout_cls is None: | ||||||
| raise ValueError( | ||||||
| f"Fused kernel {kernel_cls.__name__!r} requires a companion layout class " | ||||||
| f"named '{kernel_cls.__name__}Layout' in the same module." | ||||||
| ) | ||||||
|
|
||||||
| layout_cls.kernel_layer_name = kernel_cls.__name__ | ||||||
|
|
||||||
| glob_patterns = [item[1] for item in layer_name] | ||||||
| parent_patterns = [p.rsplit(".", 1)[0] for p in glob_patterns] | ||||||
|
|
||||||
| if len(set(parent_patterns)) != 1: | ||||||
| raise ValueError( | ||||||
| f"All patterns for a fused kernel must share the same parent module, got {glob_patterns}" | ||||||
| ) | ||||||
|
|
||||||
| parent_pattern = parent_patterns[0].replace("*", r"\w+") | ||||||
| child_names = [p.rsplit(".", 1)[1] for p in glob_patterns] | ||||||
|
|
||||||
| if meta_model is None: | ||||||
| with torch.device("meta"): | ||||||
| meta_model = cls(config) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
we only need these
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah maybe it gets updated but that's good, you can'tupdate twice so its even better in a way no? (to not re-compute the named modules) |
||||||
|
|
||||||
| matched_any = False | ||||||
| for name, module in meta_model.named_modules(): | ||||||
| if not re.fullmatch(parent_pattern, name): | ||||||
| continue | ||||||
| if not all(hasattr(module, child) for child in child_names): | ||||||
| raise ValueError( | ||||||
| f"Module {name!r} does not have the expected child modules {child_names} required for " | ||||||
| f"the fused kernel {kernel_cls.__name__!r}" | ||||||
| ) | ||||||
| matched_any = True | ||||||
| module_cls = type(module) | ||||||
| patch_mapping[module_cls.__name__] = make_parent_class_for_kernel_fusion( | ||||||
| module_cls, child_names, layout_cls | ||||||
| ) | ||||||
|
|
||||||
| if not matched_any: | ||||||
| raise ValueError( | ||||||
| f"No module matched pattern {parent_pattern!r} for fused kernel {kernel_cls.__name__!r}. " | ||||||
| f"Provide the full dotted path from the model root." | ||||||
| ) | ||||||
|
|
||||||
| register_patch_mapping(patch_mapping, overwrite=True) | ||||||
|
|
||||||
| if hasattr(layout_cls, "conversion_mapping"): | ||||||
| existing = get_checkpoint_conversion_mapping(model_type) | ||||||
| transforms = list(layout_cls.conversion_mapping) | ||||||
| if existing is not None: | ||||||
| transforms = existing + transforms | ||||||
| register_checkpoint_conversion_mapping(model_type, transforms, overwrite=True) | ||||||
|
|
||||||
| new_mapping[kernel_cls.__name__] = repo_str | ||||||
|
|
||||||
| kernel_config.kernel_mapping = new_mapping | ||||||
|
|
||||||
|
|
||||||
| __all__ = [ | ||||||
| "LayerRepository", | ||||||
| "use_kernel_forward_from_hub", | ||||||
| "use_kernel_func_from_hub", | ||||||
| "get_kernel", | ||||||
| "lazy_load_kernel", | ||||||
| "register_kernel_mapping", | ||||||
| "register_kernel_mapping_transformers", | ||||||
| "register_kernel_replacements_and_fusions", | ||||||
| "replace_kernel_forward_from_hub", | ||||||
| "lazy_load_kernel", | ||||||
| "get_kernel", | ||||||
| "use_kernel_forward_from_hub", | ||||||
| "use_kernel_func_from_hub", | ||||||
| "use_kernelized_func", | ||||||
| ] # type: ignore | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1346,6 +1346,9 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): | |||||||||||||||||||
| self.config = config | ||||||||||||||||||||
| self.name_or_path = config.name_or_path | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Make kernel_config an attribute that can be used by the model. | ||||||||||||||||||||
| self.kernel_config = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid | ||||||||||||||||||||
| # setting it recursively) | ||||||||||||||||||||
| self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation( | ||||||||||||||||||||
|
|
@@ -3782,30 +3785,27 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None | |||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| "`use_kernels=True` requires kernels>=0.9.0. Please install the latest version with `pip install -U kernels`" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| from kernels import use_kernel_mapping | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from .integrations.hub_kernels import register_kernel_mapping_transformers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| register_kernel_mapping_transformers() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if kernel_config is not None and isinstance(kernel_config, KernelConfig): | ||||||||||||||||||||
| # Since kernel_config is a correct value, set it as an attribute of the model so it can be used. | ||||||||||||||||||||
| self.kernel_config = kernel_config | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # This will make sure the mapping is valid, and the layers are registered in the model | ||||||||||||||||||||
| kernel_config.sanitize_kernel_mapping(self) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # This will create a compatible mapping for the model with the kernels library | ||||||||||||||||||||
| kernel_config.create_compatible_mapping(self) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # This is a context manager to override the default kernel mapping | ||||||||||||||||||||
| # We are calling kernelize inside this context manager using the use_kernels setter | ||||||||||||||||||||
| # Param inherit_mapping should be False to avoid still loading kernel from remote | ||||||||||||||||||||
| inherit_mapping = not kernel_config.use_local_kernel | ||||||||||||||||||||
| with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): | ||||||||||||||||||||
| self.use_kernels = True | ||||||||||||||||||||
| self.use_kernels = True | ||||||||||||||||||||
| # We use the default kernel mapping in .integrations.hub_kernels | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| self.use_kernels = True | ||||||||||||||||||||
| self.kernel_config = None | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| self.use_kernels = False | ||||||||||||||||||||
| self.kernel_config = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @classmethod | ||||||||||||||||||||
| def from_pretrained( | ||||||||||||||||||||
|
|
@@ -4245,6 +4245,12 @@ def from_pretrained( | |||||||||||||||||||
|
|
||||||||||||||||||||
| register_fusion_patches(cls, config, fusion_config) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Kernel patches: single-layer replacement (stateful __init__) then fusions. | ||||||||||||||||||||
| if kernel_config is not None and use_kernels: | ||||||||||||||||||||
| from .integrations.hub_kernels import register_kernel_replacements_and_fusions | ||||||||||||||||||||
|
|
||||||||||||||||||||
| register_kernel_replacements_and_fusions(cls, config, kernel_config) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. | ||||||||||||||||||||
|
|
@@ -4644,7 +4650,14 @@ def detach_hidden_kernels(module): | |||||||||||||||||||
| self.apply(attach_hidden_kernels) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode | ||||||||||||||||||||
| kernelize(self, device=Device(type=self.device.type), mode=mode) | ||||||||||||||||||||
| if self.kernel_config is not None: | ||||||||||||||||||||
| from kernels import use_kernel_mapping | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inherit_mapping = not self.kernel_config.use_local_kernel | ||||||||||||||||||||
| with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): | ||||||||||||||||||||
| kernelize(self, device=Device(type=self.device.type), mode=mode) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| kernelize(self, device=Device(type=self.device.type), mode=mode) | ||||||||||||||||||||
|
Comment on lines
+4653
to
+4660
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
let's reduce surface as much as possible
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kernelize is defined in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay! we can also just create |
||||||||||||||||||||
| self._use_kernels = True | ||||||||||||||||||||
|
|
||||||||||||||||||||
| finally: | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice