Skip to content

Add support for GridSearchCV to cuml.accel#7843

Open
betatim wants to merge 20 commits intorapidsai:mainfrom
betatim:cuml-accel-gridsearchcv
Open

Add support for GridSearchCV to cuml.accel#7843
betatim wants to merge 20 commits intorapidsai:mainfrom
betatim:cuml-accel-gridsearchcv

Conversation

@betatim
Copy link
Copy Markdown
Member

@betatim betatim commented Mar 3, 2026

This PR builds on the new Pipeline infrastructure. It uses the array API support in scikit-learn's GridSearchCV to achieve acceleration. We don't have our own GPU implementation, instead we enable array API support in scikit-learn and then call GridSearchCV's fit with cupy arrays as input.

At the start of the patched fit we perform a few checks to determine if we should give up on acceleration. If the estimator isn't a proxy or none of the parameters being searched over are supported we can bail out.

I tried to also take care of cross_val_score, cross_val_predict and cross_validate, to avoid having to add xfails. But that leads to more failures and fix ups. Might try again.

The other new xfail is because that test passes K_train.tolist() and expects scikit-learn to raise ValueError because a plain list isn't a valid precomputed kernel format. But the patching calls cp.asarray(K_train.tolist()) before passing data on, and cupy happily converts a nested list into a cupy array. So the validation that would reject the list never fires, and the test fails. Not sure if this is worth fixing, for now I'd just xfail it as we make something that shouldn't work work (in the case where a user passes something they shouldn't).

Some open questions:

  • we can't accelerate sparse inputs because the array API doesn't know about those.
  • I think it makes sense not to use several processes, but not sure. Should we override this?
  • There are a few # XXX comments where I still have questions to myself/others.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 3, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Mar 3, 2026
@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

1 similar comment
@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim betatim added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Mar 4, 2026
@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim

This comment was marked as resolved.

@betatim betatim marked this pull request as ready for review March 10, 2026 08:05
@betatim betatim requested a review from a team as a code owner March 10, 2026 08:05
@betatim betatim requested a review from jcrist March 10, 2026 08:05
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a GPU-accelerated patch for sklearn.model_selection.GridSearchCV with pre-flight checks and CPU fallback; updates docs for GridSearchCV and scikit-learn 1.8 compatibility; registers the new patch; adjusts SVC host conversion; adds an output-placement helper; and introduces comprehensive GridSearchCV tests and xfails.

Changes

Cohort / File(s) Summary
Documentation
docs/source/cuml-accel/faq.rst, docs/source/cuml-accel/limitations.rst
Adds GridSearchCV to the list of accelerated estimators and documents CPU-fallback conditions; bumps stated scikit-learn compatibility to include 1.8.
Patch Registration
python/cuml/cuml/accel/core.py
Register sklearn.model_selection in cuml.accel patch set so the new GridSearchCV patch is applied.
GridSearchCV Patch
python/cuml/cuml/accel/_patches/sklearn/model_selection.py
New module patching GridSearchCV.fit: SciPy array-API context manager, pre-flight checks (proxy compatibility, sparsity, numeric dtypes, n_jobs, callable scoring), CuPy conversions of inputs/params, CPU fallback, restoring user-facing attributes; exported via __all__.
Estimator Proxy
python/cuml/cuml/accel/estimator_proxy.py
Add _maybe_to_device(out) and route proxy outputs through it to honor GlobalSettings().output_type (NumPy vs CuPy).
SVM Override
python/cuml/cuml/accel/_overrides/sklearn/svm.py
Use ensure_host(y) when computing classes/counts in SVC._gpu_fit to handle potential CuPy y inputs from GridSearchCV.
Tests & CI xfails
python/cuml/cuml_accel_tests/test_grid_search.py, python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
Add extensive GridSearchCV test suite covering GPU/CPU fallback, array-API behavior, pipelines, scoring, n_jobs variants, edge cases; add xfail entries gated on sklearn >= 1.8.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

cuml-accel

