Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
// load filter
for (uint32_t idx = threadIdx.x; idx < filter_len; idx += THREADS) {
bdims[Rank - 1] = idx;
detail::mapply([&](auto &&...args) {
detail::mapply([&, d_filter](auto &&...args) {
s_filter[idx] = d_filter.operator()(args...);
}, bdims);
}
Expand Down Expand Up @@ -103,7 +103,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,

if( gidx >= 0 && gidx < signal_len) {
bdims[Rank - 1] = gidx;
detail::mapply([&](auto &&...args) {
detail::mapply([&val, d_in](auto &&...args) {
val = d_in.operator()(args...);
}, bdims);
}
Expand Down
16 changes: 11 additions & 5 deletions include/matx/operators/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,20 @@ namespace matx
using matxop = bool;
using scalar_type = NewType;

__MATX_INLINE__ std::string str() const { return as_type_str<NewType>() + "(" + op_.str() + ")"; }
__MATX_INLINE__ std::string str() const { return as_type_str<NewType>() + "(" + op_.str() + ")"; }
__MATX_INLINE__ CastOp(T op) : op_(op){};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
return static_cast<NewType>(op_(indices...));
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
return static_cast<NewType>(op_(indices...));
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
return static_cast<NewType>(op_(indices...));
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
39 changes: 28 additions & 11 deletions include/matx/operators/clone.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,39 @@ namespace matx
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{

// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> sind{indices...};
std::array<index_t, T::Rank()> gind;
// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> sind{indices...};
std::array<index_t, T::Rank()> gind;

// gather indices
for(int i = 0; i < T::Rank(); i++) {
auto idx = dims_[i];
gind[i] = sind[idx];
}
// gather indices
for(int i = 0; i < T::Rank(); i++) {
auto idx = dims_[i];
gind[i] = sind[idx];
}

return mapply(op_, gind);
return mapply(op_, gind);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{

// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> sind{indices...};
std::array<index_t, T::Rank()> gind;

// gather indices
for(int i = 0; i < T::Rank(); i++) {
auto idx = dims_[i];
gind[i] = sind[idx];
}

return mapply(op_, gind);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return CRank;
Expand Down
21 changes: 15 additions & 6 deletions include/matx/operators/fftshift.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ namespace matx
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<Rank()-1>(tup) = (cuda::std::get<Rank()-1>(tup) + (Size(Rank()-1) + 1) / 2) % Size(Rank()-1);
return mapply(op_, tup);
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<Rank()-1>(tup) = (cuda::std::get<Rank()-1>(tup) + (Size(Rank()-1) + 1) / 2) % Size(Rank()-1);
return mapply(op_, tup);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<Rank()-1>(tup) = (cuda::std::get<Rank()-1>(tup) + (Size(Rank()-1) + 1) / 2) % Size(Rank()-1);
return mapply(op_, tup);
}


static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
20 changes: 13 additions & 7 deletions include/matx/operators/flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ namespace matx
__MATX_INLINE__ std::string str() const { return "flatten(" + op1_.str() + ")"; }

__MATX_INLINE__ FlattenOp(const T1 &op1) : op1_(op1)
{
static_assert(T1::Rank() > 1, "flatten has no effect on tensors of rank 0 and 1");
}
{
static_assert(T1::Rank() > 1, "flatten has no effect on tensors of rank 0 and 1");
}

template <typename Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is id0) const
{
return *RandomOperatorIterator{op1_, id0};
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is id0) const
{
return *RandomOperatorIterator{op1_, id0};
}

template <typename Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is id0)
{
return *RandomOperatorOutputIterator{op1_, id0};
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
33 changes: 23 additions & 10 deletions include/matx/operators/interleaved.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,34 @@ namespace matx

__MATX_INLINE__ ComplexInterleavedOp(T1 op) : op_(op) {
static_assert(!is_complex_v<extract_scalar_type_t<T1>>, "Complex interleaved op only works on scalar input types");
static_assert(Rank() > 0);
static_assert(Rank() > 0);
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto real = op_(indices...);
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto real = op_(indices...);

constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<rank_idx>(tup) += op_.Size(rank_idx) / 2;
constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<rank_idx>(tup) += op_.Size(rank_idx) / 2;

auto imag = mapply(op_, tup);
return complex_type{real, imag};
}
auto imag = mapply(op_, tup);
return complex_type{real, imag};
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
auto real = op_(indices...);

constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<rank_idx>(tup) += op_.Size(rank_idx) / 2;

auto imag = mapply(op_, tup);
return complex_type{real, imag};
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
34 changes: 24 additions & 10 deletions include/matx/operators/kronecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,32 @@ namespace matx
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup1 = cuda::std::make_tuple(indices...);
auto tup2 = cuda::std::make_tuple(indices...);
cuda::std::get<Rank() - 2>(tup2) = pp_get<Rank() - 2>(indices...) % op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup2) = pp_get<Rank() - 1>(indices...) % op2_.Size(Rank() - 1);
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup1 = cuda::std::make_tuple(indices...);
auto tup2 = cuda::std::make_tuple(indices...);
cuda::std::get<Rank() - 2>(tup2) = pp_get<Rank() - 2>(indices...) % op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup2) = pp_get<Rank() - 1>(indices...) % op2_.Size(Rank() - 1);

cuda::std::get<Rank() - 2>(tup1) = pp_get<Rank() - 2>(indices...) / op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup1) = pp_get<Rank() - 1>(indices...) / op2_.Size(Rank() - 1);

return mapply(op2_, tup2) * mapply(op1_, tup1);
}

