Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions include/matx/transforms/inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ class matxInversePlan_t {
* Inverse of A (if it exists)
*
*/
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a)
{
static_assert(RANK >= 2);

MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

stream_ = stream;

// Ok to remove since we're just passing a list of RO pointers
//using a_nc = typename std::remove_const<decltype(a)>(a);
Expand All @@ -125,29 +123,26 @@ class matxInversePlan_t {
// here as our batch dims
std::vector<const T1 *> in_pointers;
std::vector<T1 *> out_pointers;
make_tensor(tmp_a_, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
(tmp_a_ = a).run(stream);

if constexpr (RANK == 2) {
in_pointers.push_back(&tmp_a_(0, 0));
in_pointers.push_back(&a(0, 0));
out_pointers.push_back(&a_inv(0, 0));
}
else {
using shape_type = typename TensorTypeA::desc_type::shape_type;
int batch_offset = 2;
std::array<shape_type, TensorTypeA::Rank()> idx{0};
auto a_shape = tmp_a_.Shape();
auto a_shape = a.Shape();
// Get total number of batches
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<shape_type>());
for (size_t iter = 0; iter < total_iter; iter++) {
auto ip = std::apply([&](auto... param) { return tmp_a_.GetPointer(param...); }, idx);
auto ip = std::apply([&a](auto... param) { return a.GetPointer(param...); }, idx);
auto op = std::apply([&a_inv](auto... param) { return a_inv.GetPointer(param...); }, idx);

in_pointers.push_back(ip);
out_pointers.push_back(op);

// Update all but the last 2 indices
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(tmp_a_, idx, batch_offset);
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a, idx, batch_offset);
}
}

Expand Down Expand Up @@ -312,8 +307,6 @@ class matxInversePlan_t {
int *d_info;
T1 **d_A_array;
T1 **d_A_inv_array;
cudaStream_t stream_;
matx::tensor_t<typename TensorTypeA::scalar_type, TensorTypeA::Rank()> tmp_a_;
};

/**
Expand Down Expand Up @@ -374,7 +367,7 @@ void inv_impl(TensorTypeAInv &a_inv, const TensorTypeA &a,
// Get cache or new inverse plan if it doesn't exist
auto ret = detail::inv_cache.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxInversePlan_t{a_inv, a, stream};
auto tmp = new detail::matxInversePlan_t{a_inv, a};
detail::inv_cache.Insert(params, static_cast<void *>(tmp));
tmp->Exec(stream);
}
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/resample_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ inline void resample_poly_impl(OutType &out, const InType &in, const FilterType
}

const index_t up_size = in.Size(RANK-1) * up;
[[maybe_unused]] const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);
const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);

MATX_ASSERT_STR(out.Size(RANK-1) == outlen, matxInvalidDim, "resample_poly: output size mismatch");

Expand Down
10 changes: 5 additions & 5 deletions test/00_solver/Inverse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4)

// example-begin inv-test-1
// Perform an inverse on matrix "A" and store the output in "Ainv"
inv_impl(Ainv, A);
(Ainv = inv(A)).run();
// example-end inv-test-1
cudaStreamSynchronize(0);

Expand Down Expand Up @@ -103,7 +103,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4Batched)
this->pb->NumpyToTensorView(A, "A");
this->pb->NumpyToTensorView(Ainv_ref, "A_inv");

inv_impl(Ainv, A, 0);
(Ainv = inv(A)).run();
cudaStreamSynchronize(0);

for (index_t b = 0; b < A.Size(0); b++) {
Expand Down Expand Up @@ -135,7 +135,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8)
this->pb->NumpyToTensorView(A, "A");
this->pb->NumpyToTensorView(Ainv_ref, "A_inv");

inv_impl(Ainv, A, 0);
(Ainv = inv(A)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < A.Size(0); i++) {
Expand Down Expand Up @@ -165,7 +165,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8Batched)
this->pb->NumpyToTensorView(A, "A");
this->pb->NumpyToTensorView(Ainv_ref, "A_inv");

inv_impl(Ainv, A, 0);
(Ainv = inv(A)).run();
cudaStreamSynchronize(0);

for (index_t b = 0; b < A.Size(0); b++) {
Expand Down Expand Up @@ -198,7 +198,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv256x256)
this->pb->NumpyToTensorView(A, "A");
this->pb->NumpyToTensorView(Ainv_ref, "A_inv");

inv_impl(Ainv, A, 0);
(Ainv = inv(A)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < A.Size(0); i++) {
Expand Down