Skip to content

Commit e226725

Browse files
committed
Update Black-Scholes to add apply() example
1 parent ea07b17 commit e226725

File tree

2 files changed

+87
-38
lines changed

2 files changed

+87
-38
lines changed

examples/black_scholes.cu

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,29 @@ using namespace matx;
4848
* boilerplate code around the original expression. This custom operator can then be used either alone or inside
4949
* other arithmetic expressions, and only a single load is issues for each tensor.
5050
*
51-
* This example uses the Black-Scholes equtation to demonstrate the two ways to implement the equation in MatX, and
52-
* shows the performance difference.
51+
* This example uses the Black-Scholes equtation to demonstrate three ways to implement the equation in MatX, and
52+
* shows the performance difference between them. The three ways are:
53+
* 1. Using a custom operator
54+
* 2. Using a lambda function via apply()
55+
* 3. Using a MatX expression
56+
*
57+
* Which method to use depends on the use case, but the lambda function is preferred for simplicity and readability.
5358
*/
5459

5560
/* Custom operator */
56-
template <class O, class I1>
57-
class BlackScholes : public BaseOp<BlackScholes<O, I1>> {
61+
template <class I1>
62+
class BlackScholes : public BaseOp<BlackScholes<I1>> {
5863
private:
59-
O out_;
6064
I1 V_, S_, K_, r_, T_;
6165

6266
public:
6367
using matxop = bool;
6468

65-
BlackScholes(O out, I1 K, I1 V, I1 S, I1 r, I1 T)
66-
: out_(out), V_(V), S_(S), K_(K), r_(r), T_(T) {}
69+
BlackScholes(I1 K, I1 V, I1 S, I1 r, I1 T)
70+
: V_(V), S_(S), K_(K), r_(r), T_(T) {}
6771

6872
template <detail::ElementsPerThread EPT>
69-
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void operator()(index_t idx)
73+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto operator()(index_t idx) const
7074
{
7175
auto V = V_(idx);
7276
auto K = K_(idx);
@@ -81,27 +85,32 @@ public:
8185
auto cdf_d2 = normcdff(d2);
8286
auto expRT = exp(-1.f * r * T);
8387

84-
out_(idx) = S * cdf_d1 - K * expRT * cdf_d2;
88+
return S * cdf_d1 - K * expRT * cdf_d2;
8589
}
8690

8791
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void operator()(index_t idx) {
8892
return this->operator()<detail::ElementsPerThread::ONE>(idx);
8993
}
9094

91-
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(uint32_t i) const { return out_.Size(i); }
92-
static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { return O::Rank(); }
95+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(uint32_t i) const { return V_.Size(i); }
96+
static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { return I1::Rank(); }
9397

9498
template <detail::OperatorCapability Cap>
9599
__MATX_INLINE__ __MATX_HOST__ auto get_capability() const {
96-
auto self_has_cap = detail::capability_attributes<Cap>::default_value;
97-
return detail::combine_capabilities<Cap>(
98-
self_has_cap,
99-
detail::get_operator_capability<Cap>(V_),
100-
detail::get_operator_capability<Cap>(S_),
101-
detail::get_operator_capability<Cap>(K_),
102-
detail::get_operator_capability<Cap>(r_),
103-
detail::get_operator_capability<Cap>(T_)
104-
);
100+
// Don't support vectorization yet
101+
if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) {
102+
return detail::ElementsPerThread::ONE;
103+
} else {
104+
auto self_has_cap = detail::capability_attributes<Cap>::default_value;
105+
return detail::combine_capabilities<Cap>(
106+
self_has_cap,
107+
detail::get_operator_capability<Cap>(V_),
108+
detail::get_operator_capability<Cap>(S_),
109+
detail::get_operator_capability<Cap>(K_),
110+
detail::get_operator_capability<Cap>(r_),
111+
detail::get_operator_capability<Cap>(T_)
112+
);
113+
}
105114
}
106115
};
107116

@@ -132,7 +141,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
132141
using dtype = float;
133142

134143
index_t input_size = 100000000;
135-
constexpr uint32_t num_iterations = 1;
144+
constexpr uint32_t num_iterations = 100;
136145
float time_ms;
137146

138147
tensor_t<dtype, 1> K_tensor{{input_size}};
@@ -141,12 +150,20 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
141150
tensor_t<dtype, 1> r_tensor{{input_size}};
142151
tensor_t<dtype, 1> T_tensor{{input_size}};
143152
tensor_t<dtype, 1> output_tensor{{input_size}};
153+
tensor_t<dtype, 1> output_tensor2{{input_size}};
154+
tensor_t<dtype, 1> output_tensor3{{input_size}};
155+
156+
(K_tensor = random<float>({input_size}, UNIFORM)).run();
157+
(S_tensor = random<float>({input_size}, UNIFORM)).run();
158+
(V_tensor = random<float>({input_size}, UNIFORM)).run();
159+
(r_tensor = random<float>({input_size}, UNIFORM)).run();
160+
(T_tensor = random<float>({input_size}, UNIFORM)).run();
144161

145162
cudaStream_t stream;
146163
cudaStreamCreate(&stream);
147164
cudaExecutor exec{stream};
148165

149-
compute_black_scholes_matx(K_tensor, S_tensor, V_tensor, r_tensor, T_tensor, output_tensor, exec);
166+
//compute_black_scholes_matx(K_tensor, S_tensor, V_tensor, r_tensor, T_tensor, output_tensor, exec);
150167

151168
cudaEvent_t start, stop;
152169
cudaEventCreate(&start);
@@ -159,6 +176,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
159176
}
160177
cudaEventRecord(stop, stream);
161178
exec.sync();
179+
162180
cudaEventElapsedTime(&time_ms, start, stop);
163181

