Skip to content

Commit 5c04ebf

Browse files
blsqrdcherian
andauthored
Add NetCDF3 dtype coercion for unsigned integer types (#4018)
* In netcdf3 backend, also coerce unsigned integer dtypes * Adjust test for netcdf3 rountrip to include coercion This might be a bit too general for what is required at this point, though ... 🤔 * Add test for failing dtype coercion * Add What's New entry for issue #4014 and PR #4018 * Move netcdf3-specific test to NetCDF3Only class Also uses a class variable for definition of netcdf3 formats now. Co-authored-by: Deepak Cherian <[email protected]>
1 parent cb90d55 commit 5c04ebf

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ New Features
7070
the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This
7171
feature requires cftime version 1.1.0 or greater. By
7272
`Spencer Clark <https://github.com/spencerkclark>`_.
73+
- For the netCDF3 backend, added dtype coercions for unsigned integer types.
74+
(:issue:`4014`, :pull:`4018`)
75+
By `Yunus Sevinchan <https://github.com/blsqr>`_
7376
- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases
7477
where the result of a computation could not be inferred automatically.
7578
By `Deepak Cherian <https://github.com/dcherian>`_

xarray/backends/netcdf3.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@
2828

2929
# These data-types aren't supported by netCDF3, so they are automatically
3030
# coerced instead as indicated by the "coerce_nc3_dtype" function
31-
_nc3_dtype_coercions = {"int64": "int32", "bool": "int8"}
31+
_nc3_dtype_coercions = {
32+
"int64": "int32",
33+
"uint64": "int32",
34+
"uint32": "int32",
35+
"uint16": "int16",
36+
"uint8": "int8",
37+
"bool": "int8",
38+
}
3239

3340
# encode all strings as UTF-8
3441
STRING_ENCODING = "utf-8"
@@ -37,12 +44,17 @@
3744
def coerce_nc3_dtype(arr):
3845
"""Coerce an array to a data type that can be stored in a netCDF-3 file
3946
40-
This function performs the following dtype conversions:
41-
int64 -> int32
42-
bool -> int8
43-
44-
Data is checked for equality, or equivalence (non-NaN values) with
45-
`np.allclose` with the default keyword arguments.
47+
This function performs the dtype conversions as specified by the
48+
``_nc3_dtype_coercions`` mapping:
49+
int64 -> int32
50+
uint64 -> int32
51+
uint32 -> int32
52+
uint16 -> int16
53+
uint8 -> int8
54+
bool -> int8
55+
56+
Data is checked for equality, or equivalence (non-NaN values) using the
57+
``(cast_array == original_array).all()``.
4658
"""
4759
dtype = str(arr.dtype)
4860
if dtype in _nc3_dtype_coercions:

xarray/tests/test_backends.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
save_mfdataset,
3131
)
3232
from xarray.backends.common import robust_getitem
33+
from xarray.backends.netcdf3 import _nc3_dtype_coercions
3334
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
3435
from xarray.backends.pydap_ import PydapDataStore
3536
from xarray.coding.variables import SerializationWarning
@@ -227,7 +228,27 @@ def __getitem__(self, key):
227228

228229

229230
class NetCDF3Only:
230-
pass
231+
netcdf3_formats = ("NETCDF3_CLASSIC", "NETCDF3_64BIT")
232+
233+
@requires_scipy
234+
def test_dtype_coercion_error(self):
235+
"""Failing dtype coercion should lead to an error"""
236+
for dtype, format in itertools.product(
237+
_nc3_dtype_coercions, self.netcdf3_formats
238+
):
239+
if dtype == "bool":
240+
# coerced upcast (bool to int8) ==> can never fail
241+
continue
242+
243+
# Using the largest representable value, create some data that will
244+
# no longer compare equal after the coerced downcast
245+
maxval = np.iinfo(dtype).max
246+
x = np.array([0, 1, 2, maxval], dtype=dtype)
247+
ds = Dataset({"x": ("t", x, {})})
248+
249+
with create_tmp_file(allow_cleanup_failure=False) as path:
250+
with pytest.raises(ValueError, match="could not safely cast"):
251+
ds.to_netcdf(path, format=format)
231252

232253

233254
class DatasetIOBase:
@@ -296,9 +317,14 @@ def test_write_store(self):
296317
def check_dtypes_roundtripped(self, expected, actual):
297318
for k in expected.variables:
298319
expected_dtype = expected.variables[k].dtype
299-
if isinstance(self, NetCDF3Only) and expected_dtype == "int64":
300-
# downcast
301-
expected_dtype = np.dtype("int32")
320+
321+
# For NetCDF3, the backend should perform dtype coercion
322+
if (
323+
isinstance(self, NetCDF3Only)
324+
and str(expected_dtype) in _nc3_dtype_coercions
325+
):
326+
expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)])
327+
302328
actual_dtype = actual.variables[k].dtype
303329
# TODO: check expected behavior for string dtypes more carefully
304330
string_kinds = {"O", "S", "U"}
@@ -2156,7 +2182,7 @@ def test_cross_engine_read_write_netcdf3(self):
21562182
valid_engines.add("scipy")
21572183

21582184
for write_engine in valid_engines:
2159-
for format in ["NETCDF3_CLASSIC", "NETCDF3_64BIT"]:
2185+
for format in self.netcdf3_formats:
21602186
with create_tmp_file() as tmp_file:
21612187
data.to_netcdf(tmp_file, format=format, engine=write_engine)
21622188
for read_engine in valid_engines:

0 commit comments

Comments
 (0)