Skip to content

Commit f58675b

Browse files
[CPU] add cpu fused moe pytorch native implementation (vllm-project#23146)
Signed-off-by: Tianyu Li <[email protected]> Co-authored-by: Li, Jiang <[email protected]>
1 parent 7c04779 commit f58675b

File tree

2 files changed

+180
-110
lines changed

2 files changed

+180
-110
lines changed

vllm/model_executor/layers/fused_moe/cpu_fused_moe.py

Lines changed: 178 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,110 @@
33
from typing import Callable, Optional
44

55
import torch
6+
from torch.nn import functional as F
67

78
from vllm import envs
89

910

11+
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
12+
d = x.shape[-1] // 2
13+
return F.silu(x[..., :d]) * x[..., d:]
14+
15+
16+
def grouped_topk(
17+
hidden_states: torch.Tensor,
18+
gating_output: torch.Tensor,
19+
topk: int,
20+
renormalize: bool,
21+
num_expert_group: int = 0,
22+
topk_group: int = 0,
23+
scoring_func: str = "softmax",
24+
e_score_correction_bias: Optional[torch.Tensor] = None
25+
) -> tuple[torch.Tensor, torch.Tensor]:
26+
assert hidden_states.shape[0] == gating_output.shape[0], (
27+
"Number of tokens mismatch")
28+
29+
gating_output = gating_output.float()
30+
if scoring_func == "softmax":
31+
scores = torch.softmax(gating_output, dim=-1)
32+
elif scoring_func == "sigmoid":
33+
scores = gating_output.sigmoid()
34+
else:
35+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
36+
37+
num_token = scores.shape[0]
38+
if e_score_correction_bias is not None:
39+
original_scores = scores
40+
scores = scores + e_score_correction_bias.unsqueeze(0)
41+
group_scores = (scores.view(num_token, num_expert_group,
42+
-1).topk(2, dim=-1)[0].sum(dim=-1))
43+
else:
44+
group_scores = scores.view(num_token, num_expert_group,
45+
-1).max(dim=-1).values # [n, n_group]
46+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
47+
sorted=False)[1] # [n, top_k_group]
48+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
49+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
50+
score_mask = group_mask.unsqueeze(-1).expand(
51+
num_token, num_expert_group,
52+
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
53+
tmp_scores = scores.masked_fill(~score_mask.bool(),
54+
float("-inf")) # [n, e]
55+
56+
if e_score_correction_bias is not None:
57+
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
58+
topk_weights = original_scores.gather(1, topk_ids)
59+
else:
60+
topk_weights, topk_ids = torch.topk(tmp_scores,
61+
k=topk,
62+
dim=-1,
63+
sorted=False)
64+
65+
if renormalize:
66+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
67+
68+
return topk_weights, topk_ids.to(torch.int32)
69+
70+
71+
def select_experts(
72+
hidden_states: torch.Tensor,
73+
router_logits: torch.Tensor,
74+
top_k: int,
75+
use_grouped_topk: bool,
76+
renormalize: bool,
77+
topk_group: Optional[int] = None,
78+
num_expert_group: Optional[int] = None,
79+
custom_routing_function: Optional[Callable] = None,
80+
scoring_func: str = "softmax",
81+
e_score_correction_bias: Optional[torch.Tensor] = None,
82+
) -> tuple[torch.Tensor, torch.Tensor]:
83+
if use_grouped_topk:
84+
assert topk_group is not None
85+
assert num_expert_group is not None
86+
return grouped_topk(hidden_states=hidden_states,
87+
gating_output=router_logits,
88+
topk=top_k,
89+
renormalize=renormalize,
90+
num_expert_group=num_expert_group,
91+
topk_group=topk_group,
92+
scoring_func=scoring_func,
93+
e_score_correction_bias=e_score_correction_bias)
94+
elif custom_routing_function is None:
95+
assert scoring_func == "softmax"
96+
topk_weights = torch.nn.functional.softmax(router_logits,
97+
dim=1,
98+
dtype=torch.float32)
99+
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
100+
if renormalize:
101+
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
102+
return topk_weights, topk_ids.to(torch.int32)
103+
else:
104+
return custom_routing_function(hidden_states=hidden_states,
105+
gating_output=router_logits,
106+
topk=top_k,
107+
renormalize=renormalize)
108+
109+
10110
class IPEXFusedMOE:
11111

12112
def __init__(self, layer: torch.nn.Module) -> None:
@@ -56,113 +156,6 @@ class SGLFusedMOE:
56156
def __init__(self, layer: torch.nn.Module) -> None:
57157
pass
58158

59-
@staticmethod
60-
def _grouped_topk(
61-
hidden_states: torch.Tensor,
62-
gating_output: torch.Tensor,
63-
topk: int,
64-
renormalize: bool,
65-
num_expert_group: int = 0,
66-
topk_group: int = 0,
67-
scoring_func: str = "softmax",
68-
e_score_correction_bias: Optional[torch.Tensor] = None
69-
) -> tuple[torch.Tensor, torch.Tensor]:
70-
assert hidden_states.shape[0] == gating_output.shape[0], (
71-
"Number of tokens mismatch")
72-
73-
gating_output = gating_output.float()
74-
if scoring_func == "softmax":
75-
scores = torch.softmax(gating_output, dim=-1)
76-
elif scoring_func == "sigmoid":
77-
scores = gating_output.sigmoid()
78-
else:
79-
raise ValueError(f"Unsupported scoring function: {scoring_func}")
80-
81-
num_token = scores.shape[0]
82-
if e_score_correction_bias is not None:
83-
# Store original scores before applying correction bias. We use
84-
# biased scores for expert selection but original scores for
85-
# routing weights
86-
original_scores = scores
87-
scores = scores + e_score_correction_bias.unsqueeze(0)
88-
group_scores = (scores.view(num_token, num_expert_group,
89-
-1).topk(2, dim=-1)[0].sum(dim=-1))
90-
else:
91-
group_scores = scores.view(num_token, num_expert_group,
92-
-1).max(dim=-1).values # [n, n_group]
93-
group_idx = torch.topk(group_scores,
94-
k=topk_group,
95-
dim=-1,
96-
sorted=False)[1] # [n, top_k_group]
97-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
98-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
99-
score_mask = group_mask.unsqueeze(-1).expand(
100-
num_token, num_expert_group,
101-
scores.shape[-1] // num_expert_group).reshape(num_token,
102-
-1) # [n, e]
103-
tmp_scores = scores.masked_fill(~score_mask.bool(),
104-
float("-inf")) # [n, e]
105-
106-
if e_score_correction_bias is not None:
107-
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
108-
# Use original unbiased scores for the routing weights
109-
topk_weights = original_scores.gather(1, topk_ids)
110-
else:
111-
topk_weights, topk_ids = torch.topk(tmp_scores,
112-
k=topk,
113-
dim=-1,
114-
sorted=False)
115-
116-
if renormalize:
117-
topk_weights = topk_weights / topk_weights.sum(dim=-1,
118-
keepdim=True)
119-
120-
return topk_weights, topk_ids.to(torch.int32)
121-
122-
@staticmethod
123-
def _select_experts(
124-
hidden_states: torch.Tensor,
125-
router_logits: torch.Tensor,
126-
top_k: int,
127-
use_grouped_topk: bool,
128-
renormalize: bool,
129-
topk_group: Optional[int] = None,
130-
num_expert_group: Optional[int] = None,
131-
custom_routing_function: Optional[Callable] = None,
132-
scoring_func: str = "softmax",
133-
e_score_correction_bias: Optional[torch.Tensor] = None,
134-
) -> tuple[torch.Tensor, torch.Tensor]:
135-
# DeekSeekv2 uses grouped_top_k
136-
if use_grouped_topk:
137-
assert topk_group is not None
138-
assert num_expert_group is not None
139-
topk_weights, topk_ids = SGLFusedMOE._grouped_topk(
140-
hidden_states=hidden_states,
141-
gating_output=router_logits,
142-
topk=top_k,
143-
renormalize=renormalize,
144-
num_expert_group=num_expert_group,
145-
topk_group=topk_group,
146-
scoring_func=scoring_func,
147-
e_score_correction_bias=e_score_correction_bias)
148-
elif custom_routing_function is None:
149-
assert scoring_func == "softmax"
150-
topk_weights = torch.nn.functional.softmax(router_logits,
151-
dim=1,
152-
dtype=torch.float32)
153-
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
154-
if renormalize:
155-
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
156-
topk_ids = topk_ids.to(torch.int32)
157-
else:
158-
topk_weights, topk_ids = custom_routing_function(
159-
hidden_states=hidden_states,
160-
gating_output=router_logits,
161-
topk=top_k,
162-
renormalize=renormalize)
163-
164-
return topk_weights, topk_ids
165-
166159
def __call__(
167160
self,
168161
layer: torch.nn.Module,
@@ -183,7 +176,7 @@ def __call__(
183176
) -> torch.Tensor:
184177
assert activation == "silu", f"{activation} is not supported."
185178
assert not apply_router_weight_on_input
186-
topk_weights, topk_ids = SGLFusedMOE._select_experts(
179+
topk_weights, topk_ids = select_experts(
187180
hidden_states=x,
188181
router_logits=router_logits,
189182
use_grouped_topk=use_grouped_topk,
@@ -213,3 +206,80 @@ def __call__(
213206
True,
214207
)
215208
return x
209+
210+
211+
class CPUFusedMOE:
212+
213+
def __init__(self, layer: torch.nn.Module) -> None:
214+
pass
215+
216+
def __call__(
217+
self,
218+
layer: torch.nn.Module,
219+
x: torch.Tensor,
220+
use_grouped_topk: bool,
221+
top_k: int,
222+
router_logits: torch.Tensor,
223+
renormalize: bool,
224+
topk_group: Optional[int] = None,
225+
num_expert_group: Optional[int] = None,
226+
global_num_experts: int = -1,
227+
expert_map: Optional[torch.Tensor] = None,
228+
custom_routing_function: Optional[Callable] = None,
229+
scoring_func: str = "softmax",
230+
e_score_correction_bias: Optional[torch.Tensor] = None,
231+
apply_router_weight_on_input: bool = False,
232+
activation: str = "silu",
233+
) -> torch.Tensor:
234+
assert activation == "silu", f"{activation} is not supported."
235+
assert not apply_router_weight_on_input
236+
topk_weights, topk_ids = select_experts(
237+
hidden_states=x,
238+
router_logits=router_logits,
239+
use_grouped_topk=use_grouped_topk,
240+
top_k=top_k,
241+
renormalize=renormalize,
242+
topk_group=topk_group,
243+
num_expert_group=num_expert_group,
244+
custom_routing_function=custom_routing_function,
245+
scoring_func=scoring_func,
246+
e_score_correction_bias=e_score_correction_bias,
247+
)
248+
249+
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
250+
len_experts = global_num_experts
251+
252+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
253+
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
254+
tokens_per_expert = cnts.sum(dim=0)
255+
idxs = topk_ids.view(-1).argsort()
256+
257+
sorted_tokens = x[idxs // topk_ids.shape[1]]
258+
tokens_per_expert = tokens_per_expert.cpu().numpy()
259+
260+
outputs = []
261+
start_idx = 0
262+
for i, num_tokens in enumerate(tokens_per_expert):
263+
end_idx = start_idx + num_tokens
264+
if num_tokens == 0:
265+
continue
266+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
267+
268+
layer_w13_weight = layer.w13_weight[i]
269+
layer_w2_weight = layer.w2_weight[i]
270+
271+
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
272+
gate_up = silu_and_mul(gate_up)
273+
expert_out = F.linear(gate_up, layer_w2_weight)
274+
outputs.append(expert_out)
275+
start_idx = end_idx
276+
277+
outs = torch.cat(outputs,
278+
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
279+
new_x = torch.empty_like(outs)
280+
281+
new_x[idxs] = outs
282+
final_out = (new_x.view(
283+
*topk_ids.shape, -1).type(topk_weights.dtype).mul_(
284+
topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
285+
return final_out

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
358358
use_prepack=True,
359359
)
360360
elif current_platform.is_cpu():
361+
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
361362
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
362-
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
363363
from vllm.model_executor.layers.utils import (
364364
check_cpu_sgl_kernel)
365365
dtype_w13 = layer.w13_weight.dtype
@@ -382,7 +382,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
382382
else:
383383
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
384384
else:
385-
raise NotImplementedError("CPU MOE only supports x86 arch.")
385+
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
386386

387387
def apply(
388388
self,

0 commit comments

Comments
 (0)