Skip to content

Commit de0ec5f

Browse files
authored
Add support for unique validation in PySpark (#1396)
* working, not the tests Signed-off-by: Filipe Oliveira <[email protected]> * tests working, missing docs Signed-off-by: Filipe Oliveira <[email protected]> * add suggestion to docs Signed-off-by: Filipe Oliveira <[email protected]> * fix failing test and add specific method for data-scoped validations Signed-off-by: Filipe Oliveira <[email protected]> * fix one code coverage issue Signed-off-by: Filipe Oliveira <[email protected]> * accept suggestions from Kasper Signed-off-by: Filipe Oliveira <[email protected]> * add condition and test for invalid column name and flattened the unique functions Signed-off-by: Filipe Oliveira <[email protected]> --------- Signed-off-by: Filipe Oliveira <[email protected]>
1 parent cf6b5e4 commit de0ec5f

File tree

5 files changed

+241
-44
lines changed

5 files changed

+241
-44
lines changed

docs/source/pyspark_sql.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,16 @@ We also provided a helper function to extract metadata from a schema as follows:
343343
.. note::
344344

345345
This feature is available for ``pyspark.sql`` and ``pandas`` both.
346+
347+
`unique` support
348+
----------------
349+
350+
*new in 0.17.3*
351+
352+
.. warning::
353+
354+
The `unique` support for PySpark-based validations to define which columns must be
355+
tested for unique values may incur in a performance hit, given Spark's distributed
356+
nature.
357+
358+
Use with caution.

pandera/backends/pyspark/container.py

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, List, Optional
77

88
from pyspark.sql import DataFrame
9-
from pyspark.sql.functions import col
9+
from pyspark.sql.functions import col, count
1010

1111
from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler
1212
from pandera.api.pyspark.types import is_table
@@ -15,7 +15,6 @@
1515
from pandera.backends.pyspark.error_formatters import scalar_failure_case
1616
from pandera.config import CONFIG
1717
from pandera.errors import (
18-
ParserError,
1918
SchemaDefinitionError,
2019
SchemaError,
2120
SchemaErrorReason,
@@ -31,14 +30,14 @@ def preprocess(self, check_obj: DataFrame, inplace: bool = False):
3130
return check_obj
3231

3332
@validate_scope(scope=ValidationScope.SCHEMA)
34-
def _column_checks(
33+
def _schema_checks(
3534
self,
3635
check_obj: DataFrame,
3736
schema,
3837
column_info: ColumnInfo,
3938
error_handler: ErrorHandler,
4039
):
41-
"""run the checks related to columns presence, uniqueness and filter column if neccesary"""
40+
"""run the checks related to columns presence, strictness and filter column if neccesary"""
4241

4342
# check the container metadata, e.g. field names
4443
try:
@@ -71,6 +70,7 @@ def _column_checks(
7170
reason_code=exc.reason_code,
7271
schema_error=exc,
7372
)
73+
7474
# try to coerce datatypes
7575
check_obj = self.coerce_dtype(
7676
check_obj,
@@ -80,6 +80,28 @@ def _column_checks(
8080

8181
return check_obj
8282

83+
@validate_scope(scope=ValidationScope.DATA)
84+
def _data_checks(
85+
self,
86+
check_obj: DataFrame,
87+
schema,
88+
column_info: ColumnInfo, # pylint: disable=unused-argument
89+
error_handler: ErrorHandler,
90+
):
91+
"""Run the checks related to data validation and uniqueness."""
92+
93+
# uniqueness of values
94+
try:
95+
check_obj = self.unique(
96+
check_obj, schema=schema, error_handler=error_handler
97+
)
98+
except SchemaError as err:
99+
error_handler.collect_error(
100+
ErrorCategory.DATA, err.reason_code, err
101+
)
102+
103+
return check_obj
104+
83105
def validate(
84106
self,
85107
check_obj: DataFrame,
@@ -115,8 +137,13 @@ def validate(
115137
check_obj = check_obj.pandera.add_schema(schema)
116138
column_info = self.collect_column_info(check_obj, schema, lazy)
117139

118-
# validate the columns of the dataframe
119-
check_obj = self._column_checks(
140+
# validate the columns (schema) of the dataframe
141+
check_obj = self._schema_checks(
142+
check_obj, schema, column_info, error_handler
143+
)
144+
145+
# validate the rows (data) of the dataframe
146+
check_obj = self._data_checks(
120147
check_obj, schema, column_info, error_handler
121148
)
122149

@@ -191,7 +218,7 @@ def run_checks(self, check_obj: DataFrame, schema, error_handler):
191218
check_results = []
192219
for check_index, check in enumerate(
193220
schema.checks
194-
): # schama.checks is null
221+
): # schema.checks is null
195222
try:
196223
check_results.append(
197224
self.run_check(check_obj, schema, check, check_index)
@@ -386,8 +413,7 @@ def coerce_dtype(
386413
except SchemaErrors as err:
387414
for schema_error_dict in err.schema_errors:
388415
if not error_handler.lazy:
389-
# raise the first error immediately if not doing lazy
390-
# validation
416+
# raise the first error immediately if not doing lazy validation
391417
raise schema_error_dict["error"]
392418
error_handler.collect_error(
393419
ErrorCategory.DTYPE_COERCION,
@@ -417,27 +443,6 @@ def _coerce_dtype(
417443
# NOTE: clean up the error handling!
418444
error_handler = ErrorHandler(lazy=True)
419445

420-
def _coerce_df_dtype(obj: DataFrame) -> DataFrame:
421-
if schema.dtype is None:
422-
raise ValueError(
423-
"dtype argument is None. Must specify this argument "
424-
"to coerce dtype"
425-
)
426-
427-
try:
428-
return schema.dtype.try_coerce(obj)
429-
except ParserError as exc:
430-
raise SchemaError(
431-
schema=schema,
432-
data=obj,
433-
message=(
434-
f"Error while coercing '{schema.name}' to type "
435-
f"{schema.dtype}: {exc}\n{exc.failure_cases}"
436-
),
437-
failure_cases=exc.failure_cases,
438-
check=f"coerce_dtype('{schema.dtype}')",
439-
) from exc
440-
441446
def _try_coercion(obj, colname, col_schema):
442447
try:
443448
schema = obj.pandera.schema
@@ -490,6 +495,74 @@ def _try_coercion(obj, colname, col_schema):
490495

491496
return obj
492497

498+
@validate_scope(scope=ValidationScope.DATA)
499+
def unique(
500+
self,
501+
check_obj: DataFrame,
502+
*,
503+
schema=None,
504+
error_handler: ErrorHandler = None,
505+
):
506+
"""Check uniqueness in the check object."""
507+
assert schema is not None, "The `schema` argument must be provided."
508+
assert (
509+
error_handler is not None
510+
), "The `error_handler` argument must be provided."
511+
512+
if not schema.unique:
513+
return check_obj
514+
515+
# Determine unique columns based on schema's config
516+
unique_columns = (
517+
[schema.unique]
518+
if isinstance(schema.unique, str)
519+
else schema.unique
520+
)
521+
522+
# Check if values belong to the dataframe columns
523+
missing_unique_columns = set(unique_columns) - set(check_obj.columns)
524+
if missing_unique_columns:
525+
raise SchemaDefinitionError(
526+
"Specified `unique` columns are missing in the dataframe: "
527+
f"{list(missing_unique_columns)}"
528+
)
529+
530+
duplicates_count = (
531+
check_obj.select(*unique_columns) # ignore other cols
532+
.groupby(*unique_columns)
533+
.agg(count("*").alias("pandera_duplicate_counts"))
534+
.filter(
535+
col("pandera_duplicate_counts") > 1
536+
) # long name to avoid colisions
537+
.count()
538+
)
539+
540+
if duplicates_count > 0:
541+
raise SchemaError(
542+
schema=schema,
543+
data=check_obj,
544+
message=(
545+
f"Duplicated rows [{duplicates_count}] were found "
546+
f"for columns {unique_columns}"
547+
),
548+
check="unique",
549+
reason_code=SchemaErrorReason.DUPLICATES,
550+
)
551+
552+
return check_obj
553+
554+
def _check_uniqueness(
555+
self,
556+
obj: DataFrame,
557+
schema,
558+
) -> DataFrame:
559+
"""Ensure uniqueness in dataframe columns.
560+
561+
:param obj: dataframe to check.
562+
:param schema: schema object.
563+
:returns: dataframe checked.
564+
"""
565+
493566
##########
494567
# Checks #
495568
##########
@@ -516,8 +589,7 @@ def check_column_names_are_unique(self, check_obj: DataFrame, schema):
516589
schema=schema,
517590
data=check_obj,
518591
message=(
519-
"dataframe contains multiple columns with label(s): "
520-
f"{failed}"
592+
f"dataframe contains multiple columns with label(s): {failed}"
521593
),
522594
failure_cases=scalar_failure_case(failed),
523595
check="dataframe_column_labels_unique",

pandera/backends/pyspark/decorators.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def validate_scope(scope: ValidationScope):
8181
def _wrapper(func):
8282
@functools.wraps(func)
8383
def wrapper(self, *args, **kwargs):
84+
def _get_check_obj():
85+
"""
86+
Get dataframe object passed as arg to the decorated func.
87+
88+
Returns:
89+
The DataFrame object.
90+
"""
91+
if args:
92+
for value in args:
93+
if isinstance(value, pyspark.sql.DataFrame):
94+
return value
95+
8496
if scope == ValidationScope.SCHEMA:
8597
if CONFIG.validation_depth in (
8698
ValidationDepth.SCHEMA_AND_DATA,
@@ -89,17 +101,12 @@ def wrapper(self, *args, **kwargs):
89101
return func(self, *args, **kwargs)
90102
else:
91103
warnings.warn(
92-
"Skipping Execution of function as parameters set to DATA_ONLY ",
104+
f"Skipping execution of function {func.__name__} as validation depth is set to DATA_ONLY ",
93105
stacklevel=2,
94106
)
95-
if not kwargs:
96-
for value in kwargs.values():
97-
if isinstance(value, pyspark.sql.DataFrame):
98-
return value
99-
if args:
100-
for value in args:
101-
if isinstance(value, pyspark.sql.DataFrame):
102-
return value
107+
# If the function was skip, return the `check_obj` value anyway,
108+
# given that some return value is expected
109+
return _get_check_obj()
103110

104111
elif scope == ValidationScope.DATA:
105112
if CONFIG.validation_depth in (
@@ -109,9 +116,12 @@ def wrapper(self, *args, **kwargs):
109116
return func(self, *args, **kwargs)
110117
else:
111118
warnings.warn(
112-
"Skipping Execution of function as parameters set to SCHEMA_ONLY ",
119+
f"Skipping execution of function {func.__name__} as validation depth is set to SCHEMA_ONLY",
113120
stacklevel=2,
114121
)
122+
# If the function was skip, return the `check_obj` value anyway,
123+
# given that some return value is expected
124+
return _get_check_obj()
115125

116126
return wrapper
117127

tests/pyspark/test_pyspark_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_schema_only(self, spark, sample_spark_schema):
5353
CONFIG.validation_enabled = True
5454
CONFIG.validation_depth = ValidationDepth.SCHEMA_ONLY
5555

56-
pandra_schema = DataFrameSchema(
56+
pandera_schema = DataFrameSchema(
5757
{
5858
"product": Column(T.StringType(), Check.str_startswith("B")),
5959
"price_val": Column(T.IntegerType()),
@@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema):
6767
assert CONFIG.dict() == expected
6868

6969
input_df = spark_df(spark, self.sample_data, sample_spark_schema)
70-
output_dataframeschema_df = pandra_schema.validate(input_df)
70+
output_dataframeschema_df = pandera_schema.validate(input_df)
7171
expected_dataframeschema = {
7272
"SCHEMA": {
7373
"COLUMN_NOT_IN_DATAFRAME": [

0 commit comments

Comments
 (0)