Skip to content

[mlir][spirv] Use assemblyFormat to define AccessChainOp assembly #116545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
let builders = [OpBuilder<(ins "Value":$basePtr, "ValueRange":$indices)>];

let hasCanonicalizer = 1;

let hasCustomAssemblyFormat = 0;

let assemblyFormat = [{
$base_ptr `[` $indices `]` attr-dict `:` type($base_ptr) `,` type($indices) `->` type(results)
}];
}

// -----
Expand Down
50 changes: 0 additions & 50 deletions mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,62 +320,12 @@ void AccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, indices);
}

ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand ptrInfo;
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
Type type;
auto loc = parser.getCurrentLocation();
SmallVector<Type, 4> indicesTypes;

if (parser.parseOperand(ptrInfo) ||
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(type) ||
parser.resolveOperand(ptrInfo, type, result.operands)) {
return failure();
}

// Check that the provided indices list is not empty before parsing their
// type list.
if (indicesInfo.empty()) {
return mlir::emitError(result.location,
"'spirv.AccessChain' op expected at "
"least one index ");
}

if (parser.parseComma() || parser.parseTypeList(indicesTypes))
return failure();

// Check that the indices types list is not empty and that it has a one-to-one
// mapping to the provided indices.
if (indicesTypes.size() != indicesInfo.size()) {
return mlir::emitError(
result.location, "'spirv.AccessChain' op indices types' count must be "
"equal to indices info count");
}

if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
return failure();

auto resultType = getElementPtrType(
type, llvm::ArrayRef(result.operands).drop_front(), result.location);
if (!resultType) {
return failure();
}

result.addTypes(resultType);
return success();
}

template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
printer << ' ' << op.getBasePtr() << '[' << indices
<< "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
}

void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
printAccessChain(*this, getIndices(), printer);
}

template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module attributes {gpu.container_module} {
%0 = spirv.mlir.addressof @kernel_arg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
%2 = spirv.Constant 0 : i32
%3 = spirv.mlir.addressof @kernel_arg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
%4 = spirv.AccessChain %0[%2, %2] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
%4 = spirv.AccessChain %0[%2, %2] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%5 = spirv.Load "StorageBuffer" %4 : f32
spirv.Return
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ spirv.func @access_chain() "None" {
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.struct<packed (f32, array<4 x f32>)>
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
spirv.Return
}

Expand All @@ -20,7 +20,7 @@ spirv.func @access_chain_array(%arg0 : i32) "None" {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.array<4 x array<4 x f32>>
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
spirv.Return
}
Expand Down
38 changes: 19 additions & 19 deletions mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,37 @@ func.func @access_chain_struct() -> () {
%0 = spirv.Constant 1: i32
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Function>
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}

func.func @access_chain_1D_array(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4xf32>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x f32>, Function>
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}

func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
%1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
%1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
return
}

func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
return
}

func.func @access_chain_rtarray(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.rtarray<f32>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.rtarray<f32>, Function>
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32 -> !spirv.ptr<f32, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
return
}
Expand All @@ -49,16 +49,16 @@ func.func @access_chain_non_composite() -> () {
%0 = spirv.Constant 1: i32
%1 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32 -> !spirv.ptr<f32, Function>
return
}

// -----

func.func @access_chain_no_indices(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// expected-error @+1 {{expected at least one index}}
%1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
// expected-error @+1 {{custom op 'spirv.AccessChain' 0 operands present, but expected 1}}
%1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}

Expand All @@ -75,17 +75,17 @@ func.func @access_chain_missing_comma(%index0 : i32) -> () {

func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
%1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
// expected-error @+1 {{custom op 'spirv.AccessChain' 1 operands present, but expected 2}}
%1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
return
}

// -----

func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
// expected-error @+1 {{custom op 'spirv.AccessChain' 2 operands present, but expected 1}}
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}

Expand All @@ -94,8 +94,8 @@ func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
func.func @access_chain_invalid_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
%1 = spirv.Load "Function" %0 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
// expected-error @+1 {{expected a pointer to composite type, but provided '!spirv.array<4 x !spirv.array<4 x f32>>'}}
%2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32
// expected-error @+1 {{'spirv.AccessChain' op operand #0 must be any SPIR-V pointer type, but got '!spirv.array<4 x !spirv.array<4 x f32>>'}}
%2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32 -> f32
return
}