Suggested reviewers

  • csadorf
  • jcrist
  • divyegala
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and accurately summarizes the main change: adding GridSearchCV support to cuml.accel, which is the primary objective reflected across all file modifications.
Description check ✅ Passed The description is directly related to the changeset, explaining the implementation approach using array API support and the decision-making process for acceleration bailout conditions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/source/cuml-accel/limitations.rst`:
- Around line 345-349: Update the limitations text for GridSearchCV to mention
that the optimized GPU path is skipped when n_jobs != 1; specifically, add a
bullet (or clause) alongside the existing conditions stating that if
GridSearchCV's n_jobs parameter is set to any value other than 1 the
implementation will fall back to CPU. Reference GridSearchCV and the n_jobs
parameter in the sentence so readers can find the relevant API option.

In `@python/cuml/cuml_accel_tests/test_grid_search.py`:
- Around line 188-203: The assertion currently compares type names of scorer
arguments, which can't distinguish NumPy vs CuPy arrays; update the final
assertion to use isinstance checks against np.ndarray instead. Specifically,
keep the scorer_arg_types collection produced by my_metric but change the assert
to verify each recorded pair with isinstance(y_true, np.ndarray) and
isinstance(y_pred, np.ndarray) (referencing my_metric and scorer_arg_types) so
the test fails if scorer receives device (CuPy) arrays. Ensure the updated
assertion still reports the scorer_arg_types on failure for debugging.

In `@python/cuml/cuml/accel/_patches/sklearn/model_selection.py`:
- Around line 96-154: The dtype validation using np.asarray(X) in the patched
GridSearchCV.fit runs before checking for GPU-resident inputs and before
sklearn's own validation, which throws on cp.ndarray and alters validation
semantics; change the flow so input-type checks happen first (detect if X or y
are cupy arrays via isinstance(X, cp.ndarray) / isinstance(y, cp.ndarray) and
skip np.asarray for those), move the numeric-dtype validation to after
confirming inputs are host arrays (or handle cupy arrays separately), and defer
converting X, y and params to cupy (cp.asarray) until after sklearn's validation
has completed with the original inputs; update references in this patch around
GridSearchCV.fit, estimator_name, is_proxy/ParameterGrid checks, and the
cp.asarray conversions that build X_gpu, y_gpu and params so validation uses the
original (possibly host) objects before any cp.asarray conversions occur.
- Around line 27-50: The current _enable_scipy_array_api context manager mutates
the process-global SCIPY_ARRAY_API env var and _scipy_array_api._GLOBAL_CONFIG
without synchronization, causing races when multiple threads call
GridSearchCV.fit(); make it thread-safe by adding a module-level lock and a
reference counter (similar to enter_internal_context in cuml.internals.outputs):
acquire the lock on entry in _enable_scipy_array_api, if the counter is zero
record the current env and _scipy_array_api._GLOBAL_CONFIG value and set both to
"1", then increment the counter and release the lock; on exit acquire the lock,
decrement the counter, and only when the counter becomes zero restore/pop the
saved env value and restore _scipy_array_api._GLOBAL_CONFIG, then release the
lock; ensure the lock and counter are module-level symbols so reentrancy across
threads is handled correctly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 61aa83b6-f112-4359-a3da-282433e1ef27

📥 Commits

Reviewing files that changed from the base of the PR and between 60715ad and d1f8465.

📒 Files selected for processing (8)
  • docs/source/cuml-accel/faq.rst
  • docs/source/cuml-accel/limitations.rst
  • python/cuml/cuml/accel/_overrides/sklearn/svm.py
  • python/cuml/cuml/accel/_patches/sklearn/model_selection.py
  • python/cuml/cuml/accel/core.py
  • python/cuml/cuml/accel/estimator_proxy.py
  • python/cuml/cuml_accel_tests/test_grid_search.py
  • python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml

Comment on lines +27 to +50
@contextlib.contextmanager
def _enable_scipy_array_api():
"""Enable scipy's array API support.

Sets the SCIPY_ARRAY_API env var (checked by sklearn's config validation)
and updates scipy's cached config (in case scipy had already been imported).

Both are restored on exit.
"""
old_env = os.environ.get("SCIPY_ARRAY_API")
os.environ["SCIPY_ARRAY_API"] = "1"

old_cached = _scipy_array_api._GLOBAL_CONFIG["SCIPY_ARRAY_API"]
_scipy_array_api._GLOBAL_CONFIG["SCIPY_ARRAY_API"] = "1"

try:
yield
finally:
if old_env is None:
os.environ.pop("SCIPY_ARRAY_API", None)
else:
os.environ["SCIPY_ARRAY_API"] = old_env

_scipy_array_api._GLOBAL_CONFIG["SCIPY_ARRAY_API"] = old_cached
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find where _enable_scipy_array_api is used
rg -n "_enable_scipy_array_api" --type py

Repository: rapidsai/cuml

Length of output: 251


🏁 Script executed:

# Get the full context of the model_selection.py file
wc -l python/cuml/cuml/accel/_patches/sklearn/model_selection.py

Repository: rapidsai/cuml

Length of output: 119


🏁 Script executed:

# Read the entire model_selection.py file to see the monkey patch and usage
cat -n python/cuml/cuml/accel/_patches/sklearn/model_selection.py

Repository: rapidsai/cuml

Length of output: 7274


🏁 Script executed:

# Check if there are any synchronization primitives (locks, threading) in the file
rg -n "threading|Lock|RLock|Semaphore" python/cuml/cuml/accel/_patches/sklearn/model_selection.py

Repository: rapidsai/cuml

Length of output: 39


🏁 Script executed:

# Search for threading usage in tests related to GridSearchCV or model_selection
rg -n "threading|Thread|concurrent" --type py | grep -i "gridsearch\|model_selection" | head -20

Repository: rapidsai/cuml

Length of output: 39


🏁 Script executed:

# Check if there are any tests for the monkey-patched GridSearchCV
fd -t f "*test*model_selection*" python/cuml/ --type f

Repository: rapidsai/cuml

Length of output: 358


🏁 Script executed:

# Search for concurrent or thread-safety related documentation
rg -n "thread|concurrent|reentrant" python/cuml/cuml/accel/_patches/sklearn/model_selection.py

Repository: rapidsai/cuml

Length of output: 39


🏁 Script executed:

# Check if sklearn's GridSearchCV itself handles thread safety
web_search
What is the thread-safety model for scikit-learn's GridSearchCV? Should concurrent fit() calls be supported?

Repository: rapidsai/cuml

Length of output: 176


🏁 Script executed:

# Search for test files related to model_selection patches
find python/cuml -type f -name "*test*" -path "*model_selection*"

Repository: rapidsai/cuml

Length of output: 39


🏁 Script executed:

# Check if there are tests for the patched GridSearchCV
rg -l "GridSearchCV" python/cuml --type py | grep test

Repository: rapidsai/cuml

Length of output: 275


🏁 Script executed:

# Look for any reentrancy or thread-safety discussions in comments/issues
rg -i "reentrant|thread.safe|concurrent.*fit" python/cuml --type py

Repository: rapidsai/cuml

Length of output: 275


🏁 Script executed:

# Check the actual usage pattern: is _enable_scipy_array_api nested or sequential?
# Look at the context manager composition on line 156-160
sed -n '156,161p' python/cuml/cuml/accel/_patches/sklearn/model_selection.py

Repository: rapidsai/cuml

Length of output: 278


🏁 Script executed:

# Check if contextlib.contextmanager provides any built-in reentrancy handling
# by looking at Python documentation or how it's used elsewhere
rg -n "contextlib.contextmanager" python/cuml --type py -A 5 | head -50

Repository: rapidsai/cuml

Length of output: 2086


🏁 Script executed:

# Look for the test file for grid search to see if there are any concurrent/threading tests
cat -n python/cuml/cuml_accel_tests/test_grid_search.py | head -100

Repository: rapidsai/cuml

Length of output: 3612


🏁 Script executed:

# Check if there are actual concurrent fit() calls tested anywhere
rg -n "Thread|threading|concurrent" python/cuml/cuml_accel_tests/test_grid_search.py

Repository: rapidsai/cuml

Length of output: 39


🏁 Script executed:

# Examine how using_output_type context manager handles reentrancy
sed -n '200,240p' python/cuml/cuml/internals/outputs.py

Repository: rapidsai/cuml

Length of output: 1396


🌐 Web query:

scikit-learn GridSearchCV thread safety concurrent fit calls

💡 Result:

  • Don’t call GridSearchCV.fit() concurrently on the same GridSearchCV instance. fit() mutates the object (e.g., it writes results like cv_results_, best_params_, best_estimator_), and scikit-learn objects are generally designed for single-threaded mutation during fit(). The docs also note that scikit-learn’s global configuration context is not thread-safe, which is another way concurrent fit() calls in one process can go wrong. [3]

  • Inside one GridSearchCV.fit() call, parallelism is handled for you. GridSearchCV parallelizes the many underlying estimator fits via joblib; by default it uses the loky process-based backend, which isolates workers and helps avoid shared-state/thread-safety issues during the internal parallel work. [1]

  • If you need “concurrent grid searches,” use separate objects (and avoid shared RNG objects). Run one GridSearchCV per thread/process (or protect each fit() with a lock). Also avoid passing the same np.random.RandomState instance into multiple estimators/searches, since fit() consumes (mutates) that RNG and can create cross-talk between fits; prefer an int random_state (or distinct RNG instances per run). [2]

Sources:
[1] scikit-learn docs: parallelism / GridSearchCV uses joblib loky backend by default
[2] scikit-learn docs: fit() consumes RNG; cloning and GridSearchCV clone estimators internally
[3] scikit-learn docs: config_context is not thread-safe


Make the SciPy array-API toggle thread-safe.

SCIPY_ARRAY_API and _scipy_array_api._GLOBAL_CONFIG are process-global. If users call GridSearchCV.fit() concurrently from multiple threads, the first thread's finally block can restore/pop the old value while the second thread is still running, disabling array-API mid-flight. (sklearn's own config_context has the same limitation and is documented as not thread-safe.) Fix this with a lock or ref-counting mechanism, similar to how cuml.internals.outputs handles reentrancy in enter_internal_context().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/accel/_patches/sklearn/model_selection.py` around lines 27 -
50, The current _enable_scipy_array_api context manager mutates the
process-global SCIPY_ARRAY_API env var and _scipy_array_api._GLOBAL_CONFIG
without synchronization, causing races when multiple threads call
GridSearchCV.fit(); make it thread-safe by adding a module-level lock and a
reference counter (similar to enter_internal_context in cuml.internals.outputs):
acquire the lock on entry in _enable_scipy_array_api, if the counter is zero
record the current env and _scipy_array_api._GLOBAL_CONFIG value and set both to
"1", then increment the counter and release the lock; on exit acquire the lock,
decrement the counter, and only when the counter becomes zero restore/pop the
saved env value and restore _scipy_array_api._GLOBAL_CONFIG, then release the
lock; ensure the lock and counter are module-level symbols so reentrancy across
threads is handled correctly.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
python/cuml/cuml_accel_tests/test_grid_search.py (2)

