Skip to content

example or documentation for StratifiedKFold use #370

@cpoerschke

Description

@cpoerschke

Is your feature request related to a problem? Please describe.

Using plain StratifiedKFold e.g. in a GridSearchCV conceptually might not work since or if the event-versus-no-event class cannot be determined from structured y that is passed to fit. Practically this can manifest as a not very obvious ValueError: n_splits=5 cannot be greater than the number of members in each class. exception.

Describe the solution you'd like

Information or an example in the documentation somewhere on how StratifiedKFold can be used would be great.

Describe alternatives you've considered

An example would be a nice-to-have i.e. the alternative is users figuring something out in other ways.

References and existing implementations

none and/or unknown

Code snippets

Something like this illustrates the issue and one possible approach.

import numpy as np
import sklearn.model_selection
import sklearn.tree
import sksurv.tree
import sksurv.util

n = 100

feature1 = np.arange(0, n, 1)
feature2 = n - feature1

times = np.arange(0,n) + 1
events = (times <= n/10)

X = np.vstack((feature1, feature2)).T
y = sksurv.util.Surv.from_arrays(time=times, event=events)


def train_test_generator(X, y, stratified=None, n_splits=2):
    if stratified is not None:
        return sklearn.model_selection.StratifiedKFold(n_splits=n_splits).split(X, y[stratified])
    else:
        return sklearn.model_selection.KFold(n_splits=n_splits).split(X)


for stratified in [None, "event"]:
    print(f"stratified={stratified}")
    for i, (train_index, test_index) in enumerate(train_test_generator(X, y, stratified=stratified)):
        print(f'split {i+1}: {y[train_index]["event"].sum()} + {y[test_index]["event"].sum()} = {y["event"].sum()}')

        
# this does not work
# gcv = sklearn.model_selection.GridSearchCV(estimator=sksurv.tree.SurvivalTree(random_state=0),
#                                            param_grid={ "min_samples_leaf": [3, 6, 9] },
#                                            cv=sklearn.model_selection.StratifiedKFold(n_splits=5))
# gcv.fit(X, y)

# this does work
gcv = sklearn.model_selection.GridSearchCV(estimator=sksurv.tree.SurvivalTree(random_state=0),
                                           param_grid={ "min_samples_leaf": [3, 6, 9] },
                                           cv=train_test_generator(X, y, stratified="event", n_splits=5))
gcv.fit(X, y)

gcv.cv_results_

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions