1
- import gradio as gr
2
- import mindspore
3
- import mindnlp
4
1
from mindnlp import core
2
+ import gradio as gr
5
3
from janus .janusflow .models import MultiModalityCausalLM , VLChatProcessor
6
4
from PIL import Image
5
+ from transformers import DynamicCache
7
6
from diffusers .models import AutoencoderKL
8
7
import numpy as np
9
8
9
+ device = 'cpu'
10
+ if core .npu .is_available ():
11
+ device = 'npu'
12
+ elif core .cuda .is_available ():
13
+ device = 'cuda'
14
+
10
15
# Load model and processor
11
16
model_path = "deepseek-ai/JanusFlow-1.3B"
12
17
vl_chat_processor = VLChatProcessor .from_pretrained (model_path )
13
18
tokenizer = vl_chat_processor .tokenizer
14
19
15
- vl_gpt = MultiModalityCausalLM .from_pretrained (model_path , ms_dtype = mindspore . float16 )
16
- vl_gpt = vl_gpt .eval ()
20
+ vl_gpt = MultiModalityCausalLM .from_pretrained (model_path )
21
+ vl_gpt = vl_gpt .to ( core . bfloat16 ). to ( device ). eval ()
17
22
18
23
# remember to use bfloat16 dtype, this vae doesn't work with fp16
19
- vae = AutoencoderKL .from_pretrained ("stabilityai/sdxl-vae" , ms_dtype = mindspore . float16 )
20
- vae = vae .eval ()
24
+ vae = AutoencoderKL .from_pretrained ("stabilityai/sdxl-vae" )
25
+ vae = vae .to ( core . bfloat16 ). to ( device ). eval ()
21
26
27
+ # Multimodal Understanding function
28
+ @core .inference_mode ()
22
29
# Multimodal Understanding function
23
30
def multimodal_understanding (image , question , seed , top_p , temperature ):
31
+ # Clear CUDA cache before generating
32
+ core .cuda .empty_cache ()
33
+
24
34
# set seed
25
- mindspore .manual_seed (seed )
35
+ core .manual_seed (seed )
26
36
np .random .seed (seed )
37
+ core .cuda .manual_seed (seed )
27
38
28
39
conversation = [
29
40
{
@@ -37,9 +48,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
37
48
pil_images = [Image .fromarray (image )]
38
49
prepare_inputs = vl_chat_processor (
39
50
conversations = conversation , images = pil_images , force_batchify = True
40
- ).to (core .get_default_device (), mindspore .float16 )
41
-
42
-
51
+ ).to (device , dtype = core .bfloat16 if core . cuda . is_available () else core .float16 )
52
+
53
+
43
54
inputs_embeds = vl_gpt .prepare_inputs_embeds (** prepare_inputs )
44
55
45
56
outputs = vl_gpt .language_model .generate (
@@ -60,13 +71,14 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
60
71
return answer
61
72
62
73
74
+ @core .inference_mode ()
63
75
def generate (
64
76
input_ids ,
65
77
cfg_weight : float = 2.0 ,
66
78
num_inference_steps : int = 30
67
79
):
68
80
# we generate 5 images at a time, *2 for CFG
69
- tokens = core .stack ([input_ids ] * 10 )
81
+ tokens = core .stack ([input_ids ] * 10 ). cuda ()
70
82
tokens [5 :, 1 :] = vl_chat_processor .pad_id
71
83
inputs_embeds = vl_gpt .language_model .get_input_embeddings ()(tokens )
72
84
print (inputs_embeds .shape )
@@ -76,10 +88,10 @@ def generate(
76
88
77
89
# generate with rectified flow ode
78
90
# step 1: encode with vision_gen_enc
79
- z = core .randn ((5 , 4 , 48 , 48 ), dtype = mindspore . float16 )
91
+ z = core .randn ((5 , 4 , 48 , 48 ), dtype = core . bfloat16 ). cuda ( )
80
92
81
93
dt = 1.0 / num_inference_steps
82
- dt = core .zeros_like (z ).to (mindspore . float16 ) + dt
94
+ dt = core .zeros_like (z ).cuda (). to (core . bfloat16 ) + dt
83
95
84
96
# step 2: run ode
85
97
attention_mask = core .ones ((10 , inputs_embeds .shape [1 ]+ 577 )).to (vl_gpt .device )
@@ -103,18 +115,21 @@ def generate(
103
115
use_cache = True ,
104
116
attention_mask = attention_mask ,
105
117
past_key_values = None )
106
- past_key_values = []
107
- for kv_cache in past_key_values :
108
- k , v = kv_cache [0 ], kv_cache [1 ]
109
- past_key_values .append ((k [:, :, :inputs_embeds .shape [1 ], :], v [:, :, :inputs_embeds .shape [1 ], :]))
110
- past_key_values = tuple (past_key_values )
118
+ past_key_values = DynamicCache .from_legacy_cache (outputs .past_key_values )
119
+
111
120
else :
112
121
outputs = vl_gpt .language_model .model (inputs_embeds = llm_emb ,
113
122
use_cache = True ,
114
123
attention_mask = attention_mask ,
115
124
past_key_values = past_key_values )
125
+ past_key_values = []
126
+ for kv_cache in outputs .past_key_values :
127
+ k , v = kv_cache [0 ], kv_cache [1 ]
128
+ past_key_values .append ((k [:, :, :inputs_embeds .shape [1 ], :], v [:, :, :inputs_embeds .shape [1 ], :]))
129
+ past_key_values = DynamicCache .from_legacy_cache (past_key_values )
130
+
116
131
hidden_states = outputs .last_hidden_state
117
-
132
+
118
133
# transform hidden_states back to v
119
134
hidden_states = vl_gpt .vision_gen_dec_aligner (vl_gpt .vision_gen_dec_aligner_norm (hidden_states [:, - 576 :, :]))
120
135
hidden_states = hidden_states .reshape (z_emb .shape [0 ], 24 , 24 , 768 ).permute (0 , 3 , 1 , 2 )
@@ -141,13 +156,17 @@ def unpack(dec, width, height, parallel_size=5):
141
156
return visual_img
142
157
143
158
159
+ @core .inference_mode ()
144
160
def generate_image (prompt ,
145
161
seed = None ,
146
162
guidance = 5 ,
147
163
num_inference_steps = 30 ):
164
+ # Clear CUDA cache and avoid tracking gradients
165
+ core .cuda .empty_cache ()
148
166
# Set the seed for reproducible results
149
167
if seed is not None :
150
- mindspore .manual_seed (seed )
168
+ core .manual_seed (seed )
169
+ core .cuda .manual_seed (seed )
151
170
np .random .seed (seed )
152
171
153
172
with core .no_grad ():
0 commit comments