cuda::std::get<Rank() - 2>(tup1) = pp_get<Rank() - 2>(indices...) / op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup1) = pp_get<Rank() - 1>(indices...) / op2_.Size(Rank() - 1);
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
auto tup1 = cuda::std::make_tuple(indices...);
auto tup2 = cuda::std::make_tuple(indices...);
cuda::std::get<Rank() - 2>(tup2) = pp_get<Rank() - 2>(indices...) % op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup2) = pp_get<Rank() - 1>(indices...) % op2_.Size(Rank() - 1);

return mapply(op2_, tup2) * mapply(op1_, tup1);
}
cuda::std::get<Rank() - 2>(tup1) = pp_get<Rank() - 2>(indices...) / op2_.Size(Rank() - 2);
cuda::std::get<Rank() - 1>(tup1) = pp_get<Rank() - 1>(indices...) / op2_.Size(Rank() - 1);

return mapply(op2_, tup2) * mapply(op1_, tup1);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
33 changes: 23 additions & 10 deletions include/matx/operators/planar.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,30 @@ namespace matx
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
if (cuda::std::get<rank_idx>(tup) >= op_.Size(rank_idx)) {
cuda::std::get<rank_idx>(tup) -= op_.Size(rank_idx);
return mapply(op_, tup).imag();
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
if (cuda::std::get<rank_idx>(tup) >= op_.Size(rank_idx)) {
cuda::std::get<rank_idx>(tup) -= op_.Size(rank_idx);
return mapply(op_, tup).imag();
}

return op_(indices...).real();
}

return op_(indices...).real();
}
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2);
auto tup = cuda::std::make_tuple(indices...);
if (cuda::std::get<rank_idx>(tup) >= op_.Size(rank_idx)) {
cuda::std::get<rank_idx>(tup) -= op_.Size(rank_idx);
return mapply(op_, tup).imag();
}

return op_(indices...).real();
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
38 changes: 26 additions & 12 deletions include/matx/operators/r2c.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,32 @@ namespace matx
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup = cuda::std::make_tuple(indices...);

// If we're on the upper part of the spectrum, return the conjugate of the first half
if (cuda::std::get<Rank()-1>(tup) >= op_.Size(Rank()-1)) {
cuda::std::get<Rank()-1>(tup) = orig_size_ - cuda::std::get<Rank()-1>(tup);
return cuda::std::conj(mapply(op_, tup));
}

return mapply(op_, tup);
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup = cuda::std::make_tuple(indices...);

// If we're on the upper part of the spectrum, return the conjugate of the first half
if (cuda::std::get<Rank()-1>(tup) >= op_.Size(Rank()-1)) {
cuda::std::get<Rank()-1>(tup) = orig_size_ - cuda::std::get<Rank()-1>(tup);
return cuda::std::conj(mapply(op_, tup));
}

return mapply(op_, tup);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
auto tup = cuda::std::make_tuple(indices...);

// If we're on the upper part of the spectrum, return the conjugate of the first half
if (cuda::std::get<Rank()-1>(tup) >= op_.Size(Rank()-1)) {
cuda::std::get<Rank()-1>(tup) = orig_size_ - cuda::std::get<Rank()-1>(tup);
return cuda::std::conj(mapply(op_, tup));
}

return mapply(op_, tup);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
Loading