Skip to content

Try to rewrite gracejoin in peephole instead of creating blockhashjoi… #21015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions ydb/core/kqp/opt/peephole/kqp_opt_peephole.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

#include "kqp_opt_peephole.h"
#include "kqp_opt_peephole_rules.h"

Expand Down Expand Up @@ -88,8 +89,8 @@ TStatus ReplaceNonDetFunctionsWithParams(TExprNode::TPtr& input, TExprContext& c

class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
public:
TKqpPeepholeTransformer(TTypeAnnotationContext& typesCtx, TSet<TString> disabledOpts)
: TOptimizeTransformerBase(&typesCtx, NYql::NLog::EComponent::ProviderKqp, disabledOpts)
TKqpPeepholeTransformer(TTypeAnnotationContext& typesCtx, TSet<TString> disabledOpts, TKikimrConfiguration::TPtr config)
: TOptimizeTransformerBase(&typesCtx, NYql::NLog::EComponent::ProviderKqp, disabledOpts), Config(config)
{
#define HNDL(name) "KqpPeephole-"#name, Hndl(&TKqpPeepholeTransformer::name)
AddHandler(0, &TDqReplicate::Match, HNDL(RewriteReplicate));
Expand All @@ -99,6 +100,13 @@ class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
AddHandler(0, &TDqPhyJoinDict::Match, HNDL(RewriteDictJoin));
AddHandler(0, &TDqJoin::Match, HNDL(RewritePureJoin));
AddHandler(0, &TDqPhyBlockHashJoin::Match, HNDL(RewriteBlockHashJoin));
AddHandler(0, [](const TExprNode* node) {
bool isGraceJoinCore = node->IsCallable("GraceJoinCore");
if (isGraceJoinCore) {
Cerr << "Handler matched GraceJoinCore for BlockHashJoin conversion!" << Endl;
}
return isGraceJoinCore;
}, HNDL(RewriteBlockHashJoinCore));
AddHandler(0, TOptimizeTransformerBase::Any(), HNDL(BuildWideReadTable));
AddHandler(0, &TDqPhyLength::Match, HNDL(RewriteLength));
AddHandler(0, &TKqpWriteConstraint::Match, HNDL(RewriteKqpWriteConstraint));
Expand Down Expand Up @@ -160,11 +168,21 @@ class TKqpPeepholeTransformer : public TOptimizeTransformerBase {
return output;
}

TMaybeNode<TExprBase> RewriteBlockHashJoinCore(TExprBase node, TExprContext& ctx) {
bool useBlockHashJoin = Config && Config->UseBlockHashJoin.Get().GetOrElse(false);
TExprBase output = DqPeepholeRewriteBlockHashJoinCore(node, ctx, useBlockHashJoin);
DumpAppliedRule("RewriteBlockHashJoinCore", node.Ptr(), output.Ptr(), ctx);
return output;
}

TMaybeNode<TExprBase> RewriteKqpWriteConstraint(TExprBase node, TExprContext& ctx) {
TExprBase output = KqpRewriteWriteConstraint(node, ctx);
DumpAppliedRule("RewriteKqpWriteConstraint", node.Ptr(), output.Ptr(), ctx);
return output;
}

private:
TKikimrConfiguration::TPtr Config;
};

struct TKqpPeepholePipelineConfigurator : IPipelineConfigurator {
Expand All @@ -183,7 +201,7 @@ struct TKqpPeepholePipelineConfigurator : IPipelineConfigurator {
}

void AfterOptimize(TTransformationPipeline* pipeline) const override {
pipeline->Add(new TKqpPeepholeTransformer(*pipeline->GetTypeAnnotationContext(), DisabledOpts), "KqpPeephole");
pipeline->Add(new TKqpPeepholeTransformer(*pipeline->GetTypeAnnotationContext(), DisabledOpts, Config), "KqpPeephole");
}

private:
Expand Down Expand Up @@ -651,3 +669,4 @@ TAutoPtr<IGraphTransformer> CreateKqpTxsPeepholeTransformer(
}

} // namespace NKikimr::NKqp::NOpt

148 changes: 75 additions & 73 deletions ydb/library/yql/dq/opt/dq_opt_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,37 @@ using namespace NYql::NNodes;

namespace {

TExprNode::TPtr ExpandJoinInput(const TStructExprType& type, TExprNode::TPtr&& arg, TExprContext& ctx, std::vector<std::pair<TString, const TTypeAnnotationNode*>>& convertedItems, TPositionHandle position) {
return ctx.Builder(arg->Pos())
.Callable("ExpandMap")
.Add(0, std::move(arg))
.Lambda(1)
.Param("item")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
auto i = 0U;
for (const auto& item : type.GetItems()) {
parent.Callable(i, "Member")
.Arg(0, "item")
.Atom(1, item->GetName())
.Seal();
i++;
}
for (const auto& convertedItem : convertedItems) {
parent.Callable(i, "StrictCast")
.Callable(0, "Member")
.Arg(0, "item")
.Atom(1, convertedItem.first)
.Seal()
.Add(1, ExpandType(position, *convertedItem.second, ctx))
.Seal();
i++;
}
return parent;
})
.Seal()
.Seal().Build();
}

