Skip to content

Commit 5b8996f

Browse files
Conv2D direct support (#744)
* Conv2DDirect for VAE stage * Enable only for Vulkan, reduced duplicated code * Cmake option to use conv2d direct * conv2d direct always on for opencl * conv direct as a flag * fix merge typo * Align conv2d behavior to flash attention's * fix readme * add conv2d direct for controlnet * add conv2d direct for esrgan * clean code, use enable_conv2d_direct/get_all_blocks * format code --------- Co-authored-by: leejet <[email protected]>
1 parent f7f05fb commit 5b8996f

File tree

11 files changed

+151
-7
lines changed

11 files changed

+151
-7
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,10 @@ arguments:
341341
--diffusion-fa use flash attention in the diffusion model (for low vram)
342342
Might lower quality, since it implies converting k and v to f16.
343343
This might crash if it is not supported by the backend.
344+
--diffusion-conv-direct use Conv2d direct in the diffusion model
345+
This might crash if it is not supported by the backend.
346+
--vae-conv-direct use Conv2d direct in the vae model (should improve the performance)
347+
This might crash if it is not supported by the backend.
344348
--control-net-cpu keep controlnet in cpu (for low vram)
345349
--canny apply canny preprocessor (edge detection)
346350
--color colors the logging tags according to level

control.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,17 @@ struct ControlNet : public GGMLRunner {
323323
control_net.init(params_ctx, tensor_types, "");
324324
}
325325

326+
void enable_conv2d_direct() {
327+
std::vector<GGMLBlock*> blocks;
328+
control_net.get_all_blocks(blocks);
329+
for (auto block : blocks) {
330+
if (block->get_desc() == "Conv2d") {
331+
auto conv_block = (Conv2d*)block;
332+
conv_block->enable_direct();
333+
}
334+
}
335+
}
336+
326337
~ControlNet() {
327338
free_control_ctx();
328339
}

esrgan.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ struct ESRGAN : public GGMLRunner {
147147
rrdb_net.init(params_ctx, tensor_types, "");
148148
}
149149

150+
void enable_conv2d_direct() {
151+
std::vector<GGMLBlock*> blocks;
152+
rrdb_net.get_all_blocks(blocks);
153+
for (auto block : blocks) {
154+
if (block->get_desc() == "Conv2d") {
155+
auto conv_block = (Conv2d*)block;
156+
conv_block->enable_direct();
157+
}
158+
}
159+
}
160+
150161
std::string get_desc() {
151162
return "esrgan";
152163
}

examples/cli/main.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct SDParams {
9797
bool clip_on_cpu = false;
9898
bool vae_on_cpu = false;
9999
bool diffusion_flash_attn = false;
100+
bool diffusion_conv_direct = false;
101+
bool vae_conv_direct = false;
100102
bool canny_preprocess = false;
101103
bool color = false;
102104
int upscale_repeats = 1;
@@ -142,6 +144,8 @@ void print_params(SDParams params) {
142144
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
143145
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
144146
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
147+
printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
148+
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false");
145149
printf(" strength(control): %.2f\n", params.control_strength);
146150
printf(" prompt: %s\n", params.prompt.c_str());
147151
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@@ -232,6 +236,10 @@ void print_usage(int argc, const char* argv[]) {
232236
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
233237
printf(" Might lower quality, since it implies converting k and v to f16.\n");
234238
printf(" This might crash if it is not supported by the backend.\n");
239+
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model");
240+
printf(" This might crash if it is not supported by the backend.\n");
241+
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)");
242+
printf(" This might crash if it is not supported by the backend.\n");
235243
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
236244
printf(" --canny apply canny preprocessor (edge detection)\n");
237245
printf(" --color colors the logging tags according to level\n");
@@ -422,6 +430,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
422430
{"", "--clip-on-cpu", "", true, &params.clip_on_cpu},
423431
{"", "--vae-on-cpu", "", true, &params.vae_on_cpu},
424432
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
433+
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
434+
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
425435
{"", "--canny", "", true, &params.canny_preprocess},
426436
{"-v", "--verbos", "", true, &params.verbose},
427437
{"", "--color", "", true, &params.color},
@@ -901,6 +911,8 @@ int main(int argc, const char* argv[]) {
901911
params.control_net_cpu,
902912
params.vae_on_cpu,
903913
params.diffusion_flash_attn,
914+
params.diffusion_conv_direct,
915+
params.vae_conv_direct,
904916
params.chroma_use_dit_mask,
905917
params.chroma_use_t5_mask,
906918
params.chroma_t5_mask_pad,
@@ -1012,7 +1024,8 @@ int main(int argc, const char* argv[]) {
10121024
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
10131025
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
10141026
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
1015-
params.n_threads);
1027+
params.n_threads,
1028+
params.diffusion_conv_direct);
10161029

10171030
if (upscaler_ctx == NULL) {
10181031
printf("new_upscaler_ctx failed\n");

ggml_extend.hpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
708708
return x;
709709
}
710710

711+
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
712+
struct ggml_tensor* x,
713+
struct ggml_tensor* w,
714+
struct ggml_tensor* b,
715+
int s0 = 1,
716+
int s1 = 1,
717+
int p0 = 0,
718+
int p1 = 0,
719+
int d0 = 1,
720+
int d1 = 1) {
721+
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
722+
if (b != NULL) {
723+
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
724+
// b = ggml_repeat(ctx, b, x);
725+
x = ggml_add(ctx, x, b);
726+
}
727+
return x;
728+
}
729+
711730
// w: [OC,IC, KD, 1 * 1]
712731
// x: [N, IC, IH, IW]
713732
// b: [OC,]
@@ -1377,6 +1396,19 @@ class GGMLBlock {
13771396
tensors[prefix + pair.first] = pair.second;
13781397
}
13791398
}
1399+
1400+
virtual std::string get_desc() {
1401+
return "GGMLBlock";
1402+
}
1403+
1404+
void get_all_blocks(std::vector<GGMLBlock*>& result) {
1405+
result.push_back(this);
1406+
for (auto& block_iter : blocks) {
1407+
if (block_iter.second) {
1408+
block_iter.second->get_all_blocks(result);
1409+
}
1410+
}
1411+
}
13801412
};
13811413

