Skip to content

Commit bfc54ed

Browse files
authored
4-bit support for convolution (#85)
- Adds arm_convolve_wrapper function for 4bit weights - Adds arm_convolve_1x1 normal and fast variant for 4bit weights - Adds arm_nn_mat_mult_kernel_s4_s16 for multiply 4bit weights with 16bit input - Adds mat_mult_nt_t_s4 function for dsp and scalar - Adds scalar and dsp implementation for arm_convolve with 4bit weights - Adds Unit tests for 4bit weight convolutions Change-Id: Idea55432fdab2db05a033889d7c39dd0ea69f8ad Signed-off-by: Ryan O'Shea <ryan.oshea3@arm.com>
1 parent edececa commit bfc54ed

231 files changed

Lines changed: 10857 additions & 54 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ARM.CMSIS-NN.pdsc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@
3535
<file category="header" name="Include/arm_nn_math_types.h"/>
3636

3737
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_1_x_n_s8.c"/>
38+
<file category="source" name="Source/ConvolutionFunctions/arm_nn_mat_mult_kernel_s4_s16.c"/>
3839
<file category="source" name="Source/ConvolutionFunctions/arm_nn_mat_mult_kernel_s8_s16.c"/>
3940
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c"/>
4041
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_1x1_s8_fast.c"/>
42+
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_1x1_s4_fast.c"/>
43+
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_1x1_s4.c"/>
4144
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_1x1_s8.c"/>
4245
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_s8.c"/>
4346
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_s16.c"/>
@@ -50,14 +53,17 @@
5053
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_s4_opt.c"/>
5154
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_s4.c"/>
5255
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_fast_s16.c"/>
56+
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_s4.c"/>
5357
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_s8.c"/>
5458
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_s16.c"/>
5559
<file category="source" name="Source/ConvolutionFunctions/arm_nn_mat_mult_s8.c"/>
5660
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_3x3_s8.c"/>
5761
<file category="source" name="Source/ConvolutionFunctions/arm_depthwise_conv_s8_opt.c"/>
62+
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_wrapper_s4.c"/>
5863
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c"/>
5964
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_wrapper_s16.c"/>
6065
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_get_buffer_sizes_s16.c"/>
66+
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_get_buffer_sizes_s4.c"/>
6167
<file category="source" name="Source/ConvolutionFunctions/arm_convolve_get_buffer_sizes_s8.c"/>
6268
<file category="source" name="Source/ConvolutionFunctions/arm_nn_depthwise_conv_s8_core.c"/>
6369
<file category="source" name="Source/ConvolutionFunctions/arm_transpose_conv_s8.c"/>
@@ -90,6 +96,7 @@
9096
<file category="source" name="Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_svdf_s8.c"/>
9197
<file category="source" name="Source/NNSupportFunctions/arm_q7_to_q15_with_offset.c"/>
9298
<file category="source" name="Source/NNSupportFunctions/arm_s8_to_s16_unordered_with_offset.c"/>
99+
<file category="source" name="Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s4.c"/>
93100
<file category="source" name="Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8.c"/>
94101
<file category="source" name="Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8_s32.c"/>
95102
<file category="source" name="Source/NNSupportFunctions/arm_nn_depthwise_conv_nt_t_s16.c"/>

Include/arm_nnfunctions.h

Lines changed: 229 additions & 2 deletions
Large diffs are not rendered by default.

Include/arm_nnsupportfunctions.h

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
* Title: arm_nnsupportfunctions.h
2222
* Description: Public header file of support functions for CMSIS NN Library
2323
*
24-
* $Date: 7 November 2023
25-
* $Revision: V.17.5.0
24+
* $Date: 13 November 2023
25+
* $Revision: V.17.6.0
2626
*
2727
* Target : Arm(R) M-Profile Architecture
2828
* -------------------------------------------------------------------- */
@@ -340,6 +340,53 @@ int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
340340
const int32_t *bias,
341341
int8_t *output);
342342

