Skip to content

Commit 9a49104

Browse files
authored
Pulls in llvm/llvm-project#123200 which is useful and also handles #5664. Integrations were required due to llvm/llvm-project#123026, llvm/llvm-project#123321 and llvm/llvm-project#123326. Also closes #5685
1 parent c2c193a commit 9a49104

File tree

13 files changed

+56
-53
lines changed

13 files changed

+56
-53
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e2402615a5a76d46a433dfcc1de10b38a1263c9d
1+
c118864223c6309378cd704f3406533474c2759f

include/triton/Conversion/MLIRTypes.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
2626
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
2727
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

29-
inline bool isFloat(Type type) {
30-
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
32-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
33-
type.isFloat8E5M2FNUZ();
29+
inline bool isFloat8(Type type) {
30+
return isa<Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3FNUZType,
31+
Float8E5M2Type, Float8E5M2FNUZType>(type);
3432
}
3533

36-
inline bool isFloat8(Type type) {
37-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
38-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
39-
type.isFloat8E5M2FNUZ();
34+
inline bool isFloat(Type type) {
35+
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
36+
type.isBF16() || llvm::isa<Float8E4M3B11FNUZType>(type) ||
37+
isFloat8(type);
4038
}
4139

4240
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -750,14 +750,14 @@ bool supportMMA(triton::DotOp op, int version) {
750750
return false;
751751
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
752752
retShapePerCTA[rank - 1] % 8 == 0 &&
753-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
753+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
754754
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
755755
aElemTy.isF32()))) {
756756
return false;
757757
}
758758
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
759759
if (op.getMaxNumImpreciseAcc() < 32 &&
760-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
760+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
761761
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
762762
return false;
763763
}
@@ -778,8 +778,8 @@ bool supportMMA(Value value, int version) {
778778
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
779779
// FP8 is not natively supported on all mma versions but it can always be
780780
// promoted to fp16 therefore we can always support it.
781-
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
782-
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
781+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
782+
Float8E4M3FNUZType>(elemTy);
783783
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
784784
(elemTy.isF32() && version >= 2) ||
785785
(elemTy.isInteger(8) && version >= 2);

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632632
NvidiaMmaEncodingAttr mmaLayout =
633633
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
634634
if (mmaLayout) {
635-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
635+
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
636636
// promote operands for sm < 89 since fp8 mma is not natively supported
637637
// promote operands for sm >= 90 when mma is not v3
638638
if (!isNativeFP8 ||

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
4545
SmallVector<unsigned> validN;
4646

4747
// MMAv3 with larger instruction shape is preferred.
48-
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
49-
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
50-
eltType.isF32()) {
48+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FNUZType>(
49+
eltType) ||
50+
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
5151
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
5252
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
5353
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
7777
const auto &d = getD();
7878
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
7979
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
80-
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
81-
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
80+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
81+
Float8E4M3FNUZType>(aElTy);
8282
bool accFP32 =
8383
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
8484
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,17 +1043,16 @@ struct FpToFpOpConversion
10431043
return outVals;
10441044
}
10451045
size_t numElements = 4;
1046-
if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() ||
1047-
srcElementType.isFloat8E4M3FNUZ() ||
1048-
dstElementType.isFloat8E4M3FNUZ() ||
1049-
srcElementType.isFloat8E5M2FNUZ() ||
1050-
dstElementType.isFloat8E5M2FNUZ()) {
1046+
if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1047+
srcElementType) ||
1048+
llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1049+
dstElementType)) {
10511050
numElements = 2;
10521051
}
10531052
bool useFP16IntermediateSrc =
1054-
srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 &&
1055-
(dstElementType.isFloat8E4M3FNUZ() ||
1056-
dstElementType.isFloat8E5M2FNUZ()));
1053+
srcElementType.isF32() &&
1054+
!(isaFamily == AMD::ISAFamily::CDNA3 &&
1055+
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
10571056
bool isDstFP32 = dstElementType.isF32();
10581057
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
10591058
Type dstType = isDstFP32 ? f16_ty : dstElementType;

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
416416
// store instructions, except for fp8 matmul kernels due to regression
417417
// TODO (lixun): investigate the regression and enable this feature again
418418
auto aElemTy = mfmaInstr.getElementTypeA();
419-
bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ();
419+
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
420420
bool isTransposed = isChainDot(dotOp) || !isFP8;
421421
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
422422
oldRetType.getContext(),

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA,
2020
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
2121
return MfmaTypeId::I8TyId;
2222
}
23-
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
23+
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
24+
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
2425
return MfmaTypeId::Fp8Fp8TyId;
2526
}
26-
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
27+
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
28+
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
2729
return MfmaTypeId::Fp8Bf8TyId;
2830
}
29-
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
31+
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
32+
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
3033
return MfmaTypeId::Bf8Fp8TyId;
3134
}
32-
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
35+
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
36+
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
3337
return MfmaTypeId::Bf8Bf8TyId;
3438
}
35-
if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) {
39+
if (llvm::isa<Float8E5M2Type>(dataTypeA) &&
40+
llvm::isa<Float8E5M2Type>(dataTypeB)) {
3641
return MfmaTypeId::Fp16TyId;
3742
}
3843
llvm_unreachable("Unsupported input argument type.");

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,17 +303,17 @@ TensorCoreType getMmaType(triton::DotOp op) {
303303
return TensorCoreType::FP32_FP16_FP16_FP32;
304304
if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16())
305305
return TensorCoreType::FP32_BF16_BF16_FP32;
306-
if (aTy.getElementType().isFloat8E5M2() &&
307-
bTy.getElementType().isFloat8E5M2())
306+
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
307+
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
308308
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
309-
if (aTy.getElementType().isFloat8E5M2() &&
310-
bTy.getElementType().isFloat8E4M3FN())
309+
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
310+
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
311311
return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
312-
if (aTy.getElementType().isFloat8E4M3FN() &&
313-
bTy.getElementType().isFloat8E5M2())
312+
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
313+
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
314314
return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
315-
if (aTy.getElementType().isFloat8E4M3FN() &&
316-
bTy.getElementType().isFloat8E4M3FN())
315+
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
316+
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
317317
return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
318318
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
319319
op.getInputPrecision() == InputPrecision::TF32)

0 commit comments

Comments
 (0)