1+ // //////////////////////////////////////////////////////////////////////////////
2+ // BSD 3-Clause License
3+ //
4+ // Copyright (c) 2021, NVIDIA Corporation
5+ // All rights reserved.
6+ //
7+ // Redistribution and use in source and binary forms, with or without
8+ // modification, are permitted provided that the following conditions are met:
9+ //
10+ // 1. Redistributions of source code must retain the above copyright notice, this
11+ // list of conditions and the following disclaimer.
12+ //
13+ // 2. Redistributions in binary form must reproduce the above copyright notice,
14+ // this list of conditions and the following disclaimer in the documentation
15+ // and/or other materials provided with the distribution.
16+ //
17+ // 3. Neither the name of the copyright holder nor the names of its
18+ // contributors may be used to endorse or promote products derived from
19+ // this software without specific prior written permission.
20+ //
21+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+ // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+ // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+ // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+ // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+ // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+ // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+ // ///////////////////////////////////////////////////////////////////////////////
32+
33+ #pragma once
34+
35+ #include " matx/core/type_utils.h"
36+ #include " matx/operators/base_operator.h"
37+ #include " matx/operators/permute.h"
38+ #include " matx/transforms/percentile.h"
39+
40+ namespace matx {
41+
42+
43+ namespace detail {
44+ template <typename OpA, int ORank>
45+ class PercentileOp : public BaseOp <PercentileOp<OpA,ORank>>
46+ {
47+ private:
48+ OpA a_;
49+ uint32_t q_;
50+ PercentileMethod method_;
51+ std::array<index_t , ORank> out_dims_;
52+ mutable matx::tensor_t <typename remove_cvref_t <OpA>::scalar_type, ORank> tmp_out_;
53+
54+ public:
55+ using matxop = bool ;
56+ using scalar_type = typename remove_cvref_t <OpA>::scalar_type;
57+ using matx_transform_op = bool ;
58+ using prod_xform_op = bool ;
59+
60+ __MATX_INLINE__ std::string str () const { return " percentile(" + get_type_str (a_) + " )" ; }
61+ __MATX_INLINE__ PercentileOp (OpA a, unsigned char q, PercentileMethod method) : a_(a), q_(q), method_(method) {
62+ for (int r = 0 ; r < ORank; r++) {
63+ out_dims_[r] = (r == ORank - 1 ) ? 1 : a_.Size (r);
64+ }
65+ };
66+
67+ template <typename ... Is>
68+ __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator ()(Is... indices) const {
69+ return tmp_out_ (indices...);
70+ };
71+
72+ template <typename Out, typename Executor>
73+ void Exec (Out &&out, Executor &&ex) const {
74+ percentile_impl (std::get<0 >(out), a_, q_, method_, ex);
75+ }
76+
77+ static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank ()
78+ {
79+ return ORank;
80+ }
81+
82+ template <typename ShapeType, typename Executor>
83+ __MATX_INLINE__ void PreRun ([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
84+ {
85+ if constexpr (is_matx_op<OpA>()) {
86+ a_.PreRun (std::forward<ShapeType>(shape), std::forward<Executor>(ex));
87+ }
88+
89+ if constexpr (is_device_executor_v<Executor>) {
90+ make_tensor (tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream ());
91+ }
92+ else {
93+ make_tensor (tmp_out_, out_dims_, MATX_HOST_MEMORY);
94+ }
95+
96+ Exec (std::make_tuple (tmp_out_), std::forward<Executor>(ex));
97+ }
98+
99+ constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (int dim) const
100+ {
101+ return out_dims_[dim];
102+ }
103+
104+ };
105+ }
106+
107+ /* *
108+ * Compute product of numbers along axes
109+ *
110+ * Returns a tensor representing the product of all items in the reduction
111+ *
112+ * @tparam InType
113+ * Input data type
114+ * @tparam D
115+ * Num of dimensions to reduce over
116+ *
117+ * @param in
118+ * Input data to reduce
119+ * @param q
120+ * Percentile to compute (between 0-100)
121+ * @param dims
122+ * Array containing dimensions to compute over
123+ * @param method
124+ * Method of interpolation
125+ * @returns Operator with reduced values of prod-reduce computed
126+ */
127+ template <typename InType, int D>
128+ __MATX_INLINE__ auto percentile (const InType &in, unsigned char q, const int (&dims)[D], PercentileMethod method = PercentileMethod::LINEAR)
129+ {
130+ static_assert (D < InType::Rank (), " reduction dimensions must be <= Rank of input" );
131+ MATX_ASSERT_STR (q < 100 , matxInvalidParameter, " Percentile must be < 100" );
132+ auto perm = detail::getPermuteDims<InType::Rank ()>(dims);
133+ auto permop = permute (in, perm);
134+
135+ return detail::PercentileOp<decltype (permop), InType::Rank () - D>(permop, q, method);
136+ }
137+
138+ /* *
139+ * Compute product of numbers
140+ *
141+ * Returns a tensor representing the product of all items in the reduction
142+ *
143+ * @tparam InType
144+ * Input data type
145+ *
146+ * @param in
147+ * Input data to reduce
148+ * @param q
149+ * Percentile to compute (between 0-100)
150+ * @param method
151+ * Method of interpolation
152+ * @returns Operator with reduced values of prod-reduce computed
153+ */
154+ template <typename InType>
155+ __MATX_INLINE__ auto percentile (const InType &in, unsigned char q, PercentileMethod method = PercentileMethod::LINEAR)
156+ {
157+ return detail::PercentileOp<decltype (in), 0 >(in, q, method);
158+ }
159+
160+ }
0 commit comments