Skip to content

Commit 22658a6

Browse files
authored
Add a check for packed tensors for convolution solvers (#2471)
1 parent a6ae364 commit 22658a6

File tree

62 files changed

+152
-133
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+152
-133
lines changed

src/include/miopen/conv/problem_description.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,17 @@ struct ProblemDescription : ProblemDescriptionBase
367367
bool IsNCHWc_NCHWc() const;
368368
bool IsNCHWc_CHWNc() const;
369369

370+
bool HasNonPackedTensors() const
371+
{
372+
return !(in.IsPacked() && weights.IsPacked() && out.IsPacked());
373+
}
374+
375+
bool HasMixedDataTypes() const
376+
{
377+
return !(GetInDataType() == GetWeightsDataType() &&
378+
GetWeightsDataType() == GetOutDataType());
379+
}
380+
370381
void HeuristicUpdateLayouts();
371382

372383
void MakeNetworkConfig(std::string& conf_key) const;

src/include/miopen/fusion/utils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes
8484
return false;
8585
if(!conv_problem.IsFp32())
8686
return false;
87+
if(conv_problem.HasNonPackedTensors())
88+
return false;
8789
if(!conv_problem.IsLayoutDefault())
8890
return false;
8991
if(!conv_problem.direction.IsForward())

src/solver/conv_MP_bidirectional_winograd.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc
229229
}
230230

231231
if(!problem.IsLayoutDefault())
232-
{
233232
return false;
234-
}
235233

236234
{
237235
unsigned int const waves_in_group = 512 / wave_size;
@@ -323,11 +321,11 @@ bool ConvMPBidirectWinograd<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsA
323321
{
324322
// HIP backend required for sending ptr (buffer + offset)
325323
// ROCBLAS for GEMM step
324+
if(problem.HasNonPackedTensors())
325+
return false;
326326

327327
if(!problem.IsLayoutDefault())
328-
{
329328
return false;
330-
}
331329

332330
if(problem.IsTensorsCasted())
333331
return false;

src/solver/conv_asm_1x1u.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
527527
return false;
528528
if(!problem.Is2d())
529529
return false;
530+
if(problem.HasNonPackedTensors())
531+
return false;
530532
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
531533
return false;
532534
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
@@ -545,13 +547,9 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
545547

546548
const std::string name = ctx.GetStream().GetDeviceName();
547549
if(name.find("gfx9") == std::string::npos)
548-
{
549550
return false;
550-
}
551551
if(!problem.IsLayoutDefault())
552-
{
553552
return false;
554-
}
555553

556554
if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
557555
return false;

src/solver/conv_asm_1x1u_bias_activ_fused.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,6 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context,
256256
if(conv_problem.GetDilationH() != 1)
257257
return false;
258258

259-
if(conv_problem.IsTensorsCasted())
260-
return false;
261-
262259
// Check if the conovlution part is applicable
263260
return sol.IsApplicable(conv_ctx, conv_problem);
264261
}

src/solver/conv_asm_1x1u_stride2.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
489489
return false;
490490
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
491491
return false;
492+
if(problem.HasNonPackedTensors())
493+
return false;
492494
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
493495
return false;
494496
if(!ctx.rmv.IsV2orV3())
@@ -505,13 +507,9 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
505507

506508
const std::string name = ctx.GetStream().GetDeviceName();
507509
if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos)
508-
{
509510
return false;
510-
}
511511
if(!problem.IsLayoutDefault())
512-
{
513512
return false;
514-
}
515513

516514
if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
517515
return false;

src/solver/conv_asm_3x3u.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
176176
return false;
177177
if(!problem.Is2d())
178178
return false;
179+
if(problem.HasNonPackedTensors())
180+
return false;
179181
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
180182
return false;
181183
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
@@ -194,9 +196,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
194196
if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx90")))
195197
return false;
196198
if(!problem.IsLayoutDefault())
197-
{
198199
return false;
199-
}
200200

201201
if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
202202
return false;

src/solver/conv_asm_5x10u2v2b1.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
4545
return false;
4646
if(!problem.Is2d())
4747
return false;
48+
if(problem.HasNonPackedTensors())
49+
return false;
4850
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
4951
return false;
5052
if(!ctx.rmv.IsV2orV3())
@@ -63,17 +65,11 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
6365
return false;
6466
#endif
6567
if(!device_is_gfx8_9_no_xnack)
66-
{
6768
return false;
68-
}
6969
if(!problem.direction.IsBackwardData())
70-
{
7170
return false;
72-
}
7371
if(!problem.IsLayoutDefault())
74-
{
7572
return false;
76-
}
7773
if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
7874
return false;
7975

src/solver/conv_asm_5x10u2v2f1.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
4646
return false;
4747
if(!problem.Is2d())
4848
return false;
49+
if(problem.HasNonPackedTensors())
50+
return false;
4951
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
5052
return false;
5153
if(!ctx.rmv.IsV2orV3())
@@ -64,18 +66,11 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
6466
return false;
6567
#endif
6668
if(!device_is_gfx8_9_no_xnack)
67-
{
6869
return false;
69-
}
7070
if(!problem.direction.IsForward())
71-
{
7271
return false;
73-
}
7472
if(!problem.IsLayoutDefault())
75-
{
7673
return false;
77-
}
78-
7974
if(problem.IsTensorsCasted())
8075
return false;
8176

src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
5151
if(!ctx.rmv.IsV2orV3())
5252
return false;
5353

54+
if(problem.HasNonPackedTensors())
55+
return false;
56+
5457
if(problem.IsTensorsCasted())
5558
return false;
5659

@@ -65,17 +68,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
6568
#endif
6669
if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" ||
6770
name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908"))
68-
{
6971
return false;
70-
}
7172
if(!problem.direction.IsForward())
72-
{
7373
return false;
74-
}
7574
if(!problem.IsLayoutDefault())
76-
{
7775
return false;
78-
}
7976

8077
// clang-format off
8178
return problem.GetPadW() == 3 // -q

0 commit comments

Comments
 (0)