Skip to content

Commit cdc0290

Browse files
feat(turbo): Primus-Torchtitan support Primus-Turbo backend. (#118)
Co-authored-by: Xiaoming-AMD <xiaoming.peng@amd.com>
1 parent 9dfb128 commit cdc0290

File tree

19 files changed

+331
-41
lines changed

19 files changed

+331
-41
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,4 @@ jobs:
137137
rm -rf ${PRIMUS_WORKDIR}/Primus-Turbo
138138
- name: Clean Primus
139139
run: |
140-
rm -rf ${PRIMUS_WORKDIR}/Primus
140+
rm -rf ${PRIMUS_WORKDIR}/Primus

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: check-added-large-files
1515
- id: check-merge-conflict
1616
- repo: https://github.com/pycqa/isort
17-
rev: 5.11.5
17+
rev: 5.13.2
1818
hooks:
1919
- id: isort
2020
args: ["--profile", "black"]

examples/torchtitan/configs/llama3.1_8B-pretrain.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,9 @@ modules:
1414
sink_level: null
1515
file_sink_level: DEBUG
1616
stderr_sink_level: INFO
17+
18+
# model:
19+
# converters: ["mx"]
20+
primus_turbo:
21+
enable_primus_turbo: false
22+
enable_attention_float8: false

primus/backends/torchtitan/__init__.py

Whitespace-only changes.

primus/backends/torchtitan/components/__init__.py

Whitespace-only changes.

primus/backends/torchtitan/components/quantization/__init__.py

Whitespace-only changes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import torch
8+
import torch.nn as nn
9+
from primus_turbo.pytorch.core.float8 import MXQuantConfig
10+
from primus_turbo.pytorch.modules import MXLinear
11+
from torchtitan.config_manager import JobConfig
12+
from torchtitan.distributed import ParallelDims
13+
from torchtitan.protocols.model_converter import (
14+
ModelConverter,
15+
register_model_converter,
16+
)
17+
from torchtitan.tools.logging import logger
18+
19+
20+
def replace_turbo_mxlinear_modules(model: nn.Module, config: MXQuantConfig):
21+
for name, module in model.named_children():
22+
if isinstance(module, torch.nn.Linear) and not isinstance(module, MXLinear):
23+
mx_linear = MXLinear.from_float(module, config)
24+
setattr(model, name, mx_linear)
25+
else:
26+
replace_turbo_mxlinear_modules(module, config)
27+
28+
29+
class PrimusTubroMXConverter(ModelConverter):
30+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
31+
self.enabled = True
32+
# TODO: quant config
33+
self.config = MXQuantConfig()
34+
35+
def convert(self, model: nn.Module):
36+
if not self.enabled:
37+
return
38+
39+
replace_turbo_mxlinear_modules(model, self.config)
40+
41+
logger.info("Swapped to MXLinear layers")
42+
43+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
44+
"""
45+
MXFP8 doesn't require any post-optimizer hooks at the moment
46+
"""
47+
return
48+
49+
50+
register_model_converter(PrimusTubroMXConverter, "primus_turbo_mx")

primus/backends/torchtitan/models/__init__.py

Whitespace-only changes.

primus/backends/torchtitan/models/llama3/__init__.py

Whitespace-only changes.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import torch
8+
from torchtitan.models.llama3.model import Attention as TTAttention
9+
from torchtitan.models.llama3.model import apply_rotary_emb
10+
11+
12+
class Attention(TTAttention):
13+
def forward(
14+
self,
15+
x: torch.Tensor,
16+
freqs_cis: torch.Tensor,
17+
):
18+
bs, seqlen, _ = x.shape
19+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
20+
21+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
22+
# local heads from sizes of xq, xk, and xv as TP may have sharded them
23+
# after the above linear ops.
24+
xq = xq.view(bs, seqlen, -1, self.head_dim)
25+
xk = xk.view(bs, seqlen, -1, self.head_dim)
26+
xv = xv.view(bs, seqlen, -1, self.head_dim)
27+
28+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
29+
30+
# repeat k/v heads if n_kv_heads < n_heads
31+
# xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
32+
# xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
33+
34+
output = self.sdpa(xq, xk, xv)
35+
36+
output = output.view(bs, seqlen, -1)
37+
return self.wo(output)

0 commit comments

Comments
 (0)