Skip to content

Commit ef4d4c1

Browse files
committed
fix categorical dtype isinstance checks
1 parent c775bbe commit ef4d4c1

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

explainerdashboard/explainer_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def merge_categorical_columns(
402402
cat_pieces.append(pd.DataFrame({col_name: merged_col}))
403403
else:
404404
if not drop_regular:
405-
if isinstance(X[col_name], pd.CategoricalDtype):
405+
if isinstance(X[col_name].dtype, pd.CategoricalDtype):
406406
cat_pieces.append(
407407
pd.DataFrame({col_name: pd.Categorical(X[col_name])})
408408
)

tests/test_catboost_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def test_get_col(precalculated_catboost_regression_explainer):
9595
precalculated_catboost_regression_explainer.get_col("Sex"), pd.Series
9696
)
9797
assert isinstance(
98-
precalculated_catboost_regression_explainer.get_col("Sex"), pd.CategoricalDtype
98+
precalculated_catboost_regression_explainer.get_col("Sex").dtype,
99+
pd.CategoricalDtype,
99100
)
100101

101102
assert isinstance(

tests/test_classifier_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ def test_get_col(precalculated_rf_classifier_explainer):
127127
precalculated_rf_classifier_explainer.get_col("Gender"), pd.Series
128128
)
129129
assert isinstance(
130-
precalculated_rf_classifier_explainer.get_col("Gender"), pd.CategoricalDtype
130+
precalculated_rf_classifier_explainer.get_col("Gender").dtype,
131+
pd.CategoricalDtype,
131132
)
132133

133134
assert isinstance(precalculated_rf_classifier_explainer.get_col("Deck"), pd.Series)
134135
assert isinstance(
135-
precalculated_rf_classifier_explainer.get_col("Deck"), pd.CategoricalDtype
136+
precalculated_rf_classifier_explainer.get_col("Deck").dtype, pd.CategoricalDtype
136137
)
137138

138139
assert isinstance(precalculated_rf_classifier_explainer.get_col("Age"), pd.Series)

tests/test_regression_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_get_col(precalculated_rf_regression_explainer):
7676
precalculated_rf_regression_explainer.get_col("Gender"), pd.Series
7777
)
7878
assert isinstance(
79-
precalculated_rf_regression_explainer.get_col("Gender"), pd.CategoricalDtype
79+
precalculated_rf_regression_explainer.get_col("Gender").dtype,
80+
pd.CategoricalDtype,
8081
)
8182

8283
assert isinstance(precalculated_rf_regression_explainer.get_col("Age"), pd.Series)

0 commit comments

Comments
 (0)