Skip to content

Commit 16e985b

Browse files
authored
Add smem-based polyphase channelizer kernel (#613)
Add a new polyphase channelizer kernel that utilizes shared memory for both the filter coefficients and input samples. This kernel is automatically selected in cases where the corresponding data will fit in shared memory.
1 parent ff37ea0 commit 16e985b

File tree

3 files changed

+223
-2
lines changed

3 files changed

+223
-2
lines changed

include/matx/kernels/channelize_poly.cuh

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,148 @@ __global__ void ChannelizePoly1D(OutType output, InType input, FilterType filter
174174
}
175175
}
176176

177+
// This kernel works in cases where the full filter (with potentially some zero padding) and
178+
// the inputs required to compute elems_per_channel_per_cta outputs all fit into shared memory.
179+
template <typename OutType, typename InType, typename FilterType>
180+
__global__ void ChannelizePoly1D_Smem(OutType output, InType input, FilterType filter, index_t elems_per_channel_per_cta)
181+
{
182+
using output_t = typename OutType::scalar_type;
183+
using input_t = typename InType::scalar_type;
184+
using filter_t = typename FilterType::scalar_type;
185+
186+
extern __shared__ uint8_t __attribute((aligned(16))) smem_dyn_align16[];
187+
188+
constexpr int InRank = InType::Rank();
189+
constexpr int OutRank = OutType::Rank();
190+
constexpr int ChannelRank = OutRank-1;
191+
constexpr int OutElemRank = OutRank-2;
192+
193+
const index_t input_len = input.Size(InRank-1);
194+
const index_t output_len_per_channel = output.Size(OutElemRank);
195+
// If the filter fits into shared memory, then a 32-bit index is sufficient. One
196+
// edge case exception would be num_channels > 2^32-1, but with a small filter
197+
// implicitly padded with zeros. We assume that the kernel selection logic
198+
// considers the size of the zero-padded filter since that is what we actually
199+
// store in shared memory.
200+
const uint32_t num_channels = static_cast<uint32_t>(output.Size(ChannelRank));
201+
const uint32_t filter_full_len = static_cast<uint32_t>(filter.Size(0));
202+
const uint32_t filter_phase_len = static_cast<uint32_t>((filter_full_len + num_channels - 1) / num_channels);
203+
204+
filter_t *smem_h = reinterpret_cast<filter_t *>(smem_dyn_align16);
205+
size_t smem_input_offset = sizeof(filter_t) * filter_phase_len * num_channels;
206+
if (smem_input_offset % sizeof(input_t)) {
207+
smem_input_offset += sizeof(input_t) - smem_input_offset % sizeof(input_t);
208+
}
209+
input_t *smem_input = reinterpret_cast<input_t *>(smem_dyn_align16 + smem_input_offset);
210+
211+
const uint32_t tid = threadIdx.y * blockDim.x + threadIdx.x;
212+
const uint32_t nthreads = blockDim.x * blockDim.y;
213+
const uint32_t chan = threadIdx.x;
214+
const uint32_t ty = threadIdx.y;
215+
const uint32_t by = blockDim.y;
216+
217+
for (uint32_t t = tid; t < filter_full_len; t += nthreads) {
218+
smem_h[t] = filter.operator()(t);
219+
}
220+
221+
for (uint32_t t = filter_full_len+tid; t < filter_phase_len * num_channels; t += nthreads) {
222+
smem_h[t] = static_cast<filter_t>(0);
223+
}
224+
225+
// The input stored in shared memory is logically [smem_input_height, num_channels] where
226+
// smem_input_height is the number of samples at the output sample rate stored in smem.
227+
const uint32_t smem_input_height = filter_phase_len + by - 1;
228+
229+
const index_t start_elem = blockIdx.x * elems_per_channel_per_cta;
230+
const index_t last_elem = std::min(output_len_per_channel-1, (blockIdx.x+1) * elems_per_channel_per_cta - 1);
231+
auto indims = BlockToIdx(input, blockIdx.z, 1);
232+
auto outdims = BlockToIdx(output, blockIdx.z, 2);
233+
outdims[ChannelRank] = chan;
234+
235+
for (uint32_t t = ty; t < filter_phase_len-1; t += by) {
236+
const index_t out_sample_ind = start_elem - (filter_phase_len-1) + t;
237+
const uint32_t smem_ind = t * num_channels + chan;
238+
const index_t input_ind = out_sample_ind * num_channels + chan;
239+
if (input_ind >= 0 && input_ind < input_len) {
240+
indims[InRank-1] = input_ind;
241+
detail::mapply([smem_input, smem_ind, &input](auto &&...args) {
242+
smem_input[smem_ind] = input.operator()(args...);
243+
}, indims);
244+
} else {
245+
smem_input[smem_ind] = static_cast<filter_t>(0);
246+
}
247+
}
248+
249+
index_t next_start_elem = start_elem;
250+
const index_t num_elem_iters = (last_elem - start_elem + 1 + by - 1) / by;
251+
252+
uint32_t cached_input_ind_tail = filter_phase_len - 1 + ty;
253+
const filter_t *h_start = smem_h + num_channels * filter_phase_len - (num_channels - chan);
254+
for (index_t iter = 0; iter < num_elem_iters; iter++) {
255+
256+
__syncthreads();
257+
258+
// Load next elems_per_channel_per_cta elements for each channel
259+
const index_t next_last_elem = std::min(next_start_elem + by - 1, last_elem);
260+
const uint32_t out_samples_this_iter = static_cast<uint32_t>(next_last_elem - next_start_elem + 1);
261+
if (ty < out_samples_this_iter) {
262+
indims[InRank-1] = (next_start_elem + ty) * num_channels + chan;
263+
const uint32_t smem_ind = cached_input_ind_tail * num_channels + chan;
264+
if (indims[InRank-1] < input_len) {
265+
detail::mapply([smem_input, smem_ind, &input](auto &&...args) {
266+
smem_input[smem_ind] = input.operator()(args...);
267+
}, indims);
268+
} else {
269+
smem_input[smem_ind] = static_cast<filter_t>(0);
270+
}
271+
}
272+
273+
cached_input_ind_tail += by;
274+
// The below effectively mods cached_input_ind_tail by smem_input_height. Since
275+
// smem_input_height is >= by, adding by means that we will need to subtract
276+
// smem_input_height at most once for cached_input_ind_tail to be in the range
277+
// [0, smem_input_height-1]. The conditional is cheaper than the mod, unless
278+
// smem_input_height is known at compile time.
279+
if (cached_input_ind_tail >= smem_input_height) {
280+
cached_input_ind_tail -= smem_input_height;
281+
}
282+
283+
__syncthreads();
284+
285+
outdims[OutElemRank] = next_start_elem + ty;
286+
if (outdims[OutElemRank] <= last_elem) {
287+
const filter_t *h = h_start;
288+
output_t accum { 0 };
289+
const int first_end = std::min(cached_input_ind_tail + filter_phase_len - 1, smem_input_height - 1);
290+
// The footprint of samples involved in the convolution may wrap from the end
291+
// to the beginning of smem_input. The prologue below handles the samples from
292+
// the current tail to the end of smem_input and the epilogue starts back at the
293+
// beginning of smem_input.
294+
const int prologue_count = (first_end - cached_input_ind_tail + 1);
295+
const int epilogue_count = (prologue_count < filter_phase_len) ? filter_phase_len - prologue_count : 0;
296+
const input_t *sample = smem_input + cached_input_ind_tail * num_channels + (num_channels - 1 - chan);
297+
// Apply the filter h in reverse order below to flip the filter for convolution
298+
for (int k = 0; k < prologue_count; k++) {
299+
accum += (*h) * (*sample);
300+
sample += num_channels;
301+
h -= num_channels;
302+
}
303+
sample = smem_input + (num_channels - 1 - chan);
304+
for (int k = 0; k < epilogue_count; k++) {
305+
accum += (*h) * (*sample);
306+
sample += num_channels;
307+
h -= num_channels;
308+
}
309+
310+
detail::mapply([accum, &output](auto &&...args) {
311+
output.operator()(args...) = accum;
312+
}, outdims);
313+
}
314+
315+
next_start_elem += out_samples_this_iter;
316+
}
317+
}
318+
177319
template <int THREADS, int NUM_CHAN, typename OutType, typename InType, typename FilterType>
178320
__launch_bounds__(THREADS)
179321
__global__ void ChannelizePoly1D_FusedChan(OutType output, InType input, FilterType filter)