Expand All @@ -113,7 +113,7 @@ func.func @access_chain_invalid_index_1(%index0 : i32) -> () {
func.func @access_chain_invalid_index_2(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct}}
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}

Expand All @@ -123,7 +123,7 @@ func.func @access_chain_invalid_constant_type_1() -> () {
%0 = arith.constant 1: i32
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct, but provided arith.constant}}
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}

Expand All @@ -133,7 +133,7 @@ func.func @access_chain_out_of_bounds() -> () {
%index0 = "spirv.Constant"() { value = 12: i32} : () -> i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{'spirv.AccessChain' op index 12 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}

Expand All @@ -142,9 +142,9 @@ func.func @access_chain_out_of_bounds() -> () {
func.func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
%1 = spirv.AccessChain %0[%index, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32
%1 = spirv.AccessChain %0[%index0, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32 -> !spirv.ptr<f32, Function>
return

}
// -----

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ spirv.module Logical GLSL450 {
// CHECK: [[VAR1:%.*]] = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
// CHECK-NEXT: spirv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
%1 = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32 -> !spirv.ptr<f32, Input>
spirv.Return
}
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ spirv.module Logical GLSL450 {
%37 = spirv.IAdd %arg4, %11 : i32
// CHECK: spirv.AccessChain [[ARG0]]
%c0 = spirv.Constant 0 : i32
%38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
%38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%39 = spirv.Load "StorageBuffer" %38 : f32
// CHECK: spirv.AccessChain [[ARG1]]
%40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
%40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%41 = spirv.Load "StorageBuffer" %40 : f32
%42 = spirv.FAdd %39, %41 : f32
// CHECK: spirv.AccessChain [[ARG2]]
%43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
%43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %43, %42 : f32
spirv.Return
}
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ func.func @combine_full_access_chain() -> f32 {
// CHECK-NEXT: spirv.Load "Function" %[[PTR]]
%c0 = spirv.Constant 0: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
%2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
%2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
%3 = spirv.Load "Function" %2 : f32
spirv.ReturnValue %3 : f32
}
Expand All @@ -28,9 +28,9 @@ func.func @combine_access_chain_multi_use() -> !spirv.array<4xf32> {
// CHECK-NEXT: spirv.Load "Function" %[[PTR_1]]
%c0 = spirv.Constant 0: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
%2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
%3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
%2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
%3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
%4 = spirv.Load "Function" %2 : !spirv.array<4xf32>
%5 = spirv.Load "Function" %3 : f32
spirv.ReturnValue %4: !spirv.array<4xf32>
Expand All @@ -49,8 +49,8 @@ func.func @dont_combine_access_chain_without_common_base() -> !spirv.array<4xi32
%c1 = spirv.Constant 1: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
%2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
%3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
%2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
%3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
%4 = spirv.Load "Function" %2 : !spirv.array<4xi32>
%5 = spirv.Load "Function" %3 : !spirv.array<4xi32>
spirv.ReturnValue %4 : !spirv.array<4xi32>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ spirv.module Logical GLSL450 {
spirv.func @callee() "None" {
%0 = spirv.mlir.addressof @data : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>
%1 = spirv.Constant 0: i32
%2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32
%2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Branch ^next

^next:
Expand Down Expand Up @@ -196,15 +196,15 @@ spirv.module Logical GLSL450 {
// CHECK: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOADPTR]]
%2 = spirv.mlir.addressof @arg_0 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
%3 = spirv.mlir.addressof @arg_1 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
%4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
%4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
%5 = spirv.Load "StorageBuffer" %4 : i32
%6 = spirv.SGreaterThan %5, %1 : i32
// CHECK: spirv.mlir.selection
spirv.mlir.selection {
spirv.BranchConditional %6, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: [[STOREPTR:%.*]] = spirv.AccessChain [[ADDRESS_ARG1]]
%7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
%7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
// CHECK-NOT: spirv.FunctionCall
// CHECK: spirv.AtomicIAdd <Device> <AcquireRelease> [[STOREPTR]], [[VAL]]
// CHECK: spirv.Branch
Expand Down
Loading
Loading