-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[MLAS] Enable FP16 for Gelu #26815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
akote123
wants to merge
10
commits into
microsoft:main
Choose a base branch
from
MonakaResearch:gelu_fp16
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[MLAS] Enable FP16 for Gelu #26815
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
645c3fa
Merged PR 432: Update fj-develop with main
akote123 e69b8e9
Merged PR 453: Update FJ-Develop
akote123 a7a111e
Merged PR 463: Update FJ-develop to latest
nikhilfujitsu ae0da8f
Merged PR 465: update fj-develop
akote123 b0667bc
Merged PR 466: Updated the fj-develop
nikhilfujitsu 7e94153
Merged PR 508: Update fj-develop with main
akote123 1bfec70
Merged PR 613: Rebase fj-develop with main
akote123 dc47d84
Merged PR 621: Merge oss main with fj-develop
akote123 52a38a5
Merged PR 779: Rebase fj-develop with main
akote123 cc2625d
Enable Gelu Fp16
akote123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,7 +172,14 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS) | |
| set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ") | ||
| endif() | ||
|
|
||
| target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT}) | ||
| if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64") | ||
| set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may be a duplicated comment - I think this is coming about because some CPU EP files are now directly using intrinsics and I feel the hardware accelerated intrinsic using routines should live in MLAS and only be called from the CPU EP files |
||
| endif() | ||
| target_include_directories(onnxruntime_providers PRIVATE | ||
| ${ONNXRUNTIME_ROOT} | ||
| ${ONNXRUNTIME_ROOT}/core/mlas/inc | ||
| ) | ||
|
|
||
| onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen) | ||
| add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| /*++ | ||
|
|
||
| Copyright 2025 FUJITSU LIMITED | ||
|
|
||
| Module Name: | ||
|
|
||
| erf_neon_fp16.cpp | ||
|
|
||
| Abstract: | ||
|
|
||
| This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. | ||
|
|
||
| --*/ | ||
|
|
||
| #include "erf_neon_fp16.h" | ||
|
|
||
| // Helpers to safely convert between float and FP16-bit representation | ||
| static float | ||
| fp16_to_float(uint16_t h) | ||
| { | ||
| __fp16 tmp; | ||
| memcpy(&tmp, &h, sizeof(h)); | ||
| return (float)tmp; | ||
| } | ||
|
|
||
| static uint16_t | ||
| float_to_fp16(float f) | ||
| { | ||
| __fp16 tmp = (__fp16)f; | ||
| uint16_t h; | ||
| memcpy(&h, &tmp, sizeof(h)); | ||
| return h; | ||
| } | ||
|
|
||
| static inline MLAS_FLOAT16X8 | ||
| exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x) | ||
| { | ||
| const float16_t a0 = 6.0f; | ||
| MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0); | ||
| x = MlasMinimumFloat16(x, max_x); | ||
|
|
||
| const float16_t c0 = 1.330f; | ||
| const float16_t c1 = -0.390f; | ||
| const float16_t c2 = 0.0288f; | ||
|
|
||
| const float16_t d0 = 1.338f; | ||
| const float16_t d1 = 0.848f; | ||
| const float16_t d2 = 0.467f; | ||
|
|
||
| MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0); | ||
| MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1); | ||
| MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2); | ||
|
|
||
| MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0); | ||
| MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1); | ||
| MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2); | ||
| MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x); | ||
| MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v); | ||
| num = MlasMultiplyAddFloat16(c2v, x2, num); | ||
| MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v); | ||
| den = MlasMultiplyAddFloat16(d2v, x2, den); | ||
| MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den); | ||
| recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); | ||
| recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip)); | ||
| MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip); | ||
| return result; | ||
| } | ||
|
|
||
| void | ||
| MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N) | ||
| { | ||
| const float16_t p = 0.328f; | ||
| const float16_t a1 = 0.2505f; | ||
| const float16_t a2 = -0.2881f; | ||
| const float16_t a3 = 1.4102f; | ||
| const float16_t a4 = -1.423f; | ||
| const float16_t a5 = 1.0547f; | ||
|
|
||
| MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p); | ||
| MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1); | ||
| MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2); | ||
| MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3); | ||
| MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4); | ||
| MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5); | ||
|
|
||
| constexpr float16_t one_fp16 = 1.0f; | ||
| constexpr float16_t neg_one_fp16 = -1.0f; | ||
| constexpr float16_t zero_fp16 = 0.0f; | ||
| constexpr float16_t four_fp16 = 4.0f; | ||
|
|
||
| MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16); | ||
| MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16); | ||
| MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16); | ||
| MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16); | ||
|
|
||
| size_t i = 0; | ||
| for (; i + 8 <= N; i += 8) { | ||
| MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]); | ||
| MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero); | ||
| MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone); | ||
| MLAS_FLOAT16X8 absx = MlasAbsFloat16(x); | ||
| MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth); | ||
| MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth); | ||
| MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone); | ||
| MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom); | ||
| t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); | ||
| t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t)); | ||
| MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t); | ||
| MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t); | ||
| MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t); | ||
| MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t); | ||
| MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t); | ||
| poly = MlasMultiplyAddFloat16(va2, t2, poly); | ||
| poly = MlasMultiplyAddFloat16(va3, t3, poly); | ||
| poly = MlasMultiplyAddFloat16(va4, t4, poly); | ||
| poly = MlasMultiplyAddFloat16(va5, t5, poly); | ||
| MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped); | ||
| MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2); | ||
| MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2); | ||
| MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp); | ||
| MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term); | ||
| erf_approx = MlasMinimumFloat16(erf_approx, vone); | ||
| erf_approx = MlasMaximumFloat16(erf_approx, vneg_one); | ||
| MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign); | ||
| MlasStoreFloat16x8(&Output[i], result); | ||
| } | ||
|
|
||
| for (; i < N; i++) { | ||
| float x = fp16_to_float(Input[i]); | ||
| float sign = (x < 0) ? -1.0f : 1.0f; | ||
| float absx = fabsf(x); | ||
|
|
||
| if (absx > 4.0f) { | ||
| Output[i] = float_to_fp16(sign); | ||
| continue; | ||
| } | ||
|
|
||
| float t = 1.0f / (1.0f + p * absx); | ||
| float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t; | ||
| float exp_neg_x2 = expf(-absx * absx); | ||
| float erf_approx = sign * (1.0f - poly * exp_neg_x2); | ||
| if (erf_approx > 1.0f) erf_approx = 1.0f; | ||
| if (erf_approx < -1.0f) erf_approx = -1.0f; | ||
|
|
||
| Output[i] = float_to_fp16(erf_approx); | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| /*++ | ||
|
|
||
| Copyright 2025 FUJITSU LIMITED | ||
|
|
||
| Module Name: | ||
|
|
||
| erf_neon_fp16.h | ||
|
|
||
| Abstract: | ||
|
|
||
| This module contains the procedure prototypes for the ERF NEON FP16 intrinsics. | ||
|
|
||
| --*/ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <arm_neon.h> | ||
|
|
||
| #include "mlasi.h" | ||
| #include "fp16_common.h" | ||
| #include "softmax_kernel_neon.h" | ||
|
|
||
| using _mlas_fp16_ = uint16_t; | ||
| void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Rename to elementwise_sve_fp16.cpp ? (casing is inconsistent)