343+
/**
344+
* @brief General Matrix-multiplication function with per-channel requantization.
345+
* This function assumes:
346+
* - LHS input matrix NOT transposed (nt)
347+
* - RHS input matrix transposed (t)
348+
* - RHS is int8 packed with 2x int4
349+
* - LHS is int8
350+
*
351+
* @note This operation also performs the broadcast bias addition before the requantization
352+
*
353+
* @param[in] lhs Pointer to the LHS input matrix
354+
* @param[in] rhs Pointer to the RHS input matrix
355+
* @param[in] bias Pointer to the bias vector. The length of this vector is equal to the number of
356+
* output columns (or RHS input rows)
357+
* @param[out] dst Pointer to the output matrix with "m" rows and "n" columns
358+
* @param[in] dst_multipliers Pointer to the multipliers vector needed for the per-channel requantization.
359+
* The length of this vector is equal to the number of output columns (or RHS input
360+
* rows)
361+
* @param[in] dst_shifts Pointer to the shifts vector needed for the per-channel requantization. The length
362+
* of this vector is equal to the number of output columns (or RHS input rows)
363+
* @param[in] lhs_rows Number of LHS input rows
364+
* @param[in] rhs_rows Number of RHS input rows
365+
* @param[in] rhs_cols Number of LHS/RHS input columns
366+
* @param[in] lhs_offset Offset to be applied to the LHS input value
367+
* @param[in] dst_offset Offset to be applied the output result
368+
* @param[in] activation_min Minimum value to clamp down the output. Range : int8
369+
* @param[in] activation_max Maximum value to clamp up the output. Range : int8
370+
* @param[in] lhs_cols_offset Column offset between subsequent lhs_rows
371+
*
372+
* @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
373+
*
374+
*/
375+
arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
376+
const int8_t *rhs,
377+
const int32_t *bias,
378+
int8_t *dst,
379+
const int32_t *dst_multipliers,
380+
const int32_t *dst_shifts,
381+
const int32_t lhs_rows,
382+
const int32_t rhs_rows,
383+
const int32_t rhs_cols,
384+
const int32_t lhs_offset,
385+
const int32_t dst_offset,
386+
const int32_t activation_min,
387+
const int32_t activation_max,
388+
const int32_t lhs_cols_offset);
389+
343390
/**
344391
* @brief General Matrix-multiplication function with per-channel requantization.
345392
* This function assumes:
@@ -822,6 +869,24 @@ __STATIC_FORCEINLINE void read_and_pad_s4_uneven(const int8_t *source, int32_t *
822869
*out2 = SXTB16_RORn(__sxtb16(inA1), 4);
823870
}
824871

872+
/**
873+
* @brief read and expand one s4 word into two s16 words with ordering.
874+
*/
875+
__STATIC_FORCEINLINE void read_and_pad_s4_ordered(const int8_t *source, int32_t *out1, int32_t *out2)
876+
{
877+
int16_t in = arm_nn_read_s8x2(source);
878+
int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
879+
int32_t inAbuf1 = SXTB16_RORn(__sxtb16(inA), 4);
880+
int32_t inAbuf2 = SXTB16_RORn(__sxtb16(inA << 4), 4);
881+
#ifndef ARM_MATH_BIG_ENDIAN
882+
*out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
883+
*out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
884+
#else
885+
*out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
886+
*out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
887+
#endif
888+
}
889+
825890
/**
826891
* @brief read and expand one s8 word into two s16 words with ordering.
827892
*/
@@ -861,6 +926,39 @@ __STATIC_FORCEINLINE const int8_t *read_and_pad_reordered(const int8_t *source,
861926

862927
#endif
863928

929+
/**
930+
* @brief Matrix-multiplication function for convolution with per-channel requantization and 4 bit weights.
931+
* @param[in] input_a pointer to operand A, int8 packed with 2x int4.
932+
* @param[in] input_b pointer to operand B, always consists of 2 vectors.
933+
* @param[in] output_ch number of rows of A
934+
* @param[in] out_shift pointer to per output channel requantization shift parameter.
935+
* @param[in] out_mult pointer to per output channel requantization multiplier parameter.
936+
* @param[in] out_offset output tensor offset.
937+
* @param[in] activation_min minimum value to clamp the output to. Range : int8
938+
* @param[in] activation_max maximum value to clamp the output to. Range : int8
939+
* @param[in] num_col_a number of columns of A
940+
* @param[in] output_bias per output channel bias. Range : int32
941+
* @param[in,out] out_0 pointer to output
942+
* @return The function returns one of the two
943+
* 1. The incremented output pointer for a successful operation or
944+
* 2. NULL if implementation is not available.
945+
*
946+
* @details This function does the matrix multiplication of weight matrix for all output channels
947+
* with 2 columns from im2col and produces two elements/output_channel. The outputs are
948+
* clamped in the range provided by activation min and max.
949+
* Supported framework: TensorFlow Lite micro.
950+
*/
951+
int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *input_a,
952+
const int16_t *input_b,
953+
const uint16_t output_ch,
954+
const int32_t *out_shift,
955+
const int32_t *out_mult,
956+
const int32_t out_offset,
957+
const int32_t activation_min,
958+
const int32_t activation_max,
959+
const int32_t num_col_a,
960+
const int32_t *const output_bias,
961+
int8_t *out_0);
864962
/**
865963
* @brief Matrix-multiplication function for convolution with per-channel requantization.
866964
* @param[in] input_a pointer to operand A

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ processors here are Cortex-M4 or a Cortex-M33 configured with optional DSP exten
2323
Processors with Arm Helium Technology use the Arm M-profile Vector Extension(MVE) instructions for optimization.
2424
Examples are Cortex-M55 or Cortex-M85 configured with MVE.
2525

26-
| Operator | C <br> int8 | C<br>int16 | C<br>int4* | DSP<br>int8 | DSP<br>int16 | DSP<br>int4* | MVE<br>int8 | MVE<br>int16 |
27-
| --------------- | ----------- | ---------- | ----------- | ------------| -------------| -------------| ------------| -------------|
28-
| Conv2D | Yes | Yes | No | Yes | Yes | No | Yes | Yes |
29-
| DepthwiseConv2D | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
30-
| TransposeConv2D | Yes | No | No | No | No | No | No | No |
31-
| Fully Connected | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
32-
| Add | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
33-
| Mul | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
34-
| MaxPooling | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
35-
| AvgPooling | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
36-
| Softmax | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No |
37-
| LSTM | Yes | NA | No | Yes | NA | No | Yes | NA |
38-
| SVDF | Yes | No | No | Yes | No | No | Yes | No |
26+
| Operator | C <br> int8 | C<br>int16 | C<br>int4* | DSP<br>int8 | DSP<br>int16 | DSP<br>int4* | MVE<br>int8 | MVE<br>int16 |
27+
| --------------- | ----------- | ---------- |------------| ------------| -------------|--------------| ------------| -------------|
28+
| Conv2D | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
29+
| DepthwiseConv2D | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
30+
| TransposeConv2D | Yes | No | No | No | No | No | No | No |
31+
| Fully Connected | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
32+
| Add | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
33+
| Mul | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
34+
| MaxPooling | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
35+
| AvgPooling | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes |
36+
| Softmax | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No |
37+
| LSTM | Yes | NA | No | Yes | NA | No | Yes | NA |
38+
| SVDF | Yes | No | No | Yes | No | No | Yes | No |
3939

4040
* int4 weights + int8 activations
4141

Source/ConvolutionFunctions/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
#
1818

1919
file(GLOB SRC_S4 "./*_s4*.c")
20-
file(GLOB SRC "./*_s8*.c")
20+
file(GLOB SRC_S8 "./*_s8*.c")
2121
file(GLOB SRC_S16 "./*_s16*.c")
22-
target_sources(cmsis-nn PRIVATE ${SRC} ${SRC_S16} ${SRC_S4})
22+
target_sources(cmsis-nn PRIVATE ${SRC_S4} ${SRC_S8} ${SRC_S16})
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
*
4+
* SPDX-License-Identifier: Apache-2.0
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the License); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an AS IS BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
/* ----------------------------------------------------------------------
20+
* Project: CMSIS NN Library
21+
* Title: arm_convolve_1x1_s4.c
22+
* Description: Generic s4 version of 1x1 convolution
23+
*
24+
* $Date: 01 November 2023
25+
* $Revision: V.1.0.0
26+
*
27+
* Target : Arm(R) M-Profile Architecture
28+
*
29+
* -------------------------------------------------------------------- */
30+
31+
#include "arm_nnfunctions.h"
32+
#include "arm_nnsupportfunctions.h"
33+
34+
/**
35+
* @ingroup Public
36+
*/
37+
38+
/**
39+
* @addtogroup NNConv
40+
* @{
41+
*/
42+
43+
/*
44+
* A more generic version of s4 1x1 convolution intended for non-unity strides. This is slower
45+
* than the _fast() version if used for unity stride values.
46+
*
47+
* Refer header file for details.
48+
*
49+
*/
50+
arm_cmsis_nn_status arm_convolve_1x1_s4(const cmsis_nn_context *ctx,
51+
const cmsis_nn_conv_params *conv_params,
52+
const cmsis_nn_per_channel_quant_params *quant_params,
53+
const cmsis_nn_dims *input_dims,
54+
const int8_t *input_data,
55+
const cmsis_nn_dims *filter_dims,
56+
const int8_t *filter_data,
57+
const cmsis_nn_dims *bias_dims,
58+
const int32_t *bias_data,
59+
const cmsis_nn_dims *output_dims,
60+
int8_t *output_data)
61+
{
62+
(void)ctx;
63+
(void)filter_dims;
64+
(void)bias_dims;
65+
if (conv_params->padding.w != 0 || conv_params->padding.h != 0)
66+
{
67+
return ARM_CMSIS_NN_ARG_ERROR;
68+
}
69+
70+
const int32_t lhs_rows = output_dims->w;
71+
const int32_t rhs_rows = output_dims->c;
72+
const int32_t rhs_cols = input_dims->c;
73+
const int32_t stride_w = conv_params->stride.w;
74+
const int32_t input_inc = input_dims->w * conv_params->stride.h * rhs_cols;
75+
const int32_t output_inc = output_dims->w * rhs_rows;
76+
const int32_t output_h = output_dims->h;
77+
const int32_t batch = input_dims->n;
78+
const int8_t *input_data_ref = input_data;
79+
80+
for (int i_batch = 0; i_batch < batch; i_batch++)
81+
{
82+
input_data = input_data_ref + (i_batch * rhs_cols * input_dims->w * input_dims->h);
83+
for (int i_output_h = 0; i_output_h < output_h; i_output_h++)
84+
{
85+
// Process one input row
86+
arm_cmsis_nn_status result = arm_nn_mat_mult_nt_t_s4(input_data,
87+
filter_data,
88+
bias_data,
89+
output_data,
90+
quant_params->multiplier,
91+
quant_params->shift,
92+
lhs_rows,
93+
rhs_rows,
94+
rhs_cols,
95+
conv_params->input_offset,
96+
conv_params->output_offset,
97+
conv_params->activation.min,
98+
conv_params->activation.max,
99+
rhs_cols * stride_w);
100+
if (result != ARM_CMSIS_NN_SUCCESS)
101+
{
102+
return result;
103+
}
104+
input_data += input_inc;
105+
output_data += output_inc;
106+
}
107+
}
108+
109+
/* Return to application */
110+
return ARM_CMSIS_NN_SUCCESS;
111+
}
112+
113+
/**
114+
* @} end of NNConv group
115+
*/

0 commit comments

Comments
 (0)