Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ExtensionArray,
ExtensionOpsMixin,
ExtensionScalarOpsMixin,
try_cast_to_ea,
)
from pandas.core.arrays.boolean import BooleanArray
from pandas.core.arrays.categorical import Categorical
Expand All @@ -19,7 +18,6 @@
"ExtensionArray",
"ExtensionOpsMixin",
"ExtensionScalarOpsMixin",
"try_cast_to_ea",
"BooleanArray",
"Categorical",
"DatetimeArray",
Expand Down
24 changes: 1 addition & 23 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pandas.util._decorators import Appender, Substitution
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.cast import try_cast_to_ea
from pandas.core.dtypes.common import is_array_like, is_list_like
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
Expand All @@ -32,29 +33,6 @@
_extension_array_shared_docs: Dict[str, str] = dict()


def try_cast_to_ea(cls_or_instance, obj, dtype=None):
"""
Call to `_from_sequence` that returns the object unchanged on Exception.

Parameters
----------
cls_or_instance : ExtensionArray subclass or instance
obj : arraylike
Values to pass to cls._from_sequence
dtype : ExtensionDtype, optional

Returns
-------
ExtensionArray or obj
"""
try:
result = cls_or_instance._from_sequence(obj, dtype=dtype)
except Exception:
# We can't predict what downstream EA constructors may raise
result = obj
return result


class ExtensionArray:
"""
Abstract base class for custom 1-D array types.
Expand Down
95 changes: 94 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
iNaT,
)
from pandas._libs.tslibs.timezones import tz_compare
from pandas._typing import Dtype
from pandas._typing import Dtype, DtypeObj
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -246,6 +246,99 @@ def trans(x):
return result


def maybe_cast_result(
result, obj: ABCSeries, numeric_only: bool = False, how: str = ""
):
"""
Try casting result to a different type if appropriate

Parameters
----------
result : array-like
Result to cast.
obj : ABCSeries
Input series from which result was calculated.
numeric_only : bool, default False
Whether to cast only numerics or datetimes as well.
how : str, default ""
How the result was computed.

Returns
-------
result : array-like
result maybe casted to the dtype.
"""
if obj.ndim > 1:
dtype = obj._values.dtype
else:
dtype = obj.dtype
dtype = maybe_cast_result_dtype(dtype, how)

if not is_scalar(result):
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.

if len(result) and isinstance(result[0], dtype.type):
cls = dtype.construct_array_type()
result = try_cast_to_ea(cls, result, dtype=dtype)

elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)

return result


def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.

Parameters
----------
dtype : DtypeObj
Input dtype.
how : str
How the result was computed.

Returns
-------
DtypeObj
The desired dtype of the result.
"""
d = {
(np.dtype(np.bool), "add"): np.dtype(np.int64),
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
}
return d.get((dtype, how), dtype)


def try_cast_to_ea(cls_or_instance, obj, dtype=None):
"""
Call to `_from_sequence` that returns the object unchanged on Exception.

Parameters
----------
cls_or_instance : ExtensionArray subclass or instance
obj : arraylike
Values to pass to cls._from_sequence
dtype : ExtensionDtype, optional

Returns
-------
ExtensionArray or obj
"""
try:
result = cls_or_instance._from_sequence(obj, dtype=dtype)
except Exception:
# We can't predict what downstream EA constructors may raise
result = obj
return result


def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
"""
A safe version of putmask that potentially upcasts the result.
Expand Down
14 changes: 9 additions & 5 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from pandas.util._decorators import Appender, Substitution

from pandas.core.dtypes.cast import (
maybe_cast_result,
maybe_cast_result_dtype,
maybe_convert_objects,
maybe_downcast_numeric,
maybe_downcast_to_dtype,
Expand Down Expand Up @@ -526,7 +528,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(result._values, ids)
if cast:
out = self._try_cast(out, self.obj)
out = maybe_cast_result(out, self.obj, how=func_nm)
return Series(out, index=self.obj.index, name=self.obj.name)

def filter(self, func, dropna=True, *args, **kwargs):
Expand Down Expand Up @@ -1072,8 +1074,10 @@ def _cython_agg_blocks(
assert not isinstance(result, DataFrame)

if result is not no_result:
# see if we can cast the block back to the original dtype
result = maybe_downcast_numeric(result, block.dtype)
# see if we can cast the block to the desired dtype
# this may not be the original dtype
dtype = maybe_cast_result_dtype(block.dtype, how)
result = maybe_downcast_numeric(result, dtype)

if block.is_extension and isinstance(result, np.ndarray):
# e.g. block.values was an IntegerArray
Expand Down Expand Up @@ -1175,7 +1179,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:

else:
if cast:
result[item] = self._try_cast(result[item], data)
result[item] = maybe_cast_result(result[item], data)

result_columns = obj.columns
if cannot_agg:
Expand Down Expand Up @@ -1460,7 +1464,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
# TODO: we have no test cases that get here with EA dtypes;
# try_cast may not be needed if EAs never get here
if cast:
res = self._try_cast(res, obj.iloc[:, i])
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
output.append(res)

return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)
Expand Down
45 changes: 7 additions & 38 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ class providing the base-class of operations.
from pandas.errors import AbstractMethodError
from pandas.util._decorators import Appender, Substitution, cache_readonly

from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.cast import maybe_cast_result
from pandas.core.dtypes.common import (
ensure_float,
is_datetime64_dtype,
is_extension_array_dtype,
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
Expand All @@ -53,7 +52,7 @@ class providing the base-class of operations.

from pandas.core import nanops
import pandas.core.algorithms as algorithms
from pandas.core.arrays import Categorical, DatetimeArray, try_cast_to_ea
from pandas.core.arrays import Categorical, DatetimeArray
from pandas.core.base import DataError, PandasObject, SelectionMixin
import pandas.core.common as com
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -792,36 +791,6 @@ def _cumcount_array(self, ascending: bool = True):
rev[sorter] = np.arange(count, dtype=np.intp)
return out[rev].astype(np.int64, copy=False)

def _try_cast(self, result, obj, numeric_only: bool = False):
"""
Try to cast the result to our obj original type,
we may have roundtripped through object in the mean-time.

