From 7ad370ae8be268373a74a1ef027ff717f28df1de Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 10 Jun 2025 14:49:55 -0600 Subject: [PATCH] [mlir][spirv] Add lowering of multiple math trigonometric/hyperbolic unary intrinsics --- .../Conversion/MathToSPIRV/MathToSPIRV.cpp | 20 ++++++++++-- .../MathToSPIRV/math-to-gl-spirv.mlir | 32 +++++++++++++++++++ .../MathToSPIRV/math-to-opencl-spirv.mlir | 32 +++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 1b83794b5f450..501bfa223fb18 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -509,7 +509,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, - CheckedElementwiseOpPattern>( + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns @@ -533,7 +541,15 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, - CheckedElementwiseOpPattern>( + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir index 5c6561c104389..b8e001c9f6950 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -46,6 +46,22 @@ func.func @float32_unary_scalar(%arg0: f32) { %14 = math.ceil %arg0 : f32 // CHECK: spirv.GL.Floor %{{.*}}: f32 %15 = math.floor %arg0 : f32 + // CHECK: spirv.GL.Tan %{{.*}}: f32 + %16 = math.tan %arg0 : f32 + // CHECK: spirv.GL.Asin %{{.*}}: f32 + %17 = math.asin %arg0 : f32 + // CHECK: spirv.GL.Acos %{{.*}}: f32 + %18 = math.acos %arg0 : f32 + // CHECK: spirv.GL.Sinh %{{.*}}: f32 + %19 = math.sinh %arg0 : f32 + // CHECK: spirv.GL.Cosh %{{.*}}: f32 + %20 = math.cosh %arg0 : f32 + // CHECK: spirv.GL.Asinh %{{.*}}: f32 + %21 = math.asinh %arg0 : f32 + // CHECK: spirv.GL.Acosh %{{.*}}: f32 + %22 = math.acosh %arg0 : f32 + // CHECK: spirv.GL.Atanh %{{.*}}: f32 + %23 = math.atanh %arg0 : f32 return } @@ -85,6 +101,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) { %11 = math.tanh %arg0 : vector<3xf32> // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32> %12 = math.sin %arg0 : vector<3xf32> + // CHECK: spirv.GL.Tan %{{.*}}: vector<3xf32> + %13 = math.tan %arg0 : vector<3xf32> + // CHECK: spirv.GL.Asin %{{.*}}: vector<3xf32> + %14 = math.asin %arg0 : vector<3xf32> + // CHECK: spirv.GL.Acos %{{.*}}: vector<3xf32> + %15 = math.acos %arg0 : vector<3xf32> + // CHECK: spirv.GL.Sinh %{{.*}}: vector<3xf32> + %16 = math.sinh %arg0 : vector<3xf32> + // CHECK: spirv.GL.Cosh %{{.*}}: vector<3xf32> + %17 = math.cosh %arg0 : vector<3xf32> + // CHECK: spirv.GL.Asinh %{{.*}}: vector<3xf32> + %18 = math.asinh %arg0 : vector<3xf32> + // CHECK: spirv.GL.Acosh %{{.*}}: vector<3xf32> + %19 = math.acosh %arg0 : vector<3xf32> + // CHECK: spirv.GL.Atanh %{{.*}}: vector<3xf32> + %20 = math.atanh %arg0 : vector<3xf32> return } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir index 393a910c1fb1d..56a0d4dafec8c 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -48,6 +48,22 @@ func.func @float32_unary_scalar(%arg0: f32) { %16 = math.erf %arg0 : f32 // CHECK: spirv.CL.round %{{.*}}: f32 %17 = math.round %arg0 : f32 + // CHECK: spirv.CL.tan %{{.*}}: f32 + %18 = math.tan %arg0 : f32 + // CHECK: spirv.CL.asin %{{.*}}: f32 + %19 = math.asin %arg0 : f32 + // CHECK: spirv.CL.acos %{{.*}}: f32 + %20 = math.acos %arg0 : f32 + // CHECK: spirv.CL.sinh %{{.*}}: f32 + %21 = math.sinh %arg0 : f32 + // CHECK: spirv.CL.cosh %{{.*}}: f32 + %22 = math.cosh %arg0 : f32 + // CHECK: spirv.CL.asinh %{{.*}}: f32 + %23 = math.asinh %arg0 : f32 + // CHECK: spirv.CL.acosh %{{.*}}: f32 + %24 = math.acosh %arg0 : f32 + // CHECK: spirv.CL.atanh %{{.*}}: f32 + %25 = math.atanh %arg0 : f32 return } @@ -87,6 +103,22 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) { %11 = math.tanh %arg0 : vector<3xf32> // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32> %12 = math.sin %arg0 : vector<3xf32> + // CHECK: spirv.CL.tan %{{.*}}: vector<3xf32> + %13 = math.tan %arg0 : vector<3xf32> + // CHECK: spirv.CL.asin %{{.*}}: vector<3xf32> + %14 = math.asin %arg0 : vector<3xf32> + // CHECK: spirv.CL.acos %{{.*}}: vector<3xf32> + %15 = math.acos %arg0 : vector<3xf32> + // CHECK: spirv.CL.sinh %{{.*}}: vector<3xf32> + %16 = math.sinh %arg0 : vector<3xf32> + // CHECK: spirv.CL.cosh %{{.*}}: vector<3xf32> + %17 = math.cosh %arg0 : vector<3xf32> + // CHECK: spirv.CL.asinh %{{.*}}: vector<3xf32> + %18 = math.asinh %arg0 : vector<3xf32> + // CHECK: spirv.CL.acosh %{{.*}}: vector<3xf32> + %19 = math.acosh %arg0 : vector<3xf32> + // CHECK: spirv.CL.atanh %{{.*}}: vector<3xf32> + %20 = math.atanh %arg0 : vector<3xf32> return }