Skip to content

[Single File] Add GGUF support #9964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b5eeaa4
update
DN6 Oct 21, 2024
71897b1
update
DN6 Oct 21, 2024
89ea1ee
update
DN6 Oct 24, 2024
f0bcd94
update
DN6 Oct 24, 2024
60d1385
update
DN6 Oct 29, 2024
22ed0b0
update
DN6 Oct 31, 2024
2e6d340
update
DN6 Nov 3, 2024
b5f927c
update
DN6 Nov 11, 2024
b9666c7
Merge branch 'main' into gguf-support
DN6 Nov 11, 2024
6dc5d22
update
DN6 Nov 13, 2024
428e44b
update
DN6 Nov 15, 2024
d7f09f2
update
DN6 Nov 19, 2024
1649936
update
DN6 Nov 19, 2024
28d3a64
update
DN6 Nov 19, 2024
c34a451
update
DN6 Nov 21, 2024
84493db
update
DN6 Nov 21, 2024
50bd784
update
DN6 Nov 21, 2024
8f604b3
Merge branch 'main' into gguf-support
DN6 Dec 3, 2024
afd5d7d
update
DN6 Dec 4, 2024
e1b964a
Merge branch 'main' into gguf-support
sayakpaul Dec 4, 2024
0ed31bc
update
DN6 Dec 4, 2024
af381ad
update
DN6 Dec 4, 2024
52a1bcb
update
DN6 Dec 4, 2024
66ae46e
Merge branch 'gguf-support' of https://github.com/huggingface/diffuse…
DN6 Dec 4, 2024
67f1700
update
DN6 Dec 4, 2024
8abfa55
update
DN6 Dec 5, 2024
d4b88d7
update
DN6 Dec 5, 2024
30f13ed
update
DN6 Dec 5, 2024
9310035
update
DN6 Dec 5, 2024
e9303a0
update
DN6 Dec 5, 2024
e56c266
update
DN6 Dec 5, 2024
1209c3a
Update src/diffusers/quantizers/gguf/utils.py
DN6 Dec 5, 2024
db9b6f3
update
DN6 Dec 5, 2024
4c0360a
Merge branch 'gguf-support' of https://github.com/huggingface/diffuse…
DN6 Dec 5, 2024
aa7659b
Merge branch 'main' into gguf-support
DN6 Dec 5, 2024
78c7861
update
DN6 Dec 5, 2024
33eb431
update
DN6 Dec 5, 2024
9651ddc
update
DN6 Dec 5, 2024
746fd2f
update
DN6 Dec 5, 2024
e027d46
update
DN6 Dec 5, 2024
9db2396
update
DN6 Dec 6, 2024
7ee89f4
update
DN6 Dec 6, 2024
edf3e54
update
DN6 Dec 6, 2024
d3eb54f
update
DN6 Dec 6, 2024
82606cb
Merge branch 'main' into gguf-support
sayakpaul Dec 9, 2024
4f34f14
Update docs/source/en/quantization/gguf.md
DN6 Dec 11, 2024
090efdb
update
DN6 Dec 11, 2024
391b5a9
Merge branch 'main' into gguf-support
DN6 Dec 17, 2024
e67c25a
update
DN6 Dec 17, 2024
e710bde
update
DN6 Dec 17, 2024
f59e07a
update
DN6 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/source/en/quantization/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@ specific language governing permissions and limitations under the License.

# GGUF

The GGUF file format is typically used to store models for inference with [GGML]() and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Support for loading GGUF checkpoint via Pipelines is currently not supported. The dequantizatation functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF)
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.

The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.

Before starting please install gguf in your environment

```shell
pip install -U gguf
```

Since GGUF is a single file format, we will be using `from_single_file` to load the model and pass in the `GGUFQuantizationConfig` when loading the model.

When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`, typically `torch.unint8` and are dynamically dequantized and cast to the configured `compute_dtype` when running a forward pass through each module in the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype` for the forward pass of each module. The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the pytorch dequantization code is based on the numpy code from llama.cpp written by @compilade - I believe he should be credited here as well :)


```python
import torch
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer

if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .utils import (
GGML_QUANT_SIZES,
GGUFParameter,
_dequantize_gguf_and_restore_linear,
_quant_shape_from_byte_shape,
_replace_with_gguf_linear,
)
Expand Down Expand Up @@ -143,3 +144,16 @@ def is_serializable(self):
@property
def is_trainable(self) -> bool:
return False

def _dequantize(self, model):
is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())

model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu:
model.to("cpu")
return model
57 changes: 57 additions & 0 deletions src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# # limitations under the License.


import inspect
from contextlib import nullcontext

import gguf
Expand All @@ -23,7 +24,27 @@


if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module


# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook


def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
Expand Down Expand Up @@ -59,6 +80,42 @@ def _should_convert_to_gguf(state_dict, prefix):
return model


def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
for name, module in model.named_children():
if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
device = module.weight.device
bias = getattr(module, "bias", None)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
new_module = nn.Linear(
module.in_features,
module.out_features,
module.bias is not None,
device=device,
)
new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
if bias is not None:
new_module.bias = bias

# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)

remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)

new_module.to(device)
model._modules[name] = new_module

has_children = list(module.children())
if has_children:
_dequantize_gguf_and_restore_linear(module, modules_to_not_convert)

return model


# dequantize operations based on torch ports of GGUF dequantize_functions
# from City96
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
Expand Down
21 changes: 20 additions & 1 deletion tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import torch.nn as nn

from diffusers import (
FluxPipeline,
Expand All @@ -23,7 +24,7 @@


if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFParameter
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter


@nightly
Expand Down Expand Up @@ -112,6 +113,24 @@ def test_dtype_assignment(self):
# This should work
model.to("cuda")

def test_dequantize_model(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
model.dequantize()

def _check_for_gguf_linear(model):
has_children = list(model.children())
if not has_children:
return

for name, module in model.named_children():
if isinstance(module, nn.Linear):
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"

for name, module in model.named_children():
_check_for_gguf_linear(module)


class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
Expand Down
Loading