Skip to content

Commit 5661705

Browse files
Merge 937aafd into 2378fba
2 parents 2378fba + 937aafd commit 5661705

File tree

6 files changed

+333
-206
lines changed

6 files changed

+333
-206
lines changed

ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#include "kqp_opt_peephole.h"
23
#include "kqp_opt_peephole_rules.h"
34

@@ -88,8 +89,8 @@ TStatus ReplaceNonDetFunctionsWithParams(TExprNode::TPtr& input, TExprContext& c
8889

8990
class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
9091
public:
91-
TKqpPeepholeTransformer(TTypeAnnotationContext& typesCtx, TSet<TString> disabledOpts)
92-
: TOptimizeTransformerBase(&typesCtx, NYql::NLog::EComponent::ProviderKqp, disabledOpts)
92+
TKqpPeepholeTransformer(TTypeAnnotationContext& typesCtx, TSet<TString> disabledOpts, TKikimrConfiguration::TPtr config)
93+
: TOptimizeTransformerBase(&typesCtx, NYql::NLog::EComponent::ProviderKqp, disabledOpts), Config(config)
9394
{
9495
#define HNDL(name) "KqpPeephole-"#name, Hndl(&TKqpPeepholeTransformer::name)
9596
AddHandler(0, &TDqReplicate::Match, HNDL(RewriteReplicate));
@@ -99,6 +100,13 @@ class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
99100
AddHandler(0, &TDqPhyJoinDict::Match, HNDL(RewriteDictJoin));
100101
AddHandler(0, &TDqJoin::Match, HNDL(RewritePureJoin));
101102
AddHandler(0, &TDqPhyBlockHashJoin::Match, HNDL(RewriteBlockHashJoin));
103+
AddHandler(0, [](const TExprNode* node) {
104+
bool isGraceJoinCore = node->IsCallable("GraceJoinCore");
105+
if (isGraceJoinCore) {
106+
Cerr << "Handler matched GraceJoinCore for BlockHashJoin conversion!" << Endl;
107+
}
108+
return isGraceJoinCore;
109+
}, HNDL(RewriteBlockHashJoinCore));
102110
AddHandler(0, TOptimizeTransformerBase::Any(), HNDL(BuildWideReadTable));
103111
AddHandler(0, &TDqPhyLength::Match, HNDL(RewriteLength));
104112
AddHandler(0, &TKqpWriteConstraint::Match, HNDL(RewriteKqpWriteConstraint));
@@ -160,11 +168,21 @@ class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
160168
return output;
161169
}
162170

171+
TMaybeNode<TExprBase> RewriteBlockHashJoinCore(TExprBase node, TExprContext& ctx) {
172+
bool useBlockHashJoin = Config && Config->UseBlockHashJoin.Get().GetOrElse(false);
173+
TExprBase output = DqPeepholeRewriteBlockHashJoinCore(node, ctx, useBlockHashJoin);
174+
DumpAppliedRule("RewriteBlockHashJoinCore", node.Ptr(), output.Ptr(), ctx);
175+
return output;
176+
}
177+
163178
TMaybeNode<TExprBase> RewriteKqpWriteConstraint(TExprBase node, TExprContext& ctx) {
164179
TExprBase output = KqpRewriteWriteConstraint(node, ctx);
165180
DumpAppliedRule("RewriteKqpWriteConstraint", node.Ptr(), output.Ptr(), ctx);
166181
return output;
167182
}
183+
184+
private:
185+
TKikimrConfiguration::TPtr Config;
168186
};
169187

