Skip to content

Commit 67c652a

Browse files
## [compiler] Drop the StableHLO "signature refinement" API
Previously, we added a special API for TriPy to be able to perform only pre-processing (including entrypoint signature type refinement) on the input module. This path is no longer needed by TriPy, and since there are no other customers, it can be safely dropped from the C/C++/Python APIs. MR: initialdl/mlir-tensorrt!1722 ## [executor/runtime] Update NCCL module to enable non-blocking communicator Updates the NCCL runtime module so that the communicators are non-blocking and so that more consistent logic is used for handling errors. This helps resolve issues where the test may deadlock (either because of an incorrect runtime implementation issue or because of system config issue) without errors being printed to stderr. Additional TODOs are noted where the implementation can be further improved. GitOrigin-RevId: 5321de2a3d779500436c7a62097e0fc219958caf
1 parent a2542cf commit 67c652a

File tree

10 files changed

+186
-298
lines changed

10 files changed

+186
-298
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,38 +108,6 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable(
108108
MTRT_CompilerClient client, MlirOperation module,
109109
MTRT_StableHLOToExecutableOptions options, MTRT_Executable *result);
110110

111-
//===----------------------------------------------------------------------===//
112-
// MTRT_StableHLOProgramSignatureRefinementOptions
113-
//===----------------------------------------------------------------------===//
114-
115-
/// Options for compiling StableHLO MLIR to an Executable.
116-
typedef struct MTRT_StableHLOProgramSignatureRefinementOptions {
117-
void *ptr;
118-
} MTRT_StableHLOProgramSignatureRefinementOptions;
119-
120-
MLIR_CAPI_EXPORTED MTRT_Status
121-
mtrtStableHloProgramSignatureRefinementOptionsCreate(
122-
MTRT_StringView funcName,
123-
MTRT_StableHLOProgramSignatureRefinementOptions *options);
124-
125-
MLIR_CAPI_EXPORTED MTRT_Status
126-
mtrtStableHloProgramSignatureRefinementOptionsDestroy(
127-
MTRT_StableHLOProgramSignatureRefinementOptions options);
128-
129-
static inline bool mtrtStableHloProgramSignatureRefinementOptionsIsNull(
130-
MTRT_StableHLOProgramSignatureRefinementOptions options) {
131-
return !options.ptr;
132-
}
133-
134-
//===----------------------------------------------------------------------===//
135-
// Main StableHLO Program Signature Refinement API Functions
136-
//===----------------------------------------------------------------------===//
137-
138-
/// Compiler StableHLO to Executable.
139-
MLIR_CAPI_EXPORTED MTRT_Status mtrtGetStableHloProgramRefinedSignature(
140-
MTRT_CompilerClient client, MlirOperation module,
141-
MTRT_StableHLOProgramSignatureRefinementOptions options, MlirType *result);
142-
143111
#ifdef __cplusplus
144112
}
145113
#endif

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,41 +44,6 @@
4444

4545
namespace mlirtrt::compiler {
4646

47-
//===----------------------------------------------------------------------===//
48-
// StableHLOProgramSignatureRefinementOptions
49-
//===----------------------------------------------------------------------===//
50-
51-
struct StableHLOProgramSignatureRefinementOptions
52-
: public mlir::OptionsContext {
53-
/// Creates default compilation options.
54-
StableHLOProgramSignatureRefinementOptions() {
55-
this->addOption("func-name", funcName, llvm::cl::init("main"));
56-
debugOptions.addToOptions(*this);
57-
}
58-
59-
/// Set the entrypoint function name.
60-
StableHLOProgramSignatureRefinementOptions &
61-
setFuncName(const std::string &name) {
62-
funcName = name;
63-
return *this;
64-
}
65-
66-
std::string funcName = "main";
67-
68-
DebugOptions debugOptions;
69-
};
70-
71-
//===----------------------------------------------------------------------===//
72-
// StableHLO Signature Refinement Entrypoint
73-
//===----------------------------------------------------------------------===//
74-
75-
/// Attempt to refine the function signature of a StableHLO program through
76-
/// canonicalization and constant folding. Returns the refined signature of the
77-
/// specified function of the module.
78-
mlirtrt::StatusOr<mlir::FunctionType> getStableHLOProgramRefinedSignature(
79-
CompilerClient &client, mlir::ModuleOp module,
80-
const StableHLOProgramSignatureRefinementOptions &options);
81-
8247
//===----------------------------------------------------------------------===//
8348
// StableHLOToExecutableOptions
8449
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ using namespace mlir;
4444
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
4545
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
4646
StableHLOToExecutableOptions)
47-
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOProgramSignatureRefinementOptions,
48-
StableHLOProgramSignatureRefinementOptions)
4947
#if defined(__GNUC__) || defined(__clang__)
5048
#pragma GCC diagnostic pop
5149
#endif
@@ -255,43 +253,3 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
255253

