Skip to content

Commit 0ddb923

Browse files
Exposes base OptionsContext class in the API
In order to support more than just the StableHloToExecutable pipeline, we need to be able to create different option types from the API. This commit exposes the base `OptionsContext` class in the Python API and includes a mechanism for child classes to register themselves with the client, allowing them to be created through a common API.
1 parent c8ee99f commit 0ddb923

File tree

10 files changed

+215
-1
lines changed

10 files changed

+215
-1
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,29 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
5151
return !options.ptr;
5252
}
5353

54+
//===----------------------------------------------------------------------===//
55+
// MTRT_OptionsContext
56+
//===----------------------------------------------------------------------===//
57+
58+
typedef struct MTRT_OptionsContext {
59+
void *ptr;
60+
} MTRT_OptionsContext;
61+
62+
MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs(
63+
MTRT_CompilerClient client, MTRT_OptionsContext *options,
64+
MlirStringRef optionsType, const MlirStringRef *argv, unsigned argc);
65+
66+
MLIR_CAPI_EXPORTED void mtrtOptionsContextPrint(MTRT_OptionsContext options,
67+
MlirStringCallback append,
68+
void *userData);
69+
70+
MLIR_CAPI_EXPORTED MTRT_Status
71+
mtrtOptionsContextDestroy(MTRT_OptionsContext options);
72+
73+
static inline bool mtrtOptionsConextIsNull(MTRT_OptionsContext options) {
74+
return !options.ptr;
75+
}
76+
5477
//===----------------------------------------------------------------------===//
5578
// MTRT_StableHLOToExecutableOptions
5679
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@
2828
#define MLIR_TENSORRT_COMPILER_CLIENT
2929

3030
#include "mlir-executor/Support/Status.h"
31+
#include "mlir-tensorrt-dialect/Utils/Options.h"
3132
#include "mlir-tensorrt/Compiler/Options.h"
33+
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3234
#include "mlir/IR/MLIRContext.h"
3335
#include "mlir/Pass/PassManager.h"
3436
#include "mlir/Support/TypeID.h"
3537
#include "llvm/ADT/DenseMap.h"
3638
#include "llvm/ADT/Hashing.h"
39+
#include "llvm/ADT/StringRef.h"
3740
#include "llvm/Support/ErrorHandling.h"
41+
#include <functional>
3842
#include <memory>
3943

4044
namespace mlirtrt::compiler {
@@ -85,6 +89,19 @@ class CompilationTask : public CompilationTaskBase {
8589
// CompilerClient
8690
//===----------------------------------------------------------------------===//
8791

92+
class CompilerClient;
93+
94+
using OptionsConstructorFuncT =
95+
std::function<StatusOr<std::unique_ptr<mlir::OptionsContext>>(
96+
const CompilerClient &client, const llvm::ArrayRef<llvm::StringRef>)>;
97+
98+
void registerOption(const llvm::StringRef optionsType,
99+
OptionsConstructorFuncT func);
100+
101+
StatusOr<std::unique_ptr<mlir::OptionsContext>>
102+
createOptions(const CompilerClient &client, const llvm::StringRef optionsType,
103+
const llvm::ArrayRef<llvm::StringRef> args);
104+
88105
/// C++ users of the MLIR-TensorRT Compiler API should create a CompilerClient
89106
/// once for each process or thread that will be performing concurrent
90107
/// compilation work. The CompilerClient holds long-lived resources such as the
@@ -147,6 +164,39 @@ class CompilerClient {
147164
cachedPassManagers;
148165
};
149166

167+
/// Helper to register option types with the client
168+
template <typename OptionsT, typename TaskT>
169+
StatusOr<std::unique_ptr<mlir::OptionsContext>>
170+
optionsCreateFromArgs(const CompilerClient &client,
171+
const llvm::ArrayRef<llvm::StringRef> args) {
172+
// Load available extensions.
173+
mlir::MLIRContext *context = client.getContext();
174+
mlir::plan::PlanDialect *planDialect =
175+
context->getLoadedDialect<mlir::plan::PlanDialect>();
176+
compiler::TaskExtensionRegistry extensions =
177+
planDialect->extensionConstructors.getExtensionRegistryForTask<TaskT>();
178+
179+
auto result = std::make_unique<OptionsT>(std::move(extensions));
180+
181+
std::string err;
182+
if (failed(result->parse(args, err))) {
183+
return getInternalErrorStatus(
184+
"failed to parse options string \"{0:$[ ]}\" due to error {1}",
185+
llvm::iterator_range(args), err);
186+
}
187+
188+
// TODO: Figure out whether to add a method in the base class like
189+
// "finalizeOptions" or a callback here, or something else if
190+
// `inferDeviceOptionsFromHost` is unique to StableHLO.
191+
//
192+
// Populate device options from host information.
193+
Status inferStatus = result->inferDeviceOptionsFromHost();
194+
if (!inferStatus.isOk())
195+
return inferStatus;
196+
197+
return std::unique_ptr<mlir::OptionsContext>(result.release());
198+
}
199+
150200
} // namespace mlirtrt::compiler
151201

152202
#endif // MLIR_TENSORRT_COMPILER_CLIENT

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class StableHloToExecutableTask
189189
const StableHLOToExecutableOptions &options);
190190
};
191191

