Skip to content

Commit db3fef2

Browse files
authored
Added find_peaks operator (#1029)
* Added find_peaks operator
1 parent 75a279c commit db3fef2

File tree

10 files changed

+423
-25
lines changed

10 files changed

+423
-25
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
.. _find_peaks_func:
2+
3+
find_peaks
4+
##########
5+
6+
Finds the peaks in a 1D input operator.
7+
8+
Currently only the `height` and `threshold` parameters are supported.
9+
10+
.. doxygenfunction:: find_peaks
11+
12+
Examples
13+
~~~~~~~~
14+
15+
.. literalinclude:: ../../../test/00_operators/find_peaks.cu
16+
:language: cpp
17+
:start-after: example-begin findpeaks-test-1
18+
:end-before: example-end findpeaks-test-1
19+
:dedent:
20+
21+
22+

include/matx/core/type_utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,26 @@ inline constexpr bool has_shape_type_v = detail::has_shape_type<typename remove_
677677

678678

679679

680+
// Detect presence of nested alias `index_cmp_op` and verify it equals bool
681+
namespace detail {
682+
template <typename T, typename = void>
683+
struct has_index_cmp_op : std::false_type {};
684+
685+
template <typename T>
686+
struct has_index_cmp_op<T, std::void_t<typename T::index_cmp_op>>
687+
: std::bool_constant<std::is_same_v<typename T::index_cmp_op, bool>> {};
688+
}
689+
690+
/**
691+
* @brief Determine if a type defines `using index_cmp_op = bool;`
692+
*
693+
* @tparam T Type to test
694+
*/
695+
template <typename T>
696+
inline constexpr bool has_index_cmp_op_v = detail::has_index_cmp_op<typename remove_cvref<T>::type>::value;
697+
698+
699+
680700
namespace detail {
681701
template <typename T>
682702
struct is_complex_half
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2021, NVIDIA Corporation
5+
// sum rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COpBRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHsum THE COpBRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
36+
#include "matx/core/type_utils.h"
37+
#include "matx/operators/base_operator.h"
38+
#include "matx/operators/permute.h"
39+
#include "matx/transforms/find_peaks.h"
40+
41+
namespace matx {
42+
43+
44+
45+
namespace detail {
46+
template<typename OpA>
47+
class FindPeaksOp : public BaseOp<FindPeaksOp<OpA>>
48+
{
49+
private:
50+
typename detail::base_type_t<OpA> a_;
51+
typename remove_cvref_t<OpA>::value_type height_;
52+
typename remove_cvref_t<OpA>::value_type threshold_;
53+
54+
public:
55+
using matxop = bool;
56+
using value_type = typename remove_cvref_t<OpA>::value_type;
57+
using matx_transform_op = bool;
58+
using find_peaks_xform_op = bool;
59+
60+
__MATX_INLINE__ std::string str() const { return "find_peaks(" + get_type_str(a_) + ")"; }
61+
__MATX_INLINE__ FindPeaksOp(const OpA &a, value_type height,
62+
value_type threshold) :
63+
a_(a), height_(height), threshold_(threshold) {
64+
}
65+
66+
template <typename... Is>
67+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const = delete;
68+
69+
template <OperatorCapability Cap>
70+
__MATX_INLINE__ __MATX_HOST__ auto get_capability() const {
71+
auto self_has_cap = capability_attributes<Cap>::default_value;
72+
return combine_capabilities<Cap>(self_has_cap, detail::get_operator_capability<Cap>(a_));
73+
}
74+
75+
template <typename Out, typename Executor>
76+
void Exec(Out &&out, Executor &&ex) const {
77+
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 2 outputs on find_peaks(). ie: (mtie(O, num_found) = find_peaks(A, height, threshold))");
78+
static_assert(remove_cvref_t<decltype(cuda::std::get<1>(out))>::Rank() == 0 &&
79+
std::is_same_v<typename remove_cvref_t<decltype(cuda::std::get<1>(out))>::value_type, int>,
80+
"Num elements output must be a scalar integer tensor");
81+
static_assert(std::is_same_v<typename remove_cvref_t<decltype(cuda::std::get<0>(out))>::value_type, index_t>,
82+
"Peak indices output must be a 1D matx::index_t tensor");
83+
find_peaks_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, height_, threshold_, ex);
84+
}
85+
86+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
87+
{
88+
return remove_cvref_t<OpA>::Rank();
89+
}
90+
91+
template <typename ShapeType, typename Executor>
92+
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
93+
{
94+
if constexpr (is_matx_op<OpA>()) {
95+
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
96+
}
97+
}
98+
99+
template <typename ShapeType, typename Executor>
100+
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
101+
{
102+
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
103+
}
104+
105+
template <typename ShapeType, typename Executor>
106+
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
107+
{
108+
if constexpr (is_matx_op<OpA>()) {
109+
a_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
110+
}
111+
}
112+
113+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
114+
{
115+
return a_.Size(dim);
116+
}
117+
118+
};
119+
}
120+
121+
122+
/**
123+
* Compute peak search of input
124+
*
125+
* Returns a tensor representing the indices of peaks found in the input operator. The first output parameter holds the indices
126+
* while the second holds the number of indices/peaks found. The output index tensor must be large enough to hold all of the peaks
127+
* found or the behavior is undefined.
128+
*
129+
* @tparam InType
130+
* Input data type
131+
* @tparam D
132+
* Number of right-most dimensions to reduce over
133+
*
134+
* @param in
135+
* Input data to reduce
136+
* @param height
137+
* Height threshold for peak detection. Values below this threshold are not considered peaks.
138+
* @param threshold
139+
* Threshold for peak detection. Neighboring values must be larger in vertical distance than this threshold
140+
* @returns Operator with reduced values of peak search computed
141+
*/
142+
template <typename InType>
143+
__MATX_INLINE__ auto find_peaks(const InType &in,
144+
typename InType::value_type height,
145+
typename InType::value_type threshold)
146+
{
147+
static_assert(InType::Rank() == 1, "Input to find_peaks() must be rank 1");
148+
return detail::FindPeaksOp<decltype(in)>(in, height, threshold);
149+
}
150+
151+
}

include/matx/operators/operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
#include "matx/operators/fft.h"
6464
#include "matx/operators/fftshift.h"
6565
#include "matx/operators/filter.h"
66+
#include "matx/operators/find_peaks.h"
6667
#include "matx/operators/flatten.h"
6768
#include "matx/operators/frexp.h"
6869
#include "matx/operators/hermitian.h"

include/matx/transforms/cub.h

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ typedef enum {
7272
CUB_OP_REDUCE_SUM,
7373
CUB_OP_REDUCE_MIN,
7474
CUB_OP_REDUCE_MAX,
75-
CUB_OP_SELECT,
75+
CUB_OP_SELECT_VALS,
7676
CUB_OP_SELECT_IDX,
7777
CUB_OP_UNIQUE,
7878
CUB_OP_SINGLE_ARG_REDUCE,
@@ -178,7 +178,7 @@ class matxCubPlan_t {
178178
else if constexpr (op == CUB_OP_REDUCE_MAX) {
179179
ExecMax(a_out, a, stream);
180180
}
181-
else if constexpr (op == CUB_OP_SELECT) {
181+
else if constexpr (op == CUB_OP_SELECT_VALS) {
182182
ExecSelect(a_out, a, stream);
183183
}
184184
else if constexpr (op == CUB_OP_SELECT_IDX) {
@@ -819,8 +819,6 @@ inline void ExecSort(OutputTensor &a_out,
819819
#endif
820820
}
821821

822-
823-
824822
/**
825823
* Execute a selection reduction on a tensor
826824
*
@@ -922,16 +920,30 @@ inline void ExecSort(OutputTensor &a_out,
922920
#ifdef __CUDACC__
923921
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
924922

925-
if constexpr (is_tensor_view_v<InputOperator>) {
926-
if (a.IsContiguous()) {
927-
cub::DeviceSelect::If(d_temp,
928-
temp_storage_bytes,
929-
thrust::counting_iterator<index_t>(0),
930-
a_out.Data(),
931-
cparams_.num_found.Data(),
932-
static_cast<int>(TotalSize(a)),
933-
IndexToSelectOp<decltype(a.Data()), decltype(cparams_.op)>{a.Data(), cparams_.op},
934-
stream);
923+
if (!has_index_cmp_op_v<decltype(cparams_.op)>) {
924+
if constexpr (is_tensor_view_v<InputOperator>) {
925+
if (a.IsContiguous()) {
926+
cub::DeviceSelect::If(d_temp,
927+
temp_storage_bytes,
928+
thrust::counting_iterator<index_t>(0),
929+
a_out.Data(),
930+
cparams_.num_found.Data(),
931+
static_cast<int>(TotalSize(a)),
932+
IndexToSelectOp<decltype(a.Data()), decltype(cparams_.op)>{a.Data(), cparams_.op},
933+
stream);
934+
}
935+
else {
936+
tensor_impl_t<typename InputOperator::value_type, InputOperator::Rank(), typename InputOperator::desc_type> base = a;
937+
cub::DeviceSelect::If(d_temp,
938+
temp_storage_bytes,
939+
thrust::counting_iterator<index_t>(0),
940+
a_out.Data(),
941+
cparams_.num_found.Data(),
942+
static_cast<int>(TotalSize(a)),
943+
IndexToSelectOp<decltype(RandomOperatorIterator{base}), decltype(cparams_.op)>
944+
{RandomOperatorIterator{base}, cparams_.op},
945+
stream);
946+
}
935947
}
936948
else {
937949
tensor_impl_t<typename InputOperator::value_type, InputOperator::Rank(), typename InputOperator::desc_type> base = a;
@@ -947,16 +959,18 @@ inline void ExecSort(OutputTensor &a_out,
947959
}
948960
}
949961
else {
962+
// Custom compare op that only takes an index. This can be more powerful for users by allowing them to define whatever
963+
// they want inside the op and not be limited to simple binary comparisons.
950964
cub::DeviceSelect::If(d_temp,
951-
temp_storage_bytes,
952-
thrust::counting_iterator<index_t>(0),
953-
a_out.Data(),
954-
cparams_.num_found.Data(),
955-
static_cast<int>(TotalSize(a)),
956-
IndexToSelectOp<decltype(RandomOperatorIterator{a}), decltype(cparams_.op)>
957-
{RandomOperatorIterator{a}, cparams_.op},
958-
stream);
965+
temp_storage_bytes,
966+
thrust::counting_iterator<index_t>(0),
967+
a_out.Data(),
968+
cparams_.num_found.Data(),
969+
static_cast<int>(TotalSize(a)),
970+
cparams_.op,
971+
stream);
959972
}
973+
960974
#endif
961975
}
962976

@@ -2399,6 +2413,7 @@ struct GTE
23992413
};
24002414

24012415

2416+
24022417
/**
24032418
* Reduce values that meet a certain criteria
24042419
*
@@ -2443,9 +2458,9 @@ void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator
24432458
auto params =
24442459
detail::matxCubPlan_t<OutputTensor,
24452460
InputOperator,
2446-
detail::CUB_OP_SELECT,
2461+
detail::CUB_OP_SELECT_VALS,
24472462
param_type>::GetCubParams(a_out, a, stream);
2448-
using cache_val_type = detail::matxCubPlan_t<OutputTensor, InputOperator, detail::CUB_OP_SELECT, param_type>;
2463+
using cache_val_type = detail::matxCubPlan_t<OutputTensor, InputOperator, detail::CUB_OP_SELECT_VALS, param_type>;
24492464
detail::GetCache().LookupAndExec<detail::cub_cache_t>(
24502465
detail::GetCacheIdFromType<detail::cub_cache_t>(),
24512466
params,
@@ -2461,7 +2476,7 @@ void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator
24612476
#else
24622477
auto tmp = detail::matxCubPlan_t< OutputTensor,
24632478
InputOperator,
2464-
detail::CUB_OP_SELECT,
2479+
detail::CUB_OP_SELECT_VALS,
24652480
decltype(cparams)>{a_out, a, cparams, stream};
24662481
tmp.ExecSelect(a_out, a, stream);
24672482
#endif

0 commit comments

Comments
 (0)