Skip to content

Commit 60d78aa

Browse files
MrSidimsvmaksimo
andauthored
[NFC] Remove legacy joint matrix instructions (#3438)
Brand-new spec: intel/llvm#12497 Signed-off-by: Dmitry Sidorov <[email protected]> Co-authored-by: Viktoria Maximova <[email protected]>
1 parent 333f956 commit 60d78aa

22 files changed

+50
-1239
lines changed

lib/SPIRV/OCLUtil.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,6 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
919919
case OpConstantSampler:
920920
case OpTypeSampler:
921921
return SPIRV_SAMPLER_T_ADDR_SPACE;
922-
case internal::OpTypeJointMatrixINTEL:
923-
case internal::OpTypeJointMatrixINTELv2:
924922
case OpTypeCooperativeMatrixKHR:
925923
case internal::OpTypeTaskSequenceINTEL:
926924
return SPIRAS_Global;

lib/SPIRV/SPIRVInternal.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ const static char ConstantSampler[] = "ConstantSampler";
317317
const static char PipeStorage[] = "PipeStorage";
318318
const static char ConstantPipeStorage[] = "ConstantPipeStorage";
319319
const static char VmeImageINTEL[] = "VmeImageINTEL";
320-
const static char JointMatrixINTEL[] = "JointMatrixINTEL";
321320
const static char CooperativeMatrixKHR[] = "CooperativeMatrixKHR";
322321
const static char BufferSurfaceINTEL[] = "BufferSurfaceINTEL";
323322
} // namespace kSPIRVTypeName
@@ -972,7 +971,6 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
972971
_SPIRV_OP(BufferSurfaceINTEL)
973972
_SPIRV_OP(CooperativeMatrixKHR)
974973
#undef _SPIRV_OP
975-
add("JointMatrixINTEL", internal::OpTypeJointMatrixINTEL);
976974
add("TaskSequenceINTEL", internal::OpTypeTaskSequenceINTEL);
977975
}
978976

