Skip to content

Commit 458bc93

Browse files
committed
chore: make weight override more robust against ggml changes
Recently, GGML_TYPE_COUNT got bumped for the new GGML_TYPE_MXFP4 quant, getting it out-of-sync with SD_TYPE_COUNT. To make it easier to build stable-diffusion.cpp against different ggml versions, adjust the type conversions to consider both GGML_TYPE_COUNT and SD_TYPE_COUNT as limits.
1 parent 1c30154 commit 458bc93

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2156,7 +2156,7 @@ std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std
21562156
if (type_name == "f32") {
21572157
tensor_type = GGML_TYPE_F32;
21582158
} else {
2159-
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
2159+
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
21602160
auto trait = ggml_get_type_traits((ggml_type)i);
21612161
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
21622162
tensor_type = (ggml_type)i;

stable-diffusion.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ class StableDiffusionGGML {
241241
}
242242

243243
LOG_INFO("Version: %s ", model_version_to_str[version]);
244-
ggml_type wtype = (ggml_type)sd_ctx_params->wtype;
244+
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
245+
? (ggml_type)sd_ctx_params->wtype;
246+
: GGML_TYPE_COUNT;
245247
if (wtype == GGML_TYPE_COUNT) {
246248
model_wtype = model_loader.get_sd_wtype();
247249
if (model_wtype == GGML_TYPE_COUNT) {
@@ -1211,11 +1213,13 @@ class StableDiffusionGGML {
12111213
#define NONE_STR "NONE"
12121214

12131215
const char* sd_type_name(enum sd_type_t type) {
1214-
return ggml_type_name((ggml_type)type);
1216+
if ((int) type < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT))
1217+
return ggml_type_name((ggml_type)type);
1218+
return NONE_STR;
12151219
}
12161220

12171221
enum sd_type_t str_to_sd_type(const char* str) {
1218-
for (int i = 0; i < SD_TYPE_COUNT; i++) {
1222+
for (int i = 0; i < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT); i++) {
12191223
auto trait = ggml_get_type_traits((ggml_type)i);
12201224
if (!strcmp(str, trait->type_name)) {
12211225
return (enum sd_type_t)i;

0 commit comments

Comments
 (0)