-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Hardware] Support platforms and plugin system #774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc4d8d5
c6d0c32
44350a3
8664384
a83a7d3
aa21514
af48d31
22e67de
f30353a
41ee928
e13a96f
648e1b3
26f9429
dcd6420
304c59b
15cf31a
09afcb2
9bef25f
3be1fcb
45df182
f139e30
1aaf259
5880aa2
0c1dc43
535d971
d7bb244
e8728f7
dfe3d63
d8d9a15
5bb5d7b
ab4c946
b8c147e
c37007f
881d0a8
e40da71
c753894
81f17b2
273ac7a
4068538
9d8e218
56438fb
84df573
656932c
2bb6dda
dfd724c
f8249c0
4360d6f
81a1762
98d6e18
8334c0b
46c041a
9abea44
cc15fd6
8a701fe
3092dac
a492a1a
d3674b6
aafbe0a
e52c29e
338f9c1
6ded4da
0bf040c
8ff2697
8dcab54
8dc7756
535a7fe
633a23d
dd00f6f
2806b53
aec46c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,7 +81,7 @@ | |
| from vllm_omni.entrypoints.omni import Omni | ||
| from vllm_omni.inputs.data import OmniDiffusionSamplingParams | ||
| from vllm_omni.outputs import OmniRequestOutput | ||
| from vllm_omni.utils.platform_utils import detect_device_type, is_npu | ||
| from vllm_omni.platforms import current_omni_platform | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
|
|
@@ -280,6 +280,16 @@ def parse_args() -> argparse.Namespace: | |
| action="store_true", | ||
| help="Disable torch.compile and force eager execution.", | ||
| ) | ||
| parser.add_argument( | ||
| "--vae_use_slicing", | ||
| action="store_true", | ||
| help="Enable VAE slicing for memory optimization.", | ||
| ) | ||
| parser.add_argument( | ||
| "--vae_use_tiling", | ||
| action="store_true", | ||
| help="Enable VAE tiling for memory optimization.", | ||
| ) | ||
| parser.add_argument( | ||
| "--enable-cpu-offload", | ||
| action="store_true", | ||
|
|
@@ -306,12 +316,8 @@ def main(): | |
| else: | ||
| input_image = input_images | ||
|
|
||
| device = detect_device_type() | ||
| generator = torch.Generator(device=device).manual_seed(args.seed) | ||
| generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand the usage. A suggestion for discussion. The function of the api is somewhat torch style. So could we change it to vllm_omni.device_type for simplicity? And so for so on.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I'm not sure I fully understood your point. Do |
||
|
|
||
| # Enable VAE memory optimizations on NPU | ||
| vae_use_slicing = is_npu() | ||
| vae_use_tiling = is_npu() | ||
| parallel_config = DiffusionParallelConfig( | ||
| ulysses_degree=args.ulysses_degree, | ||
| ring_degree=args.ring_degree, | ||
|
|
@@ -344,8 +350,8 @@ def main(): | |
| # Initialize Omni with appropriate pipeline | ||
| omni = Omni( | ||
| model=args.model, | ||
| vae_use_slicing=vae_use_slicing, | ||
| vae_use_tiling=vae_use_tiling, | ||
| vae_use_slicing=args.vae_use_slicing, | ||
| vae_use_tiling=args.vae_use_tiling, | ||
| cache_backend=args.cache_backend, | ||
| cache_config=cache_config, | ||
| parallel_config=parallel_config, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I move UT into GPU queue but it wouldn't cost GPU resources, so I guess it's okay. The reason is that after introducing platform, when it initialized, torch._C ops would be imported. If keeping in cpu queue, it will raise the error below:
See more details:https://buildkite.com/vllm/vllm-omni/builds/1913/steps/canvas?sid=019be051-4a0c-43ef-8b20-47464b363092