Skip to content

Commit ab7c452

Browse files
Ensure that NI_Vector_Dot is always handled as TYP_SIMD (#88447)
* Ensure that NI_Vector_Dot is always handled as TYP_SIMD * Apply formatting patch * Ensure simdType is passed in to gtNewSimdDotProdNode * Apply suggestions from code review Co-authored-by: Kunal Pathak <[email protected]> --------- Co-authored-by: Kunal Pathak <[email protected]>
1 parent acafb20 commit ab7c452

File tree

7 files changed

+87
-98
lines changed

7 files changed

+87
-98
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22052,9 +22052,7 @@ GenTree* Compiler::gtNewSimdDotProdNode(
2205222052
assert(op2->TypeIs(simdType));
2205322053

2205422054
var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
22055-
22056-
// We support the return type being a SIMD for floating-point as a special optimization
22057-
assert(varTypeIsArithmetic(type) || (varTypeIsSIMD(type) && varTypeIsFloating(simdBaseType)));
22055+
assert(varTypeIsSIMD(type));
2205822056

2205922057
NamedIntrinsic intrinsic = NI_Illegal;
2206022058

src/coreclr/jit/hwintrinsicarm64.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
881881
op2 = impSIMDPopStack();
882882
op1 = impSIMDPopStack();
883883

884-
retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
884+
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
885+
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
885886
}
886887
break;
887888
}

src/coreclr/jit/hwintrinsicxarch.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
17781778
op2 = impSIMDPopStack();
17791779
op1 = impSIMDPopStack();
17801780

1781-
retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
1781+
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
1782+
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
17821783
break;
17831784
}
17841785

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,9 +1628,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
16281628
assert(varTypeIsSIMD(simdType));
16291629
assert(varTypeIsArithmetic(simdBaseType));
16301630
assert(simdSize != 0);
1631-
1632-
// We support the return type being a SIMD for floating-point as a special optimization
1633-
assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
1631+
assert(varTypeIsSIMD(node));
16341632

16351633
GenTree* op1 = node->Op(1);
16361634
GenTree* op2 = node->Op(2);
@@ -1647,7 +1645,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
16471645

16481646
// For 12 byte SIMD, we need to clear the upper 4 bytes:
16491647
// idx = CNS_INT int 0x03
1650-
// tmp1 = * CNS_DLB float 0.0
1648+
// tmp1 = * CNS_DBL float 0.0
16511649
// /--* op1 simd16
16521650
// +--* idx int
16531651
// +--* tmp1 simd16
@@ -1887,34 +1885,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
18871885
}
18881886
}
18891887

1890-
if (varTypeIsSIMD(node->gtType))
1891-
{
1892-
// We're producing a vector result, so just return the result directly
1893-
1894-
LIR::Use use;
1895-
1896-
if (BlockRange().TryGetUse(node, &use))
1897-
{
1898-
use.ReplaceWith(tmp2);
1899-
}
1888+
// We're producing a vector result, so just return the result directly
1889+
LIR::Use use;
19001890

1901-
BlockRange().Remove(node);
1902-
return tmp2->gtNext;
1903-
}
1904-
else
1891+
if (BlockRange().TryGetUse(node, &use))
19051892
{
1906-
// We will be constructing the following parts:
1907-
// ...
1908-
// /--* tmp2 simd16
1909-
// node = * HWINTRINSIC simd16 T ToScalar
1910-
1911-
// This is roughly the following managed code:
1912-
// ...
1913-
// return tmp2.ToScalar();
1914-
1915-
node->ResetHWIntrinsicId((simdSize == 8) ? NI_Vector64_ToScalar : NI_Vector128_ToScalar, tmp2);
1916-
return LowerNode(node);
1893+
use.ReplaceWith(tmp2);
19171894
}
1895+
1896+
BlockRange().Remove(node);
1897+
return tmp2->gtNext;
19181898
}
19191899
#endif // FEATURE_HW_INTRINSICS
19201900

src/coreclr/jit/lowerxarch.cpp

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4428,9 +4428,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
44284428
assert(varTypeIsSIMD(simdType));
44294429
assert(varTypeIsArithmetic(simdBaseType));
44304430
assert(simdSize != 0);
4431-
4432-
// We support the return type being a SIMD for floating-point as a special optimization
4433-
assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
4431+
assert(varTypeIsSIMD(node));
44344432

44354433
GenTree* op1 = node->Op(1);
44364434
GenTree* op2 = node->Op(2);
@@ -4479,15 +4477,12 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
44794477
// tmp2 = * HWINTRINSIC simd16 T GetUpper
44804478
// /--* tmp1 simd16
44814479
// +--* tmp2 simd16
4482-
// tmp3 = * HWINTRINSIC simd16 T Add
4483-
// /--* tmp3 simd16
4484-
// node = * HWINTRINSIC simd16 T ToScalar
4480+
// node = * HWINTRINSIC simd16 T Add
44854481

