Skip to content

Commit cf10b5c

Browse files
authored
gemms conceptual (#134)
conceptual docs first one
1 parent e39c289 commit cf10b5c

File tree

1 file changed

+268
-0
lines changed

1 file changed

+268
-0
lines changed

docs/conceptual/aimodels_gemms.md

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
## 📌 GEMMs in AI Workloads
2+
3+
General Matrix Multiplications (GEMMs) are the **primary compute primitive** used in AI models. Efficient implementations of GEMMs are readily available through vendor-tuned libraries such as cuBLAS and hipBLAS. Therefore, whenever possible, the goal is to **reduce computation to a matrix multiply**.
4+
5+
This document explains how **model-level parameters** like batch size, sequence length, and hidden dimension translate into GEMM shapes, and subsequently how these shapes map to specific **BLAS kernel calls**.
6+
7+
By the end of this post, you should be equipped to understand the GEMM shapes, counts, and BLAS calls involved in **any AI model you encounter**.
8+
9+
---
10+
11+
## From Model Dimensions to GEMM Shapes
12+
13+
### Linear Layers in LLMs
14+
15+
Let's begin by understanding how linear layers, such as the MLP **up projection**, correspond to GEMM calls.
16+
17+
The input tensor shape for this operation is:
18+
```
19+
X: [B, L, d_model]
20+
```
21+
Here:
22+
- `B` represents the batch size.
23+
- `L` denotes the sequence length.
24+
- `d_model` is the input (or hidden) dimension.
25+
26+
This operation outputs a tensor with the shape:
27+
```
28+
Y: [B, L, d_ff]
29+
```
30+
The projection for each token can be expressed individually as:
31+
```
32+
Y[b, l, :] = X[b, l, :] @ Wᵀ
33+
```
34+
Where:
35+
- `W` is the weight matrix with a shape of `[d_ff, d_model]`.
36+
37+
### Flattening for GEMM
38+
39+
To express this entire operation as a single GEMM, we flatten the batch and sequence dimensions of the input tensor:
40+
```
41+
X_flat: [B·L, d_model]
42+
Wᵀ:     [d_model, d_ff]
43+
Y = X_flat @ Wᵀ
44+
```
45+
This flattening yields the following GEMM shape parameters:
46+
- `param: M = B·L`
47+
- `param: N = d_ff`
48+
- `param: K = d_model`
49+
50+
Here, K represents the **inner or shared dimension** between the input tensors involved in the multiplication.
51+
52+
53+
***
54+
55+
## Prefill vs Decode
56+
57+
Let's now contrast the GEMM behavior during the **prefill** and **decode** phases of inference, focusing on how the sequence length (`L`) changes and affects the GEMM shapes.
58+
59+
In both phases, the input tensor for an MLP computation within an LLM initially has a shape like:
60+
```
61+
X: [B, L, d_model]
62+
```
63+
Where `B` is the batch size, `L` is the sequence length, and `d_model` is the hidden size. This projects to an output shape of `[B, L, d_ff]`, where `d_ff` is the MLP's expansion size.
64+
65+
As established earlier, to process this with a single GEMM, the batch and sequence dimensions of the input are flattened. The input effectively becomes `[B·L, d_model]` for the GEMM `X_flat @ Wᵀ`.
66+
67+
The key difference between prefill and decode lies in the value of `L`:
68+
- **Prefill Phase**: `L` is the actual input sequence length (which can be large).
69+
- **Decode Phase**: `L` is always `1`, as the model processes one token at a time to generate the next.
70+
71+
This difference in `L` directly impacts the `M` parameter of the GEMM `(M, N, K)`:
72+
73+
### GEMM Shape Summary:
74+
| Mode     | Input Shape (Conceptual) | Flattened Input Shape | GEMM Shape `(M, N, K)`              | Notes                                |
75+
|----------|--------------------------|-----------------------|-------------------------------------|---------------------------------------|
76+
| Prefill  | `[B, L, d_model]`        | `[B·L, d_model]`      | `(B·L, d_ff, d_model)`              | `L` is the prompt length            |
77+
| Decode   | `[B, 1, d_model]`        | `[B·1, d_model]`      | `(B, d_ff, d_model)`                | `L=1` for generating one token at a time |
78+
79+
Notice that in the decode phase, because `L=1`, the `M` parameter of the MLP GEMM becomes simply `B`. This means the computational cost of the MLP layers in decode remains constant per token regardless of the total sequence length generated so far. The dominant **O(L)** scaling cost during decode comes from the attention mechanism, not the MLPs.
80+
81+
### Real-world Example: LLaMA-2 7B
82+
83+
The table below shows actual data from **Tracelens** profiling, filtered specifically for MLP **up** and **gate** projection GEMMs in a LLaMA-2 7B model inference trace.
84+
85+
For this trace:
86+
- `d_model = 4096`
87+
- `d_ff = 11008`
88+
- Batch size: `1` (`B=1`)
89+
- Input length (for prefill): `597` (`L=597`)
90+
- The trace included 36 decode steps.
91+
92+
| name     | param: M | param: N | param: K | param: bias | counts |
93+
|----------|-----------|-----------|-----------|---------------|--------|
94+
| aten::mm | 1         | 11008     | 4096      | FALSE         | 2304   |
95+
| aten::mm | 597       | 11008     | 4096      | FALSE         | 64     |
96+
97+
Interpreting these entries based on the `M` parameter:
98+
- The entry with `param: M = 597` corresponds to the **prefill** phase GEMM `(B·L = 1·597)`, which happens once per layer at the beginning of inference. Since there are 32 layers, this GEMM is called `64` times (32 up + 32 gate).
99+
- The entry with `param: M = 1` corresponds to the **decode** phase GEMM `(B = 1)`, where `L=1`. These occur at each decode step for every layer. With 36 decode steps and 64 GEMMs per step (32 layers * 2), this GEMM is called `36 × 64 = 2304` times.
100+
101+
---
102+
103+
## 🔁 Backward Pass GEMMs
104+
105+
Next, let's explore the backward pass during training. A forward pass GEMM operation like `Y = X @ Wᵀ + b` necessitates **two corresponding backward GEMMs** to compute gradients:
106+
107+
```python
108+
dX = dY @ W        # Gradient with respect to the input → resulting shape: [B·L, d_model]
109+
dW = dYᵀ @ X       # Gradient with respect to the weight → resulting shape: [d_ff, d_model]
110+
db = dY.sum(dim=0) # Gradient with respect to the bias   → resulting shape: [d_ff]
111+
```
112+
113+
### GEMM Shapes:
114+
| Operation         | GEMM Shape `(param: M, param: N, param: K)`       | Description                          |
115+
|------------------|---------------------------------------------------|--------------------------------------|
116+
| Forward          | `(B·L, d_ff, d_model)`                             | `X @ Wᵀ`                             |
117+
| Backward dX      | `(B·L, d_model, d_ff)`                             | `dY @ W` (result of `[B·L, d_ff] @ [d_ff, d_model]`) |
118+
| Backward dW      | `(d_ff, B·L, d_model)`                             | `dYᵀ @ X` (result of `[d_ff, B·L] @ [B·L, d_model]`) |
119+
120+
Let's look closer at the backward GEMM shapes:
121+
- For `dX = dY @ W`, the operation is `[B·L, d_ff] @ [d_ff, d_model]`, which results in a shape of `[B·L, d_model]`.
122+
- For `dW = dYᵀ @ X`, the operation is `[d_ff, B·L] @ [B·L, d_model]`, yielding a shape of `[d_ff, d_model]`.
123+
124+
### Real-world Example: GPT-3-XL
125+
126+
This table presents data from a **Tracelens** of a single training step for **GPT-3-XL**.
127+
128+
For this example:
129+
- `d_model = 2048`
130+
- `d_ff = 8192`
131+
- Batch size: `5`
132+
- Sequence length: `2048`
133+
- Thus, `param: M` for the flattened dimension is `5 × 2048 = 10240`.
134+
135+
| name        | param: M | param: N | param: K | count |
136+
|-------------|-----------|-----------|-----------|--------|
137+
| aten::addmm | 10240     | 8192      | 2048      | 24     |
138+
| aten::mm    | 10240     | 2048      | 8192      | 24     |
139+
| aten::mm    | 8192      | 2048      | 10240     | 24     |
140+
141+
We can interpret each entry based on the GEMM shapes:
142+
- The `aten::addmm` call represents the forward pass GEMM (`X @ Wᵀ`).
143+
- The first `aten::mm` call corresponds to the backward pass for `dX` (`dY @ W`).
144+
- The second `aten::mm` call represents the backward pass for `dW` (`dYᵀ @ X`).
145+
146+
Each of these operations appears once per layer in the network. Given that GPT-3-XL has 24 layers, each of these GEMMs is called 24 times per training step, aligning with the 'count' column in the table.
147+
148+
---
149+
150+
151+
## ⚙️ How PyTorch Calls BLAS
152+
153+
To fully grasp how PyTorch leverages BLAS for operations like GEMM, we must first understand the fundamental concept of **memory layout** for tensors and how BLAS libraries interpret the data buffers they receive.
154+
155+
### Memory Layout and Stride
156+
157+
Despite tensors often being represented as multi-dimensional arrays, their elements are stored in linear memory. For a 2D matrix, the two primary storage conventions are:
158+
159+
- **Row-major**: Elements of the same row are stored consecutively in memory. PyTorch adopts this as its default layout.
160+
- **Column-major**: Elements of the same column are stored consecutively in memory. Many traditional BLAS libraries primarily optimize for this layout.
161+
162+
PyTorch's `.stride()` method provides insight into a tensor's memory arrangement. It returns a tuple where each value indicates the byte (or element, depending on datatype size) distance in linear memory to move to the next element along that dimension.
163+
- For a 2D tensor `T[i][j]` in **Row-major** layout, `.stride()` is typically `(num_cols, 1)`. Moving to `T[i][j+1]` requires stepping 1 element, while moving to `T[i+1][j]` requires stepping `num_cols` elements.
164+
- For a 2D tensor `T[i][j]` in **Column-major** layout, `.stride()` is typically `(1, num_rows)`. Moving to `T[i+1][j]` requires stepping 1 element, while moving to `T[i][j+1]` requires stepping `num_rows` elements.
165+
166+
---
167+
168+
### BLAS Transpose and Row-Major Output
169+
170+
The core BLAS GEMM routine typically computes $C = \alpha \cdot op(A) \cdot op(B) + \beta \cdot C$, where $op(X)$ is either $X$ or $X^T$ depending on the `transA` and `transB` flags ('N' for No Transpose, 'T' for Transpose) passed to the function. By default, BLAS expects input matrices corresponding to the 'N' flag to be in column-major layout. Crucially, the resulting matrix $C$ is written into the output buffer in **column-major** format by default.
171+
172+
PyTorch, however, uses row-major layout internally and desires the result of a GEMM operation to also be in row-major layout *without* an extra copy or transpose step outside of the BLAS call. PyTorch achieves this by cleverly leveraging the `trans` flags and the relationship between row-major and column-major layouts.
173+
174+
A matrix $M$ stored in row-major memory has the exact same element ordering as the matrix $M^T$ stored in column-major memory. PyTorch uses this identity. To get a row-major result $C$ from a BLAS call that outputs in column-major, PyTorch requests BLAS to compute $C^T$ and write it in column-major. Since $C^T$ in column-major is $C$ in row-major, the output buffer will contain the desired row-major $C$.
175+
176+
Mathematically, the operation $C = A @ B$ (where $A, B, C$ are desired in row-major) is equivalent to computing $C^T = (A @ B)^T = B^T @ A^T$. PyTorch therefore configures the BLAS call to compute $B^T @ A^T$ using the row-major data of $B$ and $A$.
177+
178+
Here's how the `transA` and `transB` flags work in this context when passing **row-major data** to BLAS via a wrapper like PyTorch's:
179+
- Passing row-major data for matrix $M$ with `trans = 'T'` tells BLAS to mathematically treat this data as $M$. (BLAS expects row-major data for 'T' if it wants to use the matrix directly).
180+
- Passing row-major data for matrix $M$ with `trans = 'N'` tells BLAS to mathematically treat this data as $M^T$. (BLAS expects column-major data for 'N'; giving it row-major data makes it see the transpose).
181+
182+
So, to compute $C^T = B^T @ A^T$ using row-major data for $B$ and $A$ and get $C$ row-major in the output buffer:
183+
- Pass $B$'s row-major data as the first operand data (`A_data` in BLAS call). To make BLAS see $B^T$, use `transA = 'N'`.
184+
- Pass $A$'s row-major data as the second operand data (`B_data` in BLAS call). To make BLAS see $A^T$, use `transB = 'N'`.
185+
- The BLAS call becomes `gemm(transA='N', transB='N', ..., B_data, ..., A_data, ...)`. This computes $B^T @ A^T = C^T$. The result $C^T$ is written in column-major into the output buffer, which is precisely the desired $C$ in row-major.
186+
187+
This standard trick using `transA='N'` and `transB='N'` with swapped, row-major inputs is a common way PyTorch achieves row-major output for a general matrix multiply `C = A @ B` where A, B are row-major.
188+
189+
---
190+
191+
192+
### Linear Layer: `Y = X @ Wᵀ`
193+
194+
For a linear layer computation `Y = X @ Wᵀ`, where `X` (`[M, K]`) and `W` (`[N, K]`) are in row-major layout, PyTorch desires `Y` (`[M, N]`) also in row-major. To achieve this with a BLAS routine outputting column-major, PyTorch configures BLAS to compute $Y^T = W @ X^T$.
195+
196+
This involves a BLAS call computing $op(A) @ op(B)$ where $op(A)$ is $W$ and $op(B)$ is $X^T$. Using the rule that row-major data with `trans='T'` yields the matrix ($M$) and `trans='N'` yields the transpose ($M^T$):
197+
- BLAS operand A uses $W$'s row-major data. To see $W$, `transA = 'T'`.
198+
- BLAS operand B uses $X$'s row-major data. To see $X^T$, `transB = 'N'`.
199+
200+
The BLAS call uses `(transA='T', transB='N')` with $W$'s data as the first operand and $X$'s data as the second. It computes $W @ X^T = Y^T$, writing the result in column-major, which PyTorch interprets as the desired row-major $Y$.
201+
202+
### Backward Pass Operations:
203+
204+
The backward pass similarly uses GEMMs configured to produce row-major gradients:
205+
206+
- **`dX = dY @ W`**: With `dY` (`[M, K]`) and `W` (`[K, N]`) row-major, we need `dX` (`[M, N]`) row-major. BLAS computes $dX^T = W^T @ dY^T$.
207+
- BLAS operand A uses $W$'s row-major data. Needs $W^T \implies$ `transA = 'N'`.
208+
- BLAS operand B uses $dY$'s row-major data. Needs $dY^T \implies$ `transB = 'N'`.
209+
- BLAS call uses `(transA='N', transB='N')` on $W$'s and $dY$'s data, computing $W^T @ dY^T$.
210+
211+
- **`dW = dYᵀ @ X`**: With `dY` (`[K, N]`) and `X` (`[K, M]`) row-major, we need `dW` (`[N, M]`) row-major. BLAS computes $dW^T = X^T @ dY$.
212+
- BLAS operand A uses $X$'s row-major data. Needs $X^T \implies$ `transA = 'N'`.
213+
- BLAS operand B uses $dY$'s row-major data. Needs $dY \implies$ `transB = 'T'`.
214+
- BLAS call uses `(transA='N', transB='T')` on $X$'s and $dY$'s data, computing $X^T @ dY$.
215+
216+
In summary, for PyTorch's row-major operations:
217+
- Forward pass `Y = X @ Wᵀ` maps to BLAS calculating $W @ X^T$ using `(T, N)` flags on the row-major data of $W$ and $X$.
218+
- Backward pass `dX = dY @ W` maps to BLAS calculating $W^T @ dY^T$ using `(N, N)` flags on the row-major data of $W$ and $dY$.
219+
- Backward pass `dW = dYᵀ @ X` maps to BLAS calculating $X^T @ dY$ using `(N, T)` flags on the row-major data of $X$ and $dY$.
220+
221+
Let’s revisit the GPT-3-XL model gemm table from **Tracelens**:
222+
223+
| name        | param: M | param: N | param: K | param: bias | param: stride_A | param: stride_B | param: transpose |
224+
|--------------|-----------|-----------|-----------|--------------|------------------|------------------|--------------------|
225+
| aten::addmm  | 10240     | 8192      | 2048      | TRUE         | (2048, 1)        | (1, 2048)        | (True, False)      |
226+
| aten::mm     | 10240     | 2048      | 8192      | FALSE        | (8192, 1)        | (2048, 1)        | (False, False)     |
227+
| aten::mm     | 8192      | 2048      | 10240     | FALSE        | (1, 8192)        | (2048, 1)        | (False, True)      |
228+
229+
This table shows how `aten::addmm` (forward) and `aten::mm` (backward) calls map to underlying GEMM operations. The `param: M, N, K` values are likely the dimensions of the *PyTorch operation result* (`M x N` with inner dim `K`). The `param: transpose | (transA, transB)` are the BLAS flags used by the wrapper for the operands passed to the BLAS call.
230+
231+
Let's interpret the trace entries based on our understanding that PyTorch uses row-major data and BLAS receives flags to mathematically interpret this data for computing the transpose of the desired result:
232+
233+
1. **`aten::addmm` (Forward)**: Corresponds to `Y = X @ Wᵀ`. Trace flags: `(True, False)`. This matches the `(T, N)` needed for BLAS to compute $W @ X^T$.
234+
2. **First `aten::mm` (Backward dX)**: Corresponds to `dX = dY @ W`. Trace flags: `(False, False)`. This matches the `(N, N)` needed for BLAS to compute $W^T @ dY^T$.
235+
236+
3. **Second `aten::mm` (Backward dW)**: Corresponds to `dW = dYᵀ @ X`. Trace flags: `(False, True)`. This matches the `(N, T)` needed for BLAS to compute $X^T @ dY$.
237+
238+
239+
This confirms how the trace flags correspond to the BLAS transpose configurations used with row-major input data to achieve row-major output via the $C^T = B^T A^T$ trick.
240+
241+
242+
Lets summarize our understanding of the transpose flag by writing pseudo code for the logic:
243+
244+
*(Note: This Python code snippet provides a simplified view; PyTorch's actual implementation is more intricate, accounting for the specific GEMM variant and output requirements.)*
245+
246+
247+
248+
```python
249+
250+
def is_col_major(T):
251+
    return T.stride(0) == 1 and T.stride(1) >= T.shape[0]
252+
253+
def get_blas_transpose_flags(A, B):
254+
    transA = 'N' if is_col_major(A) else 'T' # If A is col-major, BLAS sees it as is ('N')
255+
    transB = 'N' if is_col_major(B) else 'T' # If B is col-major, BLAS sees it as is ('N')
256+
    return transA, transB
257+
258+
```
259+
260+
## ⚠️ Edge Cases
261+
262+
263+
-   One common assumption is that flattening a tensor shape like `[B, L, d_model]` to `[B⋅L, d_model]` is a cost-free metadata operation. This is only true if the last dimension (`d_model`) is contiguous.
264+
265+
-   If the last dimension is not contiguous, PyTorch may be forced to insert a **copy** or **transpose** operation to create a physically contiguous tensor that BLAS can work with efficiently.
266+
267+
-   Furthermore, even if the tensor layout *could* theoretically be used by BLAS (e.g., certain striding patterns), some highly tuned BLAS libraries might lack kernels optimized for those specific layouts. In such instances, a **copy** or **transpose buffer** is inserted behind the scenes by PyTorch or the BLAS wrapper. Consequently, what the BLAS routine actually operates on might not be the original tensor directly, but rather a **temporary buffer** created for compatibility or performance.
268+

0 commit comments

Comments
 (0)