Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ class IExecutionProvider {
return InlinedVector<const Node*>();
}

/**
* Returns a the underlying OrtEp instance if this IExecutionProvider wraps a plugin EP.
* Otherwise, returns a nullptr (default implementation).
*/
virtual const OrtEp* GetOrtEp() const {
return nullptr;
}

private:
const std::string type_;

Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2779,6 +2779,10 @@ struct KernelInfoImpl : Base<T> {
Logger GetLogger() const;

KeyValuePairs GetConfigEntries() const;

std::string GetOperatorType() const; ///< Wraps KernelInfo_GetOperatorType
int GetSinceVersion() const; ///< Wraps KernelInfo_GetSinceVersion
const OrtEp* GetEp() const; ///< Wraps KernelInfo_GetEp
};

} // namespace detail
Expand Down
21 changes: 21 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,27 @@ inline KeyValuePairs KernelInfoImpl<T>::GetConfigEntries() const {
return KeyValuePairs{out};
}

template <typename T>
inline std::string KernelInfoImpl<T>::GetOperatorType() const {
const char* op_type = nullptr;
Ort::ThrowOnError(GetEpApi().KernelInfo_GetOperatorType(this->p_, &op_type));
return std::string{op_type};
}

template <typename T>
inline int KernelInfoImpl<T>::GetSinceVersion() const {
int out = 0;
ThrowOnError(GetEpApi().KernelInfo_GetSinceVersion(this->p_, &out));
return out;
}

template <typename T>
inline const OrtEp* KernelInfoImpl<T>::GetEp() const {
const OrtEp* ep = nullptr;
ThrowOnError(GetEpApi().KernelInfo_GetEp(this->p_, &ep));
return ep;
}

inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
}
Expand Down
36 changes: 36 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,42 @@ struct OrtEpApi {
*/
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

/** \brief Get the graph node operator type from OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain operator information.
*
* \param[in] info An instance of ::OrtKernelInfo.
* \param[out] operator_type Output parameter set to the name of the node's operator type.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type);

/** \brief Get the opset version in which the given node's operator type was first defined from OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain operator information.
*
* \param[in] info The ::OrtKernelInfo instance.
* \param[out] since_version The opset version in which the node's operator type was first defined.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version);

/** \brief Get the OrtEp instance to which the node is assigned from the OrtKernelInfo.
*
* \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp.
*
* \param[in] info The ::OrtKernelInfo instance.
* \param[out] ep Output parameter set to the OrtEp instance associated with the OrtKernelInfo.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);
};

/**
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,59 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo*
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type) {
API_IMPL_BEGIN
if (operator_type == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter for the operator type");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
*operator_type = op_info->node().OpType().c_str();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version) {
API_IMPL_BEGIN
if (since_version == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter for the operator type");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
*since_version = op_info->node().SinceVersion();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep) {
API_IMPL_BEGIN
if (ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter in which to store the OrtEp instance");
}

auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
auto internal_ep = op_info->GetExecutionProvider();

if (internal_ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtKernelInfo does not have a valid reference to an execution provider instance");
}

const OrtEp* ort_ep = internal_ep->GetOrtEp();

if (ort_ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtKernelInfo is not associated with a plugin EP (OrtEp) instance.");
}

*ep = ort_ep;
return nullptr;
API_IMPL_END
}

static constexpr OrtEpApi ort_ep_api = {
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
// and no functions can be removed (the implementation needs to change to return an error).
Expand Down Expand Up @@ -636,6 +689,9 @@ static constexpr OrtEpApi ort_ep_api = {
&OrtExecutionProviderApi::KernelDef_GetOutputMemType,
&OrtExecutionProviderApi::GetTensorDataType,
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
&OrtExecutionProviderApi::KernelInfo_GetOperatorType,
&OrtExecutionProviderApi::KernelInfo_GetSinceVersion,
&OrtExecutionProviderApi::KernelInfo_GetEp,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,9 @@ ORT_API_STATUS_IMPL(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type,
_Outptr_ const OrtDataType** out);
ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

// KernelInfo
ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, _Outptr_ const char** operator_type);
ORT_API_STATUS_IMPL(KernelInfo_GetSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version);
ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);
} // namespace OrtExecutionProviderApi
Original file line number Diff line number Diff line change
Expand Up @@ -765,4 +765,8 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std
return Status::OK();
}

const OrtEp* PluginExecutionProvider::GetOrtEp() const {
return ort_ep_.get();
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class PluginExecutionProvider : public IExecutionProvider {
Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const override;

const OrtEp* GetOrtEp() const override;

private:
struct FusedNodeState {
FusedNodeState() = default;
Expand Down
Loading