Skip to content

Commit 3d7233e

Browse files
committed
Extend CompiledKernel to fat binary with JIT/MLIR support
1 parent 5372c6b commit 3d7233e

File tree

5 files changed

+201
-54
lines changed

5 files changed

+201
-54
lines changed

runtime/common/CompiledKernel.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,36 @@
77
******************************************************************************/
88

99
#include "CompiledKernel.h"
10+
#include <memory>
11+
#include <stdexcept>
1012

11-
cudaq::CompiledKernel::CompiledKernel(
12-
JitEngine engine, std::string kernelName, void (*entryPoint)(),
13-
int64_t (*argsCreator)(const void *, void **), bool hasResult)
14-
: engine(engine), name(std::move(kernelName)), entryPoint(entryPoint),
15-
argsCreator(argsCreator), hasResult(hasResult) {}
13+
cudaq::CompiledKernel::CompiledKernel(std::string kernelName,
14+
ResultInfo resultInfo)
15+
: name(std::move(kernelName)), resultInfo(std::move(resultInfo)) {}
16+
17+
const cudaq::CompiledKernel::JitRepr &cudaq::CompiledKernel::getJit() const {
18+
if (!jitRepr)
19+
throw std::runtime_error("CompiledKernel has no JIT representation.");
20+
return *jitRepr;
21+
}
22+
23+
const cudaq::CompiledKernel::MlirRepr &cudaq::CompiledKernel::getMlir() const {
24+
if (!mlirRepr)
25+
throw std::runtime_error("CompiledKernel has no MLIR representation.");
26+
return *mlirRepr;
27+
}
1628

