Skip to content
1 change: 1 addition & 0 deletions cf_xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import sgrid # noqa
from .accessor import CFAccessor # noqa
from .coding import ( # noqa
decode_compress_to_multi_index,
Expand Down
8 changes: 7 additions & 1 deletion cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from xarray.core.rolling import Coarsen, Rolling
from xarray.core.weighted import Weighted

from . import sgrid
from .criteria import cf_role_criteria, coordinate_criteria, regex
from .helpers import _guess_bounds_1d, _guess_bounds_2d, bounds_to_vertices
from .options import OPTIONS
Expand Down Expand Up @@ -300,6 +301,11 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
units = getattr(var.data, "units", None)
if units in expected:
results.update((coord,))

if key in _AXIS_NAMES and "grid_topology" in obj.cf.cf_roles:
sgrid_axes = sgrid.parse_axes(obj)
results.update((search_in | set(obj.dims)) & sgrid_axes[key])

return list(results)


Expand Down Expand Up @@ -424,7 +430,7 @@ def _get_coords(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
'area', 'volume'), or arbitrary measures, or standard names present in .coords
"""
return [k for k in _get_all(obj, key) if k in obj.coords]
return [k for k in _get_all(obj, key) if k in obj.coords or k in obj.dims]


def _variables(func: F) -> F:
Expand Down
20 changes: 20 additions & 0 deletions cf_xarray/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,23 @@ def _create_inexact_bounds():
"trajectory": ("trajectory", [0, 1], {"cf_role": "trajectory_id"}),
},
)


roms_sgrid = xr.Dataset()
roms_sgrid["grid"] = xr.DataArray(
0,
attrs=dict(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions="xi_psi eta_psi",
face_dimensions="xi_rho: xi_psi (padding: both) eta_rho: eta_psi (padding: both)",
edge1_dimensions="xi_u: xi_psi eta_u: eta_psi (padding: both)",
edge2_dimensions="xi_v: xi_psi (padding: both) eta_v: eta_psi",
node_coordinates="lon_psi lat_psi",
face_coordinates="lon_rho lat_rho",
edge1_coordinates="lon_u lat_u",
edge2_coordinates="lon_v lat_v",
vertical_dimensions="s_rho: s_w (padding: none)",
),
)
roms_sgrid["u"] = (("xi_u", "eta_u"), np.ones((2, 2)))
27 changes: 27 additions & 0 deletions cf_xarray/sgrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def parse_axes(ds):
import re

(gridvar,) = ds.cf.cf_roles["grid_topology"]
grid = ds[gridvar]
pattern = re.compile("\\s?(.*?):\\s*(.*?)\\s+(?:\\(padding:(.+?)\\))?")
ndim = grid.attrs["topology_dimension"]
axes_names = ["X", "Y", "Z"][:ndim]
axes = dict(
zip(
axes_names,
({k} for k in grid.attrs["node_dimensions"].split(" ")),
)
)
for attr in ["face_dimensions", "edge1_dimensions", "edge2_dimensions", "foo"]:
if attr in grid.attrs:
matches = re.findall(pattern, grid.attrs[attr] + "\n")
assert len(matches) == ndim, matches
for ax, match in zip(axes_names, matches):
axes[ax].update(set(match[:2]))

if ndim == 2 and "vertical_dimensions" in grid.attrs:
matches = re.findall(pattern, grid.attrs["vertical_dimensions"] + "\n")
assert len(matches) == 1
axes["Z"] = set(matches[0][:2])

return axes
27 changes: 27 additions & 0 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
multiple,
pomds,
popds,
roms_sgrid,
romsds,
rotds,
vert,
Expand Down Expand Up @@ -1799,3 +1800,29 @@ def plane(coords, slopex, slopey):
coords=["latitude", "longitude"], func=plane
)
assert_identical(expected, actual)


def test_sgrid():
from ..sgrid import parse_axes

positions = ["psi", "u", "v", "rho"]
expected = {
"X": {f"xi_{pos}" for pos in positions},
"Y": {f"eta_{pos}" for pos in positions},
"Z": {"s_rho", "s_w"},
}
assert parse_axes(roms_sgrid) == expected
assert roms_sgrid.cf.axes == {"X": ["xi_u"], "Y": ["eta_u"]}

assert_identical(roms_sgrid.cf["X"], roms_sgrid.xi_u)
assert_identical(roms_sgrid.cf["Y"], roms_sgrid.eta_u)

with pytest.raises(KeyError):
roms_sgrid.u.cf["X"]
with pytest.raises(KeyError):
roms_sgrid.u.cf["Y"]

roms_ = roms_sgrid.set_coords("grid")
assert roms_.u.cf.axes == {"X": ["xi_u"], "Y": ["eta_u"]}
assert_identical(roms_.u.cf["X"], roms_sgrid.xi_u)
assert_identical(roms_.u.cf["Y"], roms_sgrid.eta_u)