170188
struct TKqpPeepholePipelineConfigurator : IPipelineConfigurator {
@@ -183,7 +201,7 @@ struct TKqpPeepholePipelineConfigurator : IPipelineConfigurator {
183201
}
184202

185203
void AfterOptimize(TTransformationPipeline* pipeline) const override {
186-
pipeline->Add(new TKqpPeepholeTransformer(*pipeline->GetTypeAnnotationContext(), DisabledOpts), "KqpPeephole");
204+
pipeline->Add(new TKqpPeepholeTransformer(*pipeline->GetTypeAnnotationContext(), DisabledOpts, Config), "KqpPeephole");
187205
}
188206

189207
private:
@@ -651,3 +669,4 @@ TAutoPtr<IGraphTransformer> CreateKqpTxsPeepholeTransformer(
651669
}
652670

653671
} // namespace NKikimr::NKqp::NOpt
672+

ydb/library/yql/dq/opt/dq_opt_join.cpp

Lines changed: 75 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,37 @@ using namespace NYql::NNodes;
1414

1515
namespace {
1616

17+
TExprNode::TPtr ExpandJoinInput(const TStructExprType& type, TExprNode::TPtr&& arg, TExprContext& ctx, std::vector<std::pair<TString, const TTypeAnnotationNode*>>& convertedItems, TPositionHandle position) {
18+
return ctx.Builder(arg->Pos())
19+
.Callable("ExpandMap")
20+
.Add(0, std::move(arg))
21+
.Lambda(1)
22+
.Param("item")
23+
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
24+
auto i = 0U;
25+
for (const auto& item : type.GetItems()) {
26+
parent.Callable(i, "Member")
27+
.Arg(0, "item")
28+
.Atom(1, item->GetName())
29+
.Seal();
30+
i++;
31+
}
32+
for (const auto& convertedItem : convertedItems) {
33+
parent.Callable(i, "StrictCast")
34+
.Callable(0, "Member")
35+
.Arg(0, "item")
36+
.Atom(1, convertedItem.first)
37+
.Seal()
38+
.Add(1, ExpandType(position, *convertedItem.second, ctx))
39+
.Seal();
40+
i++;
41+
}
42+
return parent;
43+
})
44+
.Seal()
45+
.Seal().Build();
46+
}
47+
1748
struct TJoinInputDesc {
1849
TJoinInputDesc(TMaybe<THashSet<TStringBuf>> labels, const TExprBase& input,
1950
TSet<std::pair<TStringBuf, TStringBuf>>&& keys)
@@ -1151,25 +1182,7 @@ TExprBase DqBuildJoinDict(const TDqJoin& join, TExprContext& ctx) {
11511182

11521183
namespace {
11531184

1154-
TExprNode::TPtr ExpandJoinInput(const TStructExprType& type, TExprNode::TPtr&& arg, TExprContext& ctx) {
1155-
return ctx.Builder(arg->Pos())
1156-
.Callable("ExpandMap")
1157-
.Add(0, std::move(arg))
1158-
.Lambda(1)
1159-
.Param("item")
1160-
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
1161-
auto i = 0U;
1162-
for (const auto& item : type.GetItems()) {
1163-
parent.Callable(i++, "Member")
1164-
.Arg(0, "item")
1165-
.Atom(1, item->GetName())
1166-
.Seal();
1167-
}
1168-
return parent;
1169-
})
1170-
.Seal()
1171-
.Seal().Build();
1172-
}
1185+
11731186

11741187
TExprNode::TPtr SqueezeJoinInputToDict(TExprNode::TPtr&& input, size_t width, const std::vector<ui32>& keys, bool withPayloads, bool multiRow, TExprContext& ctx) {
11751188
YQL_ENSURE(width > 0U && !keys.empty());
@@ -1317,7 +1330,7 @@ TExprNode::TPtr ReplaceJoinOnSide(TExprNode::TPtr&& input, const TTypeAnnotation
13171330

13181331
}
13191332

1320-
TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool useBlockHashJoin) {
1333+
TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool) {
13211334
const auto joinType = join.JoinType().Value();
13221335
YQL_ENSURE(joinType != "Cross"sv);
13231336

@@ -1558,8 +1571,10 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
15581571
TCoArgument leftInputArg{ctx.NewArgument(join.LeftInput().Pos(), "_dq_join_left")};
15591572
TCoArgument rightInputArg{ctx.NewArgument(join.RightInput().Pos(), "_dq_join_right")};
15601573

1561-
auto leftWideFlow = ExpandJoinInput(*leftStructType, leftInputArg.Ptr(), ctx);
1562-
auto rightWideFlow = ExpandJoinInput(*rightStructType, rightInputArg.Ptr(), ctx);
1574+
// For standard Grace join we don't need type conversions
1575+
std::vector<std::pair<TString, const TTypeAnnotationNode*>> emptyConversions;
1576+
auto leftWideFlow = ExpandJoinInput(*leftStructType, leftInputArg.Ptr(), ctx, emptyConversions, join.Pos());
1577+
auto rightWideFlow = ExpandJoinInput(*rightStructType, rightInputArg.Ptr(), ctx, emptyConversions, join.Pos());
15631578

15641579
const auto leftFullWidth = leftNames.size();
15651580
const auto rightFullWidth = rightNames.size();
@@ -1580,22 +1595,6 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
15801595
switch (mode) {
15811596
case EHashJoinMode::GraceAndSelf:
15821597
case EHashJoinMode::Grace:
1583-
if (useBlockHashJoin) {
1584-
// Create TDqPhyBlockHashJoin node with structured inputs - peephole will handle conversion
1585-
// Pass the original structured inputs, not wide flows
1586-
hashJoin = Build<TDqPhyBlockHashJoin>(ctx, join.Pos())
1587-
.LeftInput(leftInputArg)
1588-
.RightInput(rightInputArg)
1589-
.LeftLabel(join.LeftLabel())
1590-
.RightLabel(join.RightLabel())
1591-
.JoinType(join.JoinType())
1592-
.JoinKeys(join.JoinKeys())
1593-
.LeftJoinKeyNames(join.LeftJoinKeyNames())
1594-
.RightJoinKeyNames(join.RightJoinKeyNames())
1595-
.Done().Ptr();
1596-
break;
1597-
}
1598-
15991598
hashJoin = ctx.Builder(join.Pos())
16001599
.Callable(callableName)
16011600
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
@@ -1727,7 +1726,7 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
17271726
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
17281727
for (ui32 i = 0U; i < rightNames.size(); ++i) {
17291728
parent.Atom(2*i, ctx.GetIndexAsString(i), TNodeFlags::Default);
1730-
parent.Atom(2*i + 1, ctx.GetIndexAsString(i + leftNames.size()), TNodeFlags::Default);
1729+
parent.Atom(2*i + 1, ctx.GetIndexAsString(i), TNodeFlags::Default);
17311730
}
17321731
return parent;
17331732
})
@@ -1827,46 +1826,46 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
18271826
ythrow yexception() << "Invalid hash join mode: " << mode;
18281827
}
18291828

