-
Notifications
You must be signed in to change notification settings - Fork 20
Fix mypy errors in xarray.py, xrutils.py, cache.py #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c972e97
2e42456
64c7d77
b3d698a
6e4db03
afee7c4
ed752dd
6303f4a
ae8953a
8fba166
ae5561d
5145dc2
5d46140
05893a2
375c31b
6ba6da4
170467b
a3d63a2
cf0d6cd
bde6c52
3728858
657496d
68ac242
c306099
90b0149
332caf9
9740009
5c08114
21b641d
d5409ef
1accd73
a50bb6b
50c2ac2
db2ac1b
1921938
3cac4b0
43dabff
e73f6e8
2d62748
62cc554
41e97e9
fc36211
a5d41a5
bfb9c6e
7260660
b34c268
eaf93d2
9486184
b18d209
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Hashable, Iterable, Sequence | ||
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
@@ -19,7 +19,10 @@ | |
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric | ||
|
||
if TYPE_CHECKING: | ||
from xarray import DataArray, Dataset, Resample | ||
from xarray.core.resample import Resample | ||
from xarray.core.types import T_DataArray, T_Dataset | ||
|
||
Dims = Union[str, Iterable[Hashable], None] | ||
|
||
|
||
def _get_input_core_dims(group_names, dim, ds, grouper_dims): | ||
|
@@ -51,13 +54,13 @@ def lookup_order(dimension): | |
|
||
|
||
def xarray_reduce( | ||
obj: Dataset | DataArray, | ||
*by: DataArray | Iterable[str] | Iterable[DataArray], | ||
obj: T_Dataset | T_DataArray, | ||
*by: T_DataArray | Hashable, | ||
func: str | Aggregation, | ||
expected_groups=None, | ||
isbin: bool | Sequence[bool] = False, | ||
sort: bool = True, | ||
dim: Hashable = None, | ||
dim: Dims | ellipsis = None, | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
split_out: int = 1, | ||
fill_value=None, | ||
method: str = "map-reduce", | ||
|
@@ -203,8 +206,11 @@ def xarray_reduce( | |
if keep_attrs is None: | ||
keep_attrs = True | ||
|
||
if isinstance(isbin, bool): | ||
isbin = (isbin,) * nby | ||
if isinstance(isbin, Sequence): | ||
isbins = isbin | ||
else: | ||
isbins = (isbin,) * nby | ||
|
||
if expected_groups is None: | ||
expected_groups = (None,) * nby | ||
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list | ||
|
@@ -217,78 +223,86 @@ def xarray_reduce( | |
raise NotImplementedError | ||
|
||
# eventually drop the variables we are grouping by | ||
maybe_drop = [b for b in by if isinstance(b, str)] | ||
maybe_drop = [b for b in by if isinstance(b, Hashable)] | ||
unindexed_dims = tuple( | ||
b | ||
for b, isbin_ in zip(by, isbin) | ||
if isinstance(b, str) and not isbin_ and b in obj.dims and b not in obj.indexes | ||
for b, isbin_ in zip(by, isbins) | ||
if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes | ||
) | ||
|
||
by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by) # type: ignore | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by) | ||
|
||
grouper_dims = [] | ||
for g in by: | ||
for g in by_da: | ||
for d in g.dims: | ||
if d not in grouper_dims: | ||
grouper_dims.append(d) | ||
|
||
if isinstance(obj, xr.DataArray): | ||
ds = obj._to_temp_dataset() | ||
else: | ||
if isinstance(obj, xr.Dataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this rearrangement was weird. Is it a mypy bug? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the error you get if you isinstance with DataArray: # obj: Union[T_Dataset, T_DataArray]
if isinstance(obj, xr.DataArray):
ds = obj._to_temp_dataset() # -> xr.Dataset
else:
ds = obj # error: Incompatible types in assignment (expression has type "Union[T_Dataset, T_DataArray]", variable has type "Dataset") My understanding is that mypy always uses the typing from the first time it is defined ( |
||
ds = obj | ||
else: | ||
ds = obj._to_temp_dataset() | ||
|
||
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) | ||
|
||
if dim is Ellipsis: | ||
if nby > 1: | ||
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") | ||
dim = tuple(obj.dims) | ||
if by[0].name in ds.dims and not isbin[0]: | ||
dim = tuple(d for d in dim if d != by[0].name) | ||
name_ = by_da[0].name | ||
if name_ in ds.dims and not isbins[0]: | ||
dim_tuple = tuple(d for d in obj.dims if d != name_) | ||
else: | ||
dim_tuple = tuple(obj.dims) | ||
elif dim is not None: | ||
dim = _atleast_1d(dim) | ||
dim_tuple = _atleast_1d(dim) | ||
else: | ||
dim = tuple() | ||
dim_tuple = tuple() | ||
|
||
# broadcast all variables against each other along all dimensions in `by` variables | ||
# don't exclude `dim` because it need not be a dimension in any of the `by` variables! | ||
# in the case where dim is Ellipsis, and by.ndim < obj.ndim | ||
# then we also broadcast `by` to all `obj.dims` | ||
# TODO: avoid this broadcasting | ||
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim) | ||
ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims) | ||
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) | ||
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims) | ||
|
||
if not dim: | ||
dim = tuple(by[0].dims) | ||
# all members of by_broad have the same dimensions | ||
# so we just pull by_broad[0].dims if dim is None | ||
if not dim_tuple: | ||
Illviljan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dim_tuple = tuple(by_broad[0].dims) | ||
|
||
if any(d not in grouper_dims and d not in obj.dims for d in dim): | ||
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): | ||
raise ValueError(f"Cannot reduce over absent dimensions {dim}.") | ||
|
||
dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims) | ||
if dims_not_in_groupers == tuple(dim) and not any(isbin): | ||
dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) | ||
if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): | ||
# reducing along a dimension along which groups do not vary | ||
# This is really just a normal reduction. | ||
# This is not right when binning so we exclude. | ||
if skipna and isinstance(func, str): | ||
dsfunc = func[3:] | ||
if isinstance(func, str): | ||
dsfunc = func[3:] if skipna else func | ||
else: | ||
dsfunc = func | ||
raise NotImplementedError( | ||
"func must be a string when reducing along a dimension not present in `by`" | ||
) | ||
# TODO: skipna needs test | ||
result = getattr(ds, dsfunc)(dim=dim, skipna=skipna) | ||
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna) | ||
if isinstance(obj, xr.DataArray): | ||
return obj._from_temp_dataset(result) | ||
else: | ||
return result | ||
|
||
axis = tuple(range(-len(dim), 0)) | ||
group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin)) | ||
|
||
group_shape = [None] * len(by) | ||
expected_groups = list(expected_groups) | ||
axis = tuple(range(-len(dim_tuple), 0)) | ||
|
||
# Set expected_groups and convert to index since we need coords, sizes | ||
# for output xarray objects | ||
for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)): | ||
expected_groups = list(expected_groups) | ||
group_names: tuple[Any, ...] = () | ||
group_sizes: dict[Any, int] = {} | ||
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)): | ||
group_name = b_.name if not isbin_ else f"{b_.name}_bins" | ||
group_names += (group_name,) | ||
|
||
if isbin_ and isinstance(expect, int): | ||
raise NotImplementedError( | ||
"flox does not support binning into an integer number of bins yet." | ||
|
@@ -297,13 +311,21 @@ def xarray_reduce( | |
if isbin_: | ||
raise ValueError( | ||
f"Please provided bin edges for group variable {idx} " | ||
f"named {group_names[idx]} in expected_groups." | ||
f"named {group_name} in expected_groups." | ||
) | ||
expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True) | ||
|
||
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort) | ||
group_shape = tuple(len(e) for e in expected_groups) | ||
group_sizes = dict(zip(group_names, group_shape)) | ||
expect_ = _get_expected_groups(b_.data, sort=sort, raise_if_dask=True) | ||
else: | ||
expect_ = expect | ||
expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0] | ||
|
||
# The if-check is for type hinting mainly, it narrows down the return | ||
# type of _convert_expected_groups_to_index to pure pd.Index: | ||
if expect_index is not None: | ||
expected_groups[idx] = expect_index | ||
group_sizes[group_name] = len(expect_index) | ||
else: | ||
# This will never be reached | ||
raise ValueError("expect_index cannot be None") | ||
Illviljan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def wrapper(array, *by, func, skipna, **kwargs): | ||
# Handle skipna here because I need to know dtype to make a good default choice. | ||
|
@@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs): | |
if isinstance(obj, xr.Dataset): | ||
# broadcasting means the group dim gets added to ds, so we check the original obj | ||
for k, v in obj.data_vars.items(): | ||
is_missing_dim = not (any(d in v.dims for d in dim)) | ||
is_missing_dim = not (any(d in v.dims for d in dim_tuple)) | ||
if is_missing_dim: | ||
missing_dim[k] = v | ||
|
||
input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims) | ||
input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims) | ||
input_core_dims += [input_core_dims[-1]] * (nby - 1) | ||
|
||
actual = xr.apply_ufunc( | ||
wrapper, | ||
ds.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), | ||
*by, | ||
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), | ||
*by_broad, | ||
input_core_dims=input_core_dims, | ||
# for xarray's test_groupby_duplicate_coordinate_labels | ||
exclude_dims=set(dim), | ||
exclude_dims=set(dim_tuple), | ||
output_core_dims=[group_names], | ||
dask="allowed", | ||
dask_gufunc_kwargs=dict(output_sizes=group_sizes), | ||
|
@@ -379,27 +401,27 @@ def wrapper(array, *by, func, skipna, **kwargs): | |
"engine": engine, | ||
"reindex": reindex, | ||
"expected_groups": tuple(expected_groups), | ||
"isbin": isbin, | ||
"isbin": isbins, | ||
"finalize_kwargs": finalize_kwargs, | ||
}, | ||
) | ||
|
||
# restore non-dim coord variables without the core dimension | ||
# TODO: shouldn't apply_ufunc handle this? | ||
for var in set(ds.variables) - set(ds.dims): | ||
if all(d not in ds[var].dims for d in dim): | ||
actual[var] = ds[var] | ||
for var in set(ds_broad.variables) - set(ds_broad.dims): | ||
if all(d not in ds_broad[var].dims for d in dim_tuple): | ||
actual[var] = ds_broad[var] | ||
|
||
for name, expect, by_ in zip(group_names, expected_groups, by): | ||
for name, expect, by_ in zip(group_names, expected_groups, by_broad): | ||
# Can't remove this till xarray handles IntervalIndex | ||
if isinstance(expect, pd.IntervalIndex): | ||
expect = expect.to_numpy() | ||
if isinstance(actual, xr.Dataset) and name in actual: | ||
actual = actual.drop_vars(name) | ||
# When grouping by MultiIndex, expect is an pd.Index wrapping | ||
# an object array of tuples | ||
if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex): | ||
levelnames = ds.indexes[name].names | ||
if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex): | ||
levelnames = ds_broad.indexes[name].names | ||
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames) | ||
actual[name] = expect | ||
if Version(xr.__version__) > Version("2022.03.0"): | ||
|
@@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs): | |
|
||
if nby == 1: | ||
for var in actual: | ||
if isinstance(obj, xr.DataArray): | ||
template = obj | ||
else: | ||
if isinstance(obj, xr.Dataset): | ||
template = obj[var] | ||
else: | ||
template = obj | ||
|
||
if actual[var].ndim > 1: | ||
actual[var] = _restore_dim_order(actual[var], template, by[0]) | ||
actual[var] = _restore_dim_order(actual[var], template, by_broad[0]) | ||
|
||
if missing_dim: | ||
for k, v in missing_dim.items(): | ||
missing_group_dims = { | ||
dim: size for dim, size in group_sizes.items() if dim not in v.dims | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims} | ||
# The expand_dims is for backward compat with xarray's questionable behaviour | ||
if missing_group_dims: | ||
actual[k] = v.expand_dims(missing_group_dims).variable | ||
|
@@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs): | |
|
||
|
||
def rechunk_for_cohorts( | ||
obj: DataArray | Dataset, | ||
obj: T_DataArray | T_Dataset, | ||
dim: str, | ||
labels: DataArray, | ||
labels: T_DataArray, | ||
force_new_chunk_at, | ||
chunksize: int | None = None, | ||
ignore_old_chunks: bool = False, | ||
|
@@ -486,7 +507,7 @@ def rechunk_for_cohorts( | |
) | ||
|
||
|
||
def rechunk_for_blockwise(obj: DataArray | Dataset, dim: str, labels: DataArray): | ||
def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray): | ||
""" | ||
Rechunks array so that group boundaries line up with chunk boundaries, allowing | ||
embarassingly parallel group reductions. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,3 +57,5 @@ per-file-ignores = | |
exclude= | ||
.eggs | ||
doc | ||
builtins = | ||
ellipsis |
Uh oh!
There was an error while loading. Please reload this page.