Description
What happened?
In my work, I produced a sequence of covariance matrices for a 2-D quantity. I wanted to extract the diagonal of the covariance matrices, then make that diagonal 2-D so I could plot it.
I could unstack the 2-D dimensions in the sequence of covariance matrices without issue. I figured out how to extract the diagonal of the covariance matrices. Unstacking the diagonal using the same procedure raised a ValueError
.
What did you expect to happen?
I expected the sequence of one-dimensional diagonals to unstack into a sequence of two-dimensional fields so I could plot them with pcolormesh.
I can make this happen by unstacking the two dimensions (producing a 5-D DataArray) and extracting the diagonals from that, but I don't see a reason it shouldn't work in the other order.
Minimal Complete Verifiable Example
import numpy as np
import xarray
# Working:
test = xarray.DataArray(
np.eye(12),
dims=("dim0", "adj_dim0"),
coords={
"dim0_0": (("dim0",), np.repeat(np.arange(3), 4)),
"dim0_1": (("dim0",), np.tile(np.arange(4), 3)),
},
)
unstacked = test.set_index(dim0=["dim0_0", "dim0_1"]).unstack("dim0")
diag_index = xarray.DataArray(np.arange(test.shape[0]), dims=("diag",))
unstacked_diag = (
test.isel(dim0=diag_index, adj_dim0=diag_index)
.set_index(diag=["dim0_0", "dim0_1"])
.unstack("diag")
)
# Not working:
test = xarray.DataArray(
dims=("dim1", "dim0", "adj_dim0"),
data=np.tile(np.eye(12), (2, 1, 1)),
coords={
"dim0": np.arange(12),
"dim0_0": (("dim0",), np.repeat(np.arange(3), 4)),
"dim0_1": (("dim0",), np.tile(np.arange(4), 3)),
"adj_dim0": np.arange(12),
"adj_dim0_0": (("adj_dim0",), np.repeat(np.arange(3), 4)),
"adj_dim0_1": (("adj_dim0",), np.tile(np.arange(4), 3)),
},
)
unstacked = test.set_index(
dim0=["dim0_0", "dim0_1"], adj_dim0=["adj_dim0_0", "adj_dim0_1"]
).unstack(["dim0", "adj_dim0"])
diag_index0 = xarray.DataArray(np.arange(unstacked.shape[1]), dims=("diag_0",))
diag_index1 = xarray.DataArray(np.arange(unstacked.shape[2]), dims=("diag_1",))
unstacked_diag = unstacked.isel(
dim0_0=diag_index0,
dim0_1=diag_index1,
adj_dim0_0=diag_index0,
adj_dim0_1=diag_index1,
)
diag_index = xarray.DataArray(np.arange(test.shape[1]), dims=("diag",))
test.isel(dim0=diag_index, adj_dim0=diag_index).set_index(
diag=["dim0_0", "dim0_1"]
).unstack("diag")
Relevant log output
Traceback (most recent call last):
File "./test.py", line 50, in <module>
).unstack("diag")
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/dataarray.py", line 2201, in unstack
ds = self._to_temp_dataset().unstack(dim, fill_value, sparse)
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/dataset.py", line 4214, in unstack
result = result._unstack_once(dim, fill_value, sparse)
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/dataset.py", line 4070, in _unstack_once
variables[name] = var._unstack_once(
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/variable.py", line 1690, in _unstack_once
return self._replace(dims=new_dims, data=data)
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/variable.py", line 978, in _replace
return type(self)(dims, data, attrs, encoding, fastpath=True)
File "~/.conda/envs/plotting/lib/python3.10/site-packages/xarray/core/variable.py", line 2668, in __init__
raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
ValueError: IndexVariable objects must be 1-dimensional
Anything else we need to know?
No response
Environment
INSTALLED VERSIONS
commit: None
python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
python-bits: 64
OS: Linux
OS-release: 3.10.0-1160.59.1.el7.x86_64
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.12.1
libnetcdf: 4.8.1
xarray: 2022.3.0
pandas: 1.4.2
numpy: 1.22.3
scipy: 1.8.0
netCDF4: 1.5.8
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.6.0
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: None
distributed: None
matplotlib: 3.5.1
cartopy: 0.20.2
seaborn: None
numbagg: None
fsspec: None
cupy: None
pint: None
sparse: None
setuptools: 61.3.1
pip: 22.0.4
conda: None
pytest: None
IPython: None
sphinx: None