Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(examples
fft_conv
resample
mvdr_beamformer
resample_poly_bench
spectrogram
spectrogram_graph
spherical_harmonics
Expand Down
4 changes: 2 additions & 2 deletions examples/channelize_poly_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ void ChannelizePolyBench(matx::index_t channel_start, matx::index_t channel_stop
cudaStreamSynchronize(stream);

float elapsed_ms = 0.0f;
cudaEventRecord(start);
cudaEventRecord(start, stream);
for (int k = 0; k < NUM_ITERATIONS; k++) {
(output = channelize_poly(input, filter, num_channels, decimation_factor)).run(stream);
}
cudaEventRecord(stop);
cudaEventRecord(stop, stream);
cudaStreamSynchronize(stream);
CUDA_CHECK_LAST_ERROR();
cudaEventElapsedTime(&elapsed_ms, start, stop);
Expand Down
148 changes: 148 additions & 0 deletions examples/resample_poly_bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#include "matx.h"
#include <cassert>
#include <cstdio>
#include <cmath>
#include <memory>
#include <fstream>
#include <istream>
#include <cuda/std/complex>

using namespace matx;

// This example is used primarily for development purposes to benchmark the performance of the
// polyphase resampler kernel(s). Typically, the parameters below (batch size, filter
// length, input signal length, up/down factors) will be adjusted to a range of interest
// and the benchmark will be run with and without the proposed kernel changes.

constexpr int NUM_WARMUP_ITERATIONS = 2;

// Number of iterations per timed test. Iteration times are averaged in the report.
constexpr int NUM_ITERATIONS = 20;

template <typename InType>
void ResamplePolyBench()
{
struct {
matx::index_t num_batches;
matx::index_t input_len;
matx::index_t up;
matx::index_t down;
} test_cases[] = {
{ 1, 256, 384, 3125 },
{ 1, 256, 4, 5 },
{ 1, 256, 1, 4 },
{ 1, 256, 1, 16 },
{ 1, 3000, 384, 3125 },
{ 1, 3000, 4, 5 },
{ 1, 3000, 1, 4 },
{ 1, 3000, 1, 16 },
{ 1, 31000, 384, 3125 },
{ 1, 31000, 4, 5 },
{ 1, 31000, 1, 4 },
{ 1, 31000, 1, 16 },
{ 1, 256000, 384, 3125 },
{ 1, 256000, 4, 5 },
{ 1, 256000, 1, 4 },
{ 1, 256000, 1, 16 },
{ 42, 256000, 384, 3125 },
{ 42, 256000, 4, 5 },
{ 42, 256000, 1, 4 },
{ 42, 256000, 1, 16 },
{ 128, 256000, 384, 3125 },
{ 128, 256000, 4, 5 },
{ 128, 256000, 1, 4 },
{ 128, 256000, 1, 16 },
};

cudaStream_t stream;
cudaStreamCreate(&stream);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);

for (size_t i = 0; i < sizeof(test_cases)/sizeof(test_cases[0]); i++) {
const matx::index_t num_batches = test_cases[i].num_batches;
const matx::index_t input_len = test_cases[i].input_len;
const matx::index_t up = test_cases[i].up;
const matx::index_t down = test_cases[i].down;
const matx::index_t half_len = 10 * std::max(up, down);
const matx::index_t filter_len = 2 * half_len + 1;
const matx::index_t filter_len_per_phase = (filter_len + up - 1) / up;

const index_t up_len = input_len * up;
const index_t output_len = up_len / down + ((up_len % down) ? 1 : 0);

auto input = matx::make_tensor<InType, 2>({num_batches, input_len});
auto filter = matx::make_tensor<InType, 1>({filter_len});
auto output = matx::make_tensor<InType, 2>({num_batches, output_len});

for (int k = 0; k < NUM_WARMUP_ITERATIONS; k++) {
(output = matx::resample_poly(input, filter, up, down)).run(stream);
}

cudaStreamSynchronize(stream);

float elapsed_ms = 0.0f;
cudaEventRecord(start, stream);
for (int k = 0; k < NUM_ITERATIONS; k++) {
(output = matx::resample_poly(input, filter, up, down)).run(stream);
}
cudaEventRecord(stop, stream);
cudaStreamSynchronize(stream);
CUDA_CHECK_LAST_ERROR();
cudaEventElapsedTime(&elapsed_ms, start, stop);

const double gflops = static_cast<double>(num_batches*(2*filter_len_per_phase-1)*output_len) / 1.0e9;
const double avg_elapsed_us = (static_cast<double>(elapsed_ms)/NUM_ITERATIONS)*1.0e3;
printf("Batches: %5lld FilterLen: %5lld InputLen: %7lld OutputLen: %7lld Up/Down: %4lld/%4lld Elapsed Usecs: %12.1f GFLOPS: %10.3f\n",
num_batches, filter_len, input_len, output_len, up, down, avg_elapsed_us, gflops/(avg_elapsed_us/1.0e6));
}

CUDA_CHECK_LAST_ERROR();

cudaEventDestroy(start);
cudaEventDestroy(stop);
cudaStreamDestroy(stream);
}

int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
{
MATX_ENTER_HANDLER();

printf("Benchmarking float\n");
ResamplePolyBench<float>();

MATX_EXIT_HANDLER();
}
21 changes: 16 additions & 5 deletions include/matx/kernels/resample_poly.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace matx {
template <int THREADS, typename OutType, typename InType, typename FilterType>
__launch_bounds__(THREADS)
__global__ void ResamplePoly1D(OutType output, InType input, FilterType filter,
index_t up, index_t down)
index_t up, index_t down, index_t elems_per_thread)
{
using output_t = typename OutType::scalar_type;
using input_t = typename InType::scalar_type;
Expand All @@ -73,6 +73,7 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter,
}

const int phase_ind = blockIdx.y;
const int elem_block = blockIdx.z;
const int tid = threadIdx.x;
const index_t filter_len_half = filter_len/2;

Expand Down Expand Up @@ -130,7 +131,9 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter,
if (last_filter_ind < 0) {
for (index_t out_ind = phase_ind + tid * up; out_ind < output_len; out_ind += THREADS * up) {
bdims[Rank - 1] = out_ind;
output.operator()(bdims) = 0;
detail::mapply([&output](auto &&...args) {
output.operator()(args...) = 0;
}, bdims);
}
return;
}
Expand Down Expand Up @@ -170,7 +173,9 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter,
const index_t max_h_epilogue = this_phase_len - left_h_ind - 1;
const index_t max_input_ind = static_cast<int>(input_len) - 1;

for (index_t out_ind = phase_ind + tid * up; out_ind < output_len; out_ind += THREADS * up) {
const index_t start_ind = phase_ind + up * (tid + elem_block * elems_per_thread * THREADS);
const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_thread * THREADS * up);
for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS * up) {
// out_ind is the index in the output array and up_ind is the corresponding
// index in the upsampled array
const index_t up_ind = out_ind * down;
Expand All @@ -196,13 +201,19 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter,
index_t x_ind = input_ind - prologue;
index_t h_ind = left_h_ind - prologue;
output_t accum {};
input_t in_val;
for (index_t j = 0; j < n; j++) {
bdims[Rank - 1] = x_ind++;
accum += s_filter[h_ind++] * input.operator()(bdims);
detail::mapply([&in_val, &input](auto &&...args) {
in_val = input.operator()(args...);
}, bdims);
accum += s_filter[h_ind++] * in_val;
}

bdims[Rank - 1] = out_ind;
output.operator()(bdims) = accum;
detail::mapply([&accum, &output](auto &&...args) {
output.operator()(args...) = accum;
}, bdims);
}
}

