Skip to content

Commit c7e0d37

Browse files
authored
bg/lwpmiopen 193 : Integrate CK's batch norm backward training into non-tunable MIOpen solver (#2385)
1 parent 1605ca8 commit c7e0d37

16 files changed

+1164
-794
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ set( MIOpen_Source
152152
solver/activ/bwd_1.cpp
153153
solver/activ/fwd_0.cpp
154154
solver/activ/fwd_1.cpp
155+
solver/batchnorm/backward_ck.cpp
155156
solver/batchnorm/backward_per_activation.cpp
156157
solver/batchnorm/backward_per_activation_fused.cpp
157158
solver/batchnorm/backward_spatial_multiple.cpp
@@ -163,6 +164,7 @@ set( MIOpen_Source
163164
solver/batchnorm/forward_per_activation_fused.cpp
164165
solver/batchnorm/forward_spatial_multiple.cpp
165166
solver/batchnorm/forward_spatial_single.cpp
167+
solver/batchnorm/forward_training_ck.cpp
166168
solver/conv_asm_1x1u.cpp
167169
solver/conv_asm_1x1u_bias_activ_fused.cpp
168170
solver/conv_asm_1x1u_stride2.cpp

src/batch_norm_api.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,6 @@ miopenBatchNormalizationBackward(miopenHandle_t handle,
243243
const void* savedMean,
244244
const void* savedInvVariance)
245245
{
246-
// bfloat16 not supported for batchnorm operation
247-
if(miopen::deref(xDesc).GetType() == miopenBFloat16 ||
248-
miopen::deref(dyDesc).GetType() == miopenBFloat16 ||
249-
miopen::deref(dxDesc).GetType() == miopenBFloat16)
250-
{
251-
return miopenStatusNotImplemented;
252-
}
253246

254247
MIOPEN_LOG_FUNCTION(handle,
255248
bn_mode,

src/include/miopen/batchnorm/solvers.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,26 @@ struct BnCKFwdInference final : BatchnormSolver
142142
const miopen::batchnorm::ProblemDescription& problem) const override;
143143
};
144144

145+
struct BnCKBwdBackward final : BatchnormSolver
146+
{
147+
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKBwdBackward>(); }
148+
149+
bool IsApplicable(const ExecutionContext& context,
150+
const miopen::batchnorm::ProblemDescription& problem) const override;
151+
ConvSolution GetSolution(const ExecutionContext& context,
152+
const miopen::batchnorm::ProblemDescription& problem) const override;
153+
};
154+
155+
struct BnCKFwdTraining final : BatchnormSolver
156+
{
157+
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKFwdTraining>(); }
158+
159+
bool IsApplicable(const ExecutionContext& context,
160+
const miopen::batchnorm::ProblemDescription& problem) const override;
161+
ConvSolution GetSolution(const ExecutionContext& context,
162+
const miopen::batchnorm::ProblemDescription& problem) const override;
163+
};
164+
145165
} // namespace batchnorm
146166

147167
} // namespace solver

src/include/miopen/solver/implicitgemm_ck_util.hpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs,
4141
});
4242
}
4343

44-
template <typename DeviceOpType, typename CKArgsType>
45-
std::vector<std::string> FillValidKernelsIDs(const ProblemDescription& problem)
44+
template <typename DeviceOpType,
45+
typename CKArgsType,
46+
typename ProblemDescriptionType = ProblemDescription>
47+
std::vector<std::string> FillValidKernelsIDs(const ProblemDescriptionType& problem)
4648
{
4749
const auto args = CKArgsType{problem};
4850
const auto conv_ptrs = DeviceOpType::GetInstances();
@@ -59,29 +61,36 @@ std::vector<std::string> FillValidKernelsIDs(const ProblemDescription& problem)
5961
return valid_kernels;
6062
}
6163

62-
template <typename DeviceOpType, typename CKArgsType>
63-
bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& kernel_id)
64+
template <typename DeviceOpType,
65+
typename CKArgsType,
66+
typename ProblemDescriptionType = ProblemDescription>
67+
bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& kernel_id)
6468
{
6569
auto conv_ptrs = DeviceOpType::GetInstances();
6670
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);
6771

6872
return (ptr_iter != conv_ptrs.end()) && CKArgsType{problem}.IsSupportedBy(*ptr_iter);
6973
}
7074

