Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
${MLAS_SRC_DIR}/erf_neon_fp16.h
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -460,13 +462,17 @@ else()
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
${MLAS_SRC_DIR}/erf_neon_fp16.h
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
)

# Conditionally add the SVE implementation if compiler supports it
if (onnxruntime_USE_SVE)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp)
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sve/Elementwise_sve_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Rename to elementwise_sve_fp16.cpp ? (casing is inconsistent)

list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
endif()

Expand Down Expand Up @@ -502,6 +508,7 @@ else()
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/erf_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -517,6 +524,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
9 changes: 8 additions & 1 deletion cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,14 @@ if (onnxruntime_ENABLE_CPU_FP16_OPS)
set_source_files_properties(${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/adasum_kernels.cc PROPERTIES COMPILE_FLAGS " -fassociative-math -ffast-math -ftree-vectorize -funsafe-math-optimizations -mf16c -mavx -mfma ")
endif()

target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT})
if(onnxruntime_target_platform STREQUAL "aarch64" OR onnxruntime_target_platform STREQUAL "ARM64" OR onnxruntime_target_platform STREQUAL "arm64")
set_source_files_properties("${ONNXRUNTIME_ROOT}/core/providers/cpu/tensor/gelu.cc" PROPERTIES COMPILE_FLAGS -march=armv8.2-a+fp16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a duplicated comment - I think this is coming about because some CPU EP files are now directly using intrinsics and I feel the hardware accelerated intrinsic using routines should live in MLAS and only be called from the CPU EP files

endif()
target_include_directories(onnxruntime_providers PRIVATE
${ONNXRUNTIME_ROOT}
${ONNXRUNTIME_ROOT}/core/mlas/inc
)

onnxruntime_add_include_to_target(onnxruntime_providers re2::re2 Eigen3::Eigen)
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})

Expand Down
147 changes: 147 additions & 0 deletions onnxruntime/core/mlas/lib/erf_neon_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*++

Copyright 2025 FUJITSU LIMITED

Module Name:

erf_neon_fp16.cpp

Abstract:

This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.

--*/

#include "erf_neon_fp16.h"

// Helpers to safely convert between float and FP16-bit representation
static float
fp16_to_float(uint16_t h)
{
__fp16 tmp;
memcpy(&tmp, &h, sizeof(h));
return (float)tmp;
}

static uint16_t
float_to_fp16(float f)
{
__fp16 tmp = (__fp16)f;
uint16_t h;
memcpy(&h, &tmp, sizeof(h));
return h;
}

static inline MLAS_FLOAT16X8
exp_neg_rational_approx_f16(MLAS_FLOAT16X8 x)
{
const float16_t a0 = 6.0f;
MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8(a0);
x = MlasMinimumFloat16(x, max_x);

const float16_t c0 = 1.330f;
const float16_t c1 = -0.390f;
const float16_t c2 = 0.0288f;

const float16_t d0 = 1.338f;
const float16_t d1 = 0.848f;
const float16_t d2 = 0.467f;

MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8(c0);
MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8(c1);
MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8(c2);

MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8(d0);
MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8(d1);
MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8(d2);
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(x, x);
MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16(c1v, x, c0v);
num = MlasMultiplyAddFloat16(c2v, x2, num);
MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16(d1v, x, d0v);
den = MlasMultiplyAddFloat16(d2v, x2, den);
MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16(den);
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
recip = MlasMultiplyFloat16(recip, MlasReciprocalSqrtFloat16(den, recip));
MLAS_FLOAT16X8 result = MlasMultiplyFloat16(num, recip);
return result;
}

void
MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
{
const float16_t p = 0.328f;
const float16_t a1 = 0.2505f;
const float16_t a2 = -0.2881f;
const float16_t a3 = 1.4102f;
const float16_t a4 = -1.423f;
const float16_t a5 = 1.0547f;

MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8(p);
MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8(a1);
MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8(a2);
MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8(a3);
MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8(a4);
MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8(a5);

constexpr float16_t one_fp16 = 1.0f;
constexpr float16_t neg_one_fp16 = -1.0f;
constexpr float16_t zero_fp16 = 0.0f;
constexpr float16_t four_fp16 = 4.0f;

MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8(one_fp16);
MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8(neg_one_fp16);
MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8(zero_fp16);
MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8(four_fp16);

size_t i = 0;
for (; i + 8 <= N; i += 8) {
MLAS_FLOAT16X8 x = MlasLoadFloat16x8(&Input[i]);
MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16(x, vzero);
MLAS_FLOAT16X8 sign = MlasSelectFloat16(neg_mask, vneg_one, vone);
MLAS_FLOAT16X8 absx = MlasAbsFloat16(x);
MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16(absx, vth);
MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16(absx, vth);
MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16(vp, absx_clamped, vone);
MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16(denom);
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
t = MlasMultiplyFloat16(t, MlasReciprocalSqrtFloat16(denom, t));
MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16(t, t);
MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16(t2, t);
MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16(t3, t);
MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16(t4, t);
MLAS_FLOAT16X8 poly = MlasMultiplyFloat16(va1, t);
poly = MlasMultiplyAddFloat16(va2, t2, poly);
poly = MlasMultiplyAddFloat16(va3, t3, poly);
poly = MlasMultiplyAddFloat16(va4, t4, poly);
poly = MlasMultiplyAddFloat16(va5, t5, poly);
MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16(absx_clamped, absx_clamped);
MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16(x2);
MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16(poly, exp_neg_x2);
MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16(vone, poly_mul_exp);
MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16(sign, one_minus_term);
erf_approx = MlasMinimumFloat16(erf_approx, vone);
erf_approx = MlasMaximumFloat16(erf_approx, vneg_one);
MLAS_FLOAT16X8 result = MlasSelectFloat16(use_mask, erf_approx, sign);
MlasStoreFloat16x8(&Output[i], result);
}

for (; i < N; i++) {
float x = fp16_to_float(Input[i]);
float sign = (x < 0) ? -1.0f : 1.0f;
float absx = fabsf(x);

if (absx > 4.0f) {
Output[i] = float_to_fp16(sign);
continue;
}

float t = 1.0f / (1.0f + p * absx);
float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
float exp_neg_x2 = expf(-absx * absx);
float erf_approx = sign * (1.0f - poly * exp_neg_x2);
if (erf_approx > 1.0f) erf_approx = 1.0f;
if (erf_approx < -1.0f) erf_approx = -1.0f;

Output[i] = float_to_fp16(erf_approx);
}
}
24 changes: 24 additions & 0 deletions onnxruntime/core/mlas/lib/erf_neon_fp16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*++

Copyright 2025 FUJITSU LIMITED

Module Name:

erf_neon_fp16.h

Abstract:

This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.

--*/

#pragma once

#include <arm_neon.h>

#include "mlasi.h"
#include "fp16_common.h"
#include "softmax_kernel_neon.h"

using _mlas_fp16_ = uint16_t;
void MlasNeonErfKernelFp16(const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N);
50 changes: 50 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); }

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasBroadcastF16Float16x8(float16_t Value) { return vdupq_n_f16(Value); }

MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); }
Expand All @@ -78,6 +82,10 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); }

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasLoadf16Float16x8(const float16_t* Buffer) { return vld1q_f16(Buffer); }

MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }
Expand Down Expand Up @@ -115,6 +123,13 @@ MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector)
vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector));
}

MLAS_FORCEINLINE
void
MlasStoref16Float16x8(float16_t* Buffer, MLAS_FLOAT16X8 Vector)
{
vst1q_f16(Buffer, Vector);
}

MLAS_FORCEINLINE
void
MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
Expand Down Expand Up @@ -579,4 +594,39 @@ MlasShiftLeftInt16(MLAS_INT16X4 Vector)
return vshl_n_s16(Vector, ShiftCount);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasReciprocalSqrtFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vrecpsq_f16(Vector1, Vector2);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasApproximateReciprocalFloat16(MLAS_FLOAT16X8 Vector)
{
return vrecpeq_f16(Vector);
}

MLAS_FORCEINLINE
MLAS_UINT16X8
MlasCompareLessThanFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vcltq_f16(Vector1, Vector2);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasAbsFloat16(MLAS_FLOAT16X8 Vector)
{
return vabsq_f16(Vector);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X8
MlasSelectFloat16(MLAS_UINT16X8 Vector, MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2)
{
return vbslq_f16(Vector, Vector1, Vector2);
}

#endif // fp16 vector intrinsic supported
Loading
Loading