|
| 1 | +## 📌 GEMMs in AI Workloads |
| 2 | + |
| 3 | +General Matrix Multiplications (GEMMs) are the **primary compute primitive** used in AI models. Efficient implementations of GEMMs are readily available through vendor-tuned libraries such as cuBLAS and hipBLAS. Therefore, whenever possible, the goal is to **reduce computation to a matrix multiply**. |
| 4 | + |
| 5 | +This document explains how **model-level parameters** like batch size, sequence length, and hidden dimension translate into GEMM shapes, and subsequently how these shapes map to specific **BLAS kernel calls**. |
| 6 | + |
| 7 | +By the end of this post, you should be equipped to understand the GEMM shapes, counts, and BLAS calls involved in **any AI model you encounter**. |
| 8 | + |
| 9 | +--- |
| 10 | + |
| 11 | +## From Model Dimensions to GEMM Shapes |
| 12 | + |
| 13 | +### Linear Layers in LLMs |
| 14 | + |
| 15 | +Let's begin by understanding how linear layers, such as the MLP **up projection**, correspond to GEMM calls. |
| 16 | + |
| 17 | +The input tensor shape for this operation is: |
| 18 | +``` |
| 19 | +X: [B, L, d_model] |
| 20 | +``` |
| 21 | +Here: |
| 22 | +- `B` represents the batch size. |
| 23 | +- `L` denotes the sequence length. |
| 24 | +- `d_model` is the input (or hidden) dimension. |
| 25 | + |
| 26 | +This operation outputs a tensor with the shape: |
| 27 | +``` |
| 28 | +Y: [B, L, d_ff] |
| 29 | +``` |
| 30 | +The projection for each token can be expressed individually as: |
| 31 | +``` |
| 32 | +Y[b, l, :] = X[b, l, :] @ Wᵀ |
| 33 | +``` |
| 34 | +Where: |
| 35 | +- `W` is the weight matrix with a shape of `[d_ff, d_model]`. |
| 36 | + |
| 37 | +### Flattening for GEMM |
| 38 | + |
| 39 | +To express this entire operation as a single GEMM, we flatten the batch and sequence dimensions of the input tensor: |
| 40 | +``` |
| 41 | +X_flat: [B·L, d_model] |
| 42 | +Wᵀ: [d_model, d_ff] |
| 43 | +Y = X_flat @ Wᵀ |
| 44 | +``` |
| 45 | +This flattening yields the following GEMM shape parameters: |
| 46 | +- `param: M = B·L` |
| 47 | +- `param: N = d_ff` |
| 48 | +- `param: K = d_model` |
| 49 | + |
| 50 | +Here, K represents the **inner or shared dimension** between the input tensors involved in the multiplication. |
| 51 | + |
| 52 | + |
| 53 | +*** |
| 54 | + |
| 55 | +## Prefill vs Decode |
| 56 | + |
| 57 | +Let's now contrast the GEMM behavior during the **prefill** and **decode** phases of inference, focusing on how the sequence length (`L`) changes and affects the GEMM shapes. |
| 58 | + |
| 59 | +In both phases, the input tensor for an MLP computation within an LLM initially has a shape like: |
| 60 | +``` |
| 61 | +X: [B, L, d_model] |
| 62 | +``` |
| 63 | +Where `B` is the batch size, `L` is the sequence length, and `d_model` is the hidden size. This projects to an output shape of `[B, L, d_ff]`, where `d_ff` is the MLP's expansion size. |
| 64 | + |
| 65 | +As established earlier, to process this with a single GEMM, the batch and sequence dimensions of the input are flattened. The input effectively becomes `[B·L, d_model]` for the GEMM `X_flat @ Wᵀ`. |
| 66 | + |
| 67 | +The key difference between prefill and decode lies in the value of `L`: |
| 68 | +- **Prefill Phase**: `L` is the actual input sequence length (which can be large). |
| 69 | +- **Decode Phase**: `L` is always `1`, as the model processes one token at a time to generate the next. |
| 70 | + |
| 71 | +This difference in `L` directly impacts the `M` parameter of the GEMM `(M, N, K)`: |
| 72 | + |
| 73 | +### GEMM Shape Summary: |
| 74 | +| Mode | Input Shape (Conceptual) | Flattened Input Shape | GEMM Shape `(M, N, K)` | Notes | |
| 75 | +|----------|--------------------------|-----------------------|-------------------------------------|---------------------------------------| |
| 76 | +| Prefill | `[B, L, d_model]` | `[B·L, d_model]` | `(B·L, d_ff, d_model)` | `L` is the prompt length | |
| 77 | +| Decode | `[B, 1, d_model]` | `[B·1, d_model]` | `(B, d_ff, d_model)` | `L=1` for generating one token at a time | |
| 78 | + |
| 79 | +Notice that in the decode phase, because `L=1`, the `M` parameter of the MLP GEMM becomes simply `B`. This means the computational cost of the MLP layers in decode remains constant per token regardless of the total sequence length generated so far. The dominant **O(L)** scaling cost during decode comes from the attention mechanism, not the MLPs. |
| 80 | + |
| 81 | +### Real-world Example: LLaMA-2 7B |
| 82 | + |
| 83 | +The table below shows actual data from **Tracelens** profiling, filtered specifically for MLP **up** and **gate** projection GEMMs in a LLaMA-2 7B model inference trace. |
| 84 | + |
| 85 | +For this trace: |
| 86 | +- `d_model = 4096` |
| 87 | +- `d_ff = 11008` |
| 88 | +- Batch size: `1` (`B=1`) |
| 89 | +- Input length (for prefill): `597` (`L=597`) |
| 90 | +- The trace included 36 decode steps. |
| 91 | + |
| 92 | +| name | param: M | param: N | param: K | param: bias | counts | |
| 93 | +|----------|-----------|-----------|-----------|---------------|--------| |
| 94 | +| aten::mm | 1 | 11008 | 4096 | FALSE | 2304 | |
| 95 | +| aten::mm | 597 | 11008 | 4096 | FALSE | 64 | |
| 96 | + |
| 97 | +Interpreting these entries based on the `M` parameter: |
| 98 | +- The entry with `param: M = 597` corresponds to the **prefill** phase GEMM `(B·L = 1·597)`, which happens once per layer at the beginning of inference. Since there are 32 layers, this GEMM is called `64` times (32 up + 32 gate). |
| 99 | +- The entry with `param: M = 1` corresponds to the **decode** phase GEMM `(B = 1)`, where `L=1`. These occur at each decode step for every layer. With 36 decode steps and 64 GEMMs per step (32 layers * 2), this GEMM is called `36 × 64 = 2304` times. |
| 100 | + |
| 101 | +--- |
| 102 | + |
| 103 | +## 🔁 Backward Pass GEMMs |
| 104 | + |
| 105 | +Next, let's explore the backward pass during training. A forward pass GEMM operation like `Y = X @ Wᵀ + b` necessitates **two corresponding backward GEMMs** to compute gradients: |
| 106 | + |
| 107 | +```python |
| 108 | +dX = dY @ W # Gradient with respect to the input → resulting shape: [B·L, d_model] |
| 109 | +dW = dYᵀ @ X # Gradient with respect to the weight → resulting shape: [d_ff, d_model] |
| 110 | +db = dY.sum(dim=0) # Gradient with respect to the bias → resulting shape: [d_ff] |
| 111 | +``` |
| 112 | + |
| 113 | +### GEMM Shapes: |
| 114 | +| Operation | GEMM Shape `(param: M, param: N, param: K)` | Description | |
| 115 | +|------------------|---------------------------------------------------|--------------------------------------| |
| 116 | +| Forward | `(B·L, d_ff, d_model)` | `X @ Wᵀ` | |
| 117 | +| Backward dX | `(B·L, d_model, d_ff)` | `dY @ W` (result of `[B·L, d_ff] @ [d_ff, d_model]`) | |
| 118 | +| Backward dW | `(d_ff, B·L, d_model)` | `dYᵀ @ X` (result of `[d_ff, B·L] @ [B·L, d_model]`) | |
| 119 | + |
| 120 | +Let's look closer at the backward GEMM shapes: |
| 121 | +- For `dX = dY @ W`, the operation is `[B·L, d_ff] @ [d_ff, d_model]`, which results in a shape of `[B·L, d_model]`. |
| 122 | +- For `dW = dYᵀ @ X`, the operation is `[d_ff, B·L] @ [B·L, d_model]`, yielding a shape of `[d_ff, d_model]`. |
| 123 | + |
| 124 | +### Real-world Example: GPT-3-XL |
| 125 | + |
| 126 | +This table presents data from a **Tracelens** of a single training step for **GPT-3-XL**. |
| 127 | + |
| 128 | +For this example: |
| 129 | +- `d_model = 2048` |
| 130 | +- `d_ff = 8192` |
| 131 | +- Batch size: `5` |
| 132 | +- Sequence length: `2048` |
| 133 | +- Thus, `param: M` for the flattened dimension is `5 × 2048 = 10240`. |
| 134 | + |
| 135 | +| name | param: M | param: N | param: K | count | |
| 136 | +|-------------|-----------|-----------|-----------|--------| |
| 137 | +| aten::addmm | 10240 | 8192 | 2048 | 24 | |
| 138 | +| aten::mm | 10240 | 2048 | 8192 | 24 | |
| 139 | +| aten::mm | 8192 | 2048 | 10240 | 24 | |
| 140 | + |
| 141 | +We can interpret each entry based on the GEMM shapes: |
| 142 | +- The `aten::addmm` call represents the forward pass GEMM (`X @ Wᵀ`). |
| 143 | +- The first `aten::mm` call corresponds to the backward pass for `dX` (`dY @ W`). |
| 144 | +- The second `aten::mm` call represents the backward pass for `dW` (`dYᵀ @ X`). |
| 145 | + |
| 146 | +Each of these operations appears once per layer in the network. Given that GPT-3-XL has 24 layers, each of these GEMMs is called 24 times per training step, aligning with the 'count' column in the table. |
| 147 | + |
| 148 | +--- |
| 149 | + |
| 150 | + |
| 151 | +## ⚙️ How PyTorch Calls BLAS |
| 152 | + |
| 153 | +To fully grasp how PyTorch leverages BLAS for operations like GEMM, we must first understand the fundamental concept of **memory layout** for tensors and how BLAS libraries interpret the data buffers they receive. |
| 154 | + |
| 155 | +### Memory Layout and Stride |
| 156 | + |
| 157 | +Despite tensors often being represented as multi-dimensional arrays, their elements are stored in linear memory. For a 2D matrix, the two primary storage conventions are: |
| 158 | + |
| 159 | +- **Row-major**: Elements of the same row are stored consecutively in memory. PyTorch adopts this as its default layout. |
| 160 | +- **Column-major**: Elements of the same column are stored consecutively in memory. Many traditional BLAS libraries primarily optimize for this layout. |
| 161 | + |
| 162 | +PyTorch's `.stride()` method provides insight into a tensor's memory arrangement. It returns a tuple where each value indicates the byte (or element, depending on datatype size) distance in linear memory to move to the next element along that dimension. |
| 163 | +- For a 2D tensor `T[i][j]` in **Row-major** layout, `.stride()` is typically `(num_cols, 1)`. Moving to `T[i][j+1]` requires stepping 1 element, while moving to `T[i+1][j]` requires stepping `num_cols` elements. |
| 164 | +- For a 2D tensor `T[i][j]` in **Column-major** layout, `.stride()` is typically `(1, num_rows)`. Moving to `T[i+1][j]` requires stepping 1 element, while moving to `T[i][j+1]` requires stepping `num_rows` elements. |
| 165 | + |
| 166 | +--- |
| 167 | + |
| 168 | +### BLAS Transpose and Row-Major Output |
| 169 | + |
| 170 | +The core BLAS GEMM routine typically computes $C = \alpha \cdot op(A) \cdot op(B) + \beta \cdot C$, where $op(X)$ is either $X$ or $X^T$ depending on the `transA` and `transB` flags ('N' for No Transpose, 'T' for Transpose) passed to the function. By default, BLAS expects input matrices corresponding to the 'N' flag to be in column-major layout. Crucially, the resulting matrix $C$ is written into the output buffer in **column-major** format by default. |
| 171 | + |
| 172 | +PyTorch, however, uses row-major layout internally and desires the result of a GEMM operation to also be in row-major layout *without* an extra copy or transpose step outside of the BLAS call. PyTorch achieves this by cleverly leveraging the `trans` flags and the relationship between row-major and column-major layouts. |
| 173 | + |
| 174 | +A matrix $M$ stored in row-major memory has the exact same element ordering as the matrix $M^T$ stored in column-major memory. PyTorch uses this identity. To get a row-major result $C$ from a BLAS call that outputs in column-major, PyTorch requests BLAS to compute $C^T$ and write it in column-major. Since $C^T$ in column-major is $C$ in row-major, the output buffer will contain the desired row-major $C$. |
| 175 | + |
| 176 | +Mathematically, the operation $C = A @ B$ (where $A, B, C$ are desired in row-major) is equivalent to computing $C^T = (A @ B)^T = B^T @ A^T$. PyTorch therefore configures the BLAS call to compute $B^T @ A^T$ using the row-major data of $B$ and $A$. |
| 177 | + |
| 178 | +Here's how the `transA` and `transB` flags work in this context when passing **row-major data** to BLAS via a wrapper like PyTorch's: |
| 179 | +- Passing row-major data for matrix $M$ with `trans = 'T'` tells BLAS to mathematically treat this data as $M$. (BLAS expects row-major data for 'T' if it wants to use the matrix directly). |
| 180 | +- Passing row-major data for matrix $M$ with `trans = 'N'` tells BLAS to mathematically treat this data as $M^T$. (BLAS expects column-major data for 'N'; giving it row-major data makes it see the transpose). |
| 181 | + |
| 182 | +So, to compute $C^T = B^T @ A^T$ using row-major data for $B$ and $A$ and get $C$ row-major in the output buffer: |
| 183 | +- Pass $B$'s row-major data as the first operand data (`A_data` in BLAS call). To make BLAS see $B^T$, use `transA = 'N'`. |
| 184 | +- Pass $A$'s row-major data as the second operand data (`B_data` in BLAS call). To make BLAS see $A^T$, use `transB = 'N'`. |
| 185 | +- The BLAS call becomes `gemm(transA='N', transB='N', ..., B_data, ..., A_data, ...)`. This computes $B^T @ A^T = C^T$. The result $C^T$ is written in column-major into the output buffer, which is precisely the desired $C$ in row-major. |
| 186 | + |
| 187 | +This standard trick using `transA='N'` and `transB='N'` with swapped, row-major inputs is a common way PyTorch achieves row-major output for a general matrix multiply `C = A @ B` where A, B are row-major. |
| 188 | + |
| 189 | +--- |
| 190 | + |
| 191 | + |
| 192 | +### Linear Layer: `Y = X @ Wᵀ` |
| 193 | + |
| 194 | +For a linear layer computation `Y = X @ Wᵀ`, where `X` (`[M, K]`) and `W` (`[N, K]`) are in row-major layout, PyTorch desires `Y` (`[M, N]`) also in row-major. To achieve this with a BLAS routine outputting column-major, PyTorch configures BLAS to compute $Y^T = W @ X^T$. |
| 195 | + |
| 196 | +This involves a BLAS call computing $op(A) @ op(B)$ where $op(A)$ is $W$ and $op(B)$ is $X^T$. Using the rule that row-major data with `trans='T'` yields the matrix ($M$) and `trans='N'` yields the transpose ($M^T$): |
| 197 | +- BLAS operand A uses $W$'s row-major data. To see $W$, `transA = 'T'`. |
| 198 | +- BLAS operand B uses $X$'s row-major data. To see $X^T$, `transB = 'N'`. |
| 199 | + |
| 200 | +The BLAS call uses `(transA='T', transB='N')` with $W$'s data as the first operand and $X$'s data as the second. It computes $W @ X^T = Y^T$, writing the result in column-major, which PyTorch interprets as the desired row-major $Y$. |
| 201 | + |
| 202 | +### Backward Pass Operations: |
| 203 | + |
| 204 | +The backward pass similarly uses GEMMs configured to produce row-major gradients: |
| 205 | + |
| 206 | +- **`dX = dY @ W`**: With `dY` (`[M, K]`) and `W` (`[K, N]`) row-major, we need `dX` (`[M, N]`) row-major. BLAS computes $dX^T = W^T @ dY^T$. |
| 207 | + - BLAS operand A uses $W$'s row-major data. Needs $W^T \implies$ `transA = 'N'`. |
| 208 | + - BLAS operand B uses $dY$'s row-major data. Needs $dY^T \implies$ `transB = 'N'`. |
| 209 | + - BLAS call uses `(transA='N', transB='N')` on $W$'s and $dY$'s data, computing $W^T @ dY^T$. |
| 210 | + |
| 211 | +- **`dW = dYᵀ @ X`**: With `dY` (`[K, N]`) and `X` (`[K, M]`) row-major, we need `dW` (`[N, M]`) row-major. BLAS computes $dW^T = X^T @ dY$. |
| 212 | + - BLAS operand A uses $X$'s row-major data. Needs $X^T \implies$ `transA = 'N'`. |
| 213 | + - BLAS operand B uses $dY$'s row-major data. Needs $dY \implies$ `transB = 'T'`. |
| 214 | + - BLAS call uses `(transA='N', transB='T')` on $X$'s and $dY$'s data, computing $X^T @ dY$. |
| 215 | + |
| 216 | +In summary, for PyTorch's row-major operations: |
| 217 | +- Forward pass `Y = X @ Wᵀ` maps to BLAS calculating $W @ X^T$ using `(T, N)` flags on the row-major data of $W$ and $X$. |
| 218 | +- Backward pass `dX = dY @ W` maps to BLAS calculating $W^T @ dY^T$ using `(N, N)` flags on the row-major data of $W$ and $dY$. |
| 219 | +- Backward pass `dW = dYᵀ @ X` maps to BLAS calculating $X^T @ dY$ using `(N, T)` flags on the row-major data of $X$ and $dY$. |
| 220 | + |
| 221 | +Let’s revisit the GPT-3-XL model gemm table from **Tracelens**: |
| 222 | + |
| 223 | +| name | param: M | param: N | param: K | param: bias | param: stride_A | param: stride_B | param: transpose | |
| 224 | +|--------------|-----------|-----------|-----------|--------------|------------------|------------------|--------------------| |
| 225 | +| aten::addmm | 10240 | 8192 | 2048 | TRUE | (2048, 1) | (1, 2048) | (True, False) | |
| 226 | +| aten::mm | 10240 | 2048 | 8192 | FALSE | (8192, 1) | (2048, 1) | (False, False) | |
| 227 | +| aten::mm | 8192 | 2048 | 10240 | FALSE | (1, 8192) | (2048, 1) | (False, True) | |
| 228 | + |
| 229 | +This table shows how `aten::addmm` (forward) and `aten::mm` (backward) calls map to underlying GEMM operations. The `param: M, N, K` values are likely the dimensions of the *PyTorch operation result* (`M x N` with inner dim `K`). The `param: transpose | (transA, transB)` are the BLAS flags used by the wrapper for the operands passed to the BLAS call. |
| 230 | + |
| 231 | +Let's interpret the trace entries based on our understanding that PyTorch uses row-major data and BLAS receives flags to mathematically interpret this data for computing the transpose of the desired result: |
| 232 | + |
| 233 | +1. **`aten::addmm` (Forward)**: Corresponds to `Y = X @ Wᵀ`. Trace flags: `(True, False)`. This matches the `(T, N)` needed for BLAS to compute $W @ X^T$. |
| 234 | +2. **First `aten::mm` (Backward dX)**: Corresponds to `dX = dY @ W`. Trace flags: `(False, False)`. This matches the `(N, N)` needed for BLAS to compute $W^T @ dY^T$. |
| 235 | + |
| 236 | +3. **Second `aten::mm` (Backward dW)**: Corresponds to `dW = dYᵀ @ X`. Trace flags: `(False, True)`. This matches the `(N, T)` needed for BLAS to compute $X^T @ dY$. |
| 237 | + |
| 238 | + |
| 239 | +This confirms how the trace flags correspond to the BLAS transpose configurations used with row-major input data to achieve row-major output via the $C^T = B^T A^T$ trick. |
| 240 | + |
| 241 | + |
| 242 | +Lets summarize our understanding of the transpose flag by writing pseudo code for the logic: |
| 243 | + |
| 244 | +*(Note: This Python code snippet provides a simplified view; PyTorch's actual implementation is more intricate, accounting for the specific GEMM variant and output requirements.)* |
| 245 | + |
| 246 | + |
| 247 | + |
| 248 | +```python |
| 249 | + |
| 250 | +def is_col_major(T): |
| 251 | + return T.stride(0) == 1 and T.stride(1) >= T.shape[0] |
| 252 | + |
| 253 | +def get_blas_transpose_flags(A, B): |
| 254 | + transA = 'N' if is_col_major(A) else 'T' # If A is col-major, BLAS sees it as is ('N') |
| 255 | + transB = 'N' if is_col_major(B) else 'T' # If B is col-major, BLAS sees it as is ('N') |
| 256 | + return transA, transB |
| 257 | + |
| 258 | +``` |
| 259 | + |
| 260 | +## ⚠️ Edge Cases |
| 261 | + |
| 262 | + |
| 263 | +- One common assumption is that flattening a tensor shape like `[B, L, d_model]` to `[B⋅L, d_model]` is a cost-free metadata operation. This is only true if the last dimension (`d_model`) is contiguous. |
| 264 | + |
| 265 | +- If the last dimension is not contiguous, PyTorch may be forced to insert a **copy** or **transpose** operation to create a physically contiguous tensor that BLAS can work with efficiently. |
| 266 | + |
| 267 | +- Furthermore, even if the tensor layout *could* theoretically be used by BLAS (e.g., certain striding patterns), some highly tuned BLAS libraries might lack kernels optimized for those specific layouts. In such instances, a **copy** or **transpose buffer** is inserted behind the scenes by PyTorch or the BLAS wrapper. Consequently, what the BLAS routine actually operates on might not be the original tensor directly, but rather a **temporary buffer** created for compatibility or performance. |
| 268 | + |
0 commit comments