If numeric_only is True, then only try to cast numerics
and not datetimelikes.

"""
if obj.ndim > 1:
dtype = obj._values.dtype
else:
dtype = obj.dtype

if not is_scalar(result):
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.

if len(result) and isinstance(result[0], dtype.type):
cls = dtype.construct_array_type()
result = try_cast_to_ea(cls, result, dtype=dtype)

elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)

return result

def _transform_should_cast(self, func_nm: str) -> bool:
"""
Parameters
Expand Down Expand Up @@ -852,7 +821,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
continue

if self._transform_should_cast(how):
result = self._try_cast(result, obj)
result = maybe_cast_result(result, obj, how=how)

key = base.OutputKey(label=name, position=idx)
output[key] = result
Expand Down Expand Up @@ -895,12 +864,12 @@ def _cython_agg_general(
assert len(agg_names) == result.shape[1]
for result_column, result_name in zip(result.T, agg_names):
key = base.OutputKey(label=result_name, position=idx)
output[key] = self._try_cast(result_column, obj)
output[key] = maybe_cast_result(result_column, obj, how=how)
idx += 1
else:
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
output[key] = self._try_cast(result, obj)
output[key] = maybe_cast_result(result, obj, how=how)
idx += 1

if len(output) == 0:
Expand Down Expand Up @@ -929,7 +898,7 @@ def _python_agg_general(self, func, *args, **kwargs):

assert result is not None
key = base.OutputKey(label=name, position=idx)
output[key] = self._try_cast(result, obj, numeric_only=True)
output[key] = maybe_cast_result(result, obj, numeric_only=True)

if len(output) == 0:
return self._python_apply_general(f)
Expand All @@ -944,7 +913,7 @@ def _python_agg_general(self, func, *args, **kwargs):
if is_numeric_dtype(values.dtype):
values = ensure_float(values)

output[key] = self._try_cast(values[mask], result)
output[key] = maybe_cast_result(values[mask], result)

return self._wrap_aggregated_output(output)

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from pandas.util._decorators import Appender, Substitution, doc
from pandas.util._validators import validate_bool_kwarg, validate_percentile

from pandas.core.dtypes.cast import convert_dtypes, validate_numeric_casting
from pandas.core.dtypes.cast import (
convert_dtypes,
try_cast_to_ea,
validate_numeric_casting,
)
from pandas.core.dtypes.common import (
_is_unorderable_exception,
ensure_platform_int,
Expand Down Expand Up @@ -59,7 +63,7 @@
import pandas as pd
from pandas.core import algorithms, base, generic, nanops, ops
from pandas.core.accessor import CachedAccessor
from pandas.core.arrays import ExtensionArray, try_cast_to_ea
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
import pandas.core.common as com
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_integer_dtype

import pandas as pd
from pandas import DataFrame, Index, MultiIndex, Series, concat
import pandas._testing as tm
Expand Down Expand Up @@ -340,6 +342,30 @@ def test_groupby_agg_coercing_bools():
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"op",
[
lambda x: x.sum(),
lambda x: x.cumsum(),
lambda x: x.transform("sum"),
lambda x: x.transform("cumsum"),
lambda x: x.agg("sum"),
lambda x: x.agg("cumsum"),
],
)
def test_bool_agg_dtype(op):
# GH 7001
# Bool sum aggregations result in int
df = pd.DataFrame({"a": [1, 1], "b": [False, True]})
s = df.set_index("a")["b"]

result = op(df.groupby("a"))["b"].dtype
assert is_integer_dtype(result)

result = op(s.groupby("a")).dtype
assert is_integer_dtype(result)


def test_order_aggregate_multiple_funcs():
# GH 25692
df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]})
Expand Down