Skip to content

Commit c29cab8

Browse files
committed
merge
2 parents 22f5657 + 91108c7 commit c29cab8

File tree

128 files changed

+1981
-599
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+1981
-599
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ repos:
5353
rev: v1.27.3
5454
hooks:
5555
- id: typos
56-
args: []
56+
args: [--force-exclude]
5757
# For Python files
5858
- repo: https://github.com/psf/black-pre-commit-mirror
5959
rev: 24.8.0

_typos.toml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
[files]
2+
# The following files will be excluded from spell check during commits
3+
extend-exclude = [
4+
"test/dataset/imikolov_test.py"
5+
]
6+
17
[default.extend-words]
28
# PaddlePaddle specific words
39
lod = "lod"
@@ -71,7 +77,6 @@ Prepar = 'Prepar'
7177
precent = 'precent'
7278
sheduler = 'sheduler'
7379
outpus = 'outpus'
74-
atrribute = 'atrribute'
7580
normlize = 'normlize'
7681
Costum = 'Costum'
7782
differnt = 'differnt'
@@ -150,11 +155,9 @@ Uniqe = 'Uniqe'
150155
valuse = 'valuse'
151156
exsits = 'exsits'
152157
sucessfully = 'sucessfully'
153-
CACH = 'CACH'
154158
endianess = 'endianess'
155159
VAILD = 'VAILD'
156160
ues = 'ues'
157-
aer = 'aer'
158161
elemenents = 'elemenents'
159162
CANN = 'CANN'
160163
pathes = 'pathes'
@@ -197,7 +200,6 @@ sotring = 'sotring'
197200
overriden = 'overriden'
198201
Maxinum = 'Maxinum'
199202
caculate = 'caculate'
200-
cahr = 'cahr'
201203
occures = 'occures'
202204
framwork = 'framwork'
203205
localy = 'localy'
@@ -243,7 +245,6 @@ starup = 'starup'
243245
iy = 'iy'
244246
bindins = 'bindins'
245247
choses = 'choses'
246-
catched = 'catched'
247248
rewrited = 'rewrited'
248249
targt = 'targt'
249250
Theoritical = 'Theoritical'
@@ -293,10 +294,8 @@ wiil = 'wiil'
293294
configurated = 'configurated'
294295
perfome = 'perfome'
295296
consructor = 'consructor'
296-
attribtue = 'attribtue'
297297
quitted = 'quitted'
298298
attribtes = 'attribtes'
299-
automatical = 'automatical'
300299
orignal = 'orignal'
301300
furture = 'furture'
302301
Indext = 'Indext'
@@ -335,7 +334,6 @@ funtion = 'funtion'
335334
optin = 'optin'
336335
defualt = 'defualt'
337336
envirnment = 'envirnment'
338-
cuase = 'cuase'
339337
fot = 'fot'
340338
coloumn = 'coloumn'
341339
inital = 'inital'
@@ -463,11 +461,9 @@ Leafs = 'Leafs'
463461
effecient = 'effecient'
464462
modifed = 'modifed'
465463
deserailize = 'deserailize'
466-
channnel = 'channnel'
467464
Suger = 'Suger'
468465
Actural = 'Actural'
469466
subsituted = 'subsituted'
470-
automaticly = 'automaticly'
471467
Minium = 'Minium'
472468
sequnece = 'sequnece'
473469
payed = 'payed'
@@ -615,7 +611,6 @@ theads = 'theads'
615611
postive = 'postive'
616612
progrss = 'progrss'
617613
diffrent = 'diffrent'
618-
attritube = 'attritube'
619614
compability = 'compability'
620615
hge = 'hge'
621616
Funcion = 'Funcion'
@@ -771,15 +766,12 @@ distrubuted = 'distrubuted'
771766
Localy = 'Localy'
772767
PARM = 'PARM'
773768
thi = 'thi'
774-
Oll = 'Oll'
775-
Auxillary = 'Auxillary'
776769
Infor = 'Infor'
777770
statment = 'statment'
778771
varn = 'varn'
779772
exmaple = 'exmaple'
780773
happend = 'happend'
781774
sequentail = 'sequentail'
782-
channles = 'channles'
783775
Mutiply = 'Mutiply'
784776
Currenly = 'Currenly'
785777
dimention = 'dimention'

