Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ option(MATX_EN_VISUALIZATION "Enable visualization support" OFF)
option(MATX_EN_CUTLASS OFF)
option(MATX_EN_CUTENSOR OFF)
option(MATX_EN_FILEIO OFF)
option(MATX_EN_NVPL OFF, "Enable NVIDIA Performance Libraries for optimized ARM CPU support")
option(MATX_DISABLE_CUB_CACHE "Disable caching for CUB allocations" ON)

set(MATX_EN_PYBIND11 OFF CACHE BOOL "Enable pybind11 support")
Expand Down Expand Up @@ -152,6 +153,15 @@ else()
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)
endif()

if (MATX_EN_NVPL)
message(STATUS "Enabling NVPL library support")
# find_package is currently broken in NVPL. Use proper targets once working
#find_package(nvpl REQUIRED COMPONENTS fft)
#target_link_libraries(matx INTERFACE nvpl::fftw)
target_link_libraries(matx INTERFACE nvpl_fftw)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
endif()

if (MATX_DISABLE_CUB_CACHE)
target_compile_definitions(matx INTERFACE MATX_DISABLE_CUB_CACHE=1)
endif()
Expand Down Expand Up @@ -291,4 +301,3 @@ if (MATX_BUILD_TESTS)
include(cmake/GetGTest.cmake)
add_subdirectory(test)
endif()

4 changes: 2 additions & 2 deletions docs_input/api/dft/fft/fft2d.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Perform a 2D FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: fft2(OpA &&a)
.. doxygenfunction:: fft2(OpA &&a, const int32_t (&axis)[2])
.. doxygenfunction:: fft2(OpA &&a, FFTNorm norm = FFTNorm::BACKWARD)
.. doxygenfunction:: fft2(OpA &&a, const int32_t (&axis)[2], FFTNorm norm = FFTNorm::BACKWARD)

Examples
~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions docs_input/api/dft/fft/ifft2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Perform a 2D inverse FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: ifft2(OpA &&a)
.. doxygenfunction:: ifft2(OpA &&a, const int32_t (&axis)[2])
.. doxygenfunction:: ifft2(OpA &&a, FFTNorm norm = FFTNorm::BACKWARD)
.. doxygenfunction:: ifft2(OpA &&a, const int32_t (&axis)[2], FFTNorm norm = FFTNorm::BACKWARD)

Examples
~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs_input/api/manipulation/basic/copy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ since it cannot be chained with other expressions.
Examples
~~~~~~~~

.. literalinclude:: ../../../../include/matx/transforms/fft.h
.. literalinclude:: ../../../../include/matx/transforms/fft/fft_common.h
:language: cpp
:start-after: example-begin copy-test-1
:end-before: example-end copy-test-1
Expand Down
10 changes: 10 additions & 0 deletions docs_input/build.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ Optional Third-party Dependencies
- `cutensor <https://developer.nvidia.com/cutensor>`_ 1.7.0.1+ (Required when using `einsum`)
- `cutensornet <https://docs.nvidia.com/cuda/cuquantum/cutensornet>`_ 23.03.0.20+ (Required when using `einsum`)

Host (CPU) Support
------------------
Host support is provided both by the C++ standard library and NVIDIA's NVPL_ library. Host support is
considered experimental and is still a work in progress. Currently all reduction functions are supported,
but only FFT transforms are supported. All host support is limited to a single thread in this release.

To enable NVPL support use the CMake option `-DMATX_EN_NVPL=ON`.

.. _NVPL: https://developer.nvidia.com/nvpl

Build Options
=============
MatX provides 5 primary options for builds, and each can be configured independently:
Expand Down
3 changes: 2 additions & 1 deletion include/matx/core/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ namespace matx
matxLUError,
matxInverseError,
matxSolverError,
matxcuTensorError
matxcuTensorError,
matxInvalidExecutor
};

