Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/matx/core/sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ template <typename VAL, typename CRD, typename POS, typename TF,
typename StorageC = DefaultStorage<CRD>,
typename StorageP = DefaultStorage<POS>,
typename DimDesc = DefaultDescriptor<TF::DIM>>
class sparse_tensor_t : public detail::tensor_impl_t<
VAL, TF::DIM, DimDesc,
detail::SparseTensorData<VAL, CRD, POS, TF::LVL>> {
class sparse_tensor_t
: public detail::tensor_impl_t<
VAL, TF::DIM, DimDesc, detail::SparseTensorData<VAL, CRD, POS, TF>> {
public:
using sparse_tensor = bool;
static constexpr int DIM = TF::DIM;
Expand All @@ -79,7 +79,7 @@ class sparse_tensor_t : public detail::tensor_impl_t<
sparse_tensor_t(const typename DimDesc::shape_type (&shape)[DIM],
StorageV &&vals, StorageC (&&crd)[LVL], StorageP (&&pos)[LVL])
: detail::tensor_impl_t<VAL, DIM, DimDesc,
detail::SparseTensorData<VAL, CRD, POS, LVL>>(
detail::SparseTensorData<VAL, CRD, POS, TF>>(
shape) {
// Initialize primary and secondary storage.
values_ = std::move(vals);
Expand Down
38 changes: 21 additions & 17 deletions include/matx/core/sparse_tensor_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ template <typename Expr, LvlType ltype> class LvlSpec {
//
template <int D, typename... LvlSpecs> class SparseTensorFormat {
public:
using LVLSPECS = std::tuple<LvlSpecs...>;
static constexpr int DIM = D;
static constexpr int LVL = sizeof...(LvlSpecs);

Expand All @@ -199,7 +200,7 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isDnVec() {
if constexpr (LVL == 1) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0;
}
Expand All @@ -208,7 +209,7 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isSpVec() {
if constexpr (LVL == 1) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
return first_type::lvltype == LvlType::Compressed &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0;
}
Expand All @@ -217,8 +218,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCOO() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::CompressedNonUnique &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0 &&
second_type::lvltype == LvlType::Singleton &&
Expand All @@ -229,8 +230,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCSR() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0 &&
second_type::lvltype == LvlType::Compressed &&
Expand All @@ -241,8 +242,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCSC() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 1 &&
second_type::lvltype == LvlType::Compressed &&
Expand All @@ -252,12 +253,13 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
}

template <typename CRD>
static CRD *dim2lvl(const CRD *dims, CRD *lvls, bool asSize) {
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
dim2lvl(const CRD *dims, CRD *lvls, bool asSize) {
// Lambda for dim2lvl translation.
auto loop_fun = [&dims, &lvls, &asSize](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
lvls[idx] = dims[ftype::expr::di];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
Expand All @@ -278,12 +280,14 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
return lvls;
}

template <typename CRD> static CRD *lvl2dim(const CRD *lvls, CRD *dims) {
template <typename CRD>
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
lvl2dim(const CRD *lvls, CRD *dims) {
// Lambda for lvl2dim translation.
auto loop_fun = [&lvls, &dims](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
dims[ftype::expr::di] = lvls[idx];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
Expand Down Expand Up @@ -314,35 +318,35 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
// Assumes LVL <= 5.
static_assert(LVL <= 5);
if constexpr (LVL > 1) {
using ftype = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<0, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL != 1) {
std::cout << ",";
}
}
if constexpr (LVL >= 2) {
using ftype = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<1, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 2) {
std::cout << ",";
}
}
if constexpr (LVL >= 3) {
using ftype = std::tuple_element_t<2, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<2, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 3) {
std::cout << ",";
}
}
if constexpr (LVL >= 4) {
using ftype = std::tuple_element_t<3, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<3, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 4) {
std::cout << ",";
}
}
if constexpr (LVL >= 5) {
using ftype = std::tuple_element_t<4, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<4, LVLSPECS>;
std::cout << " " << ftype::toString();
}
std::cout << " )" << std::endl;
Expand Down
5 changes: 3 additions & 2 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ struct DenseTensorData {
T *ldata_;
};

template <typename T, typename CRD, typename POS, int L>
template <typename T, typename CRD, typename POS, typename TF>
struct SparseTensorData {
using sparse_data = bool;
using crd_type = CRD;
using pos_type = POS;
static constexpr int LVL = L;
using Format = TF;
static constexpr int LVL = TF::LVL;
T *ldata_;
CRD *crd_[LVL];
POS *pos_[LVL];
Expand Down