Skip to content

feat: add DataFrame.ai.forecast() support #1828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,17 @@ class BaseBqml:

def __init__(self, session: bigframes.session.Session):
self._session = session
self._base_sql_generator = ml_sql.BaseSqlGenerator()
self._sql_generator = ml_sql.BaseSqlGenerator()

def ai_forecast(
self,
input_data: bpd.DataFrame,
options: Mapping[str, Union[str, int, float, Iterable[str]]],
) -> bpd.DataFrame:
result_sql = self._sql_generator.ai_forecast(
source_sql=input_data.sql, options=options
)
return self._session.read_gbq(result_sql)


class BqmlModel(BaseBqml):
Expand All @@ -55,8 +65,8 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
self._model = model
model_ref = self._model.reference
assert model_ref is not None
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
model_ref
self._sql_generator: ml_sql.ModelManipulationSqlGenerator = (
ml_sql.ModelManipulationSqlGenerator(model_ref)
)

def _apply_ml_tvf(
Expand Down Expand Up @@ -126,30 +136,28 @@ def model(self) -> bigquery.Model:
def recommend(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_recommend,
self._sql_generator.ml_recommend,
)

def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_predict,
self._sql_generator.ml_predict,
)

def explain_predict(
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_explain_predict(
lambda source_sql: self._sql_generator.ml_explain_predict(
source_sql=source_sql,
struct_options=options,
),
)

def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_global_explain(
struct_options=options
)
sql = self._sql_generator.ml_global_explain(struct_options=options)
return (
self._session.read_gbq(sql)
.sort_values(by="attribution", ascending=False)
Expand All @@ -159,7 +167,7 @@ def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_transform,
self._sql_generator.ml_transform,
)

def generate_text(
Expand All @@ -170,7 +178,7 @@ def generate_text(
options["flatten_json_output"] = True
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
lambda source_sql: self._sql_generator.ml_generate_text(
source_sql=source_sql,
struct_options=options,
),
Expand All @@ -186,7 +194,7 @@ def generate_embedding(
options["flatten_json_output"] = True
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_embedding(
lambda source_sql: self._sql_generator.ml_generate_embedding(
source_sql=source_sql,
struct_options=options,
),
Expand All @@ -201,7 +209,7 @@ def generate_table(
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ai_generate_table(
lambda source_sql: self._sql_generator.ai_generate_table(
source_sql=source_sql,
struct_options=options,
),
Expand All @@ -216,14 +224,14 @@ def detect_anomalies(

return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_detect_anomalies(
lambda source_sql: self._sql_generator.ml_detect_anomalies(
source_sql=source_sql,
struct_options=options,
),
)

def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
sql = self._sql_generator.ml_forecast(struct_options=options)
timestamp_col_name = "forecast_timestamp"
index_cols = [timestamp_col_name]
first_col_name = self._session.read_gbq(sql).columns.values[0]
Expand All @@ -232,9 +240,7 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
return self._session.read_gbq(sql, index_col=index_cols).reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
struct_options=options
)
sql = self._sql_generator.ml_explain_forecast(struct_options=options)
timestamp_col_name = "time_series_timestamp"
index_cols = [timestamp_col_name]
first_col_name = self._session.read_gbq(sql).columns.values[0]
Expand All @@ -243,7 +249,7 @@ def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
return self._session.read_gbq(sql, index_col=index_cols).reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
sql = self._sql_generator.ml_evaluate(
input_data.sql if (input_data is not None) else None
)

Expand All @@ -254,28 +260,24 @@ def llm_evaluate(
input_data: bpd.DataFrame,
task_type: Optional[str] = None,
):
sql = self._model_manipulation_sql_generator.ml_llm_evaluate(
input_data.sql, task_type
)
sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type)

return self._session.read_gbq(sql)

def arima_evaluate(self, show_all_candidate_models: bool = False):
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models
)
sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models)

return self._session.read_gbq(sql)

def arima_coefficients(self) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_arima_coefficients()
sql = self._sql_generator.ml_arima_coefficients()

return self._session.read_gbq(sql)

def centroids(self) -> bpd.DataFrame:
assert self._model.model_type == "KMEANS"

sql = self._model_manipulation_sql_generator.ml_centroids()
sql = self._sql_generator.ml_centroids()

return self._session.read_gbq(
sql, index_col=["centroid_id", "feature"]
Expand All @@ -284,7 +286,7 @@ def centroids(self) -> bpd.DataFrame:
def principal_components(self) -> bpd.DataFrame:
assert self._model.model_type == "PCA"

sql = self._model_manipulation_sql_generator.ml_principal_components()
sql = self._sql_generator.ml_principal_components()

return self._session.read_gbq(
sql, index_col=["principal_component_id", "feature"]
Expand All @@ -293,7 +295,7 @@ def principal_components(self) -> bpd.DataFrame:
def principal_component_info(self) -> bpd.DataFrame:
assert self._model.model_type == "PCA"

sql = self._model_manipulation_sql_generator.ml_principal_component_info()
sql = self._sql_generator.ml_principal_component_info()

return self._session.read_gbq(sql)

Expand All @@ -319,7 +321,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
# The possibility of conflicts should be low.
vertex_ai_model_id = vertex_ai_model_id[:63]
sql = self._model_manipulation_sql_generator.alter_model(
sql = self._sql_generator.alter_model(
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
Expand Down
17 changes: 17 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def build_parameters(self, **kwargs: Union[str, int, float, Iterable[str]]) -> s
param_strs = [f"{k}={self.encode_value(v)}" for k, v in kwargs.items()]
return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs)

def build_named_parameters(
self, **kwargs: Union[str, int, float, Iterable[str]]
) -> str:
param_strs = [f"{k} => {self.encode_value(v)}" for k, v in kwargs.items()]
return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs)

def build_structs(self, **kwargs: Union[int, float, str, Mapping]) -> str:
"""Encode a dict of values into a formatted STRUCT items for SQL"""
param_strs = []
Expand Down Expand Up @@ -187,6 +193,17 @@ def ml_distance(
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-distance"""
return f"""SELECT *, ML.DISTANCE({sql_utils.identifier(col_x)}, {sql_utils.identifier(col_y)}, '{type}') AS {sql_utils.identifier(name)} FROM ({source_sql})"""

def ai_forecast(
self,
source_sql: str,
options: Mapping[str, Union[int, float, bool, Iterable[str]]],
):
"""Encode AI.FORECAST.
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast"""
named_parameters_sql = self.build_named_parameters(**options)

return f"""SELECT * FROM AI.FORECAST(({source_sql}),{named_parameters_sql})"""


class ModelCreationSqlGenerator(BaseSqlGenerator):
"""Sql generator for creating a model entity. Model id is the standalone id without project id and dataset id."""
Expand Down
95 changes: 89 additions & 6 deletions bigframes/operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,24 @@

import re
import typing
from typing import Dict, List, Optional, Sequence
from typing import Dict, Iterable, List, Optional, Sequence, Union
import warnings

import numpy as np

from bigframes import dtypes, exceptions
from bigframes import dtypes, exceptions, options
from bigframes.core import guid, log_adapter


@log_adapter.class_logger
class AIAccessor:
def __init__(self, df) -> None:
def __init__(self, df, base_bqml=None) -> None:
import bigframes # Import in the function body to avoid circular imports.
import bigframes.dataframe

if not bigframes.options.experiments.ai_operators:
raise NotImplementedError()
from bigframes.ml import core as ml_core

self._df: bigframes.dataframe.DataFrame = df
self._base_bqml: ml_core.BaseBqml = base_bqml or ml_core.BaseBqml(df._session)

def filter(
self,
Expand Down Expand Up @@ -89,6 +88,8 @@ def filter(
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

answer_col = "answer"

Expand Down Expand Up @@ -181,6 +182,9 @@ def map(
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

import bigframes.dataframe
import bigframes.series

Expand Down Expand Up @@ -320,6 +324,8 @@ def classify(
columns are referred to, or when the count of labels does not meet the
requirement.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

if len(labels) < 2 or len(labels) > 20:
raise ValueError(
Expand Down Expand Up @@ -401,6 +407,9 @@ def join(
Raises:
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

self._validate_model(model)
columns = self._parse_columns(instruction)

Expand Down Expand Up @@ -525,6 +534,8 @@ def search(
ValueError: when the search_column is not found from the the data frame.
TypeError: when the provided model is not TextEmbeddingGenerator.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

if search_column not in self._df.columns:
raise ValueError(f"Column `{search_column}` not found")
Expand Down Expand Up @@ -640,6 +651,9 @@ def top_k(
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

import bigframes.dataframe
import bigframes.series

Expand Down Expand Up @@ -834,6 +848,8 @@ def sim_join(
Raises:
ValueError: when the amount of data to be processed exceeds the specified max_rows.
"""
if not options.experiments.ai_operators:
raise NotImplementedError()

if left_on not in self._df.columns:
raise ValueError(f"Left column {left_on} not found")
Expand Down Expand Up @@ -883,6 +899,73 @@ def sim_join(

return join_result

def forecast(
self,
timestamp_column: str,
data_column: str,
*,
model: str = "TimesFM 2.0",
id_columns: Optional[Iterable[str]] = None,
horizon: int = 10,
confidence_level: float = 0.95,
):
"""
Forecast time series at future horizon. Using Google Research's open source TimesFM(https://github.com/google-research/timesfm) model.

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).

Args:
timestamp_column (str):
A str value that specified the name of the time points column.
The time points column provides the time points used to generate the forecast.
The time points column must use one of the following data types: TIMESTAMP, DATE and DATETIME
data_column (str):
A str value that specifies the name of the data column. The data column contains the data to forecast.
The data column must use one of the following data types: INT64, NUMERIC and FLOAT64
model (str, default "TimesFM 2.0"):
A str value that specifies the name of the model. TimesFM 2.0 is the only supported value, and is the default value.
id_columns (Iterable[str] or None, default None):
An iterable of str value that specifies the names of one or more ID columns. Each ID identifies a unique time series to forecast.
Specify one or more values for this argument in order to forecast multiple time series using a single query.
The columns that you specify must use one of the following data types: STRING, INT64, ARRAY<STRING> and ARRAY<INT64>
horizon (int, default 10):
An int value that specifies the number of time points to forecast. The default value is 10. The valid input range is [1, 10,000].
confidence_level (float, default 0.95):
A FLOAT64 value that specifies the percentage of the future values that fall in the prediction interval.
The default value is 0.95. The valid input range is [0, 1).

Returns:
DataFrame:
The forecast dataframe matches that of the BigQuery AI.FORECAST function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast

Raises:
ValueError: when referring to a non-existing column.
"""
columns = [timestamp_column, data_column]
if id_columns:
columns += id_columns
for column in columns:
if column not in self._df.columns:
raise ValueError(f"Column `{column}` not found")

options: dict[str, Union[int, float, str, Iterable[str]]] = {
"data_col": data_column,
"timestamp_col": timestamp_column,
"model": model,
"horizon": horizon,
"confidence_level": confidence_level,
}
if id_columns:
options["id_cols"] = id_columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should verify the validity of timestamp_column and data_column, and raise error if necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the data types? We rely on the checks from the backend. Basically client libraries shouldn't do too much checks unless those are only specific to client.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I meant the existence of these columns. If the user has made some typos, we want the code to fail fast. The other methods have this check too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


return self._base_bqml.ai_forecast(input_data=self._df, options=options)

@staticmethod
def _attach_embedding(dataframe, source_column: str, embedding_column: str, model):
result_df = dataframe.copy()
Expand Down
Loading