|
30 | 30 | save_mfdataset,
|
31 | 31 | )
|
32 | 32 | from xarray.backends.common import robust_getitem
|
| 33 | +from xarray.backends.netcdf3 import _nc3_dtype_coercions |
33 | 34 | from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
|
34 | 35 | from xarray.backends.pydap_ import PydapDataStore
|
35 | 36 | from xarray.coding.variables import SerializationWarning
|
@@ -227,7 +228,27 @@ def __getitem__(self, key):
|
227 | 228 |
|
228 | 229 |
|
229 | 230 | class NetCDF3Only:
|
230 |
| - pass |
| 231 | + netcdf3_formats = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") |
| 232 | + |
| 233 | + @requires_scipy |
| 234 | + def test_dtype_coercion_error(self): |
| 235 | + """Failing dtype coercion should lead to an error""" |
| 236 | + for dtype, format in itertools.product( |
| 237 | + _nc3_dtype_coercions, self.netcdf3_formats |
| 238 | + ): |
| 239 | + if dtype == "bool": |
| 240 | + # coerced upcast (bool to int8) ==> can never fail |
| 241 | + continue |
| 242 | + |
| 243 | + # Using the largest representable value, create some data that will |
| 244 | + # no longer compare equal after the coerced downcast |
| 245 | + maxval = np.iinfo(dtype).max |
| 246 | + x = np.array([0, 1, 2, maxval], dtype=dtype) |
| 247 | + ds = Dataset({"x": ("t", x, {})}) |
| 248 | + |
| 249 | + with create_tmp_file(allow_cleanup_failure=False) as path: |
| 250 | + with pytest.raises(ValueError, match="could not safely cast"): |
| 251 | + ds.to_netcdf(path, format=format) |
231 | 252 |
|
232 | 253 |
|
233 | 254 | class DatasetIOBase:
|
@@ -296,9 +317,14 @@ def test_write_store(self):
|
296 | 317 | def check_dtypes_roundtripped(self, expected, actual):
|
297 | 318 | for k in expected.variables:
|
298 | 319 | expected_dtype = expected.variables[k].dtype
|
299 |
| - if isinstance(self, NetCDF3Only) and expected_dtype == "int64": |
300 |
| - # downcast |
301 |
| - expected_dtype = np.dtype("int32") |
| 320 | + |
| 321 | + # For NetCDF3, the backend should perform dtype coercion |
| 322 | + if ( |
| 323 | + isinstance(self, NetCDF3Only) |
| 324 | + and str(expected_dtype) in _nc3_dtype_coercions |
| 325 | + ): |
| 326 | + expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)]) |
| 327 | + |
302 | 328 | actual_dtype = actual.variables[k].dtype
|
303 | 329 | # TODO: check expected behavior for string dtypes more carefully
|
304 | 330 | string_kinds = {"O", "S", "U"}
|
@@ -2156,7 +2182,7 @@ def test_cross_engine_read_write_netcdf3(self):
|
2156 | 2182 | valid_engines.add("scipy")
|
2157 | 2183 |
|
2158 | 2184 | for write_engine in valid_engines:
|
2159 |
| - for format in ["NETCDF3_CLASSIC", "NETCDF3_64BIT"]: |
| 2185 | + for format in self.netcdf3_formats: |
2160 | 2186 | with create_tmp_file() as tmp_file:
|
2161 | 2187 | data.to_netcdf(tmp_file, format=format, engine=write_engine)
|
2162 | 2188 | for read_engine in valid_engines:
|
|
0 commit comments