@@ -443,7 +443,7 @@ namespace matx {
443443 // Allocate temporary storage for tridiagonal system
444444 // use a single buffer for all three diagonals so that we can use the DIA format
445445 value_type *ptr_tridiag_ = nullptr ;
446- matxAlloc ((void **)&ptr_tridiag_, 3 * batch_count * n * sizeof (value_type), MATX_MANAGED_MEMORY , stream);
446+ matxAlloc ((void **)&ptr_tridiag_, 3 * batch_count * n * sizeof (value_type), MATX_ASYNC_DEVICE_MEMORY , stream);
447447 value_type *ptr_dl_ = ptr_tridiag_;
448448 value_type *ptr_d_ = ptr_tridiag_ + batch_count * n;
449449 value_type *ptr_du_ = ptr_tridiag_ + batch_count * n * 2 ;
@@ -458,9 +458,8 @@ namespace matx {
458458
459459 // // Convert to uniform batched dia format
460460 auto val_tensor = make_tensor (ptr_tridiag_, {batch_count * n * 3 });
461- auto offset_tensor = make_tensor<index_t >({3 }, MATX_MANAGED_MEMORY, stream);
462- offset_tensor.SetVals ({-1 , 0 , 1 });
463- auto A = experimental::make_tensor_uniform_batched_dia<experimental::DIA_INDEX_I>(val_tensor, offset_tensor, {batch_count, n, n});
461+
462+ auto A = experimental::make_tensor_uniform_batched_tri_dia<experimental::DIA_INDEX_I>(val_tensor, {batch_count, n, n});
464463
465464 auto M = make_tensor (ptr_m_, {batch_count * n});
466465
0 commit comments