lib/SPIRV/SPIRVReader.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -475,31 +475,6 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
475475
}
476476
return mapType(T, Ty);
477477
}
478-
case internal::OpTypeJointMatrixINTEL: {
479-
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
480-
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
481-
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
482-
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
483-
if (auto *Layout = MT->getLayout())
484-
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
485-
Params.push_back(
486-
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
487-
if (auto *Use = MT->getUse())
488-
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
489-
auto *CTI = MT->getComponentTypeInterpretation();
490-
if (!CTI)
491-
return mapType(
492-
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
493-
transType(MT->getCompType()), Params));
494-
const unsigned CTIValue =
495-
static_cast<SPIRVConstant *>(CTI)->getZExtIntValue();
496-
assert(CTIValue <= internal::InternalJointMatrixCTI::PackedInt4 &&
497-
"Unknown matrix component type interpretation");
498-
Params.push_back(CTIValue);
499-
return mapType(
500-
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
501-
transType(MT->getCompType()), Params));
502-
}
503478
case OpTypeCooperativeMatrixKHR: {
504479
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(T);
505480
unsigned Scope =
@@ -2561,7 +2536,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25612536
auto *Load = new LoadInst(ST, Alloca, "load", false, BB);
25622537
return mapValue(BV, Load);
25632538
}
2564-
case internal::OpTypeJointMatrixINTEL:
25652539
case OpTypeCooperativeMatrixKHR:
25662540
case internal::OpTypeTaskSequenceINTEL:
25672541
return mapValue(BV, transSPIRVBuiltinFromInst(CC, BB));
@@ -2592,9 +2566,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25922566
case OpVectorExtractDynamic: {
25932567
auto *VED = static_cast<SPIRVVectorExtractDynamic *>(BV);
25942568
SPIRVValue *Vec = VED->getVector();
2595-
if (Vec->getType()->getOpCode() == internal::OpTypeJointMatrixINTEL) {
2596-
return mapValue(BV, transSPIRVBuiltinFromInst(VED, BB));
2597-
}
25982569
return mapValue(
25992570
BV, ExtractElementInst::Create(transValue(Vec, F, BB),
26002571
transValue(VED->getIndex(), F, BB),
@@ -2625,9 +2596,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
26252596
case OpVectorInsertDynamic: {
26262597
auto *VID = static_cast<SPIRVVectorInsertDynamic *>(BV);
26272598
SPIRVValue *Vec = VID->getVector();
2628-
if (Vec->getType()->getOpCode() == internal::OpTypeJointMatrixINTEL) {
2629-
return mapValue(BV, transSPIRVBuiltinFromInst(VID, BB));
2630-
}
26312599
return mapValue(
26322600
BV, InsertElementInst::Create(
26332601
transValue(Vec, F, BB), transValue(VID->getComponent(), F, BB),
@@ -3913,7 +3881,6 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
39133881
case OpUDotAccSatKHR:
39143882
case OpSUDotAccSatKHR:
39153883
case OpReadClockKHR:
3916-
case internal::OpJointMatrixLoadINTEL:
39173884
case OpCooperativeMatrixLoadKHR:
39183885
case internal::OpCooperativeMatrixLoadCheckedINTEL:
39193886
case internal::OpCooperativeMatrixLoadOffsetINTEL:

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -592,18 +592,6 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
592592
ArrayRef<unsigned> Ops = TargetTy->int_params();
593593
return mapType(T, BM->addBufferSurfaceINTELType(CastAccess(Ops[0])));
594594
}
595-
case internal::OpTypeJointMatrixINTEL: {
596-
// The expected representation is:
597-
// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%,
598-
// %layout%, %scope%, %use%,
599-
// (optional) %element_type_interpretation%)
600-
auto *ElemTy = transType(TargetTy->getTypeParameter(0));
601-
ArrayRef<unsigned> Ops = TargetTy->int_params();
602-
std::vector<SPIRVValue *> Args;
603-
for (const auto &Op : Ops)
604-
Args.emplace_back(transConstant(getUInt32(M, Op)));
605-
return mapType(T, BM->addJointMatrixINTELType(ElemTy, Args));
606-
}
607595
case OpTypeCooperativeMatrixKHR: {
608596
// The expected representation is:
609597
// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,6 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
8686
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
8787
std::end(Table));
8888

89-
// TODO: To remove this when we make a switch to new version
90-
if (OpCode == internal::OpTypeJointMatrixINTELv2)
91-
OpCode = internal::OpTypeJointMatrixINTEL;
92-
9389
// OpAtomicCompareExchangeWeak is removed starting from SPIR-V 1.4
9490
if (OpCode == OpAtomicCompareExchangeWeak)
9591
OpCode = OpAtomicCompareExchange;

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,11 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
206206
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
207207
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
208208
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
209-
{internal::CapabilityJointMatrixINTEL});
210-
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
211-
{internal::CapabilityJointMatrixINTEL});
212-
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
213-
{internal::CapabilityJointMatrixINTEL});
214-
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
215-
{internal::CapabilityJointMatrixINTEL});
216-
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
217-
{internal::CapabilityJointMatrixINTEL});
209+
{CapabilityCooperativeMatrixKHR});
210+
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL,
211+
{CapabilityCooperativeMatrixKHR});
212+
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
213+
{CapabilityCooperativeMatrixKHR});
218214
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
219215
{CapabilityCooperativeMatrixKHR});
220216
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,8 +2127,6 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
21272127
break;
21282128
case OpTypeArray:
21292129
case OpTypeStruct:
2130-
case internal::OpTypeJointMatrixINTEL:
2131-
case internal::OpTypeJointMatrixINTELv2:
21322130
case OpTypeCooperativeMatrixKHR:
21332131
break;
21342132
default:
@@ -2388,8 +2386,7 @@ class SPIRVVectorExtractDynamic : public SPIRVInstruction {
23882386
SPIRVInstruction::validate();
23892387
if (getValue(VectorId)->isForward())
23902388
return;
2391-
assert(getValueType(VectorId)->isTypeVector() ||
2392-
getValueType(VectorId)->isTypeJointMatrixINTEL());
2389+
assert(getValueType(VectorId)->isTypeVector());
23932390
}
23942391
SPIRVId VectorId;
23952392
SPIRVId IndexId;
@@ -2426,8 +2423,7 @@ class SPIRVVectorInsertDynamic : public SPIRVInstruction {
24262423
SPIRVInstruction::validate();
24272424
if (getValue(VectorId)->isForward())
24282425
return;
2429-
assert(getValueType(VectorId)->isTypeVector() ||
2430-
getValueType(VectorId)->isTypeJointMatrixINTEL());
2426+
assert(getValueType(VectorId)->isTypeVector());
24312427
}
24322428
SPIRVId VectorId;
24332429
SPIRVId IndexId;
@@ -3605,8 +3601,9 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
36053601
SPIRVCapVec getRequiredCapability() const override {
36063602
SPIRVType *ResCompTy = this->getType();
36073603
if (ResCompTy->isTypeCooperativeMatrixKHR())
3608-
return getVec(CapabilityBFloat16ConversionINTEL,
3609-
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
3604+
return getVec(
3605+
CapabilityBFloat16ConversionINTEL,
3606+
internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL);
36103607
return getVec(CapabilityBFloat16ConversionINTEL);
36113608
}
36123609

@@ -3701,26 +3698,6 @@ class SPIRVJointMatrixINTELInstBase : public SPIRVInstTemplateBase {
37013698
}
37023699
};
37033700

