Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2422,7 +2422,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
while (srcShape[ref.back()] == 1 && ref.size() > 1)
ref = ref.drop_back();
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
auto precedingRef = ref.drop_back();
bool allUnitPreceding = llvm::all_of(
precedingRef, [&srcShape](int idx) { return srcShape[idx] == 1; });
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1 ||
allUnitPreceding) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/MemRef/collapse-strided.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: mlir-opt %s | FileCheck %s

// CHECK-LABEL: test_collapse(
func.func @test_collapse(%arg0: memref<1x?xf32, strided<[5, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<1x?xf32, strided<[5, 1]>> into memref<?xf32, strided<[1]>>
return
}

// CHECK-LABEL: test_collapse_5d_middle_dynamic(
func.func @test_collapse_5d_middle_dynamic(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<?xf32, strided<[?]>>
return
}

// CHECK-LABEL: test_collapse_5d_mostly_units(
func.func @test_collapse_5d_mostly_units(%arg0: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>> into memref<?xf32, strided<[2]>>
return
}

// CHECK-LABEL: test_partial_collapse_6d(
func.func @test_partial_collapse_6d(%arg0: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3], [4, 5]]
: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>> into memref<?x5xf32, strided<[420, 7]>>
return
}

// CHECK-LABEL: test_collapse_5d_grouped(
func.func @test_collapse_5d_grouped(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0], [1, 2, 3, 4]]
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<1x?xf32, strided<[540, ?]>>
return
}

// CHECK-LABEL: test_collapse_all_units(
func.func @test_collapse_all_units(%arg0: memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
: memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>> into memref<1xf32, strided<[100]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am a bit confused. It can also have stride 1. The definition of the stride here is amorphic since there is only one element.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be trivial and irrelevant, because dimension with size 1 will be ignored anyway

return
}