71-
template <typename DeviceOpType, typename CKArgsType>
72-
bool IsCKApplicable(const ProblemDescription& problem)
75+
template <typename DeviceOpType,
76+
typename CKArgsType,
77+
typename ProblemDescriptionType = ProblemDescription>
78+
bool IsCKApplicable(const ProblemDescriptionType& problem)
7379
{
7480
const auto args = CKArgsType{problem};
75-
if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; }))
76-
return false;
81+
// if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; }))
82+
// return false;
7783

7884
const auto ptrs = DeviceOpType::GetInstances();
7985
return std::any_of(
8086
ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); });
8187
}
8288

83-
template <typename DeviceOpType, typename CKArgsType, typename CastType>
84-
ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::string& kernel_id)
89+
template <typename DeviceOpType,
90+
typename CKArgsType,
91+
typename CastType,
92+
typename ProblemDescriptionType = ProblemDescription>
93+
ConvSolution InitInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id)
8594
{
8695
auto conv_ptrs = DeviceOpType::GetInstances();
8796
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);
@@ -112,5 +121,41 @@ ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::st
112121
return result;
113122
}
114123

124+
template <typename DeviceOpType,
125+
typename CKArgsType,
126+
typename CastType,
127+
typename ProblemDescriptionType = ProblemDescription>
128+
ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem,
129+
const std::string& kernel_id)
130+
{
131+
auto conv_ptrs = DeviceOpType::GetInstances();
132+
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);
133+
134+
if(ptr_iter == conv_ptrs.end())
135+
return {miopenStatusInvalidValue};
136+
137+
ConvSolution result;
138+
result.invoker_factory =
139+
[ck_args = CKArgsType{problem},
140+
sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}](const std::vector<Kernel>&) mutable {
141+
return [ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr)](
142+
const Handle& handle, const AnyInvokeParams& primitive_parameters) {
143+
const auto& data_ctx = primitive_parameters.CastTo<CastType>();
144+
auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, data_ctx);
145+
auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer();
146+
147+
const auto enable_profiling = handle.IsProfilingEnabled();
148+
float elapsed_time =
149+
invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling});
150+
if(enable_profiling)
151+
{
152+
handle.ResetKernelTime();
153+
handle.AccumKernelTime(elapsed_time);
154+
}
155+
};
156+
};
157+
return result;
158+
}
159+
115160
} // namespace solver
116161
} // namespace miopen

src/ocl/batchnormocl.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ void BatchNormForwardTraining(Handle& handle,
131131
return tmp;
132132
}();
133133

134-
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdTrainingSpatialSingle,
134+
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKFwdTraining,
135+
solver::batchnorm::BnFwdTrainingSpatialSingle,
135136
solver::batchnorm::BnFwdTrainingSpatialMultiple,
136137
solver::batchnorm::BnFwdTrainingPerActivation>{};
137138

@@ -300,7 +301,7 @@ void BatchNormBackward(Handle& handle,
300301
{
301302
MIOPEN_THROW(miopenStatusBadParm);
302303
}
303-
if(dxDesc.GetType() != dyDesc.GetType() || dyDesc.GetType() != xDesc.GetType())
304+
if(dxDesc.GetType() != dyDesc.GetType())
304305
{
305306
MIOPEN_THROW(miopenStatusBadParm);
306307
}
@@ -338,15 +339,15 @@ void BatchNormBackward(Handle& handle,
338339
tmp.dx = dx;
339340
tmp.bnScale = bnScale;
340341
tmp.resultBnScaleDiff = resultBnScaleDiff;
341-
tmp.resultBnScaleDiff = resultBnScaleDiff;
342342
tmp.resultBnBiasDiff = resultBnBiasDiff;
343343
tmp.epsilon = epsilon;
344344
tmp.savedMean = savedMean;
345345
tmp.savedInvVariance = savedInvVariance;
346346
return tmp;
347347
}();
348348

349-
const auto solvers = solver::SolverContainer<solver::batchnorm::BnBwdTrainingSpatialSingle,
349+
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKBwdBackward,
350+
solver::batchnorm::BnBwdTrainingSpatialSingle,
350351
solver::batchnorm::BnBwdTrainingSpatialMultiple,
351352
solver::batchnorm::BnBwdTrainingPerActivation>{};
352353

src/solver.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
569569
RegisterWithSolver(
570570
registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM);
571571
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId());
572+
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId());
573+
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId());
572574

573575
// IMPORTANT: New solvers should be added to the end of the function!
574576
}

0 commit comments

Comments
 (0)