Skip to content

Commit 6e0339a

Browse files
Fix cudaq.control() veq type specialization bug
Fix a bug where controlled kernel calls failed with type mismatch errors when the kernel had both float and list[complex] parameters and accessed .real on complex values. The bug occurred during constant propagation in ApplySpecialization when veq argument types were specialized (un-relaxed from !quake.veq<?> to !quake.veq<N>) but inner func.call operations still expected the dynamic type, causing "'func.call' op operand type mismatch" errors. The fix adds hasCallWithDynamicVeq() to check if a function has any inner calls expecting !quake.veq<?>, and skips veq type specialization in those cases to avoid the type mismatch. Changes: - lib/Optimizer/Transforms/ApplyOpSpecialization.cpp: Add check to skip veq specialization when inner calls expect dynamic veq types - test/Transforms/apply_ctrl_veq_type.qke: New MLIR regression test - python/tests/kernel/test_control_negations.py: Remove xfail marker from test_control_float_list_complex_real_access Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 7fa0116 commit 6e0339a

File tree

3 files changed

+186
-5
lines changed

3 files changed

+186
-5
lines changed

lib/Optimizer/Transforms/ApplyOpSpecialization.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,38 @@ struct ApplyVariants {
6363
/// Map from `func::FuncOp` to the variants to be created.
6464
using ApplyOpAnalysisInfo = DenseMap<Operation *, ApplyVariants>;
6565

66+
/// Check if a function has any func.call operations that take a dynamic
67+
/// !quake.veq<?> argument. If so, we should not specialize (un-relax) veq
68+
/// argument types during constant propagation, as this would cause type
69+
/// mismatches when the specialized function calls inner kernels expecting
70+
/// the dynamic type.
71+
///
72+
/// Alternatives to this conservative approach:
73+
/// 1. Dataflow analysis: trace if a specific argument reaches such a call,
74+
/// allowing specialization of unaffected arguments.
75+
/// 2. Recursive specialization: specialize all callees in the call tree to
76+
/// accept the concrete veq size, propagating type info deeper for better
77+
/// optimization but increasing code size.
78+
static bool hasCallWithDynamicVeq(func::FuncOp func, ModuleOp module) {
79+
bool found = false;
80+
func.walk([&](func::CallOp callOp) {
81+
if (found)
82+
return;
83+
auto callee = module.lookupSymbol<func::FuncOp>(callOp.getCallee());
84+
if (!callee)
85+
return;
86+
for (auto inputTy : callee.getFunctionType().getInputs()) {
87+
if (auto veqTy = dyn_cast<quake::VeqType>(inputTy)) {
88+
if (!veqTy.hasSpecifiedSize()) {
89+
found = true;
90+
return;
91+
}
92+
}
93+
}
94+
});
95+
return found;
96+
}
97+
6698
/// This analysis scans the IR for `ApplyOp`s to see which ones need to have
6799
/// variants created.
68100
struct ApplyOpAnalysis {
@@ -99,11 +131,14 @@ struct ApplyOpAnalysis {
99131
LLVM_DEBUG(llvm::dbgs() << "apply has constant arguments.\n");
100132
} else {
101133
if (auto relax = v.getDefiningOp<quake::RelaxSizeOp>()) {
102-
// Also, specialize any relaxed veq types.
103-
v = relax.getInputVec();
104-
updateSignature = true;
105-
LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument ("
106-
<< v.getType() << ")\n");
134+
// Specialize relaxed veq types, but only if the function has no
135+
// inner calls expecting dynamic !quake.veq<?> types.
136+
if (!hasCallWithDynamicVeq(genericFunc, module)) {
137+
v = relax.getInputVec();
138+
updateSignature = true;
139+
LLVM_DEBUG(llvm::dbgs() << "specializing apply veq argument ("
140+
<< v.getType() << ")\n");
141+
}
107142
}
108143
inputTys.push_back(v.getType());
109144
preservedArgs.push_back(v);

python/tests/kernel/test_control_negations.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,81 @@ def control_simple_gate():
255255
e.value)
256256

257257

258+
def test_control_float_list_complex_real_access():
259+
"""
260+
Regression test for a bug in cudaq.control() argument synthesis.
261+
262+
The bug occurs when these three conditions are met:
263+
1. Using cudaq.control() to call a sub-kernel
264+
2. The sub-kernel has BOTH float AND list[complex] parameters
265+
3. The sub-kernel accesses .real on a complex value from the list
266+
267+
Error: 'func.call' op operand type mismatch: expected operand type
268+
'!quake.veq<?>', but provided '!quake.veq<N>'
269+
RuntimeError: Could not successfully apply argument synth.
270+
271+
This pattern is used in the krylov.ipynb notebook.
272+
"""
273+
274+
@cudaq.kernel
275+
def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]):
276+
rx(dt * values[0].real, qubits[0])
277+
278+
@cudaq.kernel
279+
def main_kernel(dt: float, values: list[complex]):
280+
ancilla = cudaq.qubit()
281+
qreg = cudaq.qvector(2)
282+
h(ancilla)
283+
cudaq.control(sub_kernel, ancilla, qreg, dt, values)
284+
285+
result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j])
286+
assert len(result) > 0
287+
288+
289+
def test_control_list_complex_real_access_no_float():
290+
"""
291+
Verify that list[complex] with .real access works when there's no float param.
292+
This is a control test to confirm the bug is specific to the float + list[complex]
293+
combination.
294+
"""
295+
296+
@cudaq.kernel
297+
def sub_kernel(qubits: cudaq.qview, values: list[complex]):
298+
rx(values[0].real, qubits[0])
299+
300+
@cudaq.kernel
301+
def main_kernel(values: list[complex]):
302+
ancilla = cudaq.qubit()
303+
qreg = cudaq.qvector(2)
304+
h(ancilla)
305+
cudaq.control(sub_kernel, ancilla, qreg, values)
306+
307+
# This should work
308+
result = cudaq.sample(main_kernel, [0.5 + 0j, 0.25 + 0j])
309+
assert len(result) > 0
310+
311+
312+
def test_control_float_list_complex_no_real_access():
313+
"""
314+
Verify that float + list[complex] works when .real is not accessed.
315+
This is a control test to confirm the bug requires the .real access.
316+
"""
317+
318+
@cudaq.kernel
319+
def sub_kernel(qubits: cudaq.qview, dt: float, values: list[complex]):
320+
rx(dt, qubits[0])
321+
322+
@cudaq.kernel
323+
def main_kernel(dt: float, values: list[complex]):
324+
ancilla = cudaq.qubit()
325+
qreg = cudaq.qvector(2)
326+
h(ancilla)
327+
cudaq.control(sub_kernel, ancilla, qreg, dt, values)
328+
329+
result = cudaq.sample(main_kernel, 0.1, [0.5 + 0j, 0.25 + 0j])
330+
assert len(result) > 0
331+
332+
258333
# leave for gdb debugging
259334
if __name__ == "__main__":
260335
loc = os.path.abspath(__file__)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// ========================================================================== //
2+
// Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. //
3+
// All rights reserved. //
4+
// //
5+
// This source code and the accompanying materials are made available under //
6+
// the terms of the Apache License 2.0 which accompanies this distribution. //
7+
// ========================================================================== //
8+
9+
// Regression test for a bug where veq type specialization caused type mismatches
10+
// when creating control variants of functions that have inner calls to other
11+
// kernels expecting dynamic !quake.veq<?> types.
12+
//
13+
// The bug scenario:
14+
// 1. main_kernel allocates !quake.veq<2> and relaxes it to !quake.veq<?>
15+
// 2. A controlled apply to thunk_kernel with a constant argument triggers
16+
// constant propagation which would un-relax the veq type back to !quake.veq<2>
17+
// 3. The thunk calls inner_kernel which expects !quake.veq<?>
18+
// 4. When the .ctrl variant is created, the inner call has a type mismatch
19+
//
20+
// The fix: Don't specialize (un-relax) veq types when the function has inner
21+
// func.call operations, since those callees still expect the dynamic type.
22+
// An alternative fix would be to recursively specialize all callees in the
23+
// call tree, but that would increase code size.
24+
25+
// RUN: cudaq-opt --apply-op-specialization=constant-prop=1 %s | FileCheck %s
26+
27+
module {
28+
// Inner kernel that expects dynamic veq type
29+
func.func @inner_kernel(%qubits: !quake.veq<?>, %angle: f64) attributes {"cudaq-kernel"} {
30+
%0 = quake.extract_ref %qubits[0] : (!quake.veq<?>) -> !quake.ref
31+
quake.rx (%angle) %0 : (f64, !quake.ref) -> ()
32+
return
33+
}
34+
35+
// Thunk kernel that calls inner_kernel - will be specialized with constant prop
36+
func.func @thunk_kernel(%qubits: !quake.veq<?>, %angle: f64) attributes {"cudaq-kernel"} {
37+
func.call @inner_kernel(%qubits, %angle) : (!quake.veq<?>, f64) -> ()
38+
return
39+
}
40+
41+
// Main kernel that uses controlled apply with a constant angle
42+
func.func @main_kernel() attributes {"cudaq-entrypoint", "cudaq-kernel"} {
43+
%cst = arith.constant 0.1 : f64
44+
%ancilla = quake.alloca !quake.ref
45+
%qreg = quake.alloca !quake.veq<2>
46+
%relaxed = quake.relax_size %qreg : (!quake.veq<2>) -> !quake.veq<?>
47+
quake.h %ancilla : (!quake.ref) -> ()
48+
// Controlled apply with constant - triggers specialization
49+
quake.apply @thunk_kernel [%ancilla] %relaxed, %cst : (!quake.ref, !quake.veq<?>, f64) -> ()
50+
return
51+
}
52+
}
53+
54+
// Verify that the .ctrl variant is created and veq type is NOT specialized
55+
// (remains !quake.veq<?>) because the function has inner func.call operations.
56+
57+
// CHECK-LABEL: func.func private @thunk_kernel.0.ctrl(
58+
// CHECK-SAME: %[[CTRL:.*]]: !quake.veq<?>,
59+
// CHECK-SAME: %[[QUBITS:.*]]: !quake.veq<?>)
60+
// The veq argument keeps its dynamic type, so no relax_size is needed
61+
// CHECK: call @inner_kernel(%[[QUBITS]],
62+
// CHECK-SAME: : (!quake.veq<?>, f64) -> ()
63+
// CHECK: return
64+
65+
// CHECK-LABEL: func.func @main_kernel()
66+
// CHECK: %[[ANCILLA:.*]] = quake.alloca !quake.ref
67+
// CHECK: %[[QREG:.*]] = quake.alloca !quake.veq<2>
68+
// CHECK: %[[RELAXED:.*]] = quake.relax_size %[[QREG]]
69+
// CHECK: quake.h %[[ANCILLA]]
70+
// CHECK: %[[CONCAT:.*]] = quake.concat %[[ANCILLA]]
71+
// CHECK: call @thunk_kernel.0.ctrl(%[[CONCAT]], %[[RELAXED]])

0 commit comments

Comments
 (0)