Skip to content

Commit 02f6ed1

Browse files
committed
use tridiagonal constructor
1 parent f40166f commit 02f6ed1

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

include/matx/operators/interp.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ namespace matx {
443443
// Allocate temporary storage for tridiagonal system
444444
// use a single buffer for all three diagonals so that we can use the DIA format
445445
value_type *ptr_tridiag_ = nullptr;
446-
matxAlloc((void**)&ptr_tridiag_, 3 * batch_count * n * sizeof(value_type), MATX_MANAGED_MEMORY, stream);
446+
matxAlloc((void**)&ptr_tridiag_, 3 * batch_count * n * sizeof(value_type), MATX_ASYNC_DEVICE_MEMORY, stream);
447447
value_type *ptr_dl_ = ptr_tridiag_;
448448
value_type *ptr_d_ = ptr_tridiag_ + batch_count * n;
449449
value_type *ptr_du_ = ptr_tridiag_ + batch_count * n * 2;
@@ -458,9 +458,8 @@ namespace matx {
458458

459459
// // Convert to uniform batched dia format
460460
auto val_tensor = make_tensor(ptr_tridiag_, {batch_count * n * 3});
461-
auto offset_tensor = make_tensor<index_t>({3}, MATX_MANAGED_MEMORY, stream);
462-
offset_tensor.SetVals({-1, 0, 1});
463-
auto A = experimental::make_tensor_uniform_batched_dia<experimental::DIA_INDEX_I>(val_tensor, offset_tensor, {batch_count, n, n});
461+
462+
auto A = experimental::make_tensor_uniform_batched_tri_dia<experimental::DIA_INDEX_I>(val_tensor, {batch_count, n, n});
464463

465464
auto M = make_tensor(ptr_m_, {batch_count * n});
466465

0 commit comments

Comments
 (0)