Skip to content

Commit 0e28ae4

Browse files
[compiler] Move StableHLO partitioning attribute handling to StablehloToPlan pass
Move the logic for reading StableHLO partitioning attributes (mhlo.num_replicas and mhlo.num_partitions) and setting the executor process grid shape from the plan-clustering pass to the StablehloToPlan conversion pass. This is a more appropriate location since this conversion happens during the StableHLO to Plan dialect conversion phase. Also clean up unused includes in Clustering.cpp and remove unnecessary dependent dialects declaration from Passes.td. GitOrigin-RevId: 112e7c7386773df5cb57982ac1e3a831af5be621
1 parent ef0e28f commit 0e28ae4

File tree

6 files changed

+38
-415
lines changed

6 files changed

+38
-415
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -330,26 +330,18 @@ def ClusteringPass : Pass<"plan-clustering", "::mlir::ModuleOp"> {
330330
[CompilerBackendAttrInterface](../IR/PlanInterfaces.td).
331331
}];
332332

333-
let options = [
334-
Option<"entrypoint", "entrypoint", "std::string", "\"\"",
335-
"the name of the entrypoint function; if empty then the clustering runs"
336-
" on all functions">,
337-
Option<"forceEntrypointsReturnAllocs",
338-
"force-entrypoints-return-allocs", "bool", "false",
339-
"allow backend clusters to directly allocate outputs">,
340-
Option<"disableCreateShapeFuncPass", "disable-create-shape-func-pass", "bool", "false",
341-
"don't apply create shape to func pass in TensorRT clusters">,
342-
InputKindOption
343-
];
344-
345-
let dependentDialects = [
346-
// TODO: TensorRT and Tensor dialects needed since the different
347-
// backends may create these ops. Add a way for backends to declare
348-
// dialect dependencies.
349-
"::mlir::tensor::TensorDialect",
350-
"::mlir::tensorrt::TensorRTDialect",
351-
"::mlir::plan::PlanDialect"
352-
];
333+
let options =
334+
[Option<"entrypoint", "entrypoint", "std::string", "\"\"",
335+
"the name of the entrypoint function; if empty then the "
336+
"clustering runs"
337+
" on all functions">,
338+
Option<"forceEntrypointsReturnAllocs", "force-entrypoints-return-allocs",
339+
"bool", "false",
340+
"allow backend clusters to directly allocate outputs">,
341+
Option<"disableCreateShapeFuncPass", "disable-create-shape-func-pass",
342+
"bool", "false",
343+
"don't apply create shape to func pass in TensorRT clusters">,
344+
InputKindOption];
353345
}
354346

355347
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Conversion/StablehloToPlan/StablehloToPlan.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ struct ShapeAssertionPattern
288288
return success();
289289
}
290290
};
291+
} // namespace
292+
293+
static auto getIntegerAttrOrDefault(Operation *op, StringRef name,
294+
int64_t defaultValue) {
295+
if (auto attr = op->getAttrOfType<IntegerAttr>(name))
296+
return attr.getInt();
297+
return defaultValue;
298+
}
299+
namespace {
291300

292301
struct ConvertStablehloToPlanPass
293302
: public impl::ConvertStablehloToPlanPassBase<ConvertStablehloToPlanPass> {
@@ -325,6 +334,22 @@ struct ConvertStablehloToPlanPass
325334
if (walkResult.wasInterrupted())
326335
return signalPassFailure();
327336
}
337+
338+
// Check for StableHLO partitioning attributes and attach the executor
339+
// grid shape attribute.
340+
ModuleOp module = getOperation();
341+
{
342+
auto numReplicas =
343+
getIntegerAttrOrDefault(module, "mhlo.num_replicas", 1);
344+
auto numPartitions =
345+
getIntegerAttrOrDefault(module, "mhlo.num_partitions", 1);
346+
if (failed(executor::setModuleProcessGridShape(
347+
module, {numReplicas, numPartitions}))) {
348+
emitError(module->getLoc())
349+
<< "failed to set the Executor process grid shape attribute";
350+
return signalPassFailure();
351+
}
352+
}
328353
}
329354
};
330355
} // namespace

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Clustering.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,16 @@
2121
/// Implementation of the `plan-clustering` pass.
2222
///
2323
//===----------------------------------------------------------------------===//
24-
#include "mlir-executor/Executor/IR/Executor.h"
2524
#include "mlir-executor/Transforms/Clustering/Patterns.h"
2625
#include "mlir-tensorrt-common/Interfaces/TensorKindOpInterface.h"
2726
#include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h"
28-
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2927
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3028
#include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h"
3129
#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h"
3230
#include "mlir-tensorrt/Utils/DataFlowUtils.h"
3331
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
3432
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
3533
#include "mlir/Analysis/DataFlowFramework.h"
36-
#include "mlir/Dialect/Func/IR/FuncOps.h"
37-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
3834
#include "mlir/IR/OpDefinition.h"
3935
#include "mlir/Transforms/DialectConversion.h"
4036
#include "mlir/Transforms/RegionUtils.h"
@@ -149,13 +145,6 @@ applyClusteringToFunc(RewriterBase &rewriter, FunctionOpInterface func,
149145
return success();
150146
}
151147

