Skip to content

Commit 9037920

Browse files
committed
fft_conv
1 parent 85bff91 commit 9037920

File tree

1 file changed

+91
-121
lines changed

1 file changed

+91
-121
lines changed

examples/fft_conv.cu

Lines changed: 91 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -71,133 +71,103 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7171
MATX_ENTER_HANDLER();
7272
using complex = cuda::std::complex<float>;
7373

74-
// index_t signal_size = 1ULL << 16;
75-
// index_t filter_size = 16;
76-
// index_t batches = 8;
77-
// index_t filtered_size = signal_size + filter_size - 1;
78-
// float separate_ms;
79-
// float fused_ms;
80-
// constexpr int iterations = 100;
81-
// cudaStream_t stream;
82-
// cudaStreamCreate(&stream);
83-
// cudaEvent_t start, stop;
84-
// cudaEventCreate(&start);
85-
// cudaEventCreate(&stop);
86-
87-
// // Create time domain buffers
88-
// auto sig_time = make_tensor<complex>({batches, signal_size});
89-
// auto filt_time = make_tensor<complex>({batches, filter_size});
90-
// auto time_out = make_tensor<complex>({batches, filtered_size});
91-
92-
// // Frequency domain buffers
93-
// auto sig_freq = make_tensor<complex>({batches, filtered_size});
94-
// auto filt_freq = make_tensor<complex>({batches, filtered_size});
95-
96-
// for (index_t b = 0; b < batches; b++) {
97-
// // Fill the time domain signals with data
98-
// for (index_t i = 0; i < signal_size; i++) {
99-
// sig_time(b,i) = {-1.0f * (2.0f * static_cast<float>(i % 2) + 1.0f) *
100-
// (static_cast<float>(i % 10) / 10.0f) +
101-
// 0.1f,
102-
// -1.0f * (static_cast<float>(i % 2) == 0.0f) *
103-
// (static_cast<float>(i % 10) / 5.0f) -
104-
// 0.1f};
105-
// }
106-
// for (index_t i = 0; i < filter_size; i++) {
107-
// filt_time(b,i) = {static_cast<float>(i) / static_cast<float>(filter_size),
108-
// static_cast<float>(-i) / static_cast<float>(filter_size) +
109-
// 0.5f};
110-
// }
111-
// }
112-
113-
// // Prefetch the data we just created
114-
// sig_time.PrefetchDevice(0);
115-
// filt_time.PrefetchDevice(0);
116-
117-
118-
// // Perform the FFT in-place on both signal and filter
119-
// for (int i = 0; i < iterations; i++) {
120-
// if (i == 1) {
121-
// cudaEventRecord(start, stream);
122-
// }
123-
// fft_impl(sig_freq, sig_time, 0, stream);
124-
// fft_impl(filt_freq, filt_time, 0, stream);
125-
126-
// (sig_freq = sig_freq * filt_freq).run(stream);
127-
128-
// // IFFT in-place
129-
// (sig_freq = ifft(sig_freq)).run(stream);
130-
// }
131-
132-
// cudaEventRecord(stop, stream);
133-
// cudaStreamSynchronize(stream);
134-
// cudaEventElapsedTime(&separate_ms, start, stop);
135-
136-
// for (int i = 0; i < iterations; i++) {
137-
// if (i == 1) {
138-
// cudaEventRecord(start, stream);
139-
// }
140-
// (sig_freq = ifft(fft(sig_time, filtered_size) * fft(filt_time, filtered_size))).run(stream);
141-
// }
74+
index_t signal_size = 1ULL << 16;
75+
index_t filter_size = 16;
76+
index_t batches = 8;
77+
index_t filtered_size = signal_size + filter_size - 1;
78+
float separate_ms;
79+
float fused_ms;
80+
constexpr int iterations = 100;
81+
cudaStream_t stream;
82+
cudaStreamCreate(&stream);
83+
cudaEvent_t start, stop;
84+
cudaEventCreate(&start);
85+
cudaEventCreate(&stop);
86+
87+
// Create time domain buffers
88+
auto sig_time = make_tensor<complex>({batches, signal_size});
89+
auto filt_time = make_tensor<complex>({batches, filter_size});
90+
auto time_out = make_tensor<complex>({batches, filtered_size});
91+
92+
// Frequency domain buffers
93+
auto sig_freq = make_tensor<complex>({batches, filtered_size});
94+
auto filt_freq = make_tensor<complex>({batches, filtered_size});
95+
96+
for (index_t b = 0; b < batches; b++) {
97+
// Fill the time domain signals with data
98+
for (index_t i = 0; i < signal_size; i++) {
99+
sig_time(b,i) = {-1.0f * (2.0f * static_cast<float>(i % 2) + 1.0f) *
100+
(static_cast<float>(i % 10) / 10.0f) +
101+
0.1f,
102+
-1.0f * (static_cast<float>(i % 2) == 0.0f) *
103+
(static_cast<float>(i % 10) / 5.0f) -
104+
0.1f};
105+
}
106+
for (index_t i = 0; i < filter_size; i++) {
107+
filt_time(b,i) = {static_cast<float>(i) / static_cast<float>(filter_size),
108+
static_cast<float>(-i) / static_cast<float>(filter_size) +
109+
0.5f};
110+
}
111+
}
112+
113+
// Prefetch the data we just created
114+
sig_time.PrefetchDevice(0);
115+
filt_time.PrefetchDevice(0);
116+
117+
118+
// Perform the FFT in-place on both signal and filter
119+
for (int i = 0; i < iterations; i++) {
120+
if (i == 1) {
121+
cudaEventRecord(start, stream);
122+
}
123+
fft_impl(sig_freq, sig_time, 0, stream);
124+
fft_impl(filt_freq, filt_time, 0, stream);
125+
126+
(sig_freq = sig_freq * filt_freq).run(stream);
127+
128+
// IFFT in-place
129+
(sig_freq = ifft(sig_freq)).run(stream);
130+
}
131+
132+
cudaEventRecord(stop, stream);
133+
cudaStreamSynchronize(stream);
134+
cudaEventElapsedTime(&separate_ms, start, stop);
135+
136+
for (int i = 0; i < iterations; i++) {
137+
if (i == 1) {
138+
cudaEventRecord(start, stream);
139+
}
140+
(sig_freq = ifft(fft(sig_time, filtered_size) * fft(filt_time, filtered_size))).run(stream);
141+
}
142142

