-
Notifications
You must be signed in to change notification settings - Fork 223
example or documentation for StratifiedKFold use #370
Description
Is your feature request related to a problem? Please describe.
Using plain StratifiedKFold e.g. in a GridSearchCV conceptually might not work since or if the event-versus-no-event class cannot be determined from structured y that is passed to fit. Practically this can manifest as a not very obvious ValueError: n_splits=5 cannot be greater than the number of members in each class. exception.
Describe the solution you'd like
Information or an example in the documentation somewhere on how StratifiedKFold can be used would be great.
Describe alternatives you've considered
An example would be a nice-to-have i.e. the alternative is users figuring something out in other ways.
References and existing implementations
none and/or unknown
Code snippets
Something like this illustrates the issue and one possible approach.
import numpy as np
import sklearn.model_selection
import sklearn.tree
import sksurv.tree
import sksurv.util
n = 100
feature1 = np.arange(0, n, 1)
feature2 = n - feature1
times = np.arange(0,n) + 1
events = (times <= n/10)
X = np.vstack((feature1, feature2)).T
y = sksurv.util.Surv.from_arrays(time=times, event=events)
def train_test_generator(X, y, stratified=None, n_splits=2):
if stratified is not None:
return sklearn.model_selection.StratifiedKFold(n_splits=n_splits).split(X, y[stratified])
else:
return sklearn.model_selection.KFold(n_splits=n_splits).split(X)
for stratified in [None, "event"]:
print(f"stratified={stratified}")
for i, (train_index, test_index) in enumerate(train_test_generator(X, y, stratified=stratified)):
print(f'split {i+1}: {y[train_index]["event"].sum()} + {y[test_index]["event"].sum()} = {y["event"].sum()}')
# this does not work
# gcv = sklearn.model_selection.GridSearchCV(estimator=sksurv.tree.SurvivalTree(random_state=0),
# param_grid={ "min_samples_leaf": [3, 6, 9] },
# cv=sklearn.model_selection.StratifiedKFold(n_splits=5))
# gcv.fit(X, y)
# this does work
gcv = sklearn.model_selection.GridSearchCV(estimator=sksurv.tree.SurvivalTree(random_state=0),
param_grid={ "min_samples_leaf": [3, 6, 9] },
cv=train_test_generator(X, y, stratified="event", n_splits=5))
gcv.fit(X, y)
gcv.cv_results_