Skip to content

Commit f5c1be8

Browse files
authored
fix janus demo with device dispatch (#2161)
1 parent d558645 commit f5c1be8

File tree

7 files changed

+107
-50
lines changed

7 files changed

+107
-50
lines changed

examples/diffusers/janus/demo/app.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import gradio as gr
2-
import mindspore
3-
import mindnlp
41
from mindnlp import core
2+
import gradio as gr
53
from transformers import AutoConfig, AutoModelForCausalLM
6-
from janus.models import MultiModalityCausalLM, VLChatProcessor
4+
from janus.models import VLChatProcessor
75
from PIL import Image
86

97
import numpy as np
108

9+
device = 'cpu'
10+
if core.npu.is_available():
11+
device = 'npu'
12+
elif core.cuda.is_available():
13+
device = 'cuda'
1114

1215
# Load model and processor
1316
model_path = "deepseek-ai/Janus-1.3B"
@@ -16,7 +19,8 @@
1619
language_config._attn_implementation = 'eager'
1720
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
1821
language_config=language_config,
19-
trust_remote_code=True, ms_dtype=mindspore.float16)
22+
trust_remote_code=True)
23+
vl_gpt = vl_gpt.to(core.bfloat16).to(device)
2024

2125
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
2226
tokenizer = vl_chat_processor.tokenizer
@@ -26,9 +30,12 @@
2630
# Multimodal Understanding function
2731
def multimodal_understanding(image, question, seed, top_p, temperature):
2832
# Clear CUDA cache before generating
33+
core.cuda.empty_cache()
34+
2935
# set seed
30-
mindspore.manual_seed(seed)
36+
core.manual_seed(seed)
3137
np.random.seed(seed)
38+
core.cuda.manual_seed(seed)
3239

3340
conversation = [
3441
{
@@ -42,9 +49,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
4249
pil_images = [Image.fromarray(image)]
4350
prepare_inputs = vl_chat_processor(
4451
conversations=conversation, images=pil_images, force_batchify=True
45-
).to(core.get_default_device(), dtype=mindspore.float16)
46-
52+
).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
4753

54+
print(prepare_inputs)
4855
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
4956

5057
outputs = vl_gpt.language_model.generate(
@@ -75,13 +82,13 @@ def generate(input_ids,
7582
# Clear CUDA cache before generating
7683
core.cuda.empty_cache()
7784

78-
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int)
85+
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int).to(device)
7986
for i in range(parallel_size * 2):
8087
tokens[i, :] = input_ids
8188
if i % 2 != 0:
8289
tokens[i, 1:-1] = vl_chat_processor.pad_id
8390
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
84-
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int)
91+
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int).to(device)
8592

8693
pkv = None
8794
for i in range(image_token_num_per_image):

examples/diffusers/janus/demo/app_janusflow.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,40 @@
1-
import gradio as gr
2-
import mindspore
3-
import mindnlp
41
from mindnlp import core
2+
import gradio as gr
53
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
64
from PIL import Image
5+
from transformers import DynamicCache
76
from diffusers.models import AutoencoderKL
87
import numpy as np
98

9+
device = 'cpu'
10+
if core.npu.is_available():
11+
device = 'npu'
12+
elif core.cuda.is_available():
13+
device = 'cuda'
14+
1015
# Load model and processor
1116
model_path = "deepseek-ai/JanusFlow-1.3B"
1217
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
1318
tokenizer = vl_chat_processor.tokenizer
1419

15-
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path, ms_dtype=mindspore.float16)
16-
vl_gpt = vl_gpt.eval()
20+
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
21+
vl_gpt = vl_gpt.to(core.bfloat16).to(device).eval()
1722

1823
# remember to use bfloat16 dtype, this vae doesn't work with fp16
19-
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", ms_dtype=mindspore.float16)
20-
vae = vae.eval()
24+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
25+
vae = vae.to(core.bfloat16).to(device).eval()
2126

27+
# Multimodal Understanding function
28+
@core.inference_mode()
2229
# Multimodal Understanding function
2330
def multimodal_understanding(image, question, seed, top_p, temperature):
31+
# Clear CUDA cache before generating
32+
core.cuda.empty_cache()
33+
2434
# set seed
25-
mindspore.manual_seed(seed)
35+
core.manual_seed(seed)
2636
np.random.seed(seed)
37+
core.cuda.manual_seed(seed)
2738