1729
cudaq::KernelThunkResultType
1830
cudaq::CompiledKernel::execute(const std::vector<void *> &rawArgs) const {
1931
auto funcPtr = getEntryPoint();
20-
if (hasResult) {
32+
if (resultInfo.hasResult()) {
2133
void *buff = const_cast<void *>(rawArgs.back());
2234
return reinterpret_cast<KernelThunkResultType (*)(void *, bool)>(funcPtr)(
2335
buff, /*client_server=*/false);
2436
}
25-
if (argsCreator) {
37+
if (jitRepr && jitRepr->argsCreator) {
2638
void *buff = nullptr;
27-
argsCreator(static_cast<const void *>(rawArgs.data()), &buff);
39+
jitRepr->argsCreator(static_cast<const void *>(rawArgs.data()), &buff);
2840
reinterpret_cast<KernelThunkResultType (*)(void *, bool)>(funcPtr)(
2941
buff, /*client_server=*/false);
3042
std::free(buff);
@@ -35,6 +47,25 @@ cudaq::CompiledKernel::execute(const std::vector<void *> &rawArgs) const {
3547
return {nullptr, 0};
3648
}
3749

38-
void (*cudaq::CompiledKernel::getEntryPoint() const)() { return entryPoint; }
50+
cudaq::KernelThunkResultType cudaq::CompiledKernel::execute() const {
51+
if (jitRepr && jitRepr->argsCreator)
52+
throw std::runtime_error(
53+
"Kernel has unspecialized parameters; call execute(rawArgs) instead.");
54+
if (!resultInfo.hasResult()) {
55+
getEntryPoint()();
56+
return {nullptr, 0};
57+
}
58+
// Allocate a result buffer on-the-fly.
59+
auto buf = std::make_unique<char[]>(resultInfo.bufferSize);
60+
std::vector<void *> rawArgs = {buf.get()};
61+
execute(rawArgs);
62+
return {buf.release(), resultInfo.bufferSize};
63+
}
64+
65+
void (*cudaq::CompiledKernel::getEntryPoint() const)() {
66+
return getJit().entryPoint;
67+
}
3968

40-
cudaq::JitEngine cudaq::CompiledKernel::getEngine() const { return engine; }
69+
cudaq::JitEngine cudaq::CompiledKernel::getEngine() const {
70+
return getJit().engine;
71+
}

runtime/common/CompiledKernel.h

Lines changed: 115 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,91 @@
99

1010
#include "common/JIT.h"
1111
#include "common/ThunkInterface.h"
12+
#include <optional>
1213
#include <string>
1314
#include <vector>
1415

16+
// This header file and the types defined within are designed to have no
17+
// dependencies and be useable across the compiler and runtime. However,
18+
// constructing instances of these types is easiest done within compilation
19+
// units that do link against MLIR. We provide this functionality via free
20+
// functions, defined as friends of the types defined here and implemented in
21+
// the `cudaq-mlir-runtime` library.
22+
23+
namespace mlir {
24+
class Type;
25+
class ModuleOp;
26+
} // namespace mlir
27+
1528
namespace cudaq {
1629

30+
/// Pre-computed result metadata, set at build time. Used at execution time
31+
/// for result buffer allocation and type conversion. Construct via
32+
/// `createResultInfo` (implemented in `cudaq-mlir-runtime`).
33+
class ResultInfo {
34+
// Friend factory function, to be used for construction.
35+
friend ResultInfo createResultInfo(mlir::Type resultType, bool isEntryPoint,
36+
mlir::ModuleOp module);
37+
friend class CompiledKernel;
38+
39+
/// Opaque pointer to the `mlir::Type` of the result. Obtained via
40+
/// `mlir::Type::getAsOpaquePointer()`.
41+
/// Lifetime: the `MLIRContext` that owns the Type must outlive this object.
42+
const void *typeOpaquePtr = nullptr;
43+
44+
/// Size (in bytes) of the buffer needed to hold the result value.
45+
/// Pre-computed from the MLIR type at build time.
46+
std::size_t bufferSize = 0;
47+
48+
/// Pre-computed struct field offsets (from `getTargetLayout`). Only non-empty
49+
/// for struct return types.
50+
std::vector<std::size_t> fieldOffsets;
51+
52+
public:
53+
/// Whether this kernel has a result that must be marshaled.
54+
bool hasResult() const { return typeOpaquePtr != nullptr; }
55+
};
56+
1757
/// @brief A compiled, ready-to-execute kernel.
1858
///
19-
/// This type does not have a dependency on MLIR (or LLVM) as it only keeps
20-
/// type-erased pointers to JIT-related types.
59+
/// Bundles one or more representations of a compiled kernel (JIT binary, MLIR
60+
/// module) along with metadata needed for execution and result extraction.
2161
///
22-
/// The constructor is private; use the factory function in
23-
/// `runtime/common/JIT.h` to construct instances.
62+
/// This type does not have a dependency on MLIR (or LLVM) as it only keeps
63+
/// type-erased / opaque pointers. Use `attachJit` (defined in
64+
/// `cudaq-mlir-runtime`) to attach a compiled JIT representation after
65+
/// construction.
2466
class CompiledKernel {
2567
public:
26-
/// @brief Execute the JIT-ed kernel.
27-
///
28-
/// If the kernel has a return type, the caller must have appended a result
29-
/// buffer as the last element of \p rawArgs.
68+
// --- Construction ---
69+
70+
CompiledKernel(std::string kernelName, ResultInfo resultInfo);
71+
72+
// --- Queries ---
73+
74+
bool hasJit() const { return jitRepr.has_value(); }
75+
bool hasMlir() const { return mlirRepr.has_value(); }
76+
77+
/// Whether the kernel is fully specialized (all arguments inlined). For JIT
78+
/// kernels this means `argsCreator` is null.
79+
/// Currently, MLIR-only kernels are always considered fully specialized.
80+
bool isFullySpecialized() const {
81+
return !jitRepr || jitRepr->argsCreator == nullptr;
82+
}
83+
84+
const std::string &getName() const { return name; }
85+
86+
// --- Execution (local JIT path) ---
87+
88+
/// @brief Execute a fully specialized kernel (no external arguments needed).
89+
KernelThunkResultType execute() const;
90+
91+
/// @brief Execute the JIT-ed kernel with caller-provided arguments.
3092
KernelThunkResultType execute(const std::vector<void *> &rawArgs) const;
3193

32-
// TODO: remove the following two methods once the CompiledKernel is returned
33-
// to Python.
94+
// TODO: remove the following two methods once the `CompiledKernel` is
95+
// returned to Python.
96+
3497
/// @brief Get the entry point of the kernel as a function pointer.
3598
///
3699
/// The returned function pointer will expect different arguments depending
@@ -46,31 +109,53 @@ class CompiledKernel {
46109
JitEngine getEngine() const;
47110

48111
private:
49-
CompiledKernel(JitEngine engine, std::string kernelName, void (*entryPoint)(),
50-
int64_t (*argsCreator)(const void *, void **), bool hasResult);
112+
// Friend functions to attach compiled representations after construction.
113+
friend void attachJit(CompiledKernel &ck, JitEngine engine,
114+
bool isFullySpecialized);
51115

52-
// Use the following factory function (compiled into cudaq-mlir-runtime) to
53-
// construct CompiledKernels.
54-
friend CompiledKernel createCompiledKernel(JitEngine engine,
55-
std::string kernelName,
56-
bool hasResult,
57-
bool isFullySpecialized);
116+
// --- Compiled representation formats ---
58117

59-
JitEngine engine;
60-
std::string name;
118+
/// JIT-compiled representation of a kernel, used for local execution.
119+
struct JitRepr {
120+
JitEngine engine;
121+
void (*entryPoint)() = nullptr;
122+
int64_t (*argsCreator)(const void *, void **) = nullptr;
123+
};
61124

62-
// Function pointers into JITEngine
63-
void (*entryPoint)();
64-
int64_t (*argsCreator)(const void *, void **);
125+
/// MLIR module representation for remote code generation or re-targeting.
126+
/// The opaque pointer is obtained via `ModuleOp::getAsOpaquePointer()`.
127+
/// Lifetime: the `MLIRContext` that owns the module must outlive this object.
128+
struct MlirRepr {
129+
const void *modulePtr = nullptr;
130+
};
65131

66-
bool hasResult;
132+
const JitRepr &getJit() const;
133+
const MlirRepr &getMlir() const;
134+
135+
std::string name;
136+
ResultInfo resultInfo;
137+
std::optional<JitRepr> jitRepr;
138+
std::optional<MlirRepr> mlirRepr;
67139
};
68140

69-
/// @brief Create a CompiledKernel from JIT-compiled code.
141+
/// @brief Populate the JIT representation of a `CompiledKernel`.
70142
///
71-
/// `hasResult` and `isFullySpecialized` affect how the mangled kernel name
72-
/// and the arguments buffer passed to the compiled kernel are constructed.
73-
/// See `CompiledKernel::getEntryPoint` for more details.
74-
CompiledKernel createCompiledKernel(JitEngine engine, std::string kernelName,
75-
bool hasResult, bool isFullySpecialized);
143+
/// Resolves the entry point and (optionally) `argsCreator` symbols from the
144+
/// engine, using the kernel's name and result metadata to determine the
145+
/// correct mangled symbol names.
146+
///
147+
/// Implemented in `JIT.cpp` (requires MLIR linkage).
148+
void attachJit(CompiledKernel &ck, JitEngine engine, bool isFullySpecialized);
149+
150+
/// @brief Create a `ResultInfo` from opaque MLIR type and module pointers.
151+
///
152+
/// `resultTypePtr` is obtained via `mlir::Type::getAsOpaquePointer()` (may be
153+
/// null for void-returning kernels). `modulePtr` is obtained via
154+
/// `ModuleOp::getAsOpaquePointer()`. When `resultTypePtr` is null or
155+
/// `isEntryPoint` is false, returns an empty `ResultInfo`.
156+
///
157+
/// Implemented in `JIT.cpp` (requires MLIR linkage).
158+
ResultInfo createResultInfo(mlir::Type resultType, bool isEntryPoint,
159+
mlir::ModuleOp module);
160+
76161
} // namespace cudaq

runtime/common/JIT.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "JIT.h"
1010
#include "CompiledKernel.h"
1111
#include "common/Environment.h"
12+
#include "common/LayoutInfo.h"
1213
#include "common/Timing.h"
1314
#include "cudaq/Frontend/nvqpp/AttributeNames.h"
1415
#include "cudaq/Optimizer/Builder/Runtime.h"
@@ -333,20 +334,37 @@ cudaq::JitEngine cudaq::createQIRJITEngine(ModuleOp &moduleOp,
333334
return JitEngine(std::move(jitOrError.get()));
334335
}
335336

336-
cudaq::CompiledKernel cudaq::createCompiledKernel(JitEngine engine,
337-
std::string kernelName,
338-
bool hasResult,
339-
bool isFullySpecialized) {
340-
std::string fullName = cudaq::runtime::cudaqGenPrefixName + kernelName;
337+
void cudaq::attachJit(CompiledKernel &ck, JitEngine engine,
338+
bool isFullySpecialized) {
339+
const auto &name = ck.name;
340+
bool hasResult = ck.resultInfo.hasResult();
341+
std::string fullName = cudaq::runtime::cudaqGenPrefixName + name;
341342
std::string entryName =
342-
(hasResult || !isFullySpecialized) ? kernelName + ".thunk" : fullName;
343+
(hasResult || !isFullySpecialized) ? name + ".thunk" : fullName;
343344
void (*entryPoint)() = engine.lookupRawNameOrFail(entryName);
344345
int64_t (*argsCreator)(const void *, void **) = nullptr;
345346
if (!isFullySpecialized)
346347
argsCreator = reinterpret_cast<int64_t (*)(const void *, void **)>(
347-
engine.lookupRawNameOrFail(kernelName + ".argsCreator"));
348-
return cudaq::CompiledKernel(engine, std::move(kernelName), entryPoint,
349-
argsCreator, hasResult);
348+
engine.lookupRawNameOrFail(name + ".argsCreator"));
349+
350+
ck.jitRepr =
351+
CompiledKernel::JitRepr{std::move(engine), entryPoint, argsCreator};
352+
}
353+
354+
/// Build a `CompiledKernel::ResultInfo` from an MLIR return type.
355+
/// \p resultTy may be null (no return value). When \p isEntryPoint is false,
356+
/// the result is not marshaled — returns an empty `ResultInfo`.
357+
cudaq::ResultInfo cudaq::createResultInfo(Type resultTy, bool isEntryPoint,
358+
ModuleOp module) {
359+
cudaq::ResultInfo info;
360+
if (!resultTy || !isEntryPoint)
361+
return info;
362+
363+
info.typeOpaquePtr = resultTy.getAsOpaquePointer();
364+
auto [size, offsets] = cudaq::getResultBufferLayout(module, resultTy);
365+
info.bufferSize = size;
366+
info.fieldOffsets = std::move(offsets);
367+
return info;
350368
}
351369

352370
class cudaq::JitEngine::Impl {

runtime/common/JIT.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ class JitEngine {
5656
JitEngine createQIRJITEngine(mlir::ModuleOp &moduleOp,
5757
llvm::StringRef convertTo);
5858

59+
class CompiledKernel;
60+
5961
} // namespace cudaq

runtime/cudaq/platform/default/python/QPU.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "QPU.h"
1010
#include "common/ArgumentConversion.h"
1111
#include "common/ArgumentWrapper.h"
12+
#include "common/CompiledKernel.h"
1213
#include "common/Environment.h"
1314
#include "common/ExecutionContext.h"
1415
#include "common/JIT.h"
@@ -17,6 +18,7 @@
1718
#include "cudaq/Optimizer/Builder/Runtime.h"
1819
#include "cudaq/Optimizer/CodeGen/OpenQASMEmitter.h"
1920
#include "cudaq/Optimizer/CodeGen/Passes.h"
21+
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
2022
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
2123
#include "cudaq/Optimizer/Transforms/AddMetadata.h"
2224
#include "cudaq/Optimizer/Transforms/Passes.h"
@@ -29,6 +31,13 @@
2931

3032
using namespace mlir;
3133

34+
// Forward declaration — defined in `py_alt_launch_kernel.cpp`, compiled into
35+
// the same Python extension binary.
36+
namespace cudaq {
37+
std::pair<std::size_t, std::vector<std::size_t>>
38+
getResultBufferLayout(mlir::ModuleOp mod, mlir::Type resultTy);
39+
}
40+
3241
static void
3342
specializeKernel(const std::string &name, ModuleOp module,
3443
const std::vector<void *> &rawArgs, Type resultTy = {},
@@ -259,11 +268,12 @@ struct PythonLauncher : public cudaq::ModuleLauncher {
259268
varArgIndices.clear();
260269
}
261270
const bool isFullySpecialized = varArgIndices.empty();
262-
const bool hasResult = !!resultTy;
271+
auto resultInfo = cudaq::createResultInfo(resultTy, isEntryPoint, module);
263272

264273
if (auto jit = alreadyBuiltJITCode(name, rawArgs)) {
265-
return cudaq::createCompiledKernel(*jit, name, hasResult && isEntryPoint,
266-
isFullySpecialized);
274+
cudaq::CompiledKernel ck(name, resultInfo);
275+
cudaq::attachJit(ck, *jit, isFullySpecialized);
276+
return ck;
267277
}
268278

269279
// 1. Check that this call is sane.
@@ -297,8 +307,9 @@ struct PythonLauncher : public cudaq::ModuleLauncher {
297307
cudaq::compiler_artifact::saveArtifact(name, rawArgs, jit,
298308
argsCreatorThunk);
299309

300-
return cudaq::createCompiledKernel(jit, name, hasResult && isEntryPoint,
301-
isFullySpecialized);
310+
cudaq::CompiledKernel ck(name, resultInfo);
311+
cudaq::attachJit(ck, jit, isFullySpecialized);
312+
return ck;
302313
}
303314
};
304315
} // namespace

0 commit comments

Comments
 (0)