3704-
class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
3705-
SPIRVCapVec getRequiredCapability() const override {
3706-
return getVec(internal::CapabilityJointMatrixINTEL);
3707-
}
3708-
};
3709-
3710-
#define _SPIRV_OP(x, ...) \
3711-
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELInst, internal::Op##x##INTEL, \
3712-
__VA_ARGS__> \
3713-
SPIRV##x##INTEL;
3714-
_SPIRV_OP(JointMatrixLoad, true, 6, true)
3715-
_SPIRV_OP(JointMatrixStore, false, 5, true)
3716-
_SPIRV_OP(JointMatrixMad, true, 6, true)
3717-
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
3718-
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
3719-
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
3720-
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
3721-
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
3722-
#undef _SPIRV_OP
3723-
37243701
class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
37253702
protected:
37263703
SPIRVCapVec getRequiredCapability() const override {
@@ -4032,8 +4009,9 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
40324009
SPIRVCapVec getRequiredCapability() const override {
40334010
SPIRVType *ResCompTy = this->getType();
40344011
if (ResCompTy->isTypeCooperativeMatrixKHR())
4035-
return getVec(CapabilityTensorFloat32RoundingINTEL,
4036-
internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
4012+
return getVec(
4013+
CapabilityTensorFloat32RoundingINTEL,
4014+
internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL);
40374015
return getVec(CapabilityTensorFloat32RoundingINTEL);
40384016
}
40394017

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,6 @@ class SPIRVModuleImpl : public SPIRVModule {
328328
SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) override;
329329
void closeStructType(SPIRVTypeStruct *T, bool) override;
330330
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
331-
SPIRVTypeJointMatrixINTEL *
332-
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
333331
SPIRVTypeCooperativeMatrixKHR *
334332
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
335333
SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() override;
@@ -1173,12 +1171,6 @@ SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType,
11731171
return addType(Ty);
11741172
}
11751173

1176-
SPIRVTypeJointMatrixINTEL *
1177-
SPIRVModuleImpl::addJointMatrixINTELType(SPIRVType *CompType,
1178-
std::vector<SPIRVValue *> Args) {
1179-
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Args));
1180-
}
1181-
11821174
SPIRVTypeCooperativeMatrixKHR *
11831175
SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
11841176
std::vector<SPIRVValue *> Args) {

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ class SPIRVModule {
289289
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
290290
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
291291
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
292-
virtual SPIRVTypeJointMatrixINTEL *
293-
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
294292
virtual SPIRVTypeCooperativeMatrixKHR *
295293
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
296294
virtual SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() = 0;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -657,23 +657,18 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
657657
add(CapabilityTernaryBitwiseFunctionINTEL, "TernaryBitwiseFunctionINTEL");
658658
// From spirv_internal.hpp
659659
add(internal::CapabilityTokenTypeINTEL, "TokenTypeINTEL");
660-
add(internal::CapabilityJointMatrixINTEL, "JointMatrixINTEL");
661660
add(internal::CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
662661
add(internal::CapabilityGlobalVariableDecorationsINTEL,
663662
"GlobalVariableDecorationsINTEL");
664663
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
665664
add(CapabilityTensorFloat32RoundingINTEL,
666665
"TensorFloat32RoundingINTEL");
667666
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
668-
"JointMatrixWIInstructionsINTEL");
669-
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
670-
"JointMatrixTF32ComponentTypeINTEL");
671-
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
672-
"JointMatrixBF16ComponentTypeINTEL");
673-
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
674-
"JointMatrixPackedInt2ComponentTypeINTEL");
675-
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
676-
"JointMatrixPackedInt4ComponentTypeINTEL");
667+
"CooperativeMatrixInvocationInstructionsINTEL");
668+
add(internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL,
669+
"CooperativeMatrixTF32ComponentTypeINTEL");
670+
add(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
671+
"CooperativeMatrixBFloat16ComponentTypeINTEL");
677672
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
678673
"CooperativeMatrixPrefetchINTEL");
679674
add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,

0 commit comments

Comments
 (0)