Skip to content

Performance discrepancy vs torch on cuda #2566

@arogozhnikov

Description

@arogozhnikov

Hi, I'm exploring mlx on cuda. I've ran a simple test and I see 2x (or more) slower inference with mlx.

Here is an example output:

torch-fwd compiled=False: 0.143911 seconds
torch-fwd compiled=False: 0.042127 seconds
torch-fwd compiled=False: 0.042115 seconds
torch-fwd compiled=False: 0.042109 seconds
torch-fwd compiled=False: 0.042113 seconds
torch-fwd compiled=False: 0.042116 seconds
torch-fwd compiled=False: 0.042115 seconds
torch-fwd compiled=False: 0.042126 seconds
torch-fwd compiled=False: 0.042122 seconds
torch-fwd compiled=False: 0.042262 seconds

torch-fwd compiled=True: 0.416149 seconds
torch-fwd compiled=True: 0.042157 seconds
torch-fwd compiled=True: 0.042157 seconds
torch-fwd compiled=True: 0.042162 seconds
torch-fwd compiled=True: 0.042154 seconds
torch-fwd compiled=True: 0.042320 seconds
torch-fwd compiled=True: 0.042342 seconds
torch-fwd compiled=True: 0.042221 seconds
torch-fwd compiled=True: 0.042153 seconds
torch-fwd compiled=True: 0.042163 seconds

kwargs={'dtype': mlx.core.bfloat16, 'stream': Stream(Device(gpu, 0), 2)}
mlx-fwd: 1.070995 seconds
mlx-fwd: 0.092245 seconds
mlx-fwd: 0.092189 seconds
mlx-fwd: 0.092177 seconds
mlx-fwd: 0.092305 seconds
mlx-fwd: 0.092356 seconds
mlx-fwd: 0.092400 seconds
mlx-fwd: 0.092474 seconds
mlx-fwd: 0.092326 seconds
mlx-fwd: 0.092167 seconds

And here is the minimal example:

import time
from contextlib import contextmanager


@contextmanager
def timeit(label="Elapsed"):
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    print(f"{label}: {end - start:.6f} seconds")


dim_main = 256
dim_hid = 768


def bench_mlx():
    import mlx.core as mx

    def mlx_simple_layer(
        inp,
        proj_ab_W,
        proj_ab_b,
        lin_out_W,
        lin_out_b,
    ):
        hid = mx.addmm(proj_ab_b, inp, proj_ab_W.T)
        return mx.addmm(lin_out_b, hid, lin_out_W.T)

    mlx_simple_layer = mx.compile(mlx_simple_layer)

    kwargs = dict(
        dtype=mx.bfloat16,
        stream=mx.new_stream(mx.Device(type=mx.DeviceType.gpu)),
    )
    print(f"{kwargs=}")

    res = mx.random.normal([1, 1024, 1024, dim_main], **kwargs)

    for _ in range(10):
        with timeit("mlx-fwd"):
            for _ in range(10):
                res = mlx_simple_layer(
                    inp=res,
                    proj_ab_W=mx.ones([dim_hid, dim_main], **kwargs),
                    proj_ab_b=mx.ones(dim_hid, **kwargs),
                    #
                    lin_out_W=mx.ones([dim_main, dim_hid], **kwargs),
                    lin_out_b=mx.ones([dim_main], **kwargs),
                )
            mx.eval(res)  # force sync


def bench_pytorch(compiled: bool):
    import torch
    import torch.nn.functional as F

    def torch_simple_layer(
        inp,
        proj_ab_W,
        proj_ab_b,
        lin_out_W,
        lin_out_b,
    ) -> torch.Tensor:
        hid = F.linear(inp, proj_ab_W, bias=proj_ab_b)
        return F.linear(hid, lin_out_W, bias=lin_out_b)

    if compiled:
        torch_simple_layer = torch.compile(torch_simple_layer)

    kwargs: dict = dict(
        dtype=torch.bfloat16,
        device=torch.device("cuda:0"),
    )
    res = torch.randn((1, 1024, 1024, dim_main), **kwargs)
    for _ in range(10):
        with timeit(f"torch-fwd {compiled=}"):
            for _ in range(10):
                res = torch_simple_layer(
                    inp=res,
                    proj_ab_W=torch.ones((dim_hid, dim_main), **kwargs),
                    proj_ab_b=torch.ones(dim_hid, **kwargs),
                    #
                    lin_out_W=torch.ones((dim_main, dim_hid), **kwargs),
                    lin_out_b=torch.ones(dim_main, **kwargs) * 10000,
                )
            res[0, 0, 0, 0].item()  # force sync


bench_pytorch(compiled=False)
bench_pytorch(compiled=True)
bench_mlx()

Testing setup:

  • I use A100
  • environment can be installed with uv pip install mlx[cuda]==0.26.3 torch==2.8.0 (it is not simple to find versions that would work on both. I've checked with mlx=0.29 and performance did not change).

Some observations:

  • on fp32 speeds of mlx and pytorch are very close
  • I've looked at profiler and I think they even use the same funcs from cudnn internally for bf16 matmuls.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions