@@ -1976,13 +1976,10 @@ void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, [[maybe_unused]]
19761976
19771977
19781978/* *
1979- * Compute max reduction of a tensor
1979+ * Compute max reduction of an operator
19801980 *
19811981 * Returns a tensor representing the max of all numbers in the reduction
19821982 *
1983- * @note This function uses the name rmax instead of max to not collide with the
1984- * element-wise operator max.
1985- *
19861983 * @tparam OutType
19871984 * Output data type
19881985 * @tparam InType
@@ -2007,13 +2004,10 @@ void __MATX_INLINE__ max_impl(OutType dest, const InType &in, cudaExecutor exec
20072004}
20082005
20092006/* *
2010- * Compute max reduction of a tensor
2007+ * Compute max reduction of an operator
20112008 *
20122009 * Returns a tensor representing the max of all numbers in the reduction
20132010 *
2014- * @note This function uses the name rmax instead of max to not collide with the
2015- * element-wise operator max.
2016- *
20172011 * @tparam OutType
20182012 * Output data type
20192013 * @tparam InType
@@ -2036,8 +2030,9 @@ void __MATX_INLINE__ max_impl(OutType dest, const InType &in, [[maybe_unused]] c
20362030 *lout = *std::max_element (lin, lin + TotalSize (in));
20372031 }
20382032 else {
2039- auto els = lend[1 ] - lbegin[0 ];
2040- for (index_t b = 0 ; b < els; b++) {
2033+ const index_t BATCHES = TotalSize (dest);
2034+ const index_t els = lend[0 ] - lbegin[0 ];
2035+ for (index_t b = 0 ; b < BATCHES; b++) {
20412036 lout[b] = *std::max_element (lin + lbegin[b], lin + lend[b]);
20422037 }
20432038 }
@@ -2084,9 +2079,9 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT
20842079}
20852080
20862081/* *
2087- * Compute maxn reduction of a tensor and returns value + index
2082+ * Compute max reduction of an operator and returns value + index
20882083 *
2089- * Returns a tensor with maximums and indices
2084+ * Returns a tensor with maximums and a tensor with indices
20902085 *
20912086 * @tparam OutType
20922087 * Output data type
@@ -2114,8 +2109,9 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT
21142109 *lout = cuda::std::max_element (lin, lin + TotalSize (in)) - lin;
21152110 }
21162111 else {
2117- auto els = lend[0 ] - lbegin[0 ];
2118- for (index_t b = 0 ; b < els; b++) {
2112+ const index_t BATCHES = TotalSize (dest);
2113+ const index_t els = lend[0 ] - lbegin[0 ];
2114+ for (index_t b = 0 ; b < BATCHES; b++) {
21192115 lout[b] = cuda::std::max_element (lin + lbegin[b], lin + lend[b]) - lin;
21202116 }
21212117 }
@@ -2130,7 +2126,7 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT
21302126
21312127
21322128/* *
2133- * Compute min reduction of a tensor
2129+ * Compute min reduction of an operator
21342130 *
21352131 * Returns a tensor representing the min of all numbers in the reduction
21362132 *
@@ -2158,13 +2154,10 @@ void __MATX_INLINE__ min_impl(OutType dest, const InType &in, cudaExecutor exec
21582154}
21592155
21602156/* *
2161- * Compute min reduction of a tensor
2157+ * Compute min reduction of an operator
21622158 *
21632159 * Returns a tensor representing the min of all numbers in the reduction
21642160 *
2165- * @note This function uses the name rmin instead of min to not collide with the
2166- * element-wise operator min.
2167- *
21682161 * @tparam OutType
21692162 * Output data type
21702163 * @tparam InType
@@ -2186,8 +2179,9 @@ void __MATX_INLINE__ min_impl(OutType dest, const InType &in, [[maybe_unused]] c
21862179 *lout = *std::min_element (lin, lin + TotalSize (in));
21872180 }
21882181 else {
2189- auto els = lend[1 ] - lbegin[0 ];
2190- for (index_t b = 0 ; b < els; b++) {
2182+ const index_t BATCHES = TotalSize (dest);
2183+ const index_t els = lend[0 ] - lbegin[0 ];
2184+ for (index_t b = 0 ; b < BATCHES; b++) {
21912185 lout[b] = *std::min_element (lin + lbegin[b], lin + lend[b]);
21922186 }
21932187 }
@@ -2234,6 +2228,53 @@ void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InT
22342228#endif
22352229}
22362230
2231+ /* *
2232+ * Compute min reduction of an operator and returns value + index
2233+ *
2234+ * Returns a tensor with minimums and indices
2235+ *
2236+ * @tparam OutType
2237+ * Output data type
2238+ * @tparam TensorIndexType
2239+ * Output type stpring indices
2240+ * @tparam InType
2241+ * Input data type
2242+ * @tparam MODE
2243+ * Host executor threads mode
2244+ *
2245+ * @param dest
2246+ * Destination view of reduction
2247+ * @param idest
2248+ * Destination for indices
2249+ * @param in
2250+ * Input data to reduce
2251+ * @param exec
2252+ * Single host executor
2253+ */
2254+ template <typename OutType, typename TensorIndexType, typename InType, ThreadsMode MODE>
2255+ void __MATX_INLINE__ argmin_impl (OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] const HostExecutor<MODE> &exec)
2256+ {
2257+ MATX_NVTX_START (" argmin_impl(" + get_type_str (in) + " )" , matx::MATX_NVTX_LOG_API)
2258+
2259+ auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) {
2260+ if constexpr (OutType::Rank () == 0 ) {
2261+ *lout = cuda::std::min_element (lin, lin + TotalSize (in)) - lin;
2262+ }
2263+ else {
2264+ const index_t BATCHES = TotalSize (dest);
2265+ const index_t els = lend[0 ] - lbegin[0 ];
2266+ for (index_t b = 0 ; b < BATCHES; b++) {
2267+ lout[b] = cuda::std::min_element (lin + lbegin[b], lin + lend[b]) - lin;
2268+ }
2269+ }
2270+ };
2271+
2272+ // This could be more efficient by not running two reductions to find the same values, but
2273+ // for brevity this is faster
2274+ ReduceInput (ft, idest, in);
2275+ min_impl (dest, in, exec);
2276+ }
2277+
22372278/* *
22382279 * Compute min and max reduction of an operator and returns value + index
22392280 *
@@ -2281,51 +2322,45 @@ void __MATX_INLINE__ argminmax_impl(OutType destmin, TensorIndexType &idestmin,
22812322}
22822323
22832324/* *
2284- * Compute min reduction of a tensor and returns value + index
2325+ * Compute min and max reduction of an operator and returns value + index
22852326 *
2286- * Returns a tensor with minimums and indices
2327+ * Returns tensors with minimums and indices, and maximums and indices
22872328 *
22882329 * @tparam OutType
22892330 * Output data type
22902331 * @tparam TensorIndexType
22912332 * Output type stpring indices
22922333 * @tparam InType
22932334 * Input data type
2335+ * @tparam MODE
2336+ * Host executor threads mode
22942337 *
2295- * @param dest
2296- * Destination view of reduction
2297- * @param idest
2298- * Destination for indices
2338+ * @param destmin
2339+ * Destination view of min reduction
2340+ * @param idestmin
2341+ * Destination for min indices
2342+ * @param destmax
2343+ * Destination view of max reduction
2344+ * @param idestmax
2345+ * Destination for max indices
22992346 * @param in
23002347 * Input data to reduce
23012348 * @param exec
2302- * SIngle host executor
2349+ * Single host executor
23032350 */
23042351template <typename OutType, typename TensorIndexType, typename InType, ThreadsMode MODE>
2305- void __MATX_INLINE__ argmin_impl (OutType dest , TensorIndexType &idest , const InType &in, [[maybe_unused]] const HostExecutor<MODE> &exec)
2352+ void __MATX_INLINE__ argminmax_impl (OutType destmin , TensorIndexType &idestmin, OutType destmax, TensorIndexType &idestmax , const InType &in, [[maybe_unused]] const HostExecutor<MODE> &exec)
23062353{
2307- MATX_NVTX_START (" argmin_impl(" + get_type_str (in) + " )" , matx::MATX_NVTX_LOG_API)
2308-
2309- auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) {
2310- if constexpr (OutType::Rank () == 0 ) {
2311- *lout = cuda::std::min_element (lin, lin + TotalSize (in)) - lin;
2312- }
2313- else {
2314- auto els = lend[1 ] - lbegin[0 ];
2315- for (index_t b = 0 ; b < els; b++) {
2316- lout[b] = cuda::std::min_element (lin + lbegin[b], lin + lend[b]) - lin;
2317- }
2318- }
2319- };
2354+ static_assert (OutType::Rank () == TensorIndexType::Rank ());
2355+ MATX_NVTX_START (" argminmax_impl(" + get_type_str (in) + " )" , matx::MATX_NVTX_LOG_API)
23202356
2321- // This could be more efficient by not running two reductions to find the same values, but
2357+ // This could be more efficient by not running argmin and argmax separately but
23222358 // for brevity this is faster
2323- ReduceInput (ft, idest , in);
2324- min_impl (dest , in, exec);
2359+ argmin_impl (destmin, idestmin , in, exec );
2360+ argmax_impl (destmax, idestmax , in, exec);
23252361}
23262362
23272363
2328-
23292364/* *
23302365 * Find if any value is != 0
23312366 *
0 commit comments