1830-
if (!useBlockHashJoin) {
1831-
std::vector<TString> fullColNames;
1832-
for (const auto& v: leftNames) {
1833-
if (leftTableName.empty()) {
1834-
fullColNames.emplace_back(v.first);
1835-
} else {
1836-
fullColNames.emplace_back(FullColumnName(leftTableName, v.first));
1837-
}
1829+
// Apply NarrowMap to convert wide output to structured output for all join types
1830+
std::vector<TString> fullColNames;
1831+
for (const auto& v: leftNames) {
1832+
if (leftTableName.empty()) {
1833+
fullColNames.emplace_back(v.first);
1834+
} else {
1835+
fullColNames.emplace_back(FullColumnName(leftTableName, v.first));
18381836
}
1837+
}
18391838

1840-
for (const auto& v: rightNames ) {
1841-
if (rightTableName.empty()) {
1842-
fullColNames.emplace_back(v.first);
1843-
} else {
1844-
fullColNames.emplace_back(FullColumnName(rightTableName, v.first));
1845-
}
1839+
for (const auto& v: rightNames ) {
1840+
if (rightTableName.empty()) {
1841+
fullColNames.emplace_back(v.first);
1842+
} else {
1843+
fullColNames.emplace_back(FullColumnName(rightTableName, v.first));
18461844
}
1845+
}
18471846

1848-
hashJoin = ctx.Builder(join.Pos())
1849-
.Callable("NarrowMap")
1850-
.Add(0, std::move(hashJoin))
1851-
.Lambda(1)
1852-
.Params("output", fullColNames.size())
1853-
.Callable("AsStruct")
1854-
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
1855-
ui32 i = 0U;
1856-
for (const auto& colName : fullColNames) {
1857-
parent.List(i)
1858-
.Atom(0, colName)
1859-
.Arg(1, "output", i)
1860-
.Seal();
1861-
i++;
1862-
}
1863-
return parent;
1864-
})
1865-
.Seal()
1847+
// Apply NarrowMap to convert wide output to structured output for all join types
1848+
hashJoin = ctx.Builder(join.Pos())
1849+
.Callable("NarrowMap")
1850+
.Add(0, std::move(hashJoin))
1851+
.Lambda(1)
1852+
.Params("output", fullColNames.size())
1853+
.Callable("AsStruct")
1854+
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
1855+
ui32 i = 0U;
1856+
for (const auto& colName : fullColNames) {
1857+
parent.List(i)
1858+
.Atom(0, colName)
1859+
.Arg(1, "output", i)
1860+
.Seal();
1861+
i++;
1862+
}
1863+
return parent;
1864+
})
18661865
.Seal()
18671866
.Seal()
1868-
.Build();
1869-
}
1867+
.Seal()
1868+
.Build();
18701869

