diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..5edcc233d0 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool( return (kernel_shape, strides, pads) +def _aten_avg_pool_onnx( + self: TFloat, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + ceil_mode: bool, + count_include_pad: bool, +) -> TFloat: + self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1 + if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = op.Unsqueeze(self, [0]) + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + if self_rank_is_unbatched_rank: + result = op.Squeeze(result, [0]) + + return result + + @torch_op("aten::avg_pool1d", trace_only=True) def aten_avg_pool1d( self: TFloat, @@ -134,16 +161,7 @@ def aten_avg_pool1d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) @torch_op("aten::avg_pool2d", trace_only=True) @@ -167,15 +185,6 @@ def aten_avg_pool2d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -189,7 +198,7 @@ def aten_avg_pool2d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool2d_backward( @@ -228,15 +237,6 @@ def aten_avg_pool3d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - kernel_shape=kernel_shape, - strides=strides, - pads=pads, - count_include_pad=count_include_pad, - ceil_mode=ceil_mode, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -250,7 +250,7 @@ def aten_avg_pool3d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool3d_backward( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 754f5e2a25..3c557be4f0 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -238,6 +238,30 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_avg_pool(self): + class Model(torch.nn.Module): + def forward(self, x2d, x3d, x4d, x5d): + return ( + torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable + ) + + x2d = torch.randn(10, 10) + x3d = torch.randn(10, 10, 10) + x4d = torch.randn(10, 10, 10, 10) + x5d = torch.randn(10, 10, 10, 10, 10) + onnx_program = torch.onnx.export( + Model(), + (x2d, x3d, x4d, x5d), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()