From 01c5cf251b40041508934dd845556b2d2c8c7c62 Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Thu, 30 May 2024 19:29:53 +0200 Subject: [PATCH 1/6] [LoRA] text encoder: read the ranks for all the attn modules * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this) --- src/diffusers/loaders/lora.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index e089d202ee75..d8c7fdee0f48 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -566,17 +566,18 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_B.weight" - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_B.weight" - rank_key_fc2 = f"{name}.fc2.lora_B.weight" - - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ('fc1', 'fc2'): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ From cb5bff392e1912494332390ccb5e5d2661717c40 Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Fri, 31 May 2024 19:06:54 +0200 Subject: [PATCH 2/6] ruff format loaders.lora --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index d8c7fdee0f48..edb82d53ef47 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -573,7 +573,7 @@ def load_lora_into_text_encoder( rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ('fc1', 'fc2'): + for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue From a74b84c251769c0051ca874142d6d32710948082 Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Thu, 6 Jun 2024 16:28:47 +0200 Subject: [PATCH 3/6] [LoRA] add tests for partial text encoders LoRAs --- tests/lora/utils.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d08a26645602..7823256a61c7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -395,6 +395,65 @@ def test_simple_inference_with_text_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) + def test_simple_inference_with_partial_text_lora(self): + """ + Tests a simple inference with lora attached on the text encoder + with different ranks and some adapters removed + and makes sure it works as expected + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, _, _ = self.get_dummy_components(scheduler_cls) + text_lora_config = LoraConfig( + r=4, + rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3}, + lora_alpha=4, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, + use_dora=False, + ) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + } + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + state_dict.update( + { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + } + ) + + # Discard half of the adapters. + rng = np.random.default_rng(0) + key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()} + adapters = list(set(key2adapters.values())) + adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False)) + state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters} + + # Unload lora and load it back using the pipe.load_lora_weights machinery + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + def test_simple_inference_save_pretrained(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained From eabe62236770d498952383101e8f06c1a483c157 Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:56:32 +0200 Subject: [PATCH 4/6] [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic --- tests/lora/utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7823256a61c7..641acd7c2958 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -424,6 +424,7 @@ def test_simple_inference_with_partial_text_lora(self): state_dict = { f"text_encoder.{module_name}": param for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "text_model.encoder.layers.4" not in module_name } if self.has_two_text_encoders: @@ -433,25 +434,25 @@ def test_simple_inference_with_partial_text_lora(self): ) state_dict.update( { - f"text_encoder.{module_name}": param + f"text_encoder_2.{module_name}": param for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name } ) - # Discard half of the adapters. - rng = np.random.default_rng(0) - key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()} - adapters = list(set(key2adapters.values())) - adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False)) - state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters} + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) # Unload lora and load it back using the pipe.load_lora_weights machinery pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) - output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), + "Removing adapters should change the output", ) def test_simple_inference_save_pretrained(self): From f9a4375148aec8024ef2d3409f025956832be2fc Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:55:05 +0200 Subject: [PATCH 5/6] [LoRA] comment justifying test_simple_inference_with_partial_text_lora --- tests/lora/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 641acd7c2958..9a07727db931 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -403,6 +403,7 @@ def test_simple_inference_with_partial_text_lora(self): """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: components, _, _ = self.get_dummy_components(scheduler_cls) + # Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( r=4, rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3}, @@ -421,6 +422,8 @@ def test_simple_inference_with_partial_text_lora(self): pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). state_dict = { f"text_encoder.{module_name}": param for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() From 8c3c264e548400a98980c0f4faac0d57b5489b9b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 16 Jun 2024 19:29:59 +0100 Subject: [PATCH 6/6] style --- src/diffusers/dependency_versions_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9413be5e4eed..e11410cf6eec 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -40,7 +40,7 @@ "tensorboard": "tensorboard", "torch": "torch>=1.4", "torchvision": "torchvision", - "transformers": "transformers>=4.25.1", + "transformers": "transformers>=4.41.2", "urllib3": "urllib3<=2.0.0", "black": "black", }