Skip to content

Commit 411b345

Browse files
TunaNet Integration: MI250x (#2421)
1 parent 102fbee commit 411b345

File tree

4 files changed

+375
-6
lines changed

4 files changed

+375
-6
lines changed

src/conv/heuristics/ai_heuristics.cpp

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class Model
145145
virtual std::vector<float> ToFeatures(const ProblemDescription& problem) const = 0;
146146
};
147147

148-
class Gfx908Model : public Model
148+
class Gfx908Model final : public Model
149149
{
150150
public:
151151
Gfx908Model() : Model("gfx908") {}
@@ -255,7 +255,106 @@ class Gfx908Model : public Model
255255
}
256256
};
257257

258-
std::unique_ptr<Model> GetModel(const std::string&) { return std::make_unique<Gfx908Model>(); }
258+
class Gfx90aModel final : public Model
259+
{
260+
public:
261+
Gfx90aModel() : Model("gfx90a") {}
262+
bool IsProblemSupported(const ProblemDescription& problem,
263+
const ExecutionContext& ctx) const override
264+
{
265+
// check if problem is of the kind TunaNet was trained to handle
266+
if(!problem.Is2d())
267+
{
268+
MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D");
269+
return false;
270+
}
271+
if(problem.GetInLayout() != "NCHW")
272+
{
273+
MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported");
274+
return false;
275+
}
276+
if(problem.GetKernelStrideH() != problem.GetKernelStrideW())
277+
{
278+
MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes");
279+
return false;
280+
}
281+
if(problem.GetDilationH() != problem.GetDilationW())
282+
{
283+
MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1");
284+
return false;
285+
}
286+
if(problem.GetBias() != 0)
287+
{
288+
MIOPEN_LOG_I2("TunaNet Inapplicable: Bias must be 0");
289+
return false;
290+
}
291+
const auto data_type = problem.GetInDataType();
292+
if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
293+
{
294+
MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type");
295+
return false;
296+
}
297+
298+
// check if the context is s.t. no solver TunaNet may predict would be applicable
299+
size_t applicable_solvers = 0;
300+
for(const auto& solver_name : metadata.solver_map)
301+
{
302+
auto solver_id = solver::Id{solver_name.second};
303+
auto solver = solver_id.GetSolver();
304+
if(solver.IsApplicable(ctx, problem))
305+
{
306+
applicable_solvers++;
307+
break;
308+
}
309+
}
310+
if(applicable_solvers == 0)
311+
{
312+
MIOPEN_LOG_I2("TunaNet Inapplicable: No solver that TunaNet may predict applies");
313+
return false;
314+
}
315+
MIOPEN_LOG_I2("TunaNet Applicable");
316+
return true;
317+
}
318+
319+
protected:
320+
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
321+
{
322+
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
323+
std::vector<float> features = {
324+
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()),
325+
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()),
326+
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()),
327+
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()),
328+
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()),
329+
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()),
330+
static_cast<float>(problem.GetWeightsHeight_()),
331+
static_cast<float>(problem.GetWeightsWidth_()),
332+
static_cast<float>(problem.GetPadH()),
333+
static_cast<float>(problem.GetPadW()),
334+
static_cast<float>(problem.GetKernelStrideH()),
335+
static_cast<float>(problem.GetKernelStrideW()),
336+
static_cast<float>(problem.GetDilationH()),
337+
static_cast<float>(problem.GetDilationW()),
338+
static_cast<float>(problem.GetOutBatchSize_()),
339+
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())),
340+
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())),
341+
static_cast<float>(problem.GetGroupCount())};
342+
343+
// normalize
344+
for(size_t i = 0; i < features.size(); ++i)
345+
features[i] = (features[i] - metadata.features_mean[i]) / metadata.features_std[i];
346+
347+
return features;
348+
}
349+
};
350+
351+
std::unique_ptr<Model> GetModel(const std::string& device)
352+
{
353+
if(device == "gfx90a")
354+
return std::make_unique<Gfx90aModel>();
355+
else
356+
return std::make_unique<Gfx908Model>();
357+
}
259358

260359
std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
261360
const ExecutionContext& ctx,
@@ -270,7 +369,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
270369
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
271370
if(db_res)
272371
{
273-
MIOPEN_LOG_I2("Cached heuristic result found");
372+
MIOPEN_LOG_I2("Cached heuristic (TunaNet) result found");
274373
std::vector<uint64_t> db_sol(db_res->size());
275374
// cast returned record to solver ids
276375
std::transform(db_res->begin(), db_res->end(), db_sol.begin(), [](boost::any id) {
@@ -286,7 +385,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
286385
return db_sol;
287386
}
288387

289-
MIOPEN_LOG_I2("Evaluating Heuristic");
388+
MIOPEN_LOG_I2("Evaluating TunaNet");
290389

291390
std::vector<float> res = model->Forward(problem);
292391
std::vector<std::pair<int, float>> sort_res(res.size());
@@ -322,7 +421,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
322421
std::stringstream ss;
323422
for(auto& id : sol)
324423
ss << solver::Id{id}.ToString() << " ID:" << id << ", ";
325-
MIOPEN_LOG_I2("Heuristic Result: " << ss.str());
424+
MIOPEN_LOG_I2("TunaNet Result: " << ss.str());
326425
}
327426
return sol;
328427
}

src/include/miopen/problem_description.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ struct ProblemDescriptionCompatTemporary
141141
int GetOutWidth() const { return out_width; }
142142
// int GetOutDepth() const { return out_depth; }
143143
int GetBatchSize() const { return batch_sz; }
144-
// int GetBias() const { return bias; }
144+
int GetBias() const { return bias; }
145145
// std::string GetInLayout() const { return in_layout; }
146146
// std::string GetOutLayout() const { return out_layout; }
147147
miopenDataType_t GetInDataType() const { return in_data_type; }

src/kernels/gfx90a.tn.model

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)