@@ -71,133 +71,103 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7171 MATX_ENTER_HANDLER ();
7272 using complex = cuda::std::complex <float >;
7373
74- // index_t signal_size = 1ULL << 16;
75- // index_t filter_size = 16;
76- // index_t batches = 8;
77- // index_t filtered_size = signal_size + filter_size - 1;
78- // float separate_ms;
79- // float fused_ms;
80- // constexpr int iterations = 100;
81- // cudaStream_t stream;
82- // cudaStreamCreate(&stream);
83- // cudaEvent_t start, stop;
84- // cudaEventCreate(&start);
85- // cudaEventCreate(&stop);
86-
87- // // Create time domain buffers
88- // auto sig_time = make_tensor<complex>({batches, signal_size});
89- // auto filt_time = make_tensor<complex>({batches, filter_size});
90- // auto time_out = make_tensor<complex>({batches, filtered_size});
91-
92- // // Frequency domain buffers
93- // auto sig_freq = make_tensor<complex>({batches, filtered_size});
94- // auto filt_freq = make_tensor<complex>({batches, filtered_size});
95-
96- // for (index_t b = 0; b < batches; b++) {
97- // // Fill the time domain signals with data
98- // for (index_t i = 0; i < signal_size; i++) {
99- // sig_time(b,i) = {-1.0f * (2.0f * static_cast<float>(i % 2) + 1.0f) *
100- // (static_cast<float>(i % 10) / 10.0f) +
101- // 0.1f,
102- // -1.0f * (static_cast<float>(i % 2) == 0.0f) *
103- // (static_cast<float>(i % 10) / 5.0f) -
104- // 0.1f};
105- // }
106- // for (index_t i = 0; i < filter_size; i++) {
107- // filt_time(b,i) = {static_cast<float>(i) / static_cast<float>(filter_size),
108- // static_cast<float>(-i) / static_cast<float>(filter_size) +
109- // 0.5f};
110- // }
111- // }
112-
113- // // Prefetch the data we just created
114- // sig_time.PrefetchDevice(0);
115- // filt_time.PrefetchDevice(0);
116-
117-
118- // // Perform the FFT in-place on both signal and filter
119- // for (int i = 0; i < iterations; i++) {
120- // if (i == 1) {
121- // cudaEventRecord(start, stream);
122- // }
123- // fft_impl(sig_freq, sig_time, 0, stream);
124- // fft_impl(filt_freq, filt_time, 0, stream);
125-
126- // (sig_freq = sig_freq * filt_freq).run(stream);
127-
128- // // IFFT in-place
129- // (sig_freq = ifft(sig_freq)).run(stream);
130- // }
131-
132- // cudaEventRecord(stop, stream);
133- // cudaStreamSynchronize(stream);
134- // cudaEventElapsedTime(&separate_ms, start, stop);
135-
136- // for (int i = 0; i < iterations; i++) {
137- // if (i == 1) {
138- // cudaEventRecord(start, stream);
139- // }
140- // (sig_freq = ifft(fft(sig_time, filtered_size) * fft(filt_time, filtered_size))).run(stream);
141- // }
74+ index_t signal_size = 1ULL << 16 ;
75+ index_t filter_size = 16 ;
76+ index_t batches = 8 ;
77+ index_t filtered_size = signal_size + filter_size - 1 ;
78+ float separate_ms;
79+ float fused_ms;
80+ constexpr int iterations = 100 ;
81+ cudaStream_t stream;
82+ cudaStreamCreate (&stream);
83+ cudaEvent_t start, stop;
84+ cudaEventCreate (&start);
85+ cudaEventCreate (&stop);
86+
87+ // Create time domain buffers
88+ auto sig_time = make_tensor<complex >({batches, signal_size});
89+ auto filt_time = make_tensor<complex >({batches, filter_size});
90+ auto time_out = make_tensor<complex >({batches, filtered_size});
91+
92+ // Frequency domain buffers
93+ auto sig_freq = make_tensor<complex >({batches, filtered_size});
94+ auto filt_freq = make_tensor<complex >({batches, filtered_size});
95+
96+ for (index_t b = 0 ; b < batches; b++) {
97+ // Fill the time domain signals with data
98+ for (index_t i = 0 ; i < signal_size; i++) {
99+ sig_time (b,i) = {-1 .0f * (2 .0f * static_cast <float >(i % 2 ) + 1 .0f ) *
100+ (static_cast <float >(i % 10 ) / 10 .0f ) +
101+ 0 .1f ,
102+ -1 .0f * (static_cast <float >(i % 2 ) == 0 .0f ) *
103+ (static_cast <float >(i % 10 ) / 5 .0f ) -
104+ 0 .1f };
105+ }
106+ for (index_t i = 0 ; i < filter_size; i++) {
107+ filt_time (b,i) = {static_cast <float >(i) / static_cast <float >(filter_size),
108+ static_cast <float >(-i) / static_cast <float >(filter_size) +
109+ 0 .5f };
110+ }
111+ }
112+
113+ // Prefetch the data we just created
114+ sig_time.PrefetchDevice (0 );
115+ filt_time.PrefetchDevice (0 );
116+
117+
118+ // Perform the FFT in-place on both signal and filter
119+ for (int i = 0 ; i < iterations; i++) {
120+ if (i == 1 ) {
121+ cudaEventRecord (start, stream);
122+ }
123+ fft_impl (sig_freq, sig_time, 0 , stream);
124+ fft_impl (filt_freq, filt_time, 0 , stream);
125+
126+ (sig_freq = sig_freq * filt_freq).run (stream);
127+
128+ // IFFT in-place
129+ (sig_freq = ifft (sig_freq)).run (stream);
130+ }
131+
132+ cudaEventRecord (stop, stream);
133+ cudaStreamSynchronize (stream);
134+ cudaEventElapsedTime (&separate_ms, start, stop);
135+
136+ for (int i = 0 ; i < iterations; i++) {
137+ if (i == 1 ) {
138+ cudaEventRecord (start, stream);
139+ }
140+ (sig_freq = ifft (fft (sig_time, filtered_size) * fft (filt_time, filtered_size))).run (stream);
141+ }
142142
143- // cudaEventRecord(stop, stream);
144- // cudaStreamSynchronize(stream);
145- // cudaEventElapsedTime(&fused_ms, start, stop);
143+ cudaEventRecord (stop, stream);
144+ cudaStreamSynchronize (stream);
145+ cudaEventElapsedTime (&fused_ms, start, stop);
146146
147- // printf("FFT runtimes for separate = %.2f ms, fused = %.2f ms\n", separate_ms/(iterations-1), fused_ms/(iterations-1));
147+ printf (" FFT runtimes for separate = %.2f ms, fused = %.2f ms\n " , separate_ms/(iterations-1 ), fused_ms/(iterations-1 ));
148148
149- // // Now the sig_freq view contains the full convolution result. Verify against
150- // // a direct convolution. The conv1d function only accepts a 1D filter, so we
151- // // create a sliced view here.
152- // auto filt1 = filt_time.Slice<1>({0,0}, {matxDropDim, matxEnd});
153- // (time_out = conv1d(sig_time, filt1, matxConvCorrMode_t::MATX_C_MODE_FULL)).run();
149+ // Now the sig_freq view contains the full convolution result. Verify against
150+ // a direct convolution. The conv1d function only accepts a 1D filter, so we
151+ // create a sliced view here.
152+ auto filt1 = filt_time.Slice <1 >({0 ,0 }, {matxDropDim, matxEnd});
153+ (time_out = conv1d (sig_time, filt1, matxConvCorrMode_t::MATX_C_MODE_FULL)).run ();
154154
155- // cudaStreamSynchronize(0);
155+ cudaStreamSynchronize (0 );
156156
157- // // Compare signals
158- // for (index_t b = 0; b < batches; b++) {
159- // for (index_t i = 0; i < filtered_size; i++) {
160- // if (fabs(time_out(b,i).real() - sig_freq(b,i).real()) > 0.001 ||
161- // fabs(time_out(b,i).imag() - sig_freq(b,i).imag()) > 0.001) {
162- // std::cout <<
163- // "Verification failed at item " << i << ". Direct=" << time_out(b,i).real() << " " << time_out(b,i).imag() << ", FFT=" <<
164- // sig_freq(b,i).real() << " " <<
165- // sig_freq(b,i).imag() << "\n";
166- // return -1;
167- // }
168- // }
169- // }
157+ // Compare signals
158+ for (index_t b = 0 ; b < batches; b++) {
159+ for (index_t i = 0 ; i < filtered_size; i++) {
160+ if (fabs (time_out (b,i).real () - sig_freq (b,i).real ()) > 0.001 ||
161+ fabs (time_out (b,i).imag () - sig_freq (b,i).imag ()) > 0.001 ) {
162+ std::cout <<
163+ " Verification failed at item " << i << " . Direct=" << time_out (b,i).real () << " " << time_out (b,i).imag () << " , FFT=" <<
164+ sig_freq (b,i).real () << " " <<
165+ sig_freq (b,i).imag () << " \n " ;
166+ return -1 ;
167+ }
168+ }
169+ }
170170
171- {
172- constexpr index_t m = 16 ;
173- constexpr index_t k = 32 ;
174- constexpr index_t n = 64 ;
175- constexpr index_t b = 8 ;
176- tensor_t <float , 3 > a3{{b, m, k}};
177- tensor_t <float , 3 > b3{{b, k, n}};
178- tensor_t <float , 3 > c3{{b, m, n}};
179- const int axis[2 ] = {2 , 1 };
180- std::array<int , 3 > perm ({0 , 2 , 1 });
181-
182- auto ai = make_tensor<float >({b, k, m});
183- auto bi = make_tensor<float >({b, n, k});
184- auto ci = make_tensor<float >({b, n, m});
185-
186- auto ap = permute (ai, perm);
187- auto bp = permute (bi, perm);
188- auto cp = permute (ci, perm);
189-
190- // copy data into permuted inputs
191- (ap = a3).run ();
192- (bp = b3).run ();
193-
194- // Perform a GEMM with the last two dimensions permuted
195- (ci = matmul (ai, bi, axis)).run ();
196- // example-end matmul-test-6
197-
198- // copy result from permuted output
199- (c3 = cp).run ();
200- }
201171
202172 std::cout << " Verification successful" << std::endl;
203173
0 commit comments