@@ -174,6 +174,148 @@ __global__ void ChannelizePoly1D(OutType output, InType input, FilterType filter
174174 }
175175}
176176
177+ // This kernel works in cases where the full filter (with potentially some zero padding) and
178+ // the inputs required to compute elems_per_channel_per_cta outputs all fit into shared memory.
179+ template <typename OutType, typename InType, typename FilterType>
180+ __global__ void ChannelizePoly1D_Smem (OutType output, InType input, FilterType filter, index_t elems_per_channel_per_cta)
181+ {
182+ using output_t = typename OutType::scalar_type;
183+ using input_t = typename InType::scalar_type;
184+ using filter_t = typename FilterType::scalar_type;
185+
186+ extern __shared__ uint8_t __attribute ((aligned (16 ))) smem_dyn_align16[];
187+
188+ constexpr int InRank = InType::Rank ();
189+ constexpr int OutRank = OutType::Rank ();
190+ constexpr int ChannelRank = OutRank-1 ;
191+ constexpr int OutElemRank = OutRank-2 ;
192+
193+ const index_t input_len = input.Size (InRank-1 );
194+ const index_t output_len_per_channel = output.Size (OutElemRank);
195+ // If the filter fits into shared memory, then a 32-bit index is sufficient. One
196+ // edge case exception would be num_channels > 2^32-1, but with a small filter
197+ // implicitly padded with zeros. We assume that the kernel selection logic
198+ // considers the size of the zero-padded filter since that is what we actually
199+ // store in shared memory.
200+ const uint32_t num_channels = static_cast <uint32_t >(output.Size (ChannelRank));
201+ const uint32_t filter_full_len = static_cast <uint32_t >(filter.Size (0 ));
202+ const uint32_t filter_phase_len = static_cast <uint32_t >((filter_full_len + num_channels - 1 ) / num_channels);
203+
204+ filter_t *smem_h = reinterpret_cast <filter_t *>(smem_dyn_align16);
205+ size_t smem_input_offset = sizeof (filter_t ) * filter_phase_len * num_channels;
206+ if (smem_input_offset % sizeof (input_t )) {
207+ smem_input_offset += sizeof (input_t ) - smem_input_offset % sizeof (input_t );
208+ }
209+ input_t *smem_input = reinterpret_cast <input_t *>(smem_dyn_align16 + smem_input_offset);
210+
211+ const uint32_t tid = threadIdx .y * blockDim .x + threadIdx .x ;
212+ const uint32_t nthreads = blockDim .x * blockDim .y ;
213+ const uint32_t chan = threadIdx .x ;
214+ const uint32_t ty = threadIdx .y ;
215+ const uint32_t by = blockDim .y ;
216+
217+ for (uint32_t t = tid; t < filter_full_len; t += nthreads) {
218+ smem_h[t] = filter.operator ()(t);
219+ }
220+
221+ for (uint32_t t = filter_full_len+tid; t < filter_phase_len * num_channels; t += nthreads) {
222+ smem_h[t] = static_cast <filter_t >(0 );
223+ }
224+
225+ // The input stored in shared memory is logically [smem_input_height, num_channels] where
226+ // smem_input_height is the number of samples at the output sample rate stored in smem.
227+ const uint32_t smem_input_height = filter_phase_len + by - 1 ;
228+
229+ const index_t start_elem = blockIdx .x * elems_per_channel_per_cta;
230+ const index_t last_elem = std::min (output_len_per_channel-1 , (blockIdx .x +1 ) * elems_per_channel_per_cta - 1 );
231+ auto indims = BlockToIdx (input, blockIdx .z , 1 );
232+ auto outdims = BlockToIdx (output, blockIdx .z , 2 );
233+ outdims[ChannelRank] = chan;
234+
235+ for (uint32_t t = ty; t < filter_phase_len-1 ; t += by) {
236+ const index_t out_sample_ind = start_elem - (filter_phase_len-1 ) + t;
237+ const uint32_t smem_ind = t * num_channels + chan;
238+ const index_t input_ind = out_sample_ind * num_channels + chan;
239+ if (input_ind >= 0 && input_ind < input_len) {
240+ indims[InRank-1 ] = input_ind;
241+ detail::mapply ([smem_input, smem_ind, &input](auto &&...args ) {
242+ smem_input[smem_ind] = input.operator ()(args...);
243+ }, indims);
244+ } else {
245+ smem_input[smem_ind] = static_cast <filter_t >(0 );
246+ }
247+ }
248+
249+ index_t next_start_elem = start_elem;
250+ const index_t num_elem_iters = (last_elem - start_elem + 1 + by - 1 ) / by;
251+
252+ uint32_t cached_input_ind_tail = filter_phase_len - 1 + ty;
253+ const filter_t *h_start = smem_h + num_channels * filter_phase_len - (num_channels - chan);
254+ for (index_t iter = 0 ; iter < num_elem_iters; iter++) {
255+
256+ __syncthreads ();
257+
258+ // Load next elems_per_channel_per_cta elements for each channel
259+ const index_t next_last_elem = std::min (next_start_elem + by - 1 , last_elem);
260+ const uint32_t out_samples_this_iter = static_cast <uint32_t >(next_last_elem - next_start_elem + 1 );
261+ if (ty < out_samples_this_iter) {
262+ indims[InRank-1 ] = (next_start_elem + ty) * num_channels + chan;
263+ const uint32_t smem_ind = cached_input_ind_tail * num_channels + chan;
264+ if (indims[InRank-1 ] < input_len) {
265+ detail::mapply ([smem_input, smem_ind, &input](auto &&...args ) {
266+ smem_input[smem_ind] = input.operator ()(args...);
267+ }, indims);
268+ } else {
269+ smem_input[smem_ind] = static_cast <filter_t >(0 );
270+ }
271+ }
272+
273+ cached_input_ind_tail += by;
274+ // The below effectively mods cached_input_ind_tail by smem_input_height. Since
275+ // smem_input_height is >= by, adding by means that we will need to subtract
276+ // smem_input_height at most once for cached_input_ind_tail to be in the range
277+ // [0, smem_input_height-1]. The conditional is cheaper than the mod, unless
278+ // smem_input_height is known at compile time.
279+ if (cached_input_ind_tail >= smem_input_height) {
280+ cached_input_ind_tail -= smem_input_height;
281+ }
282+
283+ __syncthreads ();
284+
285+ outdims[OutElemRank] = next_start_elem + ty;
286+ if (outdims[OutElemRank] <= last_elem) {
287+ const filter_t *h = h_start;
288+ output_t accum { 0 };
289+ const int first_end = std::min (cached_input_ind_tail + filter_phase_len - 1 , smem_input_height - 1 );
290+ // The footprint of samples involved in the convolution may wrap from the end
291+ // to the beginning of smem_input. The prologue below handles the samples from
292+ // the current tail to the end of smem_input and the epilogue starts back at the
293+ // beginning of smem_input.
294+ const int prologue_count = (first_end - cached_input_ind_tail + 1 );
295+ const int epilogue_count = (prologue_count < filter_phase_len) ? filter_phase_len - prologue_count : 0 ;
296+ const input_t *sample = smem_input + cached_input_ind_tail * num_channels + (num_channels - 1 - chan);
297+ // Apply the filter h in reverse order below to flip the filter for convolution
298+ for (int k = 0 ; k < prologue_count; k++) {
299+ accum += (*h) * (*sample);
300+ sample += num_channels;
301+ h -= num_channels;
302+ }
303+ sample = smem_input + (num_channels - 1 - chan);
304+ for (int k = 0 ; k < epilogue_count; k++) {
305+ accum += (*h) * (*sample);
306+ sample += num_channels;
307+ h -= num_channels;
308+ }
309+
310+ detail::mapply ([accum, &output](auto &&...args ) {
311+ output.operator ()(args...) = accum;
312+ }, outdims);
313+ }
314+
315+ next_start_elem += out_samples_this_iter;
316+ }
317+ }
318+
177319template <int THREADS, int NUM_CHAN, typename OutType, typename InType, typename FilterType>
178320__launch_bounds__ (THREADS)
179321__global__ void ChannelizePoly1D_FusedChan(OutType output, InType input, FilterType filter)
0 commit comments