Extended & simplified n-to-1 kernel fusion via KernelConfig#46339
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if self.kernel_config is not None: | ||
| from kernels import use_kernel_mapping | ||
|
|
||
| inherit_mapping = not self.kernel_config.use_local_kernel | ||
| with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): | ||
| kernelize(self, device=Device(type=self.device.type), mode=mode) | ||
| else: | ||
| kernelize(self, device=Device(type=self.device.type), mode=mode) |
There was a problem hiding this comment.
| if self.kernel_config is not None: | |
| from kernels import use_kernel_mapping | |
| inherit_mapping = not self.kernel_config.use_local_kernel | |
| with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): | |
| kernelize(self, device=Device(type=self.device.type), mode=mode) | |
| else: | |
| kernelize(self, device=Device(type=self.device.type), mode=mode) | |
| kernelize(self, device=Device(type=self.device.type), mode=mode, self.kernel_config) |
let's reduce surface as much as possible
There was a problem hiding this comment.
kernelize is defined in kernels. I can make a PR there, but for now it cannot be changed here.
There was a problem hiding this comment.
okay! we can also just create def kernelize to put in kernels utils!
| for module in meta_model.modules(): | ||
| module_cls = type(module) | ||
| if module_cls in seen: | ||
| continue | ||
| if not all(hasattr(module, name) for name in child_names): | ||
| continue | ||
| seen.add(module_cls) |
There was a problem hiding this comment.
I don't think we need to iterate over all the modules!
We could register like we do for the tp plan with explicit path, we like explicitness in general!
{ "layers.*.self_attn.q_proj" : XXXX} There was a problem hiding this comment.
MOST important comment IMO if the contract is more like this we have a lot of simplifications no?
There was a problem hiding this comment.
We already have this contract.
kernel_config = KernelConfig(
{
(
("RMSNorm", "model.layers.*.post_attention_layernorm"),
("MLP", "model.layers.*.mlp"),
): kernel_repo_id,
},
)I will update this loop
| kernel_config.kernel_mapping = new_mapping | ||
|
|
||
|
|
||
| def register_kernel_fusions( |
There was a problem hiding this comment.
let's do both in a single func!
| def _first_str_leaf(obj) -> str | None: | ||
| """Recursively extract the first string leaf from a potentially nested dict (device → mode → str).""" | ||
| if isinstance(obj, str): | ||
| return obj | ||
| if isinstance(obj, dict): | ||
| for v in obj.values(): | ||
| result = _first_str_leaf(v) | ||
| if result is not None: | ||
| return result | ||
| return None |
| ALLOW_ALL_KERNELS = False | ||
|
|
||
|
|
||
| def make_kernel_init_parent_class( |
There was a problem hiding this comment.
this is super important needs to be documented well:
- we replace the fused cls by identity
- thus we have to patch some inits, etc etc c
also do we even have to patch inits when the proper class replaces the one that holds them?
ArthurZucker
left a comment
There was a problem hiding this comment.
Much much better! Its just missing a piece of doc / update the doc for monkey patching, maybe some bench if you have but that's fine for another PR !
Ty for iterating its quite nice now!
| new_mapping: dict = {} | ||
|
|
||
| # We might need to instantiate the model on meta device. | ||
| # We do it lazily, only if we encounter a fused kernel. |
| else: | ||
| raise ValueError(f"Invalid hub repo {hub_repo!r} for layer {layer_name!r}") | ||
|
|
||
| repo_id, _, layer_name_in_repo = repo_str.partition(":") |
|
|
||
| if meta_model is None: | ||
| with torch.device("meta"): | ||
| meta_model = cls(config) |
There was a problem hiding this comment.
| meta_model = cls(config) | |
| meta__modules = cls(config).named_modules() |
we only need these
There was a problem hiding this comment.
ah maybe it gets updated but that's good, you can'tupdate twice so its even better in a way no? (to not re-compute the named modules)
| if self.kernel_config is not None: | ||
| from kernels import use_kernel_mapping | ||
|
|
||
| inherit_mapping = not self.kernel_config.use_local_kernel | ||
| with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): | ||
| kernelize(self, device=Device(type=self.device.type), mode=mode) | ||
| else: | ||
| kernelize(self, device=Device(type=self.device.type), mode=mode) |
There was a problem hiding this comment.
okay! we can also just create def kernelize to put in kernels utils!
|
For the |
…ace#46339) * feat: module fusion API for kernels * fix: improve __repr__ for fused modules * wip: integration to KernelConfig * wip: add temporary example * wip: pattern matching in KernelConfig and actual kernel repo * refactor: move relevant code to hub_kernels.py * docs: reformat docstring * refactor: remove comment * refactor: update example script for testing * wip: remove apply_fusions method * wip: add core feature for integration with the current fusing API * fix: move kernel mapping patching to kernelize * wip: update example script * wip: add transform_model method for WeightTransform * wip: conversion_mapping in Kernel * wip: remove things from __all__ * wip: remove imports * fix: remove register_fusion_pattern path * fix: remove unused attribute * wip: update experimentation script * refactor: add convert as abstract method * style: reformat hub_kernels.py * wip: transform_model API * wip: transform_model API, WeightTransform * wip: transform_model API, WeightConverter * wip: transform_model API, WeightConverter * wip: make transform_model idempotent * refactor: infer_kernel_fusion_transforms * style: regexs -> regexes * refactor: register_kernel_fusions * refactor: post transformation cleanup * style: fix comment * test: add TestApplyTransformsToMetaModel tests * test: add kernels test * test: fix hub_kernels package reload * style: ruff * refactor: do not create dynamic classes in test * refactor: no dynamic class creation in tests * refactor: test * fix: TYPE_CHECKING imports were broken * wip: get rid of transform_model methods * wip: move tests * wip: make conversion happen before fused module instantiation * refactor * wip: move conversion_mapping inside the init * wip: without any transform_model * wip: remove dead code * wip: api imrpovement * wip: refactor * wip: enable __init__ support in kernels * wip: fuse + init * clean: remove "dead" code * wip: use two classes in kernels * wip: remove docstring * test: add relevant tests * chore: remove experiment file * cleanup: remove helper function * cleanup: remove helper function * refactor: merge the two register kernel functions into one * cleanup: use explicit regex patterns to match for monkey patching * test: cleanup and update tests * doc: add docstring to make_parent_class_for_kernel_fusion
…ace#46339) * feat: module fusion API for kernels * fix: improve __repr__ for fused modules * wip: integration to KernelConfig * wip: add temporary example * wip: pattern matching in KernelConfig and actual kernel repo * refactor: move relevant code to hub_kernels.py * docs: reformat docstring * refactor: remove comment * refactor: update example script for testing * wip: remove apply_fusions method * wip: add core feature for integration with the current fusing API * fix: move kernel mapping patching to kernelize * wip: update example script * wip: add transform_model method for WeightTransform * wip: conversion_mapping in Kernel * wip: remove things from __all__ * wip: remove imports * fix: remove register_fusion_pattern path * fix: remove unused attribute * wip: update experimentation script * refactor: add convert as abstract method * style: reformat hub_kernels.py * wip: transform_model API * wip: transform_model API, WeightTransform * wip: transform_model API, WeightConverter * wip: transform_model API, WeightConverter * wip: make transform_model idempotent * refactor: infer_kernel_fusion_transforms * style: regexs -> regexes * refactor: register_kernel_fusions * refactor: post transformation cleanup * style: fix comment * test: add TestApplyTransformsToMetaModel tests * test: add kernels test * test: fix hub_kernels package reload * style: ruff * refactor: do not create dynamic classes in test * refactor: no dynamic class creation in tests * refactor: test * fix: TYPE_CHECKING imports were broken * wip: get rid of transform_model methods * wip: move tests * wip: make conversion happen before fused module instantiation * refactor * wip: move conversion_mapping inside the init * wip: without any transform_model * wip: remove dead code * wip: api imrpovement * wip: refactor * wip: enable __init__ support in kernels * wip: fuse + init * clean: remove "dead" code * wip: use two classes in kernels * wip: remove docstring * test: add relevant tests * chore: remove experiment file * cleanup: remove helper function * cleanup: remove helper function * refactor: merge the two register kernel functions into one * cleanup: use explicit regex patterns to match for monkey patching * test: cleanup and update tests * doc: add docstring to make_parent_class_for_kernel_fusion
What does this PR do?
Extends the
KernelConfigAPI with two orthogonal capabilities:Module fusion: specify how Transformers modules should be fused together before a custom kernel is applied (n-to-1 replacement).
Parameter transformation: handle cases where a kernel expects weights in a different layout than the original modeling (e.g. fused linears).
Compared to previous PR, this approach is more explicit and way simpler, putting much of the burden to the kernel authors.
How it works
The kernel author needs to define two classes:
KernelName: defines the forward pass, used by thekernelslibrary to kernelize the modelKernelNameLayout: defines theconversion_mappingas well as an__init__method. This is used to monkey-patch the modelHaving two classes because the
kernelslibrary prevents us from having stateful kernel classes.While it might not be as pleasing as having one big class, it separates concerns.
Script for the examples
Example 1: Parameter transformation, no fusion
In this case, the
KernelNameLayoutclass's__init__method has the same signature as the module being replaced.Example 2: Fusion and parameter transformation
Compared to the first example, here we will fuse two modules in the original model into one module.
Because of this, the
__init__method does not have the same signature, but rather take the instantiated modules it's fusing.