cmake/external/pybind11.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ set(PYBIND_PATCH_COMMAND "")
2727
if(LINUX
2828
AND (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
2929
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
30-
set(PYBIND_TAG v2.12.0)
30+
set(PYBIND_TAG v2.13.6)
3131
file(TO_NATIVE_PATH
3232
${PADDLE_SOURCE_DIR}/patches/pybind/detail/internals.h.patch native_dst)
3333
# Note: [Why calling some `git` commands before `patch`?]

paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs,
539539

540540
template <typename SymbolBindingsT>
541541
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
542-
const std::set<std::string>& symbol_names,
542+
std::set<std::string>* remain_symbol_names_to_bind,
543543
int in_tensor_idx,
544544
SymbolBindings* symbol_bindings) {
545545
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
@@ -550,7 +550,10 @@ void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
550550
"The type of dim_expr is not atomic"));
551551
if (!dim_expr.isa<std::string>()) continue;
552552
const auto& sym_name = dim_expr.dyn_cast<std::string>();
553-
if (symbol_names.find(sym_name) == symbol_names.end()) continue;
553+
if (remain_symbol_names_to_bind->find(sym_name) ==
554+
remain_symbol_names_to_bind->end())
555+
continue;
556+
remain_symbol_names_to_bind->erase(sym_name);
554557
symbol_bindings->emplace_back(SymbolBindingsT{
555558
/*.symbol_name=*/sym_name,
556559
/*.input_tensor_idx=*/in_tensor_idx,
@@ -564,14 +567,17 @@ void GenerateSymbolBindings(
564567
const std::vector<pir::Value>& input_tensors,
565568
const std::set<std::string>& symbol_names,
566569
SymbolBindings* symbol_bindings) {
570+
std::set<std::string> remain_symbol_names_to_bind = symbol_names;
567571
for (int i = 0; i < input_tensors.size(); ++i) {
568572
const auto& input_tensor = input_tensors.at(i);
569573
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
570574
AppendSymbolBindings<ShapeSymbolBinding>(
571-
dim_exprs.shape(), symbol_names, i, symbol_bindings);
575+
dim_exprs.shape(), &remain_symbol_names_to_bind, i, symbol_bindings);
572576
if (dim_exprs.data().has_value()) {
573-
AppendSymbolBindings<DataSymbolBinding>(
574-
dim_exprs.data().value(), symbol_names, i, symbol_bindings);
577+
AppendSymbolBindings<DataSymbolBinding>(dim_exprs.data().value(),
578+
&remain_symbol_names_to_bind,
579+
i,
580+
symbol_bindings);
575581
}
576582
}
577583
}

paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class MergeParallelMatmulPattern
4444
if (!op->dyn_cast<paddle::dialect::MatmulOp>()) {
4545
return false;
4646
}
47+
4748
bool trans_x =
4849
op->attribute("transpose_x").dyn_cast<pir::BoolAttribute>().data();
4950
bool trans_y =
@@ -54,6 +55,10 @@ class MergeParallelMatmulPattern
5455
return false;
5556
}
5657