include/matx/transforms/channelize_poly.h

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ namespace detail {
5252
// channel counts.
5353
constexpr index_t MATX_CHANNELIZE_POLY1D_FUSED_CHAN_KERNEL_THRESHOLD = 6;
5454

55+
// Number of output samples per channel per iteration for the kernel that stores
56+
// the input data in shared memory. Ideally, this value would be determined dynamically
57+
// to balance occupancy and CTA size. For now, we choose a reasonable default.
58+
constexpr index_t MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER = 4;
59+
5560
template <typename OutType, typename InType, typename FilterType>
5661
inline void matxChannelizePoly1DInternal(OutType o, const InType &i,
5762
const FilterType &filter, cudaStream_t stream)
@@ -78,6 +83,70 @@ inline void matxChannelizePoly1DInternal(OutType o, const InType &i,
7883
#endif
7984
}
8085

86+
template <typename OutType, typename InType, typename FilterType>
87+
inline size_t matxChannelizePoly1DInternal_SmemSizeBytes(const OutType &o, const InType &, const FilterType &filter)
88+
{
89+
using input_t = typename InType::scalar_type;
90+
using filter_t = typename FilterType::scalar_type;
91+
92+
index_t filter_len = filter.Size(FilterType::Rank()-1);
93+
94+
const index_t num_channels = o.Size(OutType::Rank()-1);
95+
const index_t nout_per_channel = o.Size(OutType::Rank()-2);
96+
const index_t filter_phase_len = (filter_len + num_channels - 1) / num_channels;
97+
98+
size_t smem_size = sizeof(filter_t)*(num_channels)*(filter_phase_len) +
99+
sizeof(input_t)*(num_channels)*(filter_phase_len + MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER - 1);
100+
const size_t max_sizeof = std::max(sizeof(filter_t), sizeof(input_t));
101+
if (smem_size % max_sizeof) {
102+
smem_size += max_sizeof - (smem_size % max_sizeof);
103+
}
104+
return smem_size;
105+
}
106+
107+
template <typename OutType, typename InType, typename FilterType>
108+
inline size_t matxChannelizePoly1DInternal_ShouldUseSmemKernel(const OutType &out, const InType &in, const FilterType &filter)
109+
{
110+
// 48 KB is the largest shared memory allocation that does not require
111+
// explicit opt-in via cudaFuncSetAttribute()
112+
const size_t MAX_SMEM_BYTES = 48 * 1024;
113+
// The full shared memory kernel uses blocks of size
114+
// (num_channels, detail::MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER), so ensure
115+
// that the resulting thread per block count will not exceed MAX_NUM_THREADS_PER_BLOCK
116+
const int MAX_NUM_THREADS_PER_BLOCK = 1024;
117+
const index_t num_channels = out.Size(OutType::Rank()-1);
118+
return (
119+
matxChannelizePoly1DInternal_SmemSizeBytes(out, in, filter) <= MAX_SMEM_BYTES &&
120+
num_channels <= (MAX_NUM_THREADS_PER_BLOCK/detail::MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER));
121+
}
122+
123+
template <typename OutType, typename InType, typename FilterType>
124+
inline void matxChannelizePoly1DInternal_Smem(OutType o, const InType &i, const FilterType &filter, cudaStream_t stream)
125+
{
126+
#ifdef __CUDACC__
127+
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
128+
129+
using input_t = typename InType::scalar_type;
130+
using filter_t = typename FilterType::scalar_type;
131+
132+
index_t filter_len = filter.Size(FilterType::Rank()-1);
133+
134+
const index_t num_channels = o.Size(OutType::Rank()-1);
135+
const index_t nout_per_channel = o.Size(OutType::Rank()-2);
136+
const int num_batches = static_cast<int>(TotalSize(i)/i.Size(i.Rank() - 1));
137+
138+
const int target_num_blocks = 1024;
139+
const int elem_per_block = static_cast<int>(
140+
(nout_per_channel + target_num_blocks - 1) / target_num_blocks);
141+
dim3 block(static_cast<int>(num_channels), MATX_CHANNELIZE_POLY1D_FULL_SMEM_KERNEL_NOUT_PER_ITER);
142+
const uint32_t num_blocks = static_cast<uint32_t>((nout_per_channel + elem_per_block - 1) / elem_per_block);
143+
dim3 grid(num_blocks, 1, num_batches);
144+
const size_t smem_size = matxChannelizePoly1DInternal_SmemSizeBytes(o, i, filter);
145+
ChannelizePoly1D_Smem<OutType, InType, FilterType><<<grid, block, smem_size, stream>>>(
146+
o, i, filter, elem_per_block);
147+
#endif
148+
}
149+
81150
template <typename OutType, typename InType, typename FilterType>
82151
inline void matxChannelizePoly1DInternal_FusedChan(OutType o, const InType &i,
83152
const FilterType &filter, cudaStream_t stream)
@@ -237,7 +306,11 @@ inline void channelize_poly_impl(OutType out, const InType &in, const FilterType
237306
}
238307
}();
239308

240-
matxChannelizePoly1DInternal(fft_in_slice, in, f, stream);
309+
if (matxChannelizePoly1DInternal_ShouldUseSmemKernel(out, in, f)) {
310+
matxChannelizePoly1DInternal_Smem(fft_in_slice, in, f, stream);
311+
} else {
312+
matxChannelizePoly1DInternal(fft_in_slice, in, f, stream);
313+
}
241314
stop_dims[OUT_RANK-1] = (num_channels/2) + 1;
242315
auto out_packed = slice<OUT_RANK>(out, start_dims, stop_dims);
243316
(out_packed = fft(fft_in_slice, num_channels)).run(stream);
@@ -247,7 +320,11 @@ inline void channelize_poly_impl(OutType out, const InType &in, const FilterType
247320
if (num_channels <= detail::MATX_CHANNELIZE_POLY1D_FUSED_CHAN_KERNEL_THRESHOLD) {
248321
matxChannelizePoly1DInternal_FusedChan(out, in, f, stream);
249322
} else {
250-
matxChannelizePoly1DInternal(out, in, f, stream);
323+
if (matxChannelizePoly1DInternal_ShouldUseSmemKernel(out, in, f)) {
324+
matxChannelizePoly1DInternal_Smem(out, in, f, stream);
325+
} else {
326+
matxChannelizePoly1DInternal(out, in, f, stream);
327+
}
251328
// Specify FORWARD here to prevent any normalization after the ifft. We do not
252329
// want any extra scaling on the output values.
253330
(out = ifft(out, num_channels, FFTNorm::FORWARD)).run(stream);

test/00_transform/ChannelizePoly.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ TYPED_TEST(ChannelizePolyTestNonHalfFloatTypes, Simple)
145145
{ 271374, 31*14+4, 14 },
146146
{ 27137, 301*13+3, 13 },
147147
{ 27138, 301*14+4, 14 },
148+
{ 1000000, 32*16, 32 },
149+
{ 1000000, 40*16, 40 }
148150
};
149151

150152
cudaStream_t stream = 0;

0 commit comments

Comments
 (0)