2839
conversation = [
2940
{
@@ -37,9 +48,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
3748
pil_images = [Image.fromarray(image)]
3849
prepare_inputs = vl_chat_processor(
3950
conversations=conversation, images=pil_images, force_batchify=True
40-
).to(core.get_default_device(), mindspore.float16)
41-
42-
51+
).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
52+
53+
4354
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
4455

4556
outputs = vl_gpt.language_model.generate(
@@ -60,13 +71,14 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
6071
return answer
6172

6273

74+
@core.inference_mode()
6375
def generate(
6476
input_ids,
6577
cfg_weight: float = 2.0,
6678
num_inference_steps: int = 30
6779
):
6880
# we generate 5 images at a time, *2 for CFG
69-
tokens = core.stack([input_ids] * 10)
81+
tokens = core.stack([input_ids] * 10).cuda()
7082
tokens[5:, 1:] = vl_chat_processor.pad_id
7183
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
7284
print(inputs_embeds.shape)
@@ -76,10 +88,10 @@ def generate(
7688

7789
# generate with rectified flow ode
7890
# step 1: encode with vision_gen_enc
79-
z = core.randn((5, 4, 48, 48), dtype=mindspore.float16)
91+
z = core.randn((5, 4, 48, 48), dtype=core.bfloat16).cuda()
8092

8193
dt = 1.0 / num_inference_steps
82-
dt = core.zeros_like(z).to(mindspore.float16) + dt
94+
dt = core.zeros_like(z).cuda().to(core.bfloat16) + dt
8395

8496
# step 2: run ode
8597
attention_mask = core.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
@@ -103,18 +115,21 @@ def generate(
103115
use_cache=True,
104116
attention_mask=attention_mask,
105117
past_key_values=None)
106-
past_key_values = []
107-
for kv_cache in past_key_values:
108-
k, v = kv_cache[0], kv_cache[1]
109-
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
110-
past_key_values = tuple(past_key_values)
118+
past_key_values = DynamicCache.from_legacy_cache(outputs.past_key_values)
119+
111120
else:
112121
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
113122
use_cache=True,
114123
attention_mask=attention_mask,
115124
past_key_values=past_key_values)
125+
past_key_values = []
126+
for kv_cache in outputs.past_key_values:
127+
k, v = kv_cache[0], kv_cache[1]
128+
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
129+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
130+
116131
hidden_states = outputs.last_hidden_state
117-
132+
118133
# transform hidden_states back to v
119134
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
120135
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
@@ -141,13 +156,17 @@ def unpack(dec, width, height, parallel_size=5):
141156
return visual_img
142157

143158

159+
@core.inference_mode()
144160
def generate_image(prompt,
145161
seed=None,
146162
guidance=5,
147163
num_inference_steps=30):
164+
# Clear CUDA cache and avoid tracking gradients
165+
core.cuda.empty_cache()
148166
# Set the seed for reproducible results
149167
if seed is not None:
150-
mindspore.manual_seed(seed)
168+
core.manual_seed(seed)
169+
core.cuda.manual_seed(seed)
151170
np.random.seed(seed)
152171

153172
with core.no_grad():

examples/diffusers/janus/demo/app_januspro.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
1-
import gradio as gr
2-
import mindnlp
3-
import mindspore
41
from mindnlp import core
2+
import gradio as gr
53
from transformers import AutoConfig, AutoModelForCausalLM
6-
from janus.models import MultiModalityCausalLM, VLChatProcessor
7-
from janus.utils.io import load_pil_images
4+
from janus.models import VLChatProcessor
85
from PIL import Image
96

107
import numpy as np
8+
# import spaces # Import spaces for ZeroGPU compatibility
9+
10+
device = 'cpu'
11+
if core.npu.is_available():
12+
device = 'npu'
13+
elif core.cuda.is_available():
14+
device = 'cuda'
15+
1116

1217
# Load model and processor
1318
model_path = "deepseek-ai/Janus-Pro-7B"
1419
config = AutoConfig.from_pretrained(model_path)
1520
language_config = config.language_config
1621
language_config._attn_implementation = 'eager'
1722
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
18-
language_config=language_config,
19-
trust_remote_code=True, ms_dtype=mindspore.float16)
23+
language_config=language_config,
24+
trust_remote_code=True)
25+
vl_gpt = vl_gpt.to(core.bfloat16).to(device)
2026

2127
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
2228
tokenizer = vl_chat_processor.tokenizer
2329

30+
@core.inference_mode()
31+
# @spaces.GPU(duration=120)
2432
# Multimodal Understanding function
2533
def multimodal_understanding(image, question, seed, top_p, temperature):
34+
# Clear CUDA cache before generating
35+
core.cuda.empty_cache()
36+
2637
# set seed
27-
mindspore.manual_seed(seed)
38+
core.manual_seed(seed)
2839
np.random.seed(seed)
40+
core.cuda.manual_seed(seed)
2941

3042
conversation = [
3143
{
@@ -39,8 +51,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
3951
pil_images = [Image.fromarray(image)]
4052
prepare_inputs = vl_chat_processor(
4153
conversations=conversation, images=pil_images, force_batchify=True
42-
).to(core.get_default_device(), mindspore.float16)
43-
54+
).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
55+
56+
4457
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
4558

4659
outputs = vl_gpt.language_model.generate(
@@ -68,14 +81,16 @@ def generate(input_ids,
6881
cfg_weight: float = 5,
6982
image_token_num_per_image: int = 576,
7083
patch_size: int = 16):
84+
# Clear CUDA cache before generating
85+
core.cuda.empty_cache()
7186

72-
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=mindspore.int32)
87+
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int).to(device)
7388
for i in range(parallel_size * 2):
7489
tokens[i, :] = input_ids
7590
if i % 2 != 0:
7691
tokens[i, 1:-1] = vl_chat_processor.pad_id
7792
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
78-
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=mindspore.int32)
93+
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int).to(device)
7994

8095
pkv = None
8196
for i in range(image_token_num_per_image):
@@ -99,13 +114,13 @@ def generate(input_ids,
99114

100115

101116

102-
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=mindspore.int32),
117+
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=core.int),
103118
shape=[parallel_size, 8, width // patch_size, height // patch_size])
104119

105-
return generated_tokens.to(dtype=mindspore.int32), patches
120+
return generated_tokens.to(dtype=core.int), patches
106121

107122
def unpack(dec, width, height, parallel_size=5):
108-
dec = dec.to(mindspore.float32).cpu().numpy().transpose(0, 2, 3, 1)
123+
dec = dec.to(core.float32).cpu().numpy().transpose(0, 2, 3, 1)
109124
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
110125

111126
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
@@ -114,13 +129,19 @@ def unpack(dec, width, height, parallel_size=5):
114129
return visual_img
115130

116131

132+
133+
@core.inference_mode()
134+
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
117135
def generate_image(prompt,
118136
seed=None,
119137
guidance=5,
120138
t2i_temperature=1.0):
139+
# Clear CUDA cache and avoid tracking gradients
140+
core.cuda.empty_cache()
121141
# Set the seed for reproducible results
122142
if seed is not None:
123-
mindspore.manual_seed(seed)
143+
core.manual_seed(seed)
144+
core.cuda.manual_seed(seed)
124145
np.random.seed(seed)
125146
width = 384
126147
height = 384

examples/diffusers/janus/generation_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
3737
model_path, language_config=language_config, trust_remote_code=True, ms_dtype=torch.float16
3838
)
39-
vl_gpt = vl_gpt.eval()
39+
vl_gpt = vl_gpt.eval().cuda()
4040

4141
conversation = [
4242
{

examples/diffusers/janus/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
3636
model_path, language_config=language_config, trust_remote_code=True, torch_dtype=torch.float16
3737
)
38-
vl_gpt = vl_gpt.eval()
38+
vl_gpt = vl_gpt.eval().npu()
3939

4040
conversation = [
4141
{

mindnlp/core/_prims/numpy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def dyn_shape(self):
103103
__all__.append('dyn_shape')
104104

105105
def cast(input, dtype):
106+
if input.dtype == dtype:
107+
return input
106108
out = input.asnumpy().astype(core.dtype2np[dtype])
107109
return core.Tensor.from_numpy(out)
108110

@@ -150,6 +152,13 @@ def bitwise_and_scalar(input, other):
150152

151153
__all__.append('bitwise_and_scalar')
152154

155+
156+
def bitwise_or_tensor(input, other):
157+
out = np.bitwise_or(input.numpy(), other.numpy())
158+
return core.Tensor.from_numpy(out)
159+
160+
__all__.append('bitwise_or_tensor')
161+
153162
def right_shift(input, other):
154163
out = np.right_shift(input.numpy(), other)
155164
return core.Tensor.from_numpy(out)

mindnlp/core/_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,8 @@ def data(self, new_value):
920920
if isinstance(self, StubTensor) and isinstance(new_value, StubTensor):
921921
self.stub = new_value.stub
922922
else:
923-
if self.device.type == 'cpu' and new_value.device.type == 'cpu' and self.shape == new_value.shape:
923+
if self.device.type == 'cpu' and new_value.device.type == 'cpu' \
924+
and self.shape == new_value.shape and self.dtype == new_value.dtype:
924925
src_ct = ctypes.c_void_p(new_value.data_ptr())
925926
dst_ct = ctypes.c_void_p(self.data_ptr())
926927
ctypes.memmove(dst_ct, src_ct, self.nbytes)

0 commit comments

Comments
 (0)