diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 33ff182f5baee..7a81331e2b7a9 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -13,6 +13,10 @@ from pandas._libs import missing as libmissing from pandas._libs.sparse import IntIndex +from pandas.core.dtypes.cast import ( + find_common_type, + infer_dtype_from_scalar, +) from pandas.core.dtypes.common import ( is_integer_dtype, is_list_like, @@ -567,7 +571,13 @@ def from_dummies( ) else: data_slice = data_to_decode.loc[:, prefix_slice] - cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype) + dtype = data.columns.dtype + if default_category: + default_category_dtype = infer_dtype_from_scalar(default_category[prefix])[ + 0 + ] + dtype = find_common_type([dtype, default_category_dtype]) + cats_array = data._constructor_sliced(cats, dtype=dtype) # get indices of True entries along axis=1 true_values = data_slice.idxmax(axis=1) indexer = data_slice.columns.get_indexer_for(true_values) diff --git a/pandas/tests/reshape/test_from_dummies.py b/pandas/tests/reshape/test_from_dummies.py index da1930323f464..063bd22a0c511 100644 --- a/pandas/tests/reshape/test_from_dummies.py +++ b/pandas/tests/reshape/test_from_dummies.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas import ( DataFrame, Series, @@ -330,14 +328,10 @@ def test_no_prefix_string_cats_contains_get_dummies_NaN_column(): ), ], ) -def test_no_prefix_string_cats_default_category( - default_category, expected, using_infer_string -): +def test_no_prefix_string_cats_default_category(default_category, expected): dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]}) result = from_dummies(dummies, default_category=default_category) expected = DataFrame(expected) - if using_infer_string: - expected[""] = expected[""].astype("str") tm.assert_frame_equal(result, expected) @@ -364,7 +358,6 @@ def test_with_prefix_contains_get_dummies_NaN_column(): tm.assert_frame_equal(result, expected) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.parametrize( "default_category, expected", [ @@ -390,7 +383,7 @@ def test_with_prefix_contains_get_dummies_NaN_column(): ), pytest.param( {"col2": None, "col1": False}, - {"col1": ["a", "b", False], "col2": [None, "a", "c"]}, + {"col1": ["a", "b", False], "col2": Series([None, "a", "c"], dtype=object)}, id="default_category is a dict with bool and None values", ), pytest.param( diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 34a6377b5786f..30e6ebf0eed13 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -293,23 +293,12 @@ def test_startswith_endswith_validate_na(any_string_dtype): dtype=any_string_dtype, ) - dtype = ser.dtype - if (isinstance(dtype, pd.StringDtype)) or dtype == np.dtype("object"): - msg = "Allowing a non-bool 'na' in obj.str.startswith is deprecated" - with tm.assert_produces_warning(FutureWarning, match=msg): - ser.str.startswith("kapow", na="baz") - msg = "Allowing a non-bool 'na' in obj.str.endswith is deprecated" - with tm.assert_produces_warning(FutureWarning, match=msg): - ser.str.endswith("bar", na="baz") - else: - # TODO(infer_string): don't surface pyarrow errors - import pyarrow as pa - - msg = "Could not convert 'baz' with type str: tried to convert to boolean" - with pytest.raises(pa.lib.ArrowInvalid, match=msg): - ser.str.startswith("kapow", na="baz") - with pytest.raises(pa.lib.ArrowInvalid, match=msg): - ser.str.endswith("kapow", na="baz") + msg = "Allowing a non-bool 'na' in obj.str.startswith is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + ser.str.startswith("kapow", na="baz") + msg = "Allowing a non-bool 'na' in obj.str.endswith is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + ser.str.endswith("bar", na="baz") @pytest.mark.parametrize("pat", ["foo", ("foo", "baz")])