256254
return mtrtStatusGetOk();
257255
}
258-
259-
//===----------------------------------------------------------------------===//
260-
// Main StableHLO Program Signature Refinement Functions
261-
//===----------------------------------------------------------------------===//
262-
263-
MTRT_Status mtrtStableHloProgramSignatureRefinementOptionsCreate(
264-
MTRT_StringView funcName,
265-
MTRT_StableHLOProgramSignatureRefinementOptions *options) {
266-
auto result = std::make_unique<StableHLOProgramSignatureRefinementOptions>();
267-
result->setFuncName(std::string(funcName.data, funcName.length));
268-
*options = wrap(result.release());
269-
return mtrtStatusGetOk();
270-
}
271-
272-
MTRT_Status mtrtStableHloProgramSignatureRefinementOptionsDestroy(
273-
MTRT_StableHLOProgramSignatureRefinementOptions options) {
274-
delete unwrap(options);
275-
return mtrtStatusGetOk();
276-
}
277-
278-
MTRT_Status mtrtGetStableHloProgramRefinedSignature(
279-
MTRT_CompilerClient client, MlirOperation module,
280-
MTRT_StableHLOProgramSignatureRefinementOptions options, MlirType *result) {
281-
ModuleOp moduleOp = llvm::dyn_cast<ModuleOp>(unwrap(module));
282-
if (!moduleOp)
283-
return mtrtStatusCreate(
284-
MTRT_StatusCode::MTRT_StatusCode_InvalidArgument,
285-
"StableHLO program signature refinement expects a ModuleOp");
286-
287-
StatusOr<FunctionType> funcType =
288-
compiler::getStableHLOProgramRefinedSignature(*unwrap(client), moduleOp,
289-
*unwrap(options));
290-
if (!funcType.isOk())
291-
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InvalidArgument,
292-
funcType.getString().c_str());
293-
294-
*result = wrap(mlir::Type(*funcType));
295-
296-
return mtrtStatusGetOk();
297-
}

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -144,68 +144,6 @@ class HloToStdPass
144144
};
145145
} // namespace
146146