static constexpr const char *matxErrorString(matxError_t e)
Expand Down
1 change: 1 addition & 0 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {

if constexpr (N > 0) {
if (end != matxDropDim) {
MATX_ASSERT_STR(end != matxKeepDim, matxInvalidParameter, "matxKeepDim only valid for clone(), not slice()");
if (end == matxEnd) {
n[d] = this->Size(i) - first;
}
Expand Down
12 changes: 6 additions & 6 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ constexpr bool is_executor_t()


namespace detail {
template<typename T> struct is_device_executor : std::false_type {};
template<> struct is_device_executor<matx::cudaExecutor> : std::true_type {};
template<typename T> struct is_cuda_executor : std::false_type {};
template<> struct is_cuda_executor<matx::cudaExecutor> : std::true_type {};
}

/**
Expand All @@ -282,11 +282,11 @@ template<> struct is_device_executor<matx::cudaExecutor> : std::true_type {};
* @tparam T Type to test
*/
template <typename T>
inline constexpr bool is_device_executor_v = detail::is_device_executor<typename remove_cvref<T>::type>::value;
inline constexpr bool is_cuda_executor_v = detail::is_cuda_executor<typename remove_cvref<T>::type>::value;

namespace detail {
template<typename T> struct is_single_thread_host_executor : std::false_type {};
template<> struct is_single_thread_host_executor<matx::HostExecutor> : std::true_type {};
template<typename T> struct is_host_executor : std::false_type {};
template<> struct is_host_executor<matx::HostExecutor> : std::true_type {};
}

/**
Expand All @@ -295,7 +295,7 @@ template<> struct is_single_thread_host_executor<matx::HostExecutor> : std::true
* @tparam T Type to test
*/
template <typename T>
inline constexpr bool is_single_thread_host_executor_v = detail::is_single_thread_host_executor<remove_cvref_t<T>>::value;
inline constexpr bool is_host_executor_v = detail::is_host_executor<remove_cvref_t<T>>::value;


namespace detail {
Expand Down
2 changes: 1 addition & 1 deletion include/matx/executors/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace matx
/*
* @breif Returns stream associated with executor
*/
auto getStream() { return stream_; }
auto getStream() const { return stream_; }

/**
* Execute an operator on a device
Expand Down
1 change: 1 addition & 0 deletions include/matx/executors/executors.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@

#pragma once

#include "matx/executors/support.h"
#include "matx/executors/device.h"
#include "matx/executors/host.h"
2 changes: 2 additions & 0 deletions include/matx/executors/host.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class HostExecutor {
}
}

int GetNumThreads() const { return params_.GetNumThreads(); }

private:
HostExecParams params_;
};
Expand Down
60 changes: 60 additions & 0 deletions include/matx/executors/support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include "matx/core/type_utils.h"

// Utility functions to determine what support is available per-executor

namespace matx {
namespace detail {

// FFT
#if defined(MATX_EN_NVPL)
#define MATX_EN_CPU_FFT 1
#else
#define MATX_EN_CPU_FFT 0
#endif

template <typename Exec>
constexpr bool CheckFFTSupport() {
if constexpr (is_host_executor_v<Exec>) {
return MATX_EN_CPU_FFT;
}
else {
return true;
}
}

}; // detail
}; // matx
8 changes: 4 additions & 4 deletions include/matx/generators/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ template <typename T, int RANK> class randomTensorView_t {
__MATX_INLINE__ void PreRun([[maybe_unused]] ST &&shape, Executor &&ex)
{
#ifdef __CUDACC__
if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
if (!init_) {
auto stream = ex.getStream();
matxAlloc((void **)&states_,
Expand All @@ -446,7 +446,7 @@ template <typename T, int RANK> class randomTensorView_t {
device_ = true;
}
}
else if constexpr (is_single_thread_host_executor_v<Executor>) {
else if constexpr (is_host_executor_v<Executor>) {
if (!init_) {
[[maybe_unused]] curandStatus_t ret;

Expand All @@ -468,10 +468,10 @@ template <typename T, int RANK> class randomTensorView_t {
template <typename ST, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ST &&shape, [[maybe_unused]] Executor &&ex) noexcept
{
if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
matxFree(states_);
}
else if constexpr (is_single_thread_host_executor_v<Executor>) {
else if constexpr (is_host_executor_v<Executor>) {
curandDestroyGenerator(gen_);
//matxFree(val);
}
Expand Down
43 changes: 34 additions & 9 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,27 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,


template <typename T>
union Uninitialized {
__host__ __device__ constexpr Uninitialized() {}
struct Uninitialized {
__host__ __device__ constexpr Uninitialized() {};
T data;
};

template <typename T, int X_LEN>
struct ShmBuffer2D {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ ShmBuffer2D(char *p) {
ptr = reinterpret_cast<T*>(p);
}
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ const T &operator()(index_t y, index_t x) const noexcept {
return *(ptr + y * X_LEN + x);
}

__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ T &operator()(index_t y, index_t x) noexcept {
return *(ptr + y * X_LEN + x);
}

T *ptr;
};

template <typename OutType, typename InType1, typename InType2,
int BLOCK_DIM_X, // blockDim.x
int BLOCK_DIM_Y, // blockDim.y
Expand All @@ -216,8 +232,17 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,

constexpr int Rank = OutType::Rank();

__shared__ Uninitialized<in1type> s_signal[SIGNAL_CHUNK_Y][SIGNAL_CHUNK_X];
__shared__ Uninitialized<in2type> s_filter[FILTER_SHARED_CHUNK_Y][FILTER_SHARED_CHUNK_X];
constexpr int type2off =
std::ceil(static_cast<double>(sizeof(in1type) * SIGNAL_CHUNK_Y * SIGNAL_CHUNK_X) / sizeof(in2type) *
sizeof(in2type));
__shared__ char shared_buf[type2off + sizeof(in2type) * FILTER_SHARED_CHUNK_Y * FILTER_SHARED_CHUNK_X];

// __shared__ Uninitialized<in1type> s_signal[SIGNAL_CHUNK_Y][SIGNAL_CHUNK_X];
// __shared__ Uninitialized<in2type> s_filter[FILTER_SHARED_CHUNK_Y][FILTER_SHARED_CHUNK_X];

// Workaround for ARM compiler bug that will not allow the union type above
ShmBuffer2D<in1type, SIGNAL_CHUNK_X> s_signal{&shared_buf[0]};
ShmBuffer2D<in2type, FILTER_SHARED_CHUNK_X> s_filter{&shared_buf[type2off]};

in2type r_filter[FILTER_REG_CHUNK_Y][FILTER_REG_CHUNK_X];

Expand Down Expand Up @@ -279,7 +304,7 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,
detail::mapply([&](auto &&...args) { val = d_in2.operator()(args...); }, bdims);
}
// store in shared
s_filter[ii][jj].data = val;
s_filter(ii, jj) = val;
}
}

Expand All @@ -300,7 +325,7 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,
}

// store in shared
s_signal[ii][jj].data = val;
s_signal(ii,jj) = val;
}
}

Expand All @@ -320,7 +345,7 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,
for(int ii = 0; ii < FILTER_REG_CHUNK_Y; ii++) {
#pragma unroll
for(int jj = 0; jj < FILTER_REG_CHUNK_X; jj++) {
r_filter[ii][jj] = s_filter[nn+ii][mm+jj].data;
r_filter[ii][jj] = s_filter(nn+ii, mm+jj);
}
}

Expand All @@ -341,7 +366,7 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,
// load ILPY signal points
#pragma unroll
for(int u = 0; u < ILPY; u++) {
i1[u] = s_signal[nn+n+threadIdx.y*ILPY + u][mm+m+threadIdx.x].data;
i1[u] = s_signal(nn+n+threadIdx.y*ILPY + u, mm+m+threadIdx.x);
}
} else {
// advance/shift signal points in registers
Expand All @@ -351,7 +376,7 @@ __global__ void Conv2D(OutType d_out, InType1 d_in1, InType2 d_in2,
}

// load new signal point at end of the array
i1[ILPY-1] = s_signal[nn+n+threadIdx.y*ILPY + ILPY - 1][mm+m+threadIdx.x].data;
i1[ILPY-1] = s_signal(nn+n+threadIdx.y*ILPY + ILPY - 1,mm+m+threadIdx.x);
}

// inner convolution loop
Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace detail {
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
}
else {
Expand Down
4 changes: 2 additions & 2 deletions include/matx/operators/ambgfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ namespace matx

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const {
static_assert(is_device_executor_v<Executor>, "ambgfun() only supports the CUDA executor currently");
static_assert(is_cuda_executor_v<Executor>, "ambgfun() only supports the CUDA executor currently");
static_assert(std::tuple_element_t<0, remove_cvref_t<Out>>::Rank() == 2, "Output tensor of ambgfun must be 2D");
ambgfun_impl(std::get<0>(out), x_, y_, fs_, cut_, cut_val_, ex.getStream());
}
Expand All @@ -121,7 +121,7 @@ namespace matx
y_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
}

Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace detail {
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
}
else {
Expand Down
4 changes: 2 additions & 2 deletions include/matx/operators/cgsolve.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ namespace matx

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const{
static_assert(is_device_executor_v<Executor>, "cgsolve() only supports the CUDA executor currently");
static_assert(is_cuda_executor_v<Executor>, "cgsolve() only supports the CUDA executor currently");
cgsolve_impl(std::get<0>(out), a_, b_, tol_, max_iters_, ex.getStream());
}

Expand All @@ -102,7 +102,7 @@ namespace matx
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_device_executor_v<Executor>) {
if constexpr (is_cuda_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
}

Expand Down
Loading