Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_input/api/math/misc/interp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ interp1
Piecewise interpolation with various methods (linear, nearest, next, previous, spline).

.. doxygenfunction:: interp1(const OpX &x, const OpV &v, const OpXQ &xq, InterpMethod method)
.. doxygenfunction:: interp1(const OpX &x, const OpV &v, const OpXQ &xq, const int (&axis)[1], InterpMethod method)

Interpolation Methods
~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 8 additions & 0 deletions include/matx/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,13 @@ auto __MATX_INLINE__ getPermuteDims( const int (&dims)[D] ) {
return getPermuteDims<RANK>(detail::to_array(dims));
}

template <int RANK>
auto __MATX_INLINE__ invPermute(const cuda::std::array<int, RANK> &perm) {
cuda::std::array<int, RANK> inv_perm;
for (int i = 0; i < RANK; i++) {
inv_perm[perm[i]] = i;
}
return inv_perm;
}
};
};
109 changes: 79 additions & 30 deletions include/matx/operators/interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ namespace matx {
using x_val_type = typename OpX::value_type;
using v_val_type = typename OpV::value_type;

constexpr static int RANK = O::Rank();
constexpr static int AXIS = RANK - 1;
constexpr static int AXIS_X = OpX::Rank() - 1;
constexpr static int AXIS_V = OpV::Rank() - 1;

public:
InterpSplineTridiagonalFillOp(const O& dl, const O& d, const O& du, const O& b, const OpX& x, const OpV& v)
Expand All @@ -76,16 +80,16 @@ namespace matx {
{

cuda::std::array idx{indices...};
index_t idxInterp = idx[Rank() - 1];
index_t idxInterp = idx[AXIS];

cuda::std::array idx0{idx};
cuda::std::array idx1{idx};
cuda::std::array idx2{idx};

if (idxInterp == 0) { // left boundary condition
idx0[Rank() - 1] = idxInterp + 0;
idx1[Rank() - 1] = idxInterp + 1;
idx2[Rank() - 1] = idxInterp + 2;
idx0[AXIS] = idxInterp + 0;
idx1[AXIS] = idxInterp + 1;
idx2[AXIS] = idxInterp + 2;

x_val_type x0 = get_value(x_, idx0);
x_val_type x1 = get_value(x_, idx1);
Expand All @@ -105,10 +109,10 @@ namespace matx {
du_(indices...) = h1 + h0;
b_(indices...) = ((2*h1 + 3*h0)*h1*delta0 + h0*h0*delta1) / (h1 + h0);
}
else if (idxInterp == x_.Size(0) - 1) { // right boundary condition
idx0[Rank() - 1] = idxInterp - 2;
idx1[Rank() - 1] = idxInterp - 1;
idx2[Rank() - 1] = idxInterp;
else if (idxInterp == x_.Size(AXIS_X) - 1) { // right boundary condition
idx0[AXIS] = idxInterp - 2;
idx1[AXIS] = idxInterp - 1;
idx2[AXIS] = idxInterp;

x_val_type x0 = get_value(x_, idx0);
x_val_type x1 = get_value(x_, idx1);
Expand All @@ -130,9 +134,9 @@ namespace matx {
b_(indices...) = ((2*h0 + 3*h1)*h0*delta1 + h1*h1*delta0) / (h0 + h1);
}
else { // interior points
idx0[Rank() - 1] = idxInterp - 1;
idx1[Rank() - 1] = idxInterp;
idx2[Rank() - 1] = idxInterp + 1;
idx0[AXIS] = idxInterp - 1;
idx1[AXIS] = idxInterp;
idx2[AXIS] = idxInterp + 1;

x_val_type x0 = get_value(x_, idx0);
x_val_type x1 = get_value(x_, idx1);
Expand Down Expand Up @@ -176,8 +180,10 @@ namespace matx {
mutable detail::tensor_impl_t<value_type, OpV::Rank()> m_; // Derivatives at sample points (spline only)
mutable value_type *ptr_m_ = nullptr;

constexpr static int32_t RANK = OpXQ::Rank();
constexpr static int32_t AXIS = RANK - 1;
constexpr static int RANK = OpXQ::Rank();
constexpr static int AXIS = RANK - 1;
constexpr static int AXIS_X = OpX::Rank() - 1;
constexpr static int AXIS_V = OpV::Rank() - 1;

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto searchsorted(const cuda::std::array<index_t, RANK> idx, const domain_type x_query) const
Expand All @@ -192,22 +198,22 @@ namespace matx {
cuda::std::array idx_mid{idx};

idx_low[AXIS] = 0;
idx_high[AXIS] = x_.Size(x_.Rank() - 1) - 1;
idx_high[AXIS] = x_.Size(AXIS_X) - 1;

domain_type x_low, x_high, x_mid;

x_low = get_value(x_, idx_low);
if (x_query < x_low) {
idx_low[AXIS] = x_.Size(x_.Rank() - 1);
idx_low[AXIS] = x_.Size(AXIS_X);
idx_high[AXIS] = 0;
return cuda::std::make_tuple(idx_low, idx_high);
} else if (x_query == x_low) {
return cuda::std::make_tuple(idx_low, idx_low);
}
x_high = get_value(x_, idx_high);
if (x_query > x_high) {
idx_low[AXIS] = x_.Size(x_.Rank() - 1) - 1;
idx_high[AXIS] = x_.Size(x_.Rank() - 1);
idx_low[AXIS] = x_.Size(AXIS_X) - 1;
idx_high[AXIS] = x_.Size(AXIS_X);
return cuda::std::make_tuple(idx_low, idx_high);
} else if (x_query == x_high) {
return cuda::std::make_tuple(idx_high, idx_high);
Expand Down Expand Up @@ -237,7 +243,7 @@ namespace matx {

if (idx_high[AXIS] == 0 || idx_low[AXIS] == idx_high[AXIS]) { // x_query <= x(0) or x_query == x(idx_low) == x(idx_high)
v = get_value(v_, idx_high);
} else if (idx_low[AXIS] == x_.Size(0) - 1) { // x_query > x(n-1)
} else if (idx_low[AXIS] == x_.Size(AXIS_X) - 1) { // x_query > x(n-1)
v = get_value(v_, idx_low);
} else {
domain_type x_low = get_value(x_, idx_low);
Expand All @@ -253,9 +259,9 @@ namespace matx {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__
value_type interpolate_nearest(const domain_type x_query, cuda::std::array<index_t, RANK> idx_low, cuda::std::array<index_t, RANK> idx_high) const {
value_type v;
if (idx_low[AXIS] == x_.Size(0)) { // x_query < x(0)
if (idx_low[AXIS] == x_.Size(AXIS_X)) { // x_query < x(0)
v = get_value(v_, idx_high);
} else if (idx_high[AXIS] == x_.Size(0)) { // x_query > x(n-1)
} else if (idx_high[AXIS] == x_.Size(AXIS_X)) { // x_query > x(n-1)
v = get_value(v_, idx_low);
} else {
domain_type x_low = get_value(x_, idx_low);
Expand All @@ -274,7 +280,7 @@ namespace matx {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__
value_type interpolate_next(const domain_type x_query, cuda::std::array<index_t, RANK> idx_low, cuda::std::array<index_t, RANK> idx_high) const {
value_type v;
if (idx_high[AXIS] == x_.Size(0)) { // x_query > x(n-1)
if (idx_high[AXIS] == x_.Size(AXIS_X)) { // x_query > x(n-1)
v = get_value(v_, idx_low);
} else {
v = get_value(v_, idx_high);
Expand All @@ -286,7 +292,7 @@ namespace matx {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__
value_type interpolate_prev(const domain_type x_query, cuda::std::array<index_t, RANK> idx_low, cuda::std::array<index_t, RANK> idx_high) const {
value_type v;
if (idx_low[AXIS] == x_.Size(0)) { // x_query < x(0)
if (idx_low[AXIS] == x_.Size(AXIS_X)) { // x_query < x(0)
v = get_value(v_, idx_high);
} else {
v = get_value(v_, idx_low);
Expand All @@ -301,12 +307,12 @@ namespace matx {
if (idx_high[AXIS] == idx_low[AXIS]) {
value_type v = get_value(v_, idx_low);
return v;
} else if (idx_low[AXIS] == x_.Size(0)) { // x_query < x(0)
} else if (idx_low[AXIS] == x_.Size(AXIS_X)) { // x_query < x(0)
idx_low[AXIS] = 0;
idx_high[AXIS] = 1;
} else if (idx_high[AXIS] == x_.Size(0)) { // x_query > x(n-1)
idx_high[AXIS] = x_.Size(0) - 1;
idx_low[AXIS] = x_.Size(0) - 2;
} else if (idx_high[AXIS] == x_.Size(AXIS_X)) { // x_query > x(n-1)
idx_high[AXIS] = x_.Size(AXIS_X) - 1;
idx_low[AXIS] = x_.Size(AXIS_X) - 2;
}

// sample points
Expand Down Expand Up @@ -523,9 +529,9 @@ namespace matx {


/**
* 1D interpolation of samples at query points.
*
* Interpolation is performed along the last dimension. All other dimensions must be of
* 1D interpolation of samples at query points.
*
* Interpolation is performed along the last dimension. All other dimensions must be of
* compatible size.
*
* @tparam OpX
Expand All @@ -549,6 +555,49 @@ auto interp1(const OpX &x, const OpV &v, const OpXQ &xq, InterpMethod method = I
static_assert(OpX::Rank() >= 1, "interp: sample points must be at least 1D");
static_assert(OpV::Rank() >= OpX::Rank(), "interp: sample values must have at least the same rank as sample points");
static_assert(OpXQ::Rank() >= OpV::Rank(), "interp: query points must have at least the same rank as sample values");
return detail::Interp1Op<OpX, OpV, OpXQ>(x, v, xq, method);
return detail::Interp1Op(x, v, xq, method);
}


/**
* 1D interpolation of samples at query points.
*
* Interpolation is performed along the specified dimension. All other dimensions must be of compatible size.
*
* @tparam OpX
* Type of sample points
* @tparam OpV
* Type of sample values
* @tparam OpXQ
* Type of query points
* @param x
* Sample points. Last dimension must be sorted in ascending order.
* @param v
* Sample values. Must have compatible dimensions with x.
* @param xq
* Query points where to interpolate. All dimensions except the specified dimension must be of compatible size with x and v (e.g. x and v can be vectors, and xq can be a matrix).
* @param axis
* Dimension (of xq) along which to interpolate.
* @param method
* Interpolation method (LINEAR, NEAREST, NEXT, PREV, SPLINE)
* @returns Operator that interpolates values at query points, with the same dimensions as xq.
*/
template <typename OpX, typename OpV, typename OpXQ>
auto interp1(const OpX &x, const OpV &v, const OpXQ &xq, const int (&axis)[1],InterpMethod method = InterpMethod::LINEAR) {
static_assert(OpX::Rank() >= 1, "interp: sample points must be at least 1D");
static_assert(OpV::Rank() >= OpX::Rank(), "interp: sample values must have at least the same rank as sample points");
static_assert(OpXQ::Rank() >= OpV::Rank(), "interp: query points must have at least the same rank as sample values");


auto x_perm = detail::getPermuteDims<OpX::Rank()>({axis[0] + OpX::Rank() - OpXQ::Rank()});
auto v_perm = detail::getPermuteDims<OpV::Rank()>({axis[0] + OpV::Rank() - OpXQ::Rank()});
auto xq_perm = detail::getPermuteDims<OpXQ::Rank()>({axis[0]});

auto px = permute(x, x_perm);
auto pv = permute(v, v_perm);
auto pxq = permute(xq, xq_perm);
auto inv_perm = detail::invPermute<OpXQ::Rank()>(xq_perm);

return permute(detail::Interp1Op(px, pv, pxq, method), inv_perm);
}
} // namespace matx
52 changes: 0 additions & 52 deletions test/00_operators/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,55 +1,3 @@
set(OPERATOR_TEST_FILES
abs2_test.cu
advanced_op_test.cu
advanced_remap_test.cu
angle_test.cu
at_test.cu
base_op_test.cu
broadcast_test.cu
cast_test.cu
clone_add_test.cu
clone_test.cu
collapse_test.cu
complex_cast_exceptions_test.cu
complex_cast_test.cu
complex_type_compatibility_test.cu
concat_test.cu
cross_test.cu
fftshift_test.cu
flatten_test.cu
fmod_test.cu
frexp_test.cu
frexpc_test.cu
get_string_test.cu
interleaved_test.cu
interp_test.cu
isclose_test.cu
isnaninf_test.cu
legendre_test.cu
operator_func_test.cu
overlap_test.cu
permute_test.cu
planar_test.cu
polyval_test.cu
print_test.cu
r2c_test.cu
real_imag_test.cu
remap_rank_zero_test.cu
remap_test.cu
repmat_test.cu
reshape_test.cu
reverse_test.cu
shift_test.cu
simple_executor_accessor_test.cu
slice_and_reduce_test.cu
slice_and_reshape_test.cu
slice_stride_test.cu
slice_test.cu
sph2cart_test.cu
square_copy_transpose_test.cu
stack_test.cu
toeplitz_test.cu
transpose_test.cu
trig_funcs_test.cu
updownsample_test.cu
)
81 changes: 81 additions & 0 deletions test/00_operators/interp_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,86 @@ TEST(InterpTests, Interp)
}


auto px2 = make_tensor<TestType>({5, 2});
auto pv3 = make_tensor<TestType>({3, 5, 2});
auto pxq4 = make_tensor<TestType>({4, 3, 6, 2});

(px2 = permute(x2, {1, 0})).run(exec);
(pv3 = permute(v3, {0, 2, 1})).run(exec);
(pxq4 = permute(xq4, {0, 1, 3, 2})).run(exec);



auto out_perm_linear4 = make_tensor<TestType>(pxq4.Shape());
(out_perm_linear4 = interp1(px2, pv3, pxq4, {2}, InterpMethod::LINEAR)).run(exec);
exec.sync();

for (index_t i = 0; i < pxq4.Size(0); i++) {
for (index_t j = 0; j < pxq4.Size(1); j++) {
for (index_t k = 0; k < pxq4.Size(2); k++) {
for (index_t l = 0; l < pxq4.Size(3); l++) {
ASSERT_EQ(out_perm_linear4(i, j, k, l), vq_linear(k));
}
}
}
}


auto out_perm_nearest4 = make_tensor<TestType>(pxq4.Shape());
(out_perm_nearest4 = interp1(px2, pv3, pxq4, {2}, InterpMethod::NEAREST)).run(exec);
exec.sync();

for (index_t i = 0; i < pxq4.Size(0); i++) {
for (index_t j = 0; j < pxq4.Size(1); j++) {
for (index_t k = 0; k < pxq4.Size(2); k++) {
for (index_t l = 0; l < pxq4.Size(3); l++) {
ASSERT_EQ(out_perm_nearest4(i, j, k, l), vq_nearest(k));
}
}
}
}

auto out_perm_next4 = make_tensor<TestType>(pxq4.Shape());
(out_perm_next4 = interp1(px2, pv3, pxq4, {2}, InterpMethod::NEXT)).run(exec);
exec.sync();

for (index_t i = 0; i < pxq4.Size(0); i++) {
for (index_t j = 0; j < pxq4.Size(1); j++) {
for (index_t k = 0; k < pxq4.Size(2); k++) {
for (index_t l = 0; l < pxq4.Size(3); l++) {
ASSERT_EQ(out_perm_next4(i, j, k, l), vq_next(k));
}
}
}
}

auto out_perm_prev4 = make_tensor<TestType>(pxq4.Shape());
(out_perm_prev4 = interp1(px2, pv3, pxq4, {2}, InterpMethod::PREV)).run(exec);
exec.sync();

for (index_t i = 0; i < pxq4.Size(0); i++) {
for (index_t j = 0; j < pxq4.Size(1); j++) {
for (index_t k = 0; k < pxq4.Size(2); k++) {
for (index_t l = 0; l < pxq4.Size(3); l++) {
ASSERT_EQ(out_perm_prev4(i, j, k, l), vq_prev(k));
}
}
}
}

auto out_perm_spline4 = make_tensor<TestType>(pxq4.Shape());
(out_perm_spline4 = interp1(px2, pv3, pxq4, {2}, InterpMethod::SPLINE)).run(exec);
exec.sync();

for (index_t i = 0; i < pxq4.Size(0); i++) {
for (index_t j = 0; j < pxq4.Size(1); j++) {
for (index_t k = 0; k < pxq4.Size(2); k++) {
for (index_t l = 0; l < pxq4.Size(3); l++) {
ASSERT_NEAR(out_perm_spline4(i, j, k, l), vq_spline(k), 1e-4);
}
}
}
}

MATX_EXIT_HANDLER();
}
Loading