Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cc_library(
":Patterns",
":pass_inc_gen",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Transforms/LayoutOptimization:Patterns",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformUtils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <utility>

#include "lib/Dialect/TensorExt/Transforms/Patterns.h"
#include "lib/Transforms/LayoutOptimization/Patterns.h"
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
Expand All @@ -23,7 +24,8 @@ struct FoldConvertLayoutIntoAssignLayout
void runOnOperation() override {
MLIRContext* context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FoldConvertLayoutIntoAssignLayoutPattern>(context);
patterns.add<FoldConvertLayoutIntoAssignLayoutPattern, HoistArgLayouts,
FoldLayoutConversions>(context);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
31 changes: 19 additions & 12 deletions lib/Kernel/Kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "lib/Kernel/KernelName.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
Expand All @@ -18,13 +21,15 @@ namespace mlir {
namespace heir {

namespace {
static std::unordered_map<KernelName, std::string> correspondingOp = {
{KernelName::MatvecNaive, "linalg.matvec"},
{KernelName::MatvecDiagonal, "linalg.matvec"},
{KernelName::VecmatDiagonal, "linalg.vecmat"},
{KernelName::MatmulDiagonal, "linalg.matmul"},
{KernelName::MatmulDiagonal, "linalg.conv2d"},
{KernelName::MatmulBicyclic, "linalg.matmul"},
static std::unordered_map<KernelName, std::vector<std::string>>
correspondingOp = {
{KernelName::MatvecNaive, {"linalg.matvec"}},
{KernelName::MatvecDiagonal,
{"linalg.matvec", "linalg.conv_2d_nchw_fchw"}},
{KernelName::VecmatDiagonal, {"linalg.vecmat"}},
{KernelName::MatmulDiagonal, {"linalg.matmul"}},
{KernelName::MatmulDiagonal, {"linalg.conv2d"}},
{KernelName::MatmulBicyclic, {"linalg.matmul"}},
};

std::set<std::string> requiredNontrivial = {"linalg"};
Expand All @@ -36,7 +41,8 @@ bool isSupportedKernel(Operation* op, KernelName name) {
return requiredNontrivial.count(dialect) == 0;
}

if (correspondingOp.find(name) == correspondingOp.end()) {
auto it = correspondingOp.find(name);
if (it == correspondingOp.end()) {
LLVM_DEBUG(llvm::dbgs() << "Kernel name " << kernelNameAsStr(name)
<< "not found in correspondingOp legality map\n");
return false;
Expand All @@ -46,14 +52,15 @@ bool isSupportedKernel(Operation* op, KernelName name) {
llvm::raw_string_ostream ss(actual);
ss << op->getName().getStringRef();

std::string resolvedOpName = correspondingOp.at(name);
if (resolvedOpName == actual) {
auto opForKernelIt = llvm::find(it->second, actual);
if (opForKernelIt != it->second.end()) {
return true;
}

LLVM_DEBUG(llvm::dbgs() << "Kernel " << kernelNameAsStr(name)
<< " is not legal for op " << actual << ", requires "
<< resolvedOpName << "\n");
<< " is not legal for op " << actual
<< ", expected one of ops: "
<< llvm::join(it->second, ", ") << "\n");
return false;
}

Expand Down
69 changes: 64 additions & 5 deletions lib/Kernel/KernelImplementationTest.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cstdint>
#include <functional>
#include <vector>

#include "gtest/gtest.h" // from @googletest
Expand All @@ -23,6 +24,8 @@ namespace {
// Parametrize over whether the kernel is rolled
class KernelImplementationTest : public testing::TestWithParam<bool> {};

using tensor4d = std::vector<std::vector<std::vector<std::vector<int>>>>;

TEST_P(KernelImplementationTest, TestHaleviShoupMatvec) {
std::vector<int> vector = {0, 1, 2, 3};
// Pre-packed diagonally
Expand Down Expand Up @@ -122,11 +125,6 @@ TEST_P(KernelImplementationTest, Test2DConvWithLayout) {
std::vector<std::vector<int>> packedData =
evaluateLayoutOnMatrix(dataLayout, data);

SmallVector<int64_t> strides = {1, 1};
auto filterLayout = get2dConvFilterRelation(filterType, dataType, strides, 0);
std::vector<std::vector<int>> packedFilter =
evaluateLayoutOnMatrix(filterLayout, matrix);

auto matrixLayout =
get2dConvFilterDiagonalizedRelation(filterType, dataType,
/*padding=*/0, numSlots)
Expand All @@ -152,6 +150,67 @@ TEST_P(KernelImplementationTest, Test2DConvWithLayout) {
EXPECT_EQ(extractedResult, expected);
}

TEST_P(KernelImplementationTest, Test2DNchwFchwConvWithLayout) {
MLIRContext context;
RankedTensorType dataType =
RankedTensorType::get({1, 1, 4, 4}, mlir::IndexType::get(&context));
RankedTensorType filterType =
RankedTensorType::get({4, 1, 2, 2}, mlir::IndexType::get(&context));

int numSlots = 16;
// 1x1x4x4 input data, 4x1x2x2 filter
tensor4d data = {
{{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}}};
tensor4d filter = {{{{7, 0}, {5, 4}}},
{{{9, 4}, {0, 3}}},
{{{7, 8}, {8, 6}}},
{{{6, 0}, {5, 4}}}};

std::function<int(const std::vector<int64_t>&)> getDataValueFn =
[&](const std::vector<int64_t>& domainPoint) -> int {
return data[domainPoint[0]][domainPoint[1]][domainPoint[2]][domainPoint[3]];
};
auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots);
std::vector<std::vector<int>> packedData =
evaluateLayout(dataLayout, getDataValueFn);

SmallVector<int64_t> strides = {2, 2};
auto filterLayout = get2dConvChwFchwFilterDiagonalizedRelation(
filterType, dataType, strides, 0, numSlots, true);
ASSERT_TRUE(succeeded(filterLayout));
std::function<int(const std::vector<int64_t>&)> getFilterValueFn =
[&](const std::vector<int64_t>& domainPoint) -> int {
return filter[domainPoint[0]][domainPoint[1]][domainPoint[2]]
[domainPoint[3]];
};
std::vector<std::vector<int>> packedFilter =
evaluateLayout(filterLayout.value(), getFilterValueFn);
auto expandedFilterShape =
get2dConvChwFchwFilterExpandedType(filterType, dataType, 0, strides);

// The expected row-major layout of the 4 2x2 result tensors is:
// [40, 72, 168, 200]
// [19, 51, 147, 179]
// [70, 128, 302, 360]
// [40, 70, 160, 190]
// But the result of the halevi-shoup kernel also includes gapping between the
// values, so the expected vector is:
std::vector<int> expected = {40, 19, 72, 51, 70, 40, 128, 70,
168, 147, 200, 179, 302, 160, 360, 190};

LiteralValue matrixInput = packedFilter;
LiteralValue vectorInput = packedData[0];

auto dag = implementHaleviShoup(vectorInput, matrixInput,
expandedFilterShape.getShape(),
DagType::intTensor(32, {numSlots}),
/*zeroDiagonals=*/{}, /*unroll=*/GetParam());
LiteralValue actual = evalKernel(dag)[0];
// Result is 4 2x2 tensors with a row-major layout.
std::vector<int> actualVector = std::get<std::vector<int>>(actual.get());
EXPECT_EQ(actualVector, expected);
}

TEST(KernelImplementationTest, BicyclicMatmul) {
MLIRContext context;
std::vector<std::vector<int>> matrixA = {
Expand Down
71 changes: 71 additions & 0 deletions lib/Transforms/LayoutOptimization/InterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace heir {
using tensor_ext::ConvertLayoutOp;
using tensor_ext::LayoutAttr;
static auto& kLayoutAttrName = tensor_ext::TensorExtDialect::kLayoutAttrName;
using ::mlir::linalg::Conv2DNchwFchwOp;
using ::mlir::linalg::MatmulOp;
using ::mlir::linalg::MatvecOp;

namespace {
Expand Down Expand Up @@ -109,6 +111,34 @@ struct MatmulHoistingImpl
}
};

struct Conv2DNchwFchwHoistingImpl
: public LayoutConversionHoistableOpInterface::ExternalModel<
Conv2DNchwFchwHoistingImpl, Conv2DNchwFchwOp> {
std::vector<Hoister> getHoisters(
Operation* op, tensor_ext::ConvertLayoutOp convertLayoutOp) const {
std::vector<Hoister> hoisters;

auto kernel = op->getAttrOfType<secret::KernelAttr>(
secret::SecretDialect::kKernelAttrName);
if (!kernel) {
LLVM_DEBUG(llvm::dbgs()
<< "Kernel attribute not found on op " << *op << "\n");
return hoisters;
}

if (!op->hasAttr(tensor_ext::TensorExtDialect::kLayoutAttrName)) {
LLVM_DEBUG(llvm::dbgs()
<< "Layout attribute not found on op " << *op << "\n");
return hoisters;
}

// FIXME(#2800): add support for Conv2DNchwFchwOp hoisters - right now the
// default matvec hoister assumes that there is only one input/output
// vector.
return hoisters;
}
};

} // namespace

Hoister createTrivialHoister(Operation* op) {
Expand Down Expand Up @@ -166,6 +196,46 @@ Hoister createPrecomposingMatvecHoister(linalg::MatvecOp op) {
};
}

Hoister createPrecomposingConvHoister(linalg::Conv2DNchwFchwOp op) {
return [op](ConvertLayoutOp convertLayoutOp) -> llvm::FailureOr<HoistResult> {
HoistResult result;
auto fromLayout = dyn_cast<LayoutAttr>(convertLayoutOp.getFromLayout());
auto toLayout = dyn_cast<LayoutAttr>(convertLayoutOp.getToLayout());

if (!fromLayout || !toLayout) return failure();

// Operand order for Conv2DNchwFchwOp is:
// 0: data
// 1: filter
// 2: init (output vector)
FailureOr<Attribute> oldFilterLayoutRes =
findAttributeAssociatedWith(op->getOperand(1), kLayoutAttrName);
assert(succeeded(oldFilterLayoutRes) && "failed to find filter layout!");
LayoutAttr oldFilterLayout =
dyn_cast<LayoutAttr>(oldFilterLayoutRes.value());
if (!oldFilterLayout) return failure();

result.convertLayoutOp = convertLayoutOp;
result.newOutputLayout = toLayout;

// The kernel is unchanged, so copy the existing kernel attr
result.newKernel = op->getAttrOfType<secret::KernelAttr>(
secret::SecretDialect::kKernelAttrName)
.getName();

presburger::IntegerRelation newFilterLayoutRelation =
hoistConversionThroughMatvec(oldFilterLayout.getIntegerRelation(),
fromLayout.getIntegerRelation(),
toLayout.getIntegerRelation());
Attribute newFilterLayout = LayoutAttr::getFromIntegerRelation(
op->getContext(), newFilterLayoutRelation);
// New operand order: data(0), filter(1), init(2)
result.newInputLayouts =
SmallVector<Attribute>{toLayout, newFilterLayout, toLayout};
return result;
};
}

void registerLayoutConversionHoistableInterface(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, arith::ArithDialect* dialect) {
arith::AddFOp::attachInterface<DoNothingHoistingImpl<arith::AddFOp>>(*ctx);
Expand All @@ -178,6 +248,7 @@ void registerLayoutConversionHoistableInterface(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect* dialect) {
linalg::MatvecOp::attachInterface<MatvecHoistingImpl>(*ctx);
linalg::MatmulOp::attachInterface<MatmulHoistingImpl>(*ctx);
linalg::Conv2DNchwFchwOp::attachInterface<Conv2DNchwFchwHoistingImpl>(*ctx);
});
}

Expand Down
3 changes: 3 additions & 0 deletions lib/Transforms/LayoutOptimization/InterfaceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Hoister createTrivialHoister(Operation* op);
/// to vecToLayout, while keeping the kernel the same.
Hoister createPrecomposingMatvecHoister(linalg::MatvecOp op);

/// Same as createPrecomposingMatvecHoister but for Conv2DNchwFchwOp.
Hoister createPrecomposingConvHoister(linalg::Conv2DNchwFchwOp op);

void registerLayoutConversionHoistableInterface(DialectRegistry& registry);

} // namespace heir
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/LayoutPropagation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cc_library(
"@heir//lib/Utils:AttributeUtils",
"@heir//lib/Utils/Layout:Convolution",
"@heir//lib/Utils/Layout:Hoisting",
"@heir//lib/Utils/Layout:IslConversion",
"@heir//lib/Utils/Layout:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineAnalysis",
Expand Down
Loading
Loading