147-
//===----------------------------------------------------------------------===//
148-
// StableHLO Signature Refinement Entrypoint
149-
//===----------------------------------------------------------------------===//
150-
151-
mlirtrt::StatusOr<mlir::FunctionType>
152-
compiler::getStableHLOProgramRefinedSignature(
153-
CompilerClient &client, mlir::ModuleOp module,
154-
const StableHLOProgramSignatureRefinementOptions &options) {
155-
156-
#ifndef NDEBUG
157-
//===----------------------------------------------------------------------===//
158-
// Set debug options.
159-
//===----------------------------------------------------------------------===//
160-
if (options.debugOptions.enableLLVMDebugFlag) {
161-
SmallVector<const char *> debugTypeLiterals =
162-
llvm::map_to_vector(options.debugOptions.llvmDebugTypes,
163-
[](const std::string &x) { return x.c_str(); });
164-
llvm::setCurrentDebugTypes(debugTypeLiterals.data(),
165-
debugTypeLiterals.size());
166-
llvm::DebugFlag = true;
167-
}
168-
#endif
169-
170-
//===----------------------------------------------------------------------===//
171-
// Setup pass manager
172-
//===----------------------------------------------------------------------===//
173-
174-
mlir::PassManager pm(module->getContext());
175-
if (failed(setupPassManager(pm, options.debugOptions))) {
176-
/// TODO: Ignored. This can fail if pass manager static CL options were not
177-
/// registered/initialized. This happens through invocation of e.g. this
178-
/// function in e.g. Python bindings or standalone calls to C++ or C API
179-
/// without doing all the typical static CL setup. We should instead be
180-
/// accepting a PassManager here that has already been setup to the caller's
181-
/// specifications.
182-
}
183-
184-
// Add pre-processing passes.
185-
{
186-
mlir::StableHloInputOptions opts{};
187-
opts.legalizeControlFlowToSCF = false;
188-
opts.preserveChloErf = true;
189-
opts.preserveChloTopK = true;
190-
mlir::buildStablehloPreProcessingPipeline(pm, opts);
191-
}
192-
193-
// Run pass pipeline.
194-
if (mlir::failed(pm.run(module)))
195-
return getStatusWithMsg(StatusCode::InternalError,
196-
"failed to run compilation pipeline");
197-
198-
// Get the signature.
199-
auto func = llvm::dyn_cast_or_null<func::FuncOp>(
200-
module.lookupSymbol(options.funcName));
201-
if (!func)
202-
return getInvalidArgStatus(
203-
"function with name {0} does not exist in the MLIR module",
204-
options.funcName);
205-
206-
return func.getFunctionType();
207-
}
208-
209147
//===----------------------------------------------------------------------===//
210148
// StableHLOToExecutableOptions
211149
//===----------------------------------------------------------------------===//

mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaErrorHandling.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,16 @@
6767
} \
6868
} while (false)
6969

70-
#define SET_LUA_ERROR_IF_NCCL_ERROR(x, lstate) \
70+
#define SET_LUA_ERROR_IF_NCCL_ERROR(x, lstate, comm) \
7171
do { \
7272
ncclResult_t err = (x); \
73-
if (err != ncclSuccess) { \
73+
if (err != ncclSuccess && err != ncclInProgress) { \
7474
lua_State *L = lstate; \
75-
luaL_error(L, ncclGetLastError(nullptr)); \
75+
std::string msg = llvm::formatv( \
76+
"{0}:{1} NCCL error [msg=\"{2}\" ncclGetLastError=\"{3}\"]", \
77+
__FILE__, __LINE__, ncclGetErrorString(err), \
78+
comm ? ncclGetLastError(comm) : ""); \
79+
luaL_error(L, msg.c_str()); \
7680
} \
7781
} while (false)
7882

mlir-tensorrt/executor/include/mlir-executor/Runtime/Support/Support.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@ namespace mlirtrt::runtime {
3535
// Debugging and logging tools
3636
//===----------------------------------------------------------------------===//
3737

38+
/// Prints the given printf-style formatted data to stderr if the 'runtime'
39+
/// debug module is enabled. Has no effect in non-assert builds.
40+
/// Note that we prepend a space to assist with readability when the logs are
41+
/// prefixed by other text when wrapped by another runtime system (e.g.
42+
/// 'mpirun').
3843
#define MTRT_DBGF(fmt, ...) \
39-
DEBUG_WITH_TYPE("runtime", fprintf(stderr, "%s:%d " fmt "\n", __FILE__, \
44+
DEBUG_WITH_TYPE("runtime", fprintf(stderr, " %s:%d " fmt "\n", __FILE__, \
4045
__LINE__, __VA_ARGS__))
4146

4247
template <typename... Args>

mlir-tensorrt/executor/include/mlir-executor/Support/Status.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ class StatusOr {
217217
} \
218218
} while (false);
219219

220+
/// Causes returning an InternalError status from the current scope if the NCCL
221+
/// result is not ncclSuccess or ncclInProgress.
220222
#define RETURN_ERROR_IF_NCCL_ERROR(x, comm) \
221223
do { \
222224
ncclResult_t err = (x); \
223-
if (err != ncclSuccess) { \
225+
if (err != ncclSuccess && err != ncclInProgress) { \
224226
return getInternalErrorStatus( \
225227
"{0}:{1} NCCL error [msg=\"{2}\" ncclGetLastError=\"{3}\"]", \
226228
__FILE__, __LINE__, ncclGetErrorString(err), \

0 commit comments

Comments
 (0)