@@ -41,8 +41,10 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs,
41
41
});
42
42
}
43
43
44
- template <typename DeviceOpType, typename CKArgsType>
45
- std::vector<std::string> FillValidKernelsIDs (const ProblemDescription& problem)
44
+ template <typename DeviceOpType,
45
+ typename CKArgsType,
46
+ typename ProblemDescriptionType = ProblemDescription>
47
+ std::vector<std::string> FillValidKernelsIDs (const ProblemDescriptionType& problem)
46
48
{
47
49
const auto args = CKArgsType{problem};
48
50
const auto conv_ptrs = DeviceOpType::GetInstances ();
@@ -59,29 +61,36 @@ std::vector<std::string> FillValidKernelsIDs(const ProblemDescription& problem)
59
61
return valid_kernels;
60
62
}
61
63
62
- template <typename DeviceOpType, typename CKArgsType>
63
- bool IsCKArgsSupported (const ProblemDescription& problem, const std::string& kernel_id)
64
+ template <typename DeviceOpType,
65
+ typename CKArgsType,
66
+ typename ProblemDescriptionType = ProblemDescription>
67
+ bool IsCKArgsSupported (const ProblemDescriptionType& problem, const std::string& kernel_id)
64
68
{
65
69
auto conv_ptrs = DeviceOpType::GetInstances ();
66
70
auto ptr_iter = FindConvPtrByID (conv_ptrs, kernel_id);
67
71
68
72
return (ptr_iter != conv_ptrs.end ()) && CKArgsType{problem}.IsSupportedBy (*ptr_iter);
69
73
}
70
74
71
- template <typename DeviceOpType, typename CKArgsType>
72
- bool IsCKApplicable (const ProblemDescription& problem)
75
+ template <typename DeviceOpType,
76
+ typename CKArgsType,
77
+ typename ProblemDescriptionType = ProblemDescription>
78
+ bool IsCKApplicable (const ProblemDescriptionType& problem)
73
79
{
74
80
const auto args = CKArgsType{problem};
75
- if (!std::all_of (args.strides .begin (), args.strides .end (), [](auto x) { return x == 1 ; }))
76
- return false ;
81
+ // if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; }))
82
+ // return false;
77
83
78
84
const auto ptrs = DeviceOpType::GetInstances ();
79
85
return std::any_of (
80
86
ptrs.begin (), ptrs.end (), [&args](auto & ptr) { return args.IsSupportedBy (ptr); });
81
87
}
82
88
83
- template <typename DeviceOpType, typename CKArgsType, typename CastType>
84
- ConvSolution InitInvokerFactory (const ProblemDescription& problem, const std::string& kernel_id)
89
+ template <typename DeviceOpType,
90
+ typename CKArgsType,
91
+ typename CastType,
92
+ typename ProblemDescriptionType = ProblemDescription>
93
+ ConvSolution InitInvokerFactory (const ProblemDescriptionType& problem, const std::string& kernel_id)
85
94
{
86
95
auto conv_ptrs = DeviceOpType::GetInstances ();
87
96
auto ptr_iter = FindConvPtrByID (conv_ptrs, kernel_id);
@@ -112,5 +121,41 @@ ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::st
112
121
return result;
113
122
}
114
123
124
+ template <typename DeviceOpType,
125
+ typename CKArgsType,
126
+ typename CastType,
127
+ typename ProblemDescriptionType = ProblemDescription>
128
+ ConvSolution InitAnyInvokerFactory (const ProblemDescriptionType& problem,
129
+ const std::string& kernel_id)
130
+ {
131
+ auto conv_ptrs = DeviceOpType::GetInstances ();
132
+ auto ptr_iter = FindConvPtrByID (conv_ptrs, kernel_id);
133
+
134
+ if (ptr_iter == conv_ptrs.end ())
135
+ return {miopenStatusInvalidValue};
136
+
137
+ ConvSolution result;
138
+ result.invoker_factory =
139
+ [ck_args = CKArgsType{problem},
140
+ sh_conv_ptr = std::shared_ptr{std::move (*ptr_iter)}](const std::vector<Kernel>&) mutable {
141
+ return [ck_args = std::move (ck_args), sh_conv_ptr = std::move (sh_conv_ptr)](
142
+ const Handle& handle, const AnyInvokeParams& primitive_parameters) {
143
+ const auto & data_ctx = primitive_parameters.CastTo <CastType>();
144
+ auto argument_ptr = ck_args.MakeArgPtr (sh_conv_ptr, data_ctx);
145
+ auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer ();
146
+
147
+ const auto enable_profiling = handle.IsProfilingEnabled ();
148
+ float elapsed_time =
149
+ invoker_ptr->Run (argument_ptr.get (), {handle.GetStream (), enable_profiling});
150
+ if (enable_profiling)
151
+ {
152
+ handle.ResetKernelTime ();
153
+ handle.AccumKernelTime (elapsed_time);
154
+ }
155
+ };
156
+ };
157
+ return result;
158
+ }
159
+
115
160
} // namespace solver
116
161
} // namespace miopen
0 commit comments