Skip to content

Add CUDA kernel for MXFP8 dim1 casting #2513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 9, 2025

Add CUDA kernel for MXFP8 dim1 casting

Co-authored-by: Less Wright [email protected]

Summary

  • Add a CUDA kernel to do mxfp8 dim1 casting, which benchmarks show is ~1.4x faster than the existing Triton kernel, and benchmarking Llama3 8b training with torchtitan show this translating to a 1.5% - 2.5% e2e training speedup. This prototype was developed/explored in a personal repo, and now this PR begins the migration into torchao.
  • We used this TE kernel as a starting point (big thanks to @lessw2020 for setting up this C++ extension with the parts we needed to iterate on! working mxfp8 quantization cpp extension danielvegamyhre/private-torchao#1)
  • Subsequent PRs will migrate the integration, torch.compile support etc. to make reviewing easier. This PR is only adds (1) the CUDA kernel, (2) the C++ extension to make it usable in Python, (3) numerical test, and (4) kernel benchmarking.

Numerics changes

I made the following changes to get matching numerics for the columnwise/dim1 scaling path:

Performance improvements

I also made the following changes to improve perf for the columnwise/dim1 scaling path:

torchao integration, torch.compile support, and other changes

Test build

  • cd ~/ao/torchao/prototype/mxfp8_cuda/
  • python setup.py install
  • Note: will include in torchao main setup.py after subsequent integration PRs are landed (see above)

Test numerics

  • pytest test/prototype/mx_formats/test_kernels.py -k cuda

Kernel microbenchmarks

  • CUDA mx dim1 floor scaling kernel is ~1.4x faster and achieves similarly higher peak memory bandwidth utilization.
  • RCEIL scaling is even faster (since it uses hardware native scaling instead of software scaling), but is not a 1-to-1 comparison since Triton is doing floor scaling.

CUDA mx dim1 with floor scaling:

(ao) [[email protected] ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_cuda_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_cuda_floor
time_us 155.45600652694702
mem_bw_gbps 5234.245972084408

CUDA mx dim1 with rceil scaling:

(ao) [[email protected] ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_cuda_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_cuda_rceil
time_us 145.31199634075165
mem_bw_gbps 5599.640748805854

Triton mx dim1 (uses floor scaling):

(ao) [[email protected] ~/ao/benchmarks/mx_formats (mxfp8-cuda)]$ CUDA_VISIBLE_DEVICES=3 python cast_bench.py --mode dim1_mx_triton
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_triton
time_us 217.28000044822693
mem_bw_gbps 3744.914277988901

E2E training benchmarks

Llama3.1 8b on 4 B200s with FSDP=4, torch.compile, per op SAC

Speedup over bf16 baseline with FSDP=4  
fp8 tensorwise 1.185
mxpf8 triton 1.185
mxfp8 cuda 1.202

Note: I got a larger improvement of ~2.6% with FSDP=8 last night, but am going to rerun that benchmark a couple times later to confirm.

BF16: https://www.internalfb.com/phabricator/paste/view/P1864747221

(ao) [[email protected] ~/torchtitan (main)]$ python parse.py --log-file=bf16-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 11401.0
Max Memory Usage: 50.22 GiB

FP8 tensorwise: https://www.internalfb.com/phabricator/paste/view/P1864745686

(ao) [[email protected] ~/torchtitan (main)]$ python parse.py --log-file=fp8-tensorwise-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13506.0
Max Memory Usage: 50.14 GiB

MXFP8 (triton): https://www.internalfb.com/phabricator/paste/view/P1864746140

(ao) [[email protected] ~/torchtitan (main)]$ python parse.py --log-file=triton-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13506.0
Max Memory Usage: 50.49 GiB

MXFP8 (cuda): https://www.internalfb.com/phabricator/paste/view/P1864746496

(ao) [[email protected] ~/torchtitan (main)]$ python parse.py --log-file=cuda-log.txt 

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 13708.0
Max Memory Usage: 50.26 GiB

Additional e2e training benchmarks on 2nd machine to confirm results

Additional e2e training benchmarks ran on @lessw2020's machine, confirming the speedup is reproducible.

The peak memory reduction of 4-5% when no AC is used is interesting to note.

Llama 3.1 8b, FSDP=4, torch.compile, no AC

TPS Memory
cuda 12910 108.47
triton 12622 113.02
cuda % change vs triton 2.28% -4.03%

Llama 3.1 8b, FSDP=8, torch.compile, no AC

TPS Memory
cuda 12992 129.45
triton 12761 136.61
cuda % change vs triton 1.81% -5.24%

danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 828b1b0 to 5d1f777 Compare July 9, 2025 23:07
Copy link

pytorch-bot bot commented Jul 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2513

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Unrelated Failure

As of commit 0e157a1 with merge base 2e2ce0b (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft July 9, 2025 23:08
danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 5d1f777 to 5df753a Compare July 9, 2025 23:47
danielvegamyhre added a commit that referenced this pull request Jul 9, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 5df753a to 6253dac Compare July 9, 2025 23:51
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 9, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review July 9, 2025 23:55
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo @drisspg for review

"dim0_mx",
"dim1_mx",
"dim1_mx_triton",
"dim1_mx_cuda_floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "floor" to match all the others, or add "floor" to all the others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added floor to others, I like the more explicit naming.

@@ -194,6 +208,42 @@ def run(
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_cuda_floor":
bench_fn = partial(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: either refactor the other branches to use partial, or remove partial here, to keep the file code style consistent


from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

try:
import mxfp8_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would it take to just have this available in torchao instead of requiring a separate import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated to torchao/csrc/cuda/mx_kernels and updated setup.py to make it available as torchao.prototype.mxfp8_cuda. cc @drisspg as well

x_hp: torch.Tensor, block_size
x_hp: torch.Tensor,
block_size,
scaling_mode=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to floor?

from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx

scale_mode = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove after default to floor above?

"""
setup.py - Build configuration for MXFP8 PyTorch extension

This extension requires NVIDIA BlACKWELL architecture (SM100+) or newer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLACKWELL instead of BlACKWELL

* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights
*reserved.
*
* See LICENSE for license information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a LICENSE file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated license headers per conversation with Supriya, see update. There's no standalone license file though.

.contiguous()
)

y_d1_ref, s_d1_ref = triton_to_mxfp8_dim1_reference(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the refernce should always be the native PyTorch code, and custom kernels should each match against the reference. This way if there is a mismatch, it's very clear which kernels have a mismatch vs native PyTorch code.

torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)

# check quantized values
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also test the memory layout of all the tensors vs reference

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to verify memory layout of quantized tensor is identical. However, I deliberately wrote the scale tensor in a different memory layout, to avoid uncoalesced global accesses. This is also why I modified the triton_scale_swizzle kernel to accept column-major inputs in danielvegamyhre/private-torchao#19

danielvegamyhre added a commit that referenced this pull request Jul 10, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 6253dac to 4823b37 Compare July 10, 2025 20:38
danielvegamyhre added a commit that referenced this pull request Jul 11, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 4823b37 to a55e5ae Compare July 11, 2025 00:13
danielvegamyhre added a commit that referenced this pull request Jul 11, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from a55e5ae to 0065bcd Compare July 11, 2025 00:25
danielvegamyhre added a commit that referenced this pull request Jul 12, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 0065bcd to 97c77a7 Compare July 12, 2025 00:37
danielvegamyhre added a commit that referenced this pull request Jul 14, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 97c77a7 to 0a26808 Compare July 14, 2025 23:29
danielvegamyhre added a commit that referenced this pull request Jul 15, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 0a26808 to 21d5528 Compare July 15, 2025 04:19
):
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"torch version: {torch.__version__}")
print(f"triton version: {triton.__version__}")
print(f"mode: {mode}")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
assert mode in (
"dim0_floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first three entries here do not do any rounding, so the old names were correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch, thanks. Updated.

danielvegamyhre added a commit that referenced this pull request Jul 15, 2025
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 21d5528 to f62a2f0 Compare July 15, 2025 14:13
Co-authored-by: Less Wright <[email protected]>

stack-info: PR: #2513, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from f62a2f0 to 0e157a1 Compare July 15, 2025 14:28
@danielvegamyhre
Copy link
Contributor Author

Confirmed rocm wheel build CI failures are unrelated to this change. I created a test PR branching off of main, that just adds a print statement to setup.py: #2545 As you can see, the rocm wheel build fails in CI.

I created an issue to track this: #2546

@danielvegamyhre danielvegamyhre merged commit c011bad into main Jul 15, 2025
29 of 34 checks passed
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Jul 25, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
idoh pushed a commit to idoh/torchtitan that referenced this pull request Jul 28, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
joellidin pushed a commit to tplr-ai/torchtitan that referenced this pull request Aug 8, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
joellidin pushed a commit to tplr-ai/torchtitan that referenced this pull request Aug 8, 2025
Stacked PRs:
 * __->__#1427


--- --- ---

make mxfp8 dim1 cast kernel configurable

## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`

## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`

## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants