Skip to content

Commit 25c977c

Browse files
f-salvettifatcat-z
andauthored
Support bitwise ops (#2192)
* 10787 support bitwise ops * 10787 add bitwise tests * ANTIALIAS has been removed in Pillow 10 (2023-07-01). Use LANCZOS instead. Signed-off-by: Salvetti, Francesco <[email protected]> Signed-off-by: Jay Zhang <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent b27aa05 commit 25c977c

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

support_status.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
| BiasAdd | 1 ~ 17 |
3333
| BiasAddV1 | 1 ~ 17 |
3434
| Bincount | 11 ~ 17 |
35+
| BitwiseAnd | 18 |
36+
| BitwiseOr | 18 |
37+
| BitwiseXor | 18 |
3538
| BroadcastTo | 8 ~ 17 |
3639
| CTCGreedyDecoder | 11 ~ 17 |
3740
| Cast | 1 ~ 17 |
@@ -93,6 +96,7 @@
9396
| Identity | 1 ~ 17 |
9497
| IdentityN | 1 ~ 17 |
9598
| If | 1 ~ 17 |
99+
| Invert | 18 |
96100
| InvertPermutation | 11 ~ 17 |
97101
| IsFinite | 10 ~ 17 |
98102
| IsInf | 10 ~ 17 |

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_img(shape, path, dtype, should_scale=True):
5858
resize_to = shape[1:3]
5959
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
6060
img = PIL.Image.open(path)
61-
img = img.resize(resize_to, PIL.Image.ANTIALIAS)
61+
img = img.resize(resize_to, PIL.Image.LANCZOS)
6262
img_np = np.array(img).astype(dtype)
6363
img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape)
6464
if should_scale:

tests/test_backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4896,6 +4896,41 @@ def func(x):
48964896
return tf.identity(x_, name=_TFOUTPUT)
48974897
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
48984898

4899+
@check_opset_min_version(18, "BitwiseAnd")
4900+
def test_bitwise_and(self):
4901+
x_val = np.array([21, 4, 1], dtype=np.int32)
4902+
y_val = np.array([45, 69, 3], dtype=np.int32)
4903+
def func(x, y):
4904+
x_ = tf.bitwise.bitwise_and(x, y)
4905+
return tf.identity(x_, name=_TFOUTPUT)
4906+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4907+
4908+
@check_opset_min_version(18, "BitwiseOr")
4909+
def test_bitwise_or(self):
4910+
x_val = np.array([21, 4, 87], dtype=np.int32)
4911+
y_val = np.array([45, 69, 173], dtype=np.int32)
4912+
def func(x, y):
4913+
x_ = tf.bitwise.bitwise_or(x, y)
4914+
return tf.identity(x_, name=_TFOUTPUT)
4915+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4916+
4917+
@check_opset_min_version(18, "BitwiseXor")
4918+
def test_bitwise_xor(self):
4919+
x_val = np.array([21, 4, 87], dtype=np.int32)
4920+
y_val = np.array([45, 69, 173], dtype=np.int32)
4921+
def func(x, y):
4922+
x_ = tf.bitwise.bitwise_xor(x, y)
4923+
return tf.identity(x_, name=_TFOUTPUT)
4924+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4925+
4926+
@check_opset_min_version(18, "BitwiseNot")
4927+
def test_bitwise_not(self):
4928+
x_val = np.array([21, 4, 1], dtype=np.int32)
4929+
def func(x):
4930+
x_ = tf.bitwise.invert(x)
4931+
return tf.identity(x_, name=_TFOUTPUT)
4932+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4933+
48994934
@check_tf_min_version("1.14", "tensor_scatter_nd_update needs tf 1.14")
49004935
@check_opset_min_version(11, "ScatterND")
49014936
def test_tensor_scatter_update(self):

tf2onnx/onnx_opset/math.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,6 @@ def version_11(cls, ctx, node, **kwargs):
792792

793793
@tf_op(["LeftShift", "RightShift"])
794794
class BitShift:
795-
796795
@classmethod
797796
def version_11(cls, ctx, node, **kwargs):
798797
dir_map = {"LeftShift": "LEFT", "RightShift": "RIGHT"}
@@ -818,6 +817,16 @@ def version_11(cls, ctx, node, **kwargs):
818817
ctx.copy_dtype(node.input[0], node.output[0])
819818

820819

820+
@tf_op("BitwiseAnd")
821+
@tf_op("BitwiseOr")
822+
@tf_op("BitwiseXor")
823+
@tf_op("Invert", onnx_op="BitwiseNot")
824+
class BitwiseOps:
825+
@classmethod
826+
def version_18(cls, ctx, node, **kwargs):
827+
pass
828+
829+
821830
@tf_op("SquaredDistance", onnx_op="MeanSquaredDistance")
822831
class SquaredDistance:
823832
@classmethod

0 commit comments

Comments
 (0)