Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b5eeaa4
update
DN6 Oct 21, 2024
71897b1
update
DN6 Oct 21, 2024
89ea1ee
update
DN6 Oct 24, 2024
f0bcd94
update
DN6 Oct 24, 2024
60d1385
update
DN6 Oct 29, 2024
22ed0b0
update
DN6 Oct 31, 2024
2e6d340
update
DN6 Nov 3, 2024
b5f927c
update
DN6 Nov 11, 2024
b9666c7
Merge branch 'main' into gguf-support
DN6 Nov 11, 2024
6dc5d22
update
DN6 Nov 13, 2024
428e44b
update
DN6 Nov 15, 2024
d7f09f2
update
DN6 Nov 19, 2024
1649936
update
DN6 Nov 19, 2024
28d3a64
update
DN6 Nov 19, 2024
c34a451
update
DN6 Nov 21, 2024
84493db
update
DN6 Nov 21, 2024
50bd784
update
DN6 Nov 21, 2024
8f604b3
Merge branch 'main' into gguf-support
DN6 Dec 3, 2024
afd5d7d
update
DN6 Dec 4, 2024
e1b964a
Merge branch 'main' into gguf-support
sayakpaul Dec 4, 2024
0ed31bc
update
DN6 Dec 4, 2024
af381ad
update
DN6 Dec 4, 2024
52a1bcb
update
DN6 Dec 4, 2024
66ae46e
Merge branch 'gguf-support' of https://github.com/huggingface/diffuse…
DN6 Dec 4, 2024
67f1700
update
DN6 Dec 4, 2024
8abfa55
update
DN6 Dec 5, 2024
d4b88d7
update
DN6 Dec 5, 2024
30f13ed
update
DN6 Dec 5, 2024
9310035
update
DN6 Dec 5, 2024
e9303a0
update
DN6 Dec 5, 2024
e56c266
update
DN6 Dec 5, 2024
1209c3a
Update src/diffusers/quantizers/gguf/utils.py
DN6 Dec 5, 2024
db9b6f3
update
DN6 Dec 5, 2024
4c0360a
Merge branch 'gguf-support' of https://github.com/huggingface/diffuse…
DN6 Dec 5, 2024
aa7659b
Merge branch 'main' into gguf-support
DN6 Dec 5, 2024
78c7861
update
DN6 Dec 5, 2024
33eb431
update
DN6 Dec 5, 2024
9651ddc
update
DN6 Dec 5, 2024
746fd2f
update
DN6 Dec 5, 2024
e027d46
update
DN6 Dec 5, 2024
9db2396
update
DN6 Dec 6, 2024
7ee89f4
update
DN6 Dec 6, 2024
edf3e54
update
DN6 Dec 6, 2024
d3eb54f
update
DN6 Dec 6, 2024
82606cb
Merge branch 'main' into gguf-support
sayakpaul Dec 9, 2024
4f34f14
Update docs/source/en/quantization/gguf.md
DN6 Dec 11, 2024
090efdb
update
DN6 Dec 11, 2024
391b5a9
Merge branch 'main' into gguf-support
DN6 Dec 17, 2024
e67c25a
update
DN6 Dec 17, 2024
e710bde
update
DN6 Dec 17, 2024
f59e07a
update
DN6 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down
34 changes: 33 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from contextlib import nullcontext
from typing import Optional

import torch
from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
Expand Down Expand Up @@ -202,6 +204,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
Expand All @@ -215,6 +218,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only=local_files_only,
revision=revision,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)

else:
hf_quantizer = None

mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]

Expand Down Expand Up @@ -295,8 +303,29 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
with ctx():
model = cls.from_config(diffusers_model_config)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]

else:
keep_in_fp32_modules = []

if hf_quantizer is not None:
hf_quantizer.preprocess_model(model=model, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
Expand All @@ -310,6 +339,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)

if torch_dtype is not None:
model.to(torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't cast the model when hf_quantizer is not None to preserve the data-types set during preprocessing and postprocessing:

# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will


Expand Down
93 changes: 86 additions & 7 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,30 @@
import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union

import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError
from tqdm import tqdm

from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_available,
is_torch_version,
logging,
)
from ..utils.import_utils import is_gguf_available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to add the method to the __init__.py of utils?



logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -140,6 +144,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
Expand Down Expand Up @@ -176,11 +182,9 @@ def load_model_dict_into_meta(
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
if hf_quantizer is None:
device = device or torch.device("cpu")
device = device or torch.device("cpu")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might have some consequences.

If device is passed as 0 (which is perfectly valid as a device id) then the device would be selected as "CPU", which is not what we want here no? For bnb, we pass the param_device to be:

param_device = torch.cuda.current_device()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w has an open PR for this #10069

dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict()
Expand Down Expand Up @@ -211,14 +215,15 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
Expand Down Expand Up @@ -396,3 +401,77 @@ def _fetch_index_file_legacy(
index_file = None

return index_file


def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type

if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value


def read_field(reader, field):
value = reader.fields[field]
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]


def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for these two. Additionally, read_field() sounds a bit ambiguous -- could do with a better name?

"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
attributes.

Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
metadata in memory.
"""

if is_gguf_available() and is_torch_available():
import gguf
from gguf import GGUFReader

from ..quantizers.gguf.utils import GGUFParameter
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check gguf version as well? (in addition to is_gguf_available)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Let's always suggest installing the latest stable build of gguf like we do for bitsandbytes.

if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):

"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")

reader = GGUFReader(gguf_checkpoint_path)
fields = reader.fields
reader_keys = list(fields.keys())

parsed_parameters = {}
for tensor in tqdm(reader.tensors):
name = tensor.name
quant_type = tensor.tensor_type

# if the tensor is a torch supported dtype do not use GGUFParameter
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could create a NON_TORCH_GGUF_DTYPE ENUM or SET with these two values (gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) and use NON_TORCH_GGUF_DTYPE here, instead.

weights = torch.from_numpy(tensor.data.copy())
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights

if len(reader_keys) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to understand this check here,

I think maybe when we iterate through the tensors we also remove the names from the reader_keys as we go, the check here would make sense - but I didn't see any code to remove anything; so maybe we forgot to remove them? or it's not the case? did I miss something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied this from transformers. But these aren't tensor keys. They're metadata keys

['GGUF.version', 'GGUF.tensor_count', 'GGUF.kv_count', 'general.architecture', 'general.quantization_version', 'general.file_type']

This can probably just be removed since the info isn't too relevant.

logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

return parsed_parameters
2 changes: 2 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from typing import Dict, Optional, Union

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod


AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def create_quantized_param(

module._parameters[tensor_name] = new_value

def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GGUF needs to access the tensor quant type to run a shape check. So this needs to change from passing in shapes to passing in params directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add this method to the gguf_quantizer.py file instead of modifying this? This would be a breaking change no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you're already adding this to the GGUF quantizer class. So, maybe okay to not modify this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely make sense here to make sure this method has same signature across all quantizers, it will be confusing otherwise
in terms of breaking change, I think it is ok, but we can deprecate it if we want to be extra cautious

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think no deprecation is fine since this method is called from load_model_dict_into_meta(). But let's make sure to run the tests to ensure nothing's breaking.

current_param_shape = current_param.shape
loaded_param_shape = loaded_param.shape

n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gguf_quantizer import GGUFQuantizer
96 changes: 96 additions & 0 deletions src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ...utils import get_module_from_name
from ..base import DiffusersQuantizer
from .utils import GGUFParameter, _quant_shape_from_byte_shape, _replace_with_gguf_linear


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin

from ...utils import (
is_gguf_available,
is_torch_available,
logging,
)


if is_torch_available():
import torch

if is_gguf_available():
import gguf

logger = logging.get_logger(__name__)


class GGUFQuantizer(DiffusersQuantizer):
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so gguf will always be pre_quantized? it does not make sense to support converting it (like we do for bnb)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather take this from the config and then default it to True. If it's otherwise, we error out unless we support saving a model quantized with GGUF.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm seems like it's always going to be prequantized for this PR.


def check_quantized_param_shape(self, param_name, current_param, loaded_param):
loaded_param_shape = loaded_param.shape
current_param_shape = current_param.shape
quant_type = loaded_param.quant_type

block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]

inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
Copy link
Member

@sayakpaul sayakpaul Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
inferred_shape = _infer_quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)

if inferred_shape != current_param_shape:
raise ValueError(
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
)

return True

def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if isinstance(param_value, GGUFParameter):
return True

return False

def create_quantized_param(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bitsandbytes, the bias is often left untouched and is kept as the original dtype (torch.bfloat16) for example. Does the same apply to GGUF or we can quantize everything?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default logic in llama.cpp is here, the weights quantized with the comfy repo follow this logic, so biases and 1d tensors are left unquantized.

They're FP32 internally in llama.cpp since they're fairly small (and the operation for numpy/gguf FP32 -> torch FP32 doesn't require any dequantization, so doing that might be ever so slightly faster. Same goes for FP16, but not BF16 due to not being a native numpy datatype).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking with @DN6 if we're following the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Biases and 1D tensors are unquantized. @city96 created the GGUF quants for Flux and SD3.* so they're the expert here :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I don't remember any specific code / tests for checking that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self,
model: "ModelMixin",
param_value: "torch.Tensor",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be a union of GGUF tensor and torch.Tensor right?

param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should state_dict and unexpected_keys not be used? That seems a bit weird no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Not needed here. I think we could just set them to Optional then no?

):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check for buffers also in module._buffers as it is not included in module._parameters


module._parameters[tensor_name] = param_value

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
_replace_with_gguf_linear(model, self.compute_dtype)

def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
return model

@property
def is_serializable(self):
return False

@property
def is_trainable(self) -> bool:
# Because we're mandating `bitsandbytes` 0.43.3.
return False
Loading
Loading