143-
// cudaEventRecord(stop, stream);
144-
// cudaStreamSynchronize(stream);
145-
// cudaEventElapsedTime(&fused_ms, start, stop);
143+
cudaEventRecord(stop, stream);
144+
cudaStreamSynchronize(stream);
145+
cudaEventElapsedTime(&fused_ms, start, stop);
146146

147-
// printf("FFT runtimes for separate = %.2f ms, fused = %.2f ms\n", separate_ms/(iterations-1), fused_ms/(iterations-1));
147+
printf("FFT runtimes for separate = %.2f ms, fused = %.2f ms\n", separate_ms/(iterations-1), fused_ms/(iterations-1));
148148

149-
// // Now the sig_freq view contains the full convolution result. Verify against
150-
// // a direct convolution. The conv1d function only accepts a 1D filter, so we
151-
// // create a sliced view here.
152-
// auto filt1 = filt_time.Slice<1>({0,0}, {matxDropDim, matxEnd});
153-
// (time_out = conv1d(sig_time, filt1, matxConvCorrMode_t::MATX_C_MODE_FULL)).run();
149+
// Now the sig_freq view contains the full convolution result. Verify against
150+
// a direct convolution. The conv1d function only accepts a 1D filter, so we
151+
// create a sliced view here.
152+
auto filt1 = filt_time.Slice<1>({0,0}, {matxDropDim, matxEnd});
153+
(time_out = conv1d(sig_time, filt1, matxConvCorrMode_t::MATX_C_MODE_FULL)).run();
154154

155-
// cudaStreamSynchronize(0);
155+
cudaStreamSynchronize(0);
156156

157-
// // Compare signals
158-
// for (index_t b = 0; b < batches; b++) {
159-
// for (index_t i = 0; i < filtered_size; i++) {
160-
// if (fabs(time_out(b,i).real() - sig_freq(b,i).real()) > 0.001 ||
161-
// fabs(time_out(b,i).imag() - sig_freq(b,i).imag()) > 0.001) {
162-
// std::cout <<
163-
// "Verification failed at item " << i << ". Direct=" << time_out(b,i).real() << " " << time_out(b,i).imag() << ", FFT=" <<
164-
// sig_freq(b,i).real() << " " <<
165-
// sig_freq(b,i).imag() << "\n";
166-
// return -1;
167-
// }
168-
// }
169-
// }
157+
// Compare signals
158+
for (index_t b = 0; b < batches; b++) {
159+
for (index_t i = 0; i < filtered_size; i++) {
160+
if (fabs(time_out(b,i).real() - sig_freq(b,i).real()) > 0.001 ||
161+
fabs(time_out(b,i).imag() - sig_freq(b,i).imag()) > 0.001) {
162+
std::cout <<
163+
"Verification failed at item " << i << ". Direct=" << time_out(b,i).real() << " " << time_out(b,i).imag() << ", FFT=" <<
164+
sig_freq(b,i).real() << " " <<
165+
sig_freq(b,i).imag() << "\n";
166+
return -1;
167+
}
168+
}
169+
}
170170

171-
{
172-
constexpr index_t m = 16;
173-
constexpr index_t k = 32;
174-
constexpr index_t n = 64;
175-
constexpr index_t b = 8;
176-
tensor_t<float, 3> a3{{b, m, k}};
177-
tensor_t<float, 3> b3{{b, k, n}};
178-
tensor_t<float, 3> c3{{b, m, n}};
179-
const int axis[2] = {2, 1};
180-
std::array<int, 3> perm({0, 2, 1});
181-
182-
auto ai = make_tensor<float>({b, k, m});
183-
auto bi = make_tensor<float>({b, n, k});
184-
auto ci = make_tensor<float>({b, n, m});
185-
186-
auto ap = permute(ai, perm);
187-
auto bp = permute(bi, perm);
188-
auto cp = permute(ci, perm);
189-
190-
// copy data into permuted inputs
191-
(ap = a3).run();
192-
(bp = b3).run();
193-
194-
// Perform a GEMM with the last two dimensions permuted
195-
(ci = matmul(ai, bi, axis)).run();
196-
// example-end matmul-test-6
197-
198-
// copy result from permuted output
199-
(c3 = cp).run();
200-
}
201171

202172
std::cout << "Verification successful" << std::endl;
203173

0 commit comments

Comments
 (0)