@@ -145,7 +145,7 @@ class Model
145
145
virtual std::vector<float > ToFeatures (const ProblemDescription& problem) const = 0;
146
146
};
147
147
148
- class Gfx908Model : public Model
148
+ class Gfx908Model final : public Model
149
149
{
150
150
public:
151
151
Gfx908Model () : Model(" gfx908" ) {}
@@ -255,7 +255,106 @@ class Gfx908Model : public Model
255
255
}
256
256
};
257
257
258
- std::unique_ptr<Model> GetModel (const std::string&) { return std::make_unique<Gfx908Model>(); }
258
+ class Gfx90aModel final : public Model
259
+ {
260
+ public:
261
+ Gfx90aModel () : Model(" gfx90a" ) {}
262
+ bool IsProblemSupported (const ProblemDescription& problem,
263
+ const ExecutionContext& ctx) const override
264
+ {
265
+ // check if problem is of the kind TunaNet was trained to handle
266
+ if (!problem.Is2d ())
267
+ {
268
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Problem not 2D" );
269
+ return false ;
270
+ }
271
+ if (problem.GetInLayout () != " NCHW" )
272
+ {
273
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Layout not supported" );
274
+ return false ;
275
+ }
276
+ if (problem.GetKernelStrideH () != problem.GetKernelStrideW ())
277
+ {
278
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Stride must be equal along all axes" );
279
+ return false ;
280
+ }
281
+ if (problem.GetDilationH () != problem.GetDilationW ())
282
+ {
283
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Dilation must be 1" );
284
+ return false ;
285
+ }
286
+ if (problem.GetBias () != 0 )
287
+ {
288
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Bias must be 0" );
289
+ return false ;
290
+ }
291
+ const auto data_type = problem.GetInDataType ();
292
+ if (data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
293
+ {
294
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: Unsupported data type" );
295
+ return false ;
296
+ }
297
+
298
+ // check if the context is s.t. no solver TunaNet may predict would be applicable
299
+ size_t applicable_solvers = 0 ;
300
+ for (const auto & solver_name : metadata.solver_map )
301
+ {
302
+ auto solver_id = solver::Id{solver_name.second };
303
+ auto solver = solver_id.GetSolver ();
304
+ if (solver.IsApplicable (ctx, problem))
305
+ {
306
+ applicable_solvers++;
307
+ break ;
308
+ }
309
+ }
310
+ if (applicable_solvers == 0 )
311
+ {
312
+ MIOPEN_LOG_I2 (" TunaNet Inapplicable: No solver that TunaNet may predict applies" );
313
+ return false ;
314
+ }
315
+ MIOPEN_LOG_I2 (" TunaNet Applicable" );
316
+ return true ;
317
+ }
318
+
319
+ protected:
320
+ std::vector<float > ToFeatures (const ProblemDescription& problem) const override
321
+ {
322
+ const bool isFwd = problem.GetDirection () == conv::Direction::Forward;
323
+ std::vector<float > features = {
324
+ static_cast <float >(isFwd ? problem.GetInChannels_ () : problem.GetOutChannels_ ()),
325
+ static_cast <float >(isFwd ? problem.GetInHeight_ () : problem.GetOutHeight_ ()),
326
+ static_cast <float >(isFwd ? problem.GetInWidth_ () : problem.GetOutWidth_ ()),
327
+ static_cast <float >(isFwd ? problem.GetOutChannels_ () : problem.GetInChannels_ ()),
328
+ static_cast <float >(isFwd ? problem.GetOutHeight_ () : problem.GetInHeight_ ()),
329
+ static_cast <float >(isFwd ? problem.GetOutWidth_ () : problem.GetInWidth_ ()),
330
+ static_cast <float >(problem.GetWeightsHeight_ ()),
331
+ static_cast <float >(problem.GetWeightsWidth_ ()),
332
+ static_cast <float >(problem.GetPadH ()),
333
+ static_cast <float >(problem.GetPadW ()),
334
+ static_cast <float >(problem.GetKernelStrideH ()),
335
+ static_cast <float >(problem.GetKernelStrideW ()),
336
+ static_cast <float >(problem.GetDilationH ()),
337
+ static_cast <float >(problem.GetDilationW ()),
338
+ static_cast <float >(problem.GetOutBatchSize_ ()),
339
+ static_cast <float >(metadata.EncodePrecision (problem.GetInDataType ())),
340
+ static_cast <float >(metadata.EncodeDirection (problem.GetDirection ())),
341
+ static_cast <float >(problem.GetGroupCount ())};
342
+
343
+ // normalize
344
+ for (size_t i = 0 ; i < features.size (); ++i)
345
+ features[i] = (features[i] - metadata.features_mean [i]) / metadata.features_std [i];
346
+
347
+ return features;
348
+ }
349
+ };
350
+
351
+ std::unique_ptr<Model> GetModel (const std::string& device)
352
+ {
353
+ if (device == " gfx90a" )
354
+ return std::make_unique<Gfx90aModel>();
355
+ else
356
+ return std::make_unique<Gfx908Model>();
357
+ }
259
358
260
359
std::vector<uint64_t > PredictSolver (const ProblemDescription& problem,
261
360
const ExecutionContext& ctx,
@@ -270,7 +369,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
270
369
auto db_res = db.FindRecord (static_cast <const conv::ProblemDescription&>(problem));
271
370
if (db_res)
272
371
{
273
- MIOPEN_LOG_I2 (" Cached heuristic result found" );
372
+ MIOPEN_LOG_I2 (" Cached heuristic (TunaNet) result found" );
274
373
std::vector<uint64_t > db_sol (db_res->size ());
275
374
// cast returned record to solver ids
276
375
std::transform (db_res->begin (), db_res->end (), db_sol.begin (), [](boost::any id) {
@@ -286,7 +385,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
286
385
return db_sol;
287
386
}
288
387
289
- MIOPEN_LOG_I2 (" Evaluating Heuristic " );
388
+ MIOPEN_LOG_I2 (" Evaluating TunaNet " );
290
389
291
390
std::vector<float > res = model->Forward (problem);
292
391
std::vector<std::pair<int , float >> sort_res (res.size ());
@@ -322,7 +421,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
322
421
std::stringstream ss;
323
422
for (auto & id : sol)
324
423
ss << solver::Id{id}.ToString () << " ID:" << id << " , " ;
325
- MIOPEN_LOG_I2 (" Heuristic Result: " << ss.str ());
424
+ MIOPEN_LOG_I2 (" TunaNet Result: " << ss.str ());
326
425
}
327
426
return sol;
328
427
}
0 commit comments