164182
printf("Time without custom operator = %.2fms per iteration\n",
@@ -167,10 +185,11 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
167185
cudaEventRecord(start, stream);
168186
// Time non-operator version
169187
for (uint32_t i = 0; i < num_iterations; i++) {
170-
BlackScholes(output_tensor, K_tensor, V_tensor, S_tensor, r_tensor, T_tensor).run(exec);
188+
(output_tensor2 = BlackScholes(K_tensor, V_tensor, S_tensor, r_tensor, T_tensor)).run(exec);
171189
}
172190
cudaEventRecord(stop, stream);
173191
exec.sync();
192+
174193
cudaEventElapsedTime(&time_ms, start, stop);
175194
printf("Time with custom operator = %.2fms per iteration\n",
176195
time_ms / num_iterations);
@@ -192,15 +211,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
192211

193212
cudaEventRecord(start, stream);
194213
for (uint32_t i = 0; i < num_iterations; i++) {
195-
(output_tensor = matx::apply(bs_lambda, K_tensor, S_tensor, V_tensor, r_tensor, T_tensor)).run(exec);
214+
(output_tensor3 = matx::apply(bs_lambda, K_tensor, S_tensor, V_tensor, r_tensor, T_tensor)).run(exec);
196215
}
216+
197217
cudaEventRecord(stop, stream);
198218
exec.sync();
219+
199220
cudaEventElapsedTime(&time_ms, start, stop);
200221
printf("Time with lambda = %.2fms per iteration\n",
201222
time_ms / num_iterations);
202223

203-
224+
// Verify all 3 outputs match within 1e-6 using operator() (Managed Memory)
225+
bool all_match = true;
226+
constexpr float tol = 1e-6f;
227+
auto n = K_tensor.Size(0);
228+
229+
for (index_t i = 0; i < n; i++) {
230+
float v1 = output_tensor(i);
231+
float v2 = output_tensor2(i);
232+
float v3 = output_tensor3(i);
233+
if (fabsf(v1 - v2) > tol || fabsf(v1 - v3) > tol || fabsf(v2 - v3) > tol) {
234+
printf("Mismatch at idx %lld: v1=%.8f v2=%.8f v3=%.8f\n", i, v1, v2, v3);
235+
all_match = false;
236+
break;
237+
}
238+
}
239+
if (all_match) {
240+
printf("All outputs match within %.1e tolerance.\n", tol);
241+
} else {
242+
printf("Outputs do NOT match within %.1e tolerance!\n", tol);
243+
}
204244

205245
cudaEventDestroy(start);
206246
cudaEventDestroy(stop);

include/matx/operators/apply.h

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ namespace matx
7171
template <ElementsPerThread EPT, typename... Is>
7272
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
7373
{
74-
if constexpr (EPT == ElementsPerThread::ONE) {
75-
return apply_impl(cuda::std::index_sequence_for<Ops...>{}, indices...);
76-
} else {
77-
return Vector<value_type, static_cast<size_t>(EPT)>();
78-
}
74+
return apply_impl<EPT>(cuda::std::index_sequence_for<Ops...>{}, indices...);
7975
}
8076

8177
template <typename... Is>
@@ -86,12 +82,8 @@ namespace matx
8682

8783
template <OperatorCapability Cap>
8884
__MATX_INLINE__ __MATX_HOST__ auto get_capability() const {
89-
if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) {
90-
return ElementsPerThread::ONE;
91-
} else {
92-
auto self_has_cap = capability_attributes<Cap>::default_value;
93-
return combine_capabilities_tuple<Cap>(self_has_cap, ops_, cuda::std::index_sequence_for<Ops...>{});
94-
}
85+
auto self_has_cap = capability_attributes<Cap>::default_value;
86+
return combine_capabilities_tuple<Cap>(self_has_cap, ops_, cuda::std::index_sequence_for<Ops...>{});
9587
}
9688

9789
template <typename ShapeType, typename Executor>
@@ -126,11 +118,28 @@ namespace matx
126118
cuda::std::tuple<typename detail::base_type_t<Ops>...> ops_;
127119
cuda::std::array<index_t, first_op_type::Rank()> sizes_;
128120
// Helper to apply the lambda function to all operators
129-
template <size_t... Is, typename... Indices>
121+
template <ElementsPerThread EPT, size_t... Is, typename... Indices>
130122
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) apply_impl(
131123
cuda::std::index_sequence<Is...>, Indices... indices) const
132124
{
133-
return func_(cuda::std::get<Is>(ops_)(indices...)...);
125+
using out_type = decltype(cuda::std::get<0>(ops_).template operator()<EPT>(indices...));
126+
if constexpr (is_vector_v<out_type>) {
127+
// Each operator returns a vector, so call operator() once per operator to get the vectors
128+
auto op_results = cuda::std::make_tuple(cuda::std::get<Is>(ops_).template operator()<EPT>(indices...)...);
129+
130+
// Deduce the result type by calling func_ on scalar elements
131+
using result_element_type = decltype(func_(cuda::std::get<Is>(op_results).data[0]...));
132+
Vector<result_element_type, static_cast<int>(EPT)> result;
133+
134+
// Unroll loop to call func_ on each element of the vectors
135+
#pragma unroll
136+
for (int i = 0; i < static_cast<int>(EPT); i++) {
137+
result.data[i] = func_(cuda::std::get<Is>(op_results).data[i]...);
138+
}
139+
return result;
140+
} else {
141+
return func_(cuda::std::get<Is>(ops_).template operator()<EPT>(indices...)...);
142+
}
134143
}
135144

136145
// Helper to call PreRun on all operators

0 commit comments

Comments
 (0)