struct TJoinInputDesc {
TJoinInputDesc(TMaybe<THashSet<TStringBuf>> labels, const TExprBase& input,
TSet<std::pair<TStringBuf, TStringBuf>>&& keys)
Expand Down Expand Up @@ -1151,25 +1182,7 @@ TExprBase DqBuildJoinDict(const TDqJoin& join, TExprContext& ctx) {

namespace {

TExprNode::TPtr ExpandJoinInput(const TStructExprType& type, TExprNode::TPtr&& arg, TExprContext& ctx) {
return ctx.Builder(arg->Pos())
.Callable("ExpandMap")
.Add(0, std::move(arg))
.Lambda(1)
.Param("item")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
auto i = 0U;
for (const auto& item : type.GetItems()) {
parent.Callable(i++, "Member")
.Arg(0, "item")
.Atom(1, item->GetName())
.Seal();
}
return parent;
})
.Seal()
.Seal().Build();
}


TExprNode::TPtr SqueezeJoinInputToDict(TExprNode::TPtr&& input, size_t width, const std::vector<ui32>& keys, bool withPayloads, bool multiRow, TExprContext& ctx) {
YQL_ENSURE(width > 0U && !keys.empty());
Expand Down Expand Up @@ -1317,7 +1330,7 @@ TExprNode::TPtr ReplaceJoinOnSide(TExprNode::TPtr&& input, const TTypeAnnotation

}

TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool useBlockHashJoin) {
TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool) {
const auto joinType = join.JoinType().Value();
YQL_ENSURE(joinType != "Cross"sv);

Expand Down Expand Up @@ -1558,8 +1571,10 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
TCoArgument leftInputArg{ctx.NewArgument(join.LeftInput().Pos(), "_dq_join_left")};
TCoArgument rightInputArg{ctx.NewArgument(join.RightInput().Pos(), "_dq_join_right")};

auto leftWideFlow = ExpandJoinInput(*leftStructType, leftInputArg.Ptr(), ctx);
auto rightWideFlow = ExpandJoinInput(*rightStructType, rightInputArg.Ptr(), ctx);
// For standard Grace join we don't need type conversions
std::vector<std::pair<TString, const TTypeAnnotationNode*>> emptyConversions;
auto leftWideFlow = ExpandJoinInput(*leftStructType, leftInputArg.Ptr(), ctx, emptyConversions, join.Pos());
auto rightWideFlow = ExpandJoinInput(*rightStructType, rightInputArg.Ptr(), ctx, emptyConversions, join.Pos());

const auto leftFullWidth = leftNames.size();
const auto rightFullWidth = rightNames.size();
Expand All @@ -1580,22 +1595,6 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
switch (mode) {
case EHashJoinMode::GraceAndSelf:
case EHashJoinMode::Grace:
if (useBlockHashJoin) {
// Create TDqPhyBlockHashJoin node with structured inputs - peephole will handle conversion
// Pass the original structured inputs, not wide flows
hashJoin = Build<TDqPhyBlockHashJoin>(ctx, join.Pos())
.LeftInput(leftInputArg)
.RightInput(rightInputArg)
.LeftLabel(join.LeftLabel())
.RightLabel(join.RightLabel())
.JoinType(join.JoinType())
.JoinKeys(join.JoinKeys())
.LeftJoinKeyNames(join.LeftJoinKeyNames())
.RightJoinKeyNames(join.RightJoinKeyNames())
.Done().Ptr();
break;
}

hashJoin = ctx.Builder(join.Pos())
.Callable(callableName)
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
Expand Down Expand Up @@ -1727,7 +1726,7 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (ui32 i = 0U; i < rightNames.size(); ++i) {
parent.Atom(2*i, ctx.GetIndexAsString(i), TNodeFlags::Default);
parent.Atom(2*i + 1, ctx.GetIndexAsString(i + leftNames.size()), TNodeFlags::Default);
parent.Atom(2*i + 1, ctx.GetIndexAsString(i), TNodeFlags::Default);
}
return parent;
})
Expand Down Expand Up @@ -1827,46 +1826,46 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
ythrow yexception() << "Invalid hash join mode: " << mode;
}

