Skip to content

Commit c57f04e

Browse files
authored
refactor: provide infrastructure for SQLGlot aggregations compiler (#1926)
Fixes internal issue 431277229
1 parent 9e8c426 commit c57f04e

File tree

15 files changed

+425
-73
lines changed

15 files changed

+425
-73
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import functools
17-
import typing
18-
1916
import sqlglot.expressions as sge
2017

21-
from bigframes.core import expression, window_spec
18+
from bigframes.core import expression
19+
from bigframes.core.compile.sqlglot.aggregations import (
20+
binary_compiler,
21+
nullary_compiler,
22+
ordered_unary_compiler,
23+
unary_compiler,
24+
)
2225
from bigframes.core.compile.sqlglot.expressions import typed_expr
2326
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
24-
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
25-
import bigframes.operations as ops
2627

2728

2829
def compile_aggregate(
@@ -31,16 +32,18 @@ def compile_aggregate(
3132
) -> sge.Expression:
3233
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
3334
if isinstance(aggregate, expression.NullaryAggregation):
34-
return compile_nullary_agg(aggregate.op)
35+
return nullary_compiler.compile(aggregate.op)
3536
if isinstance(aggregate, expression.UnaryAggregation):
3637
column = typed_expr.TypedExpr(
3738
scalar_compiler.compile_scalar_expression(aggregate.arg),
3839
aggregate.arg.output_type,
3940
)
4041
if not aggregate.op.order_independent:
41-
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by)
42+
return ordered_unary_compiler.compile(
43+
aggregate.op, column, order_by=order_by
44+
)
4245
else:
43-
return compile_unary_agg(aggregate.op, column)
46+
return unary_compiler.compile(aggregate.op, column)
4447
elif isinstance(aggregate, expression.BinaryAggregation):
4548
left = typed_expr.TypedExpr(
4649
scalar_compiler.compile_scalar_expression(aggregate.left),
@@ -50,63 +53,6 @@ def compile_aggregate(
5053
scalar_compiler.compile_scalar_expression(aggregate.right),
5154
aggregate.right.output_type,
5255
)
53-
return compile_binary_agg(aggregate.op, left, right)
56+
return binary_compiler.compile(aggregate.op, left, right)
5457
else:
5558
raise ValueError(f"Unexpected aggregation: {aggregate}")
56-
57-
58-
@functools.singledispatch
59-
def compile_nullary_agg(
60-
op: ops.aggregations.WindowOp,
61-
window: typing.Optional[window_spec.WindowSpec] = None,
62-
) -> sge.Expression:
63-
raise ValueError(f"Can't compile unrecognized operation: {op}")
64-
65-
66-
@functools.singledispatch
67-
def compile_binary_agg(
68-
op: ops.aggregations.WindowOp,
69-
left: typed_expr.TypedExpr,
70-
right: typed_expr.TypedExpr,
71-
window: typing.Optional[window_spec.WindowSpec] = None,
72-
) -> sge.Expression:
73-
raise ValueError(f"Can't compile unrecognized operation: {op}")
74-
75-
76-
@functools.singledispatch
77-
def compile_unary_agg(
78-
op: ops.aggregations.WindowOp,
79-
column: typed_expr.TypedExpr,
80-
window: typing.Optional[window_spec.WindowSpec] = None,
81-
) -> sge.Expression:
82-
raise ValueError(f"Can't compile unrecognized operation: {op}")
83-
84-
85-
@functools.singledispatch
86-
def compile_ordered_unary_agg(
87-
op: ops.aggregations.WindowOp,
88-
column: typed_expr.TypedExpr,
89-
window: typing.Optional[window_spec.WindowSpec] = None,
90-
order_by: typing.Sequence[sge.Expression] = [],
91-
) -> sge.Expression:
92-
raise ValueError(f"Can't compile unrecognized operation: {op}")
93-
94-
95-
@compile_unary_agg.register
96-
def _(
97-
op: ops.aggregations.SumOp,
98-
column: typed_expr.TypedExpr,
99-
window: typing.Optional[window_spec.WindowSpec] = None,
100-
) -> sge.Expression:
101-
# Will be null if all inputs are null. Pandas defaults to zero sum though.
102-
expr = _apply_window_if_present(sge.func("SUM", column.expr), window)
103-
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
104-
105-
106-
def _apply_window_if_present(
107-
value: sge.Expression,
108-
window: typing.Optional[window_spec.WindowSpec] = None,
109-
) -> sge.Expression:
110-
if window is not None:
111-
raise NotImplementedError("Can't apply window to the expression.")
112-
return value
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import window_spec
22+
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
24+
from bigframes.operations import aggregations as agg_ops
25+
26+
BINARY_OP_REGISTRATION = reg.OpRegistration()
27+
28+
29+
def compile(
30+
op: agg_ops.WindowOp,
31+
left: typed_expr.TypedExpr,
32+
right: typed_expr.TypedExpr,
33+
window: typing.Optional[window_spec.WindowSpec] = None,
34+
) -> sge.Expression:
35+
return BINARY_OP_REGISTRATION[op](op, left, right, window=window)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import window_spec
22+
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
24+
from bigframes.operations import aggregations as agg_ops
25+
26+
NULLARY_OP_REGISTRATION = reg.OpRegistration()
27+
28+
29+
def compile(
30+
op: agg_ops.WindowOp,
31+
window: typing.Optional[window_spec.WindowSpec] = None,
32+
) -> sge.Expression:
33+
return NULLARY_OP_REGISTRATION[op](op, window=window)
34+
35+
36+
@NULLARY_OP_REGISTRATION.register(agg_ops.SizeOp)
37+
def _(
38+
op: agg_ops.SizeOp,
39+
window: typing.Optional[window_spec.WindowSpec] = None,
40+
) -> sge.Expression:
41+
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
from sqlglot import expressions as sge
20+
21+
from bigframes.operations import aggregations as agg_ops
22+
23+
# We should've been more specific about input types. Unfortunately,
24+
# MyPy doesn't support more rigorous checks.
25+
CompilationFunc = typing.Callable[..., sge.Expression]
26+
27+
28+
class OpRegistration:
29+
def __init__(self) -> None:
30+
self._registered_ops: dict[str, CompilationFunc] = {}
31+
32+
def register(
33+
self, op: agg_ops.WindowOp | type[agg_ops.WindowOp]
34+
) -> typing.Callable[[CompilationFunc], CompilationFunc]:
35+
def decorator(item: CompilationFunc):
36+
def arg_checker(*args, **kwargs):
37+
if not isinstance(args[0], agg_ops.WindowOp):
38+
raise ValueError(
39+
"The first parameter must be a window operator. "
40+
f"Got {type(args[0])}"
41+
)
42+
return item(*args, **kwargs)
43+
44+
if hasattr(op, "name"):
45+
key = typing.cast(str, op.name)
46+
if key in self._registered_ops:
47+
raise ValueError(f"{key} is already registered")
48+
else:
49+
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
50+
self._registered_ops[key] = item
51+
return arg_checker
52+
53+
return decorator
54+
55+
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
56+
if isinstance(op, agg_ops.WindowOp):
57+
if not hasattr(op, "name"):
58+
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
59+
else:
60+
key = typing.cast(str, op.name)
61+
return self._registered_ops[key]
62+
return self._registered_ops[op]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import window_spec
22+
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
24+
from bigframes.operations import aggregations as agg_ops
25+
26+
ORDERED_UNARY_OP_REGISTRATION = reg.OpRegistration()
27+
28+
29+
def compile(
30+
op: agg_ops.WindowOp,
31+
column: typed_expr.TypedExpr,
32+
window: typing.Optional[window_spec.WindowSpec] = None,
33+
order_by: typing.Sequence[sge.Expression] = [],
34+
) -> sge.Expression:
35+
return ORDERED_UNARY_OP_REGISTRATION[op](
36+
op, column, window=window, order_by=order_by
37+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import window_spec
22+
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
24+
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
25+
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
26+
from bigframes.operations import aggregations as agg_ops
27+
28+
UNARY_OP_REGISTRATION = reg.OpRegistration()
29+
30+
31+
def compile(
32+
op: agg_ops.WindowOp,
33+
column: typed_expr.TypedExpr,
34+
window: typing.Optional[window_spec.WindowSpec] = None,
35+
) -> sge.Expression:
36+
return UNARY_OP_REGISTRATION[op](op, column, window=window)
37+
38+
39+
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
40+
def _(
41+
op: agg_ops.SumOp,
42+
column: typed_expr.TypedExpr,
43+
window: typing.Optional[window_spec.WindowSpec] = None,
44+
) -> sge.Expression:
45+
# Will be null if all inputs are null. Pandas defaults to zero sum though.
46+
expr = apply_window_if_present(sge.func("SUM", column.expr), window)
47+
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
48+
49+
50+
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
51+
def _(
52+
op: agg_ops.SizeUnaryOp,
53+
_,
54+
window: typing.Optional[window_spec.WindowSpec] = None,
55+
) -> sge.Expression:
56+
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import typing
17+
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core import window_spec
21+
22+
23+
def apply_window_if_present(
24+
value: sge.Expression,
25+
window: typing.Optional[window_spec.WindowSpec] = None,
26+
) -> sge.Expression:
27+
if window is not None:
28+
raise NotImplementedError("Can't apply window to the expression.")
29+
return value

0 commit comments

Comments
 (0)