-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Single File] Add GGUF support #9964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b5eeaa4
71897b1
89ea1ee
f0bcd94
60d1385
22ed0b0
2e6d340
b5f927c
b9666c7
6dc5d22
428e44b
d7f09f2
1649936
28d3a64
c34a451
84493db
50bd784
8f604b3
afd5d7d
e1b964a
0ed31bc
af381ad
52a1bcb
66ae46e
67f1700
8abfa55
d4b88d7
30f13ed
9310035
e9303a0
e56c266
1209c3a
db9b6f3
4c0360a
aa7659b
78c7861
33eb431
9651ddc
746fd2f
e027d46
9db2396
7ee89f4
edf3e54
d3eb54f
82606cb
4f34f14
090efdb
391b5a9
e67c25a
e710bde
f59e07a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -17,8 +17,10 @@ | |||
from contextlib import nullcontext | ||||
from typing import Optional | ||||
|
||||
import torch | ||||
from huggingface_hub.utils import validate_hf_hub_args | ||||
|
||||
from ..quantizers import DiffusersAutoQuantizer | ||||
from ..utils import deprecate, is_accelerate_available, logging | ||||
from .single_file_utils import ( | ||||
SingleFileComponentError, | ||||
|
@@ -202,6 +204,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
subfolder = kwargs.pop("subfolder", None) | ||||
revision = kwargs.pop("revision", None) | ||||
torch_dtype = kwargs.pop("torch_dtype", None) | ||||
quantization_config = kwargs.pop("quantization_config", None) | ||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict): | ||||
checkpoint = pretrained_model_link_or_path_or_dict | ||||
|
@@ -215,6 +218,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
local_files_only=local_files_only, | ||||
revision=revision, | ||||
) | ||||
if quantization_config is not None: | ||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) | ||||
|
||||
else: | ||||
hf_quantizer = None | ||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] | ||||
|
||||
|
@@ -295,8 +303,29 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
with ctx(): | ||||
model = cls.from_config(diffusers_model_config) | ||||
|
||||
# Check if `_keep_in_fp32_modules` is not None | ||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( | ||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") | ||||
) | ||||
if use_keep_in_fp32_modules: | ||||
keep_in_fp32_modules = cls._keep_in_fp32_modules | ||||
if not isinstance(keep_in_fp32_modules, list): | ||||
keep_in_fp32_modules = [keep_in_fp32_modules] | ||||
|
||||
else: | ||||
keep_in_fp32_modules = [] | ||||
|
||||
if hf_quantizer is not None: | ||||
hf_quantizer.preprocess_model(model=model, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules) | ||||
|
||||
if is_accelerate_available(): | ||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) | ||||
unexpected_keys = load_model_dict_into_meta( | ||||
model, | ||||
diffusers_format_checkpoint, | ||||
dtype=torch_dtype, | ||||
hf_quantizer=hf_quantizer, | ||||
keep_in_fp32_modules=keep_in_fp32_modules, | ||||
) | ||||
|
||||
else: | ||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) | ||||
|
@@ -310,6 +339,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | |||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" | ||||
) | ||||
|
||||
if hf_quantizer is not None: | ||||
hf_quantizer.postprocess_model(model) | ||||
|
||||
if torch_dtype is not None: | ||||
model.to(torch_dtype) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't cast the model when
|
||||
|
||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -17,26 +17,30 @@ | |||
import importlib | ||||
import inspect | ||||
import os | ||||
from array import array | ||||
from collections import OrderedDict | ||||
from pathlib import Path | ||||
from typing import List, Optional, Union | ||||
|
||||
import safetensors | ||||
import torch | ||||
from huggingface_hub.utils import EntryNotFoundError | ||||
from tqdm import tqdm | ||||
|
||||
from ..quantizers.quantization_config import QuantizationMethod | ||||
from ..utils import ( | ||||
GGUF_FILE_EXTENSION, | ||||
SAFE_WEIGHTS_INDEX_NAME, | ||||
SAFETENSORS_FILE_EXTENSION, | ||||
WEIGHTS_INDEX_NAME, | ||||
_add_variant, | ||||
_get_model_file, | ||||
deprecate, | ||||
is_accelerate_available, | ||||
is_torch_available, | ||||
is_torch_version, | ||||
logging, | ||||
) | ||||
from ..utils.import_utils import is_gguf_available | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might make sense to add the method to the |
||||
|
||||
|
||||
logger = logging.get_logger(__name__) | ||||
|
@@ -140,6 +144,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | |||
file_extension = os.path.basename(checkpoint_file).split(".")[-1] | ||||
if file_extension == SAFETENSORS_FILE_EXTENSION: | ||||
return safetensors.torch.load_file(checkpoint_file, device="cpu") | ||||
elif file_extension == GGUF_FILE_EXTENSION: | ||||
return load_gguf_checkpoint(checkpoint_file) | ||||
else: | ||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} | ||||
return torch.load( | ||||
|
@@ -176,11 +182,9 @@ def load_model_dict_into_meta( | |||
hf_quantizer=None, | ||||
keep_in_fp32_modules=None, | ||||
) -> List[str]: | ||||
if hf_quantizer is None: | ||||
device = device or torch.device("cpu") | ||||
device = device or torch.device("cpu") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might have some consequences. If
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @a-r-r-o-w has an open PR for this #10069 |
||||
dtype = dtype or torch.float32 | ||||
is_quantized = hf_quantizer is not None | ||||
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) | ||||
empty_state_dict = model.state_dict() | ||||
|
@@ -211,14 +215,15 @@ def load_model_dict_into_meta( | |||
set_module_kwargs["dtype"] = dtype | ||||
|
||||
# bnb params are flattened. | ||||
# gguf quants have a different shape based on the type of quantization applied | ||||
if empty_state_dict[param_name].shape != param.shape: | ||||
if ( | ||||
is_quant_method_bnb | ||||
is_quantized | ||||
and hf_quantizer.pre_quantized | ||||
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | ||||
): | ||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) | ||||
elif not is_quant_method_bnb: | ||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) | ||||
else: | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | ||||
raise ValueError( | ||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | ||||
|
@@ -396,3 +401,77 @@ def _fetch_index_file_legacy( | |||
index_file = None | ||||
|
||||
return index_file | ||||
|
||||
|
||||
def _gguf_parse_value(_value, data_type): | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
if not isinstance(data_type, list): | ||||
data_type = [data_type] | ||||
if len(data_type) == 1: | ||||
data_type = data_type[0] | ||||
array_data_type = None | ||||
else: | ||||
if data_type[0] != 9: | ||||
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") | ||||
data_type, array_data_type = data_type | ||||
|
||||
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: | ||||
_value = int(_value[0]) | ||||
elif data_type in [6, 12]: | ||||
_value = float(_value[0]) | ||||
elif data_type in [7]: | ||||
_value = bool(_value[0]) | ||||
elif data_type in [8]: | ||||
_value = array("B", list(_value)).tobytes().decode() | ||||
elif data_type in [9]: | ||||
_value = _gguf_parse_value(_value, array_data_type) | ||||
return _value | ||||
|
||||
|
||||
def read_field(reader, field): | ||||
value = reader.fields[field] | ||||
return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] | ||||
|
||||
|
||||
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for these two. Additionally, |
||||
""" | ||||
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config | ||||
attributes. | ||||
|
||||
Args: | ||||
gguf_checkpoint_path (`str`): | ||||
The path the to GGUF file to load | ||||
return_tensors (`bool`, defaults to `True`): | ||||
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the | ||||
metadata in memory. | ||||
""" | ||||
|
||||
if is_gguf_available() and is_torch_available(): | ||||
import gguf | ||||
from gguf import GGUFReader | ||||
|
||||
from ..quantizers.gguf.utils import GGUFParameter | ||||
else: | ||||
logger.error( | ||||
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to check gguf version as well? (in addition to is_gguf_available) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Let's always suggest installing the latest stable build of
|
||||
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." | ||||
) | ||||
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") | ||||
|
||||
reader = GGUFReader(gguf_checkpoint_path) | ||||
fields = reader.fields | ||||
reader_keys = list(fields.keys()) | ||||
|
||||
parsed_parameters = {} | ||||
for tensor in tqdm(reader.tensors): | ||||
name = tensor.name | ||||
quant_type = tensor.tensor_type | ||||
|
||||
# if the tensor is a torch supported dtype do not use GGUFParameter | ||||
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could create a |
||||
weights = torch.from_numpy(tensor.data.copy()) | ||||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights | ||||
|
||||
if len(reader_keys) > 0: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. trying to understand this check here, I think maybe when we iterate through the tensors we also remove the names from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copied this from transformers. But these aren't tensor keys. They're metadata keys
This can probably just be removed since the info isn't too relevant. |
||||
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") | ||||
|
||||
return parsed_parameters |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,7 +204,10 @@ def create_quantized_param( | |
|
||
module._parameters[tensor_name] = new_value | ||
|
||
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): | ||
def check_quantized_param_shape(self, param_name, current_param, loaded_param): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GGUF needs to access the tensor quant type to run a shape check. So this needs to change from passing in shapes to passing in params directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not add this method to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you're already adding this to the GGUF quantizer class. So, maybe okay to not modify this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely make sense here to make sure this method has same signature across all quantizers, it will be confusing otherwise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think no deprecation is fine since this method is called from |
||
current_param_shape = current_param.shape | ||
loaded_param_shape = loaded_param.shape | ||
|
||
n = current_param_shape.numel() | ||
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) | ||
if loaded_param_shape != inferred_shape: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gguf_quantizer import GGUFQuantizer |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,96 @@ | ||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||||||
|
||||||
from ...utils import get_module_from_name | ||||||
from ..base import DiffusersQuantizer | ||||||
from .utils import GGUFParameter, _quant_shape_from_byte_shape, _replace_with_gguf_linear | ||||||
|
||||||
|
||||||
if TYPE_CHECKING: | ||||||
from ...models.modeling_utils import ModelMixin | ||||||
|
||||||
from ...utils import ( | ||||||
is_gguf_available, | ||||||
is_torch_available, | ||||||
logging, | ||||||
) | ||||||
|
||||||
|
||||||
if is_torch_available(): | ||||||
import torch | ||||||
|
||||||
if is_gguf_available(): | ||||||
import gguf | ||||||
|
||||||
logger = logging.get_logger(__name__) | ||||||
|
||||||
|
||||||
class GGUFQuantizer(DiffusersQuantizer): | ||||||
def __init__(self, quantization_config, **kwargs): | ||||||
super().__init__(quantization_config, **kwargs) | ||||||
|
||||||
self.compute_dtype = quantization_config.compute_dtype | ||||||
self.pre_quantized = True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so gguf will always be pre_quantized? it does not make sense to support converting it (like we do for bnb)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather take this from the config and then default it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm seems like it's always going to be prequantized for this PR. |
||||||
|
||||||
def check_quantized_param_shape(self, param_name, current_param, loaded_param): | ||||||
loaded_param_shape = loaded_param.shape | ||||||
current_param_shape = current_param.shape | ||||||
quant_type = loaded_param.quant_type | ||||||
|
||||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] | ||||||
|
||||||
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if inferred_shape != current_param_shape: | ||||||
raise ValueError( | ||||||
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}" | ||||||
) | ||||||
|
||||||
return True | ||||||
|
||||||
def check_if_quantized_param( | ||||||
self, | ||||||
model: "ModelMixin", | ||||||
param_value: "torch.Tensor", | ||||||
DN6 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
param_name: str, | ||||||
state_dict: Dict[str, Any], | ||||||
**kwargs, | ||||||
) -> bool: | ||||||
if isinstance(param_value, GGUFParameter): | ||||||
return True | ||||||
|
||||||
return False | ||||||
|
||||||
def create_quantized_param( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For bitsandbytes, the bias is often left untouched and is kept as the original There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default logic in llama.cpp is here, the weights quantized with the comfy repo follow this logic, so biases and 1d tensors are left unquantized. They're FP32 internally in llama.cpp since they're fairly small (and the operation for numpy/gguf FP32 -> torch FP32 doesn't require any dequantization, so doing that might be ever so slightly faster. Same goes for FP16, but not BF16 due to not being a native numpy datatype). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checking with @DN6 if we're following the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Biases and 1D tensors are unquantized. @city96 created the GGUF quants for Flux and SD3.* so they're the expert here :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I don't remember any specific code / tests for checking that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah seems like this: Maybe having a test for this would be nice! |
||||||
self, | ||||||
model: "ModelMixin", | ||||||
param_value: "torch.Tensor", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be a union of GGUF tensor and |
||||||
param_name: str, | ||||||
target_device: "torch.device", | ||||||
state_dict: Dict[str, Any], | ||||||
unexpected_keys: Optional[List[str]] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. Not needed here. I think we could just set them to |
||||||
): | ||||||
module, tensor_name = get_module_from_name(model, param_name) | ||||||
if tensor_name not in module._parameters: | ||||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should check for buffers also in |
||||||
|
||||||
module._parameters[tensor_name] = param_value | ||||||
|
||||||
def _process_model_before_weight_loading( | ||||||
self, | ||||||
model: "ModelMixin", | ||||||
device_map, | ||||||
keep_in_fp32_modules: List[str] = [], | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
**kwargs, | ||||||
): | ||||||
_replace_with_gguf_linear(model, self.compute_dtype) | ||||||
|
||||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return model | ||||||
|
||||||
@property | ||||||
def is_serializable(self): | ||||||
return False | ||||||
|
||||||
@property | ||||||
def is_trainable(self) -> bool: | ||||||
# Because we're mandating `bitsandbytes` 0.43.3. | ||||||
return False |
Uh oh!
There was an error while loading. Please reload this page.