Skip to content

Commit efaafeb

Browse files
authored
[Tests] Convert test_rnn_vanilla , test_gru, test_rnn_extra and test_gru_extra gTests (#2550)
1 parent c5024fd commit efaafeb

File tree

9 files changed

+627
-161
lines changed

9 files changed

+627
-161
lines changed

test/CMakeLists.txt

Lines changed: 0 additions & 80 deletions
Large diffs are not rendered by default.

test/gru.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,7 @@
2424
*
2525
*******************************************************************************/
2626

27-
#include "gru_common.hpp"
28-
29-
template <class T>
30-
struct gru_driver : gru_basic_driver<T>
31-
{
32-
gru_driver() : gru_basic_driver<T>()
33-
{
34-
std::vector<int> modes(2, 0);
35-
modes[1] = 1;
36-
std::vector<int> defaultBS(1);
37-
38-
this->add(this->batchSize, "batch-size", this->generate_data(get_gru_batchSize(), {17}));
39-
this->add(this->seqLength, "seq-len", this->generate_data(get_gru_seq_len(), {2}));
40-
this->add(this->inVecLen, "vector-len", this->generate_data(get_gru_vector_len()));
41-
this->add(this->hiddenSize, "hidden-size", this->generate_data(get_gru_hidden_size()));
42-
this->add(this->numLayers, "num-layers", this->generate_data(get_gru_num_layers()));
43-
this->add(this->nohx, "no-hx", this->flag());
44-
this->add(this->nodhy, "no-dhy", this->flag());
45-
this->add(this->nohy, "no-hy", this->flag());
46-
this->add(this->nodhx, "no-dhx", this->flag());
47-
this->add(this->flatBatchFill, "flat-batch-fill", this->flag());
48-
this->add(this->useDropout, "use-dropout", this->generate_data({0}));
49-
50-
#if(MIO_GRU_TEST_DEBUG == 3)
51-
this->biasMode = 0;
52-
this->dirMode = 1;
53-
this->inputMode = 0;
54-
#else
55-
this->add(this->inputMode, "in-mode", this->generate_data(modes));
56-
this->add(this->biasMode, "bias-mode", this->generate_data(modes));
57-
this->add(this->dirMode, "dir-mode", this->generate_data(modes));
58-
#endif
59-
this->add(
60-
this->batchSeq,
61-
"batch-seq",
62-
this->lazy_generate_data(
63-
[=] { return generate_batchSeq(this->batchSize, this->seqLength); }, defaultBS));
64-
}
65-
};
27+
#include "gru.hpp"
6628

6729
int main(int argc, const char* argv[])
6830
{

test/gru.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (c) 2017 Advanced Micro Devices, Inc.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
#pragma once
27+
28+
#include "gru_common.hpp"
29+
30+
template <class T>
31+
struct gru_driver : gru_basic_driver<T>
32+
{
33+
gru_driver() : gru_basic_driver<T>()
34+
{
35+
std::vector<int> modes(2, 0);
36+
modes[1] = 1;
37+
std::vector<int> defaultBS(1);
38+
39+
this->add(this->batchSize, "batch-size", this->generate_data(get_gru_batchSize(), {17}));
40+
this->add(this->seqLength, "seq-len", this->generate_data(get_gru_seq_len(), {2}));
41+
this->add(this->inVecLen, "vector-len", this->generate_data(get_gru_vector_len()));
42+
this->add(this->hiddenSize, "hidden-size", this->generate_data(get_gru_hidden_size()));
43+
this->add(this->numLayers, "num-layers", this->generate_data(get_gru_num_layers()));
44+
this->add(this->nohx, "no-hx", this->flag());
45+
this->add(this->nodhy, "no-dhy", this->flag());
46+
this->add(this->nohy, "no-hy", this->flag());
47+
this->add(this->nodhx, "no-dhx", this->flag());
48+
this->add(this->flatBatchFill, "flat-batch-fill", this->flag());
49+
this->add(this->useDropout, "use-dropout", this->generate_data({0}));
50+
51+
#if(MIO_GRU_TEST_DEBUG == 3)
52+
this->biasMode = 0;
53+
this->dirMode = 1;
54+
this->inputMode = 0;
55+
#else
56+
this->add(this->inputMode, "in-mode", this->generate_data(modes));
57+
this->add(this->biasMode, "bias-mode", this->generate_data(modes));
58+
this->add(this->dirMode, "dir-mode", this->generate_data(modes));
59+
#endif
60+
this->add(
61+
this->batchSeq,
62+
"batch-seq",
63+
this->lazy_generate_data(
64+
[=] { return generate_batchSeq(this->batchSize, this->seqLength); }, defaultBS));
65+
}
66+
};

test/gtest/deepbench_gru.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (c) 2023 Advanced Micro Devices, Inc.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
#include <miopen/miopen.h>
27+
#include <gtest/gtest.h>
28+
#include <miopen/env.hpp>
29+
#include "../gru.hpp"
30+
#include "get_handle.hpp"
31+
32+
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_DEEPBENCH)
33+
34+
namespace deepbench_gru {
35+
static bool SkipTest(void) { return !miopen::IsEnabled(ENV(MIOPEN_TEST_DEEPBENCH)); }
36+
37+
void GetArgs(const std::string& param, std::vector<std::string>& tokens)
38+
{
39+
std::stringstream ss(param);
40+
std::istream_iterator<std::string> begin(ss);
41+
std::istream_iterator<std::string> end;
42+
while(begin != end)
43+
tokens.push_back(*begin++);
44+
}
45+
46+
class DeepBenchGRUConfigWithFloat : public testing::TestWithParam<std::vector<std::string>>
47+
{
48+
};
49+
50+
void Run2dDriverFloat(void)
51+
{
52+
std::vector<std::string> params = DeepBenchGRUConfigWithFloat::GetParam();
53+
54+
for(const auto& test_value : params)
55+
{
56+
std::vector<std::string> tokens;
57+
GetArgs(test_value, tokens);
58+
std::vector<const char*> ptrs;
59+
60+
std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
61+
return str.data();
62+
});
63+
64+
testing::internal::CaptureStderr();
65+
test_drive<gru_driver>(ptrs.size(), ptrs.data());
66+
auto capture = testing::internal::GetCapturedStderr();
67+
std::cout << capture;
68+
}
69+
};
70+
71+
std::vector<std::string> GetTestCases(void)
72+
{
73+
std::string flags = " --verbose";
74+
std::string commonFlags =
75+
" --num-layers 1 --in-mode 1 --bias-mode 0 -dir-mode 0 --rnn-mode 0 --flat-batch-fill";
76+
77+
const std::vector<std::string> test_cases = {
78+
// clang-format off
79+
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2816 --hidden-size 2816" + commonFlags},
80+
{flags + " --batch-size 32 --seq-len 750 --vector-len 2816 --hidden-size 2816" + commonFlags},
81+
{flags + " --batch-size 32 --seq-len 375 --vector-len 2816 --hidden-size 2816" + commonFlags},
82+
{flags + " --batch-size 32 --seq-len 187 --vector-len 2816 --hidden-size 2816" + commonFlags},
83+
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2048 --hidden-size 2048" + commonFlags},
84+
{flags + " --batch-size 32 --seq-len 750 --vector-len 2048 --hidden-size 2048" + commonFlags},
85+
{flags + " --batch-size 32 --seq-len 375 --vector-len 2048 --hidden-size 2048" + commonFlags},
86+
{flags + " --batch-size 32 --seq-len 187 --vector-len 2048 --hidden-size 2048" + commonFlags},
87+
{flags + " --batch-size 32 --seq-len 1500 --vector-len 1536 --hidden-size 1536" + commonFlags},
88+
{flags + " --batch-size 32 --seq-len 750 --vector-len 1536 --hidden-size 1536" + commonFlags},
89+
{flags + " --batch-size 32 --seq-len 375 --vector-len 1536 --hidden-size 1536" + commonFlags},
90+
{flags + " --batch-size 32 --seq-len 187 --vector-len 1536 --hidden-size 1536" + commonFlags},
91+
{flags + " --batch-size 32 --seq-len 1500 --vector-len 2560 --hidden-size 2560" + commonFlags},
92+
{flags + " --batch-size 32 --seq-len 750 --vector-len 2560 --hidden-size 2560" + commonFlags},
93+
{flags + " --batch-size 32 --seq-len 375 --vector-len 2560 --hidden-size 2560" + commonFlags},
94+
{flags + " --batch-size 32 --seq-len 187 --vector-len 2560 --hidden-size 2560" + commonFlags},
95+
{flags + " --batch-size 32 --seq-len 1 --vector-len 512 --hidden-size 512" + commonFlags},
96+
{flags + " --batch-size 32 --seq-len 1500 --vector-len 1024 --hidden-size 1024" + commonFlags},
97+
{flags + " --batch-size 64 --seq-len 1500 --vector-len 1024 --hidden-size 1024" + commonFlags}
98+
// clang-format on
99+
};
100+
101+
return test_cases;
102+
}
103+
104+
} // namespace deepbench_gru
105+
106+
using namespace deepbench_gru;
107+
108+
TEST_P(DeepBenchGRUConfigWithFloat, FloatTest_deepbench_gru)
109+
{
110+
if(SkipTest())
111+
{
112+
GTEST_SKIP();
113+
}
114+
else
115+
{
116+
Run2dDriverFloat();
117+
}
118+
};
119+
120+
INSTANTIATE_TEST_SUITE_P(ConvTrans, DeepBenchGRUConfigWithFloat, testing::Values(GetTestCases()));

