Skip to content

Commit 5febcf8

Browse files
authored
MoE Nb readability improvements (#761)
1 parent f92b40e commit 5febcf8

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,28 @@
152152
" self.num_experts = cfg[\"num_experts\"]\n",
153153
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
154154
"\n",
155-
" meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n",
156-
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
157-
" for _ in range(cfg[\"num_experts\"])])\n",
158-
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
159-
" for _ in range(cfg[\"num_experts\"])])\n",
160-
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
161-
" for _ in range(cfg[\"num_experts\"])])\n",
155+
" # meta device to reduce memory pressure when initializing the model before loading weights\n",
156+
" meta_device = torch.device(\"meta\")\n",
157+
" self.fc1 = nn.ModuleList([\n",
158+
" nn.Linear(\n",
159+
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
160+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
161+
" for _ in range(cfg[\"num_experts\"])]\n",
162+
" )\n",
163+
" self.fc2 = nn.ModuleList([\n",
164+
" nn.Linear(\n",
165+
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
166+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
167+
" )\n",
168+
" for _ in range(cfg[\"num_experts\"])]\n",
169+
" )\n",
170+
" self.fc3 = nn.ModuleList([\n",
171+
" nn.Linear(\n",
172+
" cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
173+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
174+
" )\n",
175+
" for _ in range(cfg[\"num_experts\"])]\n",
176+
" )\n",
162177
"\n",
163178
" def forward(self, x):\n",
164179
" b, seq_len, embed_dim = x.shape\n",
@@ -194,20 +209,18 @@
194209
" # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
195210
" # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
196211
" # y = torch.zeros_like(x)\n",
197-
"\n",
212+
" #\n",
198213
" # for i in range(self.num_experts_per_tok):\n",
199214
" # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
200215
" # expert_indices = topk_indices[..., i]\n",
201216
" # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
202-
"\n",
217+
" #\n",
203218
" # # For each expert, process only the tokens assigned to it\n",
204219
" # for e in range(self.num_experts):\n",
205220
" # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
206221
" # if mask.any():\n",
207222
" # selected = x[mask] # (num_tokens_e, emb_dim)\n",
208-
" # # Compute FF for expert e\n",
209223
" # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
210-
" # # Scale by gating prob and scatter back\n",
211224
" # y[mask] += prob[mask] * out\n",
212225
" # return y"
213226
]

ch05/11_qwen3/standalone-qwen3-moe.ipynb

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,28 @@
152152
" self.num_experts = cfg[\"num_experts\"]\n",
153153
" self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
154154
"\n",
155-
" meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n",
156-
" self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
157-
" for _ in range(cfg[\"num_experts\"])])\n",
158-
" self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
159-
" for _ in range(cfg[\"num_experts\"])])\n",
160-
" self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
161-
" for _ in range(cfg[\"num_experts\"])])\n",
155+
" # meta device to reduce memory pressure when initializing the model before loading weights\n",
156+
" meta_device = torch.device(\"meta\")\n",
157+
" self.fc1 = nn.ModuleList([\n",
158+
" nn.Linear(\n",
159+
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
160+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
161+
" for _ in range(cfg[\"num_experts\"])]\n",
162+
" )\n",
163+
" self.fc2 = nn.ModuleList([\n",
164+
" nn.Linear(\n",
165+
" cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
166+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
167+
" )\n",
168+
" for _ in range(cfg[\"num_experts\"])]\n",
169+
" )\n",
170+
" self.fc3 = nn.ModuleList([\n",
171+
" nn.Linear(\n",
172+
" cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
173+
" bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
174+
" )\n",
175+
" for _ in range(cfg[\"num_experts\"])]\n",
176+
" )\n",
162177
"\n",
163178
" def forward(self, x):\n",
164179
" b, seq_len, embed_dim = x.shape\n",
@@ -194,20 +209,18 @@
194209
" # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
195210
" # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
196211
" # y = torch.zeros_like(x)\n",
197-
"\n",
212+
" #\n",
198213
" # for i in range(self.num_experts_per_tok):\n",
199214
" # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
200215
" # expert_indices = topk_indices[..., i]\n",
201216
" # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
202-
"\n",
217+
" #\n",
203218
" # # For each expert, process only the tokens assigned to it\n",
204219
" # for e in range(self.num_experts):\n",
205220
" # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
206221
" # if mask.any():\n",
207222
" # selected = x[mask] # (num_tokens_e, emb_dim)\n",
208-
" # # Compute FF for expert e\n",
209223
" # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
210-
" # # Scale by gating prob and scatter back\n",
211224
" # y[mask] += prob[mask] * out\n",
212225
" # return y"
213226
]

0 commit comments

Comments
 (0)