152-
static auto getIntegerAttrOrDefault(Operation *op, StringRef name,
153-
int64_t defaultValue) {
154-
if (auto attr = op->getAttrOfType<IntegerAttr>(name))
155-
return attr.getInt();
156-
return defaultValue;
157-
}
158-
159148
namespace {
160149
class ClusteringPass : public plan::impl::ClusteringPassBase<ClusteringPass> {
161150
public:
@@ -229,22 +218,6 @@ class ClusteringPass : public plan::impl::ClusteringPassBase<ClusteringPass> {
229218
return signalPassFailure();
230219
}
231220

232-
// Check for StableHLO partitioning attributes and attach the executor
233-
// grid shape attribute.
234-
/// TODO: move this logic into a standalone pass that handles partitioning.
235-
{
236-
auto numReplicas =
237-
getIntegerAttrOrDefault(module, "mhlo.num_replicas", 1);
238-
auto numPartitions =
239-
getIntegerAttrOrDefault(module, "mhlo.num_partitions", 1);
240-
if (failed(executor::setModuleProcessGridShape(
241-
module, {numReplicas, numPartitions}))) {
242-
emitError(module->getLoc())
243-
<< "failed to set the Executor process grid shape attribute";
244-
return signalPassFailure();
245-
}
246-
}
247-
248221
// Drop clustering attributes since they are no longer needed.
249222
module->removeAttr(plan::PlanDialect::kBackendsAttrName);
250223
}

mlir-tensorrt/compiler/test/Dialect/Plan/stablehlo-clustering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func.func @test(%arg0: tensor<4xi32>, %arg1: tensor<i32>)
7575

7676
}
7777

78-
// CHECK-LABEL: module @host_backends_with_values attributes
78+
// CHECK-LABEL: module @host_backends_with_values
7979
// CHECK-LABEL: func.func @test
8080
// CHECK-SAME: (%[[arg0:.+]]: tensor<4xi32>, %[[arg1:.+]]: tensor<i32>)
8181
// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>

mlir-tensorrt/executor/include/mlir-executor/Transforms/Clustering/Patterns.h

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -113,59 +113,6 @@ class ClusteringPatternSet {
113113
SmallVector<std::unique_ptr<RewriteType>> patterns;
114114
};
115115

116-
/// Apply a set of clustering patterns to the function. Patterns are sorted and
117-
/// applied in decreasing order by benefit.
118-
LogicalResult
119-
applyClusteringPatterns(FunctionOpInterface mainFunc,
120-
ClusteringPatternSet<ClusteringRewriter> &patterns);
121-
122-
/// A type of a function that can filter cluster region operations.
123-
using RegionOpFilterFn = std::function<bool(Operation *)>;
124-
125-
/// Create a cluster region op filter using the specified parameters.
126-
RegionOpFilterFn getRegionOpFilter(
127-
Attribute target,
128-
unsigned operationCnt = std::numeric_limits<unsigned>::max(),
129-
IsClusterableOpFn canOpCluster = [](Operation *op) { return true; });
130-
131-
/// Given a `func.func` operation, this class describes a base
132-
/// pattern for doing a "scf::ExecuteRegion" based transformation
133-
class RegionOpFusionRewriter {
134-
public:
135-
RegionOpFusionRewriter(const SmallVector<RegionOpFilterFn> &filters,
136-
Attribute newTarget,
137-
ClusterRegionOpBuilderFunc regionOpBuilderFunc)
138-
: filters(filters), target(newTarget),
139-
regionOpBuilderFunc(regionOpBuilderFunc) {}
140-
141-
/// This function walks on the mainFunc graph and finds any matched
142-
/// patterns according to filters. After it finds matched consecutive
143-
/// Operation* in the graph and it will try to merge them into 1
144-
/// single Operation* and rewrite it into the graph with a new
145-
/// clustering target
146-
void run(FunctionOpInterface mainFunc, RewriterBase &rewriter);
147-
148-
private:
149-
/// A list of filter functions that identify scf.execute_region operations of
150-
/// interest
151-
SmallVector<RegionOpFilterFn> filters;
152-
153-
/// the target of the merged region operation will be set to
154-
Attribute target;
155-
156-
ClusterRegionOpBuilderFunc regionOpBuilderFunc;
157-
};
158-
159-
/// Return the "target" of a particular cluster represented by the
160-
/// `scf.execute_region` operation. This currently returns the StringAttr
161-
/// named `__cluster_target__` if present, failure otherwise.
162-
FailureOr<Attribute> getClusterTarget(Operation *regionOp);
163-
164-
/// Apply a set of region-op rewriter patterns to the function.
165-
LogicalResult applyRegionOpRewritePatterns(
166-
FunctionOpInterface mainFunc,
167-
ClusteringPatternSet<RegionOpFusionRewriter> &patterns);
168-
169116
} // namespace mlir
170117

171118
#endif // MLIR_TENSORRT_TRANSFORMS_CLUSTERING_PATTERNS_H

0 commit comments

Comments
 (0)