Skip to content

Commit fa9e872

Browse files
authored
use batched sparse solve for interp (#1016)
* use batched sparse solve for interp * use tridiagonal constructor
1 parent df8c5b2 commit fa9e872

File tree

1 file changed

+27
-86
lines changed

1 file changed

+27
-86
lines changed

include/matx/operators/interp.h

Lines changed: 27 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -428,105 +428,46 @@ namespace matx {
428428
// Allocate temporary storage for spline coefficients
429429
if (method_ == InterpMethod::SPLINE) {
430430
static_assert(is_cuda_executor_v<Executor>, "cubic spline interpolation only supports the CUDA executor currently");
431+
cudaStream_t stream = ex.getStream();
431432

432-
index_t _batch_count = 1;
433+
index_t batch_count = 1;
433434
for (int i = 0; i < v_.Rank() - 1; i++) {
434-
_batch_count *= v_.Size(i);
435-
}
436-
index_t _n = v_.Size(v_.Rank() - 1);
437-
if (_batch_count > std::numeric_limits<int>::max() || _n > std::numeric_limits<int>::max()) {
438-
const std::string err_msg = "Spline interpolation is not supported for tensors with more than 2^" + std::to_string(std::numeric_limits<int>::digits) + " items";
439-
MATX_THROW(matxInvalidSize, err_msg.c_str());
435+
batch_count *= v_.Size(i);
440436
}
441-
int batch_count = static_cast<int>(_batch_count);
442-
int n = static_cast<int>(_n); // number of sample points
437+
index_t n = v_.Size(v_.Rank() - 1);
443438

439+
444440
cuda::std::array m_shape = v_.Shape();
445441
detail::AllocateTempTensor(m_, std::forward<Executor>(ex), m_shape, &ptr_m_);
446442

447-
detail::tensor_impl_t<value_type, OpV::Rank()> d_tensor, dl_tensor, du_tensor; // Derivatives at sample points (spline only)
448-
value_type *ptr_dl_ = nullptr;
449-
value_type *ptr_d_ = nullptr;
450-
value_type *ptr_du_ = nullptr;
443+
// Allocate temporary storage for tridiagonal system
444+
// use a single buffer for all three diagonals so that we can use the DIA format
445+
value_type *ptr_tridiag_ = nullptr;
446+
matxAlloc((void**)&ptr_tridiag_, 3 * batch_count * n * sizeof(value_type), MATX_ASYNC_DEVICE_MEMORY, stream);
447+
value_type *ptr_dl_ = ptr_tridiag_;
448+
value_type *ptr_d_ = ptr_tridiag_ + batch_count * n;
449+
value_type *ptr_du_ = ptr_tridiag_ + batch_count * n * 2;
451450

452-
detail::AllocateTempTensor(dl_tensor, std::forward<Executor>(ex), m_shape, &ptr_dl_);
453-
detail::AllocateTempTensor(d_tensor, std::forward<Executor>(ex), m_shape, &ptr_d_);
454-
detail::AllocateTempTensor(du_tensor, std::forward<Executor>(ex), m_shape, &ptr_du_);
451+
detail::tensor_impl_t<value_type, OpV::Rank()> dl_tensor, d_tensor, du_tensor; // Derivatives at sample points (spline only)
452+
make_tensor(dl_tensor, ptr_dl_, m_shape);
453+
make_tensor(d_tensor, ptr_d_, m_shape);
454+
make_tensor(du_tensor, ptr_du_, m_shape);
455455

456456
// Fill tridiagonal system via custom operator
457-
InterpSplineTridiagonalFillOp(dl_tensor,d_tensor, du_tensor, m_, x_, v_).run(std::forward<Executor>(ex));
457+
InterpSplineTridiagonalFillOp(dl_tensor, d_tensor, du_tensor, m_, x_, v_).run(std::forward<Executor>(ex));
458+
459+
// // Convert to uniform batched dia format
460+
auto val_tensor = make_tensor(ptr_tridiag_, {batch_count * n * 3});
461+
462+
auto A = experimental::make_tensor_uniform_batched_tri_dia<experimental::DIA_INDEX_I>(val_tensor, {batch_count, n, n});
463+
464+
auto M = make_tensor(ptr_m_, {batch_count * n});
458465

459466
// Solve tridiagonal system using cuSPARSE
460-
cudaStream_t stream = ex.getStream();
461-
cusparseHandle_t handle = nullptr;
462-
[[maybe_unused]] cusparseStatus_t cusparse_status = cusparseCreate(&handle);
463-
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
464-
cusparse_status = cusparseSetStream(handle, stream);
465-
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
466-
467-
size_t workspace_size = 0;
468-
void* workspace = nullptr;
469-
if constexpr (std::is_same_v<value_type, float>) {
470-
cusparse_status = cusparseSgtsv2StridedBatch_bufferSizeExt(
471-
handle, // cuSPARSE handle
472-
n, // n
473-
ptr_dl_, // sub-diagonal
474-
ptr_d_, // main-diagonal
475-
ptr_du_, // super-diagonal
476-
ptr_m_, // right-hand side and solution
477-
batch_count, // batch_count
478-
n, // batch_stride
479-
&workspace_size); // workspace size
480-
} else if constexpr (std::is_same_v<value_type, double>) {
481-
cusparse_status = cusparseDgtsv2StridedBatch_bufferSizeExt(
482-
handle, // cuSPARSE handle
483-
n, // n
484-
ptr_dl_, // sub-diagonal
485-
ptr_d_, // main-diagonal
486-
ptr_du_, // super-diagonal
487-
ptr_m_, // right-hand side and solution
488-
batch_count, // batch_count
489-
n, // batch_stride
490-
&workspace_size); // workspace size
491-
}
492-
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
493-
[[maybe_unused]] cudaError_t err = cudaMallocAsync(&workspace, workspace_size, stream);
494-
MATX_ASSERT(err == cudaSuccess, matxCudaError);
495-
496-
if constexpr (std::is_same_v<value_type, float>) {
497-
cusparse_status = cusparseSgtsv2StridedBatch(
498-
handle, // cuSPARSE handle
499-
n, // Size of the system
500-
ptr_dl_, // Sub-diagonal
501-
ptr_d_, // Main diagonal
502-
ptr_du_, // Super-diagonal
503-
ptr_m_, // Right-hand side and solution
504-
batch_count, // batch_count
505-
n, // batch_stride
506-
workspace); // Workspace buffer
507-
} else if constexpr (std::is_same_v<value_type, double>) {
508-
cusparse_status = cusparseDgtsv2StridedBatch(
509-
handle, // cuSPARSE handle
510-
n, // Size of the system
511-
ptr_dl_, // Sub-diagonal
512-
ptr_d_, // Main diagonal
513-
ptr_du_, // Super-diagonal
514-
ptr_m_, // Right-hand side and solution
515-
batch_count, // batch_count
516-
n, // batch_stride
517-
workspace); // Workspace buffer
518-
}
519-
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
520-
// cleanup
521-
err = cudaFreeAsync(workspace, stream);
522-
MATX_ASSERT(err == cudaSuccess, matxCudaError);
523-
cusparse_status = cusparseDestroy(handle);
524-
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
525-
matxFree(ptr_d_);
526-
matxFree(ptr_dl_);
527-
matxFree(ptr_du_);
528-
}
467+
(M = solve(A, M)).run(std::forward<Executor>(ex));
529468

469+
matxFree(ptr_tridiag_);
470+
}
530471
}
531472

532473
template <typename ShapeType, typename Executor>

0 commit comments

Comments
 (0)