13821414
class UnaryBlock : public GGMLBlock {
@@ -1466,6 +1498,7 @@ class Conv2d : public UnaryBlock {
14661498
std::pair<int, int> padding;
14671499
std::pair<int, int> dilation;
14681500
bool bias;
1501+
bool direct = false;
14691502

14701503
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
14711504
enum ggml_type wtype = GGML_TYPE_F16;
@@ -1492,13 +1525,25 @@ class Conv2d : public UnaryBlock {
14921525
dilation(dilation),
14931526
bias(bias) {}
14941527

1528+
void enable_direct() {
1529+
direct = true;
1530+
}
1531+
1532+
std::string get_desc() {
1533+
return "Conv2d";
1534+
}
1535+
14951536
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
14961537
struct ggml_tensor* w = params["weight"];
14971538
struct ggml_tensor* b = NULL;
14981539
if (bias) {
14991540
b = params["bias"];
15001541
}
1501-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1542+
if (direct) {
1543+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1544+
} else {
1545+
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1546+
}
15021547
}
15031548
};
15041549

stable-diffusion.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ class StableDiffusionGGML {
374374
model_loader.tensor_storages_types,
375375
version,
376376
sd_ctx_params->diffusion_flash_attn);
377+
if (sd_ctx_params->diffusion_conv_direct) {
378+
LOG_INFO("Using Conv2d direct in the diffusion model");
379+
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
380+
}
377381
}
378382

379383
cond_stage_model->alloc_params_buffer();
@@ -395,6 +399,10 @@ class StableDiffusionGGML {
395399
vae_decode_only,
396400
false,
397401
version);
402+
if (sd_ctx_params->vae_conv_direct) {
403+
LOG_INFO("Using Conv2d direct in the vae model");
404+
first_stage_model->enable_conv2d_direct();
405+
}
398406
first_stage_model->alloc_params_buffer();
399407
first_stage_model->get_param_tensors(tensors, "first_stage_model");
400408
} else {
@@ -403,6 +411,10 @@ class StableDiffusionGGML {
403411
"decoder.layers",
404412
vae_decode_only,
405413
version);
414+
if (sd_ctx_params->vae_conv_direct) {
415+
LOG_INFO("Using Conv2d direct in the tae model");
416+
tae_first_stage->enable_conv2d_direct();
417+
}
406418
}
407419
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
408420

