You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.md
+2-10Lines changed: 2 additions & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -55,7 +55,7 @@ For the draft stage in MTP, there are two different MTP methods, MTP vanilla and
55
55
56
56
MTP Vanilla method is more similar to the MTP training, and it sequentially uses different MTP modules to predict multiple draft tokens. This method can support model checkpoints with weights of multiple different MTP modules. And each MTP module will have its own KV cache.
57
57
58
-
Figure 2 illustrates the MTP vanilla inference. In the context phase, assuming there are a total of four input tokens, we will get the output token $t_5$ and the hidden states after the main model forward. The output token will be appended to the input tokens, then we shift out the first token to get tokens from $t_2$ to $t_5$ as the input tokens of the first MTP module. The hidden states from the main model will be directly used as the input of the first MTP module to predict the first draft token. For the next several MTP modules, we will use the same method to prepare the inputs to predict the sequential draft tokens.
58
+
Figure 2 illustrates the MTP vanilla inference. In the context phase, assuming there are a total of four input tokens, we will get the output token $t_5$ and the hidden states after the main model forward. The output token will be appended to the input tokens, then we shift out the first token to get tokens from $t_2$ to $t_5$ as the input tokens of the first MTP module. The hidden states from the main model will be directly used as the input of the first MTP module to predict the first draft token. For the next few MTP modules, we'll append the newly generated draft token and the hidden states corresponding to the last input token to the input tokens and hidden states. Then we'll shift out the first token to prepare the inputs for the next MTP module. In this way, we can retain as much information as possible from the main model, which helps the draft layer make more accurate predictions.
59
59
60
60
In the generation phase, there will be a little difference. The predicted token $t_5$ and the draft tokens will be used as inputs for the main model. After the main model forward, we will do the verification to get the accepted tokens. In this example, assuming $j$ draft tokens $d_6$~$d_{j+5}$ are accepted. Then prepare the MTP module inputs. Different from the context phase, we will prepare input IDs and hidden states of a total of $K$ tokens before the last accepted token. In this example, the last accepted token is $t_{j+6}$. Then we can get the first draft token after the first MTP module forward. For the sequential MTP modules, we can prepare their inputs in a similar way to the MTP modules in the context phase, so all of those MTP modules have the same input sequence length. After predicting all of the draft tokens, we need to evict the keys/values of those rejected draft tokens from the main model's KV cache to ensure the subsequent calculation is correct.
61
61
@@ -72,7 +72,7 @@ MTP Eagle can be viewed as a variant of [Eagle](https://arxiv.org/pdf/2401.15077
72
72
73
73
Figure 3 gives an MTP Eagle example. In the context phase, the inputs of the first MTP module forward are the same as the MTP Vanilla. However, for the sequential MTP module forward, the first difference is that MTP Eagle uses the same MTP module to predict draft tokens and reuses the same KV cache. Another difference is that we only need to input the token ID and the hidden state of one token. The token is the last predicted draft token, while the hidden state is the corresponding hidden state in the last MTP module forward. In this way, we can predict total K draft tokens by using only one MTP module.
74
74
75
-
In the generation phase, the verification stage is the same as MTP Vanilla. After getting the accepted tokens, we will use the last accepted tokens and the corresponding hidden state as the inputs of the first MTP module forward. Compared with MTP Vanilla, it will be much easier to implement. And the sequential MTP module forwards use the same method as the context phase to prepare inputs. After predicting all of the draft tokens, we need to evict the keys/values of those rejected draft tokens from the main model's KV cache.
75
+
In the generation phase, the verification stage is the same as MTP Vanilla. Once we get the accepted tokens, we use all of them along with their corresponding hidden states as inputs for the first MTP module forward. Unlike MTP Vanilla, which needs to store past tokens and hidden states, this approach is much easier to implement. Subsequent MTP module forwards follow the same input preparation method as the context phase. After predicting all draft tokens, we need to evict the key/value pairs of any rejected draft tokens from the main model’s KV cache.
76
76
77
77
## MTP implementation in TensorRT-LLM
78
78
### Basic Implementation
@@ -241,14 +241,6 @@ TensorRT-LLM PyTorch backend can only support chain-based speculative decoding n
241
241
242
242
Another important method is Eagle3. From the [Eagle3 paper](https://arxiv.org/pdf/2503.01840), the promising results show that it can help greatly increase the acceptance rate by leveraging different levels’ hidden states to predict draft tokens. Since TensorRT-LLM already has [Eagle-3 support](https://github.com/NVIDIA/TensorRT-LLM/pull/3035) now, in the future, we also want to train an Eagle3 head to support DeepSeek-V3/R1+Eagle3 to achieve better speedup.
243
243
244
-
### Fix known issues
245
-
246
-
There are still some known issues, and we will fix them soon:
247
-
- The MTP vanilla path has a known accuracy issue. We will fix it and refactor the MTP vanilla implementation.
248
-
- The MTP Eagle is non-deterministic now.
249
-
- An accuracy issue when enabling MTP and attention DP together.
250
-
251
-
252
244
## Acknowledgment
253
245
254
246
This was a remarkable cross-team effort to support and optimize MTP in TensorRT-LLM. We would like to extend our gratitude to everyone who contributed to making this possible, as it involved a typical system/algorithm co-design approach spanning multiple technical layers—including kernel optimization, runtime enhancements, algorithmic improvements, and performance measurement & analysis. And a special thanks goes to the DeepSeek team for developing the MTP method, which lays down the foundation of this blog.
0 commit comments