@@ -404,7 +404,6 @@ def validate(
404404 lazy : bool = False ,
405405 inplace : bool = False ,
406406 ) -> pd .DataFrame :
407- # pylint: disable=too-many-locals,too-many-branches,too-many-statements
408407 """Check if all columns in a dataframe have a column in the Schema.
409408
410409 :param pd.DataFrame check_obj: the dataframe to be validated.
@@ -460,6 +459,51 @@ def validate(
460459 5 0.76 dog
461460 """
462461
462+ if not check_utils .is_table (check_obj ):
463+ raise TypeError (f"expected pd.DataFrame, got { type (check_obj )} " )
464+
465+ if hasattr (check_obj , "dask" ):
466+ # special case for dask dataframes
467+ if inplace :
468+ check_obj = check_obj .pandera .add_schema (self )
469+ else :
470+ check_obj = check_obj .copy ()
471+
472+ check_obj = check_obj .map_partitions (
473+ self ._validate ,
474+ head = head ,
475+ tail = tail ,
476+ sample = sample ,
477+ random_state = random_state ,
478+ lazy = lazy ,
479+ inplace = inplace ,
480+ meta = check_obj ,
481+ )
482+
483+ return check_obj .pandera .add_schema (self )
484+
485+ return self ._validate (
486+ check_obj = check_obj ,
487+ head = head ,
488+ tail = tail ,
489+ sample = sample ,
490+ random_state = random_state ,
491+ lazy = lazy ,
492+ inplace = inplace ,
493+ )
494+
495+ def _validate (
496+ self ,
497+ check_obj : pd .DataFrame ,
498+ head : Optional [int ] = None ,
499+ tail : Optional [int ] = None ,
500+ sample : Optional [int ] = None ,
501+ random_state : Optional [int ] = None ,
502+ lazy : bool = False ,
503+ inplace : bool = False ,
504+ ) -> pd .DataFrame :
505+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
506+
463507 if self ._is_inferred :
464508 warnings .warn (
465509 f"This { type (self )} is an inferred schema that hasn't been "
@@ -2074,7 +2118,6 @@ def validate(
20742118 lazy : bool = False ,
20752119 inplace : bool = False ,
20762120 ) -> pd .Series :
2077- # pylint: disable=too-many-branches
20782121 """Validate a Series object.
20792122
20802123 :param check_obj: One-dimensional ndarray with axis labels
@@ -2118,8 +2161,48 @@ def validate(
21182161
21192162 """
21202163 if not check_utils .is_field (check_obj ):
2121- raise TypeError (f"expected { pd .Series } , got { type (check_obj )} " )
2164+ raise TypeError (f"expected pd.Series, got { type (check_obj )} " )
2165+
2166+ if hasattr (check_obj , "dask" ):
2167+ # special case for dask series
2168+ if inplace :
2169+ check_obj = check_obj .pandera .add_schema (self )
2170+ else :
2171+ check_obj = check_obj .copy ()
2172+
2173+ check_obj = check_obj .map_partitions (
2174+ self ._validate ,
2175+ head = head ,
2176+ tail = tail ,
2177+ sample = sample ,
2178+ random_state = random_state ,
2179+ lazy = lazy ,
2180+ inplace = inplace ,
2181+ meta = check_obj ,
2182+ )
21222183
2184+ return check_obj .pandera .add_schema (self )
2185+
2186+ return self ._validate (
2187+ check_obj = check_obj ,
2188+ head = head ,
2189+ tail = tail ,
2190+ sample = sample ,
2191+ random_state = random_state ,
2192+ lazy = lazy ,
2193+ inplace = inplace ,
2194+ )
2195+
2196+ def _validate (
2197+ self ,
2198+ check_obj : pd .Series ,
2199+ head : Optional [int ] = None ,
2200+ tail : Optional [int ] = None ,
2201+ sample : Optional [int ] = None ,
2202+ random_state : Optional [int ] = None ,
2203+ lazy : bool = False ,
2204+ inplace : bool = False ,
2205+ ) -> pd .Series :
21232206 if not inplace :
21242207 check_obj = check_obj .copy ()
21252208
0 commit comments