Skip to content

Commit affc1ad

Browse files
committed
Add sgrid axes parsing
1 parent ade1362 commit affc1ad

File tree

5 files changed

+82
-1
lines changed

5 files changed

+82
-1
lines changed

cf_xarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import sgrid # noqa
12
from .accessor import CFAccessor # noqa
23
from .coding import ( # noqa
34
decode_compress_to_multi_index,

cf_xarray/accessor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from xarray.core.rolling import Coarsen, Rolling
3131
from xarray.core.weighted import Weighted
3232

33+
from . import sgrid
3334
from .criteria import cf_role_criteria, coordinate_criteria, regex
3435
from .helpers import _guess_bounds_1d, _guess_bounds_2d, bounds_to_vertices
3536
from .options import OPTIONS
@@ -300,6 +301,11 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
300301
units = getattr(var.data, "units", None)
301302
if units in expected:
302303
results.update((coord,))
304+
305+
if key in _AXIS_NAMES and "grid_topology" in obj.cf.cf_roles:
306+
sgrid_axes = sgrid.parse_axes(obj)
307+
results.update((search_in | set(obj.dims)) & sgrid_axes[key])
308+
303309
return list(results)
304310

305311

@@ -424,7 +430,7 @@ def _get_coords(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
424430
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
425431
'area', 'volume'), or arbitrary measures, or standard names present in .coords
426432
"""
427-
return [k for k in _get_all(obj, key) if k in obj.coords]
433+
return [k for k in _get_all(obj, key) if k in obj.coords or k in obj.dims]
428434

429435

430436
def _variables(func: F) -> F:

cf_xarray/datasets.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,23 @@ def _create_inexact_bounds():
604604
"trajectory": ("trajectory", [0, 1], {"cf_role": "trajectory_id"}),
605605
},
606606
)
607+
608+
609+
roms_sgrid = xr.Dataset()
610+
roms_sgrid["grid"] = xr.DataArray(
611+
0,
612+
attrs=dict(
613+
cf_role="grid_topology",
614+
topology_dimension=2,
615+
node_dimensions="xi_psi eta_psi",
616+
face_dimensions="xi_rho: xi_psi (padding: both) eta_rho: eta_psi (padding: both)",
617+
edge1_dimensions="xi_u: xi_psi eta_u: eta_psi (padding: both)",
618+
edge2_dimensions="xi_v: xi_psi (padding: both) eta_v: eta_psi",
619+
node_coordinates="lon_psi lat_psi",
620+
face_coordinates="lon_rho lat_rho",
621+
edge1_coordinates="lon_u lat_u",
622+
edge2_coordinates="lon_v lat_v",
623+
vertical_dimensions="s_rho: s_w (padding: none)",
624+
),
625+
)
626+
roms_sgrid["u"] = (("xi_u", "eta_u"), np.ones((2, 2)))

cf_xarray/sgrid.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
def parse_axes(ds):
2+
import re
3+
4+
(gridvar,) = ds.cf.cf_roles["grid_topology"]
5+
grid = ds[gridvar]
6+
pattern = re.compile("\\s?(.*?):\\s*(.*?)\\s+(?:\\(padding:(.+?)\\))?")
7+
ndim = grid.attrs["topology_dimension"]
8+
axes_names = ["X", "Y", "Z"][:ndim]
9+
axes = dict(
10+
zip(
11+
axes_names,
12+
({k} for k in grid.attrs["node_dimensions"].split(" ")),
13+
)
14+
)
15+
for attr in ["face_dimensions", "edge1_dimensions", "edge2_dimensions", "foo"]:
16+
if attr in grid.attrs:
17+
matches = re.findall(pattern, grid.attrs[attr] + "\n")
18+
assert len(matches) == ndim, matches
19+
for ax, match in zip(axes_names, matches):
20+
axes[ax].update(set(match[:2]))
21+
22+
if ndim == 2 and "vertical_dimensions" in grid.attrs:
23+
matches = re.findall(pattern, grid.attrs["vertical_dimensions"] + "\n")
24+
assert len(matches) == 1
25+
axes["Z"] = set(matches[0][:2])
26+
27+
return axes

cf_xarray/tests/test_accessor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
multiple,
3131
pomds,
3232
popds,
33+
roms_sgrid,
3334
romsds,
3435
rotds,
3536
vert,
@@ -1799,3 +1800,29 @@ def plane(coords, slopex, slopey):
17991800
coords=["latitude", "longitude"], func=plane
18001801
)
18011802
assert_identical(expected, actual)
1803+
1804+
1805+
def test_sgrid():
1806+
from ..sgrid import parse_axes
1807+
1808+
positions = ["psi", "u", "v", "rho"]
1809+
expected = {
1810+
"X": {f"xi_{pos}" for pos in positions},
1811+
"Y": {f"eta_{pos}" for pos in positions},
1812+
"Z": {"s_rho", "s_w"},
1813+
}
1814+
assert parse_axes(roms_sgrid) == expected
1815+
assert roms_sgrid.cf.axes == {"X": ["xi_u"], "Y": ["eta_u"]}
1816+
1817+
assert_identical(roms_sgrid.cf["X"], roms_sgrid.xi_u)
1818+
assert_identical(roms_sgrid.cf["Y"], roms_sgrid.eta_u)
1819+
1820+
with pytest.raises(KeyError):
1821+
roms_sgrid.u.cf["X"]
1822+
with pytest.raises(KeyError):
1823+
roms_sgrid.u.cf["Y"]
1824+
1825+
roms_ = roms_sgrid.set_coords("grid")
1826+
assert roms_.u.cf.axes == {"X": ["xi_u"], "Y": ["eta_u"]}
1827+
assert_identical(roms_.u.cf["X"], roms_sgrid.xi_u)
1828+
assert_identical(roms_.u.cf["Y"], roms_sgrid.eta_u)

0 commit comments

Comments
 (0)