Skip to content

Commit 039bd30

Browse files
committed
Fix to ConvCorr tests to skip host tests when host not enabled
1 parent d9b1399 commit 039bd30

File tree

2 files changed

+101
-32
lines changed

2 files changed

+101
-32
lines changed

include/matx/executors/support.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ constexpr bool CheckDirect1DConvSupport() {
7777
}
7878
}
7979

80+
template <typename Exec, typename T>
81+
constexpr bool CheckFFT1DConvSupport() {
82+
if constexpr (is_host_executor_v<Exec>) {
83+
return CheckFFTSupport<Exec, T>();
84+
}
85+
else {
86+
return true;
87+
}
88+
}
89+
90+
template <typename Exec>
91+
constexpr bool Check2DConvSupport() {
92+
if constexpr (is_host_executor_v<Exec>) {
93+
return false;
94+
}
95+
else {
96+
return true;
97+
}
98+
}
99+
80100
template <typename Exec, typename T>
81101
constexpr bool CheckMatMulSupport() {
82102
if constexpr (is_host_executor_v<Exec>) {

test/00_transform/ConvCorr.cu

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,15 @@ constexpr index_t a_len = 8 * 122880 + 2 * 32768;
6060
constexpr index_t b_len = 209;
6161
constexpr index_t c_len = a_len + b_len - 1;
6262

63-
template <typename T>
63+
template <typename T, matxConvCorrMethod_t METHOD = MATX_C_METHOD_DIRECT>
6464
class CorrelationConvolutionTest : public ::testing::Test {
6565
using GTestType = cuda::std::tuple_element_t<0, T>;
6666
using GExecType = cuda::std::tuple_element_t<1, T>;
6767
protected:
6868
void SetUp() override
6969
{
7070
CheckTestTypeSupport<GTestType>();
71+
CheckExecSupport();
7172
pb = std::make_unique<detail::MatXPybind>();
7273

7374
// Half precision needs a bit more tolerance when compared to
@@ -78,6 +79,19 @@ protected:
7879
}
7980

8081
void TearDown() override { pb.reset(); }
82+
83+
void CheckExecSupport() {
84+
if constexpr (METHOD == MATX_C_METHOD_FFT) {
85+
if constexpr (!detail::CheckFFT1DConvSupport<GExecType, GTestType>()) {
86+
GTEST_SKIP();
87+
}
88+
} else {
89+
if constexpr (!detail::CheckDirect1DConvSupport<GExecType>()) {
90+
GTEST_SKIP();
91+
}
92+
}
93+
}
94+
8195
GExecType exec{};
8296
std::unique_ptr<detail::MatXPybind> pb;
8397
tensor_t<GTestType, 1> av{{a_len0}};
@@ -100,6 +114,11 @@ protected:
100114
void SetUp() override
101115
{
102116
CheckTestTypeSupport<GTestType>();
117+
118+
if constexpr (!detail::Check2DConvSupport<GExecType>()) {
119+
GTEST_SKIP();
120+
}
121+
103122
pb = std::make_unique<detail::MatXPybind>();
104123

105124
// Half precision needs a bit more tolerance when compared to
@@ -123,7 +142,7 @@ protected:
123142
float thresh = 0.01f;
124143
};
125144

126-
template <typename T>
145+
template <typename T, matxConvCorrMethod_t METHOD = MATX_C_METHOD_DIRECT>
127146
class CorrelationConvolutionLargeTest : public ::testing::Test {
128147
protected:
129148
using GTestType = cuda::std::tuple_element_t<0, T>;
@@ -132,6 +151,7 @@ protected:
132151
void SetUp() override
133152
{
134153
CheckTestTypeSupport<GTestType>();
154+
CheckExecSupport();
135155
pb = std::make_unique<detail::MatXPybind>();
136156

137157
// Half precision needs a bit more tolerance when compared to
@@ -142,6 +162,19 @@ protected:
142162
}
143163

144164
void TearDown() override { pb.reset(); }
165+
166+
void CheckExecSupport() {
167+
if constexpr (METHOD == MATX_C_METHOD_FFT) {
168+
if constexpr (!detail::CheckFFT1DConvSupport<GExecType, GTestType>()) {
169+
GTEST_SKIP();
170+
}
171+
} else {
172+
if constexpr (!detail::CheckDirect1DConvSupport<GExecType>()) {
173+
GTEST_SKIP();
174+
}
175+
}
176+
}
177+
145178
GExecType exec{};
146179
std::unique_ptr<detail::MatXPybind> pb;
147180
tensor_t<GTestType, 1> av{{a_len}};
@@ -151,18 +184,33 @@ protected:
151184
};
152185

153186
template <typename TensorType>
154-
class CorrelationConvolutionTestFloatTypes
155-
: public CorrelationConvolutionTest<TensorType> {
187+
class CorrelationConvolutionFFTTestFloatTypes
188+
: public CorrelationConvolutionTest<TensorType, MATX_C_METHOD_FFT> {
156189
};
157190

158191
template <typename TensorType>
159-
class CorrelationConvolutionTestNonHalfFloatTypes
160-
: public CorrelationConvolutionTest<TensorType> {
192+
class CorrelationConvolutionDirectTestFloatTypes
193+
: public CorrelationConvolutionTest<TensorType, MATX_C_METHOD_DIRECT> {
194+
};
195+
196+
template <typename TensorType>
197+
class CorrelationConvolutionFFTTestNonHalfFloatTypes
198+
: public CorrelationConvolutionTest<TensorType, MATX_C_METHOD_FFT> {
199+
};
200+
201+
template <typename TensorType>
202+
class CorrelationConvolutionDirectTestNonHalfFloatTypes
203+
: public CorrelationConvolutionTest<TensorType, MATX_C_METHOD_DIRECT> {
204+
};
205+
206+
template <typename TensorType>
207+
class CorrelationConvolutionLargeFFTTestFloatTypes
208+
: public CorrelationConvolutionLargeTest<TensorType, MATX_C_METHOD_FFT> {
161209
};
162210

163211
template <typename TensorType>
164-
class CorrelationConvolutionLargeTestFloatTypes
165-
: public CorrelationConvolutionLargeTest<TensorType> {
212+
class CorrelationConvolutionLargeDirectTestFloatTypes
213+
: public CorrelationConvolutionLargeTest<TensorType, MATX_C_METHOD_DIRECT> {
166214
};
167215

168216
template <typename TensorType>
@@ -175,13 +223,14 @@ class CorrelationConvolutionComplexTypes
175223
: public CorrelationConvolutionTest<TensorType> {
176224
};
177225

178-
TYPED_TEST_SUITE(CorrelationConvolutionTestFloatTypes, MatXFloatTypesCUDAExec);
179-
TYPED_TEST_SUITE(CorrelationConvolutionTestNonHalfFloatTypes, MatXFloatNonHalfTypesAllExecs);
180-
TYPED_TEST_SUITE(CorrelationConvolutionLargeTestFloatTypes, MatXFloatNonHalfTypesAllExecs);
226+
TYPED_TEST_SUITE(CorrelationConvolutionDirectTestFloatTypes, MatXFloatTypesCUDAExec);
227+
TYPED_TEST_SUITE(CorrelationConvolutionFFTTestNonHalfFloatTypes, MatXFloatNonHalfTypesAllExecs);
228+
TYPED_TEST_SUITE(CorrelationConvolutionLargeDirectTestFloatTypes, MatXFloatNonHalfTypesAllExecs);
229+
TYPED_TEST_SUITE(CorrelationConvolutionLargeFFTTestFloatTypes, MatXFloatNonHalfTypesAllExecs);
181230
TYPED_TEST_SUITE(CorrelationConvolution2DTestFloatTypes, MatXFloatNonHalfTypesCUDAExec);
182231

183232
// Real/real direct 1D convolution Large
184-
TYPED_TEST(CorrelationConvolutionLargeTestFloatTypes, Direct1DConvolutionLarge)
233+
TYPED_TEST(CorrelationConvolutionLargeDirectTestFloatTypes, Direct1DConvolutionLarge)
185234
{
186235
MATX_ENTER_HANDLER();
187236
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -203,7 +252,7 @@ TYPED_TEST(CorrelationConvolutionLargeTestFloatTypes, Direct1DConvolutionLarge)
203252
MATX_EXIT_HANDLER();
204253
}
205254

206-
TYPED_TEST(CorrelationConvolutionLargeTestFloatTypes, FFT1DConvolutionLarge)
255+
TYPED_TEST(CorrelationConvolutionLargeFFTTestFloatTypes, FFT1DConvolutionLarge)
207256
{
208257
MATX_ENTER_HANDLER();
209258
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -222,7 +271,7 @@ TYPED_TEST(CorrelationConvolutionLargeTestFloatTypes, FFT1DConvolutionLarge)
222271

223272

224273
// Real/real direct 1D convolution
225-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullEven)
274+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionFullEven)
226275
{
227276
MATX_ENTER_HANDLER();
228277
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -236,7 +285,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullEven)
236285
MATX_EXIT_HANDLER();
237286
}
238287

239-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionFullEven)
288+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionFullEven)
240289
{
241290
MATX_ENTER_HANDLER();
242291
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -268,7 +317,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionFullEven)
268317

269318

270319

271-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameEven)
320+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionSameEven)
272321
{
273322
MATX_ENTER_HANDLER();
274323
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -282,7 +331,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameEven)
282331
MATX_EXIT_HANDLER();
283332
}
284333

285-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionSameEven)
334+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionSameEven)
286335
{
287336
MATX_ENTER_HANDLER();
288337
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -312,7 +361,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionSameEven)
312361
MATX_EXIT_HANDLER();
313362
}
314363

315-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionValidEven)
364+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionValidEven)
316365
{
317366
MATX_ENTER_HANDLER();
318367
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -326,7 +375,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionValidEven)
326375
MATX_EXIT_HANDLER();
327376
}
328377

329-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionValidEven)
378+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionValidEven)
330379
{
331380
MATX_ENTER_HANDLER();
332381
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -354,7 +403,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionValidEven)
354403
MATX_EXIT_HANDLER();
355404
}
356405

357-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullOdd)
406+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionFullOdd)
358407
{
359408
MATX_ENTER_HANDLER();
360409
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -368,7 +417,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullOdd)
368417
MATX_EXIT_HANDLER();
369418
}
370419

371-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionFullOdd)
420+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionFullOdd)
372421
{
373422
MATX_ENTER_HANDLER();
374423
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -398,7 +447,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionFullOdd)
398447
MATX_EXIT_HANDLER();
399448
}
400449

401-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameOdd)
450+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionSameOdd)
402451
{
403452
MATX_ENTER_HANDLER();
404453
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -412,7 +461,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameOdd)
412461
MATX_EXIT_HANDLER();
413462
}
414463

