Skip to content

Commit e3a9618

Browse files
authored
REF: re-use machinery for DataFrameGroupBy.nunique (#41390)
1 parent e23a1d3 commit e3a9618

File tree

4 files changed

+40
-47
lines changed

4 files changed

+40
-47
lines changed

pandas/core/groupby/generic.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
Mapping,
2323
TypeVar,
2424
Union,
25-
cast,
2625
)
2726
import warnings
2827

@@ -1626,6 +1625,10 @@ def _wrap_aggregated_output(
16261625

16271626
if self.axis == 1:
16281627
result = result.T
1628+
if result.index.equals(self.obj.index):
1629+
# Retain e.g. DatetimeIndex/TimedeltaIndex freq
1630+
result.index = self.obj.index.copy()
1631+
# TODO: Do this more systematically
16291632

16301633
return self._reindex_output(result)
16311634

@@ -1677,21 +1680,21 @@ def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
16771680

16781681
return self._reindex_output(result)._convert(datetime=True)
16791682

1680-
def _iterate_column_groupbys(self):
1681-
for i, colname in enumerate(self._selected_obj.columns):
1683+
def _iterate_column_groupbys(self, obj: FrameOrSeries):
1684+
for i, colname in enumerate(obj.columns):
16821685
yield colname, SeriesGroupBy(
1683-
self._selected_obj.iloc[:, i],
1686+
obj.iloc[:, i],
16841687
selection=colname,
16851688
grouper=self.grouper,
16861689
exclusions=self.exclusions,
16871690
)
16881691

1689-
def _apply_to_column_groupbys(self, func) -> DataFrame:
1692+
def _apply_to_column_groupbys(self, func, obj: FrameOrSeries) -> DataFrame:
16901693
from pandas.core.reshape.concat import concat
16911694

1692-
columns = self._selected_obj.columns
1695+
columns = obj.columns
16931696
results = [
1694-
func(col_groupby) for _, col_groupby in self._iterate_column_groupbys()
1697+
func(col_groupby) for _, col_groupby in self._iterate_column_groupbys(obj)
16951698
]
16961699

16971700
if not len(results):
@@ -1778,41 +1781,21 @@ def nunique(self, dropna: bool = True) -> DataFrame:
17781781
4 ham 5 x
17791782
5 ham 5 y
17801783
"""
1781-
from pandas.core.reshape.concat import concat
17821784

1783-
# TODO: this is duplicative of how GroupBy naturally works
1784-
# Try to consolidate with normal wrapping functions
1785+
if self.axis != 0:
1786+
# see test_groupby_crash_on_nunique
1787+
return self._python_agg_general(lambda sgb: sgb.nunique(dropna))
17851788

17861789
obj = self._obj_with_exclusions
1787-
if self.axis == 0:
1788-
iter_func = obj.items
1789-
else:
1790-
iter_func = obj.iterrows
1791-
1792-
res_list = [
1793-
SeriesGroupBy(content, selection=label, grouper=self.grouper).nunique(
1794-
dropna
1795-
)
1796-
for label, content in iter_func()
1797-
]
1798-
if res_list:
1799-
results = concat(res_list, axis=1)
1800-
results = cast(DataFrame, results)
1801-
else:
1802-
# concat would raise
1803-
results = DataFrame(
1804-
[], index=self.grouper.result_index, columns=obj.columns[:0]
1805-
)
1806-
1807-
if self.axis == 1:
1808-
results = results.T
1809-
1810-
other_axis = 1 - self.axis
1811-
results._get_axis(other_axis).names = obj._get_axis(other_axis).names
1790+
results = self._apply_to_column_groupbys(
1791+
lambda sgb: sgb.nunique(dropna), obj=obj
1792+
)
1793+
results.columns.names = obj.columns.names # TODO: do at higher level?
18121794

18131795
if not self.as_index:
18141796
results.index = ibase.default_index(len(results))
18151797
self._insert_inaxis_grouper_inplace(results)
1798+
18161799
return results
18171800

18181801
@Appender(DataFrame.idxmax.__doc__)

pandas/core/groupby/groupby.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1882,7 +1882,9 @@ def ohlc(self) -> DataFrame:
18821882
)
18831883
return self._reindex_output(result)
18841884

1885-
return self._apply_to_column_groupbys(lambda x: x.ohlc())
1885+
return self._apply_to_column_groupbys(
1886+
lambda x: x.ohlc(), self._obj_with_exclusions
1887+
)
18861888

18871889
@final
18881890
@doc(DataFrame.describe)

pandas/tests/groupby/test_groupby.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,24 +2058,36 @@ def test_dup_labels_output_shape(groupby_func, idx):
20582058

20592059
def test_groupby_crash_on_nunique(axis):
20602060
# Fix following 30253
2061+
dti = date_range("2016-01-01", periods=2, name="foo")
20612062
df = DataFrame({("A", "B"): [1, 2], ("A", "C"): [1, 3], ("D", "B"): [0, 0]})
2063+
df.columns.names = ("bar", "baz")
2064+
df.index = dti
20622065

20632066
axis_number = df._get_axis_number(axis)
20642067
if not axis_number:
20652068
df = df.T
20662069

2067-
result = df.groupby(axis=axis_number, level=0).nunique()
2070+
gb = df.groupby(axis=axis_number, level=0)
2071+
result = gb.nunique()
20682072

2069-
expected = DataFrame({"A": [1, 2], "D": [1, 1]})
2073+
expected = DataFrame({"A": [1, 2], "D": [1, 1]}, index=dti)
2074+
expected.columns.name = "bar"
20702075
if not axis_number:
20712076
expected = expected.T
20722077

20732078
tm.assert_frame_equal(result, expected)
20742079

2075-
# same thing, but empty columns
2076-
gb = df[[]].groupby(axis=axis_number, level=0)
2077-
res = gb.nunique()
2078-
exp = expected[[]]
2080+
if axis_number == 0:
2081+
# same thing, but empty columns
2082+
gb2 = df[[]].groupby(axis=axis_number, level=0)
2083+
exp = expected[[]]
2084+
else:
2085+
# same thing, but empty rows
2086+
gb2 = df.loc[[]].groupby(axis=axis_number, level=0)
2087+
# default for empty when we can't infer a dtype is float64
2088+
exp = expected.loc[[]].astype(np.float64)
2089+
2090+
res = gb2.nunique()
20792091
tm.assert_frame_equal(res, exp)
20802092

20812093

pandas/tests/resample/test_time_grouper.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,8 @@ def test_aaa_group_order():
121121
tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 5)), df[4::5])
122122

123123

124-
def test_aggregate_normal(request, resample_method):
124+
def test_aggregate_normal(resample_method):
125125
"""Check TimeGrouper's aggregation is identical as normal groupby."""
126-
if resample_method == "ohlc":
127-
request.node.add_marker(
128-
pytest.mark.xfail(reason="DataError: No numeric types to aggregate")
129-
)
130126

131127
data = np.random.randn(20, 4)
132128
normal_df = DataFrame(data, columns=["A", "B", "C", "D"])

0 commit comments

Comments
 (0)