298-327: Consider adding a test for sparse input fallback.

The implementation handles sparse inputs by falling back to CPU (line 108-110 in model_selection.py), but there's no test validating this path. A test with scipy.sparse input would improve coverage.

🧪 Suggested test
def test_grid_search_sparse_input_skips_optimization(patch_methods):
    """Sparse input skips GPU optimization."""
    import scipy.sparse as sp
    
    patch_methods(Ridge, "fit")
    
    X, y = make_regression(n_samples=100, n_features=5, random_state=42)
    X_sparse = sp.csr_matrix(X)
    
    gs = GridSearchCV(Ridge(), {"alpha": [0.1, 1.0]}, cv=3)
    gs.fit(X_sparse, y)
    
    # Should receive sparse input, not cupy
    assert sp.issparse(Ridge.fit.args[0]), (
        "Expected sparse (optimization should be skipped)"
    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml_accel_tests/test_grid_search.py` around lines 298 - 327, Add
a new unit test that verifies sparse inputs trigger the CPU fallback path:
create sparse input via scipy.sparse.csr_matrix from data produced by
make_regression, patch Ridge.fit (or use the existing patch_methods helper) to
capture the first positional argument passed into Ridge.fit, run
GridSearchCV(Ridge(), {"alpha":[...]}, cv=3).fit(X_sparse, y) and assert that
the captured training input is a scipy sparse matrix (use sp.issparse) so the
code path in model_selection.py that falls back for sparse inputs is exercised.
Ensure the test name and docstring mirror other tests (e.g.,
test_grid_search_sparse_input_skips_optimization) and import scipy.sparse as sp
inside the test.

39-52: Consider extracting MockMethod to a shared test utilities module.

The XXX comment notes this pattern could be shared. Similar method-patching helpers may exist elsewhere in cuml_accel_tests. Consolidating would reduce duplication.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml_accel_tests/test_grid_search.py` around lines 39 - 52,
Extract the MockMethod class into a shared test utilities module (e.g., create
or use an existing test_utils) and replace the inline definition in
test_grid_search.py with an import; update references to the class (MockMethod)
and ensure the __get__/__call__ behavior is preserved and tests import types if
needed from the shared module; run tests to confirm no behavior changes and
remove the XXX comment from test_grid_search.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/source/cuml-accel/limitations.rst`:
- Around line 345-351: Update the limitations text to mention the
estimator-proxy requirement: add a bullet stating that GridSearchCV will fall
back to CPU if the estimator (or any Pipeline step) is not an accelerated proxy
estimator (the condition checked by _contains_proxy in model_selection.py), so
only accelerated proxy estimators (or all Pipeline steps being proxies) enable
acceleration.

---

Nitpick comments:
In `@python/cuml/cuml_accel_tests/test_grid_search.py`:
- Around line 298-327: Add a new unit test that verifies sparse inputs trigger
the CPU fallback path: create sparse input via scipy.sparse.csr_matrix from data
produced by make_regression, patch Ridge.fit (or use the existing patch_methods
helper) to capture the first positional argument passed into Ridge.fit, run
GridSearchCV(Ridge(), {"alpha":[...]}, cv=3).fit(X_sparse, y) and assert that
the captured training input is a scipy sparse matrix (use sp.issparse) so the
code path in model_selection.py that falls back for sparse inputs is exercised.
Ensure the test name and docstring mirror other tests (e.g.,
test_grid_search_sparse_input_skips_optimization) and import scipy.sparse as sp
inside the test.
- Around line 39-52: Extract the MockMethod class into a shared test utilities
module (e.g., create or use an existing test_utils) and replace the inline
definition in test_grid_search.py with an import; update references to the class
(MockMethod) and ensure the __get__/__call__ behavior is preserved and tests
import types if needed from the shared module; run tests to confirm no behavior
changes and remove the XXX comment from test_grid_search.py.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 53ce39db-6d8e-49e4-97a1-a679dfddf573

📥 Commits

Reviewing files that changed from the base of the PR and between d1f8465 and 6dd2628.

📒 Files selected for processing (3)
  • docs/source/cuml-accel/limitations.rst
  • python/cuml/cuml/accel/_patches/sklearn/model_selection.py
  • python/cuml/cuml_accel_tests/test_grid_search.py

Copy link
Copy Markdown
Member

@jcrist jcrist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick first pass through.

os.environ["SCIPY_ARRAY_API"] = "1"

old_cached = _scipy_array_api._GLOBAL_CONFIG["SCIPY_ARRAY_API"]
_scipy_array_api._GLOBAL_CONFIG["SCIPY_ARRAY_API"] = "1"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we actually don't need scipy's array api, we just need the sklearn array-api check to pass. How do you feel about not mucking in the scipy internals and instead just temporarily setting the SCIPY_ARRAY_API environment variable?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should set the env variable and make sure the cached value in scipy matches (for the case where scipy had already been imported). Otherwise we have to remember to update this context manager at a later stage when we encounter something that does use scipy (and the errors might be weird).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling the array-api once when activating cuml.accel, rather than contextually? Could do it or:

  • Scipy alone. This seems minimal, and the scipy setting is global (rather than thread-local), so doing it contextually at all feels a bit weird.
  • Both scipy and sklearn. I'm not sure what fallouts might be from enabling it in sklearn. The sklearn version is complicated by their setting being a threadlocal and not global to the process, so maybe we'd always want to leave that contextual.

I'd vote for enabling it globally once for scipy (including the mucking you do here) when cuml.accel.install() is called.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set Scipy env variable in cuml.accel install, enable array API contextually.

Copy link
Copy Markdown
Member Author

@betatim betatim Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we enable it globally then we change which parametrisations of scikit-learn tests run. For example we get test failures because PCA returns a numpy array when it should return cupy. I think this is because the estimators that we wrap are configured to always return numpy, no matter what the input array's type is. The array API tests that get enabled check that the input and output array type match. Basically more things get tested with array API input and some of those fail.

Instead of xfailing these tests I think we should consider returning a cupy array if array API support is enabled in scikit-learn. But that is a discussion bigger than this PR -> I'd either address this in a new PR or leave it as is.

@jcrist jcrist changed the title Add support for GirdSearchCV to cuml.accel Add support for GridSearchCV to cuml.accel Mar 10, 2026
@csadorf csadorf linked an issue Mar 10, 2026 that may be closed by this pull request
@betatim betatim changed the base branch from main to release/26.04 March 19, 2026 09:40
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
python/cuml/cuml/accel/_patches/sklearn/model_selection.py (1)

35-49: ⚠️ Potential issue | 🟠 Major

Serialize the SciPy array-API toggle.

Lines 35-49 mutate SCIPY_ARRAY_API and SciPy’s cached array-API config as process-global state. Two overlapping accelerated GridSearchCV.fit() calls can interleave these enter/exit paths and restore the old value while another fit is still running. Please guard this with a module-level lock/refcount instead of per-call save/restore.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/accel/_patches/sklearn/model_selection.py` around lines 35 -
49, The current context manager directly mutates os.environ["SCIPY_ARRAY_API"]
and _scipy_array_api._GLOBAL_CONFIG, which races when multiple accelerated
GridSearchCV.fit() calls run concurrently; replace the per-call save/restore
with a module-level serialization: add a module-level threading.Lock (or RLock)
and an integer refcount (e.g., _scipy_array_api_enter_count) and on
context-manager enter acquire the lock, increment the refcount and set the
env/config only when transitioning from 0->1, and on exit decrement the refcount
and restore the saved old_env and old_cached only when transitioning to 0 before
releasing the lock; keep the existing try/finally structure and reference the
same symbols os.environ["SCIPY_ARRAY_API"] and _scipy_array_api._GLOBAL_CONFIG
so concurrent GridSearchCV.fit() calls are serialized and the global config is
only restored once.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/cuml/cuml/accel/_patches/sklearn/model_selection.py`:
- Around line 158-163: The restoration logic after the using_output_type("cupy")
block only converts attributes back to host for output types (None, "numpy");
update the check on GlobalSettings().output_type within the patch that handles
attributes best_score_ and best_index_ to also treat the explicit "input" mode
as host output. In practice, modify the condition that currently tests
GlobalSettings().output_type in (None, "numpy") to include "input" so that
getattr(self, attr, None) values are passed through ensure_host(...) for
("best_score_", "best_index_") and do not leak CuPy scalars to callers of the
patched estimators.
- Around line 134-148: The params dict conversion only moves numpy ndarrays to
GPU, leaving list/pandas/other array-like sample-wise params on host which will
be sliced with device indices and fail; update the params normalization (the
comprehension that builds params) so that for every key != "groups" you convert
any array-like non-cupy value to CuPy (e.g., use cp.asarray for values that are
not None and not already cp.ndarray, detecting array-likeness via hasattr(v,
"__array__") or isinstance(v, (list, tuple, np.ndarray, pd.Series)) ),
preserving None and leaving "groups" on host; alternatively, if you prefer
strictness, detect any non-cupy array-like non-groups param and raise/bail out
early — change the block that creates X_gpu, y_gpu and the params dict to
perform this normalization.

---

Duplicate comments:
In `@python/cuml/cuml/accel/_patches/sklearn/model_selection.py`:
- Around line 35-49: The current context manager directly mutates
os.environ["SCIPY_ARRAY_API"] and _scipy_array_api._GLOBAL_CONFIG, which races
when multiple accelerated GridSearchCV.fit() calls run concurrently; replace the
per-call save/restore with a module-level serialization: add a module-level
threading.Lock (or RLock) and an integer refcount (e.g.,
_scipy_array_api_enter_count) and on context-manager enter acquire the lock,
increment the refcount and set the env/config only when transitioning from 0->1,
and on exit decrement the refcount and restore the saved old_env and old_cached
only when transitioning to 0 before releasing the lock; keep the existing
try/finally structure and reference the same symbols
os.environ["SCIPY_ARRAY_API"] and _scipy_array_api._GLOBAL_CONFIG so concurrent
GridSearchCV.fit() calls are serialized and the global config is only restored
once.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0363691a-0c22-441f-9f55-dd78ee41bef7

📥 Commits

Reviewing files that changed from the base of the PR and between c6469e4 and 8c4371f.

📒 Files selected for processing (1)
  • python/cuml/cuml/accel/_patches/sklearn/model_selection.py

@betatim
Copy link
Copy Markdown
Member Author

betatim commented Mar 23, 2026

Benchmark if it is actually faster to do this, or if it is faster to do the random subset on the host and then move that data (even if we do that repeatedly).

Investigate detecting if we are in a joblib.parallel_backend context that sets n_jobs to a value that is not 1.

@csadorf csadorf force-pushed the cuml-accel-gridsearchcv branch from 1d63725 to 9ec3a05 Compare April 3, 2026 21:26
@csadorf csadorf requested review from a team as code owners April 3, 2026 21:26
@csadorf csadorf requested review from aamijar, jcrist and msarahan April 3, 2026 21:26
@github-actions github-actions bot added conda conda issue CUDA/C++ labels Apr 3, 2026
@csadorf csadorf changed the base branch from release/26.04 to main April 3, 2026 21:27
@csadorf csadorf removed request for a team and msarahan April 3, 2026 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

conda conda issue CUDA/C++ Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add cuml.accel support for GridSearchCV

4 participants