diff --git a/cf_xarray/geometry.py b/cf_xarray/geometry.py index e6854448..fb520796 100644 --- a/cf_xarray/geometry.py +++ b/cf_xarray/geometry.py @@ -298,8 +298,10 @@ def encode_geometries(ds: xr.Dataset): geom_var_names = [ name for name, var in ds._variables.items() - if var.dtype == "O" and isinstance(var.data.flat[0], SHAPELY_TYPES) + if var.dtype == "geometry" + or (var.dtype == "O" and isinstance(var.data.flat[0], SHAPELY_TYPES)) ] + if not geom_var_names: return ds diff --git a/cf_xarray/tests/test_geometry.py b/cf_xarray/tests/test_geometry.py index db845a6c..deb784c2 100644 --- a/cf_xarray/tests/test_geometry.py +++ b/cf_xarray/tests/test_geometry.py @@ -507,5 +507,7 @@ def test_encode_decode(geometry_ds, polygon_geometry): ) multi_ds = xr.merge([polyds, geometry_ds[1]]) for ds in (geometry_ds[1], polygon_geometry.to_dataset(), geom_dim_ds, multi_ds): - roundtripped = decode_geometries(encode_geometries(ds)) + encoded = encode_geometries(ds) + assert len(encoded.data_vars) > len(ds.data_vars) + roundtripped = decode_geometries(encoded) xr.testing.assert_identical(ds, roundtripped)