Expand Down
25 changes: 18 additions & 7 deletions include/matx/transforms/resample_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,27 @@ inline void matxResamplePoly1DInternal(OutType &o, const InType &i,
static_cast<int>((filter_len + 1 + up - 1) / up) :
static_cast<int>((filter_len + up - 1) / up);
const size_t filter_shm = sizeof(filter_t) * max_phase_len;

const index_t output_len = o.Size(OutType::Rank()-1);
const index_t max_output_len_per_phase = (output_len + up - 1) / up;
const int num_phases = static_cast<int>(up);
const int num_batches = static_cast<int>(TotalSize(i)/i.Size(i.Rank() - 1));
dim3 grid(num_batches, num_phases);

constexpr int THREADS = 128;
constexpr index_t DESIRED_MIN_GRID_SIZE = 512;
// If we do not have enough batches and phases to create a large grid, then
// we try to reduce the number of output elements generated per thread to
// yield a large-enough grid to saturate the GPU. However, since the filter
// taps are stored in shared memory, we do not want to process fewer elements
// per thread than is necessary to saturate the GPU.
if (num_batches * num_phases < DESIRED_MIN_GRID_SIZE) {
const index_t desired_elem_blocks = (DESIRED_MIN_GRID_SIZE + num_batches * num_phases - 1) /
(num_batches * num_phases);
const index_t max_output_len_per_thread = (max_output_len_per_phase + THREADS - 1) / THREADS;
grid.z = static_cast<uint32_t>(std::min(desired_elem_blocks, max_output_len_per_thread));
}
const index_t elems_per_thread = (max_output_len_per_phase + THREADS * grid.z - 1) / (THREADS * grid.z);
ResamplePoly1D<THREADS, OutType, InType, FilterType><<<grid, THREADS, filter_shm, stream>>>(
o, i, filter, up, down);

o, i, filter, up, down, elems_per_thread);
#endif
}

