Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 27 additions & 86 deletions include/matx/operators/interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,105 +428,46 @@ namespace matx {
// Allocate temporary storage for spline coefficients
if (method_ == InterpMethod::SPLINE) {
static_assert(is_cuda_executor_v<Executor>, "cubic spline interpolation only supports the CUDA executor currently");
cudaStream_t stream = ex.getStream();

index_t _batch_count = 1;
index_t batch_count = 1;
for (int i = 0; i < v_.Rank() - 1; i++) {
_batch_count *= v_.Size(i);
}
index_t _n = v_.Size(v_.Rank() - 1);
if (_batch_count > std::numeric_limits<int>::max() || _n > std::numeric_limits<int>::max()) {
const std::string err_msg = "Spline interpolation is not supported for tensors with more than 2^" + std::to_string(std::numeric_limits<int>::digits) + " items";
MATX_THROW(matxInvalidSize, err_msg.c_str());
batch_count *= v_.Size(i);
}
int batch_count = static_cast<int>(_batch_count);
int n = static_cast<int>(_n); // number of sample points
index_t n = v_.Size(v_.Rank() - 1);


cuda::std::array m_shape = v_.Shape();
detail::AllocateTempTensor(m_, std::forward<Executor>(ex), m_shape, &ptr_m_);

detail::tensor_impl_t<value_type, OpV::Rank()> d_tensor, dl_tensor, du_tensor; // Derivatives at sample points (spline only)
value_type *ptr_dl_ = nullptr;
value_type *ptr_d_ = nullptr;
value_type *ptr_du_ = nullptr;
// Allocate temporary storage for tridiagonal system
// use a single buffer for all three diagonals so that we can use the DIA format
value_type *ptr_tridiag_ = nullptr;
matxAlloc((void**)&ptr_tridiag_, 3 * batch_count * n * sizeof(value_type), MATX_ASYNC_DEVICE_MEMORY, stream);
value_type *ptr_dl_ = ptr_tridiag_;
value_type *ptr_d_ = ptr_tridiag_ + batch_count * n;
value_type *ptr_du_ = ptr_tridiag_ + batch_count * n * 2;

detail::AllocateTempTensor(dl_tensor, std::forward<Executor>(ex), m_shape, &ptr_dl_);
detail::AllocateTempTensor(d_tensor, std::forward<Executor>(ex), m_shape, &ptr_d_);
detail::AllocateTempTensor(du_tensor, std::forward<Executor>(ex), m_shape, &ptr_du_);
detail::tensor_impl_t<value_type, OpV::Rank()> dl_tensor, d_tensor, du_tensor; // Derivatives at sample points (spline only)
make_tensor(dl_tensor, ptr_dl_, m_shape);
make_tensor(d_tensor, ptr_d_, m_shape);
make_tensor(du_tensor, ptr_du_, m_shape);

// Fill tridiagonal system via custom operator
InterpSplineTridiagonalFillOp(dl_tensor,d_tensor, du_tensor, m_, x_, v_).run(std::forward<Executor>(ex));
InterpSplineTridiagonalFillOp(dl_tensor, d_tensor, du_tensor, m_, x_, v_).run(std::forward<Executor>(ex));

// // Convert to uniform batched dia format
auto val_tensor = make_tensor(ptr_tridiag_, {batch_count * n * 3});

auto A = experimental::make_tensor_uniform_batched_tri_dia<experimental::DIA_INDEX_I>(val_tensor, {batch_count, n, n});

auto M = make_tensor(ptr_m_, {batch_count * n});

// Solve tridiagonal system using cuSPARSE
cudaStream_t stream = ex.getStream();
cusparseHandle_t handle = nullptr;
[[maybe_unused]] cusparseStatus_t cusparse_status = cusparseCreate(&handle);
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
cusparse_status = cusparseSetStream(handle, stream);
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);

size_t workspace_size = 0;
void* workspace = nullptr;
if constexpr (std::is_same_v<value_type, float>) {
cusparse_status = cusparseSgtsv2StridedBatch_bufferSizeExt(
handle, // cuSPARSE handle
n, // n
ptr_dl_, // sub-diagonal
ptr_d_, // main-diagonal
ptr_du_, // super-diagonal
ptr_m_, // right-hand side and solution
batch_count, // batch_count
n, // batch_stride
&workspace_size); // workspace size
} else if constexpr (std::is_same_v<value_type, double>) {
cusparse_status = cusparseDgtsv2StridedBatch_bufferSizeExt(
handle, // cuSPARSE handle
n, // n
ptr_dl_, // sub-diagonal
ptr_d_, // main-diagonal
ptr_du_, // super-diagonal
ptr_m_, // right-hand side and solution
batch_count, // batch_count
n, // batch_stride
&workspace_size); // workspace size
}
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
[[maybe_unused]] cudaError_t err = cudaMallocAsync(&workspace, workspace_size, stream);
MATX_ASSERT(err == cudaSuccess, matxCudaError);

if constexpr (std::is_same_v<value_type, float>) {
cusparse_status = cusparseSgtsv2StridedBatch(
handle, // cuSPARSE handle
n, // Size of the system
ptr_dl_, // Sub-diagonal
ptr_d_, // Main diagonal
ptr_du_, // Super-diagonal
ptr_m_, // Right-hand side and solution
batch_count, // batch_count
n, // batch_stride
workspace); // Workspace buffer
} else if constexpr (std::is_same_v<value_type, double>) {
cusparse_status = cusparseDgtsv2StridedBatch(
handle, // cuSPARSE handle
n, // Size of the system
ptr_dl_, // Sub-diagonal
ptr_d_, // Main diagonal
ptr_du_, // Super-diagonal
ptr_m_, // Right-hand side and solution
batch_count, // batch_count
n, // batch_stride
workspace); // Workspace buffer
}
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
// cleanup
err = cudaFreeAsync(workspace, stream);
MATX_ASSERT(err == cudaSuccess, matxCudaError);
cusparse_status = cusparseDestroy(handle);
MATX_ASSERT(cusparse_status == CUSPARSE_STATUS_SUCCESS, matxCudaError);
matxFree(ptr_d_);
matxFree(ptr_dl_);
matxFree(ptr_du_);
}
(M = solve(A, M)).run(std::forward<Executor>(ex));

matxFree(ptr_tridiag_);
}
}

template <typename ShapeType, typename Executor>
Expand Down