Skip to content
Merged
13 changes: 13 additions & 0 deletions bigframes/_config/compute_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ class ComputeOptions:
ai_ops_threshold_autofail: bool = False

allow_large_results: Optional[bool] = None
maximum_rows_downloaded: Optional[int] = None
"""Limits the number of rows downloaded from BigQuery.

When converting a BigQuery DataFrames object to a pandas DataFrame or Series
(e.g., using ``.to_pandas()``, ``.head()``, ``.__repr__()``, direct iteration),
the data is downloaded from BigQuery to the client machine. This option
restricts the number of rows that can be downloaded.

If the number of rows to be downloaded exceeds this limit, a
``bigframes.exceptions.MaximumRowsDownloadedExceeded`` exception is raised.

Set to ``None`` (the default) for no limit.
"""

def assign_extra_query_labels(self, **kwargs: Any) -> None:
"""
Expand Down
4 changes: 4 additions & 0 deletions bigframes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class OperationAbortedError(RuntimeError):
"""Operation is aborted."""


class MaximumRowsDownloadedExceeded(RuntimeError):
"""Maximum rows downloaded exceeded."""


class TimeTravelDisabledWarning(Warning):
"""A query was reattempted without time travel."""

Expand Down
15 changes: 14 additions & 1 deletion bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import google.cloud.bigquery.table as bq_table
import google.cloud.bigquery_storage_v1

import bigframes.options
import bigframes.constants
import bigframes.core
from bigframes.core import compile, local_data, rewrite
Expand Down Expand Up @@ -687,7 +688,7 @@ def _execute_plan(
bigframes.core.ArrayValue(plan), iterator.schema
)

return executor.ExecuteResult(
result = executor.ExecuteResult(
arrow_batches=iterator.to_arrow_iterable(
bqstorage_client=self.bqstoragereadclient
),
Expand All @@ -697,6 +698,18 @@ def _execute_plan(
total_rows=iterator.total_rows,
)

# Check if the number of rows exceeds the maximum allowed
if result.total_rows is not None:
max_rows = bigframes.options.compute.maximum_rows_downloaded
if max_rows is not None and result.total_rows > max_rows:
raise bfe.MaximumRowsDownloadedExceeded(
f"Query would download {result.total_rows} rows, which "
f"exceeds the limit of {max_rows}. "
"You can adjust this limit by setting "
"`bigframes.options.compute.maximum_rows_downloaded`."
)
return result


def _if_schema_match(
table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema
Expand Down
13 changes: 13 additions & 0 deletions bigframes/session/direct_gbq_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import google.cloud.bigquery.job as bq_job
import google.cloud.bigquery.table as bq_table

import bigframes.exceptions
import bigframes.options
from bigframes.core import compile, nodes
from bigframes.session import executor, semi_executor
import bigframes.session._io.bigquery as bq_io
Expand Down Expand Up @@ -49,6 +51,17 @@ def execute(
sql=compiled.sql,
)

# Check if the number of rows exceeds the maximum allowed
if iterator.total_rows is not None:
max_rows = bigframes.options.compute.maximum_rows_downloaded
if max_rows is not None and iterator.total_rows > max_rows:
raise bigframes.exceptions.MaximumRowsDownloadedExceeded(
f"Query would download {iterator.total_rows} rows, which "
f"exceeds the limit of {max_rows}. "
"You can adjust this limit by setting "
"`bigframes.options.compute.maximum_rows_downloaded`."
)

return executor.ExecuteResult(
arrow_batches=iterator.to_arrow_iterable(),
schema=plan.schema,
Expand Down
107 changes: 107 additions & 0 deletions tests/system/small/test_pandas_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,113 @@ def test_credentials_need_reauthentication(
df = bpd.read_gbq(test_query)
assert df is not None


@pytest.fixture(scope="function")
def max_rows_option_manager():
"""Ensures maximum_rows_downloaded is reset after each test."""
original_value = bpd.options.compute.maximum_rows_downloaded
yield
bpd.options.compute.maximum_rows_downloaded = original_value


def test_max_rows_normal_execution_no_limit(
max_rows_option_manager, test_data_tables, scalars_pandas_df_index
):
"""Test queries execute normally when the option is not set."""
table_name = test_data_tables["scalars"]
bpd.options.compute.maximum_rows_downloaded = None
df = bpd.read_gbq(table_name)
pd_df = df.to_pandas()
# scalars table has 10 rows
assert len(pd_df) == len(scalars_pandas_df_index)


def test_max_rows_normal_execution_within_limit(
max_rows_option_manager, test_data_tables, scalars_pandas_df_index
):
"""Test queries execute normally when the number of rows is within the limit."""
table_name = test_data_tables["scalars"]
# scalars table has 10 rows
bpd.options.compute.maximum_rows_downloaded = 10
df = bpd.read_gbq(table_name)
pd_df = df.to_pandas()
assert len(pd_df) == len(scalars_pandas_df_index)

bpd.options.compute.maximum_rows_downloaded = 15
df = bpd.read_gbq(table_name)
pd_df = df.to_pandas()
assert len(pd_df) == len(scalars_pandas_df_index)


def test_max_rows_exceeds_limit_to_pandas(
max_rows_option_manager, test_data_tables
):
"""Test to_pandas() raises MaximumRowsDownloadedExceeded when the limit is exceeded."""
table_name = test_data_tables["scalars"]
# scalars table has 10 rows
bpd.options.compute.maximum_rows_downloaded = 5
df = bpd.read_gbq(table_name)
with pytest.raises(bpd.exceptions.MaximumRowsDownloadedExceeded):
df.to_pandas()


def test_max_rows_exceeds_limit_to_pandas_batches(
max_rows_option_manager, test_data_tables
):
"""Test next(iter(to_pandas_batches())) raises MaximumRowsDownloadedExceeded."""
table_name = test_data_tables["scalars"]
# scalars table has 10 rows
bpd.options.compute.maximum_rows_downloaded = 5
df = bpd.read_gbq(table_name)
with pytest.raises(bpd.exceptions.MaximumRowsDownloadedExceeded):
next(iter(df.to_pandas_batches()))


def test_max_rows_repr_does_not_raise(max_rows_option_manager, test_data_tables):
"""Test repr(df) does not raise the exception."""
table_name = test_data_tables["scalars"]
bpd.options.compute.maximum_rows_downloaded = 1
df = bpd.read_gbq(table_name)
# scalars table has 10 rows, limit is 1, but repr should only fetch a few
assert "int_col" in repr(df)


def test_max_rows_peek_does_not_raise(max_rows_option_manager, test_data_tables):
"""Test df.peek() does not raise the exception."""
table_name = test_data_tables["scalars"]
bpd.options.compute.maximum_rows_downloaded = 1
df = bpd.read_gbq(table_name)
# scalars table has 10 rows, limit is 1, but peek should only fetch a few
peeked_df = df.peek(n=2)
assert len(peeked_df) == 2


def test_max_rows_shape_does_not_raise(max_rows_option_manager, test_data_tables):
"""Test df.shape does not raise the exception."""
table_name = test_data_tables["scalars"]
bpd.options.compute.maximum_rows_downloaded = 1
df = bpd.read_gbq(table_name)
# scalars table has 10 rows, limit is 1
# Shape currently executes a full count query, but doesn't download rows
assert df.shape == (10, 16)


def test_max_rows_to_gbq_does_not_raise(
max_rows_option_manager, test_data_tables, dataset_id_permanent, session
):
"""Test df.to_gbq() does not raise the exception."""
table_name = test_data_tables["scalars"]
bpd.options.compute.maximum_rows_downloaded = 1
df = bpd.read_gbq(table_name)
# scalars table has 10 rows, limit is 1
# to_gbq only executes a query and stores results in a new table, no download
dest_table = f"{dataset_id_permanent}.test_max_rows_to_gbq_output"
df.to_gbq(destination_table=dest_table, if_exists="replace")
# Simple check to ensure table was created
dest_df = bpd.read_gbq(dest_table)
assert dest_df.shape[0] == 10
session.bqclient.delete_table(dest_table, not_found_ok=True)

# Call get_global_session() *after* read_gbq so that our location detection
# has a chance to work.
session = bpd.get_global_session()
Expand Down
Loading