Skip to content

Commit 46dc3a2

Browse files
committed
Support mypy prototype (#650)
* mypy support prototype * remove unnecessary method * Update tests/core/static/pandas_dataframe.py * use sys.executable * update deps * pylint * update mypy nox * add pandas-stubs as core requirement * add koalas and modin accessors * fix import error * improve coverage * remove typevar * attempt to fix modin issue
1 parent 02063c8 commit 46dc3a2

File tree

13 files changed

+372
-8
lines changed

13 files changed

+372
-8
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[BASIC]
22
good-names=
3+
T,
34
F,
45
logger,
56
df,

environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- hypothesis >= 5.41.1
1212
- numpy >= 1.9.0
1313
- pandas
14+
- pandas-stubs
1415
- scipy
1516
- wrapt
1617
- pyyaml >=5.1
@@ -26,7 +27,7 @@ dependencies:
2627

2728
# modin extra
2829
- modin
29-
- ray
30+
- ray <= 1.7.0
3031

3132
# dask extra
3233
- dask

noxfile.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,16 @@ def install_extras(
179179
extra: str = "core",
180180
force_pip: bool = False,
181181
pandas: str = "latest",
182+
pandas_stubs: bool = True,
182183
) -> None:
183184
"""Install dependencies."""
184185
specs, pip_specs = [], []
185186
pandas_version = "" if pandas == "latest" else f"=={pandas}"
186187
for spec in REQUIRES[extra].values():
188+
if spec == "pandas-stubs" and not pandas_stubs:
189+
# this is a temporary measure until all pandas-related mypy errors
190+
# are addressed
191+
continue
187192
if spec.split("==")[0] in ALWAYS_USE_PIP:
188193
pip_specs.append(spec)
189194
else:
@@ -297,7 +302,7 @@ def lint(session: Session) -> None:
297302
@nox.session(python=PYTHON_VERSIONS)
298303
def mypy(session: Session) -> None:
299304
"""Type-check using mypy."""
300-
install_extras(session, extra="all")
305+
install_extras(session, extra="all", pandas_stubs=False)
301306
args = session.posargs or SOURCE_PATHS
302307
session.run("mypy", "--follow-imports=silent", *args, silent=True)
303308

pandera/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,19 @@
6767
from . import dask_accessor
6868
except ImportError:
6969
pass
70+
71+
72+
try:
73+
import databricks.koalas
74+
75+
from . import koalas_accessor
76+
except ImportError:
77+
pass
78+
79+
80+
try:
81+
import modin.pandas
82+
83+
from . import modin_accessor
84+
except ImportError:
85+
pass

pandera/decorators.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,17 @@ def _check_arg(arg_name: str, arg_value: Any) -> Any:
536536
pass-through.
537537
"""
538538
schema, optional = annotated_schemas.get(arg_name, (None, None))
539-
if schema and not (optional and arg_value is None):
539+
if (
540+
schema
541+
and not (optional and arg_value is None)
542+
# the pandera.schema attribute should only be available when
543+
# schema.validate has been called in the DF. There's probably
544+
# a better way of doing this
545+
and (
546+
arg_value.pandera.schema is None
547+
or arg_value.pandera.schema != schema
548+
)
549+
):
540550
try:
541551
return schema.validate(
542552
arg_value, head, tail, sample, random_state, lazy, inplace

pandera/koalas_accessor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Register koalas accessor for pandera schema metadata."""
2+
3+
from databricks.koalas.extensions import (
4+
register_dataframe_accessor,
5+
register_series_accessor,
6+
)
7+
8+
from pandera.pandas_accessor import (
9+
PanderaDataFrameAccessor,
10+
PanderaSeriesAccessor,
11+
)
12+
13+
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
14+
register_series_accessor("pandera")(PanderaSeriesAccessor)

pandera/modin_accessor.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Custom accessor functionality for modin.
2+
3+
Source code adapted from koalas implementation:
4+
https://koalas.readthedocs.io/en/latest/_modules/databricks/koalas/extensions.html#register_dataframe_accessor
5+
"""
6+
7+
import warnings
8+
9+
from pandera.pandas_accessor import (
10+
PanderaDataFrameAccessor,
11+
PanderaSeriesAccessor,
12+
)
13+
14+
15+
# pylint: disable=too-few-public-methods
16+
class CachedAccessor:
17+
"""
18+
Custom property-like object.
19+
20+
A descriptor for caching accessors:
21+
22+
:param name: Namespace that accessor's methods, properties, etc will be
23+
accessed under, e.g. "foo" for a dataframe accessor yields the accessor
24+
``df.foo``
25+
:param cls: Class with the extension methods.
26+
27+
For accessor, the class's __init__ method assumes that you are registering
28+
an accessor for one of ``Series``, ``DataFrame``, or ``Index``.
29+
"""
30+
31+
def __init__(self, name, accessor):
32+
self._name = name
33+
self._accessor = accessor
34+
35+
def __get__(self, obj, cls):
36+
if obj is None: # pragma: no cover
37+
return self._accessor
38+
accessor_obj = self._accessor(obj)
39+
object.__setattr__(obj, self._name, accessor_obj)
40+
return accessor_obj
41+
42+
43+
def _register_accessor(name, cls):
44+
"""
45+
Register a custom accessor on {class} objects.
46+
47+
:param name: Name under which the accessor should be registered. A warning
48+
is issued if this name conflicts with a preexisting attribute.
49+
:returns: A class decorator callable.
50+
"""
51+
52+
def decorator(accessor):
53+
if hasattr(cls, name):
54+
msg = (
55+
f"registration of accessor {accessor} under name '{name}' for "
56+
"type {cls.__name__} is overriding a preexisting attribute "
57+
"with the same name."
58+
)
59+
60+
warnings.warn(
61+
msg,
62+
UserWarning,
63+
stacklevel=2,
64+
)
65+
setattr(cls, name, CachedAccessor(name, accessor))
66+
return accessor
67+
68+
return decorator
69+
70+
71+
def register_dataframe_accessor(name):
72+
"""
73+
Register a custom accessor with a DataFrame
74+
75+
:param name: name used when calling the accessor after its registered
76+
:returns: a class decorator callable.
77+
"""
78+
# pylint: disable=import-outside-toplevel
79+
from modin.pandas import DataFrame
80+
81+
return _register_accessor(name, DataFrame)
82+
83+
84+
def register_series_accessor(name):
85+
"""
86+
Register a custom accessor with a Series object
87+
88+
:param name: name used when calling the accessor after its registered
89+
:returns: a callable class decorator
90+
"""
91+
# pylint: disable=import-outside-toplevel
92+
from modin.pandas import Series
93+
94+
return _register_accessor(name, Series)
95+
96+
97+
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
98+
register_series_accessor("pandera")(PanderaSeriesAccessor)

pandera/typing.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Typing definitions and helpers."""
22
# pylint:disable=abstract-method,disable=too-many-ancestors
3+
import inspect
34
from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar
45

56
import pandas as pd
@@ -134,8 +135,35 @@ def __get__(
134135
T = Schema
135136

136137

138+
class DataFrameBase(pd.DataFrame):
139+
"""
140+
Pandera pandas.Dataframe base class for validating dataframes on
141+
initialization.
142+
"""
143+
144+
def __setattr__(self, name: str, value: Any) -> None:
145+
object.__setattr__(self, name, value)
146+
if name == "__orig_class__":
147+
class_args = getattr(self.__orig_class__, "__args__", None)
148+
if any(
149+
x.__name__ == "SchemaModel"
150+
for x in inspect.getmro(class_args[0])
151+
):
152+
schema_model = value.__args__[0]
153+
154+
# prevent the double validation problem by preventing checks for
155+
# dataframes with a defined pandera.schema
156+
if (
157+
self.pandera.schema is None
158+
or self.pandera.schema != schema_model.to_schema()
159+
):
160+
# pylint: disable=self-cls-assignment
161+
self = schema_model.validate(self)
162+
self.pandera.add_schema(schema_model.to_schema())
163+
164+
137165
# pylint:disable=too-few-public-methods
138-
class DataFrame(pd.DataFrame, Generic[T]):
166+
class DataFrame(Generic[T], DataFrameBase):
139167
"""
140168
Representation of pandas.DataFrame, only used for type annotation.
141169

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ packaging >= 20.0
66
hypothesis >= 5.41.1
77
numpy >= 1.9.0
88
pandas
9+
pandas-stubs
910
scipy
1011
wrapt
1112
pyyaml >=5.1
@@ -17,7 +18,7 @@ pydantic
1718
koalas
1819
pyspark
1920
modin
20-
ray
21+
ray <= 1.7.0
2122
dask
2223
distributed
2324
black >= 20.8b1

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
"hypotheses": ["scipy"],
1313
"io": ["pyyaml >= 5.1", "black", "frictionless"],
1414
"koalas": ["koalas", "pyspark"],
15-
"modin": ["modin", "ray", "dask"],
16-
"modin-ray": ["modin", "ray"],
15+
"modin": ["modin", "ray <= 1.7.0", "dask"],
16+
"modin-ray": ["modin", "ray <= 1.7.0"],
1717
"modin-dask": ["modin", "dask"],
1818
"dask": ["dask"],
19-
}
19+
}
2020
extras_require = {
2121
**_extras_require,
2222
"all": list(set(x for y in _extras_require.values() for x in y)),
@@ -45,6 +45,7 @@
4545
"packaging >= 20.0",
4646
"numpy >= 1.9.0",
4747
"pandas >= 1.0",
48+
"pandas-stubs",
4849
"typing_extensions >= 3.7.4.3 ; python_version<'3.8'",
4950
"typing_inspect >= 0.6.0",
5051
"wrapt",

0 commit comments

Comments
 (0)