Skip to content

Commit 21446e9

Browse files
authored
opt longlong2int (#69522)
1 parent 9a2972a commit 21446e9

File tree

4 files changed

+9
-15
lines changed

4 files changed

+9
-15
lines changed

paddle/cinn/optim/longlong2int.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class CastLonglong2Int : public ir::IRMutator<> {
6969
std::for_each(node->shape.begin(),
7070
node->shape.end(),
7171
[&](cinn::ir::Expr& e) { e->convert_int64_to_int32(); });
72-
CastBuffer(node->buffer);
72+
CastBufferMeta(node->buffer);
7373
}
7474
void Visit(const ir::Load* op, Expr* expr) override {
7575
auto node = expr->As<ir::Load>();
@@ -107,7 +107,7 @@ class CastLonglong2Int : public ir::IRMutator<> {
107107
range->ranges.end(),
108108
[&](cinn::ir::Var& v) { CastVarWithBound(v); });
109109
auto bf = range->buffer.as_buffer_ref();
110-
CastBuffer(bf);
110+
CastBufferMeta(bf);
111111
}
112112
}
113113

@@ -117,7 +117,7 @@ class CastLonglong2Int : public ir::IRMutator<> {
117117
range->ranges.end(),
118118
[&](cinn::ir::Var& v) { CastVarWithBound(v); });
119119
auto bf = range->buffer.as_buffer_ref();
120-
CastBuffer(bf);
120+
CastBufferMeta(bf);
121121
}
122122
}
123123
ir::IRMutator<>::Visit(&(node->body), &(node->body));
@@ -140,7 +140,7 @@ class CastLonglong2Int : public ir::IRMutator<> {
140140
if (lb.defined()) lb->convert_int64_to_int32();
141141
if (ub.defined()) ub->convert_int64_to_int32();
142142
}
143-
void CastBuffer(cinn::ir::Buffer& bf) { // NOLINT
143+
void CastBufferMeta(cinn::ir::Buffer& bf) { // NOLINT
144144
if (!bf.defined()) return;
145145
std::for_each(bf->shape.begin(), bf->shape.end(), [&](cinn::ir::Expr& e) {
146146
e->convert_int64_to_int32();

paddle/cinn/optim/replace_var_with_expr.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,7 @@ struct CollectTensorIndexMutator : public ir::IRMutator<> {
155155
std::vector<std::vector<Expr>> CollectTensorIndex(
156156
Expr* source, const std::string& tensor_name) {
157157
CollectTensorIndexMutator mutator(tensor_name);
158-
std::vector<std::vector<Expr>> result = mutator(source);
159-
for (auto& i : result) {
160-
for (auto& j : i) {
161-
j = cinn::common::AutoSimplify(j);
162-
}
163-
}
164-
return result;
158+
return mutator(source);
165159
}
166160

167161
} // namespace optim

paddle/cinn/optim/transform_gpu_forloop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#include "paddle/cinn/utils/string.h"
4141
#include "paddle/common/enforce.h"
4242

43-
PD_DECLARE_bool(cinn_longlong2int_for_integer);
43+
PD_DECLARE_bool(cinn_longlong2int);
4444
namespace cinn {
4545
namespace optim {
4646

@@ -487,7 +487,7 @@ void OptimizeExprGPU(Expr *expr) {
487487
ReplaceVarToZero replace_var_to_zero;
488488
replace_var_to_zero(expr);
489489

490-
if (FLAGS_cinn_longlong2int_for_integer) {
490+
if (FLAGS_cinn_longlong2int) {
491491
TryCastLonglong2Int(expr);
492492
}
493493

paddle/cinn/runtime/flags.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ PD_DEFINE_bool(cinn_check_tensor_buffer_map,
296296
BoolFromEnv("FLAGS_cinn_check_tensor_buffer_map", false),
297297
"Whether to check tensor buffer mapping in cinn ir.");
298298

299-
PD_DEFINE_bool(cinn_longlong2int_for_integer,
300-
BoolFromEnv("FLAGS_cinn_longlong2int_for_integer", true),
299+
PD_DEFINE_bool(cinn_longlong2int,
300+
BoolFromEnv("FLAGS_cinn_longlong2int", true),
301301
"Whether to cast long long to int for integer.");
302302

303303
namespace cinn {

0 commit comments

Comments
 (0)