Skip to content

[ENH] Add Basic ARIMA model #2860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d381d5e
arima first
TonyBagnall May 24, 2025
3a0552b
move utils
TonyBagnall May 24, 2025
0ac5380
make functions private
TonyBagnall May 24, 2025
44b36a7
Modularise SARIMA model
May 28, 2025
6d18de9
Add ARIMA forecaster to forecasting package
May 28, 2025
b7e6424
Add example to ARIMA forecaster, this also tests the forecaster is pr…
May 28, 2025
e33fa4d
Basic ARIMA model
May 28, 2025
f613f7e
Convert ARIMA to numba version
May 28, 2025
a6b708c
Merge branch 'main' into arb/base_arima
alexbanwell1 May 28, 2025
9eb00f6
Adjust parameters to allow modification in fit
May 28, 2025
d4ed4b1
Update example and return native python type
May 28, 2025
2893e1b
Fix examples for tests
May 28, 2025
9801e8b
Fix Nelder-Mead Optimisation Algorithm Example
May 28, 2025
2f928c7
Fix Nelder-Mead Optimisation Algorithm Example #2
May 28, 2025
94cd5b3
Remove Nelder-Mead Example due to issues with numba caching functions
May 28, 2025
0d0d63f
Fix return type issue
May 28, 2025
39a3ed2
Address PR Feedback
May 28, 2025
05a2785
Ignore small tolerances in floating point value in output of example
May 28, 2025
73966ab
Fix kpss_test example
May 28, 2025
a0f090d
Fix kpss_test example #2
May 28, 2025
6884703
Update documentation for ARIMAForecaster, change constant_term to be …
Jun 2, 2025
44a8647
Merge branch 'main' into arb/base_arima
alexbanwell1 Jun 2, 2025
9af3a56
Modify ARIMA to allow predicting multiple values by updating the stat…
Jun 8, 2025
4c63af5
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 9, 2025
e898f2f
Fix bug using self.d rather than self.d_
Jun 9, 2025
11c4987
Merge branch 'arb/base_arima' of https://github.com/aeon-toolkit/aeon…
Jun 9, 2025
6314a6f
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 11, 2025
72b7980
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 11, 2025
3c644a0
refactor ARIMA
TonyBagnall Jun 11, 2025
350252e
Merge branch 'main' into arb/base_arima
MatthewMiddlehurst Jun 16, 2025
1bd6a32
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 16, 2025
b91d135
docstring
TonyBagnall Jun 16, 2025
420cd72
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 18, 2025
061f286
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 21, 2025
149c0ad
find forecast_ in fit
TonyBagnall Jun 21, 2025
745806e
Merge branch 'main' into arb/base_arima
MatthewMiddlehurst Jul 4, 2025
1d300a4
Merge branch 'main' into arb/base_arima
TonyBagnall Jul 10, 2025
9d8b24f
remove optional y
TonyBagnall Jul 10, 2025
d9b1e7a
add iterative
TonyBagnall Jul 10, 2025
5e4f138
Merge branch 'main' into arb/base_arima
TonyBagnall Jul 16, 2025
1b10109
Merge branch 'main' into arb/base_arima
TonyBagnall Jul 16, 2025
2a962d8
typo
TonyBagnall Jul 16, 2025
6f8cd55
Merge branch 'arb/base_arima' of https://github.com/aeon-toolkit/aeon…
TonyBagnall Jul 16, 2025
c7616a4
typo
TonyBagnall Jul 17, 2025
f29c809
calculate forecast_
TonyBagnall Jul 17, 2025
5a2ee8d
use differenced
TonyBagnall Jul 17, 2025
42e699c
example
TonyBagnall Jul 17, 2025
d1caed3
iterative
TonyBagnall Jul 17, 2025
9f2a85d
arima tests
TonyBagnall Jul 17, 2025
ca30d17
revert to float
TonyBagnall Jul 17, 2025
46d5ebc
switch nelder_mead version
TonyBagnall Jul 17, 2025
9648b99
isolate loss function
TonyBagnall Jul 17, 2025
e8157d6
isolate loss function
TonyBagnall Jul 17, 2025
f45400c
remove the utils version of nelder mead
TonyBagnall Jul 17, 2025
1509989
set self.c_ correctly
TonyBagnall Jul 19, 2025
a2578d9
numba optimise
TonyBagnall Jul 19, 2025
35e5b1c
numba optimise
TonyBagnall Jul 19, 2025
79c8c2d
numba optimise
TonyBagnall Jul 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aeon/forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Forecasters."""

__all__ = [
"NaiveForecaster",
"BaseForecaster",
"NaiveForecaster",
"RegressionForecaster",
"ETSForecaster",
"TVPForecaster",
"ARIMA",
]

from aeon.forecasting._arima import ARIMA
from aeon.forecasting._ets import ETSForecaster
from aeon.forecasting._naive import NaiveForecaster
from aeon.forecasting._regression import RegressionForecaster
Expand Down
294 changes: 294 additions & 0 deletions aeon/forecasting/_arima.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
"""ARIMA.

