Skip to content

Conversation

@aartbik
Copy link
Collaborator

@aartbik aartbik commented Jun 3, 2025

This uses cuSPARSE's legacy API to solve a tri-diagonal matrix direct solve operation (lower diagonal, main diagonal, upper diagonal). This involves a run-time check on DIA (which is more general) to ensure only those three diagonals are used.

This uses cuSPARSE's legacy API to solve a tri-diagonal matrix
direct solve operation (lower diagonal, main diagonal, upper
diagonal). This involves a run-time check on DIA (which is more
general) to ensure only those three diagonals are used.
@aartbik aartbik requested review from cliffburdick and Copilot June 3, 2025 19:32
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds direct tri-diagonal solve support using cuSPARSE’s legacy API with runtime checks for three diagonals in DIA format and updates test and example code.

  • Introduce solve_cusparse.h implementing SolveTridiagonalSystem and hooking it into sparse_dia_solve_impl
  • Update SolveOp to dispatch to the new DIA solver for I–indexed DIA tensors
  • Add tests (SolveDIAI) and an example for tri-diagonal solves

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
test/00_sparse/Dia.cu Renamed test fixtures and added SolveDIAI tests
include/matx/transforms/solve/solve_cusparse.h New file implementing cuSPARSE-based tridiagonal solve
include/matx/operators/solve.h Include solver header and dispatch solve for DIA
examples/sparse_tensor.cu Example usage demonstrating direct tri-diagonal solve
Comments suppressed due to low confidence (4)

test/00_sparse/Dia.cu:118

  • [nitpick] The test suite registration uses MatXFloatNonHalfTypesCUDAExec, which doesn’t match the suite used elsewhere. Ensure this type list is correct or replace it with MatXFloatNonComplexHalfTypesCUDAExec for consistency.
TYPED_TEST_SUITE(DiaSolveSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);

include/matx/operators/solve.h:95

  • The dispatch allows both DIAI and DIAJ but the implementation only supports I-index DIA. Restrict this to isDIAI() to prevent runtime errors when using DIAJ.
if constexpr (OpA::Format::isDIAI() || OpA::Format::isDIAJ()) {

include/matx/transforms/solve/solve_cusparse.h:159

  • [nitpick] The message 'Tridiagonal solve overwrites rhs' is ambiguous. Consider clarifying it to 'In-place RHS required for tridiagonal solve' to better convey the constraint.
MATX_THROW(matxNotSupported, "Tridiagonal solve overwrites rhs");

test/00_sparse/Dia.cu:184

  • The new SolveDIAI test covers real types but omits complex specializations. Add tests for complex value types to ensure full coverage of supported data types.
TYPED_TEST(DiaSolveSparseTestsAll, SolveDIAI) {

@cliffburdick
Copy link
Collaborator

/build

@cliffburdick cliffburdick merged commit 3a6ad86 into NVIDIA:main Jun 3, 2025
1 check passed
@cliffburdick cliffburdick deleted the bik branch June 3, 2025 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants