|
152 | 152 | " self.num_experts = cfg[\"num_experts\"]\n",
|
153 | 153 | " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
154 | 154 | "\n",
|
155 |
| - " meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n", |
156 |
| - " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", |
157 |
| - " for _ in range(cfg[\"num_experts\"])])\n", |
158 |
| - " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", |
159 |
| - " for _ in range(cfg[\"num_experts\"])])\n", |
160 |
| - " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", |
161 |
| - " for _ in range(cfg[\"num_experts\"])])\n", |
| 155 | + " # meta device to reduce memory pressure when initializing the model before loading weights\n", |
| 156 | + " meta_device = torch.device(\"meta\")\n", |
| 157 | + " self.fc1 = nn.ModuleList([\n", |
| 158 | + " nn.Linear(\n", |
| 159 | + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", |
| 160 | + " bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n", |
| 161 | + " for _ in range(cfg[\"num_experts\"])]\n", |
| 162 | + " )\n", |
| 163 | + " self.fc2 = nn.ModuleList([\n", |
| 164 | + " nn.Linear(\n", |
| 165 | + " cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n", |
| 166 | + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", |
| 167 | + " )\n", |
| 168 | + " for _ in range(cfg[\"num_experts\"])]\n", |
| 169 | + " )\n", |
| 170 | + " self.fc3 = nn.ModuleList([\n", |
| 171 | + " nn.Linear(\n", |
| 172 | + " cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n", |
| 173 | + " bias=False, dtype=cfg[\"dtype\"], device=meta_device\n", |
| 174 | + " )\n", |
| 175 | + " for _ in range(cfg[\"num_experts\"])]\n", |
| 176 | + " )\n", |
162 | 177 | "\n",
|
163 | 178 | " def forward(self, x):\n",
|
164 | 179 | " b, seq_len, embed_dim = x.shape\n",
|
|
194 | 209 | " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
|
195 | 210 | " # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
|
196 | 211 | " # y = torch.zeros_like(x)\n",
|
197 |
| - "\n", |
| 212 | + " #\n", |
198 | 213 | " # for i in range(self.num_experts_per_tok):\n",
|
199 | 214 | " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
|
200 | 215 | " # expert_indices = topk_indices[..., i]\n",
|
201 | 216 | " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
|
202 |
| - "\n", |
| 217 | + " #\n", |
203 | 218 | " # # For each expert, process only the tokens assigned to it\n",
|
204 | 219 | " # for e in range(self.num_experts):\n",
|
205 | 220 | " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
|
206 | 221 | " # if mask.any():\n",
|
207 | 222 | " # selected = x[mask] # (num_tokens_e, emb_dim)\n",
|
208 |
| - " # # Compute FF for expert e\n", |
209 | 223 | " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
|
210 |
| - " # # Scale by gating prob and scatter back\n", |
211 | 224 | " # y[mask] += prob[mask] * out\n",
|
212 | 225 | " # return y"
|
213 | 226 | ]
|
|
0 commit comments