diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index b0b34057e9..a1a48f64d5 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -144,6 +144,7 @@ cc_library( ":Patterns", ":pass_inc_gen", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Transforms/LayoutOptimization:Patterns", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TransformUtils", diff --git a/lib/Dialect/TensorExt/Transforms/FoldConvertLayoutIntoAssignLayout.cpp b/lib/Dialect/TensorExt/Transforms/FoldConvertLayoutIntoAssignLayout.cpp index e839b3e2b5..60b69cc012 100644 --- a/lib/Dialect/TensorExt/Transforms/FoldConvertLayoutIntoAssignLayout.cpp +++ b/lib/Dialect/TensorExt/Transforms/FoldConvertLayoutIntoAssignLayout.cpp @@ -3,6 +3,7 @@ #include #include "lib/Dialect/TensorExt/Transforms/Patterns.h" +#include "lib/Transforms/LayoutOptimization/Patterns.h" #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -23,7 +24,8 @@ struct FoldConvertLayoutIntoAssignLayout void runOnOperation() override { MLIRContext* context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/lib/Kernel/Kernel.cpp b/lib/Kernel/Kernel.cpp index ab6b87896d..2763b0cf69 100644 --- a/lib/Kernel/Kernel.cpp +++ b/lib/Kernel/Kernel.cpp @@ -4,8 +4,11 @@ #include #include #include +#include #include "lib/Kernel/KernelName.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project @@ -18,13 +21,15 @@ namespace mlir { namespace heir { namespace { -static std::unordered_map correspondingOp = { - {KernelName::MatvecNaive, "linalg.matvec"}, - {KernelName::MatvecDiagonal, "linalg.matvec"}, - {KernelName::VecmatDiagonal, "linalg.vecmat"}, - {KernelName::MatmulDiagonal, "linalg.matmul"}, - {KernelName::MatmulDiagonal, "linalg.conv2d"}, - {KernelName::MatmulBicyclic, "linalg.matmul"}, +static std::unordered_map> + correspondingOp = { + {KernelName::MatvecNaive, {"linalg.matvec"}}, + {KernelName::MatvecDiagonal, + {"linalg.matvec", "linalg.conv_2d_nchw_fchw"}}, + {KernelName::VecmatDiagonal, {"linalg.vecmat"}}, + {KernelName::MatmulDiagonal, {"linalg.matmul"}}, + {KernelName::MatmulDiagonal, {"linalg.conv2d"}}, + {KernelName::MatmulBicyclic, {"linalg.matmul"}}, }; std::set requiredNontrivial = {"linalg"}; @@ -36,7 +41,8 @@ bool isSupportedKernel(Operation* op, KernelName name) { return requiredNontrivial.count(dialect) == 0; } - if (correspondingOp.find(name) == correspondingOp.end()) { + auto it = correspondingOp.find(name); + if (it == correspondingOp.end()) { LLVM_DEBUG(llvm::dbgs() << "Kernel name " << kernelNameAsStr(name) << "not found in correspondingOp legality map\n"); return false; @@ -46,14 +52,15 @@ bool isSupportedKernel(Operation* op, KernelName name) { llvm::raw_string_ostream ss(actual); ss << op->getName().getStringRef(); - std::string resolvedOpName = correspondingOp.at(name); - if (resolvedOpName == actual) { + auto opForKernelIt = llvm::find(it->second, actual); + if (opForKernelIt != it->second.end()) { return true; } LLVM_DEBUG(llvm::dbgs() << "Kernel " << kernelNameAsStr(name) - << " is not legal for op " << actual << ", requires " - << resolvedOpName << "\n"); + << " is not legal for op " << actual + << ", expected one of ops: " + << llvm::join(it->second, ", ") << "\n"); return false; } diff --git a/lib/Kernel/KernelImplementationTest.cpp b/lib/Kernel/KernelImplementationTest.cpp index 6c141ba22f..127f42b7c5 100644 --- a/lib/Kernel/KernelImplementationTest.cpp +++ b/lib/Kernel/KernelImplementationTest.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "gtest/gtest.h" // from @googletest @@ -23,6 +24,8 @@ namespace { // Parametrize over whether the kernel is rolled class KernelImplementationTest : public testing::TestWithParam {}; +using tensor4d = std::vector>>>; + TEST_P(KernelImplementationTest, TestHaleviShoupMatvec) { std::vector vector = {0, 1, 2, 3}; // Pre-packed diagonally @@ -122,11 +125,6 @@ TEST_P(KernelImplementationTest, Test2DConvWithLayout) { std::vector> packedData = evaluateLayoutOnMatrix(dataLayout, data); - SmallVector strides = {1, 1}; - auto filterLayout = get2dConvFilterRelation(filterType, dataType, strides, 0); - std::vector> packedFilter = - evaluateLayoutOnMatrix(filterLayout, matrix); - auto matrixLayout = get2dConvFilterDiagonalizedRelation(filterType, dataType, /*padding=*/0, numSlots) @@ -152,6 +150,67 @@ TEST_P(KernelImplementationTest, Test2DConvWithLayout) { EXPECT_EQ(extractedResult, expected); } +TEST_P(KernelImplementationTest, Test2DNchwFchwConvWithLayout) { + MLIRContext context; + RankedTensorType dataType = + RankedTensorType::get({1, 1, 4, 4}, mlir::IndexType::get(&context)); + RankedTensorType filterType = + RankedTensorType::get({4, 1, 2, 2}, mlir::IndexType::get(&context)); + + int numSlots = 16; + // 1x1x4x4 input data, 4x1x2x2 filter + tensor4d data = { + {{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}}}; + tensor4d filter = {{{{7, 0}, {5, 4}}}, + {{{9, 4}, {0, 3}}}, + {{{7, 8}, {8, 6}}}, + {{{6, 0}, {5, 4}}}}; + + std::function&)> getDataValueFn = + [&](const std::vector& domainPoint) -> int { + return data[domainPoint[0]][domainPoint[1]][domainPoint[2]][domainPoint[3]]; + }; + auto dataLayout = getRowMajorLayoutRelation(dataType, numSlots); + std::vector> packedData = + evaluateLayout(dataLayout, getDataValueFn); + + SmallVector strides = {2, 2}; + auto filterLayout = get2dConvChwFchwFilterDiagonalizedRelation( + filterType, dataType, strides, 0, numSlots, true); + ASSERT_TRUE(succeeded(filterLayout)); + std::function&)> getFilterValueFn = + [&](const std::vector& domainPoint) -> int { + return filter[domainPoint[0]][domainPoint[1]][domainPoint[2]] + [domainPoint[3]]; + }; + std::vector> packedFilter = + evaluateLayout(filterLayout.value(), getFilterValueFn); + auto expandedFilterShape = + get2dConvChwFchwFilterExpandedType(filterType, dataType, 0, strides); + + // The expected row-major layout of the 4 2x2 result tensors is: + // [40, 72, 168, 200] + // [19, 51, 147, 179] + // [70, 128, 302, 360] + // [40, 70, 160, 190] + // But the result of the halevi-shoup kernel also includes gapping between the + // values, so the expected vector is: + std::vector expected = {40, 19, 72, 51, 70, 40, 128, 70, + 168, 147, 200, 179, 302, 160, 360, 190}; + + LiteralValue matrixInput = packedFilter; + LiteralValue vectorInput = packedData[0]; + + auto dag = implementHaleviShoup(vectorInput, matrixInput, + expandedFilterShape.getShape(), + DagType::intTensor(32, {numSlots}), + /*zeroDiagonals=*/{}, /*unroll=*/GetParam()); + LiteralValue actual = evalKernel(dag)[0]; + // Result is 4 2x2 tensors with a row-major layout. + std::vector actualVector = std::get>(actual.get()); + EXPECT_EQ(actualVector, expected); +} + TEST(KernelImplementationTest, BicyclicMatmul) { MLIRContext context; std::vector> matrixA = { diff --git a/lib/Transforms/LayoutOptimization/InterfaceImpl.cpp b/lib/Transforms/LayoutOptimization/InterfaceImpl.cpp index ddee635307..5908ff79f3 100644 --- a/lib/Transforms/LayoutOptimization/InterfaceImpl.cpp +++ b/lib/Transforms/LayoutOptimization/InterfaceImpl.cpp @@ -33,6 +33,8 @@ namespace heir { using tensor_ext::ConvertLayoutOp; using tensor_ext::LayoutAttr; static auto& kLayoutAttrName = tensor_ext::TensorExtDialect::kLayoutAttrName; +using ::mlir::linalg::Conv2DNchwFchwOp; +using ::mlir::linalg::MatmulOp; using ::mlir::linalg::MatvecOp; namespace { @@ -109,6 +111,34 @@ struct MatmulHoistingImpl } }; +struct Conv2DNchwFchwHoistingImpl + : public LayoutConversionHoistableOpInterface::ExternalModel< + Conv2DNchwFchwHoistingImpl, Conv2DNchwFchwOp> { + std::vector getHoisters( + Operation* op, tensor_ext::ConvertLayoutOp convertLayoutOp) const { + std::vector hoisters; + + auto kernel = op->getAttrOfType( + secret::SecretDialect::kKernelAttrName); + if (!kernel) { + LLVM_DEBUG(llvm::dbgs() + << "Kernel attribute not found on op " << *op << "\n"); + return hoisters; + } + + if (!op->hasAttr(tensor_ext::TensorExtDialect::kLayoutAttrName)) { + LLVM_DEBUG(llvm::dbgs() + << "Layout attribute not found on op " << *op << "\n"); + return hoisters; + } + + // FIXME(#2800): add support for Conv2DNchwFchwOp hoisters - right now the + // default matvec hoister assumes that there is only one input/output + // vector. + return hoisters; + } +}; + } // namespace Hoister createTrivialHoister(Operation* op) { @@ -166,6 +196,46 @@ Hoister createPrecomposingMatvecHoister(linalg::MatvecOp op) { }; } +Hoister createPrecomposingConvHoister(linalg::Conv2DNchwFchwOp op) { + return [op](ConvertLayoutOp convertLayoutOp) -> llvm::FailureOr { + HoistResult result; + auto fromLayout = dyn_cast(convertLayoutOp.getFromLayout()); + auto toLayout = dyn_cast(convertLayoutOp.getToLayout()); + + if (!fromLayout || !toLayout) return failure(); + + // Operand order for Conv2DNchwFchwOp is: + // 0: data + // 1: filter + // 2: init (output vector) + FailureOr oldFilterLayoutRes = + findAttributeAssociatedWith(op->getOperand(1), kLayoutAttrName); + assert(succeeded(oldFilterLayoutRes) && "failed to find filter layout!"); + LayoutAttr oldFilterLayout = + dyn_cast(oldFilterLayoutRes.value()); + if (!oldFilterLayout) return failure(); + + result.convertLayoutOp = convertLayoutOp; + result.newOutputLayout = toLayout; + + // The kernel is unchanged, so copy the existing kernel attr + result.newKernel = op->getAttrOfType( + secret::SecretDialect::kKernelAttrName) + .getName(); + + presburger::IntegerRelation newFilterLayoutRelation = + hoistConversionThroughMatvec(oldFilterLayout.getIntegerRelation(), + fromLayout.getIntegerRelation(), + toLayout.getIntegerRelation()); + Attribute newFilterLayout = LayoutAttr::getFromIntegerRelation( + op->getContext(), newFilterLayoutRelation); + // New operand order: data(0), filter(1), init(2) + result.newInputLayouts = + SmallVector{toLayout, newFilterLayout, toLayout}; + return result; + }; +} + void registerLayoutConversionHoistableInterface(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, arith::ArithDialect* dialect) { arith::AddFOp::attachInterface>(*ctx); @@ -178,6 +248,7 @@ void registerLayoutConversionHoistableInterface(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect* dialect) { linalg::MatvecOp::attachInterface(*ctx); linalg::MatmulOp::attachInterface(*ctx); + linalg::Conv2DNchwFchwOp::attachInterface(*ctx); }); } diff --git a/lib/Transforms/LayoutOptimization/InterfaceImpl.h b/lib/Transforms/LayoutOptimization/InterfaceImpl.h index 7331bffc32..1d974886cf 100644 --- a/lib/Transforms/LayoutOptimization/InterfaceImpl.h +++ b/lib/Transforms/LayoutOptimization/InterfaceImpl.h @@ -17,6 +17,9 @@ Hoister createTrivialHoister(Operation* op); /// to vecToLayout, while keeping the kernel the same. Hoister createPrecomposingMatvecHoister(linalg::MatvecOp op); +/// Same as createPrecomposingMatvecHoister but for Conv2DNchwFchwOp. +Hoister createPrecomposingConvHoister(linalg::Conv2DNchwFchwOp op); + void registerLayoutConversionHoistableInterface(DialectRegistry& registry); } // namespace heir diff --git a/lib/Transforms/LayoutPropagation/BUILD b/lib/Transforms/LayoutPropagation/BUILD index 9510fedba7..eb58dd10f7 100644 --- a/lib/Transforms/LayoutPropagation/BUILD +++ b/lib/Transforms/LayoutPropagation/BUILD @@ -23,6 +23,7 @@ cc_library( "@heir//lib/Utils:AttributeUtils", "@heir//lib/Utils/Layout:Convolution", "@heir//lib/Utils/Layout:Hoisting", + "@heir//lib/Utils/Layout:IslConversion", "@heir//lib/Utils/Layout:Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", diff --git a/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp b/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp index 7e11cfb016..5f6efdfcb5 100644 --- a/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp +++ b/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp @@ -22,6 +22,7 @@ #include "lib/Utils/AttributeUtils.h" #include "lib/Utils/Layout/Convolution.h" #include "lib/Utils/Layout/Hoisting.h" +#include "lib/Utils/Layout/IslConversion.h" #include "lib/Utils/Layout/Utils.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project @@ -45,6 +46,7 @@ #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project @@ -59,6 +61,7 @@ namespace mlir { namespace heir { +using linalg::Conv2DNchwFchwOp; using linalg::Conv2DOp; using linalg::MatmulOp; using linalg::MatvecOp; @@ -138,6 +141,7 @@ struct LayoutPropagation : impl::LayoutPropagationBase { LogicalResult visitOperation(GenericOp op); LogicalResult visitOperation(ReduceOp op); LogicalResult visitOperation(Conv2DOp op); + LogicalResult visitOperation(Conv2DNchwFchwOp op); LogicalResult visitOperation(VecmatOp op); LogicalResult visitOperation(MatvecOp op); LogicalResult visitOperation(MatmulOp op); @@ -158,6 +162,7 @@ struct LayoutPropagation : impl::LayoutPropagationBase { // Op-specific compatibility functions CompatibilityResult hasCompatibleArgumentLayouts(Conv2DOp op); + CompatibilityResult hasCompatibleArgumentLayouts(Conv2DNchwFchwOp op); CompatibilityResult hasCompatibleArgumentLayouts(ReduceOp op); CompatibilityResult hasCompatibleArgumentLayouts(VecmatOp op); CompatibilityResult hasCompatibleArgumentLayouts(MatvecOp op); @@ -284,7 +289,7 @@ LogicalResult LayoutPropagation::visitOperation(Operation* op) { // secret ops .Case([&](auto op) { return visitOperation(op); }) // linalg ops - .Case( + .Case( [&](auto op) { return visitOperation(op); }) // affine ops .Case([&](auto op) { return visitOperation(op); }) @@ -652,12 +657,123 @@ LogicalResult LayoutPropagation::visitOperation(Conv2DOp op) { assignedLayouts.insert({result, resultLayout}); setResultLayoutAttr(op); - debugAssignLayout(result, resultLayout); + auto kernelAttr = + secret::KernelAttr::get(ctx, KernelName::MatvecDiagonal, /*force=*/false); + op->setAttr(secret::SecretDialect::kKernelAttrName, kernelAttr); + + return success(); +} + +LogicalResult LayoutPropagation::visitOperation(Conv2DNchwFchwOp op) { + LLVM_DEBUG(llvm::dbgs() << "Specializing visitor on Conv2DNchwFchwOp\n"); + Value data = op.getInputs().front(); + Value filter = op.getInputs().back(); + auto dataType = cast(data.getType()); + auto filterType = cast(filter.getType()); + RankedTensorType outputType = + cast(op.getResult(0).getType()); + + SmallVector strides(op.getStrides().getValues().begin(), + op.getStrides().getValues().end()); + if (!llvm::all_equal(strides)) { + return op->emitOpError() << "Expected equal strides for Conv2DNchwFchwOp"; + } + + MLIRContext* ctx = &getContext(); + mlir::IRRewriter builder(ctx); + // Ensure data is in gapped row-major layout with current inputGap. + // We expect 4-D tensor (N, C, H, W) but only support N=1. + if (dataType.getRank() != 4 || dataType.getDimSize(0) != 1) { + return op->emitOpError() << "Expected 4-D data tensor (N=1, C, H, W)"; + } + + // Since the stride will be used as the gap factor, the layout requires that + // the number of output channels is divisible by gap^2. + // TODO(FIXME): handle padding the output channels to be divisible by gap^2. + if (outputType.getDimSize(1) % (strides[0] * strides[0]) != 0) { + return op->emitOpError() + << "Expected number of output channels to be divisible by gap^2"; + } + + LayoutAttr dataLayout = assignedLayouts.at(data); + IntegerRelation targetDataRelation = + getRowMajorLayoutRelation(dataType, ciphertextSize); + + if (!dataLayout.getIntegerRelation().isEqual(targetDataRelation)) { + LLVM_DEBUG(llvm::dbgs() << "conv_2d data input is not row major, " + "inserting layout conversion.\n"); + auto [toReplace, newDataLayoutAttr] = + convertToLayout(ctx, builder, op, data, dataLayout, targetDataRelation); + debugAssignLayout(toReplace, newDataLayoutAttr); + assignedLayouts.insert({toReplace, newDataLayoutAttr}); + } + + // The kernel for this operation requires expanding the conv filter matrix + // into a larger matrix and then diagonalizing. + LayoutAttr filterLayout = assignedLayouts.at(filter); + auto convRelation = get2dConvChwFchwFilterDiagonalizedRelation( + filterType, dataType, strides, /*padding=*/0, ciphertextSize); + if (failed(convRelation)) { + return failure(); + } + if (!isRelationEqual(filterLayout.getIntegerRelation(), + convRelation.value())) { + LLVM_DEBUG(llvm::dbgs() << "conv_2d filter input is not diagonalized, " + "inserting layout conversion.\n"); + + // Insert a layout conversion op to make the matrix layout expanded and + // squat diagonal + auto [toReplace, newFilterLayoutAttr] = convertToLayout( + ctx, builder, op, filter, filterLayout, convRelation.value()); + debugAssignLayout(toReplace, newFilterLayoutAttr); + assignedLayouts.insert({toReplace, newFilterLayoutAttr}); + } + + // Always one result. For a gapped output, the result will also have a + // pixel-shuffled gap. Convert the result back to a row-major layout for other + // users and let future passes handle propagating the gap. + auto result = op->getResult(0); + presburger::IntegerRelation rowMajorRelation = + getRowMajorLayoutRelation(outputType, ciphertextSize); + // FIXME: getRowInterchange just takes one to one mapping of the index + // remapping, but doesn't distribute them into ct, slot values. + presburger::IntegerRelation interchangeRelation = getRowInterchangeRelation( + outputType.getDimSize(1), outputType.getDimSize(2), + outputType.getDimSize(3), strides[0]); + interchangeRelation.insertVar(presburger::VarKind::Domain, 0); + interchangeRelation.insertVar(presburger::VarKind::Range, 0); + addConstraint( + interchangeRelation, + {{interchangeRelation.getVarKindOffset(presburger::VarKind::Domain) + 1, + -1}, + {interchangeRelation.getVarKindOffset(presburger::VarKind::Range) + 1, + 1}}, + /*equality=*/true); + auto gappedRelation = rowMajorRelation.clone(); + gappedRelation->compose(interchangeRelation); + LayoutAttr gappedLayoutAttr = + LayoutAttr::getFromIntegerRelation(ctx, *gappedRelation); + + assignedLayouts.insert({result, gappedLayoutAttr}); + setResultLayoutAttr(op); auto kernelAttr = secret::KernelAttr::get(ctx, KernelName::MatvecDiagonal, /*force=*/false); op->setAttr(secret::SecretDialect::kKernelAttrName, kernelAttr); + // Insert a layout conversion op to make the result row-major again. + LayoutAttr rowMajorLayoutAttr = + LayoutAttr::getFromIntegerRelation(ctx, rowMajorRelation); + builder.setInsertionPointAfter(op); + ConvertLayoutOp convertLayoutOp = ConvertLayoutOp::create( + builder, op->getLoc(), result, gappedLayoutAttr, rowMajorLayoutAttr); + convertLayoutOp->setAttr(tensor_ext::TensorExtDialect::kLayoutAttrName, + rowMajorLayoutAttr); + Value toReplace = convertLayoutOp.getResult(); + builder.replaceAllUsesExcept(result, toReplace, convertLayoutOp); + debugAssignLayout(toReplace, rowMajorLayoutAttr); + assignedLayouts.insert({toReplace, rowMajorLayoutAttr}); + return success(); } @@ -905,7 +1021,8 @@ CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( affine::AffineYieldOp>( [&](auto op) { return CompatibilityResult{true, std::nullopt}; }) // Ops with special rules - .Case( + .Case( [&](auto op) { return hasCompatibleArgumentLayouts(op); }) // By default, assume operands must all have the same layout. .Default([&](Operation* op) { @@ -1020,6 +1137,22 @@ CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( return {true, std::nullopt}; } +CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( + Conv2DNchwFchwOp op) { + // Currently only support secret data and plaintext filters. + Value data = op.getInputs().front(); + Value filter = op.getInputs().back(); + if (isSecret(filter, solver) || !isSecret(data, solver)) { + return {false, op->emitError("Only secret data and plaintext filters are " + "supported for linalg.conv_2d_nchw_fchw")}; + } + + if (!assignedLayouts.contains(data)) { + return {false, op->emitError("data operand has no assigned layout")}; + } + return {true, std::nullopt}; +} + CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( tensor::InsertSliceOp op) { // The arguments of a tensor::InsertSliceOp are the tensors to insert and the diff --git a/lib/Utils/Layout/BUILD b/lib/Utils/Layout/BUILD index d2529b1eea..1f71127772 100644 --- a/lib/Utils/Layout/BUILD +++ b/lib/Utils/Layout/BUILD @@ -66,6 +66,7 @@ cc_library( deps = [ ":IslConversion", ":Utils", + "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", diff --git a/lib/Utils/Layout/Convolution.cpp b/lib/Utils/Layout/Convolution.cpp index 03cc9fd24b..44dfca31a1 100644 --- a/lib/Utils/Layout/Convolution.cpp +++ b/lib/Utils/Layout/Convolution.cpp @@ -1,11 +1,13 @@ #include "lib/Utils/Layout/Convolution.h" #include +#include #include #include #include "lib/Utils/Layout/IslConversion.h" #include "lib/Utils/Layout/Utils.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/Analysis/Presburger/IntegerRelation.h" // from @llvm-project #include "mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h" // from @llvm-project @@ -152,18 +154,43 @@ RankedTensorType get2dConvFilterExpandedType(RankedTensorType filterType, return RankedTensorType::get({rows, cols}, filterType.getElementType()); } +RankedTensorType get2dConvChwFchwFilterExpandedType(RankedTensorType filterType, + RankedTensorType dataType, + int64_t padding, + ArrayRef strides) { + // Get the filter relation for a single input and output channel and multiply + // the dimensions by the number of input and output channels for the row and + // column dimensions respectively. + RankedTensorType singleFilterType = RankedTensorType::get( + {filterType.getDimSize(2), filterType.getDimSize(3)}, + filterType.getElementType()); + RankedTensorType singleDataType = + RankedTensorType::get({dataType.getDimSize(2), dataType.getDimSize(3)}, + dataType.getElementType()); + auto singleResultType = get2dConvFilterExpandedType( + singleFilterType, singleDataType, padding, strides); + + int64_t inputChannels = dataType.getDimSize(1); + int64_t outputChannels = filterType.getDimSize(0); + + int64_t rows = outputChannels * singleResultType.getDimSize(0); + int64_t cols = inputChannels * singleResultType.getDimSize(1); + return RankedTensorType::get({rows, cols}, filterType.getElementType()); +} + presburger::IntegerRelation get2dConvChwFchwFilterRelation( RankedTensorType filterType, RankedTensorType dataType, ArrayRef strides, int64_t padding) { assert(filterType.getRank() == 4 && "expected 4-D filter matrix"); - assert(dataType.getRank() == 3 && "expected 3-D data matrix"); + assert(dataType.getRank() == 4 && "expected 4-D data matrix"); + assert(dataType.getDimSize(0) == 1 && "expected N=1 batch size"); // Get the filter relation for a single input and output channel. RankedTensorType singleFilterType = RankedTensorType::get( {filterType.getDimSize(2), filterType.getDimSize(3)}, filterType.getElementType()); RankedTensorType singleDataType = - RankedTensorType::get({dataType.getDimSize(1), dataType.getDimSize(2)}, + RankedTensorType::get({dataType.getDimSize(2), dataType.getDimSize(3)}, dataType.getElementType()); auto singleFilterRelation = get2dConvFilterRelation( singleFilterType, singleDataType, strides, padding); @@ -179,7 +206,7 @@ presburger::IntegerRelation get2dConvChwFchwFilterRelation( auto cDim = singleFilterRelation.getVarKindOffset(presburger::VarKind::Domain) + 1; - auto inputChannels = dataType.getDimSize(0); + auto inputChannels = dataType.getDimSize(1); auto outputChannels = filterType.getDimSize(0); assert(inputChannels == filterType.getDimSize(1) && "input channels must match filter input channels"); @@ -222,28 +249,40 @@ FailureOr get2dConvFilterDiagonalizedRelation( SmallVector strides = {1, 1}; auto expandedFilterRelation = get2dConvFilterRelation(filterType, dataType, strides, padding); - // Get size of the expanded filter matrix. - auto rowBound = expandedFilterRelation.getConstantBound64( - BoundType::UB, expandedFilterRelation.getVarKindOffset(VarKind::Range)); - if (!rowBound.has_value()) { - return failure(); - } - auto colBound = expandedFilterRelation.getConstantBound64( - BoundType::UB, - expandedFilterRelation.getVarKindOffset(VarKind::Range) + 1); - if (!colBound.has_value()) { - return failure(); - } - RankedTensorType expandedFilterType = - RankedTensorType::get({rowBound.value() + 1, colBound.value() + 1}, - filterType.getElementType()); - - auto diagonalizedFilterRelation = - getDiagonalLayoutRelation(expandedFilterType, ciphertextSize); + return diagonalize2dMatrix(expandedFilterRelation, filterType, + ciphertextSize); +} - // Compose these relations. - expandedFilterRelation.compose(diagonalizedFilterRelation); - return expandedFilterRelation; +FailureOr +get2dConvChwFchwFilterDiagonalizedRelation(RankedTensorType filterType, + RankedTensorType dataType, + ArrayRef strides, + int64_t padding, + int64_t ciphertextSize, + bool interchangeRows) { + auto expandedFilterRelation = + get2dConvChwFchwFilterRelation(filterType, dataType, strides, padding); + // Permutate the rows of the matrix to minimize the number of non-zero + // diagonals. + if (interchangeRows) { + auto rowInterchangeRelation = getRowInterchangeRelation( + filterType.getDimSize(0), filterType.getDimSize(2), + filterType.getDimSize(3), strides[0]); + rowInterchangeRelation.appendVar(presburger::VarKind::Domain); + rowInterchangeRelation.appendVar(presburger::VarKind::Range); + addConstraint( + rowInterchangeRelation, + {{rowInterchangeRelation.getVarKindOffset(presburger::VarKind::Domain) + + 1, + -1}, + {rowInterchangeRelation.getVarKindOffset(presburger::VarKind::Range) + + 1, + 1}}, + /*equality=*/true); + expandedFilterRelation.compose(rowInterchangeRelation); + } + return diagonalize2dMatrix(expandedFilterRelation, filterType, + ciphertextSize); } bool isRelation2dConvFilterDiagonalized( @@ -254,15 +293,7 @@ bool isRelation2dConvFilterDiagonalized( if (failed(diagonalizedRelation)) { return false; } - bool fastCheck = relation.isObviouslyEqual(diagonalizedRelation.value()); - if (fastCheck) return true; - - LogicalResult inequalityTest = - tryProveUnequal(diagonalizedRelation.value(), relation); - if (succeeded(inequalityTest)) return false; - - bool slowCheck = relation.isEqual(diagonalizedRelation.value()); - return slowCheck; + return isRelationEqual(relation, diagonalizedRelation.value()); } presburger::IntegerRelation getRowInterchangeRelation(int64_t c, int64_t h, @@ -275,22 +306,66 @@ presburger::IntegerRelation getRowInterchangeRelation(int64_t c, int64_t h, int64_t hOut = w * g; int64_t wOut = h * g; int64_t cOut = c / (g * g); + int64_t numElements = c * h * w; - // Domain: [idx_in] - // Range: [ct, slot] where ct=0 and slot=idx_out + // One to one mapping from idx_in to idx_out. std::string islStr = llvm::formatv( - "{{ [idx_in] -> [0, idx_out] : exists hi, wi, ci, ho, wo, co : " + "{{ [idx_in] -> [idx_out] : exists hi, wi, ci, ho, wo, co : " "0 <= hi < {0} and 0 <= wi < {1} and 0 <= ci < {2} and " "0 <= ho < {3} and 0 <= wo < {4} and 0 <= co < {5} and co = ci // " "{10}^2 and " "wo = wi * {10} + (ci % {10}) and " "ho = hi * {10} + (ci % {10}^2) // {10} and " "idx_in = wi + hi * {6} + ci * {7} and " - "idx_out = wo + ho * {8} + co * {9} }", - h, w, c, hOut, wOut, cOut, w, h * w, wOut, hOut * wOut, g); + "idx_out = wo + ho * {8} + co * {9} and 0 <= idx_in < {11} and 0 <= " + "idx_out < {11} }", + h, w, c, hOut, wOut, cOut, w, h * w, wOut, hOut * wOut, g, numElements); return getIntegerRelationFromIslStr(islStr).value(); } +presburger::IntegerRelation get2dConvResultRelation(RankedTensorType outputType, + ArrayRef strides, + int64_t padding, + int64_t ciphertextSize) { + assert(llvm::all_equal(strides) && "strides must be equal"); + + // First flatten the output tensor into a 1-D tensor of (ct, slot) where ct = + // 0 (set the "ciphertextSize" to be the same as the number of elements). This + // creates outputType -> [0, slot]. + auto flattenedOutput = + getRowMajorLayoutRelation(outputType, outputType.getNumElements()); + + // Create the interchange permutation [idx_in] -> [idx_out] and add a domain + // var = 0 to align with the range of the flattenedOutput relation. + int64_t c = outputType.getDimSize(1); + int64_t h = outputType.getDimSize(2); + int64_t w = outputType.getDimSize(3); + int64_t g = strides[0]; + auto rowInterchange = getRowInterchangeRelation(c, h, w, g); + rowInterchange.insertVar(presburger::VarKind::Domain, 0); + addConstraint( + rowInterchange, + {{rowInterchange.getVarKindOffset(presburger::VarKind::Domain), 1}, + {rowInterchange.getNumCols() - 1, 0}}, + /*equality=*/true); + // Compose the row interchange relation with the flattened output relation: + // [outputType] -> [0, slot] -> [slot']. + flattenedOutput.compose(rowInterchange); + + // Compose with the [slot'] -> [ct, slot] relation across multiple + // ciphertexts. + int64_t numCiphertexts = + std::ceil((float)outputType.getNumElements() / ciphertextSize); + std::string mapToCtSlot = llvm::formatv( + "{{ [idx_out] -> [ct, slot] : " + "0 <= ct < {0} and 0 <= slot < {1} and idx_out = ct * {1} + slot }", + numCiphertexts, ciphertextSize); + auto toCtSlot = getIntegerRelationFromIslStr(mapToCtSlot).value(); + flattenedOutput.compose(toCtSlot); + + return flattenedOutput; +} + } // namespace heir } // namespace mlir diff --git a/lib/Utils/Layout/Convolution.h b/lib/Utils/Layout/Convolution.h index 07a4339ce9..f453d0cdbb 100644 --- a/lib/Utils/Layout/Convolution.h +++ b/lib/Utils/Layout/Convolution.h @@ -24,6 +24,14 @@ RankedTensorType get2dConvFilterExpandedType( RankedTensorType filterType, RankedTensorType dataType, int64_t padding, ArrayRef strides = {1, 1}); +// Returns an IntegerRelation that expands a 2-D filter matrix used in a +// convolution into a 2-D matrix such that the convolution is +// equivalent a matrix product with the flattened input vector. Each row +// corresponds to one filter multiplication. +FailureOr get2dConvFilterDiagonalizedRelation( + RankedTensorType filterType, RankedTensorType dataType, int64_t padding, + int64_t ciphertextSize); + // Returns an IntegerRelation that expands a multichannel filter used // in a 2-D convolution into a 2-D Toeplitz matrix such that the convolution is // equivalent a matrix product with the flattened multichannel input vector. @@ -35,13 +43,23 @@ presburger::IntegerRelation get2dConvChwFchwFilterRelation( RankedTensorType filterType, RankedTensorType dataType, ArrayRef strides, int64_t padding); -// Returns an IntegerRelation that expands a 2-D filter matrix used in a -// convolution into a 2-D matrix such that the convolution is -// equivalent a matrix product with the flattened input vector. Each row -// corresponds to one filter multiplication. -FailureOr get2dConvFilterDiagonalizedRelation( +// Returns an IntegerRelation that represents a diagonalized 2-D Toeplitz matrix +// that is used to compute a 2-D multichannel convolution filter such that the +// convolution is equivalent a matrix product with the flattened multichannel +// input vector. Each row corresponds to one filter multiplication. The filter +// type is assumed to be 4-D with dimensions (f, c, h, w) and the data type is +// assumed to be 3-D with dimensions (c, h, w). +FailureOr +get2dConvChwFchwFilterDiagonalizedRelation(RankedTensorType filterType, + RankedTensorType dataType, + ArrayRef strides, + int64_t padding, + int64_t ciphertextSize, + bool interchangeRows = true); + +RankedTensorType get2dConvChwFchwFilterExpandedType( RankedTensorType filterType, RankedTensorType dataType, int64_t padding, - int64_t ciphertextSize); + ArrayRef strides = {1, 1}); // Returns an IntegerRelation for a row-interchange map that optimizes the // diagonal structure of a convolution's Toeplitz matrix. @@ -59,6 +77,14 @@ FailureOr get2dConvFilterDiagonalizedRelation( presburger::IntegerRelation getRowInterchangeRelation(int64_t c, int64_t h, int64_t w, int64_t g); +// Returns an IntegerRelation that corresponds to the output layout of a 2-D +// multi-channel convolution. This includes the row interchange from pixel +// shuffling. The result is a relation mapping to (ct, slot) of the output. +presburger::IntegerRelation get2dConvResultRelation(RankedTensorType outputType, + ArrayRef strides, + int64_t padding, + int64_t ciphertextSize); + bool isRelation2dConvFilterDiagonalized( RankedTensorType filterType, RankedTensorType dataType, int64_t padding, int64_t ciphertextSize, const presburger::IntegerRelation& relation); diff --git a/lib/Utils/Layout/ConvolutionTest.cpp b/lib/Utils/Layout/ConvolutionTest.cpp index d8b6a227a6..089fd032c5 100644 --- a/lib/Utils/Layout/ConvolutionTest.cpp +++ b/lib/Utils/Layout/ConvolutionTest.cpp @@ -1,9 +1,11 @@ #include +#include #include #include "gtest/gtest.h" // from @googletest #include "lib/Utils/Layout/Convolution.h" #include "lib/Utils/Layout/Evaluate.h" +#include "lib/Utils/Layout/Utils.h" #include "mlir/include/mlir/Analysis/Presburger/IntegerRelation.h" // from @llvm-project #include "mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project @@ -190,7 +192,7 @@ TEST(ConvolutionTest, ConvChwFchwFilterRelation) { RankedTensorType filterType = RankedTensorType::get({2, 2, 3, 3}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 3, 3}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 3, 3}, IndexType::get(&context)); SmallVector strides = {1, 1}; int64_t padding = 1; IntegerRelation rel = @@ -214,7 +216,7 @@ TEST(ConvolutionTest, ConvChwFchwNoPaddingFilterRelation) { RankedTensorType filterType = RankedTensorType::get({2, 2, 2, 2}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 4, 4}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 4, 4}, IndexType::get(&context)); SmallVector strides = {2, 2}; int64_t padding = 0; IntegerRelation rel = @@ -244,7 +246,7 @@ TEST(ConvolutionTest, ConvChwFchwFilterRelationUnequalStrides) { RankedTensorType filterType = RankedTensorType::get({2, 2, 3, 3}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 5, 5}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 5, 5}, IndexType::get(&context)); SmallVector strides = {2, 3}; int64_t padding = 0; IntegerRelation rel = @@ -275,7 +277,7 @@ TEST(ConvolutionTest, ConvChwFchwFilterRelationPadding) { RankedTensorType filterType = RankedTensorType::get({2, 2, 3, 3}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 3, 3}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 3, 3}, IndexType::get(&context)); SmallVector strides = {2, 2}; int64_t padding = 1; IntegerRelation rel = @@ -294,6 +296,7 @@ TEST(ConvolutionTest, ConvChwFchwFilterRelationPadding) { // singleColSize = 9, c=2 -> slotBound = 1 * 9 + 8 = 17 EXPECT_EQ(slotBound.value(), 17); } + TEST(ConvolutionTest, TestRowInterchange) { MLIRContext context; // c=4, h=2, w=2, g=2 @@ -301,14 +304,22 @@ TEST(ConvolutionTest, TestRowInterchange) { std::vector input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - auto result = evaluateLayoutOnVector(rel, input); - - ASSERT_EQ(result.size(), 1); - // Expected permutation: [0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, - // 15] std::vector expectedPermutation = {0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}; - EXPECT_EQ(result[0], expectedPermutation); + PointPairCollector collector(1, 1); // 1 domain dim, 1 range dim + enumeratePoints(rel, collector); + + EXPECT_EQ(collector.points.size(), expectedPermutation.size()); + + for (const auto& actualPoint : collector.points) { + // The permutation in the relation is the expected (i -> j) mappings. + auto startVal = actualPoint.first[0]; + auto permuteIdx = actualPoint.second[0]; + auto resultingVal = expectedPermutation[permuteIdx]; + EXPECT_EQ(startVal, resultingVal) + << "Point not found: domain=" << actualPoint.first[0] + << ", range=" << actualPoint.second[0]; + } } TEST(ConvolutionTest, TestRowInterchangeMultiChannel) { @@ -317,26 +328,28 @@ TEST(ConvolutionTest, TestRowInterchangeMultiChannel) { // Input: 2x2x18 = 72 elements. Output: 6x6x2 IntegerRelation rel = getRowInterchangeRelation(18, 2, 2, 3); - std::vector input(72); - for (int i = 0; i < 72; ++i) input[i] = i; - auto result = evaluateLayoutOnVector(rel, input); - - ASSERT_EQ(result.size(), 1); - ASSERT_EQ(result[0].size(), 72); - - // expected{i} is the flattened output channel i - std::vector expected0 = {0, 4, 8, 1, 5, 9, 12, 16, 20, 13, 17, 21, - 24, 28, 32, 25, 29, 33, 2, 6, 10, 3, 7, 11, - 14, 18, 22, 15, 19, 23, 26, 30, 34, 27, 31, 35}; - std::vector expected1 = {36, 40, 44, 37, 41, 45, 48, 52, 56, 49, 53, 57, - 60, 64, 68, 61, 65, 69, 38, 42, 46, 39, 43, 47, - 50, 54, 58, 51, 55, 59, 62, 66, 70, 63, 67, 71}; - - std::vector expectedAll; - expectedAll.insert(expectedAll.end(), expected0.begin(), expected0.end()); - expectedAll.insert(expectedAll.end(), expected1.begin(), expected1.end()); - - EXPECT_EQ(result[0], expectedAll); + PointPairCollector collector(1, 1); // 1 domain dim, 1 range dim + enumeratePoints(rel, collector); + + EXPECT_EQ(collector.points.size(), 72); + + // expected contains all the flattened output channels in order + std::vector expected = { + 0, 4, 8, 1, 5, 9, 12, 16, 20, 13, 17, 21, 24, 28, 32, 25, 29, 33, + 2, 6, 10, 3, 7, 11, 14, 18, 22, 15, 19, 23, 26, 30, 34, 27, 31, 35, + 36, 40, 44, 37, 41, 45, 48, 52, 56, 49, 53, 57, 60, 64, 68, 61, 65, 69, + 38, 42, 46, 39, 43, 47, 50, 54, 58, 51, 55, 59, 62, 66, 70, 63, 67, 71}; + + for (const auto& actualPoint : collector.points) { + // The permutation in the relation is the expected (i -> j) mappings. + auto startVal = actualPoint.first[0]; + auto permuteIdx = actualPoint.second[0]; + auto resultingVal = expected[permuteIdx]; + EXPECT_EQ(startVal, resultingVal) + << "Point not found: domain=" << actualPoint.first[0] + << ", range=" << actualPoint.second[0]; + ; + } } } // namespace diff --git a/lib/Utils/Layout/EvaluateTest.cpp b/lib/Utils/Layout/EvaluateTest.cpp index 2949bdb4fc..e6ebd789fe 100644 --- a/lib/Utils/Layout/EvaluateTest.cpp +++ b/lib/Utils/Layout/EvaluateTest.cpp @@ -189,7 +189,7 @@ TEST(EvaluateTest, EvaluateLayoutFor2DConvChwFchw) { RankedTensorType filterType = RankedTensorType::get({2, 2, 3, 3}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 3, 3}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 3, 3}, IndexType::get(&context)); SmallVector strides = {1, 1}; int64_t padding = 1; IntegerRelation rel = @@ -236,7 +236,7 @@ TEST(EvaluateTest, EvaluateLayoutFor2DConvChwFchwNoPadding) { RankedTensorType filterType = RankedTensorType::get({2, 2, 2, 2}, IndexType::get(&context)); RankedTensorType dataType = - RankedTensorType::get({2, 4, 4}, IndexType::get(&context)); + RankedTensorType::get({1, 2, 4, 4}, IndexType::get(&context)); SmallVector strides = {2, 2}; int64_t padding = 0; IntegerRelation rel = @@ -275,6 +275,135 @@ TEST(EvaluateTest, EvaluateLayoutFor2DConvChwFchwNoPadding) { ASSERT_THAT(result, Eq(expected)); } +TEST(EvaluateTest, EvaluateLayoutFor2DConvChwFchwNoPaddingDiagonalized) { + MLIRContext context; + // Filter 2x2 and data is 4x4 so there are 2x2 sliding windows. + RankedTensorType filterType = + RankedTensorType::get({4, 1, 2, 2}, IndexType::get(&context)); + RankedTensorType dataType = + RankedTensorType::get({1, 1, 4, 4}, IndexType::get(&context)); + SmallVector strides = {2, 2}; + int64_t padding = 0; + auto rel = get2dConvChwFchwFilterDiagonalizedRelation( + filterType, dataType, strides, padding, 16, false); + ASSERT_TRUE(succeeded(rel)); + + std::vector>>> filter = { + {{{1, 2}, {3, 4}}}, // Channel 0 + {{{1, 2}, {3, 4}}}, // Channel 1 + {{{1, 2}, {3, 4}}}, // Channel 2 + {{{1, 2}, {3, 4}}} // Channel 3 + }; + std::function&)> getValueFn = + [&](const std::vector& domainPoint) -> int { + return filter[domainPoint[0]][domainPoint[1]][domainPoint[2]] + [domainPoint[3]]; + }; + + auto result = evaluateLayout(rel.value(), getValueFn); + + std::vector> expected = { + {1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 4}, + {2, 1, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 2, 0, 0, 0, 4, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 4, 3, 0, 0, 0, 0}, + {3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 4, 1, 0, 0, 0}, + {4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0}, + {0, 4, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0}, + {0, 0, 2, 1, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 4, 1, 0, 0, 0, 3, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 4, 3, 0, 0}, + {0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 4, 1, 0}, + {0, 0, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1}, + {0, 0, 0, 4, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2}, + {0, 0, 0, 0, 2, 1, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 2, 0, 0, 0, 4, 1, 0, 0, 0, 3, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 4, 3}}; + EXPECT_THAT(result, Eq(expected)); + + // Now test minimal non zero diagonals + auto relOptimized = get2dConvChwFchwFilterDiagonalizedRelation( + filterType, dataType, strides, padding, 16, true); + ASSERT_TRUE(succeeded(relOptimized)); + auto resultOptimized = evaluateLayout(relOptimized.value(), getValueFn); + + std::vector> expectedOptimized = { + {1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, + {2, 0, 2, 0, 4, 0, 4, 0, 2, 0, 2, 0, 4, 0, 4, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 3, 0, 3, 0, 0, 0, 0, 0, 3, 0, 3, 0, 0, 0, 0}, + {3, 4, 3, 4, 0, 0, 0, 0, 3, 4, 3, 4, 0, 0, 0, 0}, + {4, 0, 4, 0, 0, 0, 0, 0, 4, 0, 4, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1}, + {0, 0, 0, 0, 1, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 2}, + {0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 1, 0, 3, 0, 3, 0, 1, 0, 1, 0, 3, 0, 3}}; + EXPECT_THAT(resultOptimized, Eq(expectedOptimized)); +} + +TEST(EvaluateTest, Conv2dResultRelation) { + MLIRContext context; + RankedTensorType outputType = + RankedTensorType::get({1, 4, 2, 2}, IndexType::get(&context)); + SmallVector strides = {2, 2}; + int64_t padding = 0; + + // Fits in one ciphertext + int64_t ciphertextSize = 16; + IntegerRelation rel = + get2dConvResultRelation(outputType, strides, padding, ciphertextSize); + EXPECT_EQ(rel.getNumDomainVars(), outputType.getRank()); + EXPECT_EQ(rel.getNumRangeVars(), 2); + + std::vector>>> output = { + {{{1, 2}, {3, 4}}, + {{5, 6}, {7, 8}}, + {{9, 10}, {11, 12}}, + {{13, 14}, {15, 16}}}}; + std::function&)> getValueFn = + [&](const std::vector& domainPoint) -> int { + return output[domainPoint[0]][domainPoint[1]][domainPoint[2]] + [domainPoint[3]]; + }; + auto result = evaluateLayout(rel, getValueFn); + std::vector> expected = { + {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}}; + EXPECT_THAT(result, Eq(expected)); +} + +TEST(EvaluateTest, Conv2dResultRelationTwoCiphertexts) { + MLIRContext context; + RankedTensorType outputType = + RankedTensorType::get({1, 4, 2, 2}, IndexType::get(&context)); + SmallVector strides = {2, 2}; + int64_t padding = 0; + + int64_t ciphertextSize = 8; + IntegerRelation rel = + get2dConvResultRelation(outputType, strides, padding, ciphertextSize); + + std::vector>>> output = { + {{{1, 2}, {3, 4}}, + {{5, 6}, {7, 8}}, + {{9, 10}, {11, 12}}, + {{13, 14}, {15, 16}}}}; + std::function&)> getValueFn = + [&](const std::vector& domainPoint) -> int { + return output[domainPoint[0]][domainPoint[1]][domainPoint[2]] + [domainPoint[3]]; + }; + auto result = evaluateLayout(rel, getValueFn); + std::vector> expected = {{1, 5, 2, 6, 9, 13, 10, 14}, + {3, 7, 4, 8, 11, 15, 12, 16}}; + EXPECT_THAT(result, Eq(expected)); +} + } // namespace } // namespace heir } // namespace mlir diff --git a/lib/Utils/Layout/Utils.cpp b/lib/Utils/Layout/Utils.cpp index 3a1c698eb9..80cca1adf1 100644 --- a/lib/Utils/Layout/Utils.cpp +++ b/lib/Utils/Layout/Utils.cpp @@ -336,6 +336,27 @@ presburger::IntegerRelation getDiagonalLayoutRelation( return result; } +FailureOr diagonalize2dMatrix( + presburger::IntegerRelation relation, RankedTensorType originalType, + int64_t ciphertextSize) { + // Get size of the matrix. + auto rowBound = relation.getConstantBound64( + BoundType::UB, relation.getVarKindOffset(VarKind::Range)); + auto colBound = relation.getConstantBound64( + BoundType::UB, relation.getVarKindOffset(VarKind::Range) + 1); + if (!rowBound.has_value() || !colBound.has_value()) { + return failure(); + } + RankedTensorType matrixType = + RankedTensorType::get({rowBound.value() + 1, colBound.value() + 1}, + originalType.getElementType()); + auto diagonalRelation = getDiagonalLayoutRelation(matrixType, ciphertextSize); + + // Compose these relations. + relation.compose(diagonalRelation); + return relation; +} + presburger::IntegerRelation getBicyclicLayoutRelation( RankedTensorType matrixType, int64_t numSlots) { unsigned int rows = matrixType.getDimSize(0); @@ -892,5 +913,17 @@ FailureOr getSliceExtractionRelation( return result; } +bool isRelationEqual(const presburger::IntegerRelation& relation1, + const presburger::IntegerRelation& relation2) { + bool fastCheck = relation1.isObviouslyEqual(relation2); + if (fastCheck) return true; + + LogicalResult inequalityTest = tryProveUnequal(relation2, relation1); + if (succeeded(inequalityTest)) return false; + + bool slowCheck = relation1.isEqual(relation2); + return slowCheck; +} + } // namespace heir } // namespace mlir diff --git a/lib/Utils/Layout/Utils.h b/lib/Utils/Layout/Utils.h index 3168a7febd..7357473d10 100644 --- a/lib/Utils/Layout/Utils.h +++ b/lib/Utils/Layout/Utils.h @@ -73,6 +73,11 @@ presburger::IntegerRelation getRowMajorLayoutRelation( presburger::IntegerRelation getDiagonalLayoutRelation( RankedTensorType matrixType, int64_t ciphertextSize); +// Applies a diagonal layout onto a given 2-D matrix layout. +FailureOr diagonalize2dMatrix( + presburger::IntegerRelation relation, RankedTensorType originalType, + int64_t ciphertextSize); + // Returns an IntegerRelation that represents a bicyclic layout for a matrix. // See https://eprint.iacr.org/2024/1762 for details. presburger::IntegerRelation getBicyclicLayoutRelation( @@ -226,6 +231,9 @@ FailureOr getSliceExtractionRelation( SmallVector offsets, SmallVector sizes, SmallVector strides); +bool isRelationEqual(const presburger::IntegerRelation& relation1, + const presburger::IntegerRelation& relation2); + } // namespace heir } // namespace mlir diff --git a/tests/Transforms/layout_optimization/hoist_conv.mlir b/tests/Transforms/layout_optimization/hoist_conv.mlir new file mode 100644 index 0000000000..d7ba985b3c --- /dev/null +++ b/tests/Transforms/layout_optimization/hoist_conv.mlir @@ -0,0 +1,28 @@ +// RUN: heir-opt --layout-optimization --canonicalize %s | FileCheck %s + +#data_layout = #tensor_ext.layout<"{ [d0, d1, d2, d3] -> [0, slot] : 0 <= d0 < 1 and 0 <= d1 < 1 and 0 <= d2 < 4 and 0 <= d3 < 4 and slot = d3 + 4 * d2 + 16 * d1 + 16 * d0 }"> +#data_layout_2 = #tensor_ext.layout<"{ [d0, d1, d2, d3] -> [0, slot] : 0 <= d0 < 1 and 0 <= d1 < 1 and 0 <= d2 < 4 and 0 <= d3 < 4 and slot = d2 + 4 * d3 + 16 * d1 + 16 * d0 }"> +// Filter layout is large but we check that it changes +#filter_layout = #tensor_ext.layout<"{ [d0, d1, d2, d3] -> [0, slot] : 0 <= d0 < 1 and 0 <= d1 < 1 and 0 <= d2 < 2 and 0 <= d3 < 2 and slot = d3 + 2 * d2 + 4 * d1 + 4 * d0 }"> + +// CHECK: func.func @hoist_conv +func.func @hoist_conv(%arg0: !secret.secret> {tensor_ext.layout = #data_layout}, %arg1: tensor<1x1x2x2xf32>) -> (!secret.secret> {tensor_ext.layout = #data_layout_2}) { + %cst = arith.constant dense<0.000000e+00> : tensor<1x1x3x3xf32> + %1 = tensor_ext.assign_layout %arg1 {layout = #filter_layout, tensor_ext.layout = #filter_layout} : tensor<1x1x2x2xf32> + %2 = tensor_ext.assign_layout %cst {layout = #data_layout, tensor_ext.layout = #data_layout} : tensor<1x1x3x3xf32> + + // CHECK: secret.generic + // CHECK-NOT: tensor_ext.convert_layout + %3 = secret.generic(%arg0 : !secret.secret>) { + ^body(%input0: tensor<1x1x4x4xf32>): + %4 = linalg.conv_2d_nchw_fchw + { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, + secret.kernel = #secret.kernel, + tensor_ext.layout = #data_layout } + ins(%input0, %1 : tensor<1x1x4x4xf32>, tensor<1x1x2x2xf32>) + outs(%2 : tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> + %5 = tensor_ext.convert_layout %4 {from_layout = #data_layout, tensor_ext.layout = #data_layout_2, to_layout = #data_layout_2} : tensor<1x1x3x3xf32> + secret.yield %5 : tensor<1x1x3x3xf32> + } -> (!secret.secret> {tensor_ext.layout = #data_layout_2}) + return %3 : !secret.secret> +} diff --git a/tests/Transforms/layout_propagation/conv2d_nchw.mlir b/tests/Transforms/layout_propagation/conv2d_nchw.mlir new file mode 100644 index 0000000000..a6221502a6 --- /dev/null +++ b/tests/Transforms/layout_propagation/conv2d_nchw.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --layout-propagation %s | FileCheck %s + +// Input layout is a flattened row-major layout. +// CHECK-DAG: #[[layout1:.*]] = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-10i2 - i3 + slot) mod 128 = 0 and 0 <= i2 <= 9 and 0 <= i3 <= 9 and 0 <= slot <= 1023 }"> +// CHECK-DAG: #[[layout:.*]] = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and ct = 0 and (-25i1 - 5i2 - i3 + slot) mod 128 = 0 and 0 <= i1 <= 3 and 0 <= i2 <= 4 and 0 <= i3 <= 1023 - 25i1 - 5i2 and i3 <= 4 and 0 <= slot <= 1023 and 1024*floor((-128 + 25i1 + 5i2 + i3)/1024) <= -1024 + 25i1 + 5i2 + i3 }"> +// CHECK-DAG: #[[layout3:.*]] = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : exists (e1, e2, e3: i0 = 0 and (-25i1 - 5i2 - i3 + slot) mod 128 = 0 and 0 <= i1 <= 3 and 0 <= i2 <= 4 and 0 <= i3 <= 1023 - 25i1 - 5i2 and i3 <= 4 and 0 <= slot <= 99 and 1024*floor((-128 + 25i1 + 5i2 + i3)/1024) <= -1024 + 25i1 + 5i2 + i3 and 0 <= e1 <= 4 and slot <= 2e2 <= 3 + slot and -4 + 26slot - 5e1 <= 50e2 <= 26slot - 5e1 and -1 - slot + 4e1 + 2e2 <= 2e3 <= -slot + 4e1 + 2e2 and -1 - 51slot + 10e1 + 100e2 <= 10e3 <= -51slot + 10e1 + 100e2) }"> +// CHECK-DAG: #kernel = #secret.kernel + +// CHECK: @conv2d_nchw +// CHECK-SAME: %[[arg0:.*]]: !secret.secret> {tensor_ext.layout = #[[layout1]]} +func.func @conv2d_nchw(%arg0: !secret.secret>) -> !secret.secret> { + %cst = arith.constant dense<0.000000e+00> : tensor<1x4x5x5xf32> + %filter = arith.constant dense<2.500000e-01> : tensor<4x1x2x2xf32> + + // CHECK: %[[res:.*]] = secret.generic + %0 = secret.generic(%arg0 : !secret.secret>) { + ^body(%input0: tensor<1x1x10x10xf32>): + // CHECK: linalg.conv_2d_nchw_fchw + // CHECK-SAME: secret.kernel = #kernel + // CHECK-SAME: strides = dense<2> : tensor<2xi64> + // CHECK-SAME: tensor_ext.layout = #[[layout3]] + %1 = linalg.conv_2d_nchw_fchw + { dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> } + ins(%input0, %filter : tensor<1x1x10x10xf32>, tensor<4x1x2x2xf32>) + outs(%cst : tensor<1x4x5x5xf32>) -> tensor<1x4x5x5xf32> + secret.yield %1 : tensor<1x4x5x5xf32> + // CHECK: convert_layout + // CHECK: secret.yield + // CHECK-NEXT: {tensor_ext.layout = #[[layout]]} + } -> !secret.secret> + return %0 : !secret.secret> +}