@@ -415,6 +427,10 @@ class StableDiffusionGGML {
415427
controlnet_backend = backend;
416428
}
417429
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
430+
if (sd_ctx_params->diffusion_conv_direct) {
431+
LOG_INFO("Using Conv2d direct in the control net");
432+
control_net->enable_conv2d_direct();
433+
}
418434
}
419435

420436
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {

stable-diffusion.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ typedef struct {
134134
bool keep_control_net_on_cpu;
135135
bool keep_vae_on_cpu;
136136
bool diffusion_flash_attn;
137+
bool diffusion_conv_direct;
138+
bool vae_conv_direct;
137139
bool chroma_use_dit_mask;
138140
bool chroma_use_t5_mask;
139141
int chroma_t5_mask_pad;
@@ -236,7 +238,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
236238
typedef struct upscaler_ctx_t upscaler_ctx_t;
237239

238240
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
239-
int n_threads);
241+
int n_threads,
242+
bool direct);
240243
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
241244

242245
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

tae.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ struct TinyAutoEncoder : public GGMLRunner {
206206
taesd.init(params_ctx, tensor_types, prefix);
207207
}
208208

209+
void enable_conv2d_direct() {
210+
std::vector<GGMLBlock*> blocks;
211+
taesd.get_all_blocks(blocks);
212+
for (auto block : blocks) {
213+
if (block->get_desc() == "Conv2d") {
214+
auto conv_block = (Conv2d*)block;
215+
conv_block->enable_direct();
216+
}
217+
}
218+
}
219+
209220
std::string get_desc() {
210221
return "taesd";
211222
}

unet.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,18 @@ struct UNetModelRunner : public GGMLRunner {
546546
unet.init(params_ctx, tensor_types, prefix);
547547
}
548548

549+
void enable_conv2d_direct() {
550+
std::vector<GGMLBlock*> blocks;
551+
unet.get_all_blocks(blocks);
552+
for (auto block : blocks) {
553+
if (block->get_desc() == "Conv2d") {
554+
LOG_DEBUG("block %s", block->get_desc().c_str());
555+
auto conv_block = (Conv2d*)block;
556+
conv_block->enable_direct();
557+
}
558+
}
559+
}
560+
549561
std::string get_desc() {
550562
return "unet";
551563
}

upscaler.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ struct UpscalerGGML {
99
std::shared_ptr<ESRGAN> esrgan_upscaler;
1010
std::string esrgan_path;
1111
int n_threads;
12+
bool direct = false;
1213

13-
UpscalerGGML(int n_threads)
14-
: n_threads(n_threads) {
14+
UpscalerGGML(int n_threads,
15+
bool direct = false)
16+
: n_threads(n_threads),
17+
direct(direct) {
1518
}
1619

1720
bool load_from_file(const std::string& esrgan_path) {
@@ -47,6 +50,9 @@ struct UpscalerGGML {
4750
}
4851
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
4952
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
53+
if (direct) {
54+
esrgan_upscaler->enable_conv2d_direct();
55+
}
5056
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
5157
return false;
5258
}
@@ -104,14 +110,15 @@ struct upscaler_ctx_t {
104110
};
105111

106112
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
107-
int n_threads) {
113+
int n_threads,
114+
bool direct = false) {
108115
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
109116
if (upscaler_ctx == NULL) {
110117
return NULL;
111118
}
112119
std::string esrgan_path(esrgan_path_c_str);
113120

114-
upscaler_ctx->upscaler = new UpscalerGGML(n_threads);
121+
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
115122
if (upscaler_ctx->upscaler == NULL) {
116123
return NULL;
117124
}

0 commit comments

Comments
 (0)