Skip to content

Commit c7fa6f7

Browse files
Add GroupNorm forward operation (#2623)
1 parent b9f6ef2 commit c7fa6f7

36 files changed

+1996
-103
lines changed

docs/apireference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ API Reference
2424
layernorm
2525
sum
2626
argmax
27+
groupnorm
2728

docs/groupnorm.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
GroupNorm Layer(experimental)
3+
=============================
4+
5+
The groupnorm types and functions.
6+
It splits input channels into num_group groups and do normalize for each group.
7+
8+
To enable this, define MIOPEN_BETA_API before including miopen.h.
9+
10+
11+
miopenNormMode_t
12+
-----------------------
13+
14+
.. doxygenenum:: miopenNormMode_t
15+
16+
miopenGroupNormForward
17+
----------------------------------
18+
19+
.. doxygenfunction:: miopenGroupNormForward
20+

docs/layernorm.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ The layernorm types and functions.
66
To enable this, define MIOPEN_BETA_API before including miopen.h.
77

88

9-
miopenLayerNormMode_t
9+
miopenNormMode_t
1010
-----------------------
1111

12-
.. doxygenenum:: miopenLayerNormMode_t
12+
.. doxygenenum:: miopenNormMode_t
1313

1414
miopenLayerNormForward
1515
----------------------------------

driver/driver.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
151151
"pool[fp16], lrn[fp16], "
152152
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
153153
"tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], "
154-
"argmax[bfp16|fp16]\n");
154+
"argmax[bfp16|fp16], groupnorm[bfp16|fp16]\n");
155155
exit(0); // NOLINT (concurrency-mt-unsafe)
156156
}
157157

@@ -175,6 +175,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
175175
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" &&
176176
arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" &&
177177
arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" &&
178+
arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" &&
178179
arg != "--version")
179180
{
180181
printf("FAILED: Invalid Base Input Argument\n");

0 commit comments

Comments
 (0)