-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Closed
Labels
Description
Hi, I'm exploring mlx on cuda. I've ran a simple test and I see 2x (or more) slower inference with mlx.
Here is an example output:
torch-fwd compiled=False: 0.143911 seconds
torch-fwd compiled=False: 0.042127 seconds
torch-fwd compiled=False: 0.042115 seconds
torch-fwd compiled=False: 0.042109 seconds
torch-fwd compiled=False: 0.042113 seconds
torch-fwd compiled=False: 0.042116 seconds
torch-fwd compiled=False: 0.042115 seconds
torch-fwd compiled=False: 0.042126 seconds
torch-fwd compiled=False: 0.042122 seconds
torch-fwd compiled=False: 0.042262 seconds
torch-fwd compiled=True: 0.416149 seconds
torch-fwd compiled=True: 0.042157 seconds
torch-fwd compiled=True: 0.042157 seconds
torch-fwd compiled=True: 0.042162 seconds
torch-fwd compiled=True: 0.042154 seconds
torch-fwd compiled=True: 0.042320 seconds
torch-fwd compiled=True: 0.042342 seconds
torch-fwd compiled=True: 0.042221 seconds
torch-fwd compiled=True: 0.042153 seconds
torch-fwd compiled=True: 0.042163 seconds
kwargs={'dtype': mlx.core.bfloat16, 'stream': Stream(Device(gpu, 0), 2)}
mlx-fwd: 1.070995 seconds
mlx-fwd: 0.092245 seconds
mlx-fwd: 0.092189 seconds
mlx-fwd: 0.092177 seconds
mlx-fwd: 0.092305 seconds
mlx-fwd: 0.092356 seconds
mlx-fwd: 0.092400 seconds
mlx-fwd: 0.092474 seconds
mlx-fwd: 0.092326 seconds
mlx-fwd: 0.092167 seconds
And here is the minimal example:
import time
from contextlib import contextmanager
@contextmanager
def timeit(label="Elapsed"):
start = time.perf_counter()
yield
end = time.perf_counter()
print(f"{label}: {end - start:.6f} seconds")
dim_main = 256
dim_hid = 768
def bench_mlx():
import mlx.core as mx
def mlx_simple_layer(
inp,
proj_ab_W,
proj_ab_b,
lin_out_W,
lin_out_b,
):
hid = mx.addmm(proj_ab_b, inp, proj_ab_W.T)
return mx.addmm(lin_out_b, hid, lin_out_W.T)
mlx_simple_layer = mx.compile(mlx_simple_layer)
kwargs = dict(
dtype=mx.bfloat16,
stream=mx.new_stream(mx.Device(type=mx.DeviceType.gpu)),
)
print(f"{kwargs=}")
res = mx.random.normal([1, 1024, 1024, dim_main], **kwargs)
for _ in range(10):
with timeit("mlx-fwd"):
for _ in range(10):
res = mlx_simple_layer(
inp=res,
proj_ab_W=mx.ones([dim_hid, dim_main], **kwargs),
proj_ab_b=mx.ones(dim_hid, **kwargs),
#
lin_out_W=mx.ones([dim_main, dim_hid], **kwargs),
lin_out_b=mx.ones([dim_main], **kwargs),
)
mx.eval(res) # force sync
def bench_pytorch(compiled: bool):
import torch
import torch.nn.functional as F
def torch_simple_layer(
inp,
proj_ab_W,
proj_ab_b,
lin_out_W,
lin_out_b,
) -> torch.Tensor:
hid = F.linear(inp, proj_ab_W, bias=proj_ab_b)
return F.linear(hid, lin_out_W, bias=lin_out_b)
if compiled:
torch_simple_layer = torch.compile(torch_simple_layer)
kwargs: dict = dict(
dtype=torch.bfloat16,
device=torch.device("cuda:0"),
)
res = torch.randn((1, 1024, 1024, dim_main), **kwargs)
for _ in range(10):
with timeit(f"torch-fwd {compiled=}"):
for _ in range(10):
res = torch_simple_layer(
inp=res,
proj_ab_W=torch.ones((dim_hid, dim_main), **kwargs),
proj_ab_b=torch.ones(dim_hid, **kwargs),
#
lin_out_W=torch.ones((dim_main, dim_hid), **kwargs),
lin_out_b=torch.ones(dim_main, **kwargs) * 10000,
)
res[0, 0, 0, 0].item() # force sync
bench_pytorch(compiled=False)
bench_pytorch(compiled=True)
bench_mlx()Testing setup:
- I use A100
- environment can be installed with
uv pip install mlx[cuda]==0.26.3 torch==2.8.0(it is not simple to find versions that would work on both. I've checked with mlx=0.29 and performance did not change).
Some observations:
- on fp32 speeds of mlx and pytorch are very close
- I've looked at profiler and I think they even use the same funcs from cudnn internally for bf16 matmuls.
Reactions are currently unavailable