test/gtest/deepbench_rnn.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (c) 2023 Advanced Micro Devices, Inc.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
#include <miopen/miopen.h>
27+
#include <gtest/gtest.h>
28+
#include <miopen/env.hpp>
29+
#include "../rnn_vanilla.hpp"
30+
#include "get_handle.hpp"
31+
32+
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_DEEPBENCH)
33+
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL)
34+
35+
namespace deepbench_rnn {
36+
static bool SkipTest(void) { return !miopen::IsEnabled(ENV(MIOPEN_TEST_DEEPBENCH)); }
37+
38+
void GetArgs(const std::string& param, std::vector<std::string>& tokens)
39+
{
40+
std::stringstream ss(param);
41+
std::istream_iterator<std::string> begin(ss);
42+
std::istream_iterator<std::string> end;
43+
while(begin != end)
44+
tokens.push_back(*begin++);
45+
}
46+
47+
class DeepBenchRNNConfigWithFloat : public testing::TestWithParam<std::vector<std::string>>
48+
{
49+
};
50+
51+
void Run2dDriverFloat(void)
52+
{
53+
54+
std::vector<std::string> params = DeepBenchRNNConfigWithFloat::GetParam();
55+
56+
for(const auto& test_value : params)
57+
{
58+
std::vector<std::string> tokens;
59+
GetArgs(test_value, tokens);
60+
std::vector<const char*> ptrs;
61+
62+
std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
63+
return str.data();
64+
});
65+
66+
testing::internal::CaptureStderr();
67+
test_drive<rnn_vanilla_driver>(ptrs.size(), ptrs.data());
68+
auto capture = testing::internal::GetCapturedStderr();
69+
std::cout << capture;
70+
}
71+
};
72+
73+
std::vector<std::string> GetTestCases(void)
74+
{
75+
std::string flags = " --verbose";
76+
77+
std::string postFlags =
78+
"--num-layers 1 --in-mode 1 --bias-mode 0 -dir-mode 0 --rnn-mode 0 --flat-batch-fill";
79+
80+
const std::vector<std::string> test_cases = {
81+
// clang-format off
82+
{flags + " --batch-size 16 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
83+
{flags + " --batch-size 32 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
84+
{flags + " --batch-size 64 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
85+
{flags + " --batch-size 128 --seq-len 50 --vector-len 1760 --hidden-size 1760 " + postFlags},
86+
{flags + " --batch-size 16 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
87+
{flags + " --batch-size 32 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
88+
{flags + " --batch-size 64 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
89+
{flags + " --batch-size 128 --seq-len 50 --vector-len 2048 --hidden-size 2048 " + postFlags},
90+
{flags + " --batch-size 16 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
91+
{flags + " --batch-size 32 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
92+
{flags + " --batch-size 64 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags},
93+
{flags + " --batch-size 128 --seq-len 50 --vector-len 2560 --hidden-size 2560 " + postFlags}
94+
// clang-format on
95+
};
96+
97+
return test_cases;
98+
}
99+
100+
} // namespace deepbench_rnn
101+
102+
using namespace deepbench_rnn;
103+
104+
TEST_P(DeepBenchRNNConfigWithFloat, FloatTest_deepbench_rnn)
105+
{
106+
if(SkipTest())
107+
{
108+
GTEST_SKIP();
109+
}
110+
else
111+
{
112+
Run2dDriverFloat();
113+
}
114+
};
115+
116+
INSTANTIATE_TEST_SUITE_P(ConvTrans, DeepBenchRNNConfigWithFloat, testing::Values(GetTestCases()));

0 commit comments

Comments
 (0)