@@ -105,15 +105,15 @@ struct ApplyOpAnalysis {
105105 void performAnalysis (Operation *op) {
106106 op->walk ([&](quake::ApplyOp apply) {
107107 if (constProp) {
108- // If some of the arguments in getArgs () are constants, then materialize
109- // those constants in a clone of the variant. The specialized variant
110- // will then be able to perform better constant propagation even if not
111- // inlined.
108+ // If some of the arguments in getActuals () are constants, then
109+ // materialize those constants in a clone of the variant. The
110+ // specialized variant will then be able to perform better constant
111+ // propagation even if not inlined.
112112 auto calleeName = apply.getCallee ()->getRootReference ().str ();
113113 if (func::FuncOp genericFunc =
114114 module .lookupSymbol <func::FuncOp>(calleeName)) {
115115 SmallVector<Value> newArgs;
116- newArgs.append (apply.getArgs ().begin (), apply.getArgs ().end ());
116+ newArgs.append (apply.getActuals ().begin (), apply.getActuals ().end ());
117117 IRMapping mapper;
118118 SmallVector<Value> preservedArgs;
119119 SmallVector<Type> inputTys;
@@ -314,8 +314,12 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
314314 LogicalResult matchAndRewrite (quake::ApplyOp apply,
315315 PatternRewriter &rewriter) const override {
316316 std::string calleeOrigName;
317- if (apply.getCallee ()) {
318- calleeOrigName = apply.getCallee ()->getRootReference ().str ();
317+ FunctionType calleeSignature;
318+ if (auto callee = apply.getCallee ()) {
319+ calleeOrigName = callee->getRootReference ().str ();
320+ auto fn =
321+ SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(apply, *callee);
322+ calleeSignature = fn.getFunctionType ();
319323 } else {
320324 // Check if the first argument is a func.ConstantOp.
321325 auto calleeVals = apply.getIndirectCallee ();
@@ -326,27 +330,31 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
326330 if (!fc)
327331 return failure ();
328332 calleeOrigName = fc.getValue ().str ();
333+ calleeSignature = dyn_cast<FunctionType>(fc.getResult ().getType ());
329334 }
330335 auto calleeName = getVariantFunctionName (apply, calleeOrigName);
331336 auto *ctx = apply.getContext ();
332- auto consTy = quake::VeqType::getUnsized (ctx);
337+ auto unsizedVeqTy = quake::VeqType::getUnsized (ctx);
333338 SmallVector<Value> newArgs;
334339 if (!apply.getControls ().empty ()) {
335- auto consOp = rewriter.create <quake::ConcatOp>(apply. getLoc (), consTy,
336- apply.getControls ());
340+ auto consOp = rewriter.create <quake::ConcatOp>(
341+ apply. getLoc (), unsizedVeqTy, apply.getControls ());
337342 newArgs.push_back (consOp);
338343 }
339- if (constProp) {
340- for (auto v : apply.getArgs ()) {
341- if (auto c = v.getDefiningOp <arith::ConstantOp>())
342- continue ;
343- newArgs.emplace_back (v);
344- }
345- } else {
346- newArgs.append (apply.getArgs ().begin (), apply.getArgs ().end ());
344+ for (auto [v, toTy] :
345+ llvm::zip (apply.getActuals (), calleeSignature.getInputs ())) {
346+ if (constProp && v.getDefiningOp <arith::ConstantOp>())
347+ continue ;
348+ Value arg = v;
349+ if (arg.getType () != toTy)
350+ arg =
351+ rewriter.create <quake::ConcatOp>(apply.getLoc (), unsizedVeqTy, arg);
352+ newArgs.emplace_back (arg);
347353 }
348- rewriter.replaceOpWithNewOp <func::CallOp>(apply, apply.getResultTypes (),
349- calleeName, newArgs);
354+ LLVM_DEBUG (llvm::dbgs () << " replacing: " << apply << ' \n ' );
355+ [[maybe_unused]] auto result = rewriter.replaceOpWithNewOp <func::CallOp>(
356+ apply, apply.getResultTypes (), calleeName, newArgs);
357+ LLVM_DEBUG (llvm::dbgs () << " with " << result << ' \n ' );
350358 return success ();
351359 }
352360
@@ -363,16 +371,18 @@ struct FoldCallable : public OpRewritePattern<quake::ApplyOp> {
363371 return failure ();
364372
365373 Value ind = apply.getIndirectCallee ()[0 ];
366- if (auto callee = ind.getDefiningOp <cudaq::cc::InstantiateCallableOp>()) {
367- auto sym = callee.getCallee ();
368- SmallVector<Value> newArguments = {ind};
369- newArguments.append (apply.getArgs ().begin (), apply.getArgs ().end ());
370- rewriter.replaceOpWithNewOp <quake::ApplyOp>(
371- apply, apply.getResultTypes (), sym, apply.getIsAdj (),
372- apply.getControls (), newArguments);
373- return success ();
374- }
375- return failure ();
374+ auto callee = ind.getDefiningOp <cudaq::cc::InstantiateCallableOp>();
375+ if (!callee)
376+ return failure ();
377+ auto sym = callee.getCallee ();
378+ SmallVector<Value> newArguments = {ind};
379+ newArguments.append (apply.getActuals ().begin (), apply.getActuals ().end ());
380+ LLVM_DEBUG (llvm::dbgs () << " replacing " << apply << ' \n ' );
381+ [[maybe_unused]] auto result = rewriter.replaceOpWithNewOp <quake::ApplyOp>(
382+ apply, apply.getResultTypes (), sym, apply.getIsAdj (),
383+ apply.getControls (), newArguments);
384+ LLVM_DEBUG (llvm::dbgs () << " as " << result << ' \n ' );
385+ return success ();
376386 }
377387};
378388
@@ -529,7 +539,7 @@ class ApplySpecializationPass
529539 apply.getControls ().end ());
530540 auto newApply = builder.create <quake::ApplyOp>(
531541 apply.getLoc (), apply.getResultTypes (), apply.getCalleeAttr (),
532- apply.getIsAdjAttr (), newControls, apply.getArgs ());
542+ apply.getIsAdjAttr (), newControls, apply.getActuals ());
533543 apply->replaceAllUsesWith (newApply.getResults ());
534544 apply->erase ();
535545 } else if (isQuantumKernelCall (op)) {
0 commit comments