Skip to content

Commit 3b3b348

Browse files
[ENH] fit_predict for FreshPRINCE and RotationForest (#1456)
* fp and rotf fit_predict * fix
1 parent f95d6e7 commit 3b3b348

File tree

7 files changed

+404
-383
lines changed

7 files changed

+404
-383
lines changed

aeon/classification/feature_based/_fresh_prince.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
import numpy as np
12-
from sklearn.utils import check_random_state
1312

1413
from aeon.classification.base import BaseClassifier
1514
from aeon.classification.sklearn import RotationForestClassifier
@@ -117,8 +116,8 @@ def _fit(self, X, y):
117116
Changes state by creating a fitted model that updates attributes
118117
ending in "_" and sets is_fitted flag to True.
119118
"""
120-
self._fit_fresh_prince(X, y)
121-
119+
X_t = self._fit_fp_shared(X, y)
120+
self._rotf.fit(X_t, y)
122121
return self
123122

124123
def _predict(self, X) -> np.ndarray:
@@ -155,24 +154,18 @@ def _predict_proba(self, X) -> np.ndarray:
155154
return self._rotf.predict_proba(self._tsfresh.transform(X))
156155

157156
def _fit_predict(self, X, y) -> np.ndarray:
158-
rng = check_random_state(self.random_state)
159-
return np.array(
160-
[
161-
self.classes_[int(rng.choice(np.flatnonzero(prob == prob.max())))]
162-
for prob in self._fit_predict_proba(X, y)
163-
]
164-
)
157+
X_t = self._fit_fp_shared(X, y)
158+
return self._rotf.fit_predict(X_t, y)
165159

166160
def _fit_predict_proba(self, X, y) -> np.ndarray:
167-
Xt = self._fit_fresh_prince(X, y, save_rotf_data=True)
168-
return self._rotf._get_train_probs(Xt, y)
161+
X_t = self._fit_fp_shared(X, y)
162+
return self._rotf.fit_predict_proba(X_t, y)
169163

170-
def _fit_fresh_prince(self, X, y, save_rotf_data=False):
164+
def _fit_fp_shared(self, X, y):
171165
self.n_cases_, self.n_channels_, self.n_timepoints_ = X.shape
172166

173167
self._rotf = RotationForestClassifier(
174168
n_estimators=self.n_estimators,
175-
save_transformed_data=save_rotf_data,
176169
n_jobs=self._n_jobs,
177170
random_state=self.random_state,
178171
)
@@ -184,10 +177,7 @@ def _fit_fresh_prince(self, X, y, save_rotf_data=False):
184177
disable_progressbar=self.verbose < 1,
185178
)
186179

187-
X_t = self._tsfresh.fit_transform(X, y)
188-
self._rotf.fit(X_t, y)
189-
190-
return X_t
180+
return self._tsfresh.fit_transform(X, y)
191181

192182
@classmethod
193183
def get_test_params(cls, parameter_set="default"):

aeon/classification/shapelet_based/_stc.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def _fit(self, X, y):
187187
Changes state by creating a fitted model that updates attributes
188188
ending in "_".
189189
"""
190-
self._fit_stc(X, y)
191-
190+
X_t = self._fit_stc_shared(X, y)
191+
self._estimator.fit(X_t, y)
192192
return self
193193

194194
def _predict(self, X) -> np.ndarray:
@@ -243,13 +243,15 @@ def _fit_predict(self, X, y) -> np.ndarray:
243243
)
244244

245245
def _fit_predict_proba(self, X, y) -> np.ndarray:
246-
Xt = self._fit_stc(X, y, save_rotf_data=True)
246+
X_t = self._fit_stc_shared(X, y)
247247

248248
if (isinstance(self.estimator, RotationForestClassifier)) or (
249249
self.estimator is None
250250
):
251-
return self._estimator._get_train_probs(Xt, y)
251+
return self._estimator.fit_predict_proba(X_t, y)
252252
else:
253+
self._estimator.fit(X_t, y)
254+
253255
m = getattr(self._estimator, "predict_proba", None)
254256
if not callable(m):
255257
raise ValueError("Estimator must have a predict_proba method.")
@@ -269,14 +271,14 @@ def _fit_predict_proba(self, X, y) -> np.ndarray:
269271

270272
return cross_val_predict(
271273
estimator,
272-
X=Xt,
274+
X=X_t,
273275
y=y,
274276
cv=cv_size,
275277
method="predict_proba",
276278
n_jobs=self._n_jobs,
277279
)
278280

279-
def _fit_stc(self, X, y, save_rotf_data=False):
281+
def _fit_stc_shared(self, X, y):
280282
self.n_cases_, self.n_channels_, self.n_timepoints_ = X.shape
281283

282284
if self.time_limit_in_minutes > 0:
@@ -304,9 +306,6 @@ def _fit_stc(self, X, y, save_rotf_data=False):
304306
self.random_state,
305307
)
306308

307-
if isinstance(self._estimator, RotationForestClassifier):
308-
self._estimator.save_transformed_data = save_rotf_data
309-
310309
m = getattr(self._estimator, "n_jobs", None)
311310
if m is not None:
312311
self._estimator.n_jobs = self._n_jobs
@@ -315,11 +314,7 @@ def _fit_stc(self, X, y, save_rotf_data=False):
315314
if m is not None and self.time_limit_in_minutes > 0:
316315
self._estimator.time_limit_in_minutes = self._classifier_limit_in_minutes
317316

318-
Xt = self._transformer.fit_transform(X, y)
319-
320-
self._estimator.fit(Xt, y)
321-
322-
return Xt
317+
return self._transformer.fit_transform(X, y)
323318

324319
@classmethod
325320
def get_test_params(cls, parameter_set="default"):

0 commit comments

Comments
 (0)