44864482
// This is roughly the following managed code:
44874483
// var tmp1 = Avx.DotProduct(op1, op2, 0xFF);
44884484
// var tmp2 = tmp1.GetUpper();
4489-
// var tmp3 = Sse.Add(tmp1, tmp2);
4490-
// return tmp3.ToScalar();
4485+
// return Sse.Add(tmp1, tmp2);
44914486

44924487
idx = comp->gtNewIconNode(0xF1, TYP_INT);
44934488
BlockRange().InsertBefore(node, idx);
@@ -4515,13 +4510,17 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
45154510

45164511
tmp2 = comp->gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, tmp3, tmp1, simdBaseJitType, 16);
45174512
BlockRange().InsertAfter(tmp1, tmp2);
4518-
LowerNode(tmp2);
45194513

4520-
node->SetSimdSize(16);
4514+
// We're producing a vector result, so just return the result directly
4515+
LIR::Use use;
45214516

4522-
node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp2);
4517+
if (BlockRange().TryGetUse(node, &use))
4518+
{
4519+
use.ReplaceWith(tmp2);
4520+
}
45234521

4524-
return LowerNode(node);
4522+
BlockRange().Remove(node);
4523+
return LowerNode(tmp2);
45254524
}
45264525

45274526
case TYP_DOUBLE:
@@ -5043,34 +5042,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
50435042
tmp1 = tmp2;
50445043
}
50455044

5046-
if (varTypeIsSIMD(node->gtType))
5047-
{
5048-
// We're producing a vector result, so just return the result directly
5049-
5050-
LIR::Use use;
5051-
5052-
if (BlockRange().TryGetUse(node, &use))
5053-
{
5054-
use.ReplaceWith(tmp1);
5055-
}
5045+
// We're producing a vector result, so just return the result directly
5046+
LIR::Use use;
50565047

5057-
BlockRange().Remove(node);
5058-
return tmp1->gtNext;
5059-
}
5060-
else
5048+
if (BlockRange().TryGetUse(node, &use))
50615049
{
5062-
// We will be constructing the following parts:
5063-
// ...
5064-
// /--* tmp1 simd16
5065-
// node = * HWINTRINSIC simd16 T ToScalar
5066-
5067-
// This is roughly the following managed code:
5068-
// ...
5069-
// return tmp1.ToScalar();
5070-
5071-
node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp1);
5072-
return LowerNode(node);
5050+
use.ReplaceWith(tmp1);
50735051
}
5052+
5053+
BlockRange().Remove(node);
5054+
return tmp1->gtNext;
50745055
}
50755056

50765057
//----------------------------------------------------------------------------------------------