192+
/// Register the task/options with the client's registry.
193+
void registerStableHloToExecutableTask();
194+
192195
//===----------------------------------------------------------------------===//
193196
// Pipeline Registrations
194197
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ inline void registerAllMlirTensorRtPasses() {
5353
mlir::registerConvertPDLToPDLInterp();
5454

5555
#ifdef MLIR_TRT_ENABLE_HLO
56+
mlirtrt::compiler::registerStableHloToExecutableTask();
5657
mlirtrt::compiler::registerStablehloClusteringPipelines();
5758
registerStableHloInputPipelines();
5859
stablehlo_ext::registerStableHloExtPasses();

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
#include "mlir-c/Support.h"
2727
#include "mlir-executor-c/Support/Status.h"
2828
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
29+
#include "mlir-tensorrt-dialect/Utils/Options.h"
2930
#include "mlir-tensorrt/Compiler/Extension.h"
3031
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
3132
#include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h"
3233
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3334
#include "mlir/CAPI/IR.h"
35+
#include "mlir/CAPI/Utils.h"
3436
#include "llvm/ADT/StringExtras.h"
3537

3638
using namespace mlirtrt;
@@ -44,6 +46,7 @@ using namespace mlir;
4446
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
4547
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
4648
StableHLOToExecutableOptions)
49+
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
4750
#if defined(__GNUC__) || defined(__clang__)
4851
#pragma GCC diagnostic pop
4952
#endif
@@ -99,6 +102,40 @@ MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) {
99102
return mtrtStatusGetOk();
100103
}
101104

105+
//===----------------------------------------------------------------------===//
106+
// MTRT_OptionsContext
107+
//===----------------------------------------------------------------------===//
108+
109+
MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs(
110+
MTRT_CompilerClient client, MTRT_OptionsContext *options,
111+
MlirStringRef optionsType, const MlirStringRef *argv, unsigned argc) {
112+
std::vector<llvm::StringRef> argvStrRef(argc);
113+
for (unsigned i = 0; i < argc; i++)
114+
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);
115+
116+
auto result = createOptions(
117+
*unwrap(client), llvm::StringRef(optionsType.data, optionsType.length),
118+
argvStrRef);
119+
if (!result.isOk())
120+
return wrap(result.getStatus());
121+
122+
*options = wrap(result->release());
123+
return mtrtStatusGetOk();
124+
}
125+
126+
MLIR_CAPI_EXPORTED void mtrtOptionsContextPrint(MTRT_OptionsContext options,
127+
MlirStringCallback append,
128+
void *userData) {
129+
mlir::detail::CallbackOstream stream(append, userData);
130+
unwrap(options)->print(stream);
131+
}
132+
133+
MLIR_CAPI_EXPORTED MTRT_Status
134+
mtrtOptionsContextDestroy(MTRT_OptionsContext options) {
135+
delete unwrap(options);
136+
return mtrtStatusGetOk();
137+
}
138+
102139
//===----------------------------------------------------------------------===//
103140
// MTRT_StableHLOToExecutableOptions
104141
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Compiler/Client.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#include "mlir-tensorrt/Compiler/Client.h"
2525
#include "mlir/IR/BuiltinOps.h"
2626
#include "mlir/Support/FileUtilities.h"
27+
#include "llvm/ADT/ArrayRef.h"
28+
#include "llvm/ADT/StringRef.h"
29+
#include "llvm/Support/ManagedStatic.h"
2730

2831
using namespace mlirtrt;
2932
using namespace mlirtrt::compiler;
@@ -46,6 +49,25 @@ CompilationTaskBase::~CompilationTaskBase() {}
4649
// CompilerClient
4750
//===----------------------------------------------------------------------===//
4851

52+
namespace mlirtrt::compiler {
53+
static llvm::ManagedStatic<llvm::StringMap<OptionsConstructorFuncT>> registry{};
54+
55+
void registerOption(const llvm::StringRef optionsType,
56+
OptionsConstructorFuncT func) {
57+
(*registry)[optionsType] = std::move(func);
58+
}
59+
60+
StatusOr<std::unique_ptr<mlir::OptionsContext>>
61+
createOptions(const CompilerClient &client, const llvm::StringRef optionsType,
62+
const llvm::ArrayRef<llvm::StringRef> args) {
63+
if (!registry->contains(optionsType))
64+
return getInvalidArgStatus(
65+
"{0} is not a valid option type. Valid options were: {1:$[ ]}",
66+
optionsType, llvm::iterator_range(registry->keys()));
67+
return (*registry)[optionsType](client, args);
68+
}
69+
} // namespace mlirtrt::compiler
70+
4971
StatusOr<std::unique_ptr<CompilerClient>>
5072
CompilerClient::create(MLIRContext *context) {
5173
context->disableMultithreading();

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
501501
return opts;
502502
}
503503

504+
void mlirtrt::compiler::registerStableHloToExecutableTask() {
505+
registerOption("stable-hlo-to-executable",
506+
optionsCreateFromArgs<StableHLOToExecutableOptions,
507+
StableHloToExecutableTask>);
508+
}
509+
504510
void mlirtrt::compiler::registerStablehloClusteringPipelines() {
505511
PassRegistration<HloToStdPass>();
506512
PassRegistration<HloToArithDynamicPipelinePass>();

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir-tensorrt-c/Compiler/Compiler.h"
1818
#include "mlir/Bindings/Python/PybindAdaptors.h"
1919
#include "pybind11/pybind11.h"
20+
#include "llvm/ADT/StringRef.h"
2021
#include "llvm/Support/DynamicLibrary.h"
2122
#include "llvm/Support/raw_ostream.h"
2223
#include <pybind11/attr.h>
@@ -52,6 +53,17 @@ class PyCompilerClient
5253
mtrtCompilerClientIsNull, mtrtCompilerClientDestroy};
5354
};
5455

56+
/// Python object type wrapper for `MTRT_OptionsContext`.
57+
class PyOptionsContext
58+
: public PyMTRTWrapper<PyOptionsContext, MTRT_OptionsContext> {
59+
public:
60+
using PyMTRTWrapper::PyMTRTWrapper;
61+
DECLARE_WRAPPER_CONSTRUCTORS(PyOptionsContext);
62+
63+
static constexpr auto kMethodTable = CAPITable<MTRT_OptionsContext>{
64+
mtrtOptionsConextIsNull, mtrtOptionsContextDestroy};
65+
};
66+
5567
/// Python object type wrapper for `MTRT_StableHLOToExecutableOptions`.
5668
class PyStableHLOToExecutableOptions
5769
: public PyMTRTWrapper<PyStableHLOToExecutableOptions,
@@ -240,6 +252,36 @@ PYBIND11_MODULE(_api, m) {
240252
return new PyCompilerClient(client);
241253
}));
242254

