Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions include/matx/core/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,24 @@
#include <optional>
#include <any>
#include <unordered_map>
#include <cuda/atomic>

#include "matx/core/error.h"

namespace matx {
namespace detail {

enum class CacheName {
FFT_1D,
FFT_2D,
CHOL,
LU,
QR,
SVD,
EIG,
CUB,
GEMM,
COV,
FILTER,
INV
};
using CacheId = uint64_t;

inline cuda::std::atomic<CacheId> CacheIdCounter{0};

template<typename CacheType>
CacheId GetCacheIdFromType()
{
static CacheId id = CacheIdCounter.fetch_add(1);

return id;
}

/**
* Generic caching object for caching parameters. This class is used for
Expand All @@ -72,7 +69,7 @@ class matxCache_t {
~matxCache_t() {
// Destroy all outstanding objects in the cache to free memory
for (auto &[k, v]: cache) {
v.reset();
v.reset();
}
}

Expand All @@ -81,22 +78,22 @@ class matxCache_t {
*
*/
template <typename CacheType>
void Clear(const CacheName &name) {
auto el = cache.find(name);
void Clear(const CacheId &id) {
auto el = cache.find(id);
MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found");

std::any_cast<CacheType>(el->second).clear();
}

template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun>
void LookupAndExec(const CacheName &name, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
// Create named cache if it doesn't exist
auto el = cache.find(name);
auto el = cache.find(id);
if (el == cache.end()) {
cache[name] = CacheType{};
cache[id] = CacheType{};
}

auto &cval = cache[name];
auto &cval = cache[id];
auto &rmap = std::any_cast<CacheType&>(cval);
auto cache_el = rmap.find(params);
if (cache_el == rmap.end()) {
Expand All @@ -110,7 +107,7 @@ class matxCache_t {
}

private:
std::unordered_map<CacheName, std::any> cache;
std::unordered_map<CacheId, std::any> cache;
};

/**
Expand Down
14 changes: 7 additions & 7 deletions include/matx/transforms/cov.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
static_assert(RANK >= 2);
MATX_ASSERT(c.Size(RANK - 1) == c.Size(RANK - 2), matxInvalidSize);
MATX_ASSERT(a.Size(RANK - 1) == c.Size(RANK - 1), matxInvalidSize);
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

// Ensure batch dimensions are equal
for (int i = 2; i < RANK - 2; i++) {
Expand Down Expand Up @@ -144,7 +144,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
inline void Exec(TensorTypeC &c, const TensorTypeA &a,
cudaStream_t stream)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
// Calculate a matrix of means
matmul_impl(means, onesM, a, stream,
1.0f / static_cast<float>(a.Size(RANK - 2)));
Expand All @@ -167,7 +167,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
// Multiply by itself and scale by N-1 for the final covariance
matmul_impl(c, devsT, devs, stream,
1.0f / static_cast<float>(a.Size(RANK - 2) - 1));
}
}

private:
// Member variables
Expand Down Expand Up @@ -231,21 +231,21 @@ template <typename TensorTypeC, typename TensorTypeA>
void cov_impl(TensorTypeC &c, const TensorTypeA &a,
cudaStream_t stream = 0)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
// Get parameters required by these tensors
auto params = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>::GetCovParams(c, a, stream);

using cache_val_type = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>;
detail::GetCache().LookupAndExec<detail::cov_cache_t>(
detail::CacheName::COV,
detail::GetCacheIdFromType<detail::cov_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(c, a);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->Exec(c, a, stream);
}
);
);
}

} // end namespace matx
Loading