diff --git a/CMakeLists.txt b/CMakeLists.txt index ac3863622..dc644e270 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ set(MATX_EN_PYBIND11 OFF CACHE BOOL "Enable pybind11 support") set(cutensor_DIR "" CACHE PATH "Directory where cuTENSOR is installed.") set(cutensornet_DIR "" CACHE PATH "Directory where cuTensorNet is installed.") +set(eigen_DIR "" CACHE PATH "Directory where Eigen is installed") # Enable compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) diff --git a/docs_input/basics/matlabpython.rst b/docs_input/basics/matlabpython.rst index 02082fec7..eeeea4a0d 100644 --- a/docs_input/basics/matlabpython.rst +++ b/docs_input/basics/matlabpython.rst @@ -1,11 +1,20 @@ -MatX For MATLAB/NumPy Users -=========================== +MatX For MATLAB/NumPy/Eigen Users +================================= MatX has many features directly inspired by MATLAB and Python (Numpy/Scipy) for translating these high-level languages into C++. The table below aims to give users of either of those tools a translation guide to writing efficient MatX code by learning the syntax mapping between the tools. Most of these conversions can also be found inside either the unit tests or the source code as well. +Due to its popularity in linear algebra applications, examples of common Eigen operations have been added to the below table. An example file is provided at ``examples/eigenExample.cu`` with examples of common operations in Eigen and their MatX equivalent. +If you have Eigen installed on your system, you can build the examples with Eigen by setting the cmake variable ``eigen_DIR=/path/to/eigen/``. + +A few key notes to be aware of when using MatX and Eigen in the same environment: + +1. The below example are only valid for 2D data. Eigen and its API is primarily targeting 2D problems (without the unsupported/tensor library), so there is not a single pattern to follow for porting code with rank > 2 tensors from Eigen to MatX; each user's solution for higher rank data will result in a unique mapping to MatX tensor memory. +2. When copying data between Eigen and MatX structures (most likely Eigen::MatrixXd to MatX tensors) keep in mind that the underlying data structure may or may not be available on the CPU. use the accessor functions () or a cudaMemcpy when applicable. +3. Eigen has column-major storage by default, so ensure you transpose any raw data copies between structures. + Overview -------- @@ -25,47 +34,50 @@ Conversion Table ---------------- .. table:: Conversion Table - :widths: 10 15 15 15 35 10 + :widths: 10 15 15 25 15 35 - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Operation | MATLAB | Python | MatX | Notes | Examples | - +===========================+========================================+================================================+===================================================================+========================================================================================================================+==========+ - | Basic indexing | ``A(1,5)`` | ``A[0,4]`` | ``A(0,4)`` | Retrieves the element in the first row and fifth column | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor addition | ``A + B`` | ``A + B`` | ``A + B`` | Adds two tensors element-wise | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor subtraction | ``A - B`` | ``A - B`` | ``A - B`` | Subtracts two tensors element-wise | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor multiplication | ``A .* B`` | ``A * B`` | ``A * B`` | Multiplies two tensors element-wise | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor division | ``A ./ B`` | ``A / B`` | ``A / B`` | Divides two tensors element-wise | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor slice (contiguous) | ``A(1:4,2:5)`` | ``A[0:5,1:6]`` | ``slice(A, {0,1}, {5,6});`` | Slices 4 elements of the outer dimension starting at 0, | | - | | | | | and 5 elements of the inner dimension, starting at the second element. | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Tensor slice (w/stride) | ``A(1:2:end,2:3:8)`` | ``A[::2,1:9:3]`` | ``slice(A, {0,1}, {matxEnd,9}, {2,3});`` | Slices N elements of the outer dimension starting at the first element and picking every second element until the end. | | - | | | | | In the inner dimension, start at the first element and grab every third item, and stop at the 8th item. | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Cloning a dimension | ``reshape(repmat(A, [4,1]), [4 4 4])`` | ``np.repeat(np.expand_dims(A, axis=0), 5, 0)`` | ``clone<3>(A, {5, matxKeepDim, matxKeepDim})`` | Takes a 4x4 2D tensor and makes it a 5x4x4 3D tensor where every outer dimension replicates the two inner | | - | | | | | inner dimensions | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Slice off a row or column | ``A(5,:)`` | ``A[4,:]`` | ``slice<1>(A, {4, 0}, {matxDropDim, matxEnd})`` | Selects the fifth row and all columns from a 2D tensor, and returns a 1D tensor | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Permute dimensions | ``permute(A, [3 2 1])`` | ``np.einsum('kij->ijk', A)`` | ``permute(A, {2,1,0})`` or ``cutensor::einsum("kij->ijk", A);`` | Permutes the three axes into the opposite order | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Get real values | ``real(A)`` | ``np.real(A)`` | ``A.RealView()`` | Returns only the real values of the complex series | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Matrix multiply (GEMM) | ``A * B`` | ``np.matmul(A, B)`` or ``A @ B`` | ``matmul(A, B)`` | Computes the matrix multiplication of ``A * B`` | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Compute matrix inverse | ``inv(A)`` | ``np.linalg.inv(A)`` | ``inv(A)`` | Computes the inverse of matrix A using LU factorization | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | 1D FFT | ``fft(A)`` | ``np.fft.fft(A)`` | ``fft(A)`` | Computes the 1D fast fourier transfor, (FFT) of rows of A | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | 1D IFFT | ``ifft(A)`` | ``np.fft.ifft(A)`` | ``ifft(A)`` | Computes the 1D inverse fast fourier transfor, (IFFT) of rows of A | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | 2D FFT | ``fft2(A)`` | ``np.fft.fft2(A)`` | ``fft2(A)`` | Computes the 2D fast fourier transfor, (FFT) of matrices in outer 2 dimensions of A | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | 2D IFFT | ``ifft2(A)`` | ``np.fft.ifft2(A)`` | ``ifft2(A)`` | Computes the 2D inverse fast fourier transfor, (IFFT) of matrices in outer 2 dimensions of A | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ - | Covariance | ``cov(A)`` | ``np.cov(A)`` | ``cov(A)`` | Computes the covariance on the rows of matrix A | | - +---------------------------+----------------------------------------+------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+----------+ + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Operation | MATLAB | Python | Eigen | MatX | Notes | + +===========================+========================================+================================================+===================================================================================+===================================================================+========================================================================================================================+ + | Basic indexing | ``A(1,5)`` | ``A[0,4]`` | ``A(0,4)`` | ``A(0,4)`` | Retrieves the element in the first row and fifth column | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor addition | ``A + B`` | ``A + B`` | ``A + B`` | ``A + B`` | Adds two tensors element-wise | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor subtraction | ``A - B`` | ``A - B`` | ``A - B`` | ``A - B`` | Subtracts two tensors element-wise | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor multiplication | ``A .* B`` | ``A * B`` | ``A.cwiseProduct(B)`` | ``A * B`` | Multiplies two tensors element-wise | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor division | ``A ./ B`` | ``A / B`` | ``A.cwiseQuotient(B)`` | ``A / B`` | Divides two tensors element-wise | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor slice (contiguous) | ``A(1:4,2:5)`` | ``A[0:5,1:6]`` | ``A.block(0, 1, 5, 6)`` | ``slice(A, {0,1}, {5,6});`` | Slices 4 elements of the outer dimension starting at 0, | + | | | | | | and 5 elements of the inner dimension, starting at the second element. | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Tensor slice (w/stride) | ``A(1:2:end,2:3:8)`` | ``A[::2,1:9:3]`` | ``Eigen::Map>`` | ``slice(A, {0,1}, {matxEnd,9}, {2,3});`` | Slices N elements of the outer dimension starting at the first element and picking every second element until the end. | + | | | | ``strided(matrix.data() + 0 * matrix.outerStride() + 0, 5, 3,`` | | In the inner dimension, start at the first element and grab every third item, and stop at the 8th item. | + | | | | ``Eigen::Stride(3 * matrix.outerStride(), 2))`` | | | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Cloning a dimension | ``reshape(repmat(A, [4,1]), [4 4 4])`` | ``np.repeat(np.expand_dims(A, axis=0), 5, 0)`` | ``cloneMat = A.replicate(5, 1)`` | ``clone<3>(A, {5, matxKeepDim, matxKeepDim})`` | Takes a 4x4 2D tensor and makes it a 5x4x4 3D tensor where every outer dimension replicates the two inner | + | | | | | | inner dimensions | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Slice off a row or column | ``A(5,:)`` | ``A[4,:]`` | ``Eigen::RowVector3d row = a.row(1)`` | ``slice<1>(A, {4, 0}, {matxDropDim, matxEnd})`` | Selects the fifth row and all columns from a 2D tensor, and returns a 1D tensor | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Permute dimensions | ``permute(A, [3 2 1])`` | ``np.einsum('kij->ijk', A)`` | ``Eigen::PermutationMatrix<3> perm`` | ``permute(A, {2,1,0})`` or ``cutensor::einsum("kij->ijk", A);`` | Permutes the three axes into the opposite order | + | | | | ``perm.indices() << 2, 1, 0`` | | In the inner dimension, start at the first element and grab every third item, and stop at the 8th item. | + | | | | ``Eigen::Matrix3d permutedMatrix = perm * a`` | | In the inner dimension, start at the first element and grab every third item, and stop at the 8th item. | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Get real values | ``real(A)`` | ``np.real(A)`` | ``A.real()`` | ``A.RealView()`` | Returns only the real values of the complex series | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Matrix multiply (GEMM) | ``A * B`` | ``np.matmul(A, B)`` or ``A @ B`` | ``A * B`` | ``matmul(A, B)`` | Computes the matrix multiplication of ``A * B`` | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Compute matrix inverse | ``inv(A)`` | ``np.linalg.inv(A)`` | ``A.inverse()`` | ``inv(A)`` | Computes the inverse of matrix A using LU factorization | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | 1D FFT | ``fft(A)`` | ``np.fft.fft(A)`` | N/A | ``fft(A)`` | Computes the 1D fast fourier transfor, (FFT) of rows of A | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | 1D IFFT | ``ifft(A)`` | ``np.fft.ifft(A)`` | N/A | ``ifft(A)`` | Computes the 1D inverse fast fourier transfor, (IFFT) of rows of A | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | 2D FFT | ``fft2(A)`` | ``np.fft.fft2(A)`` | N/A | ``fft2(A)`` | Computes the 2D fast fourier transfor, (FFT) of matrices in outer 2 dimensions of A | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | 2D IFFT | ``ifft2(A)`` | ``np.fft.ifft2(A)`` | N/A | ``ifft2(A)`` | Computes the 2D inverse fast fourier transfor, (IFFT) of matrices in outer 2 dimensions of A | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | Covariance | ``cov(A)`` | ``np.cov(A)`` | N/A | ``cov(A)`` | Computes the covariance on the rows of matrix A | + +---------------------------+----------------------------------------+------------------------------------------------+-----------------------------------------------------------------------------------+-------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 34b740d2b..66523d298 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -5,6 +5,7 @@ set(examples convolution conv2d cgsolve + eigenExample fft_conv resample mvdr_beamformer @@ -18,12 +19,21 @@ set(examples black_scholes print_styles) + + + add_library(example_lib INTERFACE) target_include_directories(example_lib SYSTEM INTERFACE ${CUTLASS_INC} ${pybind11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS}) target_link_libraries(example_lib INTERFACE matx::matx) # Transitive properties set_property(TARGET example_lib PROPERTY ENABLE_EXPORTS 1) + +if(eigen_DIR) + include_directories(SYSTEM ${eigen_DIR}) + add_definitions(-DUSE_EIGEN) + target_compile_definitions(example_lib INTERFACE USE_EIGEN) +endif() if (MSVC) target_compile_options(example_lib INTERFACE /W4 /WX) diff --git a/examples/eigenExample.cu b/examples/eigenExample.cu new file mode 100644 index 000000000..6e53cd701 --- /dev/null +++ b/examples/eigenExample.cu @@ -0,0 +1,387 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#include + + +// BUILD NOTES: TO build, include the path to the eigen in cmake with the variable eigen_DIR="Path/To/Eigen" +#ifdef USE_EIGEN + #include +#endif + +#include + + + +int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) +{ + int dimX = 3; + int dimY = 3; + + + /////////////////////////////////////////////////////////////////////////////// + ////////////// Eigen Test Data Setup ////////////// + /////////////////////////////////////////////////////////////////////////////// +#ifdef USE_EIGEN + + typedef Eigen::Matrix MatrixXdRowMajor; // define a custom type that is aligned to MatX row-Major. + + Eigen::MatrixXd a(dimX, dimY); + MatrixXdRowMajor b(dimX, dimY); + Eigen::RowVectorXd rowVec(dimX); + Eigen::Matrix, 2, 2> complexMatrix; + Eigen::MatrixXf matrix10x10(10, 10); +#endif + /////////////////////////////////////////////////////////////////////////////// + ////////////// MatX Test Data Setup ////////////// + /////////////////////////////////////////////////////////////////////////////// + auto aTensor = matx::make_tensor({dimX,dimY}); + auto bTensor = matx::make_tensor({dimX,dimY}); + auto tensor1D = matx::make_tensor({dimX}); + auto complexTensor = matx::make_tensor>({2,2}); + auto matTensor10x10 = matx::make_tensor({10,10}); + + + /////////////////////////////////////////////////////////////////////////////// + ////////////// Initialize Data ////////////// + /////////////////////////////////////////////////////////////////////////////// +#ifdef USE_EIGEN + std::cout <<"!!!!!!!!! Using Eigen in Test !!!!!!!!!" << std ::endl; + // Initialize with random values + a.setRandom(); + b.setRandom(); + matrix10x10.setRandom(); + + rowVec << 1, 2, 3; + + complexMatrix(0, 0) = std::complex(1.0, 2.0); + complexMatrix(0, 1) = std::complex(2.0, 3.0); + complexMatrix(1, 0) = std::complex(3.0, 4.0); + complexMatrix(1, 1) = std::complex(4.0, 5.0); + +#else + std::cout <<"!!!!!!!!! Eigen NOT USED in Test !!!!!!!!!" << std ::endl; + // provide data in tensors if eigen is not used + (aTensor = matx::random({dimX, dimY}, matx::UNIFORM)).run(); + (bTensor = matx::random({dimX, dimY}, matx::UNIFORM)).run(); + (complexTensor = matx::random>({2, 2}, matx::UNIFORM)).run(); + (matTensor10x10 = matx::random({10, 10}, matx::UNIFORM)).run(); + +#endif + + + + + /////////////////////////////////////////////////////////////////////////////// + ////////////// Copy Eigen inputs to MatX ////////////// + /////////////////////////////////////////////////////////////////////////////// +#ifdef USE_EIGEN + cudaMemcpy(aTensor.Data(), a.data(), sizeof(double) * dimX * dimY, cudaMemcpyHostToDevice); + cudaMemcpy(bTensor.Data(), b.data(), sizeof(double) * dimX * dimY, cudaMemcpyHostToDevice); + cudaMemcpy(complexTensor.Data(), complexMatrix.data(), sizeof(std::complex)*2*2, cudaMemcpyHostToDevice); + cudaMemcpy(matTensor10x10.Data(), matrix10x10.data(), sizeof(float)*10*10, cudaMemcpyHostToDevice); + + (aTensor = matx::transpose(aTensor)).run(); + // (bTensor = matx::transpose(bTensor)).run(); // do not need to transpose because b has the same layout + (complexTensor = matx::transpose(complexTensor)).run(); + (matTensor10x10 = matx::transpose(matTensor10x10)).run(); +#endif + + tensor1D(0) = 1; + tensor1D(1) = 2; + tensor1D(2) = 3; + cudaDeviceSynchronize(); + + // slower alternative of copying per-element + // for(int curX=0; curX mappedMatrix(raw_data, dimX, dimY); + std::cout << "Eigen Mapped Data :\n" << mappedMatrix << std::endl; + + // map user memory into Eigen Matrix + auto mappedTensor = matx::make_tensor(raw_data, {dimX, dimY}, false); // create MatX tensor with non-owning user allocated memory + matx::print(mappedTensor); + + // modify the data from each of the references + raw_data[4] = 117; + mappedMatrix(0,1) = 42; + mappedTensor(2,1) = 87; + + // print modified data + std::cout << "Eigen Mapped Data After Modified :\n" << mappedMatrix << std::endl; + matx::print(mappedTensor); +#endif + + // + // Basic Indexing + // + std::cout << "=================== Indexing ===================" << std::endl; +#ifdef USE_EIGEN + std::cout << "eigen a(1,2) = " << a(1,2) << std::endl; +#endif + + std::cout << "MatX a(1,2) = " << aTensor(1,2) << std::endl; + + + // + // Add A and B + // + std::cout << "=================== Addition ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::MatrixXd addResult = a + b; + std::cout << "A + B = \n" << addResult << std::endl; +#endif + + auto addTensor = aTensor + bTensor; + matx::print(addTensor); + + + // + // Element-Wise Multiply A and B + // + std::cout << "=================== Element-Wise Multiply ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::MatrixXd elementWise = a.cwiseProduct(b); + std::cout << "A .* B = \n" << elementWise << std::endl; +#endif + + auto elementWiseTensor = aTensor*bTensor; + matx::print(elementWiseTensor); + + + // + // Divide A and B + // + std::cout << "=================== Element-Wise Division ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::MatrixXd divResult = a.cwiseQuotient(b); + std::cout << "A / B = \n" << divResult << std::endl; +#endif + + auto divResultTensor = aTensor / bTensor; + matx::print(divResultTensor); + + + // + // Slice (Continuous) + // + std::cout << "=================== Continuous Slice ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::Matrix2d aSlice = a.block(0, 0, 2, 2); + std::cout << "A Sliced: \n" << aSlice << std::endl; +#endif + + auto aSliceTensor = matx::slice<2>(aTensor,{0,0},{2,2}); + matx::print(aSliceTensor); + + + // + // Slice (Strided) + // + std::cout << "=================== Strided Slice ===================" << std::endl; +#ifdef USE_EIGEN + std::cout << "Original matrix10x10:\n" << matrix10x10 << "\n\n"; + // Define the starting point, number of elements to select, and strides for both rows and columns + // int startRow = 0, startCol = 0; // Starting index for rows and columns + // int rowStride = 3, colStride = 2; // Stride along rows and columns + // int numRows = 5; // Calculate the number of rows, considering every second element + // int numCols = 3; // Grab every third item until the 8th item (0, 3, 6) + + // Create a Map with Stride to access the elements + Eigen::Map> + strided(matrix10x10.data() + 0 * matrix10x10.outerStride() + 0, + 5, 3, + Eigen::Stride(3 * matrix10x10.outerStride(), 2)); + + // Print the strided matrix10x10 + std::cout << "Strided matrix10x10:\n" << strided << "\n"; +#endif + + auto slicedMat = matx::slice(matTensor10x10, {0,0}, {matx::matxEnd,9}, {2,3}); + matx::print(slicedMat); + + + // + // Clone + // + std::cout << "=================== Clone ===================" << std::endl; +#ifdef USE_EIGEN + // Use the replicate function to create a 5x5 matrix by replicating the 1x5 matrix + Eigen::MatrixXd mat = rowVec.replicate(3, 1); + std::cout << "1D Cloned to 2D \n" << mat << std::endl; +#endif + + auto cloned3Tensor = matx::clone<2>(tensor1D, {3, matx::matxKeepDim}); + matx::print(cloned3Tensor); + + + // + // Slice Row + // + std::cout << "=================== Slice Row ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::RowVector3d row = a.row(1); + std::cout << "Sliced Row \n" << row << std::endl; +#endif + + auto rowSlice = matx::slice<1>(aTensor, {1, 0}, {matx::matxDropDim, matx::matxEnd}); + matx::print(rowSlice); + + + // + // Permute Rows + // + std::cout << "=================== Permute Rows ===================" << std::endl; +#ifdef USE_EIGEN + std::cout << "Original Matrix:\n" << a << std::endl; + // Define a permutation a + Eigen::PermutationMatrix<3> perm; + perm.indices() << 2, 1, 0; // This permutation swaps the first and third rows + // Apply the permutation to the rows + Eigen::Matrix3d permutedMatrix = perm * a; + std::cout << "Permuted Matrix (Rows):\n" << permutedMatrix << std::endl; +#endif + + // Define a permutation a + auto permVec = matx::make_tensor({dimX}); + permVec(0) = 2; + permVec(1) = 1; + permVec(2) = 0; + // Apply the permutation to the rows + auto permTensor = matx::remap<0>(aTensor, permVec); + matx::print(permTensor); + + + // + // Permutation Dimensions + // + std::cout << "=================== Permute Dimension ===================" << std::endl; + // Unsupported by eigen + auto permA = permute(aTensor, {1,0}); + matx::print(permA); + + // + // Get Real Value + // + std::cout << "=================== Get Real Values ===================" << std::endl; +#ifdef USE_EIGEN + std::cout << "Original Complex Matrix:\n" << complexMatrix << std::endl; + + // Extract and output the real part of the complex matrix + Eigen::Matrix realMatrix = complexMatrix.real(); + std::cout << "Real Part of Matrix:\n" << realMatrix << std::endl; +#endif + + auto realTensor = matx::real(complexTensor); + matx::print(realTensor); + + + // + // Multiply A and B + // + std::cout << "=================== Matrix Multiply ===================" << std::endl; +#ifdef USE_EIGEN + Eigen::MatrixXd multResult = a * b; + std::cout << "A * B = \n" << multResult << std::endl; +#endif + + auto multResultTensor=matmul(aTensor,bTensor); + matx::print(multResultTensor); + + + // + // inverse Matrix + // + std::cout << "=================== Invert Matrix ===================" << std::endl; +#ifdef USE_EIGEN + // Eigen::MatrixXd inverseMatrix = a.inverse(); // current bug where .run() in inverse is ambiguous, so cannot be used with MatX + // std::cout << "Inverse of the Real Part:\n" << inverseMatrix << std::endl; // current bug where .run() in inverse is ambiguous, so cannot be used with MatX +#endif + + auto invTensor = matx::inv(aTensor); + matx::print(invTensor); + + // + // 1D FFT + // + // Unsupported by eigen + + // + // 1D IFFT + // + // Unsupported by eigen + + // + // 2D FFT + // + // Unsupported by eigen + + // + // 2D IFFT + // + // Unsupported by eigen + + // + // Covariance + // + // Unsupported by eigen + + return 0; +}