Skip to content

Commit ecde931

Browse files
committed
Add std::transform_reduce
1 parent 12ca9ca commit ecde931

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Algorithms are added on an as-needed basis. If you need one that is not present
3131

3232
### `<numeric>`
3333
* [std::reduce](https://en.cppreference.com/w/cpp/algorithm/reduce)
34+
* [std::transform_reduce](https://en.cppreference.com/w/cpp/algorithm/transform_reduce) (C++17 only)
3435

3536
Note: All iterators must be random access.
3637

include/poolstl/numeric

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,57 @@ namespace std {
5353
typename std::iterator_traits<RandIt>::value_type{});
5454
}
5555

56+
#if POOLSTL_HAVE_CXX17_LIB
57+
/**
58+
* NOTE: Iterators are expected to be random access.
59+
* See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
60+
*/
61+
template <class ExecPolicy, class RandIt1, class T, class BinaryReductionOp, class UnaryTransformOp>
62+
poolstl::internal::enable_if_poolstl_execution_policy<ExecPolicy, T>
63+
transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, T init,
64+
BinaryReductionOp reduce_op, UnaryTransformOp transform_op) {
65+
66+
auto futures = poolstl::internal::parallel_chunk_for(std::forward<ExecPolicy>(policy), first1, last1,
67+
[&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1) {
68+
return std::transform_reduce(chunk_first1, chunk_last1, init, reduce_op, transform_op);
69+
});
70+
71+
return poolstl::internal::cpp17::reduce(
72+
poolstl::internal::get_wrap(futures.begin()),
73+
poolstl::internal::get_wrap(futures.end()), init, reduce_op);
74+
}
75+
76+
/**
77+
* NOTE: Iterators are expected to be random access.
78+
* See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
79+
*/
80+
template <class ExecPolicy, class RandIt1, class RandIt2, class T, class BinaryReductionOp, class BinaryTransformOp>
81+
poolstl::internal::enable_if_poolstl_execution_policy<ExecPolicy, T>
82+
transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init,
83+
BinaryReductionOp reduce_op, BinaryTransformOp transform_op) {
84+
85+
auto futures = poolstl::internal::parallel_chunk_for(std::forward<ExecPolicy>(policy), first1, last1, first2,
86+
[&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt2 chunk_first2) {
87+
return std::transform_reduce(chunk_first1, chunk_last1, chunk_first2, init, reduce_op, transform_op);
88+
});
89+
90+
return poolstl::internal::cpp17::reduce(
91+
poolstl::internal::get_wrap(futures.begin()),
92+
poolstl::internal::get_wrap(futures.end()), init, reduce_op);
93+
}
94+
95+
/**
96+
* NOTE: Iterators are expected to be random access.
97+
* See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
98+
*/
99+
template< class ExecPolicy, class RandIt1, class RandIt2, class T >
100+
poolstl::internal::enable_if_poolstl_execution_policy<ExecPolicy, T>
101+
transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init ) {
102+
return transform_reduce(std::forward<ExecPolicy>(policy),
103+
first1, last1, first2, init, std::plus<>(), std::multiplies<>());
104+
}
105+
#endif
106+
56107
}
57108

58109
#endif

tests/poolstl_test.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,38 @@ TEST_CASE("reduce", "[alg][numeric]") {
178178
}
179179
}
180180

181+
#if POOLSTL_HAVE_CXX17_LIB
182+
TEST_CASE("transform_reduce_1", "[alg][numeric]") {
183+
for (auto num_threads : test_thread_counts) {
184+
ttp::task_thread_pool pool(num_threads);
185+
186+
for (auto num_iters : test_arr_sizes) {
187+
auto v = iota_vector(num_iters);
188+
189+
auto doubler = [&](auto x) { return 2*x; };
190+
auto seq = std::transform_reduce(v.cbegin(), v.cend(), 0, std::plus<>(), doubler);
191+
auto par = std::transform_reduce(poolstl::par_pool(pool), v.cbegin(), v.cend(), 0, std::plus<>(), doubler);
192+
REQUIRE(seq == par);
193+
}
194+
}
195+
}
196+
197+
TEST_CASE("transform_reduce_2", "[alg][numeric]") {
198+
for (auto num_threads : test_thread_counts) {
199+
ttp::task_thread_pool pool(num_threads);
200+
201+
for (auto num_iters : test_arr_sizes) {
202+
auto v1 = iota_vector(num_iters);
203+
auto v2 = iota_vector(num_iters);
204+
205+
auto seq = std::transform_reduce(v1.cbegin(), v1.cend(), v2.cbegin(), 0);
206+
auto par = std::transform_reduce(poolstl::par_pool(pool), v1.cbegin(), v1.cend(), v2.cbegin(), 0);
207+
REQUIRE(seq == par);
208+
}
209+
}
210+
}
211+
#endif
212+
181213
TEST_CASE("default_pool", "[execution]") {
182214
std::atomic<int> sum{0};
183215
for (auto num_iters : test_arr_sizes) {

0 commit comments

Comments
 (0)