Skip to content

Commit 717861b

Browse files
committed
Fix GemmDriver uninitialized field
1 parent f185a64 commit 717861b

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

driver/driver.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
141141
printf("Usage: ./driver *base_arg* *other_args*\n");
142142
printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], "
143143
"pool[fp16], lrn[fp16], "
144-
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
144+
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
145145
"tensorop[fp16], reduce[fp16,fp64]\n");
146146
exit(0); // NOLINT (concurrency-mt-unsafe)
147147
}
@@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
160160
arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" &&
161161
arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" &&
162162
arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" &&
163-
arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
163+
arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
164164
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
165165
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version")
166166
{

driver/gemm_driver.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,19 @@ int GemmDriver<T>::GetandSetData()
207207
gemm_desc.strideB = gemm_desc.k * gemm_desc.n;
208208
gemm_desc.strideC = gemm_desc.m * gemm_desc.n;
209209

210+
if constexpr (std::is_same_v<T, float>)
211+
{
212+
gemm_desc.dataType = miopenFloat;
213+
}
214+
else if constexpr (std::is_same_v<T, float16>)
215+
{
216+
gemm_desc.dataType = miopenHalf;
217+
}
218+
else
219+
{
220+
static_assert(!"unsupported type");
221+
}
222+
210223
return (0);
211224
}
212225

@@ -230,9 +243,9 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
230243
a = std::vector<T>(a_sz);
231244
b = std::vector<T>(b_sz);
232245
#if GEMM_DRIVER_DEBUG
233-
c = std::vector<T>(c_sz, 1.);
246+
c = std::vector<T>(c_sz, static_cast<T>(1.));
234247
#else
235-
c = std::vector<T>(c_sz, 0.);
248+
c = std::vector<T>(c_sz, static_cast<T>(0.));
236249
#endif
237250
chost = c;
238251

driver/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ int main(int argc, char* argv[])
126126
drv = new GemmDriver<float>();
127127
}
128128
// TODO half is not supported in gemm
129-
// else if(base_arg == "gemmfp16")
130-
// {
131-
// drv = new GemmDriver<float16>();
132-
// }
129+
else if(base_arg == "gemmfp16")
130+
{
131+
drv = new GemmDriver<float16>();
132+
}
133133
#endif
134134
else if(base_arg == "bnorm")
135135
{

0 commit comments

Comments
 (0)