Skip to content

Commit 1251ded

Browse files
authored
feat: (Preview) Support aggregations over timedeltas (#1418)
* feat: (Preview) Support aggregations over timedeltas * rename variable
1 parent aeb5063 commit 1251ded

File tree

4 files changed

+109
-12
lines changed

4 files changed

+109
-12
lines changed

bigframes/core/compile/aggregate_compiler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ def _(
231231
column: ibis_types.NumericColumn,
232232
window=None,
233233
) -> ibis_types.NumericValue:
234-
return _apply_window_if_present(column.quantile(op.q), window)
234+
result = column.quantile(op.q)
235+
if op.should_floor_result:
236+
result = result.floor() # type:ignore
237+
238+
return _apply_window_if_present(result, window)
235239

236240

237241
@compile_unary_agg.register
@@ -242,7 +246,8 @@ def _(
242246
window=None,
243247
# order_by: typing.Sequence[ibis_types.Value] = [],
244248
) -> ibis_types.NumericValue:
245-
return _apply_window_if_present(column.mean(), window)
249+
result = column.mean().floor() if op.should_floor_result else column.mean()
250+
return _apply_window_if_present(result, window)
246251

247252

248253
@compile_unary_agg.register
@@ -306,10 +311,11 @@ def _(
306311
@numeric_op
307312
def _(
308313
op: agg_ops.StdOp,
309-
x: ibis_types.Column,
314+
x: ibis_types.NumericColumn,
310315
window=None,
311316
) -> ibis_types.Value:
312-
return _apply_window_if_present(cast(ibis_types.NumericColumn, x).std(), window)
317+
result = x.std().floor() if op.should_floor_result else x.std()
318+
return _apply_window_if_present(result, window)
313319

314320

315321
@compile_unary_agg.register

bigframes/core/rewrite/timedeltas.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
7070
root.skip_reproject_unsafe,
7171
)
7272

73+
if isinstance(root, nodes.AggregateNode):
74+
updated_aggregations = tuple(
75+
(_rewrite_aggregation(agg, root.child.schema), col_id)
76+
for agg, col_id in root.aggregations
77+
)
78+
return nodes.AggregateNode(
79+
root.child,
80+
updated_aggregations,
81+
root.by_column_ids,
82+
root.order_by,
83+
root.dropna,
84+
)
85+
7386
return root
7487

7588

@@ -196,17 +209,34 @@ def _rewrite_aggregation(
196209
) -> ex.Aggregation:
197210
if not isinstance(aggregation, ex.UnaryAggregation):
198211
return aggregation
199-
if not isinstance(aggregation.op, aggs.DiffOp):
200-
return aggregation
201212

202213
if isinstance(aggregation.arg, ex.DerefOp):
203214
input_type = schema.get_type(aggregation.arg.id.sql)
204215
else:
205216
input_type = aggregation.arg.dtype
206217

207-
if dtypes.is_datetime_like(input_type):
218+
if isinstance(aggregation.op, aggs.DiffOp) and dtypes.is_datetime_like(input_type):
208219
return ex.UnaryAggregation(
209220
aggs.TimeSeriesDiffOp(aggregation.op.periods), aggregation.arg
210221
)
211222

223+
if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE:
224+
return ex.UnaryAggregation(
225+
aggs.StdOp(should_floor_result=True), aggregation.arg
226+
)
227+
228+
if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE:
229+
return ex.UnaryAggregation(
230+
aggs.MeanOp(should_floor_result=True), aggregation.arg
231+
)
232+
233+
if (
234+
isinstance(aggregation.op, aggs.QuantileOp)
235+
and input_type is dtypes.TIMEDELTA_DTYPE
236+
):
237+
return ex.UnaryAggregation(
238+
aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True),
239+
aggregation.arg,
240+
)
241+
212242
return aggregation

bigframes/operations/aggregations.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,16 @@ class SumOp(UnaryAggregateOp):
142142
name: ClassVar[str] = "sum"
143143

