Skip to content

Commit bc5c79e

Browse files
nbren12crusaderky
andauthored
Improve typehints of xr.Dataset.__getitem__ (#4144)
* Improve typehints of xr.Dataset.__getitem__ Resolves #4125 * Add overload for Mapping behavior Sadly this is not working with my version of mypy. See python/mypy#7328 * Overload only Hashable inputs Given mypy's use of overloads, I think this is all we can do. If the argument is not Hashable, then return the Union type as before. * Lint * Quote the DataArray to avoid error in py3.6 * Code review Co-authored-by: crusaderky <[email protected]>
1 parent 2ba5300 commit bc5c79e

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
hooks:
1717
- id: flake8
1818
- repo: https://github.com/pre-commit/mirrors-mypy
19-
rev: v0.761 # Must match ci/requirements/*.yml
19+
rev: v0.780 # Must match ci/requirements/*.yml
2020
hooks:
2121
- id: mypy
2222
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194

ci/requirements/py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- isort
2323
- lxml # Optional dep of pydap
2424
- matplotlib
25-
- mypy=0.761 # Must match .pre-commit-config.yaml
25+
- mypy=0.780 # Must match .pre-commit-config.yaml
2626
- nc-time-axis
2727
- netcdf4
2828
- numba

xarray/core/dataset.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TypeVar,
2828
Union,
2929
cast,
30+
overload,
3031
)
3132

3233
import numpy as np
@@ -1241,13 +1242,25 @@ def loc(self) -> _LocIndexer:
12411242
"""
12421243
return _LocIndexer(self)
12431244

1244-
def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]":
1245+
# FIXME https://github.com/python/mypy/issues/7328
1246+
@overload
1247+
def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore
1248+
...
1249+
1250+
@overload
1251+
def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore
1252+
...
1253+
1254+
@overload
1255+
def __getitem__(self, key: Any) -> "Dataset":
1256+
...
1257+
1258+
def __getitem__(self, key):
12451259
"""Access variables or coordinates this dataset as a
12461260
:py:class:`~xarray.DataArray`.
12471261
12481262
Indexing with a list of names will return a new ``Dataset`` object.
12491263
"""
1250-
# TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328
12511264
if utils.is_dict_like(key):
12521265
return self.isel(**cast(Mapping, key))
12531266

xarray/core/weighted.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ class Weighted:
7272
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
7373
...
7474

75-
@overload # noqa: F811
76-
def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811
75+
@overload
76+
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
7777
...
7878

79-
def __init__(self, obj, weights): # noqa: F811
79+
def __init__(self, obj, weights):
8080
"""
8181
Create a Weighted object
8282

0 commit comments

Comments
 (0)