Skip to content

Commit 00e921c

Browse files
committed
L2_NORMALIZATION support for tflite
Signed-off-by: Shesung <[email protected]>
1 parent 6905d05 commit 00e921c

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5873,6 +5873,14 @@ def func(x):
58735873
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
58745874
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
58755875

5876+
@skip_tfjs("not supported in tfjs")
5877+
def test_l2normalization(self):
5878+
def func(x):
5879+
op_ = tf.math.l2_normalize(x)
5880+
return tf.identity(op_, name=_TFOUTPUT)
5881+
5882+
x_val = make_xval([3, 4])
5883+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
58765884

58775885
if __name__ == '__main__':
58785886
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,15 @@ class HardSwish:
826826
@classmethod
827827
def version_14(cls, ctx, node, **kwargs):
828828
pass
829+
830+
831+
@tf_op(["L2Normalization"], onnx_op="LpNormalization")
832+
class L2Normalization:
833+
@classmethod
834+
def version_1(cls, ctx, node, **kwargs):
835+
axis = node.get_attr_value("axis")
836+
if axis is None:
837+
# by default use the last dim
838+
axis = -1
839+
node.set_attr("axis", axis)
840+
node.set_attr("p", 2)

tf2onnx/tflite_handlers/tfl_direct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
@tfl_op("TFL_RFFT2D", tf_op="RFFT2D")
8989
@tfl_op("TFL_COMPLEX_ABS", tf_op="ComplexAbs")
9090
@tfl_op("TFL_HARD_SWISH", tf_op="HardSwish")
91+
@tfl_op("TFL_L2_NORMALIZATION", tf_op="L2Normalization")
9192
class TflDirectOp:
9293
@classmethod
9394
def to_tf(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)