|
1 | 1 | from contextlib import suppress
|
2 | 2 |
|
3 | 3 | import numpy as np
|
| 4 | +import pandas as pd |
4 | 5 | import pytest
|
5 | 6 |
|
6 | 7 | from xarray import Variable
|
@@ -32,10 +33,30 @@ def test_vlen_dtype():
|
32 | 33 | @pytest.mark.parametrize("numpy_str_type", (np.str, np.str_))
|
33 | 34 | def test_numpy_str_handling(numpy_str_type):
|
34 | 35 | dtype = strings.create_vlen_dtype(numpy_str_type)
|
35 |
| - assert dtype.metadata["element_type"] == str |
| 36 | + assert dtype.metadata["element_type"] == numpy_str_type |
36 | 37 | assert strings.is_unicode_dtype(dtype)
|
37 | 38 | assert not strings.is_bytes_dtype(dtype)
|
38 |
| - assert strings.check_vlen_dtype(dtype) is str |
| 39 | + assert strings.check_vlen_dtype(dtype) is numpy_str_type |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.parametrize("numpy_str_type", (np.str, np.str_)) |
| 43 | +def test_write_file_from_np_str(numpy_str_type): |
| 44 | + # should be moved elsewhere probably |
| 45 | + scenarios = [numpy_str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]] |
| 46 | + years = range(2015, 2100 + 1) |
| 47 | + tdf = pd.DataFrame( |
| 48 | + data=np.random.random((len(scenarios), len(years))), |
| 49 | + columns=years, |
| 50 | + index=scenarios, |
| 51 | + ) |
| 52 | + tdf.index.name = "scenario" |
| 53 | + tdf.columns.name = "year" |
| 54 | + tdf = tdf.stack() |
| 55 | + tdf.name = "tas" |
| 56 | + |
| 57 | + txr = tdf.to_xarray() |
| 58 | + |
| 59 | + txr.to_netcdf("test.nc") |
39 | 60 |
|
40 | 61 |
|
41 | 62 | def test_EncodedStringCoder_decode():
|
|
0 commit comments