Skip to content

Commit dc78afa

Browse files
committed
uses rather than
1 parent e4fa084 commit dc78afa

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

cf_xarray/accessor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,15 +1586,15 @@ def get_bounds_dim_name(self, key: str) -> str:
15861586
assert self._obj.sizes[bounds_dim] in [2, 4]
15871587
return bounds_dim
15881588

1589-
def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
1589+
def add_bounds(self, keys: Union[str, Iterable[str]]):
15901590
"""
15911591
Returns a new object with bounds variables. The bounds values are guessed assuming
15921592
equal spacing on either side of a coordinate label.
15931593
15941594
Parameters
15951595
----------
1596-
dims : Hashable or Iterable[Hashable]
1597-
Either a single dimension name or a list of dimension names.
1596+
keys : str or Iterable[str]
1597+
Either a single key or a list of keys corresponding to dimensions.
15981598
15991599
Returns
16001600
-------
@@ -1609,12 +1609,16 @@ def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
16091609
The bounds variables are automatically named f"{dim}_bounds" where ``dim``
16101610
is a dimension name.
16111611
"""
1612-
if isinstance(dims, Hashable):
1613-
dimensions = (dims,)
1614-
else:
1615-
dimensions = dims
1612+
if isinstance(keys, str):
1613+
keys = [keys]
1614+
1615+
dimensions = set()
1616+
for key in keys:
1617+
dimensions |= set(
1618+
apply_mapper(_get_dims, self._obj, key, error=False, default=[key])
1619+
)
16161620

1617-
bad_dims: Set[Hashable] = set(dimensions) - set(self._obj.dims)
1621+
bad_dims: Set[str] = dimensions - set(self._obj.dims)
16181622
if bad_dims:
16191623
raise ValueError(
16201624
f"{bad_dims!r} are not dimensions in the underlying object."

cf_xarray/tests/test_accessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,10 @@ def test_add_bounds(obj, dims):
634634
assert added[dim].attrs["bounds"] == name
635635
assert_allclose(added[name].reset_coords(drop=True), expected[dim])
636636

637+
# Test multiple dimensions
638+
assert not {"x1_bounds", "x2_bounds"} <= set(multiple.variables)
639+
assert {"x1_bounds", "x2_bounds"} <= set(multiple.cf.add_bounds("X").variables)
640+
637641

638642
def test_bounds():
639643
ds = airds.copy(deep=True).cf.add_bounds("lat")

0 commit comments

Comments
 (0)