Skip to content

Commit fa05c1d

Browse files
committed
Make size non-constexpr
1 parent b2e5c66 commit fa05c1d

23 files changed

+289
-215
lines changed

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ inline scalar_t vec_reduce_all(
1818
scalar_t acc_arr[Vec::size()];
1919
acc_vec.store(acc_arr);
2020
for (const auto i : c10::irange(1, size)) {
21-
std::array<scalar_t, Vec::size()> acc_arr_next = {0};
21+
scalar_t acc_arr_next[Vec::size()] = {0};
2222
acc_arr_next[0] = acc_arr[i];
23-
Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
23+
Vec acc_vec_next = Vec::loadu(acc_arr_next);
2424
acc_vec = vec_fun(acc_vec, acc_vec_next);
2525
}
2626
acc_vec.store(acc_arr);

aten/src/ATen/cpu/vec/sve/vec_double.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ template <> class Vectorized<double> {
4747
operator svfloat64_t() const {
4848
return values;
4949
}
50-
template <uint64_t mask>
51-
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
50+
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b, const uint64_t mask) {
5251
// Build an array of flags: each element is 1 if the corresponding bit in 'mask' is set, 0 otherwise.
5352
__at_align__ int64_t flag_arr[size()];
5453
for (int i = 0; i < size(); i++) {

aten/src/ATen/cpu/vec/sve/vec_float_2.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ template <> class Vectorized<float> {
3737

3838
using value_type = float;
3939
using size_type = int;
40-
static inline constexpr size_type size() {
41-
return SVE_FLOAT_VEC_SIZE;
40+
static inline size_type size() {
41+
return svcntw();
4242
}
4343
inline Vectorized() {}
4444
inline Vectorized(const float val) {
@@ -84,8 +84,7 @@ template <> class Vectorized<float> {
8484
return result;
8585
}
8686

87-
template <uint64_t mask>
88-
static inline Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
87+
static inline Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b, const uint64_t mask) {
8988
// Build an array of flags: each element is 1 if the corresponding bit in 'mask' is set, 0 otherwise.
9089
__at_align__ int32_t flag_arr[size()];
9190
for (int i = 0; i < size(); i++) {

aten/src/ATen/cpu/vec/sve/vec_int.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ public:
4242
operator svint##bit##_t() const { \
4343
return values; \
4444
} \
45-
template <uint64_t mask> \
46-
static Vectorized<int##bit##_t> blend(const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
45+
static Vectorized<int##bit##_t> blend(const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b, const uint64_t mask) { \
4746
__at_align__ int##bit##_t flag_arr[size()]; \
4847
for (int i = 0; i < size(); ++i) { \
4948
flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \

aten/src/ATen/cpu/vec/sve/vec_qint.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
193193
int32_t zero_point,
194194
float inverse_scale) {
195195
std::array<value_type, size()> qvals;
196-
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
196+
float float_vals[float_num_vecs() * Vectorized<float>::size()];
197197

198198
for (int i = 0; i < float_num_vecs(); ++i) {
199199
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
@@ -202,7 +202,7 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
202202
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
203203
scale,
204204
zero_point,
205-
float_vals.data(),
205+
float_vals,
206206
(c10::qint32*)qvals.data(),
207207
Vectorized<float>::size() * float_num_vecs());
208208

@@ -337,7 +337,7 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
337337
int32_t zero_point,
338338
float inverse_scale) {
339339
std::array<value_type, size()> qvals;
340-
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
340+
float float_vals[float_num_vecs() * Vectorized<float>::size()];
341341

342342
for (int i = 0; i < float_num_vecs(); ++i) {
343343
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
@@ -346,7 +346,7 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
346346
at::native::quantize_vec<c10::qint8>(
347347
scale,
348348
zero_point,
349-
float_vals.data(),
349+
float_vals,
350350
(c10::qint8*)qvals.data(),
351351
Vectorized<float>::size() * float_num_vecs());
352352

@@ -476,7 +476,7 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
476476
int32_t zero_point,
477477
float inverse_scale) {
478478
std::array<value_type, size()> qvals;
479-
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
479+
float float_vals[float_num_vecs() * Vectorized<float>::size()];
480480

481481
for (int i = 0; i < float_num_vecs(); ++i) {
482482
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
@@ -485,7 +485,7 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
485485
at::native::quantize_vec<c10::quint8>(
486486
scale,
487487
zero_point,
488-
float_vals.data(),
488+
float_vals,
489489
(c10::quint8*)qvals.data(),
490490
Vectorized<float>::size() * float_num_vecs());
491491

aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,12 @@ inline namespace CPU_CAPABILITY {
4141
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
4242
#endif
4343

44-
template<int index, bool mask_val>
45-
struct BlendRegs {
46-
static float32x4_t impl(
47-
const float32x4_t& a, const float32x4_t& b, float32x4_t& res);
48-
};
49-
50-
template<int index>
51-
struct BlendRegs<index, true>{
52-
static float32x4_t impl(
53-
const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
54-
return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
55-
}
56-
};
57-
5844
template<int index>
59-
struct BlendRegs<index, false>{
45+
struct BlendRegs{
6046
static float32x4_t impl(
61-
const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
47+
const float32x4_t& a, const float32x4_t& b, float32x4_t& res, bool mask_val) {
48+
if (mask_val)
49+
return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
6250
return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
6351
}
6452
};
@@ -81,21 +69,20 @@ template <> class Vectorized<float> {
8169
operator float32x4_t() const {
8270
return values;
8371
}
84-
template <int64_t mask>
85-
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
72+
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b, const int64_t mask) {
8673
Vectorized<float> vec;
8774
vec.values =
88-
BlendRegs<0, (mask & 0x01)!=0>::impl(
89-
a.values, b.values, vec.values);
75+
BlendRegs<0>::impl(
76+
a.values, b.values, vec.values, (mask & 0x01)!=0);
9077
vec.values =
91-
BlendRegs<1, (mask & 0x02)!=0>::impl(
92-
a.values, b.values, vec.values);
78+
BlendRegs<1>::impl(
79+
a.values, b.values, vec.values, (mask & 0x02)!=0);
9380
vec.values =
94-
BlendRegs<2, (mask & 0x04)!=0>::impl(
95-
a.values, b.values, vec.values);
81+
BlendRegs<2>::impl(
82+
a.values, b.values, vec.values, (mask & 0x04)!=0);
9683
vec.values =
97-
BlendRegs<3, (mask & 0x08)!=0>::impl(
98-
a.values, b.values, vec.values);
84+
BlendRegs<3>::impl(
85+
a.values, b.values, vec.values, (mask & 0x08)!=0);
9986
return vec;
10087
}
10188
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,

aten/src/ATen/cpu/vec/vec_base.h

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ struct Vectorized {
182182
auto as_bytes() const -> const char* {
183183
return reinterpret_cast<const char*>(values);
184184
}
185-
template <int64_t mask_>
186-
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
185+
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b, const int64_t mask_) {
187186
int64_t mask = mask_;
188187
Vectorized vector;
189188
for (const auto i : c10::irange(size())) {
@@ -1013,7 +1012,7 @@ template <int64_t scale = 1, typename T = void>
10131012
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
10141013
inline mask_gather(const Vectorized<T>& src, T const* base_addr,
10151014
const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
1016-
static constexpr int size = Vectorized<T>::size();
1015+
static const int size = Vectorized<T>::size();
10171016
T src_arr[size];
10181017
int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
10191018
int_same_size_t<T> index_arr[size];
@@ -1097,7 +1096,7 @@ inline Vectorized<T> convert_to_fp_of_same_size(const Vectorized<IntType>& src)
10971096
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
10981097
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
10991098
template <typename T>
1100-
inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1099+
inline std::enable_if_t<true, std::pair<Vectorized<T>, Vectorized<T>>>
11011100
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
11021101
static constexpr int size = Vectorized<T>::size();
11031102
static constexpr int half_size = size / 2;
@@ -1116,6 +1115,26 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
11161115
return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
11171116
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
11181117
}
1118+
// template <typename T>
1119+
// inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1120+
// deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1121+
// static constexpr int size = Vectorized<T>::size();
1122+
// static constexpr int half_size = size / 2;
1123+
// T a_arr[size];
1124+
// T b_arr[size];
1125+
// T buffer1[size];
1126+
// T buffer2[size];
1127+
// a.store(static_cast<void*>(a_arr));
1128+
// b.store(static_cast<void*>(b_arr));
1129+
// for (const auto i : c10::irange(half_size)) {
1130+
// buffer1[i] = a_arr[i * 2];
1131+
// buffer1[half_size + i] = b_arr[i * 2];
1132+
// buffer2[i] = a_arr[i * 2 + 1];
1133+
// buffer2[half_size + i] = b_arr[i * 2 + 1];
1134+
// }
1135+
// return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1136+
// Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1137+
// }
11191138

11201139
// inverse operation of deinterleave2
11211140
// Example inputs for AVX512:
@@ -1129,7 +1148,7 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
11291148
// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
11301149
// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
11311150
template <typename T>
1132-
inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1151+
inline std::enable_if_t<true, std::pair<Vectorized<T>, Vectorized<T>>>
11331152
interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
11341153
static constexpr int size = Vectorized<T>::size();
11351154
static constexpr int half_size = size / 2;
@@ -1149,6 +1168,27 @@ interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
11491168
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
11501169
}
11511170

1171+
// template <typename T>
1172+
// inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1173+
// interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1174+
// static constexpr int size = Vectorized<T>::size();
1175+
// static constexpr int half_size = size / 2;
1176+
// T a_arr[size];
1177+
// T b_arr[size];
1178+
// T buffer1[size];
1179+
// T buffer2[size];
1180+
// a.store(static_cast<void*>(a_arr));
1181+
// b.store(static_cast<void*>(b_arr));
1182+
// for (const auto i : c10::irange(half_size)) {
1183+
// buffer1[i * 2] = a_arr[i];
1184+
// buffer1[i * 2 + 1] = b_arr[i];
1185+
// buffer2[i * 2] = a_arr[half_size + i];
1186+
// buffer2[i * 2 + 1] = b_arr[half_size + i];
1187+
// }
1188+
// return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1189+
// Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1190+
// }
1191+
11521192
template <typename src_T, typename dst_T>
11531193
inline void convert(const src_T *src, dst_T *dst, int64_t n) {
11541194
#ifndef _MSC_VER
@@ -1163,7 +1203,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) {
11631203

11641204
template <typename T>
11651205
inline Vectorized<T> flip(const Vectorized<T> & data) {
1166-
static constexpr int size = Vectorized<T>::size();
1206+
static const int size = Vectorized<T>::size();
11671207
T output[size];
11681208
T buffer[size];
11691209
data.store(static_cast<void*>(buffer));

aten/src/ATen/cpu/vec/vec_convert.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ template <
1515
struct VecConvert {
1616
static inline VectorizedN<dst_t, dst_n> apply(
1717
const VectorizedN<src_t, src_n>& src) {
18-
constexpr int count = std::min(
18+
const int count = std::min(
1919
VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
2020
__at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
2121
src.store(src_buf);

aten/src/ATen/cpu/vec/vec_mask.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <ATen/cpu/vec/vec_base.h>
44
#include <ATen/cpu/vec/vec_n.h>
5+
6+
#include <cassert>
57
namespace at::vec {
68
inline namespace CPU_CAPABILITY {
79

@@ -38,9 +40,9 @@ struct VecMaskLoad {
3840
static inline VectorizedN<data_t, data_n> apply(
3941
const data_t* ptr,
4042
const VecMask<mask_t, mask_n>& vec_mask) {
41-
constexpr typename VecMask<mask_t, mask_n>::size_type size =
43+
const typename VecMask<mask_t, mask_n>::size_type size =
4244
VecMask<mask_t, mask_n>::size();
43-
static_assert(VectorizedN<data_t, data_n>::size() >= size);
45+
assert((VectorizedN<data_t, data_n>::size() >= size));
4446
__at_align__ data_t data[size];
4547
__at_align__ mask_t mask[size];
4648
auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
@@ -127,7 +129,7 @@ class VecMask {
127129
template <typename U, int L>
128130
static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
129131
__at_align__ U b_buf[size()];
130-
if constexpr (size() >= VectorizedN<U, L>::size()) {
132+
if (size() >= VectorizedN<U, L>::size()) {
131133
b_vec.store(b_buf);
132134
for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
133135
b_buf[i] = static_cast<U>(0);
@@ -230,16 +232,18 @@ class VecMask {
230232
template <
231233
typename U,
232234
int L,
233-
std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
235+
std::enable_if_t<L >= 2, int> = 0>
234236
VectorizedN<U, L> loadu(const U* ptr) const {
237+
assert((VectorizedN<U, L>::size() >= size()));
235238
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
236239
}
237240

238241
template <
239242
typename U,
240243
int L,
241-
std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
244+
std::enable_if_t<L == 1, int> = 0>
242245
Vectorized<U> loadu(const U* ptr) const {
246+
assert((Vectorized<U>::size() >= size()));
243247
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
244248
}
245249
};

aten/src/ATen/cpu/vec/vec_n.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class VectorizedN {
2828
using size_type = int;
2929

3030
static constexpr size_type size_T = sizeof(T);
31-
static constexpr size_type size() {
31+
static size_type size() {
3232
return Vectorized<T>::size() * N;
3333
}
3434

0 commit comments

Comments
 (0)