Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 17 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,25 @@ We are thrilled to release **Qwen-Image**, a 20B MMDiT image foundation model th
- Due to heavy traffic, if you'd like to experience our demo online, we also recommend visiting DashScope, WaveSpeed, and LibLib. Please find the links below in the community support.

## Quick Start

1. Make sure your transformers>=4.51.3 (Supporting Qwen2.5-VL)

2. Install the latest version of diffusers
```
pip install git+https://github.com/huggingface/diffusers
git clone https://github.com/QwenLM/Qwen-Image
cd Qwen-Image
conda create -n Qwen-Image python=3.10
conda activate Qwen-Image
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu128
pip install -r requirements.txt
```

The following contains a code snippet illustrating how to use the model to generate images based on text prompts:

```python
from diffusers import DiffusionPipeline
import torch

model_name = "Qwen/Qwen-Image"

# Load the pipeline
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
device = "cuda"
else:
torch_dtype = torch.float32
device = "cpu"

pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch_dtype)
pipe = pipe.to(device)

positive_magic = {
"en": "Ultra HD, 4K, cinematic composition.", # for english prompt
"zh": "超清,4K,电影级构图" # for chinese prompt
}

# Generate image
prompt = '''A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197".'''

negative_prompt = " " # Recommended if you don't use a negative prompt.


# Generate with different aspect ratios
aspect_ratios = {
"1:1": (1328, 1328),
"16:9": (1664, 928),
"9:16": (928, 1664),
"4:3": (1472, 1140),
"3:4": (1140, 1472)
}

width, height = aspect_ratios["16:9"]

image = pipe(
prompt=prompt + positive_magic["en"],
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator(device="cuda").manual_seed(42)
).images[0]

image.save("example.png")
Download the model weights:
```
python download.py
```
Quick start with the gradio demo. Cost vram 39G:
```
python app.py
```
Quick start with the low vram mode. Cost vram 22G:
```
python app.py --vram low
```

## Show Cases
Expand Down
123 changes: 123 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import argparse
import os
import numpy as np
import datetime
import random
from diffusers import DiffusionPipeline
import gradio as gr
from optimum.quanto import freeze, qint8, quantize

parser = argparse.ArgumentParser()
parser.add_argument("--server_name", type=str, default="127.0.0.1", help="IP地址,局域网访问改为0.0.0.0")
parser.add_argument("--server_port", type=int, default=7891, help="使用端口")
parser.add_argument("--share", action="store_true", help="是否启用gradio共享")
parser.add_argument("--mcp_server", action="store_true", help="是否启用mcp服务")
parser.add_argument('--vram', type=str, default='high', choices=['low', 'high'], help='显存模式')
parser.add_argument('--lora', type=str, default="None", help='lora模型路径')
args = parser.parse_args()


if torch.cuda.is_available():
device = "cuda"
if torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32


MAX_SEED = np.iinfo(np.int32).max
os.makedirs("outputs", exist_ok=True)
model_id = "models/Qwen-Image"
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
)
if args.lora!="None":
pipe.load_lora_weights(args.lora)
print(f"加载{args.lora}")
if args.vram=="high":
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()
else:
quantize(pipe.transformer, qint8)
freeze(pipe.transformer)
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()


def generate(
prompt,
negative_prompt,
width,
height,
num_inference_steps,
true_cfg_scale,
seed_param,
):
global pipe, model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if seed_param<0:
seed = random.randint(0, MAX_SEED)
else:
seed = seed_param
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
true_cfg_scale=true_cfg_scale,
generator=torch.Generator().manual_seed(seed)
).images[0]
image.save(f"outputs/{timestamp}.png")
return f"outputs/{timestamp}.png", seed


with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("""
<div>
<h2 style="font-size: 30px;text-align: center;">Qwen-Image</h2>
</div>
""")
with gr.TabItem("Qwen-Image文生图"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="提示词", value="超清,4K,电影级构图,")
negative_prompt = gr.Textbox(label="负面提示词", value="")
width = gr.Slider(label="宽度(推荐1328x1328、1664x928、1472x1140)", minimum=256, maximum=2656, step=32, value=1328)
height = gr.Slider(label="高度", minimum=256, maximum=2656, step=32, value=1328)
num_inference_steps = gr.Slider(label="采样步数", minimum=1, maximum=100, step=1, value=50)
true_cfg_scale = gr.Slider(label="true cfg scale", minimum=1, maximum=10, step=0.1, value=4.0)
seed_param = gr.Number(label="种子,请输入正整数,-1为随机", value=-1)
generate_button = gr.Button("🎬 开始生成", variant='primary')
with gr.Column():
image_output = gr.Image(label="生成图片")
seed_output = gr.Textbox(label="种子")

gr.on(
triggers=[generate_button.click, prompt.submit, negative_prompt.submit],
fn = generate,
inputs = [
prompt,
negative_prompt,
width,
height,
num_inference_steps,
true_cfg_scale,
seed_param,
],
outputs = [image_output, seed_output]
)

if __name__ == "__main__":
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
mcp_server=args.mcp_server,
inbrowser=True,
)
32 changes: 32 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from huggingface_hub import snapshot_download, login
import os

model_name="Qwen/Qwen-Image"
folder_name = model_name.split('/')[-1]

def download_model(model_name):
snapshot_download(
repo_id=model_name,
local_dir=f"models/{folder_name}",
allow_patterns=[
"*",
],
resume_download=True
)
print("下载成功!")

try:
download_model(model_name)

except Exception as e:
if "401" in str(e):
# 2. 需要登录
print("需要 Hugging Face 登录,创建新令牌(选择'Read'权限即可)")
token = input("请输入您的 Hugging Face 访问令牌 (https://huggingface.co/settings/tokens): ")

# 登录并重试
login(token=token)
download_model(model_name)

else:
print(f"下载失败: {str(e)}")
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
git+https://github.com/huggingface/diffusers
transformers
huggingface_hub
gradio
numpy==1.26.4
accelerate
optimum-quanto