Skip to content

Commit 04aa9b1

Browse files
reeselevineneha-ha
authored andcommitted
ggml webgpu: faster normal quant and some k-quant matrix operations, better shader parameter handling (ggml-org#20173)
* K quant speedup (ggml-org#20) * Basic JIT compilation for mul_mat, get_rows, and scale (ggml-org#17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * no gibberish, all k quants added, merged * vec memory fix * q6_k matching metal on my machine, tests passing * Set tile size for q6_k separately * Separate out fast shaders --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> * Move towards writeBuffer for params * Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups * Remove extra file * Formatting --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
1 parent 3ac9e6c commit 04aa9b1

5 files changed

Lines changed: 1237 additions & 256 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,20 @@
4242
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
4343

4444
// Matrix-vector multiplication parameters
45-
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
45+
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
46+
4647
// Must be multiple of 4 to work with vectorized paths, and must divide
4748
// mul_mat_vec wg size
48-
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
49-
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
49+
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
50+
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
51+
52+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
53+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
54+
55+
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
56+
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
57+
// Requires at least two (and multiple of 2) k-quant blocks per tile
58+
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
5059

5160
// default size for legacy matrix multiplication
5261
#define WEBGPU_MUL_MAT_WG_SIZE 256
@@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
199208
bool src_overlap;
200209

201210
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
202-
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
211+
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
212+
src_overlap == other.src_overlap;
203213
}
204214
};
205215

@@ -749,36 +759,25 @@ class ggml_webgpu_shader_lib {
749759
std::vector<std::string> defines;
750760
std::string variant = "mul_mat_vec";
751761

752-
// src1 type (vector)
753-
switch (context.src1->type) {
754-
case GGML_TYPE_F32:
755-
defines.push_back("SRC1_INNER_TYPE=f32");
756-
variant += "_f32";
757-
break;
758-
case GGML_TYPE_F16:
759-
defines.push_back("SRC1_INNER_TYPE=f16");
760-
variant += "_f16";
761-
break;
762-
default:
763-
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
764-
}
765-
766762
// src0 type (matrix row)
767763
switch (context.src0->type) {
768764
case GGML_TYPE_F32:
769765
defines.push_back("SRC0_INNER_TYPE=f32");
770766
defines.push_back("MUL_ACC_FLOAT");
767+
variant += "_f32";
771768
break;
772769
case GGML_TYPE_F16:
773770
defines.push_back("SRC0_INNER_TYPE=f16");
774771
defines.push_back("MUL_ACC_FLOAT");
772+
variant += "_f16";
775773
break;
776774
default:
777775
{
778776
// Quantized types: use helpers but accumulate in f16
779777
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
780778
std::string src0_name = src0_traits->type_name;
781779
std::string type_upper = src0_name;
780+
variant += "_" + src0_name;
782781
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
783782

784783
defines.push_back("BYTE_HELPERS");
@@ -790,12 +789,35 @@ class ggml_webgpu_shader_lib {
790789
}
791790
}
792791

792+
// src1 type (vector)
793+
switch (context.src1->type) {
794+
case GGML_TYPE_F32:
795+
defines.push_back("SRC1_INNER_TYPE=f32");
796+
variant += "_f32";
797+
break;
798+
case GGML_TYPE_F16:
799+
defines.push_back("SRC1_INNER_TYPE=f16");
800+
variant += "_f16";
801+
break;
802+
default:
803+
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
804+
}
805+
793806
// VEC/SCALAR controls
794807
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
795808

796809
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
797-
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
798-
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
810+
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
811+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
812+
813+
if (key.src0_type >= GGML_TYPE_Q2_K) {
814+
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
815+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
816+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
817+
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
818+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
819+
}
820+
799821
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
800822
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
801823
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
@@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
10611083

10621084
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
10631085
ggml_webgpu_binary_pipeline_key key = {
1064-
.type = context.dst->type,
1065-
.op = context.dst->op,
1066-
.inplace = context.inplace,
1067-
.overlap = context.overlap,
1086+
.type = context.dst->type,
1087+
.op = context.dst->op,
1088+
.inplace = context.inplace,
1089+
.overlap = context.overlap,
10681090
.src_overlap = context.src_overlap,
10691091
};
10701092

0 commit comments

Comments
 (0)