@@ -428,105 +428,46 @@ namespace matx {
428428 // Allocate temporary storage for spline coefficients
429429 if (method_ == InterpMethod::SPLINE) {
430430 static_assert (is_cuda_executor_v<Executor>, " cubic spline interpolation only supports the CUDA executor currently" );
431+ cudaStream_t stream = ex.getStream ();
431432
432- index_t _batch_count = 1 ;
433+ index_t batch_count = 1 ;
433434 for (int i = 0 ; i < v_.Rank () - 1 ; i++) {
434- _batch_count *= v_.Size (i);
435- }
436- index_t _n = v_.Size (v_.Rank () - 1 );
437- if (_batch_count > std::numeric_limits<int >::max () || _n > std::numeric_limits<int >::max ()) {
438- const std::string err_msg = " Spline interpolation is not supported for tensors with more than 2^" + std::to_string (std::numeric_limits<int >::digits) + " items" ;
439- MATX_THROW (matxInvalidSize, err_msg.c_str ());
435+ batch_count *= v_.Size (i);
440436 }
441- int batch_count = static_cast <int >(_batch_count);
442- int n = static_cast <int >(_n); // number of sample points
437+ index_t n = v_.Size (v_.Rank () - 1 );
443438
439+
444440 cuda::std::array m_shape = v_.Shape ();
445441 detail::AllocateTempTensor (m_, std::forward<Executor>(ex), m_shape, &ptr_m_);
446442
447- detail::tensor_impl_t <value_type, OpV::Rank ()> d_tensor, dl_tensor, du_tensor; // Derivatives at sample points (spline only)
448- value_type *ptr_dl_ = nullptr ;
449- value_type *ptr_d_ = nullptr ;
450- value_type *ptr_du_ = nullptr ;
443+ // Allocate temporary storage for tridiagonal system
444+ // use a single buffer for all three diagonals so that we can use the DIA format
445+ value_type *ptr_tridiag_ = nullptr ;
446+ matxAlloc ((void **)&ptr_tridiag_, 3 * batch_count * n * sizeof (value_type), MATX_ASYNC_DEVICE_MEMORY, stream);
447+ value_type *ptr_dl_ = ptr_tridiag_;
448+ value_type *ptr_d_ = ptr_tridiag_ + batch_count * n;
449+ value_type *ptr_du_ = ptr_tridiag_ + batch_count * n * 2 ;
451450
452- detail::AllocateTempTensor (dl_tensor, std::forward<Executor>(ex), m_shape, &ptr_dl_);
453- detail::AllocateTempTensor (d_tensor, std::forward<Executor>(ex), m_shape, &ptr_d_);
454- detail::AllocateTempTensor (du_tensor, std::forward<Executor>(ex), m_shape, &ptr_du_);
451+ detail::tensor_impl_t <value_type, OpV::Rank ()> dl_tensor, d_tensor, du_tensor; // Derivatives at sample points (spline only)
452+ make_tensor (dl_tensor, ptr_dl_, m_shape);
453+ make_tensor (d_tensor, ptr_d_, m_shape);
454+ make_tensor (du_tensor, ptr_du_, m_shape);
455455
456456 // Fill tridiagonal system via custom operator
457- InterpSplineTridiagonalFillOp (dl_tensor,d_tensor, du_tensor, m_, x_, v_).run (std::forward<Executor>(ex));
457+ InterpSplineTridiagonalFillOp (dl_tensor, d_tensor, du_tensor, m_, x_, v_).run (std::forward<Executor>(ex));
458+
459+ // // Convert to uniform batched dia format
460+ auto val_tensor = make_tensor (ptr_tridiag_, {batch_count * n * 3 });
461+
462+ auto A = experimental::make_tensor_uniform_batched_tri_dia<experimental::DIA_INDEX_I>(val_tensor, {batch_count, n, n});
463+
464+ auto M = make_tensor (ptr_m_, {batch_count * n});
458465
459466 // Solve tridiagonal system using cuSPARSE
460- cudaStream_t stream = ex.getStream ();
461- cusparseHandle_t handle = nullptr ;
462- [[maybe_unused]] cusparseStatus_t cusparse_status = cusparseCreate (&handle);
463- MATX_ASSERT (cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
464- cusparse_status = cusparseSetStream (handle, stream);
465- MATX_ASSERT (cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
466-
467- size_t workspace_size = 0 ;
468- void * workspace = nullptr ;
469- if constexpr (std::is_same_v<value_type, float >) {
470- cusparse_status = cusparseSgtsv2StridedBatch_bufferSizeExt (
471- handle, // cuSPARSE handle
472- n, // n
473- ptr_dl_, // sub-diagonal
474- ptr_d_, // main-diagonal
475- ptr_du_, // super-diagonal
476- ptr_m_, // right-hand side and solution
477- batch_count, // batch_count
478- n, // batch_stride
479- &workspace_size); // workspace size
480- } else if constexpr (std::is_same_v<value_type, double >) {
481- cusparse_status = cusparseDgtsv2StridedBatch_bufferSizeExt (
482- handle, // cuSPARSE handle
483- n, // n
484- ptr_dl_, // sub-diagonal
485- ptr_d_, // main-diagonal
486- ptr_du_, // super-diagonal
487- ptr_m_, // right-hand side and solution
488- batch_count, // batch_count
489- n, // batch_stride
490- &workspace_size); // workspace size
491- }
492- MATX_ASSERT (cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
493- [[maybe_unused]] cudaError_t err = cudaMallocAsync (&workspace, workspace_size, stream);
494- MATX_ASSERT (err == cudaSuccess, matxCudaError);
495-
496- if constexpr (std::is_same_v<value_type, float >) {
497- cusparse_status = cusparseSgtsv2StridedBatch (
498- handle, // cuSPARSE handle
499- n, // Size of the system
500- ptr_dl_, // Sub-diagonal
501- ptr_d_, // Main diagonal
502- ptr_du_, // Super-diagonal
503- ptr_m_, // Right-hand side and solution
504- batch_count, // batch_count
505- n, // batch_stride
506- workspace); // Workspace buffer
507- } else if constexpr (std::is_same_v<value_type, double >) {
508- cusparse_status = cusparseDgtsv2StridedBatch (
509- handle, // cuSPARSE handle
510- n, // Size of the system
511- ptr_dl_, // Sub-diagonal
512- ptr_d_, // Main diagonal
513- ptr_du_, // Super-diagonal
514- ptr_m_, // Right-hand side and solution
515- batch_count, // batch_count
516- n, // batch_stride
517- workspace); // Workspace buffer
518- }
519- MATX_ASSERT (cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
520- // cleanup
521- err = cudaFreeAsync (workspace, stream);
522- MATX_ASSERT (err == cudaSuccess, matxCudaError);
523- cusparse_status = cusparseDestroy (handle);
524- MATX_ASSERT (cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
525- matxFree (ptr_d_);
526- matxFree (ptr_dl_);
527- matxFree (ptr_du_);
528- }
467+ (M = solve (A, M)).run (std::forward<Executor>(ex));
529468
469+ matxFree (ptr_tridiag_);
470+ }
530471 }
531472
532473 template <typename ShapeType, typename Executor>
0 commit comments