Expand Down Expand Up @@ -125,9 +137,8 @@ inline void resample_poly_impl(OutType &out, const InType &in, const FilterType
MATX_ASSERT_STR(out.Size(i) == in.Size(i), matxInvalidDim, "resample_poly: input/output must have matched batch sizes");
}

const index_t up_size = in.Size(RANK-1) * up;
const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);

[[maybe_unused]] const index_t up_size = in.Size(RANK-1) * up;
[[maybe_unused]] const index_t outlen = up_size / down + ((up_size % down) ? 1 : 0);
MATX_ASSERT_STR(out.Size(RANK-1) == outlen, matxInvalidDim, "resample_poly: output size mismatch");

const index_t g = gcd(up, down);
Expand Down
4 changes: 2 additions & 2 deletions test/00_operators/ReductionTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ TYPED_TEST(ReductionTestsFloatNonComplexNonHalfAllExecs, AllClose)
auto B = make_tensor<TestType>({5, 5, 5});
auto C = make_tensor<int>({});

(A = ones<TestType>(A.Shape())).run();
(B = ones<TestType>(B.Shape())).run();
(A = ones<TestType>(A.Shape())).run(exec);
(B = ones<TestType>(B.Shape())).run(exec);
allclose(C, A, B, 1e-5, 1e-8, exec);
// example-end allclose-test-1
cudaStreamSynchronize(0);
Expand Down
46 changes: 46 additions & 0 deletions test/00_transform/ResamplePoly.cu
Original file line number Diff line number Diff line change
Expand Up @@ -513,5 +513,51 @@ TYPED_TEST(ResamplePolyTestNonHalfFloatTypes, Upsample)
}
}

MATX_EXIT_HANDLER();
}

// Use non-trivial operators for input and filter tensors to ensure
// that the kernel supports such operators.
TYPED_TEST(ResamplePolyTestNonHalfFloatTypes, Operators)
{
MATX_ENTER_HANDLER();

struct {
index_t a_len;
index_t f_len;
index_t up;
index_t down;
} test_cases[] = {
{ 3500, 62501, 384, 3125 },
{ 3501, 62501, 384, 3125 },
{ 3500, 62500, 384, 3125 },
{ 3501, 62500, 384, 3125 },
};

for (size_t i = 0; i < sizeof(test_cases)/sizeof(test_cases[0]); i++) {
const index_t a_len = test_cases[i].a_len;
const index_t f_len = test_cases[i].f_len;
const index_t up = test_cases[i].up;
const index_t down = test_cases[i].down;
const index_t up_len = a_len * up;
[[maybe_unused]] const index_t b_len = up_len / down + ((up_len % down) ? 1 : 0);
this->pb->template InitAndRunTVGenerator<TypeParam>(
"00_transforms", "resample_poly_operators", "resample", {a_len, f_len, up, down});

auto a = make_tensor<TypeParam>({a_len});
auto f = make_tensor<TypeParam>({f_len});
auto b = make_tensor<TypeParam>({b_len});
this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(f, "filter_random");

cudaStreamSynchronize(0);

(b = resample_poly(shift<0>(shift<0>(a, 8), -8), shift<0>(shift<0>(f, 3), -3), up, down)).run();

cudaStreamSynchronize(0);

MATX_TEST_ASSERT_COMPARE(this->pb, b, "b_random", this->thresh);
}

MATX_EXIT_HANDLER();
}