144144
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
145-
if not dtypes.is_numeric(input_types[0]):
146-
raise TypeError(f"Type {input_types[0]} is not numeric")
147-
if pd.api.types.is_bool_dtype(input_types[0]):
148-
return dtypes.INT_DTYPE
149-
else:
145+
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
146+
return dtypes.TIMEDELTA_DTYPE
147+
148+
if dtypes.is_numeric(input_types[0]):
149+
if pd.api.types.is_bool_dtype(input_types[0]):
150+
return dtypes.INT_DTYPE
150151
return input_types[0]
151152

153+
raise TypeError(f"Type {input_types[0]} is not numeric or timedelta")
154+
152155

153156
@dataclasses.dataclass(frozen=True)
154157
class MedianOp(UnaryAggregateOp):
@@ -171,6 +174,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
171174
@dataclasses.dataclass(frozen=True)
172175
class QuantileOp(UnaryAggregateOp):
173176
q: float
177+
should_floor_result: bool = False
174178

175179
@property
176180
def name(self):
@@ -181,6 +185,8 @@ def order_independent(self) -> bool:
181185
return True
182186

183187
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
188+
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
189+
return dtypes.TIMEDELTA_DTYPE
184190
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])
185191

186192

@@ -224,7 +230,11 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
224230
class MeanOp(UnaryAggregateOp):
225231
name: ClassVar[str] = "mean"
226232

233+
should_floor_result: bool = False
234+
227235
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
236+
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
237+
return dtypes.TIMEDELTA_DTYPE
228238
return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0])
229239

230240

@@ -262,7 +272,12 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
262272
class StdOp(UnaryAggregateOp):
263273
name: ClassVar[str] = "std"
264274

275+
should_floor_result: bool = False
276+
265277
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
278+
if input_types[0] is dtypes.TIMEDELTA_DTYPE:
279+
return dtypes.TIMEDELTA_DTYPE
280+
266281
return signatures.FixedOutputType(
267282
dtypes.is_numeric, dtypes.FLOAT_DTYPE, "numeric"
268283
).output_type(input_types[0])

tests/system/small/operations/test_timedeltas.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,49 @@ def test_timedelta_ordering(session):
465465
pandas.testing.assert_series_equal(
466466
actual_result, expected_result, check_index_type=False
467467
)
468+
469+
470+
def test_timedelta_cumsum(temporal_dfs):
471+
bf_df, pd_df = temporal_dfs
472+
473+
actual_result = bf_df["timedelta_col_1"].cumsum().to_pandas()
474+
475+
expected_result = pd_df["timedelta_col_1"].cumsum()
476+
_assert_series_equal(actual_result, expected_result)
477+
478+
479+
@pytest.mark.parametrize(
480+
"agg_func",
481+
[
482+
pytest.param(lambda x: x.min(), id="min"),
483+
pytest.param(lambda x: x.max(), id="max"),
484+
pytest.param(lambda x: x.sum(), id="sum"),
485+
pytest.param(lambda x: x.mean(), id="mean"),
486+
pytest.param(lambda x: x.median(), id="median"),
487+
pytest.param(lambda x: x.quantile(0.5), id="quantile"),
488+
pytest.param(lambda x: x.std(), id="std"),
489+
],
490+
)
491+
def test_timedelta_agg__timedelta_result(temporal_dfs, agg_func):
492+
bf_df, pd_df = temporal_dfs
493+
494+
actual_result = agg_func(bf_df["timedelta_col_1"])
495+
496+
expected_result = agg_func(pd_df["timedelta_col_1"]).floor("us")
497+
assert actual_result == expected_result
498+
499+
500+
@pytest.mark.parametrize(
501+
"agg_func",
502+
[
503+
pytest.param(lambda x: x.count(), id="count"),
504+
pytest.param(lambda x: x.nunique(), id="nunique"),
505+
],
506+
)
507+
def test_timedelta_agg__int_result(temporal_dfs, agg_func):
508+
bf_df, pd_df = temporal_dfs
509+
510+
actual_result = agg_func(bf_df["timedelta_col_1"])
511+
512+
expected_result = agg_func(pd_df["timedelta_col_1"])
513+
assert actual_result == expected_result

0 commit comments

Comments
 (0)