An implementation of the ARIMA forecasting algorithm.
"""

__maintainer__ = ["alexbanwell1", "TonyBagnall"]
__all__ = ["ARIMA"]

import numpy as np
from numba import njit

from aeon.forecasting._extract_paras import _extract_arma_params
from aeon.forecasting._nelder_mead import nelder_mead
from aeon.forecasting.base import BaseForecaster


class ARIMA(BaseForecaster):
"""AutoRegressive Integrated Moving Average (ARIMA) forecaster.

ARIMA with fixed model structure and fitted parameters found with an
nelder mead optimizer to minimise the AIC.

Parameters
----------
p : int, default=1,
Autoregressive (p) order of the ARIMA model
d : int, default=0,
Differencing (d) order of the ARIMA model
q : int, default=1,
Moving average (q) order of the ARIMA model
use_constant : bool = False,
Presence of a constant/intercept term in the model.
iterations : int, default = 200
Maximum number of iterations to use in the Nelder-Mead parameter search.

Attributes
----------
residuals_ : np.ndarray
Residual errors from the fitted model.
aic_ : float
Akaike Information Criterion for the fitted model.
c_ : float, default = 0
Intercept term.
phi_ : np.ndarray
Coefficients for autoregressive terms (length p).
theta_ : np.ndarray
Coefficients for moving average terms (length q).

References
----------
.. [1] R. J. Hyndman and G. Athanasopoulos,
Forecasting: Principles and Practice. OTexts, 2014.
https://otexts.com/fpp3/
"""

_tags = {
"capability:horizon": False, # cannot fit to a horizon other than 1
}

def __init__(
self,
p: int = 1,
d: int = 0,
q: int = 1,
use_constant: bool = False,
iterations: int = 200,
):
self.p = p
self.d = d
self.q = q
self.use_constant = use_constant
self.iterations = iterations
self.phi_ = 0
self.theta_ = 0
self.c_ = 0
self._series = []
self._differenced_series = []
self.residuals_ = []
self.fitted_values_ = []
self.aic_ = 0
self._model = []
self._parameters = []
super().__init__(horizon=1, axis=1)

def _fit(self, y, exog=None):
"""Fit ARIMA forecaster to series y to predict one ahead using y.

Parameters
----------
y : np.ndarray
A time series on which to learn a forecaster to predict horizon ahead
exog : np.ndarray, default =None
Not allowed for this forecaster

Returns
-------
self
Fitted ARIMA.
"""
self._series = np.array(y.squeeze(), dtype=np.float64)
# Model is an array of the (c,p,q)
self._model = np.array(
(1 if self.use_constant else 0, self.p, self.q), dtype=np.int32
)
self._differenced_series = np.diff(self._series, n=self.d)
# Nelder Mead returns the parameters in a single array
(self._parameters, self.aic_) = nelder_mead(
0,
np.sum(self._model[:3]),
self._differenced_series,
self._model,
max_iter=self.iterations,
)
#
(self.aic_, self.residuals_, self.fitted_values_) = _arima_model(
self._parameters,
self._differenced_series,
self._model,
)
formatted_params = _extract_arma_params(
self._parameters, self._model
) # Extract
# parameters
differenced_forecast = self.fitted_values_[-1]

if self.d == 0:
self.forecast_ = differenced_forecast
elif self.d == 1:
self.forecast_ = differenced_forecast + self._series[-1]
else:
self.forecast_ = differenced_forecast + np.sum(self._series[-self.d :])
# Extract the parameter values
if self.use_constant:
self.c_ = formatted_params[0][0]
self.phi_ = formatted_params[1][: self.p]
self.theta_ = formatted_params[2][: self.q]

return self

def _predict(self, y, exog=None):
"""
Predict the next step ahead for y.

