From aeeadc7941a935c2ef66c12cc15d60d06ff15c17 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Dec 2024 16:12:11 +0100 Subject: [PATCH 1/2] update --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4a486fd4ce40..9796cd38f293 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -836,7 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor elif is_quant_method_bnb: - param_device = torch.cuda.current_device() + param_device = torch.device(torch.cuda.current_device()) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) From 126d84eac2f7a52d7787c80299cd037a490b42c9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 4 Dec 2024 22:31:12 +0100 Subject: [PATCH 2/2] apply review suggestion --- src/diffusers/models/model_loading_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 932a94571107..751117f8f247 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -176,6 +176,8 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: + if device is not None and not isinstance(device, (str, torch.device)): + raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") if hf_quantizer is None: device = device or torch.device("cpu") dtype = dtype or torch.float32