Skip to content

Import dask and cubed only once #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 15, 2025
Merged

Conversation

eendebakpt
Copy link
Contributor

We recently found a performance bottleneck in the xarray groupby method. Part of the bottleneck was repeated attempted imports of cubed. In this PR we refactor the code to import cubed and dask only once.

@dcherian
Copy link
Collaborator

dcherian commented Aug 13, 2025

Nice, thanks. Looks like there's one more mypy thing to fix. I'm happy to immediately release once that is fixed.

We recently found a performance bottleneck in the xarray groupby method.

I'm curious to learn if you've found other things. I spent quite some time optimizing this stuff a couple of years ago.

@eendebakpt
Copy link
Contributor Author

Nice, thanks. Looks like there's one more mypy thing to fix. I'm happy to immediately release once that is fixed.

On my local system mypy is giving different output, so I am doing a bit of trial and error here.

We recently found a performance bottleneck in the xarray groupby method.

I'm curious to learn if you've found other things. I spent quite some time optimizing this stuff a couple of years ago.

Not something specific. Our use case has relatively small arrays, maybe the optimizations have been tested on larger datasets?
Here is a minimal example I extracted:

import xarray as xr
import time
import numpy as np


def method(ds: xr.Dataset):
    pauli_groups = ds.groupby("pauli")
    for _, ds in pauli_groups:
        da = ds.groupby("clifford").mean().counts


k = 12
mult = 30
index = np.arange(k*mult)
bitindex = np.arange(4)
counts = np.random.randint(0, 10, size=(4, index.size))
c = xr.DataArray(counts, coords={'bit_index': bitindex, 'index': index})
ds = xr.Dataset({'counts': c})
ds = ds.assign_coords({'clifford': xr.DataArray(
    np.tile(np.arange(k), mult), dims='index')})
p = xr.DataArray(np.tile(np.arange(4), (k//4) * mult), dims='index')
for ii in range(16):  # the value here has a large impact on the running time
    p.data[ii] = 1000 + ii
ds = ds.assign_coords({'pauli': p})
print(ds)

nn = 20
for ii in range(4):
    t0 = time.time()
    for ii in range(nn):
        method(ds)
    dt = time.time()-t0
    print(f'time {1e3*dt/nn:.2f} [ms / iteration]')

On my system a single call to method takes about 50 ms. I suspect that the equivalent calculation with just numpy arrays is much faster (not a fair comparison, xarray has to keep track of much more data!). Looking at profiling results shows that the call to mean() is actually taking most of the time.

@dcherian dcherian merged commit cbcc035 into xarray-contrib:main Aug 15, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants