3
3
from typing import Callable , Optional
4
4
5
5
import torch
6
+ from torch .nn import functional as F
6
7
7
8
from vllm import envs
8
9
9
10
11
+ def silu_and_mul (x : torch .Tensor ) -> torch .Tensor :
12
+ d = x .shape [- 1 ] // 2
13
+ return F .silu (x [..., :d ]) * x [..., d :]
14
+
15
+
16
+ def grouped_topk (
17
+ hidden_states : torch .Tensor ,
18
+ gating_output : torch .Tensor ,
19
+ topk : int ,
20
+ renormalize : bool ,
21
+ num_expert_group : int = 0 ,
22
+ topk_group : int = 0 ,
23
+ scoring_func : str = "softmax" ,
24
+ e_score_correction_bias : Optional [torch .Tensor ] = None
25
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
26
+ assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
27
+ "Number of tokens mismatch" )
28
+
29
+ gating_output = gating_output .float ()
30
+ if scoring_func == "softmax" :
31
+ scores = torch .softmax (gating_output , dim = - 1 )
32
+ elif scoring_func == "sigmoid" :
33
+ scores = gating_output .sigmoid ()
34
+ else :
35
+ raise ValueError (f"Unsupported scoring function: { scoring_func } " )
36
+
37
+ num_token = scores .shape [0 ]
38
+ if e_score_correction_bias is not None :
39
+ original_scores = scores
40
+ scores = scores + e_score_correction_bias .unsqueeze (0 )
41
+ group_scores = (scores .view (num_token , num_expert_group ,
42
+ - 1 ).topk (2 , dim = - 1 )[0 ].sum (dim = - 1 ))
43
+ else :
44
+ group_scores = scores .view (num_token , num_expert_group ,
45
+ - 1 ).max (dim = - 1 ).values # [n, n_group]
46
+ group_idx = torch .topk (group_scores , k = topk_group , dim = - 1 ,
47
+ sorted = False )[1 ] # [n, top_k_group]
48
+ group_mask = torch .zeros_like (group_scores ) # [n, n_group]
49
+ group_mask .scatter_ (1 , group_idx , 1 ) # [n, n_group]
50
+ score_mask = group_mask .unsqueeze (- 1 ).expand (
51
+ num_token , num_expert_group ,
52
+ scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
53
+ tmp_scores = scores .masked_fill (~ score_mask .bool (),
54
+ float ("-inf" )) # [n, e]
55
+
56
+ if e_score_correction_bias is not None :
57
+ topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
58
+ topk_weights = original_scores .gather (1 , topk_ids )
59
+ else :
60
+ topk_weights , topk_ids = torch .topk (tmp_scores ,
61
+ k = topk ,
62
+ dim = - 1 ,
63
+ sorted = False )
64
+
65
+ if renormalize :
66
+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
67
+
68
+ return topk_weights , topk_ids .to (torch .int32 )
69
+
70
+
71
+ def select_experts (
72
+ hidden_states : torch .Tensor ,
73
+ router_logits : torch .Tensor ,
74
+ top_k : int ,
75
+ use_grouped_topk : bool ,
76
+ renormalize : bool ,
77
+ topk_group : Optional [int ] = None ,
78
+ num_expert_group : Optional [int ] = None ,
79
+ custom_routing_function : Optional [Callable ] = None ,
80
+ scoring_func : str = "softmax" ,
81
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
82
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
83
+ if use_grouped_topk :
84
+ assert topk_group is not None
85
+ assert num_expert_group is not None
86
+ return grouped_topk (hidden_states = hidden_states ,
87
+ gating_output = router_logits ,
88
+ topk = top_k ,
89
+ renormalize = renormalize ,
90
+ num_expert_group = num_expert_group ,
91
+ topk_group = topk_group ,
92
+ scoring_func = scoring_func ,
93
+ e_score_correction_bias = e_score_correction_bias )
94
+ elif custom_routing_function is None :
95
+ assert scoring_func == "softmax"
96
+ topk_weights = torch .nn .functional .softmax (router_logits ,
97
+ dim = 1 ,
98
+ dtype = torch .float32 )
99
+ topk_weights , topk_ids = torch .topk (topk_weights , top_k , dim = - 1 )
100
+ if renormalize :
101
+ topk_weights /= topk_weights .sum (dim = - 1 , keepdim = True )
102
+ return topk_weights , topk_ids .to (torch .int32 )
103
+ else :
104
+ return custom_routing_function (hidden_states = hidden_states ,
105
+ gating_output = router_logits ,
106
+ topk = top_k ,
107
+ renormalize = renormalize )
108
+
109
+
10
110
class IPEXFusedMOE :
11
111
12
112
def __init__ (self , layer : torch .nn .Module ) -> None :
@@ -56,113 +156,6 @@ class SGLFusedMOE:
56
156
def __init__ (self , layer : torch .nn .Module ) -> None :
57
157
pass
58
158
59
- @staticmethod
60
- def _grouped_topk (
61
- hidden_states : torch .Tensor ,
62
- gating_output : torch .Tensor ,
63
- topk : int ,
64
- renormalize : bool ,
65
- num_expert_group : int = 0 ,
66
- topk_group : int = 0 ,
67
- scoring_func : str = "softmax" ,
68
- e_score_correction_bias : Optional [torch .Tensor ] = None
69
- ) -> tuple [torch .Tensor , torch .Tensor ]:
70
- assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
71
- "Number of tokens mismatch" )
72
-
73
- gating_output = gating_output .float ()
74
- if scoring_func == "softmax" :
75
- scores = torch .softmax (gating_output , dim = - 1 )
76
- elif scoring_func == "sigmoid" :
77
- scores = gating_output .sigmoid ()
78
- else :
79
- raise ValueError (f"Unsupported scoring function: { scoring_func } " )
80
-
81
- num_token = scores .shape [0 ]
82
- if e_score_correction_bias is not None :
83
- # Store original scores before applying correction bias. We use
84
- # biased scores for expert selection but original scores for
85
- # routing weights
86
- original_scores = scores
87
- scores = scores + e_score_correction_bias .unsqueeze (0 )
88
- group_scores = (scores .view (num_token , num_expert_group ,
89
- - 1 ).topk (2 , dim = - 1 )[0 ].sum (dim = - 1 ))
90
- else :
91
- group_scores = scores .view (num_token , num_expert_group ,
92
- - 1 ).max (dim = - 1 ).values # [n, n_group]
93
- group_idx = torch .topk (group_scores ,
94
- k = topk_group ,
95
- dim = - 1 ,
96
- sorted = False )[1 ] # [n, top_k_group]
97
- group_mask = torch .zeros_like (group_scores ) # [n, n_group]
98
- group_mask .scatter_ (1 , group_idx , 1 ) # [n, n_group]
99
- score_mask = group_mask .unsqueeze (- 1 ).expand (
100
- num_token , num_expert_group ,
101
- scores .shape [- 1 ] // num_expert_group ).reshape (num_token ,
102
- - 1 ) # [n, e]
103
- tmp_scores = scores .masked_fill (~ score_mask .bool (),
104
- float ("-inf" )) # [n, e]
105
-
106
- if e_score_correction_bias is not None :
107
- topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
108
- # Use original unbiased scores for the routing weights
109
- topk_weights = original_scores .gather (1 , topk_ids )
110
- else :
111
- topk_weights , topk_ids = torch .topk (tmp_scores ,
112
- k = topk ,
113
- dim = - 1 ,
114
- sorted = False )
115
-
116
- if renormalize :
117
- topk_weights = topk_weights / topk_weights .sum (dim = - 1 ,
118
- keepdim = True )
119
-
120
- return topk_weights , topk_ids .to (torch .int32 )
121
-
122
- @staticmethod
123
- def _select_experts (
124
- hidden_states : torch .Tensor ,
125
- router_logits : torch .Tensor ,
126
- top_k : int ,
127
- use_grouped_topk : bool ,
128
- renormalize : bool ,
129
- topk_group : Optional [int ] = None ,
130
- num_expert_group : Optional [int ] = None ,
131
- custom_routing_function : Optional [Callable ] = None ,
132
- scoring_func : str = "softmax" ,
133
- e_score_correction_bias : Optional [torch .Tensor ] = None ,
134
- ) -> tuple [torch .Tensor , torch .Tensor ]:
135
- # DeekSeekv2 uses grouped_top_k
136
- if use_grouped_topk :
137
- assert topk_group is not None
138
- assert num_expert_group is not None
139
- topk_weights , topk_ids = SGLFusedMOE ._grouped_topk (
140
- hidden_states = hidden_states ,
141
- gating_output = router_logits ,
142
- topk = top_k ,
143
- renormalize = renormalize ,
144
- num_expert_group = num_expert_group ,
145
- topk_group = topk_group ,
146
- scoring_func = scoring_func ,
147
- e_score_correction_bias = e_score_correction_bias )
148
- elif custom_routing_function is None :
149
- assert scoring_func == "softmax"
150
- topk_weights = torch .nn .functional .softmax (router_logits ,
151
- dim = 1 ,
152
- dtype = torch .float32 )
153
- topk_weights , topk_ids = torch .topk (topk_weights , top_k , dim = - 1 )
154
- if renormalize :
155
- topk_weights /= topk_weights .sum (dim = - 1 , keepdim = True )
156
- topk_ids = topk_ids .to (torch .int32 )
157
- else :
158
- topk_weights , topk_ids = custom_routing_function (
159
- hidden_states = hidden_states ,
160
- gating_output = router_logits ,
161
- topk = top_k ,
162
- renormalize = renormalize )
163
-
164
- return topk_weights , topk_ids
165
-
166
159
def __call__ (
167
160
self ,
168
161
layer : torch .nn .Module ,
@@ -183,7 +176,7 @@ def __call__(
183
176
) -> torch .Tensor :
184
177
assert activation == "silu" , f"{ activation } is not supported."
185
178
assert not apply_router_weight_on_input
186
- topk_weights , topk_ids = SGLFusedMOE . _select_experts (
179
+ topk_weights , topk_ids = select_experts (
187
180
hidden_states = x ,
188
181
router_logits = router_logits ,
189
182
use_grouped_topk = use_grouped_topk ,
@@ -213,3 +206,80 @@ def __call__(
213
206
True ,
214
207
)
215
208
return x
209
+
210
+
211
+ class CPUFusedMOE :
212
+
213
+ def __init__ (self , layer : torch .nn .Module ) -> None :
214
+ pass
215
+
216
+ def __call__ (
217
+ self ,
218
+ layer : torch .nn .Module ,
219
+ x : torch .Tensor ,
220
+ use_grouped_topk : bool ,
221
+ top_k : int ,
222
+ router_logits : torch .Tensor ,
223
+ renormalize : bool ,
224
+ topk_group : Optional [int ] = None ,
225
+ num_expert_group : Optional [int ] = None ,
226
+ global_num_experts : int = - 1 ,
227
+ expert_map : Optional [torch .Tensor ] = None ,
228
+ custom_routing_function : Optional [Callable ] = None ,
229
+ scoring_func : str = "softmax" ,
230
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
231
+ apply_router_weight_on_input : bool = False ,
232
+ activation : str = "silu" ,
233
+ ) -> torch .Tensor :
234
+ assert activation == "silu" , f"{ activation } is not supported."
235
+ assert not apply_router_weight_on_input
236
+ topk_weights , topk_ids = select_experts (
237
+ hidden_states = x ,
238
+ router_logits = router_logits ,
239
+ use_grouped_topk = use_grouped_topk ,
240
+ top_k = top_k ,
241
+ renormalize = renormalize ,
242
+ topk_group = topk_group ,
243
+ num_expert_group = num_expert_group ,
244
+ custom_routing_function = custom_routing_function ,
245
+ scoring_func = scoring_func ,
246
+ e_score_correction_bias = e_score_correction_bias ,
247
+ )
248
+
249
+ # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
250
+ len_experts = global_num_experts
251
+
252
+ cnts = topk_ids .new_zeros ((topk_ids .shape [0 ], len_experts ))
253
+ cnts .scatter_ (1 , topk_ids .to (torch .int64 ), 1 )
254
+ tokens_per_expert = cnts .sum (dim = 0 )
255
+ idxs = topk_ids .view (- 1 ).argsort ()
256
+
257
+ sorted_tokens = x [idxs // topk_ids .shape [1 ]]
258
+ tokens_per_expert = tokens_per_expert .cpu ().numpy ()
259
+
260
+ outputs = []
261
+ start_idx = 0
262
+ for i , num_tokens in enumerate (tokens_per_expert ):
263
+ end_idx = start_idx + num_tokens
264
+ if num_tokens == 0 :
265
+ continue
266
+ tokens_for_this_expert = sorted_tokens [start_idx :end_idx ]
267
+
268
+ layer_w13_weight = layer .w13_weight [i ]
269
+ layer_w2_weight = layer .w2_weight [i ]
270
+
271
+ gate_up = F .linear (tokens_for_this_expert , layer_w13_weight )
272
+ gate_up = silu_and_mul (gate_up )
273
+ expert_out = F .linear (gate_up , layer_w2_weight )
274
+ outputs .append (expert_out )
275
+ start_idx = end_idx
276
+
277
+ outs = torch .cat (outputs ,
278
+ dim = 0 ) if len (outputs ) else sorted_tokens .new_empty (0 )
279
+ new_x = torch .empty_like (outs )
280
+
281
+ new_x [idxs ] = outs
282
+ final_out = (new_x .view (
283
+ * topk_ids .shape , - 1 ).type (topk_weights .dtype ).mul_ (
284
+ topk_weights .unsqueeze (dim = - 1 )).sum (dim = 1 ).type (new_x .dtype ))
285
+ return final_out
0 commit comments