Skip to content

Commit 02063c8

Browse files
Brian PhillipscosmicBboy
authored andcommitted
Add Basic Dask Support (#665)
* first pass of basic Dask support * cleanup docstrings * cleanup after rebase * improve coverage * update CI for new extra * cover branches for dask not installed * more coverage improvements * further coverage improvements
1 parent b7f6516 commit 02063c8

File tree

13 files changed

+372
-13
lines changed

13 files changed

+372
-13
lines changed

.github/workflows/ci-tests.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ jobs:
124124
--non-interactive
125125
--session "tests-${{ matrix.python-version }}(extra='core', pandas='${{ matrix.pandas-version }}')"
126126
127+
- name: Unit Tests - Dask
128+
run: >
129+
nox
130+
-db virtualenv -r -v
131+
--non-interactive
132+
--session "tests-${{ matrix.python-version }}(extra='dask', pandas='${{ matrix.pandas-version }}')"
133+
127134
- name: Unit Tests - Hypotheses
128135
run: >
129136
nox

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ dependencies:
2727
# modin extra
2828
- modin
2929
- ray
30+
31+
# dask extra
3032
- dask
3133
- distributed
3234

pandera/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,11 @@
5959
if platform.system() != "Windows":
6060
# pylint: disable=ungrouped-imports
6161
from pandera.dtypes import Complex256, Float128
62+
63+
64+
try:
65+
import dask.dataframe
66+
67+
from . import dask_accessor
68+
except ImportError:
69+
pass

pandera/check_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
SupportedTypes = NamedTuple(
99
"SupportedTypes",
1010
(
11-
("table_types", Tuple[type]),
12-
("field_types", Tuple[type]),
13-
("index_types", Tuple[type]),
14-
("multiindex_types", Tuple[type]),
11+
("table_types", Tuple[type, ...]),
12+
("field_types", Tuple[type, ...]),
13+
("index_types", Tuple[type, ...]),
14+
("multiindex_types", Tuple[type, ...]),
1515
),
1616
)
1717

@@ -42,6 +42,14 @@ def _supported_types():
4242
multiindex_types.append(mpd.MultiIndex)
4343
except ImportError:
4444
pass
45+
try:
46+
import dask.dataframe as dd
47+
48+
table_types.append(dd.DataFrame)
49+
field_types.append(dd.Series)
50+
index_types.append(dd.Index)
51+
except ImportError:
52+
pass
4553

4654
return SupportedTypes(
4755
tuple(table_types),

pandera/dask_accessor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Register dask accessor for pandera schema metadata."""
2+
3+
from dask.dataframe.extensions import (
4+
register_dataframe_accessor,
5+
register_series_accessor,
6+
)
7+
8+
from pandera.pandas_accessor import (
9+
PanderaDataFrameAccessor,
10+
PanderaSeriesAccessor,
11+
)
12+
13+
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
14+
register_series_accessor("pandera")(PanderaSeriesAccessor)

pandera/schemas.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ def validate(
404404
lazy: bool = False,
405405
inplace: bool = False,
406406
) -> pd.DataFrame:
407-
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
408407
"""Check if all columns in a dataframe have a column in the Schema.
409408
410409
:param pd.DataFrame check_obj: the dataframe to be validated.
@@ -460,6 +459,51 @@ def validate(
460459
5 0.76 dog
461460
"""
462461

462+
if not check_utils.is_table(check_obj):
463+
raise TypeError(f"expected pd.DataFrame, got {type(check_obj)}")
464+
465+
if hasattr(check_obj, "dask"):
466+
# special case for dask dataframes
467+
if inplace:
468+
check_obj = check_obj.pandera.add_schema(self)
469+
else:
470+
check_obj = check_obj.copy()
471+
472+
check_obj = check_obj.map_partitions(
473+
self._validate,
474+
head=head,
475+
tail=tail,
476+
sample=sample,
477+
random_state=random_state,
478+
lazy=lazy,
479+
inplace=inplace,
480+
meta=check_obj,
481+
)
482+
483+
return check_obj.pandera.add_schema(self)
484+
485+
return self._validate(
486+
check_obj=check_obj,
487+
head=head,
488+
tail=tail,
489+
sample=sample,
490+
random_state=random_state,
491+
lazy=lazy,
492+
inplace=inplace,
493+
)
494+
495+
def _validate(
496+
self,
497+
check_obj: pd.DataFrame,
498+
head: Optional[int] = None,
499+
tail: Optional[int] = None,
500+
sample: Optional[int] = None,
501+
random_state: Optional[int] = None,
502+
lazy: bool = False,
503+
inplace: bool = False,
504+
) -> pd.DataFrame:
505+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
506+
463507
if self._is_inferred:
464508
warnings.warn(
465509
f"This {type(self)} is an inferred schema that hasn't been "
@@ -2074,7 +2118,6 @@ def validate(
20742118
lazy: bool = False,
20752119
inplace: bool = False,
20762120
) -> pd.Series:
2077-
# pylint: disable=too-many-branches
20782121
"""Validate a Series object.
20792122
20802123
:param check_obj: One-dimensional ndarray with axis labels
@@ -2118,8 +2161,48 @@ def validate(
21182161
21192162
"""
21202163
if not check_utils.is_field(check_obj):
2121-
raise TypeError(f"expected {pd.Series}, got {type(check_obj)}")
2164+
raise TypeError(f"expected pd.Series, got {type(check_obj)}")
2165+
2166+
if hasattr(check_obj, "dask"):
2167+
# special case for dask series
2168+
if inplace:
2169+
check_obj = check_obj.pandera.add_schema(self)
2170+
else:
2171+
check_obj = check_obj.copy()
2172+
2173+
check_obj = check_obj.map_partitions(
2174+
self._validate,
2175+
head=head,
2176+
tail=tail,
2177+
sample=sample,
2178+
random_state=random_state,
2179+
lazy=lazy,
2180+
inplace=inplace,
2181+
meta=check_obj,
2182+
)
21222183

2184+
return check_obj.pandera.add_schema(self)
2185+
2186+
return self._validate(
2187+
check_obj=check_obj,
2188+
head=head,
2189+
tail=tail,
2190+
sample=sample,
2191+
random_state=random_state,
2192+
lazy=lazy,
2193+
inplace=inplace,
2194+
)
2195+
2196+
def _validate(
2197+
self,
2198+
check_obj: pd.Series,
2199+
head: Optional[int] = None,
2200+
tail: Optional[int] = None,
2201+
sample: Optional[int] = None,
2202+
random_state: Optional[int] = None,
2203+
lazy: bool = False,
2204+
inplace: bool = False,
2205+
) -> pd.Series:
21232206
if not inplace:
21242207
check_obj = check_obj.copy()
21252208

pandera/typing.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
except ImportError:
2121
ModelField = Any # type: ignore
2222

23+
try:
24+
import dask.dataframe as dd
25+
26+
_DASK_INSTALLED = True
27+
except ImportError:
28+
_DASK_INSTALLED = False
29+
2330
Bool = dtypes.Bool #: ``"bool"`` numpy dtype
2431
DateTime = dtypes.DateTime #: ``"datetime64[ns]"`` numpy dtype
2532
Timedelta = dtypes.Timedelta #: ``"timedelta64[ns]"`` numpy dtype
@@ -178,6 +185,15 @@ def _pydantic_validate(
178185
raise ValueError(str(exc)) from exc
179186

180187

188+
if _DASK_INSTALLED:
189+
# pylint:disable=too-few-public-methods
190+
class DaskDataFrame(dd.DataFrame, Generic[T]):
191+
"""
192+
Representation of dask.dataframe.DataFrame, only used for type
193+
annotation.
194+
"""
195+
196+
181197
class AnnotationInfo: # pylint:disable=too-few-public-methods
182198
"""Captures extra information about an annotation.
183199
@@ -195,11 +211,16 @@ def __init__(self, raw_annotation: Type) -> None:
195211

196212
@property
197213
def is_generic_df(self) -> bool:
198-
"""True if the annotation is a pandera.typing.DataFrame."""
214+
"""True if the annotation is a pandera.typing.DataFrame or
215+
pandera.typing.DaskDataFrame.
216+
"""
199217
try:
200-
return self.origin is not None and issubclass(
201-
self.origin, DataFrame
202-
)
218+
if self.origin is None:
219+
return False
220+
if _DASK_INSTALLED:
221+
return issubclass(self.origin, (DataFrame, DaskDataFrame))
222+
else:
223+
return issubclass(self.origin, DataFrame)
203224
except TypeError:
204225
return False
205226

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"modin": ["modin", "ray", "dask"],
1616
"modin-ray": ["modin", "ray"],
1717
"modin-dask": ["modin", "dask"],
18+
"dask": ["dask"],
1819
}
1920
extras_require = {
2021
**_extras_require,

tests/core/test_pandas_accessor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for pandas_accessor module."""
22
from typing import Union
3+
from unittest.mock import patch
34

45
import pandas as pd
56
import pytest
@@ -49,8 +50,20 @@ def test_dataframe_series_add_schema(
4950
assert validated_data_1.pandera.schema == schema1
5051
assert validated_data_2.pandera.schema == schema2
5152

52-
with pytest.raises(TypeError):
53+
with pytest.raises(TypeError, match=f"expected pd.{type(data).__name__}"):
5354
schema1(invalid_data)
5455

55-
with pytest.raises(TypeError):
56+
with pytest.raises(TypeError, match=f"expected pd.{type(data).__name__}"):
5657
schema2(invalid_data)
58+
59+
with patch.object(pa.schemas.check_utils, "is_table", return_value=True):
60+
with patch.object(
61+
pa.schemas.check_utils,
62+
"is_field",
63+
return_value=True,
64+
):
65+
with pytest.raises(TypeError, match="schema arg"):
66+
schema1(invalid_data)
67+
68+
with pytest.raises(TypeError, match="schema arg"):
69+
schema2(invalid_data)

tests/dask/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)