Skip to content

add cached load_lora_weight#524

Merged
marigoold merged 22 commits intomainfrom
dev_wy_cached_lora
Jan 26, 2024
Merged

add cached load_lora_weight#524
marigoold merged 22 commits intomainfrom
dev_wy_cached_lora

Conversation

@marigoold
Copy link
Collaborator

@marigoold marigoold commented Jan 16, 2024

add cache for loaded LoRAs based on diffusers load_lora_weights, to avoid time cost of loading the same LoRA from disk

TODO:

  • support local file cached
  • support lora downloaded from hub cached
  • support unfuse lora
  • support custom offload
  • profile

diffusers 原来 load LoRA 的方法中,时间开销最大的地方是 LoRA module 的参数初始化,但这一步是在推理中不需要的,是一个主要的优化点。

这里在 examples/text_to_image_sdxl_lora.py 里面增加了多种使用 LoRA 的方法,分别是:

  1. 只使用 load_lora_weights,这会改变 Linear forward 的计算路径,从而改变计算图。好处是不用 fuse LoRA,把 LoRA 的计算推迟到推理时,坏处就是推理性能下降
  2. 使用 load_lora_weights 和 fuse LoRA 来加载 LoRA,好处是推理性能不变,坏处是加载 LoRA 需要一些时间
  3. 本 PR 开发的 load_and_fuse_lora,可以在保证推理性能的前提下,尽可能减少加载、切换 LoRA 的开销。具体思路是增加一个 cache,保存 LoRA 的 cpu offload,下次导入的时候直接从内存中读取,减少磁盘读取的开销。另外手动重写了 fuse 过程,跳过了 LoRA module 参数初始化的过程,节省了大部分时间。

推理、加载速度 profile 结果(加载内存中的 LoRA dict):

 /data/h/w/w/diffusers/examples  dev_wy_cached_lora *15 !1 ?13  python3 text_to_image_sdxl_lora.py
Loading pipeline components...: 100%|████████████████████████████████████| 7/7 [00:01<00:00,  5.57it/s]
[1] Elapsed time: 0.9750442989170551 seconds
100%|██████████████████████████████████████████████████████████████████| 30/30 [01:08<00:00,  2.28s/it]
100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00,  6.26it/s]
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
Loading pipeline components...: 100%|████████████████████████████████████| 7/7 [00:01<00:00,  5.51it/s]
100%|██████████████████████████████████████████████████████████████████| 30/30 [00:39<00:00,  1.32s/it]
[2] Elapsed time: 4.074353616917506 seconds
100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00,  7.18it/s]
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
[3] Elapsed time: 0.7907805619761348 seconds
100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00,  7.16it/s]
100%|██████████████████████████████████████████████████████████████████| 30/30 [00:04<00:00,  7.14it/s]

三种方法的时间分别为

  1. 0.9750442989170551 seconds
  2. 4.074353616917506 seconds
  3. 0.7907805619761348 seconds

加载三个 LoRA 的速度(不跑推理,LoRA dict):

 /data/h/w/w/diffusers/examples  dev_wy_cached_lora *15 !1 ?13  python3 /data/home/wangyi/workspace/temp/test.py                                               1 х  8s  py10 Py  wangyi@oneflow-28  05:57:56
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.38it/s]
[1] Elapsed time: 3.8003906158264726 seconds
[2] Elapsed time: 5.7611241028644145 seconds
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT.
[3] Elapsed time: 2.2499090780038387 seconds

三种方法的速度分别是:

  1. 3.8003906158264726 seconds
  2. 5.7611241028644145 seconds
  3. 2.2499090780038387 seconds

profile 了一下用时占比,可以看到用时从高到低是:getattr(DualModule 的设计问题),linear fuse,linear unfuse

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.258    0.258    1.390    1.390 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:179(load_and_fuse_lora)
11999/7640    0.016    0.000    0.599    0.000 {built-in method builtins.getattr}
7996/4359    0.015    0.000    0.583    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:82(__getattr__)
     2322    0.025    0.000    0.500    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:120(__init__)
      722    0.058    0.000    0.322    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:30(linear_fuse_lora)
    11788    0.006    0.000    0.279    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:159(__init__)
    11788    0.016    0.000    0.273    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:21(__init__)
  1063466    0.160    0.000    0.160    0.000 {method 'replace' of 'str' objects}
    11788    0.006    0.000    0.145    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:157(get_mixed_dual_module)
    14110    0.136    0.000    0.145    0.000 /home/wangyi/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:437(__init__)
    11788    0.134    0.000    0.139    0.000 {built-in method builtins.__build_class__}
    23576    0.020    0.000    0.133    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:105(__setattr__)
    25978    0.067    0.000    0.127    0.000 /home/wangyi/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:1617(__setattr__)
      722    0.036    0.000    0.120    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/utils/lora.py:75(linear_unfuse_lora)
 1446/723    0.002    0.000    0.117    0.000 /data/home/wangyi/workspace/diffusers/src/onediff/infer_compiler/with_oneflow_compile.py:303(__getattr__)

@marigoold marigoold marked this pull request as ready for review January 22, 2024 05:08
):
assert isinstance(self, torch.nn.Linear)
if isinstance(self, DualModule):
self = self._torch_module
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有风险

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是没有风险的,外面传进来的是一个 getattr 得到的 DualModule,是一个临时对象 @strint

Comment on lines +48 to +52
self.register_buffer("_lora_up", w_up.to(offload_device))
self.register_buffer(
"_lora_down", state_dict["lora.down.weight"].to(offload_device)
)
self._lora_scale = lora_scale
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu to gpu 可能不做参数拷贝

rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low_cpu_mem_usage 做了什么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low_cpu_mem_usage 做了什么

如果这里 low_cpu_mem_usage 是 True,就把 torch 的默认 device 改成 meta,到之后的初始化传入的 tensor 都是 meta tensor。
但是神奇的是我把这里的 low_cpu_mem_usage 改成 False,在 diffusers 的 load_lora_weights 里耗时也是差不多的。
然后我对比了一下 cpu uniform 和 meta uniform,前者直接调用 C++ 的 uniform 接口,后者还有很长的 Python 调用链。虽然 C++ 里有针对 meta 的 uniform 实现,但是好像没走到这个接口(我自己编译了一下 pytorch,加了 cout,发现 cpu uniform 有输出,但是 meta uniform 没有)。
再具体就不追查了,结论是这里是什么对 linear_fuse_lora 之类的没有影响,可以删掉

@marigoold
Copy link
Collaborator Author

lora.py 放到 diffusers ext 里面

@strint
Copy link
Collaborator

strint commented Jan 25, 2024

mark 记得改下 lora 的 readme,记录下性能结果

@strint strint mentioned this pull request Jan 26, 2024
@marigoold marigoold merged commit f8484d1 into main Jan 26, 2024
@marigoold marigoold deleted the dev_wy_cached_lora branch January 26, 2024 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants