diff --git a/lib/Optimizer/Transforms/ApplyOpSpecialization.cpp b/lib/Optimizer/Transforms/ApplyOpSpecialization.cpp index 14265547d17..009b787aae9 100644 --- a/lib/Optimizer/Transforms/ApplyOpSpecialization.cpp +++ b/lib/Optimizer/Transforms/ApplyOpSpecialization.cpp @@ -63,6 +63,34 @@ struct ApplyVariants { /// Map from `func::FuncOp` to the variants to be created. using ApplyOpAnalysisInfo = DenseMap; +/// Check if a function has any func.call operations that take a dynamic +/// !quake.veq argument. If so, we should not specialize (un-relax) veq +/// argument types during constant propagation, as this would cause type +/// mismatches when the specialized function calls inner kernels expecting +/// the dynamic type. +/// +/// Alternatives to this conservative approach: +/// 1. Dataflow analysis: trace if a specific argument reaches such a call, +/// allowing specialization of unaffected arguments. +/// 2. Recursive specialization: specialize all callees in the call tree to +/// accept the concrete veq size, propagating type info deeper for better +/// optimization but increasing code size. +static bool hasCallWithDynamicVeq(func::FuncOp func, ModuleOp module) { + auto result = func.walk([&](func::CallOp callOp) { + auto callee = module.lookupSymbol(callOp.getCallee()); + if (!callee) + return WalkResult::advance(); + for (auto inputTy : callee.getFunctionType().getInputs()) { + if (auto veqTy = dyn_cast(inputTy)) { + if (!veqTy.hasSpecifiedSize()) + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return result.wasInterrupted(); +} + /// This analysis scans the IR for `ApplyOp`s to see which ones need to have /// variants created. struct ApplyOpAnalysis { @@ -99,11 +127,14 @@ struct ApplyOpAnalysis { LLVM_DEBUG(llvm::dbgs() << "apply has constant arguments.\n"); } else { if (auto relax = v.getDefiningOp()) { - // Also, specialize any relaxed veq types. - v = relax.getInputVec(); - updateSignature = true; - LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument (" - << v.getType() << ")\n"); + // Specialize relaxed veq types, but only if the function has no + // inner calls expecting dynamic !quake.veq types. + if (!hasCallWithDynamicVeq(genericFunc, module)) { + v = relax.getInputVec(); + updateSignature = true; + LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument (" + << v.getType() << ")\n"); + } } inputTys.push_back(v.getType()); preservedArgs.push_back(v); diff --git a/python/tests/kernel/test_control_negations.py b/python/tests/kernel/test_control_negations.py index ab42fc4aa7d..a7e285ca530 100644 --- a/python/tests/kernel/test_control_negations.py +++ b/python/tests/kernel/test_control_negations.py @@ -255,6 +255,81 @@ def control_simple_gate(): e.value) +def test_control_float_list_complex_real_access(): + """ + Regression test for a bug in cudaq.control() argument synthesis. + + The bug occurs when these three conditions are met: + 1. Using cudaq.control() to call a sub-kernel + 2. The sub-kernel has BOTH float AND list[complex] parameters + 3. The sub-kernel accesses .real on a complex value from the list + + Error: 'func.call' op operand type mismatch: expected operand type + '!quake.veq', but provided '!quake.veq' + RuntimeError: Could not successfully apply argument synth. + + This pattern is used in the krylov.ipynb notebook. + """ + + @cudaq.kernel + def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]): + rx(dt * values[0].real, qubits[0]) + + @cudaq.kernel + def main_kernel(dt: float, values: list[complex]): + ancilla = cudaq.qubit() + qreg = cudaq.qvector(2) + h(ancilla) + cudaq.control(sub_kernel, ancilla, qreg, dt, values) + + result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j]) + assert len(result) > 0 + + +def test_control_list_complex_real_access_no_float(): + """ + Verify that list[complex] with .real access works when there's no float param. + This is a control test to confirm the bug is specific to the float + list[complex] + combination. + """ + + @cudaq.kernel + def sub_kernel(qubits: cudaq.qview, values: list[complex]): + rx(values[0].real, qubits[0]) + + @cudaq.kernel + def main_kernel(values: list[complex]): + ancilla = cudaq.qubit() + qreg = cudaq.qvector(2) + h(ancilla) + cudaq.control(sub_kernel, ancilla, qreg, values) + + # This should work + result = cudaq.sample(main_kernel, [0.5 + 0j, 0.25 + 0j]) + assert len(result) > 0 + + +def test_control_float_list_complex_no_real_access(): + """ + Verify that float + list[complex] works when .real is not accessed. + This is a control test to confirm the bug requires the .real access. + """ + + @cudaq.kernel + def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]): + rx(dt, qubits[0]) + + @cudaq.kernel + def main_kernel(dt: float, values: list[complex]): + ancilla = cudaq.qubit() + qreg = cudaq.qvector(2) + h(ancilla) + cudaq.control(sub_kernel, ancilla, qreg, dt, values) + + result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j]) + assert len(result) > 0 + + # leave for gdb debugging if __name__ == "__main__": loc = os.path.abspath(__file__) diff --git a/test/Transforms/apply_ctrl_veq_type.qke b/test/Transforms/apply_ctrl_veq_type.qke new file mode 100644 index 00000000000..03576c86003 --- /dev/null +++ b/test/Transforms/apply_ctrl_veq_type.qke @@ -0,0 +1,71 @@ +// ========================================================================== // +// Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// Regression test for a bug where veq type specialization caused type mismatches +// when creating control variants of functions that have inner calls to other +// kernels expecting dynamic !quake.veq types. +// +// The bug scenario: +// 1. main_kernel allocates !quake.veq<2> and relaxes it to !quake.veq +// 2. A controlled apply to thunk_kernel with a constant argument triggers +// constant propagation which would un-relax the veq type back to !quake.veq<2> +// 3. The thunk calls inner_kernel which expects !quake.veq +// 4. When the .ctrl variant is created, the inner call has a type mismatch +// +// The fix: Don't specialize (un-relax) veq types when the function has inner +// func.call operations, since those callees still expect the dynamic type. +// An alternative fix would be to recursively specialize all callees in the +// call tree, but that would increase code size. + +// RUN: cudaq-opt --apply-op-specialization=constant-prop=1 %s | FileCheck %s + +module { + // Inner kernel that expects dynamic veq type + func.func @inner_kernel(%qubits: !quake.veq, %angle: f64) attributes {"cudaq-kernel"} { + %0 = quake.extract_ref %qubits[0] : (!quake.veq) -> !quake.ref + quake.rx (%angle) %0 : (f64, !quake.ref) -> () + return + } + + // Thunk kernel that calls inner_kernel - will be specialized with constant prop + func.func @thunk_kernel(%qubits: !quake.veq, %angle: f64) attributes {"cudaq-kernel"} { + func.call @inner_kernel(%qubits, %angle) : (!quake.veq, f64) -> () + return + } + + // Main kernel that uses controlled apply with a constant angle + func.func @main_kernel() attributes {"cudaq-entrypoint", "cudaq-kernel"} { + %cst = arith.constant 0.1 : f64 + %ancilla = quake.alloca !quake.ref + %qreg = quake.alloca !quake.veq<2> + %relaxed = quake.relax_size %qreg : (!quake.veq<2>) -> !quake.veq + quake.h %ancilla : (!quake.ref) -> () + // Controlled apply with constant - triggers specialization + quake.apply @thunk_kernel [%ancilla] %relaxed, %cst : (!quake.ref, !quake.veq, f64) -> () + return + } +} + +// Verify that the .ctrl variant is created and veq type is NOT specialized +// (remains !quake.veq) because the function has inner func.call operations. + +// CHECK-LABEL: func.func private @thunk_kernel.0.ctrl( +// CHECK-SAME: %[[CTRL:.*]]: !quake.veq, +// CHECK-SAME: %[[QUBITS:.*]]: !quake.veq) +// The veq argument keeps its dynamic type, so no relax_size is needed +// CHECK: call @inner_kernel(%[[QUBITS]], +// CHECK-SAME: : (!quake.veq, f64) -> () +// CHECK: return + +// CHECK-LABEL: func.func @main_kernel() +// CHECK: %[[ANCILLA:.*]] = quake.alloca !quake.ref +// CHECK: %[[QREG:.*]] = quake.alloca !quake.veq<2> +// CHECK: %[[RELAXED:.*]] = quake.relax_size %[[QREG]] +// CHECK: quake.h %[[ANCILLA]] +// CHECK: %[[CONCAT:.*]] = quake.concat %[[ANCILLA]] +// CHECK: call @thunk_kernel.0.ctrl(%[[CONCAT]], %[[RELAXED]])