Skip to content

Commit fe3560e

Browse files
HasanAhmadQ7HasanAhmadQ7
authored andcommitted
BUG: .sel method fails when label is float differnt from coords float type
casting at a lower level closer to pandas.Index methods
1 parent d037c03 commit fe3560e

File tree

4 files changed

+29
-23
lines changed

4 files changed

+29
-23
lines changed

xarray/core/dataset.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,15 +1814,6 @@ def sel(self, indexers=None, method=None, tolerance=None, drop=False,
18141814
DataArray.sel
18151815
"""
18161816
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
1817-
casting_keys = self._get_casting_keys(indexers.keys())
1818-
from .dataarray import DataArray
1819-
for k in casting_keys:
1820-
if isinstance(indexers[k], (slice, DataArray, Variable)):
1821-
pass
1822-
else:
1823-
casting_type = getattr(self.coords[k].values.dtype, "type")
1824-
indexers[k] = casting_type(indexers[k])
1825-
18261817
pos_indexers, new_indexes = remap_label_indexers(
18271818
self, indexers=indexers, method=method, tolerance=tolerance)
18281819
result = self.isel(indexers=pos_indexers, drop=drop)
@@ -2331,17 +2322,6 @@ def interp_like(self, other, method='linear', assume_sorted=False,
23312322
ds = self.reindex(object_coords)
23322323
return ds.interp(numeric_coords, method, assume_sorted, kwargs)
23332324

2334-
# Helper method for sel()
2335-
def _get_casting_keys(self, indexers_keys):
2336-
casting_keys = []
2337-
coords_keys = self.coords.keys()
2338-
common_keys = list(set(indexers_keys) & set(coords_keys))
2339-
for k in common_keys:
2340-
coords_var = self.coords[k].values
2341-
if (isinstance(coords_var, np.ndarray) and
2342-
coords_var.dtype.kind == 'f'):
2343-
casting_keys.append(k)
2344-
return casting_keys
23452325

23462326
# Helper methods for rename()
23472327
def _rename_vars(self, name_dict, dims_dict):

xarray/core/indexing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import duck_array_ops, nputils, utils
1212
from .pycompat import dask_array_type, integer_types
13-
from .utils import is_dict_like
13+
from .utils import is_dict_like, maybe_cast_to_coords_dtype
1414

1515

1616
def expanded_indexer(key, ndim):
@@ -248,6 +248,8 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):
248248
'an associated coordinate.')
249249
pos_indexers[dim] = label
250250
else:
251+
coords_dtype = data_obj.coords[dim].values.dtype
252+
label = maybe_cast_to_coords_dtype(label, coords_dtype)
251253
idxr, new_idx = convert_label_indexer(index, label,
252254
dim, method, tolerance)
253255
pos_indexers[dim] = idxr

xarray/core/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index:
6666
return index
6767

6868

69+
def maybe_cast_to_coords_dtype(label, coords_dtype):
70+
if isinstance(label, float):
71+
label = coords_dtype.type(label)
72+
elif isinstance(label, list) and coords_dtype.kind == 'f':
73+
label = np.asarray(label, dtype=coords_dtype)
74+
return label
75+
76+
6977
def safe_cast_to_index(array: Any) -> pd.Index:
7078
"""Given an array, safely cast it to a pandas.Index.
7179

xarray/tests/test_dataarray.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,15 +817,31 @@ def test_sel_dataarray_datetime(self):
817817
result = array.sel(delta=slice(array.delta[0], array.delta[-1]))
818818
assert_equal(result, array)
819819

820-
def test_sel_float32(self):
820+
def test_sel_float(self):
821+
data_values = np.arange(4)
822+
823+
# case coords are float32 and label is list of floats
821824
float_values = [0., 0.111, 0.222, 0.333]
822825
coord_values = np.asarray(float_values, dtype='float32')
823-
data_values = np.arange(4)
824826
array = DataArray(data_values, [('float32_coord', coord_values)])
825827
expected = DataArray(data_values[1:3], [('float32_coord',
826828
coord_values[1:3])])
827829
actual = array.sel(float32_coord=float_values[1:3])
830+
# case coords are float16 and label is list of floats
831+
coord_values_16 = np.asarray(float_values, dtype='float16')
832+
expected_16 = DataArray(data_values[1:3], [('float16_coord',
833+
coord_values_16[1:3])])
834+
array_16 = DataArray(data_values, [('float16_coord', coord_values_16)])
835+
actual_16 = array_16.sel(float16_coord=float_values[1:3])
836+
837+
# case coord, label are scalars
838+
expected_scalar = DataArray(data_values[2], coords={
839+
'float32_coord': coord_values[2]})
840+
actual_scalar = array.sel(float32_coord=float_values[2])
841+
828842
assert_equal(expected, actual)
843+
assert_equal(expected_scalar, actual_scalar)
844+
assert_equal(expected_16, actual_16)
829845

830846
def test_sel_no_index(self):
831847
array = DataArray(np.arange(10), dims='x')

0 commit comments

Comments
 (0)