18711870
// this func add join to the stage and add connection to it. we do this instead of map connection to reduce data network interacting
18721871
auto addJoinToStage =
@@ -1961,4 +1960,7 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
19611960
.Done();
19621961
}
19631962

1963+
1964+
19641965
} // namespace NYql::NDq
1966+

ydb/library/yql/dq/opt/dq_opt_join.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
12
#pragma once
23

34
#include "dq_opt.h"
45

56
#include <ydb/library/yql/dq/common/dq_common.h>
67
#include <yql/essentials/core/yql_expr_optimize.h>
78
#include <yql/essentials/core/cbo/cbo_optimizer_new.h>
9+
#include <util/generic/map.h>
10+
#include <util/generic/vector.h>
811

912
namespace NYql {
1013

@@ -37,7 +40,19 @@ NNodes::TExprBase DqBuildJoin(
3740

3841
NNodes::TExprBase DqBuildHashJoin(const NNodes::TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool useBlockHashJoin = false);
3942

40-
NNodes::TExprBase DqBuildBlockHashJoin(const NNodes::TDqJoin& join, TExprContext& ctx);
43+
// Updated DqBuildBlockHashJoin function signature with all necessary parameters
44+
NNodes::TExprBase DqBuildBlockHashJoin(
45+
const NNodes::TDqJoin& join,
46+
const TStructExprType* leftStructType,
47+
const TStructExprType* rightStructType,
48+
const std::map<std::string_view, ui32>& leftNames,
49+
const std::map<std::string_view, ui32>& rightNames,
50+
const TVector<NNodes::TCoAtom>& leftJoinKeys,
51+
const TVector<NNodes::TCoAtom>& rightJoinKeys,
52+
NNodes::TCoArgument leftInputArg,
53+
NNodes::TCoArgument rightInputArg,
54+
TExprContext& ctx
55+
);
4156

4257
NNodes::TExprBase DqBuildJoinDict(const NNodes::TDqJoin& join, TExprContext& ctx);
4358

@@ -51,3 +66,4 @@ bool DqCollectJoinRelationsWithStats(
5166

5267
} // namespace NDq
5368
} // namespace NYql
69+

0 commit comments

Comments
 (0)