@@ -48,25 +48,29 @@ using namespace matx;
4848 * boilerplate code around the original expression. This custom operator can then be used either alone or inside
4949 * other arithmetic expressions, and only a single load is issues for each tensor.
5050 *
51- * This example uses the Black-Scholes equtation to demonstrate the two ways to implement the equation in MatX, and
52- * shows the performance difference.
51+ * This example uses the Black-Scholes equtation to demonstrate three ways to implement the equation in MatX, and
52+ * shows the performance difference between them. The three ways are:
53+ * 1. Using a custom operator
54+ * 2. Using a lambda function via apply()
55+ * 3. Using a MatX expression
56+ *
57+ * Which method to use depends on the use case, but the lambda function is preferred for simplicity and readability.
5358 */
5459
5560/* Custom operator */
56- template <class O , class I1 >
57- class BlackScholes : public BaseOp <BlackScholes<O, I1>> {
61+ template <class I1 >
62+ class BlackScholes : public BaseOp <BlackScholes<I1>> {
5863private:
59- O out_;
6064 I1 V_, S_, K_, r_, T_;
6165
6266public:
6367 using matxop = bool ;
6468
65- BlackScholes (O out, I1 K, I1 V, I1 S, I1 r, I1 T)
66- : out_(out), V_(V), S_(S), K_(K), r_(r), T_(T) {}
69+ BlackScholes (I1 K, I1 V, I1 S, I1 r, I1 T)
70+ : V_(V), S_(S), K_(K), r_(r), T_(T) {}
6771
6872 template <detail::ElementsPerThread EPT>
69- __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void operator ()(index_t idx)
73+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto operator ()(index_t idx) const
7074 {
7175 auto V = V_ (idx);
7276 auto K = K_ (idx);
@@ -81,27 +85,32 @@ public:
8185 auto cdf_d2 = normcdff (d2);
8286 auto expRT = exp (-1 .f * r * T);
8387
84- out_ (idx) = S * cdf_d1 - K * expRT * cdf_d2;
88+ return S * cdf_d1 - K * expRT * cdf_d2;
8589 }
8690
8791 __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void operator ()(index_t idx) {
8892 return this ->operator ()<detail::ElementsPerThread::ONE>(idx);
8993 }
9094
91- __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (uint32_t i) const { return out_ .Size (i); }
92- static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ int32_t Rank () { return O ::Rank (); }
95+ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (uint32_t i) const { return V_ .Size (i); }
96+ static constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ int32_t Rank () { return I1 ::Rank (); }
9397
9498 template <detail::OperatorCapability Cap>
9599 __MATX_INLINE__ __MATX_HOST__ auto get_capability () const {
96- auto self_has_cap = detail::capability_attributes<Cap>::default_value;
97- return detail::combine_capabilities<Cap>(
98- self_has_cap,
99- detail::get_operator_capability<Cap>(V_),
100- detail::get_operator_capability<Cap>(S_),
101- detail::get_operator_capability<Cap>(K_),
102- detail::get_operator_capability<Cap>(r_),
103- detail::get_operator_capability<Cap>(T_)
104- );
100+ // Don't support vectorization yet
101+ if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) {
102+ return detail::ElementsPerThread::ONE;
103+ } else {
104+ auto self_has_cap = detail::capability_attributes<Cap>::default_value;
105+ return detail::combine_capabilities<Cap>(
106+ self_has_cap,
107+ detail::get_operator_capability<Cap>(V_),
108+ detail::get_operator_capability<Cap>(S_),
109+ detail::get_operator_capability<Cap>(K_),
110+ detail::get_operator_capability<Cap>(r_),
111+ detail::get_operator_capability<Cap>(T_)
112+ );
113+ }
105114 }
106115};
107116
@@ -132,7 +141,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
132141 using dtype = float ;
133142
134143 index_t input_size = 100000000 ;
135- constexpr uint32_t num_iterations = 1 ;
144+ constexpr uint32_t num_iterations = 100 ;
136145 float time_ms;
137146
138147 tensor_t <dtype, 1 > K_tensor{{input_size}};
@@ -141,12 +150,20 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
141150 tensor_t <dtype, 1 > r_tensor{{input_size}};
142151 tensor_t <dtype, 1 > T_tensor{{input_size}};
143152 tensor_t <dtype, 1 > output_tensor{{input_size}};
153+ tensor_t <dtype, 1 > output_tensor2{{input_size}};
154+ tensor_t <dtype, 1 > output_tensor3{{input_size}};
155+
156+ (K_tensor = random<float >({input_size}, UNIFORM)).run ();
157+ (S_tensor = random<float >({input_size}, UNIFORM)).run ();
158+ (V_tensor = random<float >({input_size}, UNIFORM)).run ();
159+ (r_tensor = random<float >({input_size}, UNIFORM)).run ();
160+ (T_tensor = random<float >({input_size}, UNIFORM)).run ();
144161
145162 cudaStream_t stream;
146163 cudaStreamCreate (&stream);
147164 cudaExecutor exec{stream};
148165
149- compute_black_scholes_matx (K_tensor, S_tensor, V_tensor, r_tensor, T_tensor, output_tensor, exec);
166+ // compute_black_scholes_matx(K_tensor, S_tensor, V_tensor, r_tensor, T_tensor, output_tensor, exec);
150167
151168 cudaEvent_t start, stop;
152169 cudaEventCreate (&start);
@@ -159,6 +176,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
159176 }
160177 cudaEventRecord (stop, stream);
161178 exec.sync ();
179+
162180 cudaEventElapsedTime (&time_ms, start, stop);
163181
164182 printf (" Time without custom operator = %.2fms per iteration\n " ,
@@ -167,10 +185,11 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
167185 cudaEventRecord (start, stream);
168186 // Time non-operator version
169187 for (uint32_t i = 0 ; i < num_iterations; i++) {
170- BlackScholes (output_tensor, K_tensor, V_tensor, S_tensor, r_tensor, T_tensor).run (exec);
188+ (output_tensor2 = BlackScholes ( K_tensor, V_tensor, S_tensor, r_tensor, T_tensor) ).run (exec);
171189 }
172190 cudaEventRecord (stop, stream);
173191 exec.sync ();
192+
174193 cudaEventElapsedTime (&time_ms, start, stop);
175194 printf (" Time with custom operator = %.2fms per iteration\n " ,
176195 time_ms / num_iterations);
@@ -192,15 +211,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
192211
193212 cudaEventRecord (start, stream);
194213 for (uint32_t i = 0 ; i < num_iterations; i++) {
195- (output_tensor = matx::apply (bs_lambda, K_tensor, S_tensor, V_tensor, r_tensor, T_tensor)).run (exec);
214+ (output_tensor3 = matx::apply (bs_lambda, K_tensor, S_tensor, V_tensor, r_tensor, T_tensor)).run (exec);
196215 }
216+
197217 cudaEventRecord (stop, stream);
198218 exec.sync ();
219+
199220 cudaEventElapsedTime (&time_ms, start, stop);
200221 printf (" Time with lambda = %.2fms per iteration\n " ,
201222 time_ms / num_iterations);
202223
203-
224+ // Verify all 3 outputs match within 1e-6 using operator() (Managed Memory)
225+ bool all_match = true ;
226+ constexpr float tol = 1e-6f ;
227+ auto n = K_tensor.Size (0 );
228+
229+ for (index_t i = 0 ; i < n; i++) {
230+ float v1 = output_tensor (i);
231+ float v2 = output_tensor2 (i);
232+ float v3 = output_tensor3 (i);
233+ if (fabsf (v1 - v2) > tol || fabsf (v1 - v3) > tol || fabsf (v2 - v3) > tol) {
234+ printf (" Mismatch at idx %lld: v1=%.8f v2=%.8f v3=%.8f\n " , i, v1, v2, v3);
235+ all_match = false ;
236+ break ;
237+ }
238+ }
239+ if (all_match) {
240+ printf (" All outputs match within %.1e tolerance.\n " , tol);
241+ } else {
242+ printf (" Outputs do NOT match within %.1e tolerance!\n " , tol);
243+ }
204244
205245 cudaEventDestroy (start);
206246 cudaEventDestroy (stop);
0 commit comments