src/coreclr/jit/morph.cpp

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10670,39 +10670,45 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1067010670
#endif // TARGET_ARM64
1067110671
case NI_Vector128_Create:
1067210672
{
10673-
// The `Dot` API returns a scalar. However, many common usages require it to
10674-
// be then immediately broadcast back to a vector so that it can be used in
10675-
// a subsequent operation. One of the most common is normalizing a vector
10673+
// The managed `Dot` API returns a scalar. However, many common usages require
10674+
// it to be then immediately broadcast back to a vector so that it can be used
10675+
// in a subsequent operation. One of the most common is normalizing a vector
1067610676
// which is effectively `value / value.Length` where `Length` is
10677-
// `Sqrt(Dot(value, value))`
10677+
// `Sqrt(Dot(value, value))`. Because of this, and because of how a lot of
10678+
// hardware works, we treat `NI_Vector_Dot` as returning a SIMD type and then
10679+
// also wrap it in `ToScalar` where required.
1067810680
//
1067910681
// In order to ensure that developers can still utilize this efficiently, we
10680-
// will look for two common patterns:
10682+
// then look for four common patterns:
1068110683
// * Create(Dot(..., ...))
1068210684
// * Create(Sqrt(Dot(..., ...)))
10685+
// * Create(ToScalar(Dot(..., ...)))
10686+
// * Create(ToScalar(Sqrt(Dot(..., ...))))
1068310687
//
10684-
// When these exist, we'll avoid converting to a scalar at all and just
10685-
// keep everything as a vector. However, we only do this for Vector64/Vector128
10686-
// and only for float/double.
10688+
// When these exist, we'll avoid converting to a scalar and hence, avoid broadcasting
10689+
// the value back into a vector. Instead we'll just keep everything as a vector.
1068710690
//
10688-
// We don't do this for Vector256 since that is xarch only and doesn't trivially
10689-
// support operations which cross the upper and lower 128-bit lanes
10691+
// We only do this for Vector64/Vector128 today. We could expand this more in
10692+
// the future but it would require additional hand handling for Vector256
10693+
// (since a 256-bit result requires more work). We do some integer handling
10694+
// when the value is trivially replicated to all elements without extra work.
1069010695

1069110696
if (node->GetOperandCount() != 1)
1069210697
{
1069310698
break;
1069410699
}
1069510700

10696-
if (!varTypeIsFloating(node->GetSimdBaseType()))
10697-
{
10698-
break;
10699-
}
10700-
10701-
GenTree* op1 = node->Op(1);
10702-
GenTree* sqrt = nullptr;
10701+
GenTree* op1 = node->Op(1);
10702+
GenTree* sqrt = nullptr;
10703+
GenTree* toScalar = nullptr;
1070310704

1070410705
if (op1->OperIs(GT_INTRINSIC))
1070510706
{
10707+
if (!varTypeIsFloating(node->GetSimdBaseType()))
10708+
{
10709+
break;
10710+
}
10711+
1070610712
if (op1->AsIntrinsic()->gtIntrinsicName != NI_System_Math_Sqrt)
1070710713
{
1070810714
break;
@@ -10719,6 +10725,24 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1071910725

1072010726
GenTreeHWIntrinsic* hwop1 = op1->AsHWIntrinsic();
1072110727

10728+
#if defined(TARGET_ARM64)
10729+
if ((hwop1->GetHWIntrinsicId() == NI_Vector64_ToScalar) ||
10730+
(hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar))
10731+
#else
10732+
if (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar)
10733+
#endif
10734+
{
10735+
op1 = hwop1->Op(1);
10736+
10737+
if (!op1->OperIs(GT_HWINTRINSIC))
10738+
{
10739+
break;
10740+
}
10741+
10742+
toScalar = hwop1;
10743+
hwop1 = op1->AsHWIntrinsic();
10744+
}
10745+
1072210746
#if defined(TARGET_ARM64)
1072310747
if ((hwop1->GetHWIntrinsicId() != NI_Vector64_Dot) && (hwop1->GetHWIntrinsicId() != NI_Vector128_Dot))
1072410748
#else
@@ -10728,13 +10752,16 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1072810752
break;
1072910753
}
1073010754

10731-
unsigned simdSize = node->GetSimdSize();
10732-
var_types simdType = getSIMDTypeForSize(simdSize);
10733-
10734-
hwop1->gtType = simdType;
10755+
if (toScalar != nullptr)
10756+
{
10757+
DEBUG_DESTROY_NODE(toScalar);
10758+
}
1073510759

1073610760
if (sqrt != nullptr)
1073710761
{
10762+
unsigned simdSize = node->GetSimdSize();
10763+
var_types simdType = getSIMDTypeForSize(simdSize);
10764+
1073810765
node = gtNewSimdSqrtNode(simdType, hwop1, node->GetSimdBaseJitType(), simdSize)->AsHWIntrinsic();
1073910766
DEBUG_DESTROY_NODE(sqrt);
1074010767
}

src/coreclr/jit/simdashwintrinsic.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,8 +1065,7 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
10651065
vecCon->gtSimdVal.f32[3] = +1.0f;
10661066

10671067
GenTree* conjugate = gtNewSimdBinOpNode(GT_MUL, retType, op1, vecCon, simdBaseJitType, simdSize);
1068-
1069-
op1 = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);
1068+
op1 = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);
10701069

10711070
return gtNewSimdBinOpNode(GT_DIV, retType, conjugate, op1, simdBaseJitType, simdSize);
10721071
}
@@ -1095,7 +1094,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
10951094
op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
10961095
nullptr DEBUGARG("Clone op1 for vector length squared"));
10971096

1098-
return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
1097+
op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
1098+
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
10991099
}
11001100

11011101
case NI_VectorT128_Load:
@@ -1174,7 +1174,6 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
11741174
nullptr DEBUGARG("Clone op1 for vector normalize (2)"));
11751175

11761176
op1 = gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
1177-
11781177
op1 = gtNewSimdSqrtNode(retType, op1, simdBaseJitType, simdSize);
11791178

11801179
return gtNewSimdBinOpNode(GT_DIV, retType, clonedOp2, op1, simdBaseJitType, simdSize);
@@ -1462,7 +1461,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
14621461
op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
14631462
nullptr DEBUGARG("Clone diff for vector distance squared"));
14641463

1465-
return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
1464+
op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
1465+
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
14661466
}
14671467

14681468
case NI_Quaternion_Divide:
@@ -1492,7 +1492,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
14921492
case NI_VectorT256_Dot:
14931493
#endif // TARGET_XARCH
14941494
{
1495-
return gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
1495+
op1 = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
1496+
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
14961497
}
14971498

14981499
case NI_VectorT128_Equals:

0 commit comments

Comments
 (0)