58+
auto IsFirstInput = [&](pir::Operation* op, pir::Value in_x) -> bool {
59+
return in_x == op->operand_source(0);
60+
};
61+
5762
auto VectorPrefixEqual = [](const std::vector<std::int64_t>& a,
5863
const std::vector<std::int64_t>& b) {
5964
return std::vector<std::int64_t>(a.begin(), a.end() - 1) ==
@@ -74,6 +79,10 @@ class MergeParallelMatmulPattern
7479
if (!ValidMatmulTranspose(it->owner())) {
7580
continue;
7681
}
82+
83+
if (!IsFirstInput(it->owner(), input_x)) {
84+
continue;
85+
}
7786
if (!pre_dim.has_value()) {
7887
pre_dim = ::common::vectorize(
7988
it->owner()

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ class FuseSingleElementShapeOpsIntoGenerateShapeOpPattern
423423
auto* user = iter->owner();
424424
if (IsSingleElementShapeOp(user, &shape_analysis)) return false;
425425
if (user->isa<cinn::dialect::GenerateShapeOp>()) return false;
426+
if (user->isa<pir::ShadowOutputOp>()) return false;
426427
}
427428

428429
return true;

paddle/cinn/hlir/framework/pir/trivial_op_util.cc

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h"
1616

17+
#include "paddle/cinn/common/dim_expr_converter.h"
1718
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1819
#include "paddle/cinn/hlir/framework/compile_error.h"
1920
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
@@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
547548
* remove it in axes.bind()
548549
*/
549550
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
550-
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
551-
<< replaced_expr << ")";
552-
VLOG(4) << " Input is " << e;
553551
PADDLE_ENFORCE_NE(
554552
e.As<ir::ScheduleBlockRealize>(),
555553
nullptr,
@@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
562560
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
563561
->schedule_block.As<ir::ScheduleBlock>()
564562
->iter_vars;
565-
for (const auto& i_var : schedule_block_iter_vars) {
566-
PADDLE_ENFORCE_EQ(
567-
i_var.is_var(),
568-
true,
569-
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
570-
"axes.bind rhs is is not a Var."));
571-
}
572563
// find replace idx
573564
int target_idx = -1;
574565
for (int i = 0; i < schedule_block_iter_vars.size(); ++i) {
575-
VLOG(4) << "RemoveVarInScheduleBlockRealize: compare with "
576-
<< schedule_block_iter_vars[i] << " vs " << target_vars
577-
<< ", and equality is: "
578-
<< (schedule_block_iter_vars[i].as_var()->name ==
579-
target_vars->name);
580-
if (schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
566+
if (schedule_block_iter_vars[i].is_var() &&
567+
schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
581568
target_idx = i;
582569
}
583570
}
@@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
688675
.GetSingle(copied);
689676
const ir::Expr& target_block =
690677
ExprSetFinderUtils::DirectlyFather(copied).GetSingle(target_for);
691-
VLOG(4) << "RemoveOneTransformer: directly target_block of for is "
692-
<< target_block;
693678
if (target_block.As<ir::ScheduleBlockRealize>() != nullptr) {
694679
VLOG(4) << "RemoveOneTransformer: father block is root realize";
695680
ir::Expr shedule_block =
@@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
708693
shedule_block.As<ir::ScheduleBlock>()->body = for_body;
709694
}
710695
} else if (target_block.As<ir::Block>() != nullptr) {
711-
VLOG(4) << "RemoveOneTransformer: father block is Block";
712696
std::vector<ir::Expr> new_bodies;
713697
for (const auto& expr : target_block.As<ir::Block>()->stmts) {
714698
if (expr != target_for) {
@@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
728712
"RemoveOneTransformer: target for father should be a ir::Block or "
729713
"ir::ScheduleBlockRealize."));
730714
}
731-
VLOG(4) << "Remove Var to 0 in ScheduleBlockRealizer: " << copied;
732715
// Remove var to 0 in ScheduleBlockRealizer
733716
InplaceMutateSingleExpr(
734717
&copied,
@@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {
949932

950933
ir::Expr GetBodyBlock(const ir::Expr& root) {
951934
const auto& iters = GetNonReduceLoopVars(root);
935+
if (iters.empty()) {
936+
return ir::Block::Make(
937+
{ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle(root)});
938+
}
952939
const size_t reduce_size =
953940
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
954941
return v->is_reduce_axis;
@@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
965952
->body;
966953
}
967954

955+
ir::Expr ReshapeLoop(const ir::Expr& root,
956+
const std::vector<symbol::DimExpr>& in_shape,
957+
const std::vector<symbol::DimExpr>& out_shape) {
958+
auto copied = ir::ir_utils::IRCopy(root);
959+
960+
ir::ModuleExpr mod_expr({copied});
961+
ir::IRSchedule ir_sch(
962+
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);
963+
964+
const auto block_realize =
965+
(ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle(copied);
966+
const auto block_name = block_realize.As<ir::ScheduleBlockRealize>()
967+
->schedule_block.As<ir::ScheduleBlock>()
968+
->name;
969+
const auto shape_partion = fusion::PartionReshapeAxes(in_shape, out_shape);
970+
971+
for (int idx = shape_partion.size() - 1; idx > 0; --idx) {
972+
const auto& in_s = shape_partion[idx - 1].first;
973+
const auto& in_e = shape_partion[idx].first;
974+
const auto& out_s = shape_partion[idx - 1].second;
975+
const auto& out_e = shape_partion[idx].second;
976+
977+
std::vector<int> fuse_indices;
978+
for (int i = in_e - 1; i >= in_s; --i) {
979+
if (in_shape[i] != symbol::DimExpr(1)) {
980+
fuse_indices.insert(fuse_indices.begin(), i);
981+
} else {
982+
VLOG(4) << "Remove index[" << i << "]: " << in_shape[i]
983+
<< " for expr: \n"
984+
<< copied;
985+
copied = ExprTransformerUtils::RemoveOneTransformer(i)(copied);
986+
ir_sch.SetExprs({copied});
987+
for (auto& index : fuse_indices) {
988+
index--;
989+
}
990+
}
991+
}
992+
if (fuse_indices.size() > 1) {
993+
VLOG(4) << "fuse_indices: " << cinn::utils::Join(fuse_indices, ",");
994+
ir_sch.Fuse(block_name, fuse_indices);
995+
}
996+
997+
std::vector<ir::Expr> split_shapes;
998+
for (int i = out_s; i < out_e; ++i) {
999+
if (out_shape[i] != symbol::DimExpr(1)) {
1000+
split_shapes.push_back(
1001+
cinn::common::DimExprConverter().ConvertToIrExpr(out_shape[i]));
1002+
}
1003+
}
1004+
if (split_shapes.size() > 1) {
1005+
ir_sch.Split(ir_sch.GetLoops(block_name)[in_s], split_shapes)[0];
1006+
}
1007+
}
1008+
1009+
std::vector<int> insert_axis;
1010+
std::vector<ir::Var> ones_var;
1011+
for (int i = 0; i < out_shape.size(); ++i) {
1012+
if (out_shape[i] == symbol::DimExpr(1)) {
1013+
insert_axis.push_back(i);
1014+
ones_var.push_back(ir::Var(1, "one_" + std::to_string(ones_var.size())));
1015+
}
1016+
}
1017+
copied = ExprTransformerUtils::InsertForsTransformer(insert_axis,
1018+
ones_var)(copied);
1019+
1020+
return copied;
1021+
}
1022+
9681023
} // namespace trivial_fusion_detail
9691024
} // namespace pir
9701025
} // namespace framework

paddle/cinn/hlir/framework/pir/trivial_op_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root);
297297

298298
ir::Expr GetBodyBlock(const ir::Expr& root);
299299

300+
ir::Expr ReshapeLoop(const ir::Expr& root,
301+
const std::vector<symbol::DimExpr>& in_shape,
302+
const std::vector<symbol::DimExpr>& out_shape);
303+
300304
} // namespace trivial_fusion_detail
301305
} // namespace pir
302306
} // namespace framework

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
6363
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
6464
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
6565
for (int i = 0; i < iter_vars.size(); i++) {
66-
ir::Var loop_var = block->iter_values[i].as_var_ref();
67-
if (reduce_loop_vars.count(loop_var->name) > 0) {
66+
if (block->iter_values[i].is_var() &&
67+
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
6868
reduce_iter_vars.insert(iter_vars[i]->name);
6969
}
7070
}

paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ir::Expr ApplyItersTransform::operator()(const TransposeItersTransform& trans) {
3131

3232
ir::Expr ApplyItersTransform::operator()(const RemoveOnesTransform& trans) {
3333
VLOG(4) << "[ItersTransform] Before RemoveOnesTransform("
34-
<< utils::Join(trans.ones_, ",") << "'): " << expr_;
34+
<< utils::Join(trans.ones_, ",") << "): " << expr_;
3535
auto result = RemoveOnesTransformer(trans.ones_)(expr_);
3636
VLOG(4) << "[ItersTransform] After RemoveOnesTransform: " << result;
3737
return result;

0 commit comments

Comments
 (0)