Skip to content

Commit 3e4ad9c

Browse files
committed
[core] Update quake.apply operation.
Add a verifier and diagnostic message tests. Fix bugs in apply-op-specialization and canonicalization. Add tests and handling for !quake.struq while we're here. Signed-off-by: Eric Schweitz <eschweitz@nvidia.com>
1 parent 5372c6b commit 3e4ad9c

File tree

10 files changed

+270
-46
lines changed

10 files changed

+270
-46
lines changed

include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,27 +397,29 @@ def quake_ApplyOp : QuakeOp<"apply",
397397
Variadic<cc_CallableType>:$indirect_callee, // must be 0 or 1 element
398398
UnitAttr:$is_adj,
399399
Variadic<AnyQType>:$controls,
400-
Variadic<AnyType>:$args
400+
Variadic<AnyType>:$actuals
401401
);
402402
let results = (outs Variadic<AnyType>);
403403

404404
let hasCustomAssemblyFormat = 1;
405+
let hasVerifier = 1;
406+
405407
let builders = [
406408
OpBuilder<(ins "mlir::TypeRange":$retTy,
407409
"mlir::SymbolRefAttr":$callee,
408410
"mlir::UnitAttr":$is_adj,
409411
"mlir::ValueRange":$controls,
410412
"mlir::ValueRange":$args), [{
411-
return build($_builder, $_state, retTy, callee, mlir::ValueRange{},
412-
is_adj, controls, args);
413+
return build($_builder, $_state, retTy, callee, {}, is_adj, controls,
414+
args);
413415
}]>,
414416
OpBuilder<(ins "mlir::TypeRange":$retTy,
415417
"mlir::SymbolRefAttr":$callee,
416418
"bool":$is_adj,
417419
"mlir::ValueRange":$controls,
418420
"mlir::ValueRange":$args), [{
419-
return build($_builder, $_state, retTy, callee, mlir::ValueRange{},
420-
is_adj, controls, args);
421+
return build($_builder, $_state, retTy, callee, {}, is_adj, controls,
422+
args);
421423
}]>,
422424
OpBuilder<(ins "mlir::TypeRange":$retTy,
423425
"mlir::Value":$callable,
@@ -446,7 +448,7 @@ def quake_ApplyOp : QuakeOp<"apply",
446448
operand_range getArgOperands() {
447449
if (getControls().empty())
448450
return {operand_begin(), operand_end()};
449-
return {getArgs().begin(), getArgs().end()};
451+
return {getActuals().begin(), getActuals().end()};
450452
}
451453

452454
bool applyToVariant() {

lib/Optimizer/CodeGen/TranslateToOpenQASM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ static LogicalResult emitOperation(Emitter &emitter, quake::ApplyOp op) {
176176
// Separate classical and quantum arguments.
177177
SmallVector<Value> parameters;
178178
SmallVector<Value> targets;
179-
for (auto arg : op.getArgs()) {
179+
for (auto arg : op.getActuals()) {
180180
if (isa<quake::RefType, quake::VeqType>(arg.getType()))
181181
targets.push_back(arg);
182182
else

lib/Optimizer/Dialect/Quake/CanonicalPatterns.inc

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,25 +316,45 @@ struct ConcatSizePattern : public OpRewritePattern<quake::ConcatOp> {
316316
if (concat.getType().hasSpecifiedSize())
317317
return failure();
318318

319+
auto *ctx = rewriter.getContext();
320+
auto loc = concat.getLoc();
321+
auto refTy = quake::RefType::get(ctx);
319322
// Walk the arguments and sum them, if possible.
320323
std::size_t sum = 0;
324+
SmallVector<Value> targets;
321325
for (auto opnd : concat.getTargets()) {
322326
if (auto veqTy = dyn_cast<quake::VeqType>(opnd.getType())) {
323327
if (!veqTy.hasSpecifiedSize())
324328
return failure();
325329
sum += veqTy.getSize();
330+
targets.push_back(opnd);
326331
continue;
327332
}
328-
assert(isa<quake::RefType>(opnd.getType()));
329-
sum++;
333+
if (auto stqTy = dyn_cast<quake::StruqType>(opnd.getType())) {
334+
if (!stqTy.hasSpecifiedSize())
335+
return failure();
336+
sum += quake::getAllocationSize(stqTy);
337+
auto arity = stqTy.getArity();
338+
if (*arity) {
339+
// Get each member for IR legalization.
340+
for (auto [i, memTy] : llvm::enumerate(stqTy.getMembers())) {
341+
auto mem = rewriter.create<quake::GetMemberOp>(loc, memTy, opnd, i);
342+
targets.push_back(mem);
343+
}
344+
}
345+
continue;
346+
}
347+
if (opnd.getType() == refTy) {
348+
sum++;
349+
targets.push_back(opnd);
350+
continue;
351+
}
352+
return failure();
330353
}
331354

332355
// Leans into the relax_size canonicalization pattern.
333-
auto *ctx = rewriter.getContext();
334-
auto loc = concat.getLoc();
335356
auto newTy = quake::VeqType::get(ctx, sum);
336-
Value newOp =
337-
rewriter.create<quake::ConcatOp>(loc, newTy, concat.getTargets());
357+
Value newOp = rewriter.create<quake::ConcatOp>(loc, newTy, targets);
338358
auto noSizeTy = quake::VeqType::getUnsized(ctx);
339359
rewriter.replaceOpWithNewOp<quake::RelaxSizeOp>(concat, noSizeTy, newOp);
340360
return success();

lib/Optimizer/Dialect/Quake/QuakeOps.cpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,46 @@ quake::InitializeStateOp quake::AllocaOp::getInitializedState() {
201201
// Apply
202202
//===----------------------------------------------------------------------===//
203203

204+
LogicalResult quake::ApplyOp::verify() {
205+
FunctionType asSig;
206+
if (auto callee = getCallee()) {
207+
auto fn =
208+
SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(*this, *callee);
209+
if (!fn)
210+
return emitOpError("callee must be declared");
211+
asSig = fn.getFunctionType();
212+
} else {
213+
Value callable = getIndirectCallee().front();
214+
asSig = cast<cudaq::cc::CallableType>(callable.getType()).getSignature();
215+
}
216+
217+
// Arity of callee's signature must be equal to number of arguments provided.
218+
if (getActuals().size() != asSig.getInputs().size())
219+
return emitOpError("number of arguments must be consistent");
220+
221+
// Quantum reference type values are allowed to implicitly coerce to a relaxed
222+
// veq type when they appear as arguments to a `quake.apply` op. Specifically,
223+
// lowering the apply op is required to add a `quake.concat` op to manifest
224+
// the type conversion.
225+
auto isRelaxedVeq = [](Type ty1, Type ty2) {
226+
if (auto veq2 = dyn_cast<quake::VeqType>(ty2))
227+
return quake::isQuantumReferenceType(ty1) && !veq2.hasSpecifiedSize();
228+
return false;
229+
};
230+
231+
// The args are the formal arguments and they must match.
232+
for (auto [ty1, ty2] : llvm::zip(getActuals().getTypes(), asSig.getInputs()))
233+
if (ty1 != ty2 && !isRelaxedVeq(ty1, ty2))
234+
return emitOpError("argument types must match");
235+
236+
// The results are the formal results and they must match.
237+
for (auto [ty1, ty2] : llvm::zip(getResultTypes(), asSig.getResults()))
238+
if (ty1 != ty2 && !isRelaxedVeq(ty1, ty2))
239+
return emitOpError("result types must match");
240+
241+
return success();
242+
}
243+
204244
void quake::ApplyOp::print(OpAsmPrinter &p) {
205245
if (getIsAdj())
206246
p << "<adj>";
@@ -213,7 +253,7 @@ void quake::ApplyOp::print(OpAsmPrinter &p) {
213253
p << ' ';
214254
if (!getControls().empty())
215255
p << '[' << getControls() << "] ";
216-
p << getArgs() << " : ";
256+
p << getActuals() << " : ";
217257
SmallVector<Type> operandTys{(*this)->getOperandTypes().begin(),
218258
(*this)->getOperandTypes().end()};
219259
p.printFunctionalType(ArrayRef<Type>{operandTys}.drop_front(isDirect ? 0 : 1),
@@ -229,13 +269,14 @@ ParseResult quake::ApplyOp::parse(OpAsmParser &parser, OperationState &result) {
229269
return failure();
230270
result.addAttribute("is_adj", parser.getBuilder().getUnitAttr());
231271
}
272+
OpAsmParser::UnresolvedOperand calleeOpnd;
232273
SmallVector<OpAsmParser::UnresolvedOperand> calleeOperand;
233-
if (parser.parseOperandList(calleeOperand))
234-
return failure();
235-
bool isDirect = calleeOperand.empty();
236-
if (calleeOperand.size() > 1)
237-
return failure();
238-
if (isDirect) {
274+
bool isDirect;
275+
if (parser.parseOptionalOperand(calleeOpnd).has_value()) {
276+
isDirect = false;
277+
calleeOperand.push_back(calleeOpnd);
278+
} else {
279+
isDirect = true;
239280
NamedAttrList attrs;
240281
SymbolRefAttr funcAttr;
241282
if (parser.parseCustomAttributeWithFallback(

lib/Optimizer/Transforms/ApplyOpSpecialization.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ struct ApplyOpAnalysis {
105105
void performAnalysis(Operation *op) {
106106
op->walk([&](quake::ApplyOp apply) {
107107
if (constProp) {
108-
// If some of the arguments in getArgs() are constants, then materialize
109-
// those constants in a clone of the variant. The specialized variant
110-
// will then be able to perform better constant propagation even if not
111-
// inlined.
108+
// If some of the arguments in getActuals() are constants, then
109+
// materialize those constants in a clone of the variant. The
110+
// specialized variant will then be able to perform better constant
111+
// propagation even if not inlined.
112112
auto calleeName = apply.getCallee()->getRootReference().str();
113113
if (func::FuncOp genericFunc =
114114
module.lookupSymbol<func::FuncOp>(calleeName)) {
115115
SmallVector<Value> newArgs;
116-
newArgs.append(apply.getArgs().begin(), apply.getArgs().end());
116+
newArgs.append(apply.getActuals().begin(), apply.getActuals().end());
117117
IRMapping mapper;
118118
SmallVector<Value> preservedArgs;
119119
SmallVector<Type> inputTys;
@@ -314,8 +314,12 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
314314
LogicalResult matchAndRewrite(quake::ApplyOp apply,
315315
PatternRewriter &rewriter) const override {
316316
std::string calleeOrigName;
317-
if (apply.getCallee()) {
318-
calleeOrigName = apply.getCallee()->getRootReference().str();
317+
FunctionType calleeSignature;
318+
if (auto callee = apply.getCallee()) {
319+
calleeOrigName = callee->getRootReference().str();
320+
auto fn =
321+
SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(apply, *callee);
322+
calleeSignature = fn.getFunctionType();
319323
} else {
320324
// Check if the first argument is a func.ConstantOp.
321325
auto calleeVals = apply.getIndirectCallee();
@@ -326,24 +330,26 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
326330
if (!fc)
327331
return failure();
328332
calleeOrigName = fc.getValue().str();
333+
calleeSignature = dyn_cast<FunctionType>(fc.getResult().getType());
329334
}
330335
auto calleeName = getVariantFunctionName(apply, calleeOrigName);
331336
auto *ctx = apply.getContext();
332-
auto consTy = quake::VeqType::getUnsized(ctx);
337+
auto unsizedVeqTy = quake::VeqType::getUnsized(ctx);
333338
SmallVector<Value> newArgs;
334339
if (!apply.getControls().empty()) {
335-
auto consOp = rewriter.create<quake::ConcatOp>(apply.getLoc(), consTy,
336-
apply.getControls());
340+
auto consOp = rewriter.create<quake::ConcatOp>(
341+
apply.getLoc(), unsizedVeqTy, apply.getControls());
337342
newArgs.push_back(consOp);
338343
}
339-
if (constProp) {
340-
for (auto v : apply.getArgs()) {
341-
if (auto c = v.getDefiningOp<arith::ConstantOp>())
342-
continue;
343-
newArgs.emplace_back(v);
344-
}
345-
} else {
346-
newArgs.append(apply.getArgs().begin(), apply.getArgs().end());
344+
for (auto [v, toTy] :
345+
llvm::zip(apply.getActuals(), calleeSignature.getInputs())) {
346+
if (constProp && v.getDefiningOp<arith::ConstantOp>())
347+
continue;
348+
Value arg = v;
349+
if (arg.getType() != toTy)
350+
arg =
351+
rewriter.create<quake::ConcatOp>(apply.getLoc(), unsizedVeqTy, arg);
352+
newArgs.emplace_back(arg);
347353
}
348354
rewriter.replaceOpWithNewOp<func::CallOp>(apply, apply.getResultTypes(),
349355
calleeName, newArgs);
@@ -365,8 +371,8 @@ struct FoldCallable : public OpRewritePattern<quake::ApplyOp> {
365371
Value ind = apply.getIndirectCallee()[0];
366372
if (auto callee = ind.getDefiningOp<cudaq::cc::InstantiateCallableOp>()) {
367373
auto sym = callee.getCallee();
368-
SmallVector<Value> newArguments = {ind};
369-
newArguments.append(apply.getArgs().begin(), apply.getArgs().end());
374+
SmallVector<Value> newArguments{apply.getActuals().begin(),
375+
apply.getActuals().end()};
370376
rewriter.replaceOpWithNewOp<quake::ApplyOp>(
371377
apply, apply.getResultTypes(), sym, apply.getIsAdj(),
372378
apply.getControls(), newArguments);
@@ -529,7 +535,7 @@ class ApplySpecializationPass
529535
apply.getControls().end());
530536
auto newApply = builder.create<quake::ApplyOp>(
531537
apply.getLoc(), apply.getResultTypes(), apply.getCalleeAttr(),
532-
apply.getIsAdjAttr(), newControls, apply.getArgs());
538+
apply.getIsAdjAttr(), newControls, apply.getActuals());
533539
apply->replaceAllUsesWith(newApply.getResults());
534540
apply->erase();
535541
} else if (isQuantumKernelCall(op)) {

lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
118118

119119
rewriter.replaceOpWithNewOp<quake::ApplyOp>(
120120
op, TypeRange{}, FlatSymbolRefAttr::get(ctx, replacement.getName()),
121-
adaptor.getIsAdj(), adaptor.getControls(), adaptor.getArgs());
121+
adaptor.getIsAdj(), adaptor.getControls(), adaptor.getActuals());
122122
return success();
123123
}
124124
return failure();

test/Transforms/apply-0.qke

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,57 @@ func.func @do_apply(%arg : !quake.ref, %brg : !quake.ref) {
3232
// CHECK: quake.t<adj> [%[[VAL_0]]] %[[VAL_1]] : (!quake.veq<?>, !quake.ref) -> ()
3333
// CHECK: return
3434
// CHECK: }
35+
36+
// The following test the implicit coercion of apply operands to unsized veq
37+
// type. This can be done uniformly by injecting a concat operation.
38+
39+
func.func @utility_gate(%arg : !quake.veq<?>) {
40+
quake.y %arg : (!quake.veq<?>) -> ()
41+
quake.t %arg : (!quake.veq<?>) -> ()
42+
return
43+
}
44+
45+
func.func @apply_yski() {
46+
%0 = quake.alloca !quake.ref
47+
quake.apply @utility_gate %0 : (!quake.ref) -> ()
48+
quake.dealloc %0 : !quake.ref
49+
return
50+
}
51+
52+
// CHECK-LABEL: func.func @apply_yski() {
53+
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.ref
54+
// CHECK: %[[VAL_1:.*]] = quake.concat %[[VAL_0]] : (!quake.ref) -> !quake.veq<?>
55+
// CHECK: call @utility_gate(%[[VAL_1]]) : (!quake.veq<?>) -> ()
56+
// CHECK: quake.dealloc %[[VAL_0]] : !quake.ref
57+
// CHECK: return
58+
// CHECK: }
59+
60+
func.func @apply_kaksi() {
61+
%0 = quake.alloca !quake.veq<4>
62+
quake.apply @utility_gate %0 : (!quake.veq<4>) -> ()
63+
quake.dealloc %0 : !quake.veq<4>
64+
return
65+
}
66+
67+
// CHECK-LABEL: func.func @apply_kaksi() {
68+
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.veq<4>
69+
// CHECK: %[[VAL_1:.*]] = quake.concat %[[VAL_0]] : (!quake.veq<4>) -> !quake.veq<?>
70+
// CHECK: call @utility_gate(%[[VAL_1]]) : (!quake.veq<?>) -> ()
71+
// CHECK: quake.dealloc %[[VAL_0]] : !quake.veq<4>
72+
// CHECK: return
73+
// CHECK: }
74+
75+
func.func @apply_kolme() {
76+
%0 = quake.alloca !quake.struq<!quake.ref, !quake.ref>
77+
quake.apply @utility_gate %0 : (!quake.struq<!quake.ref, !quake.ref>) -> ()
78+
quake.dealloc %0 : !quake.struq<!quake.ref, !quake.ref>
79+
return
80+
}
81+
82+
// CHECK-LABEL: func.func @apply_kolme() {
83+
// CHECK: %[[VAL_0:.*]] = quake.alloca !quake.struq<!quake.ref, !quake.ref>
84+
// CHECK: %[[VAL_1:.*]] = quake.concat %[[VAL_0]] : (!quake.struq<!quake.ref, !quake.ref>) -> !quake.veq<?>
85+
// CHECK: call @utility_gate(%[[VAL_1]]) : (!quake.veq<?>) -> ()
86+
// CHECK: quake.dealloc %[[VAL_0]] : !quake.struq<!quake.ref, !quake.ref>
87+
// CHECK: return
88+
// CHECK: }

test/Transforms/apply_ctrl_veq_type.qke

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
// the terms of the Apache License 2.0 which accompanies this distribution. //
77
// ========================================================================== //
88

9-
// Regression test for a bug where veq type specialization caused type mismatches
10-
// when creating control variants of functions that have inner calls to other
11-
// kernels expecting dynamic !quake.veq<?> types.
9+
// Regression test for a bug where veq type specialization caused type
10+
// mismatches when creating control variants of functions that have inner calls
11+
// to other kernels expecting dynamic !quake.veq<?> types.
1212
//
1313
// The bug scenario:
1414
// 1. main_kernel allocates !quake.veq<2> and relaxes it to !quake.veq<?>
1515
// 2. A controlled apply to thunk_kernel with a constant argument triggers
16-
// constant propagation which would un-relax the veq type back to !quake.veq<2>
16+
// constant propagation which would un-relax the veq type back to !quake.veq<2>
1717
// 3. The thunk calls inner_kernel which expects !quake.veq<?>
1818
// 4. When the .ctrl variant is created, the inner call has a type mismatch
1919
//

0 commit comments

Comments
 (0)