255+
py::class_<PyOptionsContext>(m, "OptionsContext", py::module_local())
256+
.def(py::init<>([](PyCompilerClient &client,
257+
const std::string &optionsType,
258+
const std::vector<std::string> &args) {
259+
std::vector<MlirStringRef> refs(args.size());
260+
for (unsigned i = 0; i < args.size(); i++)
261+
refs[i] = mlirStringRefCreate(args[i].data(), args[i].size());
262+
263+
MTRT_OptionsContext options;
264+
MTRT_Status s = mtrtOptionsContextCreateFromArgs(
265+
client, &options,
266+
mlirStringRefCreate(optionsType.data(), optionsType.size()),
267+
refs.data(), refs.size());
268+
THROW_IF_MTRT_ERROR(s);
269+
return new PyOptionsContext(options);
270+
}),
271+
py::arg("client"), py::arg("options_type"), py::arg("args"))
272+
273+
.def("__repr__", [](PyOptionsContext &self) {
274+
auto callback = [](MlirStringRef data, void *initialString) {
275+
*reinterpret_cast<std::string *>(initialString) +=
276+
llvm::StringRef(data.data, data.length);
277+
};
278+
279+
std::string result("Options[");
280+
mtrtOptionsContextPrint(self, callback, &result);
281+
result += "]";
282+
return result;
283+
});
284+
243285
py::class_<PyStableHLOToExecutableOptions>(m, "StableHLOToExecutableOptions",
244286
py::module_local())
245287
.def(py::init<>([](PyCompilerClient &client,

mlir-tensorrt/python/bindings/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class PyMTRTWrapper {
163163

164164
static py::object createFromCapsule(py::object capsule) {
165165
if constexpr (cFuncTable.capsuleToCApi == nullptr) {
166-
throw py::value_error("boject cannot be converted from opaque capsule");
166+
throw py::value_error("object cannot be converted from opaque capsule");
167167
} else {
168168
MTRT_StableHLOToExecutableOptions cObj =
169169
cFuncTable.capsuleToCApi(capsule.ptr());
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# REQUIRES: host-has-at-least-1-gpus
2+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
3+
4+
import mlir_tensorrt.compiler.api as api
5+
from mlir_tensorrt.compiler.ir import *
6+
7+
8+
with Context() as context:
9+
client = api.CompilerClient(context)
10+
# Try to create a non-existent option type
11+
try:
12+
opts = api.OptionsContext(client, "non-existent-options-type", [])
13+
except Exception as err:
14+
print(err)
15+
16+
opts = api.OptionsContext(
17+
client,
18+
"stable-hlo-to-executable",
19+
[
20+
"--tensorrt-builder-opt-level=3",
21+
"--tensorrt-strongly-typed=false",
22+
"--tensorrt-workspace-memory-pool-limit=1gb",
23+
],
24+
)
25+
26+
print(opts)
27+
28+
29+
# CHECK: InvalidArgument: InvalidArgument: non-existent-options-type is not a valid option type. Valid options were: stable-hlo-to-executable
30+
# CHECK: --tensorrt-timing-cache-path= --device-infer-from-host=true --debug-only= --executor-index-bitwidth=64 --entrypoint=main --plan-clustering-disallow-host-tensors-in-tensorrt-clusters=false --tensorrt-workspace-memory-pool-limit=1073741824 --device-max-registers-per-block=65536 --tensorrt-strongly-typed=false --tensorrt-layer-info-dir= --device-compute-capability=86 --debug=false --mlir-print-ir-tree-dir= --disable-tensorrt-extension=false --tensorrt-builder-opt-level=3 --tensorrt-engines-dir=

0 commit comments

Comments
 (0)