|
12 | 12 | #include "common/ArgumentConversion.h" |
13 | 13 | #include "common/ArgumentWrapper.h" |
14 | 14 | #include "common/Environment.h" |
| 15 | +#include "common/LayoutInfo.h" |
15 | 16 | #include "cudaq/Optimizer/Builder/Marshal.h" |
16 | 17 | #include "cudaq/Optimizer/Builder/Runtime.h" |
17 | 18 | #include "cudaq/Optimizer/CAPI/Dialects.h" |
@@ -169,31 +170,6 @@ py::args cudaq::simplifiedValidateInputArguments(py::args &args) { |
169 | 170 | return processed; |
170 | 171 | } |
171 | 172 |
|
172 | | -std::pair<std::size_t, std::vector<std::size_t>> |
173 | | -cudaq::getTargetLayout(mlir::ModuleOp mod, cudaq::cc::StructType structTy) { |
174 | | - mlir::StringRef dataLayoutSpec = ""; |
175 | | - if (auto attr = mod->getAttr(cudaq::opt::factory::targetDataLayoutAttrName)) |
176 | | - dataLayoutSpec = mlir::cast<mlir::StringAttr>(attr); |
177 | | - else |
178 | | - throw std::runtime_error("No data layout attribute is set on the module."); |
179 | | - |
180 | | - auto dataLayout = llvm::DataLayout(dataLayoutSpec); |
181 | | - // Convert bufferTy to llvm. |
182 | | - llvm::LLVMContext context; |
183 | | - mlir::LLVMTypeConverter converter(structTy.getContext()); |
184 | | - cudaq::opt::initializeTypeConversions(converter); |
185 | | - auto llvmDialectTy = converter.convertType(structTy); |
186 | | - mlir::LLVM::TypeToLLVMIRTranslator translator(context); |
187 | | - auto *llvmStructTy = |
188 | | - mlir::cast<llvm::StructType>(translator.translateType(llvmDialectTy)); |
189 | | - auto *layout = dataLayout.getStructLayout(llvmStructTy); |
190 | | - auto strSize = layout->getSizeInBytes(); |
191 | | - std::vector<std::size_t> fieldOffsets; |
192 | | - for (std::size_t i = 0, I = structTy.getMembers().size(); i != I; ++i) |
193 | | - fieldOffsets.emplace_back(layout->getElementOffset(i)); |
194 | | - return {strSize, fieldOffsets}; |
195 | | -} |
196 | | - |
197 | 173 | void cudaq::handleStructMemberVariable(void *data, std::size_t offset, |
198 | 174 | mlir::Type memberType, |
199 | 175 | py::object value) { |
@@ -626,75 +602,11 @@ cudaq::OpaqueArguments *cudaq::toOpaqueArgs(py::args &args, MlirModule mod, |
626 | 602 | static void appendTheResultValue(ModuleOp module, const std::string &name, |
627 | 603 | cudaq::OpaqueArguments &runtimeArgs, |
628 | 604 | Type returnType) { |
629 | | - TypeSwitch<Type, void>(returnType) |
630 | | - .Case([&](IntegerType type) { |
631 | | - if (type.getIntOrFloatBitWidth() == 1) { |
632 | | - bool *ourAllocatedArg = new bool(); |
633 | | - *ourAllocatedArg = 0; |
634 | | - runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) { |
635 | | - delete static_cast<bool *>(ptr); |
636 | | - }); |
637 | | - return; |
638 | | - } |
639 | | - |
640 | | - long *ourAllocatedArg = new long(); |
641 | | - *ourAllocatedArg = 0; |
642 | | - runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) { |
643 | | - delete static_cast<long *>(ptr); |
644 | | - }); |
645 | | - }) |
646 | | - .Case([&](ComplexType type) { |
647 | | - Py_complex *ourAllocatedArg = new Py_complex(); |
648 | | - ourAllocatedArg->real = 0.0; |
649 | | - ourAllocatedArg->imag = 0.0; |
650 | | - runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) { |
651 | | - delete static_cast<Py_complex *>(ptr); |
652 | | - }); |
653 | | - }) |
654 | | - .Case([&](Float64Type type) { |
655 | | - double *ourAllocatedArg = new double(); |
656 | | - *ourAllocatedArg = 0.; |
657 | | - runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) { |
658 | | - delete static_cast<double *>(ptr); |
659 | | - }); |
660 | | - }) |
661 | | - .Case([&](Float32Type type) { |
662 | | - float *ourAllocatedArg = new float(); |
663 | | - *ourAllocatedArg = 0.; |
664 | | - runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) { |
665 | | - delete static_cast<float *>(ptr); |
666 | | - }); |
667 | | - }) |
668 | | - .Case([&](cudaq::cc::StdvecType ty) { |
669 | | - // Vector is a span: `{ data, length }`. |
670 | | - struct vec { |
671 | | - char *data; |
672 | | - std::size_t length; |
673 | | - }; |
674 | | - vec *ourAllocatedArg = new vec{nullptr, 0}; |
675 | | - runtimeArgs.emplace_back( |
676 | | - ourAllocatedArg, [](void *ptr) { delete static_cast<vec *>(ptr); }); |
677 | | - }) |
678 | | - .Case([&](cudaq::cc::StructType ty) { |
679 | | - auto [size, offsets] = cudaq::getTargetLayout(module, ty); |
680 | | - auto ourAllocatedArg = std::malloc(size); |
681 | | - runtimeArgs.emplace_back(ourAllocatedArg, |
682 | | - [](void *ptr) { std::free(ptr); }); |
683 | | - }) |
684 | | - .Case([&](cudaq::cc::CallableType ty) { |
685 | | - // Callables may not be returned from entry-point kernels. Append a |
686 | | - // dummy value as a placeholder. |
687 | | - runtimeArgs.emplace_back(nullptr, [](void *) {}); |
688 | | - }) |
689 | | - .Default([](Type ty) { |
690 | | - std::string msg; |
691 | | - { |
692 | | - llvm::raw_string_ostream os(msg); |
693 | | - ty.print(os); |
694 | | - } |
695 | | - throw std::runtime_error("Unsupported CUDA-Q kernel return type - " + |
696 | | - msg + ".\n"); |
697 | | - }); |
| 605 | + auto [bufferSize, offsets] = cudaq::getResultBufferLayout(module, returnType); |
| 606 | + if (bufferSize == 0) |
| 607 | + return; |
| 608 | + auto *buf = std::calloc(1, bufferSize); |
| 609 | + runtimeArgs.emplace_back(buf, [](void *ptr) { std::free(ptr); }); |
698 | 610 | } |
699 | 611 |
|
700 | 612 | // Launching the module \p mod will modify its content, such as by argument |
|
0 commit comments