if (!useBlockHashJoin) {
std::vector<TString> fullColNames;
for (const auto& v: leftNames) {
if (leftTableName.empty()) {
fullColNames.emplace_back(v.first);
} else {
fullColNames.emplace_back(FullColumnName(leftTableName, v.first));
}
// Apply NarrowMap to convert wide output to structured output for all join types
std::vector<TString> fullColNames;
for (const auto& v: leftNames) {
if (leftTableName.empty()) {
fullColNames.emplace_back(v.first);
} else {
fullColNames.emplace_back(FullColumnName(leftTableName, v.first));
}
}

for (const auto& v: rightNames ) {
if (rightTableName.empty()) {
fullColNames.emplace_back(v.first);
} else {
fullColNames.emplace_back(FullColumnName(rightTableName, v.first));
}
for (const auto& v: rightNames ) {
if (rightTableName.empty()) {
fullColNames.emplace_back(v.first);
} else {
fullColNames.emplace_back(FullColumnName(rightTableName, v.first));
}
}

hashJoin = ctx.Builder(join.Pos())
.Callable("NarrowMap")
.Add(0, std::move(hashJoin))
.Lambda(1)
.Params("output", fullColNames.size())
.Callable("AsStruct")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 i = 0U;
for (const auto& colName : fullColNames) {
parent.List(i)
.Atom(0, colName)
.Arg(1, "output", i)
.Seal();
i++;
}
return parent;
})
.Seal()
// Apply NarrowMap to convert wide output to structured output for all join types
hashJoin = ctx.Builder(join.Pos())
.Callable("NarrowMap")
.Add(0, std::move(hashJoin))
.Lambda(1)
.Params("output", fullColNames.size())
.Callable("AsStruct")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
ui32 i = 0U;
for (const auto& colName : fullColNames) {
parent.List(i)
.Atom(0, colName)
.Arg(1, "output", i)
.Seal();
i++;
}
return parent;
})
.Seal()
.Seal()
.Build();
}
.Seal()
.Build();

// this func add join to the stage and add connection to it. we do this instead of map connection to reduce data network interacting
auto addJoinToStage =
Expand Down Expand Up @@ -1961,4 +1960,7 @@ TExprBase DqBuildHashJoin(const TDqJoin& join, EHashJoinMode mode, TExprContext&
.Done();
}



} // namespace NYql::NDq

18 changes: 17 additions & 1 deletion ydb/library/yql/dq/opt/dq_opt_join.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@

#pragma once

#include "dq_opt.h"

#include <ydb/library/yql/dq/common/dq_common.h>
#include <yql/essentials/core/yql_expr_optimize.h>
#include <yql/essentials/core/cbo/cbo_optimizer_new.h>
#include <util/generic/map.h>
#include <util/generic/vector.h>

namespace NYql {

Expand Down Expand Up @@ -37,7 +40,19 @@ NNodes::TExprBase DqBuildJoin(

NNodes::TExprBase DqBuildHashJoin(const NNodes::TDqJoin& join, EHashJoinMode mode, TExprContext& ctx, IOptimizationContext& optCtx, bool shuffleElimination, bool shuffleEliminationWithMap, bool useBlockHashJoin = false);

NNodes::TExprBase DqBuildBlockHashJoin(const NNodes::TDqJoin& join, TExprContext& ctx);
// Updated DqBuildBlockHashJoin function signature with all necessary parameters
NNodes::TExprBase DqBuildBlockHashJoin(
const NNodes::TDqJoin& join,
const TStructExprType* leftStructType,
const TStructExprType* rightStructType,
const std::map<std::string_view, ui32>& leftNames,
const std::map<std::string_view, ui32>& rightNames,
const TVector<NNodes::TCoAtom>& leftJoinKeys,
const TVector<NNodes::TCoAtom>& rightJoinKeys,
NNodes::TCoArgument leftInputArg,
NNodes::TCoArgument rightInputArg,
TExprContext& ctx
);

NNodes::TExprBase DqBuildJoinDict(const NNodes::TDqJoin& join, TExprContext& ctx);

Expand All @@ -51,3 +66,4 @@ bool DqCollectJoinRelationsWithStats(

} // namespace NDq
} // namespace NYql

Loading
Loading