Parameters
----------
y : np.ndarray, default = None
A time series to predict the value of. y can be independent of the series
seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
float
Prediction 1 step ahead of the last value in y.
"""
y = y.squeeze()
p, q, d = self.p, self.q, self.d
phi, theta = self.phi_, self.theta_
c = 0.0
if self.use_constant:
c = self.c_

# Apply differencing
if d > 0:
if len(y) <= d:
raise ValueError("Series too short for differencing.")
y_diff = np.diff(y, n=d)
else:
y_diff = y

n = len(y_diff)
if n < max(p, q):
raise ValueError("Series too short for ARMA(p,q) with given order.")

# Estimate in-sample residuals using model (fixed parameters)
residuals = np.zeros(n)
for t in range(max(p, q), n):
ar_part = np.dot(phi, y_diff[t - np.arange(1, p + 1)]) if p > 0 else 0.0
ma_part = (
np.dot(theta, residuals[t - np.arange(1, q + 1)]) if q > 0 else 0.0
)
pred = c + ar_part + ma_part
residuals[t] = y_diff[t] - pred

# Use most recent p values of y_diff and q values of residuals to forecast t+1
ar_forecast = np.dot(phi, y_diff[-np.arange(1, p + 1)]) if p > 0 else 0.0
ma_forecast = np.dot(theta, residuals[-np.arange(1, q + 1)]) if q > 0 else 0.0

forecast_diff = c + ar_forecast + ma_forecast

# Undifference the forecast
if d == 0:
return forecast_diff
elif d == 1:
return forecast_diff + y[-1]
else:
return forecast_diff + np.sum(y[-d:])

def _forecast(self, y, exog=None):
"""Forecast one ahead for time series y."""
self.fit(y, exog)
return float(self.forecast_)

def iterative_forecast(self, y, prediction_horizon):
self.fit(y)
n = len(self._differenced_series)
p, q = self.p, self.q
phi, theta = self.phi_, self.theta_
h = prediction_horizon
c = 0.0
if self.use_constant:
c = self.c_

# Start with a copy of the original series and residuals
residuals = np.zeros(len(self.residuals_) + h)
residuals[: len(self.residuals_)] = self.residuals_
forecast_series = np.zeros(n + h)
forecast_series[:n] = self._differenced_series
for i in range(h):
# Get most recent p values (lags)
t = n + i
ar_term = 0.0
if p > 0:
ar_term = np.dot(phi, forecast_series[t - np.arange(1, p + 1)])
# Get most recent q residuals (lags)
ma_term = 0.0
if q > 0:
ma_term = np.dot(theta, residuals[t - np.arange(1, q + 1)])
next_value = c + ar_term + ma_term
# Append prediction and a zero residual (placeholder)
forecast_series[n + i] = next_value
# Can't compute real residual during prediction, leave as zero

# Correct differencing using forecast values
y_forecast_diff = forecast_series[n : n + h]
d = self.d
if d == 0:
return y_forecast_diff
else: # Correct undifferencing
# Start with last d values from original y
undiff = list(self._series[-d:])
for i in range(h):
# Take the last d values and sum them
reconstructed = y_forecast_diff[i] + sum(undiff[-d:])
undiff.append(reconstructed)
return np.array(undiff[d:])


@njit(cache=True, fastmath=True)
def _aic(residuals, num_params):
"""Calculate the log-likelihood of a model."""
variance = np.mean(residuals**2)
likelihood = len(residuals) * (np.log(2 * np.pi) + np.log(variance) + 1)
return likelihood + 2 * num_params


# Define the ARIMA(p, d, q) likelihood function
@njit(cache=True, fastmath=True)
def _arima_model(params, data, model):
"""Calculate the log-likelihood of an ARIMA model given the parameters."""
formatted_params = _extract_arma_params(params, model) # Extract parameters

# Initialize residuals
n = len(data)
num_predictions = n + 1
residuals = np.zeros(num_predictions - 1)
fitted_values = np.zeros(num_predictions)
for t in range(num_predictions):
fitted_values[t] = _in_sample_forecast(
data, model, t, formatted_params, residuals
)
if t != num_predictions - 1:
# Only calculate residuals for the predictions we have data for
residuals[t] = data[t] - fitted_values[t]
return _aic(residuals, len(params)), residuals, fitted_values


@njit(cache=True, fastmath=True)
def _in_sample_forecast(data, model, t, formatted_params, residuals):
"""Efficient ARMA one-step forecast at time t for fitted model."""
p = model[1]
q = model[2]
c = formatted_params[0][0] if model[0] else 0.0

ar_term = 0.0
for j in range(min(p, t)):
ar_term += formatted_params[1, j] * data[t - j - 1]

ma_term = 0.0
for j in range(min(q, t)):
ma_term += formatted_params[2, j] * residuals[t - j - 1]

return c + ar_term + ma_term
39 changes: 39 additions & 0 deletions aeon/forecasting/_loss_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Loss functions for optimiser."""

import numpy as np
from numba import njit

from aeon.forecasting._extract_paras import _extract_arma_params

LOG_2PI = 1.8378770664093453


@njit(cache=True, fastmath=True)
def _arima_fit(params, data, model):
"""Calculate the AIC of an ARIMA model given the parameters."""
formatted_params = _extract_arma_params(params, model) # Extract parameters

# Initialize residuals
n = len(data)
residuals = np.zeros(n)
c = formatted_params[0][0] if model[0] else 0
p = model[1]
q = model[2]
for t in range(n):
ar_term = 0.0
max_ar = min(p, t)
for j in range(max_ar):
ar_term += formatted_params[1, j] * data[t - j - 1]
ma_term = 0.0
max_ma = min(q, t)
for j in range(max_ma):
ma_term += formatted_params[2, j] * residuals[t - j - 1]
y_hat = c + ar_term + ma_term
residuals[t] = data[t] - y_hat
sse = 0.0
for i in range(n):
sse += residuals[i] ** 2
variance = sse / n
likelihood = n * (LOG_2PI + np.log(variance) + 1.0)
k = len(params)
return likelihood + 2 * k
Loading
Loading