99
1010#include " common/JIT.h"
1111#include " common/ThunkInterface.h"
12+ #include < optional>
1213#include < string>
1314#include < vector>
1415
16+ // This header file and the types defined within are designed to have no
17+ // dependencies and be useable across the compiler and runtime. However,
18+ // constructing instances of these types is easiest done within compilation
19+ // units that do link against MLIR. We provide this functionality via free
20+ // functions, defined as friends of the types defined here and implemented in
21+ // the `cudaq-mlir-runtime` library.
22+
23+ namespace mlir {
24+ class Type ;
25+ class ModuleOp ;
26+ } // namespace mlir
27+
1528namespace cudaq {
1629
30+ // / Pre-computed result metadata, set at build time. Used at execution time
31+ // / for result buffer allocation and type conversion. Construct via
32+ // / `createResultInfo` (implemented in `cudaq-mlir-runtime`).
33+ class ResultInfo {
34+ // Friend factory function, to be used for construction.
35+ friend ResultInfo createResultInfo (mlir::Type resultType, bool isEntryPoint,
36+ mlir::ModuleOp module );
37+ friend class CompiledKernel ;
38+
39+ // / Opaque pointer to the `mlir::Type` of the result. Obtained via
40+ // / `mlir::Type::getAsOpaquePointer()`.
41+ // / Lifetime: the `MLIRContext` that owns the Type must outlive this object.
42+ const void *typeOpaquePtr = nullptr ;
43+
44+ // / Size (in bytes) of the buffer needed to hold the result value.
45+ // / Pre-computed from the MLIR type at build time.
46+ std::size_t bufferSize = 0 ;
47+
48+ // / Pre-computed struct field offsets (from `getTargetLayout`). Only non-empty
49+ // / for struct return types.
50+ std::vector<std::size_t > fieldOffsets;
51+
52+ public:
53+ // / Whether this kernel has a result that must be marshaled.
54+ bool hasResult () const { return typeOpaquePtr != nullptr ; }
55+ };
56+
1757// / @brief A compiled, ready-to-execute kernel.
1858// /
19- // / This type does not have a dependency on MLIR (or LLVM) as it only keeps
20- // / type-erased pointers to JIT-related types .
59+ // / Bundles one or more representations of a compiled kernel (JIT binary, MLIR
60+ // / module) along with metadata needed for execution and result extraction .
2161// /
22- // / The constructor is private; use the factory function in
23- // / `runtime/common/JIT.h` to construct instances.
62+ // / This type does not have a dependency on MLIR (or LLVM) as it only keeps
63+ // / type-erased / opaque pointers. Use `attachJit` (defined in
64+ // / `cudaq-mlir-runtime`) to attach a compiled JIT representation after
65+ // / construction.
2466class CompiledKernel {
2567public:
26- // / @brief Execute the JIT-ed kernel.
27- // /
28- // / If the kernel has a return type, the caller must have appended a result
29- // / buffer as the last element of \p rawArgs.
68+ // --- Construction ---
69+
70+ CompiledKernel (std::string kernelName, ResultInfo resultInfo);
71+
72+ // --- Queries ---
73+
74+ bool hasJit () const { return jitRepr.has_value (); }
75+ bool hasMlir () const { return mlirRepr.has_value (); }
76+
77+ // / Whether the kernel is fully specialized (all arguments inlined). For JIT
78+ // / kernels this means `argsCreator` is null.
79+ // / Currently, MLIR-only kernels are always considered fully specialized.
80+ bool isFullySpecialized () const {
81+ return !jitRepr || jitRepr->argsCreator == nullptr ;
82+ }
83+
84+ const std::string &getName () const { return name; }
85+
86+ // --- Execution (local JIT path) ---
87+
88+ // / @brief Execute a fully specialized kernel (no external arguments needed).
89+ KernelThunkResultType execute () const ;
90+
91+ // / @brief Execute the JIT-ed kernel with caller-provided arguments.
3092 KernelThunkResultType execute (const std::vector<void *> &rawArgs) const ;
3193
32- // TODO: remove the following two methods once the CompiledKernel is returned
33- // to Python.
94+ // TODO: remove the following two methods once the `CompiledKernel` is
95+ // returned to Python.
96+
3497 // / @brief Get the entry point of the kernel as a function pointer.
3598 // /
3699 // / The returned function pointer will expect different arguments depending
@@ -46,31 +109,53 @@ class CompiledKernel {
46109 JitEngine getEngine () const ;
47110
48111private:
49- CompiledKernel (JitEngine engine, std::string kernelName, void (*entryPoint)(),
50- int64_t (*argsCreator)(const void *, void **), bool hasResult);
112+ // Friend functions to attach compiled representations after construction.
113+ friend void attachJit (CompiledKernel &ck, JitEngine engine,
114+ bool isFullySpecialized);
51115
52- // Use the following factory function (compiled into cudaq-mlir-runtime) to
53- // construct CompiledKernels.
54- friend CompiledKernel createCompiledKernel (JitEngine engine,
55- std::string kernelName,
56- bool hasResult,
57- bool isFullySpecialized);
116+ // --- Compiled representation formats ---
58117
59- JitEngine engine;
60- std::string name;
118+ // / JIT-compiled representation of a kernel, used for local execution.
119+ struct JitRepr {
120+ JitEngine engine;
121+ void (*entryPoint)() = nullptr ;
122+ int64_t (*argsCreator)(const void *, void **) = nullptr ;
123+ };
61124
62- // Function pointers into JITEngine
63- void (*entryPoint)();
64- int64_t (*argsCreator)(const void *, void **);
125+ // / MLIR module representation for remote code generation or re-targeting.
126+ // / The opaque pointer is obtained via `ModuleOp::getAsOpaquePointer()`.
127+ // / Lifetime: the `MLIRContext` that owns the module must outlive this object.
128+ struct MlirRepr {
129+ const void *modulePtr = nullptr ;
130+ };
65131
66- bool hasResult;
132+ const JitRepr &getJit () const ;
133+ const MlirRepr &getMlir () const ;
134+
135+ std::string name;
136+ ResultInfo resultInfo;
137+ std::optional<JitRepr> jitRepr;
138+ std::optional<MlirRepr> mlirRepr;
67139};
68140
69- // / @brief Create a CompiledKernel from JIT-compiled code .
141+ // / @brief Populate the JIT representation of a `CompiledKernel` .
70142// /
71- // / `hasResult` and `isFullySpecialized` affect how the mangled kernel name
72- // / and the arguments buffer passed to the compiled kernel are constructed.
73- // / See `CompiledKernel::getEntryPoint` for more details.
74- CompiledKernel createCompiledKernel (JitEngine engine, std::string kernelName,
75- bool hasResult, bool isFullySpecialized);
143+ // / Resolves the entry point and (optionally) `argsCreator` symbols from the
144+ // / engine, using the kernel's name and result metadata to determine the
145+ // / correct mangled symbol names.
146+ // /
147+ // / Implemented in `JIT.cpp` (requires MLIR linkage).
148+ void attachJit (CompiledKernel &ck, JitEngine engine, bool isFullySpecialized);
149+
150+ // / @brief Create a `ResultInfo` from opaque MLIR type and module pointers.
151+ // /
152+ // / `resultTypePtr` is obtained via `mlir::Type::getAsOpaquePointer()` (may be
153+ // / null for void-returning kernels). `modulePtr` is obtained via
154+ // / `ModuleOp::getAsOpaquePointer()`. When `resultTypePtr` is null or
155+ // / `isEntryPoint` is false, returns an empty `ResultInfo`.
156+ // /
157+ // / Implemented in `JIT.cpp` (requires MLIR linkage).
158+ ResultInfo createResultInfo (mlir::Type resultType, bool isEntryPoint,
159+ mlir::ModuleOp module );
160+
76161} // namespace cudaq
0 commit comments