415-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionSameOdd)
464+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionSameOdd)
416465
{
417466
MATX_ENTER_HANDLER();
418467
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -440,7 +489,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionSameOdd)
440489
MATX_EXIT_HANDLER();
441490
}
442491

443-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionValidOdd)
492+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionValidOdd)
444493
{
445494
MATX_ENTER_HANDLER();
446495
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -454,7 +503,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionValidOdd)
454503
MATX_EXIT_HANDLER();
455504
}
456505

457-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionValidOdd)
506+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionValidOdd)
458507
{
459508
MATX_ENTER_HANDLER();
460509
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -482,7 +531,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionValidOdd)
482531
MATX_EXIT_HANDLER();
483532
}
484533

485-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSwap)
534+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DConvolutionSwap)
486535
{
487536
MATX_ENTER_HANDLER();
488537
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -496,7 +545,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSwap)
496545
MATX_EXIT_HANDLER();
497546
}
498547

499-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DConvolutionSwap)
548+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DConvolutionSwap)
500549
{
501550
MATX_ENTER_HANDLER();
502551
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -524,7 +573,7 @@ TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Direct2DConvolutionSwap)
524573
MATX_EXIT_HANDLER();
525574
}
526575

527-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelation)
576+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DCorrelation)
528577
{
529578
MATX_ENTER_HANDLER();
530579
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -541,7 +590,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelation)
541590
MATX_EXIT_HANDLER();
542591
}
543592

544-
TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DCorrelation)
593+
TYPED_TEST(CorrelationConvolutionFFTTestNonHalfFloatTypes, FFT1DCorrelation)
545594
{
546595
MATX_ENTER_HANDLER();
547596
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -556,7 +605,7 @@ TYPED_TEST(CorrelationConvolutionTestNonHalfFloatTypes, FFT1DCorrelation)
556605
MATX_EXIT_HANDLER();
557606
}
558607

559-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelationSwap)
608+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Direct1DCorrelationSwap)
560609
{
561610
MATX_ENTER_HANDLER();
562611
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -570,7 +619,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelationSwap)
570619
MATX_EXIT_HANDLER();
571620
}
572621

573-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Conv1Axis)
622+
TYPED_TEST(CorrelationConvolutionDirectTestFloatTypes, Conv1Axis)
574623
{
575624
MATX_ENTER_HANDLER();
576625
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
@@ -677,7 +726,7 @@ TYPED_TEST(CorrelationConvolutionTestFloatTypes, Conv1Axis)
677726
MATX_EXIT_HANDLER();
678727
}
679728

680-
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Conv2Axis)
729+
TYPED_TEST(CorrelationConvolution2DTestFloatTypes, Conv2Axis)
681730
{
682731
MATX_ENTER_HANDLER();
683732
using TestType = cuda::std::tuple_element_t<0, TypeParam>;

0 commit comments

Comments
 (0)